This commit is contained in:
Yaojia Wang
2026-02-12 23:06:00 +01:00
parent ad5ed46b4c
commit 58d36c8927
26 changed files with 3903 additions and 2551 deletions

View File

@@ -213,8 +213,8 @@ class TestNormalizeOCR:
assert ' ' not in result.value # Spaces should be removed
def test_short_ocr_invalid(self, normalizer):
"""Test that too short OCR is invalid."""
result = normalizer.normalize("123")
"""Test that single-digit OCR is invalid (min 2 digits)."""
result = normalizer.normalize("5")
assert result.is_valid is False

View File

@@ -100,6 +100,22 @@ class TestInvoiceNumberNormalizer:
result = normalizer.normalize("Invoice 54321 OCR 12345678901234")
assert result.value == "54321"
def test_year_not_extracted_when_real_number_exists(self, normalizer):
"""4-digit year should be skipped when a real invoice number is present."""
result = normalizer.normalize("Faktura 12345 Datum 2025")
assert result.value == "12345"
def test_year_2026_not_extracted(self, normalizer):
"""Year '2026' should not be preferred over a real invoice number."""
result = normalizer.normalize("Invoice 54321 Date 2026")
assert result.value == "54321"
def test_non_year_4_digit_still_matches(self, normalizer):
"""4-digit numbers that are NOT years should still match."""
result = normalizer.normalize("Invoice 3456")
assert result.value == "3456"
assert result.is_valid is True
def test_fallback_extraction(self, normalizer):
"""Test fallback to digit extraction."""
# This matches Pattern 3 (short digit sequence 3-10 digits)
@@ -107,6 +123,16 @@ class TestInvoiceNumberNormalizer:
assert result.value == "123"
assert result.is_valid is True
def test_amount_fragment_not_selected(self, normalizer):
"""Amount fragment '775' from '9 775,96' should lose to real invoice number."""
result = normalizer.normalize("9 775,96 Belopp Kontoutdragsnr 04862823")
assert result.value == "04862823"
def test_prefer_medium_length_over_shortest(self, normalizer):
"""Prefer 4-8 digit sequences over very short 3-digit ones."""
result = normalizer.normalize("Ref 999 Fakturanr 12345")
assert result.value == "12345"
def test_no_valid_sequence(self, normalizer):
"""Test failure when no valid sequence found."""
result = normalizer.normalize("no numbers here")
@@ -134,8 +160,21 @@ class TestOcrNumberNormalizer:
assert result.value == "310196187399952"
assert " " not in result.value
def test_4_digit_ocr_valid(self, normalizer):
"""4-digit OCR numbers like '3046' should be accepted."""
result = normalizer.normalize("3046")
assert result.is_valid is True
assert result.value == "3046"
def test_2_digit_ocr_valid(self, normalizer):
"""2-digit OCR numbers should be accepted."""
result = normalizer.normalize("42")
assert result.is_valid is True
assert result.value == "42"
def test_too_short(self, normalizer):
result = normalizer.normalize("1234")
"""Single-digit OCR should be rejected."""
result = normalizer.normalize("5")
assert result.is_valid is False
def test_empty_string(self, normalizer):
@@ -477,6 +516,38 @@ class TestAmountNormalizer:
assert result.value == "100.00"
assert result.is_valid is True
def test_astronomical_amount_rejected(self, normalizer):
"""IBAN digits should NOT produce astronomical amounts (>10M)."""
# IBAN "SE14120000001201138650" contains long digit sequences
# The standalone fallback pattern should not extract these as amounts
result = normalizer.normalize("SE14120000001201138650")
if result.is_valid:
assert float(result.value) < 10_000_000
def test_large_valid_amount_accepted(self, normalizer):
"""Valid large amount like 108000,00 should be accepted."""
result = normalizer.normalize("108000,00")
assert result.value == "108000.00"
assert result.is_valid is True
def test_standalone_iban_digits_rejected(self, normalizer):
"""Very long digit sequence (IBAN fragment) should not produce >10M."""
result = normalizer.normalize("1036149234823114")
if result.is_valid:
assert float(result.value) < 10_000_000
def test_main_pattern_rejects_over_10m(self, normalizer):
"""Main regex path should reject amounts over 10M (e.g. IBAN-like digits)."""
result = normalizer.normalize("Belopp 81648164,00 kr")
# 81648164.00 > 10M, should be rejected
assert not result.is_valid or float(result.value) < 10_000_000
def test_main_pattern_accepts_under_10m(self, normalizer):
"""Main regex path should accept valid amounts under 10M."""
result = normalizer.normalize("Summa 999999,99 kr")
assert result.value == "999999.99"
assert result.is_valid is True
class TestEnhancedAmountNormalizer:
"""Tests for EnhancedAmountNormalizer."""

