299 lines
9.7 KiB
Python
299 lines
9.7 KiB
Python
"""
|
|
Integration tests for inference API endpoints.
|
|
|
|
Tests the /api/v1/infer endpoint to ensure it works end-to-end.
|
|
"""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
from fastapi.testclient import TestClient
|
|
from PIL import Image
|
|
import io
|
|
|
|
from src.web.app import create_app
|
|
from src.web.config import ModelConfig, StorageConfig, AppConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def test_app(tmp_path):
|
|
"""Create test FastAPI application."""
|
|
# Setup test directories
|
|
upload_dir = tmp_path / "uploads"
|
|
result_dir = tmp_path / "results"
|
|
upload_dir.mkdir()
|
|
result_dir.mkdir()
|
|
|
|
# Create test config
|
|
app_config = AppConfig(
|
|
model=ModelConfig(
|
|
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
|
|
confidence_threshold=0.5,
|
|
use_gpu=False,
|
|
dpi=150,
|
|
),
|
|
storage=StorageConfig(
|
|
upload_dir=upload_dir,
|
|
result_dir=result_dir,
|
|
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
|
max_file_size_mb=50,
|
|
),
|
|
)
|
|
|
|
# Create app
|
|
app = create_app(app_config)
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(test_app):
|
|
"""Create test client."""
|
|
return TestClient(test_app)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_png_bytes():
|
|
"""Create sample PNG image bytes."""
|
|
img = Image.new('RGB', (800, 1200), color='white')
|
|
img_bytes = io.BytesIO()
|
|
img.save(img_bytes, format='PNG')
|
|
img_bytes.seek(0)
|
|
return img_bytes
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
"""Test /api/v1/health endpoint."""
|
|
|
|
def test_health_check_returns_200(self, client):
|
|
"""Test health check returns 200 OK."""
|
|
response = client.get("/api/v1/health")
|
|
assert response.status_code == 200
|
|
|
|
def test_health_check_response_structure(self, client):
|
|
"""Test health check response has correct structure."""
|
|
response = client.get("/api/v1/health")
|
|
data = response.json()
|
|
|
|
assert "status" in data
|
|
assert "model_loaded" in data
|
|
assert "gpu_available" in data
|
|
assert "version" in data
|
|
|
|
assert data["status"] == "healthy"
|
|
assert isinstance(data["model_loaded"], bool)
|
|
assert isinstance(data["gpu_available"], bool)
|
|
|
|
|
|
class TestInferEndpoint:
|
|
"""Test /api/v1/infer endpoint."""
|
|
|
|
@patch('src.inference.pipeline.InferencePipeline')
|
|
@patch('src.inference.yolo_detector.YOLODetector')
|
|
def test_infer_accepts_png_file(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
sample_png_bytes,
|
|
):
|
|
"""Test that /infer endpoint accepts PNG files."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Mock pipeline result
|
|
mock_result = Mock()
|
|
mock_result.fields = {"InvoiceNumber": "12345"}
|
|
mock_result.confidence = {"InvoiceNumber": 0.95}
|
|
mock_result.success = True
|
|
mock_result.errors = []
|
|
mock_result.raw_detections = []
|
|
mock_result.document_id = "test123"
|
|
mock_result.document_type = "invoice"
|
|
mock_result.processing_time_ms = 100.0
|
|
mock_result.visualization_path = None
|
|
mock_result.detections = []
|
|
mock_pipeline_instance.process_image.return_value = mock_result
|
|
|
|
# Make request
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
|
)
|
|
|
|
# Verify response
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["status"] == "success"
|
|
assert "result" in data
|
|
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
|
|
assert data["result"]["confidence"]["InvoiceNumber"] == 0.95
|
|
|
|
def test_infer_rejects_invalid_file_type(self, client):
|
|
"""Test that /infer rejects unsupported file types."""
|
|
invalid_file = io.BytesIO(b"fake txt content")
|
|
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.txt", invalid_file, "text/plain")},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "Unsupported file type" in response.json()["detail"]
|
|
|
|
def test_infer_requires_file(self, client):
|
|
"""Test that /infer requires a file parameter."""
|
|
response = client.post("/api/v1/infer")
|
|
|
|
assert response.status_code == 422 # Unprocessable Entity
|
|
|
|
@patch('src.inference.pipeline.InferencePipeline')
|
|
@patch('src.inference.yolo_detector.YOLODetector')
|
|
def test_infer_returns_cross_validation_if_available(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
sample_png_bytes,
|
|
):
|
|
"""Test that cross-validation results are included if available."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Mock pipeline result with cross-validation
|
|
mock_result = Mock()
|
|
mock_result.fields = {
|
|
"InvoiceNumber": "12345",
|
|
"OCR": "1234567",
|
|
"Amount": "100.00",
|
|
}
|
|
mock_result.confidence = {
|
|
"InvoiceNumber": 0.95,
|
|
"OCR": 0.90,
|
|
"Amount": 0.88,
|
|
}
|
|
mock_result.success = True
|
|
mock_result.errors = []
|
|
mock_result.raw_detections = []
|
|
mock_result.document_id = "test123"
|
|
mock_result.document_type = "invoice"
|
|
mock_result.processing_time_ms = 100.0
|
|
mock_result.visualization_path = None
|
|
mock_result.detections = []
|
|
|
|
# Add cross-validation result
|
|
mock_cv = Mock()
|
|
mock_cv.is_valid = True
|
|
mock_cv.payment_line_ocr = "1234567"
|
|
mock_cv.ocr_match = True
|
|
mock_result.cross_validation = mock_cv
|
|
|
|
mock_pipeline_instance.process_image.return_value = mock_result
|
|
|
|
# Make request
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
|
)
|
|
|
|
# Verify response includes cross-validation
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Note: cross_validation is not currently in the response schema
|
|
# This test documents that it should be added
|
|
|
|
|
|
@patch('src.inference.pipeline.InferencePipeline')
|
|
@patch('src.inference.yolo_detector.YOLODetector')
|
|
def test_infer_handles_processing_errors_gracefully(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
sample_png_bytes,
|
|
):
|
|
"""Test that processing errors are handled gracefully."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Make pipeline raise an error
|
|
mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed")
|
|
|
|
# Make request
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
|
)
|
|
|
|
# Verify error handling - service catches exceptions and returns partial results
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "partial"
|
|
assert data["result"]["success"] is False
|
|
assert len(data["result"]["errors"]) > 0
|
|
assert "Model inference failed" in data["result"]["errors"][0]
|
|
|
|
|
|
class TestResultsEndpoint:
|
|
"""Test /api/v1/results/{filename} endpoint."""
|
|
|
|
def test_get_result_image_returns_404_if_not_found(self, client):
|
|
"""Test that getting non-existent result returns 404."""
|
|
response = client.get("/api/v1/results/nonexistent.png")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
|
|
"""Test that existing result file is returned."""
|
|
# Get storage config from app
|
|
storage_config = test_app.extra.get("storage_config")
|
|
if not storage_config:
|
|
pytest.skip("Storage config not available in test app")
|
|
|
|
# Create a test result file
|
|
result_file = storage_config.result_dir / "test_result.png"
|
|
img = Image.new('RGB', (100, 100), color='red')
|
|
img.save(result_file)
|
|
|
|
# Request the file
|
|
response = client.get("/api/v1/results/test_result.png")
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "image/png"
|
|
|
|
|
|
class TestInferenceServiceImports:
|
|
"""Critical test to catch import errors."""
|
|
|
|
def test_inference_service_can_import_modules(self):
|
|
"""
|
|
Test that InferenceService can import its dependencies.
|
|
|
|
This test will fail if there are ImportError issues like:
|
|
- from ..inference.pipeline (wrong relative import)
|
|
- from src.web.inference (non-existent module)
|
|
|
|
It ensures the imports are correct before runtime.
|
|
"""
|
|
from src.web.services.inference import InferenceService
|
|
|
|
# Import the modules that InferenceService tries to import
|
|
from src.inference.pipeline import InferencePipeline
|
|
from src.inference.yolo_detector import YOLODetector
|
|
from src.pdf.renderer import render_pdf_to_images
|
|
|
|
# If we got here, all imports work correctly
|
|
assert InferencePipeline is not None
|
|
assert YOLODetector is not None
|
|
assert render_pdf_to_images is not None
|
|
assert InferenceService is not None
|