""" Integration tests for inference API endpoints. Tests the /api/v1/infer endpoint to ensure it works end-to-end. """ import pytest from pathlib import Path from unittest.mock import Mock, patch from fastapi.testclient import TestClient from PIL import Image import io from backend.web.app import create_app from backend.web.config import ModelConfig, StorageConfig, AppConfig @pytest.fixture def test_app(tmp_path): """Create test FastAPI application.""" # Setup test directories upload_dir = tmp_path / "uploads" result_dir = tmp_path / "results" upload_dir.mkdir() result_dir.mkdir() # Create test config app_config = AppConfig( model=ModelConfig( model_path=Path("runs/train/invoice_fields/weights/best.pt"), confidence_threshold=0.5, use_gpu=False, dpi=150, ), file=StorageConfig( upload_dir=upload_dir, result_dir=result_dir, allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"), max_file_size_mb=50, ), ) # Create app app = create_app(app_config) return app @pytest.fixture def client(test_app): """Create test client.""" return TestClient(test_app) @pytest.fixture def sample_png_bytes(): """Create sample PNG image bytes.""" img = Image.new('RGB', (800, 1200), color='white') img_bytes = io.BytesIO() img.save(img_bytes, format='PNG') img_bytes.seek(0) return img_bytes class TestHealthEndpoint: """Test /api/v1/health endpoint.""" def test_health_check_returns_200(self, client): """Test health check returns 200 OK.""" response = client.get("/api/v1/health") assert response.status_code == 200 def test_health_check_response_structure(self, client): """Test health check response has correct structure.""" response = client.get("/api/v1/health") data = response.json() assert "status" in data assert "model_loaded" in data assert "gpu_available" in data assert "version" in data assert data["status"] == "healthy" assert isinstance(data["model_loaded"], bool) assert isinstance(data["gpu_available"], bool) class TestInferEndpoint: """Test /api/v1/infer endpoint.""" @patch('backend.pipeline.pipeline.InferencePipeline') @patch('backend.pipeline.yolo_detector.YOLODetector') def test_infer_accepts_png_file( self, mock_yolo_detector, mock_pipeline, client, sample_png_bytes, ): """Test that /infer endpoint accepts PNG files.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Mock pipeline result mock_result = Mock() mock_result.fields = {"InvoiceNumber": "12345"} mock_result.confidence = {"InvoiceNumber": 0.95} mock_result.success = True mock_result.errors = [] mock_result.raw_detections = [] mock_result.document_id = "test123" mock_result.document_type = "invoice" mock_result.processing_time_ms = 100.0 mock_result.visualization_path = None mock_result.detections = [] mock_pipeline_instance.process_image.return_value = mock_result # Make request response = client.post( "/api/v1/infer", files={"file": ("test.png", sample_png_bytes, "image/png")}, ) # Verify response assert response.status_code == 200 data = response.json() assert data["status"] == "success" assert "result" in data assert data["result"]["fields"]["InvoiceNumber"] == "12345" assert data["result"]["confidence"]["InvoiceNumber"] == 0.95 def test_infer_rejects_invalid_file_type(self, client): """Test that /infer rejects unsupported file types.""" invalid_file = io.BytesIO(b"fake txt content") response = client.post( "/api/v1/infer", files={"file": ("test.txt", invalid_file, "text/plain")}, ) assert response.status_code == 400 assert "Unsupported file type" in response.json()["detail"] def test_infer_requires_file(self, client): """Test that /infer requires a file parameter.""" response = client.post("/api/v1/infer") assert response.status_code == 422 # Unprocessable Entity @patch('backend.pipeline.pipeline.InferencePipeline') @patch('backend.pipeline.yolo_detector.YOLODetector') def test_infer_returns_cross_validation_if_available( self, mock_yolo_detector, mock_pipeline, client, sample_png_bytes, ): """Test that cross-validation results are included if available.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Mock pipeline result with cross-validation mock_result = Mock() mock_result.fields = { "InvoiceNumber": "12345", "OCR": "1234567", "Amount": "100.00", } mock_result.confidence = { "InvoiceNumber": 0.95, "OCR": 0.90, "Amount": 0.88, } mock_result.success = True mock_result.errors = [] mock_result.raw_detections = [] mock_result.document_id = "test123" mock_result.document_type = "invoice" mock_result.processing_time_ms = 100.0 mock_result.visualization_path = None mock_result.detections = [] # Add cross-validation result mock_cv = Mock() mock_cv.is_valid = True mock_cv.payment_line_ocr = "1234567" mock_cv.ocr_match = True mock_result.cross_validation = mock_cv mock_pipeline_instance.process_image.return_value = mock_result # Make request response = client.post( "/api/v1/infer", files={"file": ("test.png", sample_png_bytes, "image/png")}, ) # Verify response includes cross-validation assert response.status_code == 200 data = response.json() # Note: cross_validation is not currently in the response schema # This test documents that it should be added @patch('backend.pipeline.pipeline.InferencePipeline') @patch('backend.pipeline.yolo_detector.YOLODetector') def test_infer_handles_processing_errors_gracefully( self, mock_yolo_detector, mock_pipeline, client, sample_png_bytes, ): """Test that processing errors are handled gracefully.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Make pipeline raise an error mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed") # Make request response = client.post( "/api/v1/infer", files={"file": ("test.png", sample_png_bytes, "image/png")}, ) # Verify error handling - service catches exceptions and returns partial results assert response.status_code == 200 data = response.json() assert data["status"] == "partial" assert data["result"]["success"] is False assert len(data["result"]["errors"]) > 0 assert "Model inference failed" in data["result"]["errors"][0] class TestResultsEndpoint: """Test /api/v1/results/{filename} endpoint.""" def test_get_result_image_returns_404_if_not_found(self, client): """Test that getting non-existent result returns 404.""" response = client.get("/api/v1/results/nonexistent.png") assert response.status_code == 404 def test_get_result_image_returns_file_if_exists(self, client, tmp_path): """Test that existing result file is returned.""" # Create a test result file in temp directory result_dir = tmp_path / "results" result_dir.mkdir(exist_ok=True) result_file = result_dir / "test_result.png" img = Image.new('RGB', (100, 100), color='red') img.save(result_file) # Mock the storage helper to return our test file path with patch( "backend.web.api.v1.public.inference.get_storage_helper" ) as mock_storage: mock_helper = Mock() mock_helper.get_result_local_path.return_value = result_file mock_storage.return_value = mock_helper # Request the file response = client.get("/api/v1/results/test_result.png") assert response.status_code == 200 assert response.headers["content-type"] == "image/png" class TestInferenceServiceImports: """Critical test to catch import errors.""" def test_inference_service_can_import_modules(self): """ Test that InferenceService can import its dependencies. This test will fail if there are ImportError issues like: - from ..inference.pipeline (wrong relative import) - from backend.web.inference (non-existent module) It ensures the imports are correct before runtime. """ from backend.web.services.inference import InferenceService # Import the modules that InferenceService tries to import from backend.pipeline.pipeline import InferencePipeline from backend.pipeline.yolo_detector import YOLODetector from shared.pdf.renderer import render_pdf_to_images # If we got here, all imports work correctly assert InferencePipeline is not None assert YOLODetector is not None assert render_pdf_to_images is not None assert InferenceService is not None class TestBusinessFeaturesAPI: """Tests for business features (line items, VAT) in API.""" @patch('backend.pipeline.pipeline.InferencePipeline') @patch('backend.pipeline.yolo_detector.YOLODetector') def test_infer_with_extract_line_items_false_by_default( self, mock_yolo_detector, mock_pipeline, client, sample_png_bytes, ): """Test that extract_line_items defaults to False.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Mock pipeline result mock_result = Mock() mock_result.fields = {"InvoiceNumber": "12345"} mock_result.confidence = {"InvoiceNumber": 0.95} mock_result.success = True mock_result.errors = [] mock_result.raw_detections = [] mock_result.document_id = "test123" mock_result.document_type = "invoice" mock_result.processing_time_ms = 100.0 mock_result.visualization_path = None mock_result.detections = [] mock_pipeline_instance.process_image.return_value = mock_result # Make request without extract_line_items parameter response = client.post( "/api/v1/infer", files={"file": ("test.png", sample_png_bytes, "image/png")}, ) assert response.status_code == 200 data = response.json() # Business features should be None when not requested assert data["result"]["line_items"] is None assert data["result"]["vat_summary"] is None assert data["result"]["vat_validation"] is None @patch('backend.pipeline.pipeline.InferencePipeline') @patch('backend.pipeline.yolo_detector.YOLODetector') def test_infer_with_extract_line_items_returns_business_features( self, mock_yolo_detector, mock_pipeline, client, tmp_path, ): """Test that extract_line_items=True returns business features.""" # Setup mocks mock_detector_instance = Mock() mock_pipeline_instance = Mock() mock_yolo_detector.return_value = mock_detector_instance mock_pipeline.return_value = mock_pipeline_instance # Create a test PDF file pdf_path = tmp_path / "test.pdf" pdf_path.write_bytes(b'%PDF-1.4 fake pdf content') # Mock pipeline result with business features mock_result = Mock() mock_result.fields = {"Amount": "12500,00"} mock_result.confidence = {"Amount": 0.95} mock_result.success = True mock_result.errors = [] mock_result.raw_detections = [] mock_result.document_id = "test123" mock_result.document_type = "invoice" mock_result.processing_time_ms = 150.0 mock_result.visualization_path = None mock_result.detections = [] # Mock line items mock_result.line_items = Mock() mock_result._line_items_to_json.return_value = { "items": [ { "row_index": 0, "description": "Product A", "quantity": "2", "unit": "st", "unit_price": "5000,00", "amount": "10000,00", "article_number": "ART001", "vat_rate": "25", "confidence": 0.9, } ], "header_row": ["Beskrivning", "Antal", "Pris", "Belopp"], "total_amount": "10000,00", } # Mock VAT summary mock_result.vat_summary = Mock() mock_result._vat_summary_to_json.return_value = { "breakdowns": [ { "rate": 25.0, "base_amount": "10000,00", "vat_amount": "2500,00", "source": "regex", } ], "total_excl_vat": "10000,00", "total_vat": "2500,00", "total_incl_vat": "12500,00", "confidence": 0.9, } # Mock VAT validation mock_result.vat_validation = Mock() mock_result._vat_validation_to_json.return_value = { "is_valid": True, "confidence_score": 0.95, "math_checks": [ { "rate": 25.0, "base_amount": 10000.0, "expected_vat": 2500.0, "actual_vat": 2500.0, "is_valid": True, "tolerance": 0.5, } ], "total_check": True, "line_items_vs_summary": True, "amount_consistency": True, "needs_review": False, "review_reasons": [], } mock_pipeline_instance.process_pdf.return_value = mock_result # Make request with extract_line_items=true response = client.post( "/api/v1/infer", files={"file": ("test.pdf", pdf_path.open("rb"), "application/pdf")}, data={"extract_line_items": "true"}, ) assert response.status_code == 200 data = response.json() # Verify business features are included assert data["result"]["line_items"] is not None assert len(data["result"]["line_items"]["items"]) == 1 assert data["result"]["line_items"]["items"][0]["description"] == "Product A" assert data["result"]["line_items"]["items"][0]["amount"] == "10000,00" assert data["result"]["vat_summary"] is not None assert len(data["result"]["vat_summary"]["breakdowns"]) == 1 assert data["result"]["vat_summary"]["breakdowns"][0]["rate"] == 25.0 assert data["result"]["vat_summary"]["total_incl_vat"] == "12500,00" assert data["result"]["vat_validation"] is not None assert data["result"]["vat_validation"]["is_valid"] is True assert data["result"]["vat_validation"]["confidence_score"] == 0.95 def test_schema_imports_work_correctly(self): """Test that all business feature schemas can be imported.""" from backend.web.schemas.inference import ( LineItemSchema, LineItemsResultSchema, VATBreakdownSchema, VATSummarySchema, MathCheckResultSchema, VATValidationResultSchema, InferenceResult, ) # Verify schemas can be instantiated line_item = LineItemSchema( row_index=0, description="Test", amount="100", ) assert line_item.description == "Test" vat_breakdown = VATBreakdownSchema( rate=25.0, base_amount="100", vat_amount="25", ) assert vat_breakdown.rate == 25.0 # Verify InferenceResult includes business feature fields result = InferenceResult( document_id="test", success=True, processing_time_ms=100.0, ) assert result.line_items is None assert result.vat_summary is None assert result.vat_validation is None def test_service_result_has_business_feature_fields(self): """Test that ServiceResult dataclass includes business feature fields.""" from backend.web.services.inference import ServiceResult result = ServiceResult(document_id="test123") # Verify business feature fields exist and default to None assert result.line_items is None assert result.vat_summary is None assert result.vat_validation is None # Verify they can be set result.line_items = {"items": []} result.vat_summary = {"breakdowns": []} result.vat_validation = {"is_valid": True} assert result.line_items == {"items": []} assert result.vat_summary == {"breakdowns": []} assert result.vat_validation == {"is_valid": True}