""" Tests for Inference Pipeline Normalizers These normalizers extract and validate field values from OCR text. They are different from shared/normalize/normalizers which generate matching variants from known values. """ from unittest.mock import patch import pytest from inference.pipeline.normalizers import ( NormalizationResult, InvoiceNumberNormalizer, OcrNumberNormalizer, BankgiroNormalizer, PlusgiroNormalizer, AmountNormalizer, EnhancedAmountNormalizer, DateNormalizer, EnhancedDateNormalizer, SupplierOrgNumberNormalizer, create_normalizer_registry, ) class TestNormalizationResult: """Tests for NormalizationResult dataclass.""" def test_success(self): result = NormalizationResult.success("123") assert result.value == "123" assert result.is_valid is True assert result.error is None def test_success_with_warning(self): result = NormalizationResult.success_with_warning("123", "Warning message") assert result.value == "123" assert result.is_valid is True assert result.error == "Warning message" def test_failure(self): result = NormalizationResult.failure("Error message") assert result.value is None assert result.is_valid is False assert result.error == "Error message" def test_to_tuple(self): result = NormalizationResult.success("123") value, is_valid, error = result.to_tuple() assert value == "123" assert is_valid is True assert error is None class TestInvoiceNumberNormalizer: """Tests for InvoiceNumberNormalizer.""" @pytest.fixture def normalizer(self): return InvoiceNumberNormalizer() def test_field_name(self, normalizer): assert normalizer.field_name == "InvoiceNumber" def test_alphanumeric(self, normalizer): result = normalizer.normalize("A3861") assert result.value == "A3861" assert result.is_valid is True def test_with_prefix(self, normalizer): result = normalizer.normalize("Faktura: INV12345") assert result.value is not None assert "INV" in result.value or "12345" in result.value def test_year_prefix(self, normalizer): result = normalizer.normalize("2024-12345") assert result.value == "2024-12345" assert result.is_valid is True def test_numeric_only(self, normalizer): result = normalizer.normalize("12345678") assert result.value == "12345678" assert result.is_valid is True def test_empty_string(self, normalizer): result = normalizer.normalize("") assert result.is_valid is False def test_callable(self, normalizer): result = normalizer("A3861") assert result.value == "A3861" def test_skip_date_like_sequence(self, normalizer): """Test that 8-digit sequences starting with 20 (dates) are skipped.""" result = normalizer.normalize("Invoice 12345 Date 20240115") assert result.value == "12345" def test_skip_long_ocr_sequence(self, normalizer): """Test that sequences > 10 digits are skipped.""" result = normalizer.normalize("Invoice 54321 OCR 12345678901234") assert result.value == "54321" def test_fallback_extraction(self, normalizer): """Test fallback to digit extraction.""" # This matches Pattern 3 (short digit sequence 3-10 digits) result = normalizer.normalize("Some text with number 123 embedded") assert result.value == "123" assert result.is_valid is True def test_no_valid_sequence(self, normalizer): """Test failure when no valid sequence found.""" result = normalizer.normalize("no numbers here") assert result.is_valid is False assert "Cannot extract" in result.error class TestOcrNumberNormalizer: """Tests for OcrNumberNormalizer.""" @pytest.fixture def normalizer(self): return OcrNumberNormalizer() def test_field_name(self, normalizer): assert normalizer.field_name == "OCR" def test_standard_ocr(self, normalizer): result = normalizer.normalize("310196187399952") assert result.value == "310196187399952" assert result.is_valid is True def test_with_spaces(self, normalizer): result = normalizer.normalize("3101 9618 7399 952") assert result.value == "310196187399952" assert " " not in result.value def test_too_short(self, normalizer): result = normalizer.normalize("1234") assert result.is_valid is False def test_empty_string(self, normalizer): result = normalizer.normalize("") assert result.is_valid is False class TestBankgiroNormalizer: """Tests for BankgiroNormalizer.""" @pytest.fixture def normalizer(self): return BankgiroNormalizer() def test_field_name(self, normalizer): assert normalizer.field_name == "Bankgiro" def test_7_digit_format(self, normalizer): result = normalizer.normalize("782-1713") assert result.value == "782-1713" assert result.is_valid is True def test_8_digit_format(self, normalizer): result = normalizer.normalize("5393-9484") assert result.value == "5393-9484" assert result.is_valid is True def test_without_dash(self, normalizer): result = normalizer.normalize("7821713") assert result.value is not None assert "-" in result.value def test_with_prefix(self, normalizer): result = normalizer.normalize("Bankgiro: 782-1713") assert result.value == "782-1713" def test_invalid_too_short(self, normalizer): result = normalizer.normalize("123") assert result.is_valid is False def test_empty_string(self, normalizer): result = normalizer.normalize("") assert result.is_valid is False def test_invalid_luhn_with_warning(self, normalizer): """Test BG with invalid Luhn checksum returns warning.""" # 1234-5679 has invalid Luhn result = normalizer.normalize("1234-5679") assert result.value is not None assert "Luhn checksum failed" in (result.error or "") def test_pg_format_excluded(self, normalizer): """Test that PG format (X-X) is not matched as BG.""" result = normalizer.normalize("1234567-8") # PG format assert result.is_valid is False def test_raw_7_digits_fallback(self, normalizer): """Test fallback to raw 7 digits without dash.""" result = normalizer.normalize("BG number is 7821713 here") assert result.value is not None assert "-" in result.value def test_raw_8_digits_invalid_luhn(self, normalizer): """Test raw 8 digits with invalid Luhn.""" result = normalizer.normalize("12345679") # 8 digits, invalid Luhn assert result.value is not None assert "Luhn" in (result.error or "") class TestPlusgiroNormalizer: """Tests for PlusgiroNormalizer.""" @pytest.fixture def normalizer(self): return PlusgiroNormalizer() def test_field_name(self, normalizer): assert normalizer.field_name == "Plusgiro" def test_standard_format(self, normalizer): result = normalizer.normalize("1234567-8") assert result.value is not None assert "-" in result.value def test_short_format(self, normalizer): result = normalizer.normalize("12-3") assert result.value is not None def test_without_dash(self, normalizer): result = normalizer.normalize("12345678") assert result.value is not None assert "-" in result.value def test_with_spaces(self, normalizer): result = normalizer.normalize("486 98 63-6") assert result.value is not None def test_empty_string(self, normalizer): result = normalizer.normalize("") assert result.is_valid is False def test_invalid_luhn_with_warning(self, normalizer): """Test PG with invalid Luhn returns warning.""" result = normalizer.normalize("1234567-9") # Invalid Luhn assert result.value is not None assert "Luhn checksum failed" in (result.error or "") def test_all_digits_fallback(self, normalizer): """Test fallback to all digits extraction.""" result = normalizer.normalize("PG 12345") assert result.value is not None def test_digit_sequence_fallback(self, normalizer): """Test finding digit sequence in text.""" result = normalizer.normalize("Account number: 54321") assert result.value is not None def test_too_long_fails(self, normalizer): """Test that > 8 digits fails (no PG format found).""" result = normalizer.normalize("123456789") # 9 digits, too long # PG is 2-8 digits, so 9 digits is invalid assert result.is_valid is False def test_no_digits_fails(self, normalizer): """Test failure when no valid digits found.""" result = normalizer.normalize("no numbers") assert result.is_valid is False def test_pg_display_format_valid_luhn(self, normalizer): """Test PG display format with valid Luhn checksum.""" # 1000009 has valid Luhn checksum result = normalizer.normalize("PG: 100000-9") assert result.value == "100000-9" assert result.is_valid is True assert result.error is None # No warning for valid Luhn def test_pg_all_digits_valid_luhn(self, normalizer): """Test all digits extraction with valid Luhn.""" # When no PG format found, extract all digits # 10000008 has valid Luhn (8 digits) result = normalizer.normalize("PG number 10000008") assert result.value == "1000000-8" assert result.is_valid is True assert result.error is None def test_pg_digit_sequence_valid_luhn(self, normalizer): """Test digit sequence fallback with valid Luhn.""" # Find word-bounded digit sequence # 1000017 has valid Luhn result = normalizer.normalize("Account: 1000017 registered") assert result.value == "100001-7" assert result.is_valid is True assert result.error is None def test_pg_digit_sequence_invalid_luhn(self, normalizer): """Test digit sequence fallback with invalid Luhn.""" result = normalizer.normalize("Account: 12345678 registered") assert result.value == "1234567-8" assert result.is_valid is True assert "Luhn" in (result.error or "") def test_pg_digit_sequence_when_all_digits_too_long(self, normalizer): """Test digit sequence search when all_digits > 8 (lines 79-86).""" # Total digits > 8, so all_digits fallback fails # But there's a word-bounded 7-digit sequence with valid Luhn result = normalizer.normalize("PG is 1000017 but ID is 9999999999") assert result.value == "100001-7" assert result.is_valid is True assert result.error is None # Valid Luhn def test_pg_digit_sequence_invalid_luhn_when_all_digits_too_long(self, normalizer): """Test digit sequence with invalid Luhn when all_digits > 8.""" # Total digits > 8, word-bounded sequence has invalid Luhn result = normalizer.normalize("Account 12345 in document 987654321") assert result.value == "1234-5" assert result.is_valid is True assert "Luhn" in (result.error or "") class TestAmountNormalizer: """Tests for AmountNormalizer.""" @pytest.fixture def normalizer(self): return AmountNormalizer() def test_field_name(self, normalizer): assert normalizer.field_name == "Amount" def test_swedish_format(self, normalizer): result = normalizer.normalize("11 699,00") assert result.value is not None assert result.is_valid is True def test_with_currency(self, normalizer): result = normalizer.normalize("11 699,00 SEK") assert result.value is not None def test_dot_decimal(self, normalizer): result = normalizer.normalize("1234.56") assert result.value == "1234.56" def test_integer_amount(self, normalizer): result = normalizer.normalize("Belopp: 11699") assert result.value is not None def test_multiple_amounts_returns_last(self, normalizer): result = normalizer.normalize("Subtotal: 100,00\nMoms: 25,00\nTotal: 125,00") assert result.value == "125.00" def test_empty_string(self, normalizer): result = normalizer.normalize("") assert result.is_valid is False def test_empty_lines_skipped(self, normalizer): """Test that empty lines are skipped.""" result = normalizer.normalize("\n\n100,00\n\n") assert result.value == "100.00" def test_simple_decimal_fallback(self, normalizer): """Test simple decimal pattern fallback.""" result = normalizer.normalize("Price is 99.99 dollars") assert result.value == "99.99" def test_standalone_number_fallback(self, normalizer): """Test standalone number >= 3 digits fallback.""" result = normalizer.normalize("Amount 12345") assert result.value == "12345.00" def test_no_amount_fails(self, normalizer): """Test failure when no amount found.""" result = normalizer.normalize("no amount here") assert result.is_valid is False def test_value_error_in_amount_parsing(self, normalizer): """Test that ValueError in float conversion is handled.""" # A pattern that matches but cannot be converted to float # This is hard to trigger since regex already validates digits result = normalizer.normalize("Amount: abc") assert result.is_valid is False def test_shared_validator_fallback(self, normalizer): """Test fallback to shared validator.""" # Input that doesn't match primary pattern but shared validator handles result = normalizer.normalize("kr 1234") assert result.value is not None def test_simple_decimal_pattern_fallback(self, normalizer): """Test simple decimal pattern fallback.""" # Pattern that requires simple_pattern fallback result = normalizer.normalize("Total: 99,99") assert result.value == "99.99" def test_integer_pattern_fallback(self, normalizer): """Test integer amount pattern fallback.""" result = normalizer.normalize("Amount: 5000") assert result.value == "5000.00" def test_standalone_number_fallback(self, normalizer): """Test standalone number >= 3 digits fallback (lines 99-104).""" # No amount/belopp/summa/total keywords, no decimal - reaches standalone pattern result = normalizer.normalize("Reference 12500") assert result.value == "12500.00" def test_zero_amount_rejected(self, normalizer): """Test that zero amounts are rejected.""" result = normalizer.normalize("0,00 kr") assert result.is_valid is False def test_negative_sign_ignored(self, normalizer): """Test that negative sign is ignored (code extracts digits only).""" result = normalizer.normalize("-100,00") # The pattern extracts "100,00" ignoring the negative sign assert result.value == "100.00" assert result.is_valid is True class TestEnhancedAmountNormalizer: """Tests for EnhancedAmountNormalizer.""" @pytest.fixture def normalizer(self): return EnhancedAmountNormalizer() def test_labeled_amount(self, normalizer): result = normalizer.normalize("Att betala: 1 234,56") assert result.value is not None assert result.is_valid is True def test_total_keyword(self, normalizer): result = normalizer.normalize("Total: 9 999,00 kr") assert result.value is not None def test_ocr_correction(self, normalizer): # O -> 0 correction result = normalizer.normalize("1O23,45") assert result.value is not None def test_summa_keyword(self, normalizer): """Test Swedish 'summa' keyword.""" result = normalizer.normalize("Summa: 5 000,00") assert result.value is not None def test_moms_lower_priority(self, normalizer): """Test that moms (VAT) has lower priority than summa/total.""" # 'summa' keyword has priority 1.0, 'moms' has 0.8 result = normalizer.normalize("Moms: 250,00 Summa: 1250,00") assert result.value == "1250.00" def test_decimal_pattern_fallback(self, normalizer): """Test decimal pattern extraction.""" result = normalizer.normalize("Invoice for 1 234 567,89 kr") assert result.value is not None def test_no_amount_fails(self, normalizer): """Test failure when no amount found.""" result = normalizer.normalize("no amount") assert result.is_valid is False def test_enhanced_empty_string(self, normalizer): """Test empty string fails.""" result = normalizer.normalize("") assert result.is_valid is False def test_enhanced_shared_validator_fallback(self, normalizer): """Test fallback to shared validator when no labeled patterns match.""" # Input that doesn't match labeled patterns but shared validator handles result = normalizer.normalize("kr 1234") assert result.value is not None def test_enhanced_decimal_pattern_fallback(self, normalizer): """Test Strategy 4 decimal pattern fallback.""" # Input that bypasses labeled patterns and shared validator result = normalizer.normalize("Price: 1 234 567,89") assert result.value is not None def test_amount_out_of_range_rejected(self, normalizer): """Test that amounts >= 10,000,000 are rejected.""" result = normalizer.normalize("Summa: 99 999 999,00") # Should fail since amount is >= 10,000,000 assert result.is_valid is False def test_value_error_in_labeled_pattern(self, normalizer): """Test ValueError handling in labeled pattern parsing.""" # This is defensive code that's hard to trigger result = normalizer.normalize("Total: abc,00") # Should fall through to other strategies assert result.is_valid is False def test_enhanced_decimal_pattern_multiple_amounts(self, normalizer): """Test Strategy 4 with multiple decimal amounts (lines 168-183).""" # Need input that bypasses labeled patterns AND shared validator # but has decimal pattern matches with patch( "inference.pipeline.normalizers.amount.FieldValidators.parse_amount", return_value=None, ): result = normalizer.normalize("Items: 100,00 and 200,00 and 300,00") # Should return max amount assert result.value == "300.00" assert result.is_valid is True class TestDateNormalizer: """Tests for DateNormalizer.""" @pytest.fixture def normalizer(self): return DateNormalizer() def test_field_name(self, normalizer): assert normalizer.field_name == "Date" def test_iso_format(self, normalizer): result = normalizer.normalize("2026-01-31") assert result.value == "2026-01-31" assert result.is_valid is True def test_european_dot_format(self, normalizer): result = normalizer.normalize("31.01.2026") assert result.value == "2026-01-31" def test_european_slash_format(self, normalizer): result = normalizer.normalize("31/01/2026") assert result.value == "2026-01-31" def test_compact_format(self, normalizer): result = normalizer.normalize("20260131") assert result.value == "2026-01-31" def test_invalid_date(self, normalizer): result = normalizer.normalize("not a date") assert result.is_valid is False def test_empty_string(self, normalizer): result = normalizer.normalize("") assert result.is_valid is False def test_dot_format_ymd(self, normalizer): """Test YYYY.MM.DD format.""" result = normalizer.normalize("2025.08.29") assert result.value == "2025-08-29" def test_invalid_date_value_continues(self, normalizer): """Test that invalid date values are skipped.""" result = normalizer.normalize("2025-13-45") # Invalid month/day assert result.is_valid is False def test_year_out_of_range(self, normalizer): """Test that years outside 2000-2100 are rejected.""" result = normalizer.normalize("1999-01-01") assert result.is_valid is False def test_fallback_pattern_single_digit_day(self, normalizer): """Test fallback pattern with single digit day (European slash format).""" # The shared validator returns None for single digit day like 8/12/2025 # So it falls back to the PATTERNS list (European DD/MM/YYYY) result = normalizer.normalize("8/12/2025") assert result.value == "2025-12-08" assert result.is_valid is True def test_fallback_pattern_with_mock(self, normalizer): """Test fallback PATTERNS when shared validator returns None (line 83).""" with patch( "inference.pipeline.normalizers.date.FieldValidators.format_date_iso", return_value=None, ): result = normalizer.normalize("2025-08-29") assert result.value == "2025-08-29" assert result.is_valid is True class TestEnhancedDateNormalizer: """Tests for EnhancedDateNormalizer.""" @pytest.fixture def normalizer(self): return EnhancedDateNormalizer() def test_swedish_text_date(self, normalizer): result = normalizer.normalize("29 december 2024") assert result.value == "2024-12-29" assert result.is_valid is True def test_swedish_abbreviated(self, normalizer): result = normalizer.normalize("15 jan 2025") assert result.value == "2025-01-15" def test_ocr_correction(self, normalizer): # O -> 0 correction result = normalizer.normalize("2O26-01-31") assert result.value == "2026-01-31" def test_empty_string(self, normalizer): """Test empty string fails.""" result = normalizer.normalize("") assert result.is_valid is False def test_swedish_months(self, normalizer): """Test Swedish month names that work with OCR correction. Note: OCRCorrections.correct_digits corrupts some month names: - april -> apr11, juli -> ju11, augusti -> augu571, oktober -> ok706er These months are excluded from this test. """ months = [ ("15 januari 2025", "2025-01-15"), ("15 februari 2025", "2025-02-15"), ("15 mars 2025", "2025-03-15"), ("15 maj 2025", "2025-05-15"), ("15 juni 2025", "2025-06-15"), ("15 september 2025", "2025-09-15"), ("15 november 2025", "2025-11-15"), ("15 december 2025", "2025-12-15"), ] for text, expected in months: result = normalizer.normalize(text) assert result.value == expected, f"Failed for {text}" def test_extended_ymd_slash(self, normalizer): """Test YYYY/MM/DD format.""" result = normalizer.normalize("2025/08/29") assert result.value == "2025-08-29" def test_extended_dmy_dash(self, normalizer): """Test DD-MM-YYYY format.""" result = normalizer.normalize("29-08-2025") assert result.value == "2025-08-29" def test_extended_compact(self, normalizer): """Test YYYYMMDD compact format.""" result = normalizer.normalize("20250829") assert result.value == "2025-08-29" def test_invalid_swedish_month(self, normalizer): """Test invalid Swedish month name falls through.""" result = normalizer.normalize("15 invalidmonth 2025") assert result.is_valid is False def test_invalid_extended_date_continues(self, normalizer): """Test that invalid dates in extended patterns are skipped.""" result = normalizer.normalize("32-13-2025") # Invalid day/month assert result.is_valid is False def test_swedish_pattern_invalid_date(self, normalizer): """Test Swedish pattern with invalid date (Feb 31) falls through. When shared validator returns an invalid date like 2025-02-31, is_valid_date returns False, so it tries Swedish pattern, which also fails due to invalid datetime. """ result = normalizer.normalize("31 feb 2025") assert result.is_valid is False def test_swedish_pattern_year_out_of_range(self, normalizer): """Test Swedish pattern with year outside 2000-2100.""" # Use abbreviated month to avoid OCR corruption result = normalizer.normalize("15 jan 1999") # is_valid_date returns False for 1999-01-15, falls through # Swedish pattern matches but year < 2000 assert result.is_valid is False def test_ymd_compact_format_with_prefix(self, normalizer): """Test YYYYMMDD compact format with surrounding text.""" # The compact pattern requires word boundaries result = normalizer.normalize("Date code: 20250315") assert result.value == "2025-03-15" def test_swedish_pattern_fallback_with_mock(self, normalizer): """Test Swedish pattern when shared validator returns None (line 170).""" with patch( "inference.pipeline.normalizers.date.FieldValidators.format_date_iso", return_value=None, ): result = normalizer.normalize("15 maj 2025") assert result.value == "2025-05-15" assert result.is_valid is True def test_ymd_compact_fallback_with_mock(self, normalizer): """Test ymd_compact pattern when shared validator returns None (lines 187-192).""" with patch( "inference.pipeline.normalizers.date.FieldValidators.format_date_iso", return_value=None, ): result = normalizer.normalize("20250315") assert result.value == "2025-03-15" assert result.is_valid is True class TestSupplierOrgNumberNormalizer: """Tests for SupplierOrgNumberNormalizer.""" @pytest.fixture def normalizer(self): return SupplierOrgNumberNormalizer() def test_field_name(self, normalizer): assert normalizer.field_name == "supplier_org_number" def test_standard_format(self, normalizer): result = normalizer.normalize("516406-1102") assert result.value == "516406-1102" assert result.is_valid is True def test_with_prefix(self, normalizer): result = normalizer.normalize("Org.nr 516406-1102") assert result.value == "516406-1102" def test_without_dash(self, normalizer): result = normalizer.normalize("5164061102") assert result.value == "516406-1102" def test_vat_format(self, normalizer): result = normalizer.normalize("SE556123456701") assert result.value is not None assert "-" in result.value def test_empty_string(self, normalizer): result = normalizer.normalize("") assert result.is_valid is False def test_10_consecutive_digits(self, normalizer): """Test 10 consecutive digits pattern.""" result = normalizer.normalize("Company org 5164061102 registered") assert result.value == "516406-1102" def test_10_digits_starting_with_zero_accepted(self, normalizer): """Test that 10 digits starting with 0 are accepted by Pattern 1. Pattern 1 (NNNNNN-?NNNN) matches any 10 digits with optional dash. Only Pattern 3 (standalone 10 digits) validates first digit != 0. """ result = normalizer.normalize("0164061102") assert result.is_valid is True assert result.value == "016406-1102" def test_no_org_number_fails(self, normalizer): """Test failure when no org number found.""" result = normalizer.normalize("no org number here") assert result.is_valid is False class TestNormalizerRegistry: """Tests for normalizer registry factory.""" def test_create_registry(self): registry = create_normalizer_registry() assert "InvoiceNumber" in registry assert "OCR" in registry assert "Bankgiro" in registry assert "Plusgiro" in registry assert "Amount" in registry assert "InvoiceDate" in registry assert "InvoiceDueDate" in registry assert "supplier_org_number" in registry def test_registry_with_enhanced(self): registry = create_normalizer_registry(use_enhanced=True) # Enhanced normalizers should be used for Amount and Date assert isinstance(registry["Amount"], EnhancedAmountNormalizer) assert isinstance(registry["InvoiceDate"], EnhancedDateNormalizer) def test_registry_without_enhanced(self): registry = create_normalizer_registry(use_enhanced=False) assert isinstance(registry["Amount"], AmountNormalizer) assert isinstance(registry["InvoiceDate"], DateNormalizer) if __name__ == "__main__": pytest.main([__file__, "-v"])