""" Tests for Inference Pipeline Tests the cross-validation logic between payment_line and detected fields: - OCR override from payment_line - Amount override from payment_line - Bankgiro/Plusgiro comparison (no override) - Validation scoring """ import pytest from unittest.mock import MagicMock, patch from backend.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult class TestCrossValidationResult: """Tests for CrossValidationResult dataclass.""" def test_default_values(self): """Test default values.""" cv = CrossValidationResult() assert cv.ocr_match is None assert cv.amount_match is None assert cv.bankgiro_match is None assert cv.plusgiro_match is None assert cv.payment_line_ocr is None assert cv.payment_line_amount is None assert cv.payment_line_account is None assert cv.payment_line_account_type is None def test_attributes(self): """Test setting attributes.""" cv = CrossValidationResult() cv.ocr_match = True cv.amount_match = True cv.payment_line_ocr = '12345678901' cv.payment_line_amount = '100' cv.details = ['OCR match', 'Amount match'] assert cv.ocr_match is True assert cv.amount_match is True assert cv.payment_line_ocr == '12345678901' assert 'OCR match' in cv.details class TestInferenceResult: """Tests for InferenceResult dataclass.""" def test_default_fields(self): """Test default field values.""" result = InferenceResult() assert result.fields == {} assert result.confidence == {} assert result.errors == [] def test_set_fields(self): """Test setting field values.""" result = InferenceResult() result.fields = { 'OCR': '12345678901', 'Amount': '100', 'Bankgiro': '782-1713' } result.confidence = { 'OCR': 0.95, 'Amount': 0.90, 'Bankgiro': 0.88 } assert result.fields['OCR'] == '12345678901' assert result.fields['Amount'] == '100' assert result.fields['Bankgiro'] == '782-1713' def test_cross_validation_assignment(self): """Test cross validation assignment.""" result = InferenceResult() result.fields = {'OCR': '12345678901'} cv = CrossValidationResult() cv.ocr_match = True cv.payment_line_ocr = '12345678901' result.cross_validation = cv assert result.cross_validation is not None assert result.cross_validation.ocr_match is True class TestPaymentLineParsingInPipeline: """Tests for payment_line parsing in cross-validation.""" def test_parse_payment_line_format(self): """Test parsing of payment_line format: OCR:xxx Amount:xxx BG:xxx""" # Simulate the parsing logic from pipeline payment_line = "OCR:310196187399952 Amount:11699 BG:782-1713" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') == '310196187399952' assert pl_parts.get('AMOUNT') == '11699' assert pl_parts.get('BG') == '782-1713' def test_parse_payment_line_with_plusgiro(self): """Test parsing with Plusgiro.""" payment_line = "OCR:12345678901 Amount:500 PG:1234567-8" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') == '12345678901' assert pl_parts.get('PG') == '1234567-8' assert pl_parts.get('BG') is None def test_parse_empty_payment_line(self): """Test parsing empty payment_line.""" payment_line = "" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') is None assert pl_parts.get('AMOUNT') is None class TestOCROverride: """Tests for OCR override logic.""" def test_ocr_override_when_different(self): """Test OCR is overridden when payment_line value differs.""" result = InferenceResult() result.fields = {'OCR': 'wrong_ocr_12345', 'payment_line': 'OCR:correct_ocr_67890 Amount:100 BG:782-1713'} # Simulate the override logic payment_line = result.fields.get('payment_line') pl_parts = {} for part in str(payment_line).split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value payment_line_ocr = pl_parts.get('OCR') # Override detected OCR with payment_line OCR if payment_line_ocr: result.fields['OCR'] = payment_line_ocr assert result.fields['OCR'] == 'correct_ocr_67890' def test_ocr_no_override_when_no_payment_line(self): """Test OCR is not overridden when no payment_line.""" result = InferenceResult() result.fields = {'OCR': 'original_ocr_12345'} # No payment_line, no override assert result.fields['OCR'] == 'original_ocr_12345' class TestAmountOverride: """Tests for Amount override logic.""" def test_amount_override(self): """Test Amount is overridden from payment_line.""" result = InferenceResult() result.fields = { 'Amount': '999.00', 'payment_line': 'OCR:12345 Amount:11699 BG:782-1713' } payment_line = result.fields.get('payment_line') pl_parts = {} for part in str(payment_line).split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value payment_line_amount = pl_parts.get('AMOUNT') if payment_line_amount: result.fields['Amount'] = payment_line_amount assert result.fields['Amount'] == '11699' class TestBankgiroComparison: """Tests for Bankgiro comparison (no override).""" def test_bankgiro_match(self): """Test Bankgiro match detection.""" import re detected_bankgiro = '782-1713' payment_line_account = '782-1713' det_digits = re.sub(r'\D', '', detected_bankgiro) pl_digits = re.sub(r'\D', '', payment_line_account) assert det_digits == pl_digits assert det_digits == '7821713' def test_bankgiro_mismatch(self): """Test Bankgiro mismatch detection.""" import re detected_bankgiro = '782-1713' payment_line_account = '123-4567' det_digits = re.sub(r'\D', '', detected_bankgiro) pl_digits = re.sub(r'\D', '', payment_line_account) assert det_digits != pl_digits def test_bankgiro_not_overridden(self): """Test that Bankgiro is NOT overridden from payment_line.""" result = InferenceResult() result.fields = { 'Bankgiro': '999-9999', # Different value 'payment_line': 'OCR:12345 Amount:100 BG:782-1713' } # Bankgiro should NOT be overridden (per current logic) # Only compared for validation original_bankgiro = result.fields['Bankgiro'] # The override logic explicitly skips Bankgiro # So we verify it remains unchanged assert result.fields['Bankgiro'] == '999-9999' assert result.fields['Bankgiro'] == original_bankgiro class TestValidationScoring: """Tests for validation scoring logic.""" def test_all_fields_match(self): """Test score when all fields match.""" matches = [True, True, True] # OCR, Amount, Bankgiro match_count = sum(1 for m in matches if m) total = len(matches) assert match_count == 3 assert total == 3 def test_partial_match(self): """Test score with partial matches.""" matches = [True, True, False] # OCR match, Amount match, Bankgiro mismatch match_count = sum(1 for m in matches if m) assert match_count == 2 def test_no_matches(self): """Test score when nothing matches.""" matches = [False, False, False] match_count = sum(1 for m in matches if m) assert match_count == 0 def test_only_count_present_fields(self): """Test that only present fields are counted.""" # When invoice has both BG and PG but payment_line only has BG, # we should only count BG in validation payment_line_account_type = 'bankgiro' bankgiro_match = True plusgiro_match = None # Not compared because payment_line doesn't have PG matches = [] if payment_line_account_type == 'bankgiro' and bankgiro_match is not None: matches.append(bankgiro_match) elif payment_line_account_type == 'plusgiro' and plusgiro_match is not None: matches.append(plusgiro_match) assert len(matches) == 1 assert matches[0] is True class TestAmountNormalization: """Tests for amount normalization for comparison.""" def test_normalize_amount_with_comma(self): """Test normalizing amount with comma decimal.""" import re amount = "11699,00" normalized = re.sub(r'[^\d]', '', amount) # Remove trailing zeros for öre if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' def test_normalize_amount_with_dot(self): """Test normalizing amount with dot decimal.""" import re amount = "11699.00" normalized = re.sub(r'[^\d]', '', amount) if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' def test_normalize_amount_with_space_separator(self): """Test normalizing amount with space thousand separator.""" import re amount = "11 699,00" normalized = re.sub(r'[^\d]', '', amount) if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' class TestBusinessFeatures: """Tests for business invoice features (line items, VAT, validation).""" def test_inference_result_has_business_fields(self): """Test that InferenceResult has business feature fields.""" result = InferenceResult() assert result.line_items is None assert result.vat_summary is None assert result.vat_validation is None def test_to_json_without_business_features(self): """Test to_json works without business features.""" result = InferenceResult() result.fields = {'InvoiceNumber': '12345'} result.confidence = {'InvoiceNumber': 0.95} json_result = result.to_json() assert json_result['InvoiceNumber'] == '12345' assert 'line_items' not in json_result assert 'vat_summary' not in json_result assert 'vat_validation' not in json_result def test_to_json_with_line_items(self): """Test to_json includes line items when present.""" from backend.table.line_items_extractor import LineItem, LineItemsResult result = InferenceResult() result.fields = {'Amount': '12500.00'} result.line_items = LineItemsResult( items=[ LineItem( row_index=0, description="Product A", quantity="2", unit_price="5000,00", amount="10000,00", vat_rate="25", confidence=0.9 ) ], header_row=["Beskrivning", "Antal", "Pris", "Belopp", "Moms"], raw_html="...
" ) json_result = result.to_json() assert 'line_items' in json_result assert len(json_result['line_items']['items']) == 1 assert json_result['line_items']['items'][0]['description'] == "Product A" assert json_result['line_items']['items'][0]['amount'] == "10000,00" def test_to_json_with_vat_summary(self): """Test to_json includes VAT summary when present.""" from backend.vat.vat_extractor import VATBreakdown, VATSummary result = InferenceResult() result.vat_summary = VATSummary( breakdowns=[ VATBreakdown(rate=25.0, base_amount="10000,00", vat_amount="2500,00", source="regex") ], total_excl_vat="10000,00", total_vat="2500,00", total_incl_vat="12500,00", confidence=0.9 ) json_result = result.to_json() assert 'vat_summary' in json_result assert len(json_result['vat_summary']['breakdowns']) == 1 assert json_result['vat_summary']['breakdowns'][0]['rate'] == 25.0 assert json_result['vat_summary']['total_incl_vat'] == "12500,00" def test_to_json_with_vat_validation(self): """Test to_json includes VAT validation when present.""" from backend.validation.vat_validator import VATValidationResult, MathCheckResult result = InferenceResult() result.vat_validation = VATValidationResult( is_valid=True, confidence_score=0.95, math_checks=[ MathCheckResult( rate=25.0, base_amount=10000.0, expected_vat=2500.0, actual_vat=2500.0, is_valid=True, tolerance=0.5 ) ], total_check=True, line_items_vs_summary=True, amount_consistency=True, needs_review=False, review_reasons=[] ) json_result = result.to_json() assert 'vat_validation' in json_result assert json_result['vat_validation']['is_valid'] is True assert json_result['vat_validation']['confidence_score'] == 0.95 assert len(json_result['vat_validation']['math_checks']) == 1 class TestBusinessFeaturesAvailable: """Tests for BUSINESS_FEATURES_AVAILABLE flag.""" def test_business_features_available(self): """Test that business features are available.""" from backend.pipeline import BUSINESS_FEATURES_AVAILABLE assert BUSINESS_FEATURES_AVAILABLE is True class TestExtractBusinessFeaturesErrorHandling: """Tests for _extract_business_features error handling.""" def test_pipeline_module_has_logger(self): """Test that pipeline module defines logger correctly.""" from backend.pipeline import pipeline assert hasattr(pipeline, 'logger') assert pipeline.logger is not None def test_extract_business_features_logs_errors(self): """Test that _extract_business_features logs detailed errors.""" from backend.pipeline.pipeline import InferencePipeline, InferenceResult # Create a pipeline with mocked extractors that raise an exception with patch.object(InferencePipeline, '__init__', lambda self, **kwargs: None): pipeline = InferencePipeline() pipeline.line_items_extractor = MagicMock() pipeline.vat_extractor = MagicMock() pipeline.vat_validator = MagicMock() # Make line_items_extractor raise an exception test_error = ValueError("Test error message") pipeline.line_items_extractor.extract_from_pdf.side_effect = test_error result = InferenceResult() # Call the method pipeline._extract_business_features("/fake/path.pdf", result, "full text") # Verify error was captured with type info assert len(result.errors) == 1 assert "ValueError" in result.errors[0] assert "Test error message" in result.errors[0] def test_extract_business_features_handles_numeric_exceptions(self): """Test that _extract_business_features handles non-standard exceptions.""" from backend.pipeline.pipeline import InferencePipeline, InferenceResult with patch.object(InferencePipeline, '__init__', lambda self, **kwargs: None): pipeline = InferencePipeline() pipeline.line_items_extractor = MagicMock() pipeline.vat_extractor = MagicMock() pipeline.vat_validator = MagicMock() # Simulate an exception that might have a numeric value (like exit codes) class NumericException(Exception): def __str__(self): return "0" pipeline.line_items_extractor.extract_from_pdf.side_effect = NumericException() result = InferenceResult() pipeline._extract_business_features("/fake/path.pdf", result, "full text") # Should include type name even when str(e) is just "0" assert len(result.errors) == 1 assert "NumericException" in result.errors[0] class TestProcessPdfTokenPath: """Tests for PDF text token extraction path in process_pdf().""" def _make_pipeline(self): """Create pipeline with mocked internals, bypassing __init__.""" with patch.object(InferencePipeline, '__init__', lambda self, **kw: None): p = InferencePipeline() p.detector = MagicMock() p.extractor = MagicMock() p.payment_line_parser = MagicMock() p.dpi = 300 p.enable_fallback = False p.enable_business_features = False p.vat_tolerance = 0.5 p.line_items_extractor = None p.vat_extractor = None p.vat_validator = None p._business_ocr_engine = None p._table_detector = None return p def _make_detection(self, class_name='Amount', confidence=0.85, page_no=0): """Create a Detection object.""" from backend.pipeline.yolo_detector import Detection return Detection( class_id=6, class_name=class_name, confidence=confidence, bbox=(100.0, 200.0, 300.0, 250.0), page_no=page_no, ) def _make_extracted_field(self, field_name='Amount', raw_text='2.254,50', normalized='2254.50', confidence=0.85): """Create an ExtractedField object.""" from backend.pipeline.field_extractor import ExtractedField return ExtractedField( field_name=field_name, raw_text=raw_text, normalized_value=normalized, confidence=confidence, detection_confidence=confidence, ocr_confidence=1.0, bbox=(100.0, 200.0, 300.0, 250.0), page_no=0, ) def _make_image_bytes(self): """Create minimal valid PNG bytes (100x100 white image).""" from PIL import Image as PILImage import io as _io img = PILImage.new('RGB', (100, 100), color='white') buf = _io.BytesIO() img.save(buf, format='PNG') return buf.getvalue() @patch('shared.pdf.extractor.PDFDocument') @patch('shared.pdf.renderer.render_pdf_to_images') def test_text_pdf_uses_pdf_tokens(self, mock_render, mock_pdf_doc_cls): """When PDF is text-based, extract_from_detection_with_pdf is used.""" from shared.pdf.extractor import Token pipeline = self._make_pipeline() detection = self._make_detection() image_bytes = self._make_image_bytes() # Setup PDFDocument mock - text PDF with tokens mock_pdf_doc = MagicMock() mock_pdf_doc.is_text_pdf.return_value = True mock_pdf_doc.page_count = 1 tokens = [Token(text="2.254,50", bbox=(100, 200, 200, 220), page_no=0)] mock_pdf_doc.extract_text_tokens.return_value = iter(tokens) mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc) mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False) pipeline.detector.detect.return_value = [detection] pipeline.extractor.extract_from_detection_with_pdf.return_value = ( self._make_extracted_field() ) mock_render.return_value = iter([(0, image_bytes)]) result = pipeline.process_pdf('/fake/invoice.pdf') pipeline.extractor.extract_from_detection_with_pdf.assert_called_once() pipeline.extractor.extract_from_detection.assert_not_called() assert result.fields.get('Amount') == '2254.50' assert result.success is True @patch('shared.pdf.extractor.PDFDocument') @patch('shared.pdf.renderer.render_pdf_to_images') def test_scanned_pdf_uses_ocr(self, mock_render, mock_pdf_doc_cls): """When PDF is scanned, extract_from_detection (OCR) is used.""" pipeline = self._make_pipeline() detection = self._make_detection() image_bytes = self._make_image_bytes() mock_pdf_doc = MagicMock() mock_pdf_doc.is_text_pdf.return_value = False mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc) mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False) pipeline.detector.detect.return_value = [detection] pipeline.extractor.extract_from_detection.return_value = ( self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75) ) mock_render.return_value = iter([(0, image_bytes)]) result = pipeline.process_pdf('/fake/invoice.pdf') pipeline.extractor.extract_from_detection.assert_called_once() pipeline.extractor.extract_from_detection_with_pdf.assert_not_called() @patch('shared.pdf.extractor.PDFDocument') @patch('shared.pdf.renderer.render_pdf_to_images') def test_pdf_detection_error_falls_back_to_ocr(self, mock_render, mock_pdf_doc_cls): """When PDF text detection throws, fall back to OCR.""" pipeline = self._make_pipeline() detection = self._make_detection() image_bytes = self._make_image_bytes() mock_ctx = MagicMock() mock_ctx.__enter__ = MagicMock(side_effect=Exception("corrupt PDF")) mock_ctx.__exit__ = MagicMock(return_value=False) mock_pdf_doc_cls.return_value = mock_ctx pipeline.detector.detect.return_value = [detection] pipeline.extractor.extract_from_detection.return_value = ( self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75) ) mock_render.return_value = iter([(0, image_bytes)]) result = pipeline.process_pdf('/fake/invoice.pdf') pipeline.extractor.extract_from_detection.assert_called_once() pipeline.extractor.extract_from_detection_with_pdf.assert_not_called() @patch('shared.pdf.extractor.PDFDocument') @patch('shared.pdf.renderer.render_pdf_to_images') def test_text_pdf_passes_correct_args(self, mock_render, mock_pdf_doc_cls): """Verify correct token list and image dimensions are passed.""" from shared.pdf.extractor import Token pipeline = self._make_pipeline() detection = self._make_detection() image_bytes = self._make_image_bytes() # 100x100 PNG mock_pdf_doc = MagicMock() mock_pdf_doc.is_text_pdf.return_value = True mock_pdf_doc.page_count = 1 tokens = [ Token(text="Fakturabelopp:", bbox=(50, 190, 100, 210), page_no=0), Token(text="2.254,50", bbox=(105, 190, 180, 210), page_no=0), Token(text="SEK", bbox=(185, 190, 210, 210), page_no=0), ] mock_pdf_doc.extract_text_tokens.return_value = iter(tokens) mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc) mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False) pipeline.detector.detect.return_value = [detection] pipeline.extractor.extract_from_detection_with_pdf.return_value = ( self._make_extracted_field() ) mock_render.return_value = iter([(0, image_bytes)]) pipeline.process_pdf('/fake/invoice.pdf') call_args = pipeline.extractor.extract_from_detection_with_pdf.call_args[0] assert call_args[0] == detection assert len(call_args[1]) == 3 # 3 tokens passed assert call_args[2] == 100 # image width assert call_args[3] == 100 # image height class TestDpiPassthrough: """Tests for DPI being passed from pipeline to FieldExtractor (Bug 1).""" def test_field_extractor_receives_pipeline_dpi(self): """FieldExtractor should receive the pipeline's DPI, not default to 300.""" with patch('backend.pipeline.pipeline.YOLODetector'): with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls: InferencePipeline( model_path='/fake/model.pt', dpi=150, use_gpu=False, ) mock_fe_cls.assert_called_once_with( ocr_lang='en', use_gpu=False, dpi=150 ) def test_field_extractor_receives_default_dpi(self): """When dpi=300 (default), FieldExtractor should also get 300.""" with patch('backend.pipeline.pipeline.YOLODetector'): with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls: InferencePipeline( model_path='/fake/model.pt', dpi=300, use_gpu=False, ) mock_fe_cls.assert_called_once_with( ocr_lang='en', use_gpu=False, dpi=300 ) class TestFallbackPatternExtraction: """Tests for _extract_with_patterns fallback regex (Bugs 2, 3).""" def _make_pipeline_with_patterns(self): """Create pipeline with mocked internals for pattern testing.""" with patch.object(InferencePipeline, '__init__', lambda self, **kw: None): p = InferencePipeline() p.dpi = 150 p.enable_fallback = True return p def test_bankgiro_no_match_in_org_number(self): """Bankgiro regex must NOT match digits embedded in an org number.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Org.nr 802546-1610", result) assert 'Bankgiro' not in result.fields def test_bankgiro_matches_labeled(self): """Bankgiro regex should match when preceded by 'Bankgiro' label.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Bankgiro 5393-9484", result) assert result.fields.get('Bankgiro') == '5393-9484' def test_bankgiro_matches_standalone(self): """Bankgiro regex should match a standalone 4-4 digit pattern.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Betala till 5393-9484 senast", result) assert result.fields.get('Bankgiro') == '5393-9484' def test_amount_rejects_bare_integer(self): """Amount regex must NOT match bare integers like 'Summa 1'.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Summa 1 Medlemsavgift", result) assert 'Amount' not in result.fields def test_amount_requires_decimal(self): """Amount regex should require a decimal separator.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Total 5 items", result) assert 'Amount' not in result.fields def test_amount_with_decimal_works(self): """Amount regex should match Swedish decimal amounts.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Att betala 1 234,56 SEK", result) assert 'Amount' in result.fields assert float(result.fields['Amount']) == pytest.approx(1234.56, abs=0.01) def test_amount_with_sek_suffix(self): """Amount regex should match amounts ending with SEK.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("7 500,00 SEK", result) assert 'Amount' in result.fields assert float(result.fields['Amount']) == pytest.approx(7500.00, abs=0.01) def test_fallback_extracts_invoice_date(self): """Fallback should extract InvoiceDate from Swedish text.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Fakturadatum 2025-01-15 Referens ABC", result) assert result.fields.get('InvoiceDate') == '2025-01-15' def test_fallback_extracts_due_date(self): """Fallback should extract InvoiceDueDate from Swedish text.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Forfallodag 2025-02-15 Belopp", result) assert result.fields.get('InvoiceDueDate') == '2025-02-15' def test_fallback_extracts_supplier_org(self): """Fallback should extract supplier_organisation_number.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Org.nr 556123-4567 Stockholm", result) assert result.fields.get('supplier_organisation_number') == '556123-4567' def test_fallback_extracts_plusgiro(self): """Fallback should extract Plusgiro number.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Plusgiro 12 34 56-7 betalning", result) assert 'Plusgiro' in result.fields def test_fallback_skips_year_as_invoice_number(self): """Fallback should NOT extract year-like value as InvoiceNumber.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Fakturanr 2025 Datum 2025-01-15", result) assert 'InvoiceNumber' not in result.fields def test_fallback_accepts_valid_invoice_number(self): """Fallback should extract valid non-year InvoiceNumber.""" p = self._make_pipeline_with_patterns() result = InferenceResult() p._extract_with_patterns("Fakturanr 12345 Summa", result) assert result.fields.get('InvoiceNumber') == '12345' class TestDateValidation: """Tests for InvoiceDueDate < InvoiceDate validation (Bug 6).""" def _make_pipeline_for_merge(self): """Create pipeline with mocked internals for merge testing.""" with patch.object(InferencePipeline, '__init__', lambda self, **kw: None): p = InferencePipeline() p.payment_line_parser = MagicMock() p.payment_line_parser.parse.return_value = MagicMock(is_valid=False) return p def test_due_date_before_invoice_date_dropped(self): """DueDate earlier than InvoiceDate should be removed.""" from backend.pipeline.field_extractor import ExtractedField p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ ExtractedField( field_name='InvoiceDate', raw_text='2026-01-16', normalized_value='2026-01-16', confidence=0.9, detection_confidence=0.9, ocr_confidence=1.0, bbox=(0, 0, 100, 50), page_no=0, ), ExtractedField( field_name='InvoiceDueDate', raw_text='2025-12-01', normalized_value='2025-12-01', confidence=0.9, detection_confidence=0.9, ocr_confidence=1.0, bbox=(0, 60, 100, 110), page_no=0, ), ] p._merge_fields(result) assert 'InvoiceDate' in result.fields assert 'InvoiceDueDate' not in result.fields def test_valid_dates_preserved(self): """Both dates kept when DueDate >= InvoiceDate.""" from backend.pipeline.field_extractor import ExtractedField p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ ExtractedField( field_name='InvoiceDate', raw_text='2026-01-16', normalized_value='2026-01-16', confidence=0.9, detection_confidence=0.9, ocr_confidence=1.0, bbox=(0, 0, 100, 50), page_no=0, ), ExtractedField( field_name='InvoiceDueDate', raw_text='2026-02-15', normalized_value='2026-02-15', confidence=0.9, detection_confidence=0.9, ocr_confidence=1.0, bbox=(0, 60, 100, 110), page_no=0, ), ] p._merge_fields(result) assert result.fields['InvoiceDate'] == '2026-01-16' assert result.fields['InvoiceDueDate'] == '2026-02-15' def test_same_dates_preserved(self): """Same InvoiceDate and DueDate should both be kept.""" from backend.pipeline.field_extractor import ExtractedField p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ ExtractedField( field_name='InvoiceDate', raw_text='2026-01-16', normalized_value='2026-01-16', confidence=0.9, detection_confidence=0.9, ocr_confidence=1.0, bbox=(0, 0, 100, 50), page_no=0, ), ExtractedField( field_name='InvoiceDueDate', raw_text='2026-01-16', normalized_value='2026-01-16', confidence=0.9, detection_confidence=0.9, ocr_confidence=1.0, bbox=(0, 60, 100, 110), page_no=0, ), ] p._merge_fields(result) assert result.fields['InvoiceDate'] == '2026-01-16' assert result.fields['InvoiceDueDate'] == '2026-01-16' class TestCrossFieldDedup: """Tests for cross-field deduplication of InvoiceNumber vs OCR/Bankgiro.""" def _make_pipeline_for_merge(self): """Create pipeline with mocked internals for merge testing.""" with patch.object(InferencePipeline, '__init__', lambda self, **kw: None): p = InferencePipeline() p.payment_line_parser = MagicMock() p.payment_line_parser.parse.return_value = MagicMock(is_valid=False) return p def _make_extracted_field(self, field_name, raw_text, normalized, confidence=0.9): from backend.pipeline.field_extractor import ExtractedField return ExtractedField( field_name=field_name, raw_text=raw_text, normalized_value=normalized, confidence=confidence, detection_confidence=confidence, ocr_confidence=1.0, bbox=(0, 0, 100, 50), page_no=0, ) def test_invoice_number_not_same_as_ocr(self): """When InvoiceNumber == OCR, InvoiceNumber should be dropped.""" p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ self._make_extracted_field('InvoiceNumber', '9179845608', '9179845608'), self._make_extracted_field('OCR', '9179845608', '9179845608'), self._make_extracted_field('Amount', '1234,56', '1234.56'), ] p._merge_fields(result) assert 'OCR' in result.fields assert result.fields['OCR'] == '9179845608' assert 'InvoiceNumber' not in result.fields def test_invoice_number_not_same_as_bankgiro_digits(self): """When InvoiceNumber digits == Bankgiro digits, InvoiceNumber should be dropped.""" p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ self._make_extracted_field('InvoiceNumber', '53939484', '53939484'), self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'), self._make_extracted_field('Amount', '500,00', '500.00'), ] p._merge_fields(result) assert 'Bankgiro' in result.fields assert result.fields['Bankgiro'] == '5393-9484' assert 'InvoiceNumber' not in result.fields def test_unrelated_values_kept(self): """When InvoiceNumber, OCR, and Bankgiro are all different, keep all.""" p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ self._make_extracted_field('InvoiceNumber', '19061', '19061'), self._make_extracted_field('OCR', '9179845608', '9179845608'), self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'), ] p._merge_fields(result) assert result.fields['InvoiceNumber'] == '19061' assert result.fields['OCR'] == '9179845608' assert result.fields['Bankgiro'] == '5393-9484' def test_dedup_after_fallback_re_add(self): """Dedup should remove InvoiceNumber re-added by fallback if it matches OCR.""" p = self._make_pipeline_for_merge() result = InferenceResult() # Simulate state after fallback re-adds InvoiceNumber = OCR result.fields = { 'OCR': '758200602426', 'Amount': '164.00', 'InvoiceNumber': '758200602426', # re-added by fallback } result.confidence = { 'OCR': 0.9, 'Amount': 0.9, 'InvoiceNumber': 0.5, # fallback confidence } result.bboxes = {} p._dedup_invoice_number(result) assert 'InvoiceNumber' not in result.fields assert 'OCR' in result.fields def test_invoice_number_substring_of_bankgiro(self): """When InvoiceNumber digits are a substring of Bankgiro digits, drop InvoiceNumber.""" p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ self._make_extracted_field('InvoiceNumber', '4639', '4639'), self._make_extracted_field('Bankgiro', '134-4639', '134-4639'), self._make_extracted_field('Amount', '500,00', '500.00'), ] p._merge_fields(result) assert 'Bankgiro' in result.fields assert result.fields['Bankgiro'] == '134-4639' assert 'InvoiceNumber' not in result.fields def test_invoice_number_not_substring_of_unrelated_bankgiro(self): """When InvoiceNumber is NOT a substring of Bankgiro, keep both.""" p = self._make_pipeline_for_merge() result = InferenceResult() result.extracted_fields = [ self._make_extracted_field('InvoiceNumber', '19061', '19061'), self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'), self._make_extracted_field('Amount', '500,00', '500.00'), ] p._merge_fields(result) assert result.fields['InvoiceNumber'] == '19061' assert result.fields['Bankgiro'] == '5393-9484' class TestFallbackTrigger: """Tests for _needs_fallback trigger threshold.""" def _make_pipeline(self): with patch.object(InferencePipeline, '__init__', lambda self, **kw: None): p = InferencePipeline() return p def test_fallback_triggers_when_1_key_field_missing(self): """Should trigger when only 1 key field (e.g. InvoiceNumber) is missing.""" p = self._make_pipeline() result = InferenceResult() result.fields = { 'Amount': '1234.56', 'OCR': '12345678901', 'InvoiceDate': '2025-01-15', 'InvoiceDueDate': '2025-02-15', 'supplier_organisation_number': '556123-4567', } # InvoiceNumber missing -> should trigger assert p._needs_fallback(result) is True def test_fallback_triggers_when_dates_missing(self): """Should trigger when all key fields present but 2+ important fields missing.""" p = self._make_pipeline() result = InferenceResult() result.fields = { 'Amount': '1234.56', 'InvoiceNumber': '12345', 'OCR': '12345678901', } # InvoiceDate, InvoiceDueDate, supplier_org all missing -> should trigger assert p._needs_fallback(result) is True def test_no_fallback_when_all_fields_present(self): """Should NOT trigger when all key and important fields present.""" p = self._make_pipeline() result = InferenceResult() result.fields = { 'Amount': '1234.56', 'InvoiceNumber': '12345', 'OCR': '12345678901', 'InvoiceDate': '2025-01-15', 'InvoiceDueDate': '2025-02-15', 'supplier_organisation_number': '556123-4567', } assert p._needs_fallback(result) is False if __name__ == '__main__': pytest.main([__file__, '-v'])