""" Inference Pipeline Integration Tests Tests the complete pipeline from input to output. Note: These tests use mocks for YOLO and OCR to avoid requiring actual models, but test the integration of pipeline components. """ from dataclasses import dataclass, field from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch import pytest import numpy as np from backend.pipeline.pipeline import ( InferencePipeline, InferenceResult, CrossValidationResult, ) from backend.pipeline.yolo_detector import Detection from backend.pipeline.field_extractor import ExtractedField @pytest.fixture def mock_detection(): """Create a mock detection.""" return Detection( class_id=0, class_name="invoice_number", confidence=0.95, bbox=(100, 50, 200, 30), page_no=0, ) @pytest.fixture def mock_extracted_field(): """Create a mock extracted field.""" return ExtractedField( field_name="InvoiceNumber", raw_text="INV-2024-001", normalized_value="INV-2024-001", confidence=0.95, bbox=(100, 50, 200, 30), page_no=0, is_valid=True, ) class TestInferenceResultConstruction: """Tests for InferenceResult construction and methods.""" def test_default_result(self): """Test default InferenceResult values.""" result = InferenceResult() assert result.document_id is None assert result.success is False assert result.fields == {} assert result.confidence == {} assert result.raw_detections == [] assert result.extracted_fields == [] assert result.errors == [] assert result.fallback_used is False assert result.cross_validation is None def test_result_to_json(self): """Test JSON serialization of result.""" result = InferenceResult( document_id="test-doc", success=True, fields={ "InvoiceNumber": "INV-001", "Amount": "1500.00", }, confidence={ "InvoiceNumber": 0.95, "Amount": 0.92, }, bboxes={ "InvoiceNumber": (100, 50, 200, 30), }, ) json_data = result.to_json() assert json_data["DocumentId"] == "test-doc" assert json_data["success"] is True assert json_data["InvoiceNumber"] == "INV-001" assert json_data["Amount"] == "1500.00" assert json_data["confidence"]["InvoiceNumber"] == 0.95 assert "bboxes" in json_data def test_result_get_field(self): """Test getting field value and confidence.""" result = InferenceResult( fields={"InvoiceNumber": "INV-001"}, confidence={"InvoiceNumber": 0.95}, ) value, conf = result.get_field("InvoiceNumber") assert value == "INV-001" assert conf == 0.95 value, conf = result.get_field("Amount") assert value is None assert conf == 0.0 class TestCrossValidation: """Tests for cross-validation logic.""" def test_cross_validation_default(self): """Test default CrossValidationResult values.""" cv = CrossValidationResult() assert cv.is_valid is False assert cv.ocr_match is None assert cv.amount_match is None assert cv.bankgiro_match is None assert cv.plusgiro_match is None assert cv.payment_line_ocr is None assert cv.payment_line_amount is None assert cv.details == [] def test_cross_validation_with_matches(self): """Test CrossValidationResult with matches.""" cv = CrossValidationResult( is_valid=True, ocr_match=True, amount_match=True, bankgiro_match=True, payment_line_ocr="12345678901234", payment_line_amount="1500.00", payment_line_account="1234-5678", payment_line_account_type="bankgiro", details=["OCR match", "Amount match", "Bankgiro match"], ) assert cv.is_valid is True assert cv.ocr_match is True assert cv.amount_match is True assert len(cv.details) == 3 class TestPipelineMergeFields: """Tests for field merging logic.""" def test_merge_selects_highest_confidence(self): """Test that merge selects highest confidence for duplicate fields.""" # Create mock pipeline with minimal mocking with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) pipeline.payment_line_parser = MagicMock() pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False) result = InferenceResult() result.extracted_fields = [ ExtractedField( field_name="InvoiceNumber", raw_text="INV-001", normalized_value="INV-001", confidence=0.85, detection_confidence=0.90, ocr_confidence=0.85, bbox=(100, 50, 200, 30), page_no=0, is_valid=True, ), ExtractedField( field_name="InvoiceNumber", raw_text="INV-001", normalized_value="INV-001", confidence=0.95, # Higher confidence detection_confidence=0.95, ocr_confidence=0.95, bbox=(105, 52, 198, 28), page_no=0, is_valid=True, ), ] pipeline._merge_fields(result) assert result.fields["InvoiceNumber"] == "INV-001" assert result.confidence["InvoiceNumber"] == 0.95 def test_merge_skips_invalid_fields(self): """Test that merge skips invalid extracted fields.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) pipeline.payment_line_parser = MagicMock() pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False) result = InferenceResult() result.extracted_fields = [ ExtractedField( field_name="InvoiceNumber", raw_text="", normalized_value=None, confidence=0.95, detection_confidence=0.95, ocr_confidence=0.95, bbox=(100, 50, 200, 30), page_no=0, is_valid=False, # Invalid ), ExtractedField( field_name="Amount", raw_text="1500.00", normalized_value="1500.00", confidence=0.92, detection_confidence=0.92, ocr_confidence=0.92, bbox=(200, 100, 100, 25), page_no=0, is_valid=True, ), ] pipeline._merge_fields(result) assert "InvoiceNumber" not in result.fields assert result.fields["Amount"] == "1500.00" class TestPaymentLineValidation: """Tests for payment line cross-validation.""" def test_payment_line_overrides_ocr(self): """Test that payment line OCR overrides detected OCR.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) # Mock payment line parser mock_parsed = MagicMock() mock_parsed.is_valid = True mock_parsed.ocr_number = "12345678901234" mock_parsed.amount = "1500.00" mock_parsed.account_number = "12345678" pipeline.payment_line_parser = MagicMock() pipeline.payment_line_parser.parse.return_value = mock_parsed result = InferenceResult( fields={ "payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#", "OCR": "99999999999999", # Different OCR }, confidence={"OCR": 0.85}, ) pipeline._cross_validate_payment_line(result) # Payment line OCR should override assert result.fields["OCR"] == "12345678901234" assert result.confidence["OCR"] == 0.95 def test_payment_line_overrides_amount(self): """Test that payment line amount overrides detected amount.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) mock_parsed = MagicMock() mock_parsed.is_valid = True mock_parsed.ocr_number = None mock_parsed.amount = "2500.50" mock_parsed.account_number = None pipeline.payment_line_parser = MagicMock() pipeline.payment_line_parser.parse.return_value = mock_parsed result = InferenceResult( fields={ "payment_line": "# ... # 2500 50 5 > ...", "Amount": "2500.00", # Slightly different }, confidence={"Amount": 0.80}, ) pipeline._cross_validate_payment_line(result) assert result.fields["Amount"] == "2500.50" assert result.confidence["Amount"] == 0.95 def test_cross_validation_records_matches(self): """Test that cross-validation records match status.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) mock_parsed = MagicMock() mock_parsed.is_valid = True mock_parsed.ocr_number = "12345678901234" mock_parsed.amount = "1500.00" mock_parsed.account_number = "12345678" pipeline.payment_line_parser = MagicMock() pipeline.payment_line_parser.parse.return_value = mock_parsed result = InferenceResult( fields={ "payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#", "OCR": "12345678901234", # Matching "Amount": "1500.00", # Matching "Bankgiro": "1234-5678", # Matching }, confidence={}, ) pipeline._cross_validate_payment_line(result) assert result.cross_validation is not None assert result.cross_validation.ocr_match is True assert result.cross_validation.amount_match is True assert result.cross_validation.is_valid is True class TestFallbackLogic: """Tests for fallback detection logic.""" def test_needs_fallback_when_key_fields_missing(self): """Test fallback is triggered when key fields missing.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) # Only one key field present result = InferenceResult(fields={"Amount": "1500.00"}) assert pipeline._needs_fallback(result) is True def test_no_fallback_when_fields_present(self): """Test no fallback when key fields present.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) # All key fields present result = InferenceResult( fields={ "Amount": "1500.00", "InvoiceNumber": "INV-001", "OCR": "12345678901234", } ) assert pipeline._needs_fallback(result) is False class TestPatternExtraction: """Tests for fallback pattern extraction.""" def test_extract_amount_pattern(self): """Test amount extraction with regex.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) text = "Att betala: 1 500,00 SEK" result = InferenceResult() pipeline._extract_with_patterns(text, result) assert "Amount" in result.fields assert result.confidence["Amount"] == 0.5 def test_extract_bankgiro_pattern(self): """Test bankgiro extraction with regex.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) text = "Bankgiro: 1234-5678" result = InferenceResult() pipeline._extract_with_patterns(text, result) assert "Bankgiro" in result.fields assert result.fields["Bankgiro"] == "1234-5678" def test_extract_ocr_pattern(self): """Test OCR extraction with regex.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) text = "OCR: 12345678901234567890" result = InferenceResult() pipeline._extract_with_patterns(text, result) assert "OCR" in result.fields assert result.fields["OCR"] == "12345678901234567890" def test_does_not_override_existing_fields(self): """Test pattern extraction doesn't override existing fields.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) text = "Fakturanr: 999" result = InferenceResult(fields={"InvoiceNumber": "INV-001"}) pipeline._extract_with_patterns(text, result) # Should keep existing value assert result.fields["InvoiceNumber"] == "INV-001" class TestAmountNormalization: """Tests for amount normalization.""" def test_normalize_swedish_format(self): """Test normalizing Swedish amount format.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) # Swedish format: space as thousands separator, comma as decimal assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00 # Standard format: dot as decimal assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00 # Swedish format with comma as decimal assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00 def test_normalize_invalid_amount(self): """Test normalizing invalid amount returns None.""" with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): pipeline = InferencePipeline.__new__(InferencePipeline) assert pipeline._normalize_amount_for_compare("invalid") is None assert pipeline._normalize_amount_for_compare("") is None class TestResultSerialization: """Tests for result serialization with cross-validation.""" def test_to_json_with_cross_validation(self): """Test JSON serialization includes cross-validation.""" cv = CrossValidationResult( is_valid=True, ocr_match=True, amount_match=True, payment_line_ocr="12345678901234", payment_line_amount="1500.00", details=["OCR match", "Amount match"], ) result = InferenceResult( document_id="test-doc", success=True, fields={"InvoiceNumber": "INV-001"}, cross_validation=cv, ) json_data = result.to_json() assert "cross_validation" in json_data assert json_data["cross_validation"]["is_valid"] is True assert json_data["cross_validation"]["ocr_match"] is True assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"