""" Integration tests for inference API endpoints. Tests the /api/v1/infer endpoint to ensure it works end-to-end. """ import pytest from pathlib import Path from unittest.mock import Mock, patch from fastapi.testclient import TestClient from PIL import Image import io from inference.web.app import create_app from inference.web.config import ModelConfig, StorageConfig, AppConfig @pytest.fixture def test_app(tmp_path): """Create test FastAPI application.""" # Setup test directories upload_dir = tmp_path / "uploads" result_dir = tmp_path / "results" upload_dir.mkdir() result_dir.mkdir() # Create test config app_config = AppConfig( model=ModelConfig( model_path=Path("runs/train/invoice_fields/weights/best.pt"), confidence_threshold=0.5, use_gpu=False, dpi=150, ), file=StorageConfig( upload_dir=upload_dir, result_dir=result_dir, allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"), max_file_size_mb=50, ), ) # Create app app = create_app(app_config) return app @pytest.fixture def client(test_app): """Create test client.""" return TestClient(test_app) @pytest.fixture def sample_png_bytes(): """Create sample PNG image bytes.""" img = Image.new('RGB', (800, 1200), color='white') img_bytes = io.BytesIO() img.save(img_bytes, format='PNG') img_bytes.seek(0) return img_bytes class TestHealthEndpoint: """Test /api/v1/health endpoint.""" def test_health_check_returns_200(self, client): """Test health check returns 200 OK.""" response = client.get("/api/v1/health") assert response.status_code == 200 def test_health_check_response_structure(self, client): """Test health check response has correct structure.""" response = client.get("/api/v1/health") data = response.json() assert "status" in data assert "model_loaded" in data assert "gpu_available" in data assert "version" in data assert data["status"] == "healthy" assert isinstance(data["model_loaded"], bool) assert isinstance(data["gpu_available"], bool) class TestInferEndpoint: """Test /api/v1/infer endpoint.""" @patch('inference.pipeline.pipeline.InferencePipeline') @patch('inference.pipeline.yolo_detector.YOLODetector') def test_infer_accepts_png_file( self, mock_yolo_detector, mock_pipeline, client, sample_png_bytes, ): """Test that /infer endpoint accepts PNG files.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Mock pipeline result mock_result = Mock() mock_result.fields = {"InvoiceNumber": "12345"} mock_result.confidence = {"InvoiceNumber": 0.95} mock_result.success = True mock_result.errors = [] mock_result.raw_detections = [] mock_result.document_id = "test123" mock_result.document_type = "invoice" mock_result.processing_time_ms = 100.0 mock_result.visualization_path = None mock_result.detections = [] mock_pipeline_instance.process_image.return_value = mock_result # Make request response = client.post( "/api/v1/infer", files={"file": ("test.png", sample_png_bytes, "image/png")}, ) # Verify response assert response.status_code == 200 data = response.json() assert data["status"] == "success" assert "result" in data assert data["result"]["fields"]["InvoiceNumber"] == "12345" assert data["result"]["confidence"]["InvoiceNumber"] == 0.95 def test_infer_rejects_invalid_file_type(self, client): """Test that /infer rejects unsupported file types.""" invalid_file = io.BytesIO(b"fake txt content") response = client.post( "/api/v1/infer", files={"file": ("test.txt", invalid_file, "text/plain")}, ) assert response.status_code == 400 assert "Unsupported file type" in response.json()["detail"] def test_infer_requires_file(self, client): """Test that /infer requires a file parameter.""" response = client.post("/api/v1/infer") assert response.status_code == 422 # Unprocessable Entity @patch('inference.pipeline.pipeline.InferencePipeline') @patch('inference.pipeline.yolo_detector.YOLODetector') def test_infer_returns_cross_validation_if_available( self, mock_yolo_detector, mock_pipeline, client, sample_png_bytes, ): """Test that cross-validation results are included if available.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Mock pipeline result with cross-validation mock_result = Mock() mock_result.fields = { "InvoiceNumber": "12345", "OCR": "1234567", "Amount": "100.00", } mock_result.confidence = { "InvoiceNumber": 0.95, "OCR": 0.90, "Amount": 0.88, } mock_result.success = True mock_result.errors = [] mock_result.raw_detections = [] mock_result.document_id = "test123" mock_result.document_type = "invoice" mock_result.processing_time_ms = 100.0 mock_result.visualization_path = None mock_result.detections = [] # Add cross-validation result mock_cv = Mock() mock_cv.is_valid = True mock_cv.payment_line_ocr = "1234567" mock_cv.ocr_match = True mock_result.cross_validation = mock_cv mock_pipeline_instance.process_image.return_value = mock_result # Make request response = client.post( "/api/v1/infer", files={"file": ("test.png", sample_png_bytes, "image/png")}, ) # Verify response includes cross-validation assert response.status_code == 200 data = response.json() # Note: cross_validation is not currently in the response schema # This test documents that it should be added @patch('inference.pipeline.pipeline.InferencePipeline') @patch('inference.pipeline.yolo_detector.YOLODetector') def test_infer_handles_processing_errors_gracefully( self, mock_yolo_detector, mock_pipeline, client, sample_png_bytes, ): """Test that processing errors are handled gracefully.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Make pipeline raise an error mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed") # Make request response = client.post( "/api/v1/infer", files={"file": ("test.png", sample_png_bytes, "image/png")}, ) # Verify error handling - service catches exceptions and returns partial results assert response.status_code == 200 data = response.json() assert data["status"] == "partial" assert data["result"]["success"] is False assert len(data["result"]["errors"]) > 0 assert "Model inference failed" in data["result"]["errors"][0] class TestResultsEndpoint: """Test /api/v1/results/{filename} endpoint.""" def test_get_result_image_returns_404_if_not_found(self, client): """Test that getting non-existent result returns 404.""" response = client.get("/api/v1/results/nonexistent.png") assert response.status_code == 404 def test_get_result_image_returns_file_if_exists(self, client, tmp_path): """Test that existing result file is returned.""" # Create a test result file in temp directory result_dir = tmp_path / "results" result_dir.mkdir(exist_ok=True) result_file = result_dir / "test_result.png" img = Image.new('RGB', (100, 100), color='red') img.save(result_file) # Mock the storage helper to return our test file path with patch( "inference.web.api.v1.public.inference.get_storage_helper" ) as mock_storage: mock_helper = Mock() mock_helper.get_result_local_path.return_value = result_file mock_storage.return_value = mock_helper # Request the file response = client.get("/api/v1/results/test_result.png") assert response.status_code == 200 assert response.headers["content-type"] == "image/png" class TestInferenceServiceImports: """Critical test to catch import errors.""" def test_inference_service_can_import_modules(self): """ Test that InferenceService can import its dependencies. This test will fail if there are ImportError issues like: - from ..inference.pipeline (wrong relative import) - from inference.web.inference (non-existent module) It ensures the imports are correct before runtime. """ from inference.web.services.inference import InferenceService # Import the modules that InferenceService tries to import from inference.pipeline.pipeline import InferencePipeline from inference.pipeline.yolo_detector import YOLODetector from shared.pdf.renderer import render_pdf_to_images # If we got here, all imports work correctly assert InferencePipeline is not None assert YOLODetector is not None assert render_pdf_to_images is not None assert InferenceService is not None