View File

@@ -670,5 +670,387 @@ class TestProcessPdfTokenPath:
assert call_args[3] == 100 # image height
class TestDpiPassthrough:
"""Tests for DPI being passed from pipeline to FieldExtractor (Bug 1)."""
def test_field_extractor_receives_pipeline_dpi(self):
"""FieldExtractor should receive the pipeline's DPI, not default to 300."""
with patch('backend.pipeline.pipeline.YOLODetector'):
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
InferencePipeline(
model_path='/fake/model.pt',
dpi=150,
use_gpu=False,
)
mock_fe_cls.assert_called_once_with(
ocr_lang='en', use_gpu=False, dpi=150
)
def test_field_extractor_receives_default_dpi(self):
"""When dpi=300 (default), FieldExtractor should also get 300."""
with patch('backend.pipeline.pipeline.YOLODetector'):
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
InferencePipeline(
model_path='/fake/model.pt',
dpi=300,
use_gpu=False,
)
mock_fe_cls.assert_called_once_with(
ocr_lang='en', use_gpu=False, dpi=300
)
class TestFallbackPatternExtraction:
"""Tests for _extract_with_patterns fallback regex (Bugs 2, 3)."""
def _make_pipeline_with_patterns(self):
"""Create pipeline with mocked internals for pattern testing."""
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
p.dpi = 150
p.enable_fallback = True
return p
def test_bankgiro_no_match_in_org_number(self):
"""Bankgiro regex must NOT match digits embedded in an org number."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Org.nr 802546-1610", result)
assert 'Bankgiro' not in result.fields
def test_bankgiro_matches_labeled(self):
"""Bankgiro regex should match when preceded by 'Bankgiro' label."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Bankgiro 5393-9484", result)
assert result.fields.get('Bankgiro') == '5393-9484'
def test_bankgiro_matches_standalone(self):
"""Bankgiro regex should match a standalone 4-4 digit pattern."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Betala till 5393-9484 senast", result)
assert result.fields.get('Bankgiro') == '5393-9484'
def test_amount_rejects_bare_integer(self):
"""Amount regex must NOT match bare integers like 'Summa 1'."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Summa 1 Medlemsavgift", result)
assert 'Amount' not in result.fields
def test_amount_requires_decimal(self):
"""Amount regex should require a decimal separator."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Total 5 items", result)
assert 'Amount' not in result.fields
def test_amount_with_decimal_works(self):
"""Amount regex should match Swedish decimal amounts."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Att betala 1 234,56 SEK", result)
assert 'Amount' in result.fields
assert float(result.fields['Amount']) == pytest.approx(1234.56, abs=0.01)
def test_amount_with_sek_suffix(self):
"""Amount regex should match amounts ending with SEK."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("7 500,00 SEK", result)
assert 'Amount' in result.fields
assert float(result.fields['Amount']) == pytest.approx(7500.00, abs=0.01)
def test_fallback_extracts_invoice_date(self):
"""Fallback should extract InvoiceDate from Swedish text."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Fakturadatum 2025-01-15 Referens ABC", result)
assert result.fields.get('InvoiceDate') == '2025-01-15'
def test_fallback_extracts_due_date(self):
"""Fallback should extract InvoiceDueDate from Swedish text."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Forfallodag 2025-02-15 Belopp", result)
assert result.fields.get('InvoiceDueDate') == '2025-02-15'
def test_fallback_extracts_supplier_org(self):
"""Fallback should extract supplier_organisation_number."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Org.nr 556123-4567 Stockholm", result)
assert result.fields.get('supplier_organisation_number') == '556123-4567'
def test_fallback_extracts_plusgiro(self):
"""Fallback should extract Plusgiro number."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Plusgiro 12 34 56-7 betalning", result)
assert 'Plusgiro' in result.fields
def test_fallback_skips_year_as_invoice_number(self):
"""Fallback should NOT extract year-like value as InvoiceNumber."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Fakturanr 2025 Datum 2025-01-15", result)
assert 'InvoiceNumber' not in result.fields
def test_fallback_accepts_valid_invoice_number(self):
"""Fallback should extract valid non-year InvoiceNumber."""
p = self._make_pipeline_with_patterns()
result = InferenceResult()
p._extract_with_patterns("Fakturanr 12345 Summa", result)
assert result.fields.get('InvoiceNumber') == '12345'
class TestDateValidation:
"""Tests for InvoiceDueDate < InvoiceDate validation (Bug 6)."""
def _make_pipeline_for_merge(self):
"""Create pipeline with mocked internals for merge testing."""
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
p.payment_line_parser = MagicMock()
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
return p
def test_due_date_before_invoice_date_dropped(self):
"""DueDate earlier than InvoiceDate should be removed."""
from backend.pipeline.field_extractor import ExtractedField
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name='InvoiceDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 0, 100, 50), page_no=0,
),
ExtractedField(
field_name='InvoiceDueDate', raw_text='2025-12-01',
normalized_value='2025-12-01', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 60, 100, 110), page_no=0,
),
]
p._merge_fields(result)
assert 'InvoiceDate' in result.fields
assert 'InvoiceDueDate' not in result.fields
def test_valid_dates_preserved(self):
"""Both dates kept when DueDate >= InvoiceDate."""
from backend.pipeline.field_extractor import ExtractedField
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name='InvoiceDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 0, 100, 50), page_no=0,
),
ExtractedField(
field_name='InvoiceDueDate', raw_text='2026-02-15',
normalized_value='2026-02-15', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 60, 100, 110), page_no=0,
),
]
p._merge_fields(result)
assert result.fields['InvoiceDate'] == '2026-01-16'
assert result.fields['InvoiceDueDate'] == '2026-02-15'
def test_same_dates_preserved(self):
"""Same InvoiceDate and DueDate should both be kept."""
from backend.pipeline.field_extractor import ExtractedField
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name='InvoiceDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 0, 100, 50), page_no=0,
),
ExtractedField(
field_name='InvoiceDueDate', raw_text='2026-01-16',
normalized_value='2026-01-16', confidence=0.9,
detection_confidence=0.9, ocr_confidence=1.0,
bbox=(0, 60, 100, 110), page_no=0,
),
]
p._merge_fields(result)
assert result.fields['InvoiceDate'] == '2026-01-16'
assert result.fields['InvoiceDueDate'] == '2026-01-16'
class TestCrossFieldDedup:
"""Tests for cross-field deduplication of InvoiceNumber vs OCR/Bankgiro."""
def _make_pipeline_for_merge(self):
"""Create pipeline with mocked internals for merge testing."""
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
p.payment_line_parser = MagicMock()
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
return p
def _make_extracted_field(self, field_name, raw_text, normalized, confidence=0.9):
from backend.pipeline.field_extractor import ExtractedField
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized,
confidence=confidence,
detection_confidence=confidence,
ocr_confidence=1.0,
bbox=(0, 0, 100, 50),
page_no=0,
)
def test_invoice_number_not_same_as_ocr(self):
"""When InvoiceNumber == OCR, InvoiceNumber should be dropped."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '9179845608', '9179845608'),
self._make_extracted_field('OCR', '9179845608', '9179845608'),
self._make_extracted_field('Amount', '1234,56', '1234.56'),
]
p._merge_fields(result)
assert 'OCR' in result.fields
assert result.fields['OCR'] == '9179845608'
assert 'InvoiceNumber' not in result.fields
def test_invoice_number_not_same_as_bankgiro_digits(self):
"""When InvoiceNumber digits == Bankgiro digits, InvoiceNumber should be dropped."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '53939484', '53939484'),
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
self._make_extracted_field('Amount', '500,00', '500.00'),
]
p._merge_fields(result)
assert 'Bankgiro' in result.fields
assert result.fields['Bankgiro'] == '5393-9484'
assert 'InvoiceNumber' not in result.fields
def test_unrelated_values_kept(self):
"""When InvoiceNumber, OCR, and Bankgiro are all different, keep all."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
self._make_extracted_field('OCR', '9179845608', '9179845608'),
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
]
p._merge_fields(result)
assert result.fields['InvoiceNumber'] == '19061'
assert result.fields['OCR'] == '9179845608'
assert result.fields['Bankgiro'] == '5393-9484'
def test_dedup_after_fallback_re_add(self):
"""Dedup should remove InvoiceNumber re-added by fallback if it matches OCR."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
# Simulate state after fallback re-adds InvoiceNumber = OCR
result.fields = {
'OCR': '758200602426',
'Amount': '164.00',
'InvoiceNumber': '758200602426', # re-added by fallback
}
result.confidence = {
'OCR': 0.9,
'Amount': 0.9,
'InvoiceNumber': 0.5, # fallback confidence
}
result.bboxes = {}
p._dedup_invoice_number(result)
assert 'InvoiceNumber' not in result.fields
assert 'OCR' in result.fields
def test_invoice_number_substring_of_bankgiro(self):
"""When InvoiceNumber digits are a substring of Bankgiro digits, drop InvoiceNumber."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '4639', '4639'),
self._make_extracted_field('Bankgiro', '134-4639', '134-4639'),
self._make_extracted_field('Amount', '500,00', '500.00'),
]
p._merge_fields(result)
assert 'Bankgiro' in result.fields
assert result.fields['Bankgiro'] == '134-4639'
assert 'InvoiceNumber' not in result.fields
def test_invoice_number_not_substring_of_unrelated_bankgiro(self):
"""When InvoiceNumber is NOT a substring of Bankgiro, keep both."""
p = self._make_pipeline_for_merge()
result = InferenceResult()
result.extracted_fields = [
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
self._make_extracted_field('Amount', '500,00', '500.00'),
]
p._merge_fields(result)
assert result.fields['InvoiceNumber'] == '19061'
assert result.fields['Bankgiro'] == '5393-9484'
class TestFallbackTrigger:
"""Tests for _needs_fallback trigger threshold."""
def _make_pipeline(self):
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
return p
def test_fallback_triggers_when_1_key_field_missing(self):
"""Should trigger when only 1 key field (e.g. InvoiceNumber) is missing."""
p = self._make_pipeline()
result = InferenceResult()
result.fields = {
'Amount': '1234.56',
'OCR': '12345678901',
'InvoiceDate': '2025-01-15',
'InvoiceDueDate': '2025-02-15',
'supplier_organisation_number': '556123-4567',
}
# InvoiceNumber missing -> should trigger
assert p._needs_fallback(result) is True
def test_fallback_triggers_when_dates_missing(self):
"""Should trigger when all key fields present but 2+ important fields missing."""
p = self._make_pipeline()
result = InferenceResult()
result.fields = {
'Amount': '1234.56',
'InvoiceNumber': '12345',
'OCR': '12345678901',
}
# InvoiceDate, InvoiceDueDate, supplier_org all missing -> should trigger
assert p._needs_fallback(result) is True
def test_no_fallback_when_all_fields_present(self):
"""Should NOT trigger when all key and important fields present."""
p = self._make_pipeline()
result = InferenceResult()
result.fields = {
'Amount': '1234.56',
'InvoiceNumber': '12345',
'OCR': '12345678901',
'InvoiceDate': '2025-01-15',
'InvoiceDueDate': '2025-02-15',
'supplier_organisation_number': '556123-4567',
}
assert p._needs_fallback(result) is False
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -335,12 +335,15 @@ class TestFallbackLogic:
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# All key fields present
# All key and important fields present
result = InferenceResult(
fields={
"Amount": "1500.00",
"InvoiceNumber": "INV-001",
"OCR": "12345678901234",
"InvoiceDate": "2025-01-15",
"InvoiceDueDate": "2025-02-15",
"supplier_organisation_number": "556123-4567",
}
)

View File

@@ -203,15 +203,33 @@ class TestValueSelectorOcrField:
assert len(result) == 1
assert result[0].text == "94228110015950070"
def test_ignores_short_digit_tokens(self):
"""Tokens with fewer than 5 digits are not OCR references."""
tokens = _tokens("OCR", "123")
def test_ignores_single_digit_tokens(self):
"""Tokens with fewer than 2 digits are not OCR references."""
tokens = _tokens("OCR", "5")
result = ValueSelector.select_value_tokens(tokens, "OCR")
# Fallback: return all tokens since no valid OCR found
assert len(result) == 2
def test_ocr_4_digit_token_selected(self):
"""4-digit OCR token should be selected."""
tokens = _tokens("OCR", "3046")
result = ValueSelector.select_value_tokens(tokens, "OCR")
assert len(result) == 1
assert result[0].text == "3046"
def test_ocr_2_digit_token_selected(self):
"""2-digit OCR token should be selected."""
tokens = _tokens("OCR", "42")
result = ValueSelector.select_value_tokens(tokens, "OCR")
assert len(result) == 1
assert result[0].text == "42"
class TestValueSelectorInvoiceNumberField:
"""Tests for InvoiceNumber field value selection."""