Files
invoice-master-poc-v2/tests/web/test_inference_api.py
2026-01-27 23:58:17 +01:00

299 lines
9.7 KiB
Python

"""
Integration tests for inference API endpoints.
Tests the /api/v1/infer endpoint to ensure it works end-to-end.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from PIL import Image
import io
from inference.web.app import create_app
from inference.web.config import ModelConfig, StorageConfig, AppConfig
@pytest.fixture
def test_app(tmp_path):
"""Create test FastAPI application."""
# Setup test directories
upload_dir = tmp_path / "uploads"
result_dir = tmp_path / "results"
upload_dir.mkdir()
result_dir.mkdir()
# Create test config
app_config = AppConfig(
model=ModelConfig(
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
confidence_threshold=0.5,
use_gpu=False,
dpi=150,
),
storage=StorageConfig(
upload_dir=upload_dir,
result_dir=result_dir,
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
max_file_size_mb=50,
),
)
# Create app
app = create_app(app_config)
return app
@pytest.fixture
def client(test_app):
"""Create test client."""
return TestClient(test_app)
@pytest.fixture
def sample_png_bytes():
"""Create sample PNG image bytes."""
img = Image.new('RGB', (800, 1200), color='white')
img_bytes = io.BytesIO()
img.save(img_bytes, format='PNG')
img_bytes.seek(0)
return img_bytes
class TestHealthEndpoint:
"""Test /api/v1/health endpoint."""
def test_health_check_returns_200(self, client):
"""Test health check returns 200 OK."""
response = client.get("/api/v1/health")
assert response.status_code == 200
def test_health_check_response_structure(self, client):
"""Test health check response has correct structure."""
response = client.get("/api/v1/health")
data = response.json()
assert "status" in data
assert "model_loaded" in data
assert "gpu_available" in data
assert "version" in data
assert data["status"] == "healthy"
assert isinstance(data["model_loaded"], bool)
assert isinstance(data["gpu_available"], bool)
class TestInferEndpoint:
"""Test /api/v1/infer endpoint."""
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_accepts_png_file(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that /infer endpoint accepts PNG files."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result
mock_result = Mock()
mock_result.fields = {"InvoiceNumber": "12345"}
mock_result.confidence = {"InvoiceNumber": 0.95}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 100.0
mock_result.visualization_path = None
mock_result.detections = []
mock_pipeline_instance.process_image.return_value = mock_result
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify response
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "result" in data
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
assert data["result"]["confidence"]["InvoiceNumber"] == 0.95
def test_infer_rejects_invalid_file_type(self, client):
"""Test that /infer rejects unsupported file types."""
invalid_file = io.BytesIO(b"fake txt content")
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", invalid_file, "text/plain")},
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
def test_infer_requires_file(self, client):
"""Test that /infer requires a file parameter."""
response = client.post("/api/v1/infer")
assert response.status_code == 422 # Unprocessable Entity
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_returns_cross_validation_if_available(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that cross-validation results are included if available."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result with cross-validation
mock_result = Mock()
mock_result.fields = {
"InvoiceNumber": "12345",
"OCR": "1234567",
"Amount": "100.00",
}
mock_result.confidence = {
"InvoiceNumber": 0.95,
"OCR": 0.90,
"Amount": 0.88,
}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 100.0
mock_result.visualization_path = None
mock_result.detections = []
# Add cross-validation result
mock_cv = Mock()
mock_cv.is_valid = True
mock_cv.payment_line_ocr = "1234567"
mock_cv.ocr_match = True
mock_result.cross_validation = mock_cv
mock_pipeline_instance.process_image.return_value = mock_result
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify response includes cross-validation
assert response.status_code == 200
data = response.json()
# Note: cross_validation is not currently in the response schema
# This test documents that it should be added
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
def test_infer_handles_processing_errors_gracefully(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that processing errors are handled gracefully."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Make pipeline raise an error
mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed")
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify error handling - service catches exceptions and returns partial results
assert response.status_code == 200
data = response.json()
assert data["status"] == "partial"
assert data["result"]["success"] is False
assert len(data["result"]["errors"]) > 0
assert "Model inference failed" in data["result"]["errors"][0]
class TestResultsEndpoint:
"""Test /api/v1/results/{filename} endpoint."""
def test_get_result_image_returns_404_if_not_found(self, client):
"""Test that getting non-existent result returns 404."""
response = client.get("/api/v1/results/nonexistent.png")
assert response.status_code == 404
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
"""Test that existing result file is returned."""
# Get storage config from app
storage_config = test_app.extra.get("storage_config")
if not storage_config:
pytest.skip("Storage config not available in test app")
# Create a test result file
result_file = storage_config.result_dir / "test_result.png"
img = Image.new('RGB', (100, 100), color='red')
img.save(result_file)
# Request the file
response = client.get("/api/v1/results/test_result.png")
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"
class TestInferenceServiceImports:
"""Critical test to catch import errors."""
def test_inference_service_can_import_modules(self):
"""
Test that InferenceService can import its dependencies.
This test will fail if there are ImportError issues like:
- from ..inference.pipeline (wrong relative import)
- from inference.web.inference (non-existent module)
It ensures the imports are correct before runtime.
"""
from inference.web.services.inference import InferenceService
# Import the modules that InferenceService tries to import
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
from shared.pdf.renderer import render_pdf_to_images
# If we got here, all imports work correctly
assert InferencePipeline is not None
assert YOLODetector is not None
assert render_pdf_to_images is not None
assert InferenceService is not None