400 lines
15 KiB
Python
400 lines
15 KiB
Python
"""
|
|
Tests for advanced utility modules:
|
|
- FuzzyMatcher
|
|
- OCRCorrections
|
|
- ContextExtractor
|
|
"""
|
|
|
|
import pytest
|
|
from shared.utils.fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
|
|
from shared.utils.ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants
|
|
from shared.utils.context_extractor import ContextExtractor, extract_field_with_context
|
|
|
|
|
|
class TestFuzzyMatcher:
|
|
"""Tests for FuzzyMatcher class."""
|
|
|
|
def test_levenshtein_distance_identical(self):
|
|
"""Test distance for identical strings."""
|
|
assert FuzzyMatcher.levenshtein_distance("hello", "hello") == 0
|
|
|
|
def test_levenshtein_distance_one_char(self):
|
|
"""Test distance for one character difference."""
|
|
assert FuzzyMatcher.levenshtein_distance("hello", "hallo") == 1
|
|
assert FuzzyMatcher.levenshtein_distance("hello", "hell") == 1
|
|
assert FuzzyMatcher.levenshtein_distance("hello", "helloo") == 1
|
|
|
|
def test_levenshtein_distance_multiple(self):
|
|
"""Test distance for multiple differences."""
|
|
assert FuzzyMatcher.levenshtein_distance("hello", "world") == 4
|
|
assert FuzzyMatcher.levenshtein_distance("", "hello") == 5
|
|
|
|
def test_similarity_ratio_identical(self):
|
|
"""Test similarity for identical strings."""
|
|
assert FuzzyMatcher.similarity_ratio("hello", "hello") == 1.0
|
|
|
|
def test_similarity_ratio_similar(self):
|
|
"""Test similarity for similar strings."""
|
|
ratio = FuzzyMatcher.similarity_ratio("hello", "hallo")
|
|
assert 0.8 <= ratio <= 0.9 # One char different in 5-char string
|
|
|
|
def test_similarity_ratio_different(self):
|
|
"""Test similarity for different strings."""
|
|
ratio = FuzzyMatcher.similarity_ratio("hello", "world")
|
|
assert ratio < 0.5
|
|
|
|
def test_ocr_aware_similarity_exact(self):
|
|
"""Test OCR-aware similarity for exact match."""
|
|
assert FuzzyMatcher.ocr_aware_similarity("12345", "12345") == 1.0
|
|
|
|
def test_ocr_aware_similarity_ocr_error(self):
|
|
"""Test OCR-aware similarity with OCR error."""
|
|
# O instead of 0
|
|
score = FuzzyMatcher.ocr_aware_similarity("1234O", "12340")
|
|
assert score >= 0.9 # Should be high due to OCR correction
|
|
|
|
def test_ocr_aware_similarity_multiple_errors(self):
|
|
"""Test OCR-aware similarity with multiple OCR errors."""
|
|
# l instead of 1, O instead of 0
|
|
score = FuzzyMatcher.ocr_aware_similarity("l234O", "12340")
|
|
assert score >= 0.85
|
|
|
|
def test_match_digits_exact(self):
|
|
"""Test digit matching for exact match."""
|
|
result = FuzzyMatcher.match_digits("12345", "12345")
|
|
assert result.matched is True
|
|
assert result.score == 1.0
|
|
assert result.match_type == 'exact'
|
|
|
|
def test_match_digits_with_separators(self):
|
|
"""Test digit matching ignoring separators."""
|
|
result = FuzzyMatcher.match_digits("123-4567", "1234567")
|
|
assert result.matched is True
|
|
assert result.normalized_ocr == "1234567"
|
|
|
|
def test_match_digits_ocr_error(self):
|
|
"""Test digit matching with OCR error."""
|
|
result = FuzzyMatcher.match_digits("556O234567", "5560234567")
|
|
assert result.matched is True
|
|
assert result.score >= 0.9
|
|
|
|
def test_match_amount_exact(self):
|
|
"""Test amount matching for exact values."""
|
|
result = FuzzyMatcher.match_amount("1234.56", "1234.56")
|
|
assert result.matched is True
|
|
assert result.score == 1.0
|
|
|
|
def test_match_amount_different_formats(self):
|
|
"""Test amount matching with different formats."""
|
|
# Swedish vs US format
|
|
result = FuzzyMatcher.match_amount("1234,56", "1234.56")
|
|
assert result.matched is True
|
|
assert result.score >= 0.99
|
|
|
|
def test_match_amount_with_spaces(self):
|
|
"""Test amount matching with thousand separators."""
|
|
result = FuzzyMatcher.match_amount("1 234,56", "1234.56")
|
|
assert result.matched is True
|
|
|
|
def test_match_date_same_date_different_format(self):
|
|
"""Test date matching with different formats."""
|
|
result = FuzzyMatcher.match_date("2024-12-29", "29.12.2024")
|
|
assert result.matched is True
|
|
assert result.score >= 0.9
|
|
|
|
def test_match_date_different_dates(self):
|
|
"""Test date matching with different dates."""
|
|
result = FuzzyMatcher.match_date("2024-12-29", "2024-12-30")
|
|
assert result.matched is False
|
|
|
|
def test_match_string_exact(self):
|
|
"""Test string matching for exact match."""
|
|
result = FuzzyMatcher.match_string("Hello World", "Hello World")
|
|
assert result.matched is True
|
|
assert result.match_type == 'exact'
|
|
|
|
def test_match_string_case_insensitive(self):
|
|
"""Test string matching case insensitivity."""
|
|
result = FuzzyMatcher.match_string("HELLO", "hello")
|
|
assert result.matched is True
|
|
assert result.match_type == 'normalized'
|
|
|
|
def test_match_string_ocr_corrected(self):
|
|
"""Test string matching with OCR corrections."""
|
|
result = FuzzyMatcher.match_string("5561234567", "556l234567")
|
|
assert result.matched is True
|
|
|
|
def test_match_field_routes_correctly(self):
|
|
"""Test that match_field routes to correct matcher."""
|
|
# Amount field
|
|
result = FuzzyMatcher.match_field("Amount", "1234.56", "1234,56")
|
|
assert result.matched is True
|
|
|
|
# Date field
|
|
result = FuzzyMatcher.match_field("InvoiceDate", "2024-12-29", "29.12.2024")
|
|
assert result.matched is True
|
|
|
|
def test_find_best_match(self):
|
|
"""Test finding best match from candidates."""
|
|
candidates = ["12345", "12346", "99999"]
|
|
result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber")
|
|
|
|
assert result is not None
|
|
assert result[0] == "12345"
|
|
assert result[1].score == 1.0
|
|
|
|
def test_find_best_match_no_match(self):
|
|
"""Test finding best match when none above threshold."""
|
|
candidates = ["99999", "88888", "77777"]
|
|
result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestOCRCorrections:
|
|
"""Tests for OCRCorrections class."""
|
|
|
|
def test_correct_digits_simple(self):
|
|
"""Test simple digit correction."""
|
|
result = OCRCorrections.correct_digits("556O23", aggressive=False)
|
|
assert result.corrected == "556023"
|
|
assert len(result.corrections_applied) == 1
|
|
|
|
def test_correct_digits_multiple(self):
|
|
"""Test multiple digit corrections."""
|
|
result = OCRCorrections.correct_digits("5S6l23", aggressive=False)
|
|
assert result.corrected == "556123"
|
|
assert len(result.corrections_applied) == 2
|
|
|
|
def test_correct_digits_aggressive(self):
|
|
"""Test aggressive mode corrects all potential errors."""
|
|
result = OCRCorrections.correct_digits("AB123", aggressive=True)
|
|
# A -> 4, B -> 8
|
|
assert result.corrected == "48123"
|
|
|
|
def test_correct_digits_non_aggressive(self):
|
|
"""Test non-aggressive mode only corrects adjacent."""
|
|
result = OCRCorrections.correct_digits("AB 123", aggressive=False)
|
|
# A and B are adjacent to each other and both in CHAR_TO_DIGIT,
|
|
# so they may be corrected. The key is digits are not affected.
|
|
assert "123" in result.corrected
|
|
|
|
def test_generate_digit_variants(self):
|
|
"""Test generating OCR variants."""
|
|
variants = OCRCorrections.generate_digit_variants("10")
|
|
# Should include original and variants like "1O", "I0", "IO", "l0", etc.
|
|
assert "10" in variants
|
|
assert "1O" in variants or "l0" in variants
|
|
|
|
def test_generate_digit_variants_limits(self):
|
|
"""Test that variant generation is limited."""
|
|
variants = OCRCorrections.generate_digit_variants("1234567890")
|
|
# Should be limited to prevent explosion (limit is ~100, but may slightly exceed)
|
|
assert len(variants) <= 150
|
|
|
|
def test_is_likely_ocr_error(self):
|
|
"""Test OCR error detection."""
|
|
assert OCRCorrections.is_likely_ocr_error('0', 'O') is True
|
|
assert OCRCorrections.is_likely_ocr_error('O', '0') is True
|
|
assert OCRCorrections.is_likely_ocr_error('1', 'l') is True
|
|
assert OCRCorrections.is_likely_ocr_error('5', 'S') is True
|
|
assert OCRCorrections.is_likely_ocr_error('A', 'Z') is False
|
|
|
|
def test_count_potential_ocr_errors(self):
|
|
"""Test counting OCR errors vs other errors."""
|
|
ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1O3", "103")
|
|
assert ocr_errors == 1 # O vs 0
|
|
assert other_errors == 0
|
|
|
|
ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1X3", "103")
|
|
assert ocr_errors == 0
|
|
assert other_errors == 1 # X vs 0, not a known pair
|
|
|
|
def test_suggest_corrections(self):
|
|
"""Test correction suggestions."""
|
|
suggestions = OCRCorrections.suggest_corrections("556O23", expected_type='digit')
|
|
assert len(suggestions) > 0
|
|
# First suggestion should be the corrected version
|
|
assert suggestions[0][0] == "556023"
|
|
|
|
def test_convenience_function_correct(self):
|
|
"""Test convenience function."""
|
|
assert correct_ocr_digits("556O23") == "556023"
|
|
|
|
def test_convenience_function_variants(self):
|
|
"""Test convenience function for variants."""
|
|
variants = generate_ocr_variants("10")
|
|
assert "10" in variants
|
|
|
|
|
|
class TestContextExtractor:
|
|
"""Tests for ContextExtractor class."""
|
|
|
|
def test_extract_invoice_number_with_label(self):
|
|
"""Test extracting invoice number after label."""
|
|
text = "Fakturanummer: 12345678"
|
|
candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber")
|
|
|
|
assert len(candidates) > 0
|
|
assert candidates[0].value == "12345678"
|
|
assert candidates[0].extraction_method == 'label'
|
|
|
|
def test_extract_invoice_number_swedish(self):
|
|
"""Test extracting with Swedish label."""
|
|
text = "Faktura nr: A12345"
|
|
candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber")
|
|
|
|
assert len(candidates) > 0
|
|
# Should extract A12345 or 12345
|
|
|
|
def test_extract_amount_with_label(self):
|
|
"""Test extracting amount after label."""
|
|
text = "Att betala: 1 234,56"
|
|
candidates = ContextExtractor.extract_with_label(text, "Amount")
|
|
|
|
assert len(candidates) > 0
|
|
|
|
def test_extract_amount_total(self):
|
|
"""Test extracting with total label."""
|
|
text = "Total: 5678,90 kr"
|
|
candidates = ContextExtractor.extract_with_label(text, "Amount")
|
|
|
|
assert len(candidates) > 0
|
|
|
|
def test_extract_date_with_label(self):
|
|
"""Test extracting date after label."""
|
|
text = "Fakturadatum: 2024-12-29"
|
|
candidates = ContextExtractor.extract_with_label(text, "InvoiceDate")
|
|
|
|
assert len(candidates) > 0
|
|
assert "2024-12-29" in candidates[0].value
|
|
|
|
def test_extract_due_date(self):
|
|
"""Test extracting due date."""
|
|
text = "Förfallodatum: 2025-01-15"
|
|
candidates = ContextExtractor.extract_with_label(text, "InvoiceDueDate")
|
|
|
|
assert len(candidates) > 0
|
|
|
|
def test_extract_bankgiro(self):
|
|
"""Test extracting Bankgiro."""
|
|
text = "Bankgiro: 1234-5678"
|
|
candidates = ContextExtractor.extract_with_label(text, "Bankgiro")
|
|
|
|
assert len(candidates) > 0
|
|
assert "1234-5678" in candidates[0].value or "12345678" in candidates[0].value
|
|
|
|
def test_extract_plusgiro(self):
|
|
"""Test extracting Plusgiro."""
|
|
text = "Plusgiro: 1234567-8"
|
|
candidates = ContextExtractor.extract_with_label(text, "Plusgiro")
|
|
|
|
assert len(candidates) > 0
|
|
|
|
def test_extract_ocr(self):
|
|
"""Test extracting OCR number."""
|
|
text = "OCR: 12345678901234"
|
|
candidates = ContextExtractor.extract_with_label(text, "OCR")
|
|
|
|
assert len(candidates) > 0
|
|
assert candidates[0].value == "12345678901234"
|
|
|
|
def test_extract_org_number(self):
|
|
"""Test extracting organization number."""
|
|
text = "Org.nr: 556123-4567"
|
|
candidates = ContextExtractor.extract_with_label(text, "supplier_organisation_number")
|
|
|
|
assert len(candidates) > 0
|
|
|
|
def test_extract_customer_number(self):
|
|
"""Test extracting customer number."""
|
|
text = "Kundnummer: EMM 256-6"
|
|
candidates = ContextExtractor.extract_with_label(text, "customer_number")
|
|
|
|
assert len(candidates) > 0
|
|
|
|
def test_extract_field_returns_sorted(self):
|
|
"""Test that extract_field returns sorted by confidence."""
|
|
text = "Fakturanummer: 12345 Invoice number: 67890"
|
|
candidates = ContextExtractor.extract_field(text, "InvoiceNumber")
|
|
|
|
if len(candidates) > 1:
|
|
# Should be sorted by confidence (descending)
|
|
assert candidates[0].confidence >= candidates[1].confidence
|
|
|
|
def test_extract_best(self):
|
|
"""Test extract_best returns single best candidate."""
|
|
text = "Fakturanummer: 12345678"
|
|
best = ContextExtractor.extract_best(text, "InvoiceNumber")
|
|
|
|
assert best is not None
|
|
assert best.value == "12345678"
|
|
|
|
def test_extract_best_no_match(self):
|
|
"""Test extract_best returns None when no match."""
|
|
text = "No invoice information here"
|
|
best = ContextExtractor.extract_best(text, "InvoiceNumber", validate=True)
|
|
|
|
# May or may not find something depending on validation
|
|
|
|
def test_extract_all_fields(self):
|
|
"""Test extracting all fields from text."""
|
|
text = """
|
|
Fakturanummer: 12345
|
|
Datum: 2024-12-29
|
|
Belopp: 1234,56
|
|
Bankgiro: 1234-5678
|
|
"""
|
|
results = ContextExtractor.extract_all_fields(text)
|
|
|
|
# Should find at least some fields
|
|
assert len(results) > 0
|
|
|
|
def test_identify_field_type(self):
|
|
"""Test identifying field type from context."""
|
|
text = "Fakturanummer: 12345"
|
|
field_type = ContextExtractor.identify_field_type(text, "12345")
|
|
|
|
assert field_type == "InvoiceNumber"
|
|
|
|
def test_convenience_function_extract(self):
|
|
"""Test convenience function."""
|
|
text = "Fakturanummer: 12345678"
|
|
value = extract_field_with_context(text, "InvoiceNumber")
|
|
|
|
assert value == "12345678"
|
|
|
|
|
|
class TestIntegration:
|
|
"""Integration tests combining multiple modules."""
|
|
|
|
def test_fuzzy_match_with_ocr_correction(self):
|
|
"""Test fuzzy matching with OCR correction."""
|
|
# Simulate OCR error: 0 -> O
|
|
ocr_text = "556O234567"
|
|
expected = "5560234567"
|
|
|
|
# First correct
|
|
corrected = correct_ocr_digits(ocr_text)
|
|
assert corrected == expected
|
|
|
|
# Then match
|
|
result = FuzzyMatcher.match_digits(ocr_text, expected)
|
|
assert result.matched is True
|
|
|
|
def test_context_extraction_with_fuzzy_match(self):
|
|
"""Test extracting value and fuzzy matching."""
|
|
text = "Fakturanummer: 1234S678" # S is OCR error for 5
|
|
|
|
# Extract
|
|
candidate = ContextExtractor.extract_best(text, "InvoiceNumber", validate=False)
|
|
assert candidate is not None
|
|
|
|
# Fuzzy match against expected
|
|
result = FuzzyMatcher.match_string(candidate.value, "12345678")
|
|
# Might match depending on threshold
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|