""" Tests for Inference Pipeline Tests the cross-validation logic between payment_line and detected fields: - OCR override from payment_line - Amount override from payment_line - Bankgiro/Plusgiro comparison (no override) - Validation scoring """ import pytest from unittest.mock import MagicMock, patch from backend.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult class TestCrossValidationResult: """Tests for CrossValidationResult dataclass.""" def test_default_values(self): """Test default values.""" cv = CrossValidationResult() assert cv.ocr_match is None assert cv.amount_match is None assert cv.bankgiro_match is None assert cv.plusgiro_match is None assert cv.payment_line_ocr is None assert cv.payment_line_amount is None assert cv.payment_line_account is None assert cv.payment_line_account_type is None def test_attributes(self): """Test setting attributes.""" cv = CrossValidationResult() cv.ocr_match = True cv.amount_match = True cv.payment_line_ocr = '12345678901' cv.payment_line_amount = '100' cv.details = ['OCR match', 'Amount match'] assert cv.ocr_match is True assert cv.amount_match is True assert cv.payment_line_ocr == '12345678901' assert 'OCR match' in cv.details class TestInferenceResult: """Tests for InferenceResult dataclass.""" def test_default_fields(self): """Test default field values.""" result = InferenceResult() assert result.fields == {} assert result.confidence == {} assert result.errors == [] def test_set_fields(self): """Test setting field values.""" result = InferenceResult() result.fields = { 'OCR': '12345678901', 'Amount': '100', 'Bankgiro': '782-1713' } result.confidence = { 'OCR': 0.95, 'Amount': 0.90, 'Bankgiro': 0.88 } assert result.fields['OCR'] == '12345678901' assert result.fields['Amount'] == '100' assert result.fields['Bankgiro'] == '782-1713' def test_cross_validation_assignment(self): """Test cross validation assignment.""" result = InferenceResult() result.fields = {'OCR': '12345678901'} cv = CrossValidationResult() cv.ocr_match = True cv.payment_line_ocr = '12345678901' result.cross_validation = cv assert result.cross_validation is not None assert result.cross_validation.ocr_match is True class TestPaymentLineParsingInPipeline: """Tests for payment_line parsing in cross-validation.""" def test_parse_payment_line_format(self): """Test parsing of payment_line format: OCR:xxx Amount:xxx BG:xxx""" # Simulate the parsing logic from pipeline payment_line = "OCR:310196187399952 Amount:11699 BG:782-1713" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') == '310196187399952' assert pl_parts.get('AMOUNT') == '11699' assert pl_parts.get('BG') == '782-1713' def test_parse_payment_line_with_plusgiro(self): """Test parsing with Plusgiro.""" payment_line = "OCR:12345678901 Amount:500 PG:1234567-8" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') == '12345678901' assert pl_parts.get('PG') == '1234567-8' assert pl_parts.get('BG') is None def test_parse_empty_payment_line(self): """Test parsing empty payment_line.""" payment_line = "" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') is None assert pl_parts.get('AMOUNT') is None class TestOCROverride: """Tests for OCR override logic.""" def test_ocr_override_when_different(self): """Test OCR is overridden when payment_line value differs.""" result = InferenceResult() result.fields = {'OCR': 'wrong_ocr_12345', 'payment_line': 'OCR:correct_ocr_67890 Amount:100 BG:782-1713'} # Simulate the override logic payment_line = result.fields.get('payment_line') pl_parts = {} for part in str(payment_line).split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value payment_line_ocr = pl_parts.get('OCR') # Override detected OCR with payment_line OCR if payment_line_ocr: result.fields['OCR'] = payment_line_ocr assert result.fields['OCR'] == 'correct_ocr_67890' def test_ocr_no_override_when_no_payment_line(self): """Test OCR is not overridden when no payment_line.""" result = InferenceResult() result.fields = {'OCR': 'original_ocr_12345'} # No payment_line, no override assert result.fields['OCR'] == 'original_ocr_12345' class TestAmountOverride: """Tests for Amount override logic.""" def test_amount_override(self): """Test Amount is overridden from payment_line.""" result = InferenceResult() result.fields = { 'Amount': '999.00', 'payment_line': 'OCR:12345 Amount:11699 BG:782-1713' } payment_line = result.fields.get('payment_line') pl_parts = {} for part in str(payment_line).split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value payment_line_amount = pl_parts.get('AMOUNT') if payment_line_amount: result.fields['Amount'] = payment_line_amount assert result.fields['Amount'] == '11699' class TestBankgiroComparison: """Tests for Bankgiro comparison (no override).""" def test_bankgiro_match(self): """Test Bankgiro match detection.""" import re detected_bankgiro = '782-1713' payment_line_account = '782-1713' det_digits = re.sub(r'\D', '', detected_bankgiro) pl_digits = re.sub(r'\D', '', payment_line_account) assert det_digits == pl_digits assert det_digits == '7821713' def test_bankgiro_mismatch(self): """Test Bankgiro mismatch detection.""" import re detected_bankgiro = '782-1713' payment_line_account = '123-4567' det_digits = re.sub(r'\D', '', detected_bankgiro) pl_digits = re.sub(r'\D', '', payment_line_account) assert det_digits != pl_digits def test_bankgiro_not_overridden(self): """Test that Bankgiro is NOT overridden from payment_line.""" result = InferenceResult() result.fields = { 'Bankgiro': '999-9999', # Different value 'payment_line': 'OCR:12345 Amount:100 BG:782-1713' } # Bankgiro should NOT be overridden (per current logic) # Only compared for validation original_bankgiro = result.fields['Bankgiro'] # The override logic explicitly skips Bankgiro # So we verify it remains unchanged assert result.fields['Bankgiro'] == '999-9999' assert result.fields['Bankgiro'] == original_bankgiro class TestValidationScoring: """Tests for validation scoring logic.""" def test_all_fields_match(self): """Test score when all fields match.""" matches = [True, True, True] # OCR, Amount, Bankgiro match_count = sum(1 for m in matches if m) total = len(matches) assert match_count == 3 assert total == 3 def test_partial_match(self): """Test score with partial matches.""" matches = [True, True, False] # OCR match, Amount match, Bankgiro mismatch match_count = sum(1 for m in matches if m) assert match_count == 2 def test_no_matches(self): """Test score when nothing matches.""" matches = [False, False, False] match_count = sum(1 for m in matches if m) assert match_count == 0 def test_only_count_present_fields(self): """Test that only present fields are counted.""" # When invoice has both BG and PG but payment_line only has BG, # we should only count BG in validation payment_line_account_type = 'bankgiro' bankgiro_match = True plusgiro_match = None # Not compared because payment_line doesn't have PG matches = [] if payment_line_account_type == 'bankgiro' and bankgiro_match is not None: matches.append(bankgiro_match) elif payment_line_account_type == 'plusgiro' and plusgiro_match is not None: matches.append(plusgiro_match) assert len(matches) == 1 assert matches[0] is True class TestAmountNormalization: """Tests for amount normalization for comparison.""" def test_normalize_amount_with_comma(self): """Test normalizing amount with comma decimal.""" import re amount = "11699,00" normalized = re.sub(r'[^\d]', '', amount) # Remove trailing zeros for öre if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' def test_normalize_amount_with_dot(self): """Test normalizing amount with dot decimal.""" import re amount = "11699.00" normalized = re.sub(r'[^\d]', '', amount) if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' def test_normalize_amount_with_space_separator(self): """Test normalizing amount with space thousand separator.""" import re amount = "11 699,00" normalized = re.sub(r'[^\d]', '', amount) if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' class TestBusinessFeatures: """Tests for business invoice features (line items, VAT, validation).""" def test_inference_result_has_business_fields(self): """Test that InferenceResult has business feature fields.""" result = InferenceResult() assert result.line_items is None assert result.vat_summary is None assert result.vat_validation is None def test_to_json_without_business_features(self): """Test to_json works without business features.""" result = InferenceResult() result.fields = {'InvoiceNumber': '12345'} result.confidence = {'InvoiceNumber': 0.95} json_result = result.to_json() assert json_result['InvoiceNumber'] == '12345' assert 'line_items' not in json_result assert 'vat_summary' not in json_result assert 'vat_validation' not in json_result def test_to_json_with_line_items(self): """Test to_json includes line items when present.""" from backend.table.line_items_extractor import LineItem, LineItemsResult result = InferenceResult() result.fields = {'Amount': '12500.00'} result.line_items = LineItemsResult( items=[ LineItem( row_index=0, description="Product A", quantity="2", unit_price="5000,00", amount="10000,00", vat_rate="25", confidence=0.9 ) ], header_row=["Beskrivning", "Antal", "Pris", "Belopp", "Moms"], raw_html="...
" ) json_result = result.to_json() assert 'line_items' in json_result assert len(json_result['line_items']['items']) == 1 assert json_result['line_items']['items'][0]['description'] == "Product A" assert json_result['line_items']['items'][0]['amount'] == "10000,00" def test_to_json_with_vat_summary(self): """Test to_json includes VAT summary when present.""" from backend.vat.vat_extractor import VATBreakdown, VATSummary result = InferenceResult() result.vat_summary = VATSummary( breakdowns=[ VATBreakdown(rate=25.0, base_amount="10000,00", vat_amount="2500,00", source="regex") ], total_excl_vat="10000,00", total_vat="2500,00", total_incl_vat="12500,00", confidence=0.9 ) json_result = result.to_json() assert 'vat_summary' in json_result assert len(json_result['vat_summary']['breakdowns']) == 1 assert json_result['vat_summary']['breakdowns'][0]['rate'] == 25.0 assert json_result['vat_summary']['total_incl_vat'] == "12500,00" def test_to_json_with_vat_validation(self): """Test to_json includes VAT validation when present.""" from backend.validation.vat_validator import VATValidationResult, MathCheckResult result = InferenceResult() result.vat_validation = VATValidationResult( is_valid=True, confidence_score=0.95, math_checks=[ MathCheckResult( rate=25.0, base_amount=10000.0, expected_vat=2500.0, actual_vat=2500.0, is_valid=True, tolerance=0.5 ) ], total_check=True, line_items_vs_summary=True, amount_consistency=True, needs_review=False, review_reasons=[] ) json_result = result.to_json() assert 'vat_validation' in json_result assert json_result['vat_validation']['is_valid'] is True assert json_result['vat_validation']['confidence_score'] == 0.95 assert len(json_result['vat_validation']['math_checks']) == 1 class TestBusinessFeaturesAvailable: """Tests for BUSINESS_FEATURES_AVAILABLE flag.""" def test_business_features_available(self): """Test that business features are available.""" from backend.pipeline import BUSINESS_FEATURES_AVAILABLE assert BUSINESS_FEATURES_AVAILABLE is True class TestExtractBusinessFeaturesErrorHandling: """Tests for _extract_business_features error handling.""" def test_pipeline_module_has_logger(self): """Test that pipeline module defines logger correctly.""" from backend.pipeline import pipeline assert hasattr(pipeline, 'logger') assert pipeline.logger is not None def test_extract_business_features_logs_errors(self): """Test that _extract_business_features logs detailed errors.""" from backend.pipeline.pipeline import InferencePipeline, InferenceResult # Create a pipeline with mocked extractors that raise an exception with patch.object(InferencePipeline, '__init__', lambda self, **kwargs: None): pipeline = InferencePipeline() pipeline.line_items_extractor = MagicMock() pipeline.vat_extractor = MagicMock() pipeline.vat_validator = MagicMock() # Make line_items_extractor raise an exception test_error = ValueError("Test error message") pipeline.line_items_extractor.extract_from_pdf.side_effect = test_error result = InferenceResult() # Call the method pipeline._extract_business_features("/fake/path.pdf", result, "full text") # Verify error was captured with type info assert len(result.errors) == 1 assert "ValueError" in result.errors[0] assert "Test error message" in result.errors[0] def test_extract_business_features_handles_numeric_exceptions(self): """Test that _extract_business_features handles non-standard exceptions.""" from backend.pipeline.pipeline import InferencePipeline, InferenceResult with patch.object(InferencePipeline, '__init__', lambda self, **kwargs: None): pipeline = InferencePipeline() pipeline.line_items_extractor = MagicMock() pipeline.vat_extractor = MagicMock() pipeline.vat_validator = MagicMock() # Simulate an exception that might have a numeric value (like exit codes) class NumericException(Exception): def __str__(self): return "0" pipeline.line_items_extractor.extract_from_pdf.side_effect = NumericException() result = InferenceResult() pipeline._extract_business_features("/fake/path.pdf", result, "full text") # Should include type name even when str(e) is just "0" assert len(result.errors) == 1 assert "NumericException" in result.errors[0] if __name__ == '__main__': pytest.main([__file__, '-v'])