""" Tests for Inference Pipeline Tests the cross-validation logic between payment_line and detected fields: - OCR override from payment_line - Amount override from payment_line - Bankgiro/Plusgiro comparison (no override) - Validation scoring """ import pytest from unittest.mock import MagicMock, patch from src.inference.pipeline import InferencePipeline, InferenceResult, CrossValidationResult class TestCrossValidationResult: """Tests for CrossValidationResult dataclass.""" def test_default_values(self): """Test default values.""" cv = CrossValidationResult() assert cv.ocr_match is None assert cv.amount_match is None assert cv.bankgiro_match is None assert cv.plusgiro_match is None assert cv.payment_line_ocr is None assert cv.payment_line_amount is None assert cv.payment_line_account is None assert cv.payment_line_account_type is None def test_attributes(self): """Test setting attributes.""" cv = CrossValidationResult() cv.ocr_match = True cv.amount_match = True cv.payment_line_ocr = '12345678901' cv.payment_line_amount = '100' cv.details = ['OCR match', 'Amount match'] assert cv.ocr_match is True assert cv.amount_match is True assert cv.payment_line_ocr == '12345678901' assert 'OCR match' in cv.details class TestInferenceResult: """Tests for InferenceResult dataclass.""" def test_default_fields(self): """Test default field values.""" result = InferenceResult() assert result.fields == {} assert result.confidence == {} assert result.errors == [] def test_set_fields(self): """Test setting field values.""" result = InferenceResult() result.fields = { 'OCR': '12345678901', 'Amount': '100', 'Bankgiro': '782-1713' } result.confidence = { 'OCR': 0.95, 'Amount': 0.90, 'Bankgiro': 0.88 } assert result.fields['OCR'] == '12345678901' assert result.fields['Amount'] == '100' assert result.fields['Bankgiro'] == '782-1713' def test_cross_validation_assignment(self): """Test cross validation assignment.""" result = InferenceResult() result.fields = {'OCR': '12345678901'} cv = CrossValidationResult() cv.ocr_match = True cv.payment_line_ocr = '12345678901' result.cross_validation = cv assert result.cross_validation is not None assert result.cross_validation.ocr_match is True class TestPaymentLineParsingInPipeline: """Tests for payment_line parsing in cross-validation.""" def test_parse_payment_line_format(self): """Test parsing of payment_line format: OCR:xxx Amount:xxx BG:xxx""" # Simulate the parsing logic from pipeline payment_line = "OCR:310196187399952 Amount:11699 BG:782-1713" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') == '310196187399952' assert pl_parts.get('AMOUNT') == '11699' assert pl_parts.get('BG') == '782-1713' def test_parse_payment_line_with_plusgiro(self): """Test parsing with Plusgiro.""" payment_line = "OCR:12345678901 Amount:500 PG:1234567-8" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') == '12345678901' assert pl_parts.get('PG') == '1234567-8' assert pl_parts.get('BG') is None def test_parse_empty_payment_line(self): """Test parsing empty payment_line.""" payment_line = "" pl_parts = {} for part in payment_line.split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value assert pl_parts.get('OCR') is None assert pl_parts.get('AMOUNT') is None class TestOCROverride: """Tests for OCR override logic.""" def test_ocr_override_when_different(self): """Test OCR is overridden when payment_line value differs.""" result = InferenceResult() result.fields = {'OCR': 'wrong_ocr_12345', 'payment_line': 'OCR:correct_ocr_67890 Amount:100 BG:782-1713'} # Simulate the override logic payment_line = result.fields.get('payment_line') pl_parts = {} for part in str(payment_line).split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value payment_line_ocr = pl_parts.get('OCR') # Override detected OCR with payment_line OCR if payment_line_ocr: result.fields['OCR'] = payment_line_ocr assert result.fields['OCR'] == 'correct_ocr_67890' def test_ocr_no_override_when_no_payment_line(self): """Test OCR is not overridden when no payment_line.""" result = InferenceResult() result.fields = {'OCR': 'original_ocr_12345'} # No payment_line, no override assert result.fields['OCR'] == 'original_ocr_12345' class TestAmountOverride: """Tests for Amount override logic.""" def test_amount_override(self): """Test Amount is overridden from payment_line.""" result = InferenceResult() result.fields = { 'Amount': '999.00', 'payment_line': 'OCR:12345 Amount:11699 BG:782-1713' } payment_line = result.fields.get('payment_line') pl_parts = {} for part in str(payment_line).split(): if ':' in part: key, value = part.split(':', 1) pl_parts[key.upper()] = value payment_line_amount = pl_parts.get('AMOUNT') if payment_line_amount: result.fields['Amount'] = payment_line_amount assert result.fields['Amount'] == '11699' class TestBankgiroComparison: """Tests for Bankgiro comparison (no override).""" def test_bankgiro_match(self): """Test Bankgiro match detection.""" import re detected_bankgiro = '782-1713' payment_line_account = '782-1713' det_digits = re.sub(r'\D', '', detected_bankgiro) pl_digits = re.sub(r'\D', '', payment_line_account) assert det_digits == pl_digits assert det_digits == '7821713' def test_bankgiro_mismatch(self): """Test Bankgiro mismatch detection.""" import re detected_bankgiro = '782-1713' payment_line_account = '123-4567' det_digits = re.sub(r'\D', '', detected_bankgiro) pl_digits = re.sub(r'\D', '', payment_line_account) assert det_digits != pl_digits def test_bankgiro_not_overridden(self): """Test that Bankgiro is NOT overridden from payment_line.""" result = InferenceResult() result.fields = { 'Bankgiro': '999-9999', # Different value 'payment_line': 'OCR:12345 Amount:100 BG:782-1713' } # Bankgiro should NOT be overridden (per current logic) # Only compared for validation original_bankgiro = result.fields['Bankgiro'] # The override logic explicitly skips Bankgiro # So we verify it remains unchanged assert result.fields['Bankgiro'] == '999-9999' assert result.fields['Bankgiro'] == original_bankgiro class TestValidationScoring: """Tests for validation scoring logic.""" def test_all_fields_match(self): """Test score when all fields match.""" matches = [True, True, True] # OCR, Amount, Bankgiro match_count = sum(1 for m in matches if m) total = len(matches) assert match_count == 3 assert total == 3 def test_partial_match(self): """Test score with partial matches.""" matches = [True, True, False] # OCR match, Amount match, Bankgiro mismatch match_count = sum(1 for m in matches if m) assert match_count == 2 def test_no_matches(self): """Test score when nothing matches.""" matches = [False, False, False] match_count = sum(1 for m in matches if m) assert match_count == 0 def test_only_count_present_fields(self): """Test that only present fields are counted.""" # When invoice has both BG and PG but payment_line only has BG, # we should only count BG in validation payment_line_account_type = 'bankgiro' bankgiro_match = True plusgiro_match = None # Not compared because payment_line doesn't have PG matches = [] if payment_line_account_type == 'bankgiro' and bankgiro_match is not None: matches.append(bankgiro_match) elif payment_line_account_type == 'plusgiro' and plusgiro_match is not None: matches.append(plusgiro_match) assert len(matches) == 1 assert matches[0] is True class TestAmountNormalization: """Tests for amount normalization for comparison.""" def test_normalize_amount_with_comma(self): """Test normalizing amount with comma decimal.""" import re amount = "11699,00" normalized = re.sub(r'[^\d]', '', amount) # Remove trailing zeros for öre if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' def test_normalize_amount_with_dot(self): """Test normalizing amount with dot decimal.""" import re amount = "11699.00" normalized = re.sub(r'[^\d]', '', amount) if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' def test_normalize_amount_with_space_separator(self): """Test normalizing amount with space thousand separator.""" import re amount = "11 699,00" normalized = re.sub(r'[^\d]', '', amount) if len(normalized) > 2 and normalized[-2:] == '00': normalized = normalized[:-2] assert normalized == '11699' if __name__ == '__main__': pytest.main([__file__, '-v'])