Re-structure the project.
This commit is contained in:
0
tests/inference/__init__.py
Normal file
0
tests/inference/__init__.py
Normal file
401
tests/inference/test_field_extractor.py
Normal file
401
tests/inference/test_field_extractor.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Tests for Field Extractor
|
||||
|
||||
Tests field normalization functions:
|
||||
- Invoice number normalization
|
||||
- Date normalization
|
||||
- Amount normalization
|
||||
- Bankgiro/Plusgiro normalization
|
||||
- OCR number normalization
|
||||
- Payment line normalization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from src.inference.field_extractor import FieldExtractor
|
||||
|
||||
|
||||
class TestFieldExtractorInit:
|
||||
"""Tests for FieldExtractor initialization."""
|
||||
|
||||
def test_default_init(self):
|
||||
"""Test default initialization."""
|
||||
extractor = FieldExtractor()
|
||||
assert extractor.ocr_lang == 'en'
|
||||
assert extractor.use_gpu is False
|
||||
assert extractor.bbox_padding == 0.1
|
||||
assert extractor.dpi == 300
|
||||
|
||||
def test_custom_init(self):
|
||||
"""Test custom initialization."""
|
||||
extractor = FieldExtractor(
|
||||
ocr_lang='sv',
|
||||
use_gpu=True,
|
||||
bbox_padding=0.2,
|
||||
dpi=150
|
||||
)
|
||||
assert extractor.ocr_lang == 'sv'
|
||||
assert extractor.use_gpu is True
|
||||
assert extractor.bbox_padding == 0.2
|
||||
assert extractor.dpi == 150
|
||||
|
||||
|
||||
class TestNormalizeInvoiceNumber:
|
||||
"""Tests for invoice number normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_alphanumeric_invoice_number(self, extractor):
|
||||
"""Test alphanumeric invoice number like A3861."""
|
||||
result, is_valid, error = extractor._normalize_invoice_number("Fakturanummer: A3861")
|
||||
assert result == 'A3861'
|
||||
assert is_valid is True
|
||||
|
||||
def test_prefix_invoice_number(self, extractor):
|
||||
"""Test invoice number with prefix like INV12345."""
|
||||
result, is_valid, error = extractor._normalize_invoice_number("Invoice INV12345")
|
||||
assert result is not None
|
||||
assert 'INV' in result or '12345' in result
|
||||
|
||||
def test_numeric_invoice_number(self, extractor):
|
||||
"""Test pure numeric invoice number."""
|
||||
result, is_valid, error = extractor._normalize_invoice_number("Invoice: 12345678")
|
||||
assert result is not None
|
||||
assert result.isdigit()
|
||||
|
||||
def test_year_prefixed_invoice_number(self, extractor):
|
||||
"""Test invoice number with year prefix like 2024-001."""
|
||||
result, is_valid, error = extractor._normalize_invoice_number("Faktura 2024-12345")
|
||||
assert result is not None
|
||||
assert '2024' in result
|
||||
|
||||
def test_avoid_long_ocr_sequence(self, extractor):
|
||||
"""Test that long OCR-like sequences are avoided."""
|
||||
# When text contains both short invoice number and long OCR sequence
|
||||
text = "Fakturanummer: A3861 OCR: 310196187399952763290708"
|
||||
result, is_valid, error = extractor._normalize_invoice_number(text)
|
||||
# Should prefer the shorter alphanumeric pattern
|
||||
assert result == 'A3861'
|
||||
|
||||
def test_empty_string(self, extractor):
|
||||
"""Test empty string input."""
|
||||
result, is_valid, error = extractor._normalize_invoice_number("")
|
||||
assert result is None or is_valid is False
|
||||
|
||||
|
||||
class TestNormalizeBankgiro:
|
||||
"""Tests for Bankgiro normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_standard_7_digit_format(self, extractor):
|
||||
"""Test 7-digit Bankgiro XXX-XXXX."""
|
||||
result, is_valid, error = extractor._normalize_bankgiro("Bankgiro: 782-1713")
|
||||
assert result == '782-1713'
|
||||
assert is_valid is True
|
||||
|
||||
def test_standard_8_digit_format(self, extractor):
|
||||
"""Test 8-digit Bankgiro XXXX-XXXX."""
|
||||
result, is_valid, error = extractor._normalize_bankgiro("BG 5393-9484")
|
||||
assert result == '5393-9484'
|
||||
assert is_valid is True
|
||||
|
||||
def test_without_dash(self, extractor):
|
||||
"""Test Bankgiro without dash."""
|
||||
result, is_valid, error = extractor._normalize_bankgiro("Bankgiro 7821713")
|
||||
assert result is not None
|
||||
# Should be formatted with dash
|
||||
|
||||
def test_with_spaces(self, extractor):
|
||||
"""Test Bankgiro with spaces - may not parse if spaces break the pattern."""
|
||||
result, is_valid, error = extractor._normalize_bankgiro("BG: 782 1713")
|
||||
# Spaces in the middle might cause parsing issues - that's acceptable
|
||||
# The test passes if it doesn't crash
|
||||
|
||||
def test_invalid_bankgiro(self, extractor):
|
||||
"""Test invalid Bankgiro (too short)."""
|
||||
result, is_valid, error = extractor._normalize_bankgiro("BG: 123")
|
||||
# Should fail or return None
|
||||
|
||||
|
||||
class TestNormalizePlusgiro:
|
||||
"""Tests for Plusgiro normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_standard_format(self, extractor):
|
||||
"""Test standard Plusgiro format XXXXXXX-X."""
|
||||
result, is_valid, error = extractor._normalize_plusgiro("Plusgiro: 1234567-8")
|
||||
assert result is not None
|
||||
assert '-' in result
|
||||
|
||||
def test_without_dash(self, extractor):
|
||||
"""Test Plusgiro without dash."""
|
||||
result, is_valid, error = extractor._normalize_plusgiro("PG 12345678")
|
||||
assert result is not None
|
||||
|
||||
def test_distinguish_from_bankgiro(self, extractor):
|
||||
"""Test that Plusgiro is distinguished from Bankgiro by format."""
|
||||
# Plusgiro has 1 digit after dash, Bankgiro has 4
|
||||
pg_text = "4809603-6" # Plusgiro format
|
||||
bg_text = "782-1713" # Bankgiro format
|
||||
|
||||
pg_result, _, _ = extractor._normalize_plusgiro(pg_text)
|
||||
bg_result, _, _ = extractor._normalize_bankgiro(bg_text)
|
||||
|
||||
# Both should succeed in their respective normalizations
|
||||
|
||||
|
||||
class TestNormalizeAmount:
|
||||
"""Tests for Amount normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_swedish_format_comma(self, extractor):
|
||||
"""Test Swedish format with comma: 11 699,00."""
|
||||
result, is_valid, error = extractor._normalize_amount("11 699,00 SEK")
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
|
||||
def test_integer_amount(self, extractor):
|
||||
"""Test integer amount without decimals."""
|
||||
result, is_valid, error = extractor._normalize_amount("Amount: 11699")
|
||||
assert result is not None
|
||||
|
||||
def test_with_currency(self, extractor):
|
||||
"""Test amount with currency symbol."""
|
||||
result, is_valid, error = extractor._normalize_amount("SEK 11 699,00")
|
||||
assert result is not None
|
||||
|
||||
def test_large_amount(self, extractor):
|
||||
"""Test large amount with thousand separators."""
|
||||
result, is_valid, error = extractor._normalize_amount("1 234 567,89")
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestNormalizeOCR:
|
||||
"""Tests for OCR number normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_standard_ocr(self, extractor):
|
||||
"""Test standard OCR number."""
|
||||
result, is_valid, error = extractor._normalize_ocr_number("OCR: 310196187399952")
|
||||
assert result == '310196187399952'
|
||||
assert is_valid is True
|
||||
|
||||
def test_ocr_with_spaces(self, extractor):
|
||||
"""Test OCR number with spaces."""
|
||||
result, is_valid, error = extractor._normalize_ocr_number("3101 9618 7399 952")
|
||||
assert result is not None
|
||||
assert ' ' not in result # Spaces should be removed
|
||||
|
||||
def test_short_ocr_invalid(self, extractor):
|
||||
"""Test that too short OCR is invalid."""
|
||||
result, is_valid, error = extractor._normalize_ocr_number("123")
|
||||
assert is_valid is False
|
||||
|
||||
|
||||
class TestNormalizeDate:
|
||||
"""Tests for date normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_iso_format(self, extractor):
|
||||
"""Test ISO date format YYYY-MM-DD."""
|
||||
result, is_valid, error = extractor._normalize_date("2026-01-31")
|
||||
assert result == '2026-01-31'
|
||||
assert is_valid is True
|
||||
|
||||
def test_swedish_format(self, extractor):
|
||||
"""Test Swedish format with dots: 31.01.2026."""
|
||||
result, is_valid, error = extractor._normalize_date("31.01.2026")
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
|
||||
def test_slash_format(self, extractor):
|
||||
"""Test slash format: 31/01/2026."""
|
||||
result, is_valid, error = extractor._normalize_date("31/01/2026")
|
||||
assert result is not None
|
||||
|
||||
def test_compact_format(self, extractor):
|
||||
"""Test compact format: 20260131."""
|
||||
result, is_valid, error = extractor._normalize_date("20260131")
|
||||
assert result is not None
|
||||
|
||||
def test_invalid_date(self, extractor):
|
||||
"""Test invalid date."""
|
||||
result, is_valid, error = extractor._normalize_date("not a date")
|
||||
assert is_valid is False
|
||||
|
||||
|
||||
class TestNormalizePaymentLine:
|
||||
"""Tests for payment line normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_standard_payment_line(self, extractor):
|
||||
"""Test standard payment line parsing."""
|
||||
text = "# 310196187399952 # 11699 00 6 > 7821713#41#"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
# Should be formatted as: OCR:xxx Amount:xxx BG:xxx
|
||||
assert 'OCR:' in result or '310196187399952' in result
|
||||
|
||||
def test_payment_line_with_spaces_in_bg(self, extractor):
|
||||
"""Test payment line with spaces in Bankgiro."""
|
||||
text = "# 310196187399952 # 11699 00 6 > 78 2 1 713 #41#"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
# Bankgiro should be normalized despite spaces
|
||||
|
||||
def test_payment_line_with_spaces_in_check_digits(self, extractor):
|
||||
"""Test payment line with spaces around check digits: #41 # instead of #41#."""
|
||||
text = "# 6026726908 # 736 00 9 > 5692041 #41 #"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert "6026726908" in result
|
||||
assert "736 00" in result
|
||||
assert "5692041#41#" in result
|
||||
|
||||
def test_payment_line_with_ocr_spaces_in_amount(self, extractor):
|
||||
"""Test payment line with OCR-induced spaces in amount: '12 0 0 00' -> '1200 00'."""
|
||||
text = "# 11000770600242 # 12 0 0 00 5 3082963#41#"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert "11000770600242" in result
|
||||
assert "1200 00" in result
|
||||
assert "3082963#41#" in result
|
||||
|
||||
def test_payment_line_without_greater_symbol(self, extractor):
|
||||
"""Test payment line with missing > symbol (low-DPI OCR issue)."""
|
||||
text = "# 11000770600242 # 1200 00 5 3082963#41#"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert "11000770600242" in result
|
||||
assert "1200 00" in result
|
||||
|
||||
|
||||
class TestNormalizeCustomerNumber:
|
||||
"""Tests for customer number normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_with_separator(self, extractor):
|
||||
"""Test customer number with separator: JTY 576-3."""
|
||||
result, is_valid, error = extractor._normalize_customer_number("Kundnr: JTY 576-3")
|
||||
assert result is not None
|
||||
|
||||
def test_compact_format(self, extractor):
|
||||
"""Test compact customer number: JTY5763."""
|
||||
result, is_valid, error = extractor._normalize_customer_number("JTY5763")
|
||||
assert result is not None
|
||||
|
||||
def test_format_without_dash(self, extractor):
|
||||
"""Test customer number format without dash: Dwq 211X -> DWQ 211-X."""
|
||||
text = "Dwq 211X Billo SE 106 43 Stockholm"
|
||||
result, is_valid, error = extractor._normalize_customer_number(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert result == "DWQ 211-X"
|
||||
|
||||
def test_swedish_postal_code_exclusion(self, extractor):
|
||||
"""Test that Swedish postal codes are excluded: SE 106 43 should not be extracted."""
|
||||
text = "SE 106 43 Stockholm"
|
||||
result, is_valid, error = extractor._normalize_customer_number(text)
|
||||
|
||||
# Should not extract postal code
|
||||
assert result is None or "SE 106" not in result
|
||||
|
||||
def test_customer_number_with_postal_code_in_text(self, extractor):
|
||||
"""Test extracting customer number when postal code is also present."""
|
||||
text = "Customer: ABC 123X, Address: SE 106 43 Stockholm"
|
||||
result, is_valid, error = extractor._normalize_customer_number(text)
|
||||
|
||||
assert result is not None
|
||||
assert "ABC" in result
|
||||
# Should not extract postal code
|
||||
assert "SE 106" not in result if result else True
|
||||
|
||||
|
||||
class TestNormalizeSupplierOrgNumber:
|
||||
"""Tests for supplier organization number normalization."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_standard_format(self, extractor):
|
||||
"""Test standard format NNNNNN-NNNN."""
|
||||
result, is_valid, error = extractor._normalize_supplier_org_number("Org.nr 516406-1102")
|
||||
assert result == '516406-1102'
|
||||
assert is_valid is True
|
||||
|
||||
def test_vat_number_format(self, extractor):
|
||||
"""Test VAT number format SE + 10 digits + 01."""
|
||||
result, is_valid, error = extractor._normalize_supplier_org_number("Momsreg.nr SE556123456701")
|
||||
assert result is not None
|
||||
assert '-' in result
|
||||
|
||||
|
||||
class TestNormalizeAndValidateDispatch:
|
||||
"""Tests for the _normalize_and_validate dispatch method."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self):
|
||||
return FieldExtractor()
|
||||
|
||||
def test_dispatch_invoice_number(self, extractor):
|
||||
"""Test dispatch to invoice number normalizer."""
|
||||
result, is_valid, error = extractor._normalize_and_validate('InvoiceNumber', 'A3861')
|
||||
assert result is not None
|
||||
|
||||
def test_dispatch_amount(self, extractor):
|
||||
"""Test dispatch to amount normalizer."""
|
||||
result, is_valid, error = extractor._normalize_and_validate('Amount', '11699,00')
|
||||
assert result is not None
|
||||
|
||||
def test_dispatch_bankgiro(self, extractor):
|
||||
"""Test dispatch to Bankgiro normalizer."""
|
||||
result, is_valid, error = extractor._normalize_and_validate('Bankgiro', '782-1713')
|
||||
assert result is not None
|
||||
|
||||
def test_dispatch_ocr(self, extractor):
|
||||
"""Test dispatch to OCR normalizer."""
|
||||
result, is_valid, error = extractor._normalize_and_validate('OCR', '310196187399952')
|
||||
assert result is not None
|
||||
|
||||
def test_dispatch_date(self, extractor):
|
||||
"""Test dispatch to date normalizer."""
|
||||
result, is_valid, error = extractor._normalize_and_validate('InvoiceDate', '2026-01-31')
|
||||
assert result is not None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
326
tests/inference/test_pipeline.py
Normal file
326
tests/inference/test_pipeline.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
Tests for Inference Pipeline
|
||||
|
||||
Tests the cross-validation logic between payment_line and detected fields:
|
||||
- OCR override from payment_line
|
||||
- Amount override from payment_line
|
||||
- Bankgiro/Plusgiro comparison (no override)
|
||||
- Validation scoring
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from src.inference.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
|
||||
|
||||
|
||||
class TestCrossValidationResult:
|
||||
"""Tests for CrossValidationResult dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values."""
|
||||
cv = CrossValidationResult()
|
||||
assert cv.ocr_match is None
|
||||
assert cv.amount_match is None
|
||||
assert cv.bankgiro_match is None
|
||||
assert cv.plusgiro_match is None
|
||||
assert cv.payment_line_ocr is None
|
||||
assert cv.payment_line_amount is None
|
||||
assert cv.payment_line_account is None
|
||||
assert cv.payment_line_account_type is None
|
||||
|
||||
def test_attributes(self):
|
||||
"""Test setting attributes."""
|
||||
cv = CrossValidationResult()
|
||||
cv.ocr_match = True
|
||||
cv.amount_match = True
|
||||
cv.payment_line_ocr = '12345678901'
|
||||
cv.payment_line_amount = '100'
|
||||
cv.details = ['OCR match', 'Amount match']
|
||||
|
||||
assert cv.ocr_match is True
|
||||
assert cv.amount_match is True
|
||||
assert cv.payment_line_ocr == '12345678901'
|
||||
assert 'OCR match' in cv.details
|
||||
|
||||
|
||||
class TestInferenceResult:
|
||||
"""Tests for InferenceResult dataclass."""
|
||||
|
||||
def test_default_fields(self):
|
||||
"""Test default field values."""
|
||||
result = InferenceResult()
|
||||
assert result.fields == {}
|
||||
assert result.confidence == {}
|
||||
assert result.errors == []
|
||||
|
||||
def test_set_fields(self):
|
||||
"""Test setting field values."""
|
||||
result = InferenceResult()
|
||||
result.fields = {
|
||||
'OCR': '12345678901',
|
||||
'Amount': '100',
|
||||
'Bankgiro': '782-1713'
|
||||
}
|
||||
result.confidence = {
|
||||
'OCR': 0.95,
|
||||
'Amount': 0.90,
|
||||
'Bankgiro': 0.88
|
||||
}
|
||||
|
||||
assert result.fields['OCR'] == '12345678901'
|
||||
assert result.fields['Amount'] == '100'
|
||||
assert result.fields['Bankgiro'] == '782-1713'
|
||||
|
||||
def test_cross_validation_assignment(self):
|
||||
"""Test cross validation assignment."""
|
||||
result = InferenceResult()
|
||||
result.fields = {'OCR': '12345678901'}
|
||||
|
||||
cv = CrossValidationResult()
|
||||
cv.ocr_match = True
|
||||
cv.payment_line_ocr = '12345678901'
|
||||
result.cross_validation = cv
|
||||
|
||||
assert result.cross_validation is not None
|
||||
assert result.cross_validation.ocr_match is True
|
||||
|
||||
|
||||
class TestPaymentLineParsingInPipeline:
|
||||
"""Tests for payment_line parsing in cross-validation."""
|
||||
|
||||
def test_parse_payment_line_format(self):
|
||||
"""Test parsing of payment_line format: OCR:xxx Amount:xxx BG:xxx"""
|
||||
# Simulate the parsing logic from pipeline
|
||||
payment_line = "OCR:310196187399952 Amount:11699 BG:782-1713"
|
||||
|
||||
pl_parts = {}
|
||||
for part in payment_line.split():
|
||||
if ':' in part:
|
||||
key, value = part.split(':', 1)
|
||||
pl_parts[key.upper()] = value
|
||||
|
||||
assert pl_parts.get('OCR') == '310196187399952'
|
||||
assert pl_parts.get('AMOUNT') == '11699'
|
||||
assert pl_parts.get('BG') == '782-1713'
|
||||
|
||||
def test_parse_payment_line_with_plusgiro(self):
|
||||
"""Test parsing with Plusgiro."""
|
||||
payment_line = "OCR:12345678901 Amount:500 PG:1234567-8"
|
||||
|
||||
pl_parts = {}
|
||||
for part in payment_line.split():
|
||||
if ':' in part:
|
||||
key, value = part.split(':', 1)
|
||||
pl_parts[key.upper()] = value
|
||||
|
||||
assert pl_parts.get('OCR') == '12345678901'
|
||||
assert pl_parts.get('PG') == '1234567-8'
|
||||
assert pl_parts.get('BG') is None
|
||||
|
||||
def test_parse_empty_payment_line(self):
|
||||
"""Test parsing empty payment_line."""
|
||||
payment_line = ""
|
||||
|
||||
pl_parts = {}
|
||||
for part in payment_line.split():
|
||||
if ':' in part:
|
||||
key, value = part.split(':', 1)
|
||||
pl_parts[key.upper()] = value
|
||||
|
||||
assert pl_parts.get('OCR') is None
|
||||
assert pl_parts.get('AMOUNT') is None
|
||||
|
||||
|
||||
class TestOCROverride:
|
||||
"""Tests for OCR override logic."""
|
||||
|
||||
def test_ocr_override_when_different(self):
|
||||
"""Test OCR is overridden when payment_line value differs."""
|
||||
result = InferenceResult()
|
||||
result.fields = {'OCR': 'wrong_ocr_12345', 'payment_line': 'OCR:correct_ocr_67890 Amount:100 BG:782-1713'}
|
||||
|
||||
# Simulate the override logic
|
||||
payment_line = result.fields.get('payment_line')
|
||||
pl_parts = {}
|
||||
for part in str(payment_line).split():
|
||||
if ':' in part:
|
||||
key, value = part.split(':', 1)
|
||||
pl_parts[key.upper()] = value
|
||||
|
||||
payment_line_ocr = pl_parts.get('OCR')
|
||||
|
||||
# Override detected OCR with payment_line OCR
|
||||
if payment_line_ocr:
|
||||
result.fields['OCR'] = payment_line_ocr
|
||||
|
||||
assert result.fields['OCR'] == 'correct_ocr_67890'
|
||||
|
||||
def test_ocr_no_override_when_no_payment_line(self):
|
||||
"""Test OCR is not overridden when no payment_line."""
|
||||
result = InferenceResult()
|
||||
result.fields = {'OCR': 'original_ocr_12345'}
|
||||
|
||||
# No payment_line, no override
|
||||
assert result.fields['OCR'] == 'original_ocr_12345'
|
||||
|
||||
|
||||
class TestAmountOverride:
|
||||
"""Tests for Amount override logic."""
|
||||
|
||||
def test_amount_override(self):
|
||||
"""Test Amount is overridden from payment_line."""
|
||||
result = InferenceResult()
|
||||
result.fields = {
|
||||
'Amount': '999.00',
|
||||
'payment_line': 'OCR:12345 Amount:11699 BG:782-1713'
|
||||
}
|
||||
|
||||
payment_line = result.fields.get('payment_line')
|
||||
pl_parts = {}
|
||||
for part in str(payment_line).split():
|
||||
if ':' in part:
|
||||
key, value = part.split(':', 1)
|
||||
pl_parts[key.upper()] = value
|
||||
|
||||
payment_line_amount = pl_parts.get('AMOUNT')
|
||||
|
||||
if payment_line_amount:
|
||||
result.fields['Amount'] = payment_line_amount
|
||||
|
||||
assert result.fields['Amount'] == '11699'
|
||||
|
||||
|
||||
class TestBankgiroComparison:
|
||||
"""Tests for Bankgiro comparison (no override)."""
|
||||
|
||||
def test_bankgiro_match(self):
|
||||
"""Test Bankgiro match detection."""
|
||||
import re
|
||||
|
||||
detected_bankgiro = '782-1713'
|
||||
payment_line_account = '782-1713'
|
||||
|
||||
det_digits = re.sub(r'\D', '', detected_bankgiro)
|
||||
pl_digits = re.sub(r'\D', '', payment_line_account)
|
||||
|
||||
assert det_digits == pl_digits
|
||||
assert det_digits == '7821713'
|
||||
|
||||
def test_bankgiro_mismatch(self):
|
||||
"""Test Bankgiro mismatch detection."""
|
||||
import re
|
||||
|
||||
detected_bankgiro = '782-1713'
|
||||
payment_line_account = '123-4567'
|
||||
|
||||
det_digits = re.sub(r'\D', '', detected_bankgiro)
|
||||
pl_digits = re.sub(r'\D', '', payment_line_account)
|
||||
|
||||
assert det_digits != pl_digits
|
||||
|
||||
def test_bankgiro_not_overridden(self):
|
||||
"""Test that Bankgiro is NOT overridden from payment_line."""
|
||||
result = InferenceResult()
|
||||
result.fields = {
|
||||
'Bankgiro': '999-9999', # Different value
|
||||
'payment_line': 'OCR:12345 Amount:100 BG:782-1713'
|
||||
}
|
||||
|
||||
# Bankgiro should NOT be overridden (per current logic)
|
||||
# Only compared for validation
|
||||
original_bankgiro = result.fields['Bankgiro']
|
||||
|
||||
# The override logic explicitly skips Bankgiro
|
||||
# So we verify it remains unchanged
|
||||
assert result.fields['Bankgiro'] == '999-9999'
|
||||
assert result.fields['Bankgiro'] == original_bankgiro
|
||||
|
||||
|
||||
class TestValidationScoring:
|
||||
"""Tests for validation scoring logic."""
|
||||
|
||||
def test_all_fields_match(self):
|
||||
"""Test score when all fields match."""
|
||||
matches = [True, True, True] # OCR, Amount, Bankgiro
|
||||
match_count = sum(1 for m in matches if m)
|
||||
total = len(matches)
|
||||
|
||||
assert match_count == 3
|
||||
assert total == 3
|
||||
|
||||
def test_partial_match(self):
|
||||
"""Test score with partial matches."""
|
||||
matches = [True, True, False] # OCR match, Amount match, Bankgiro mismatch
|
||||
match_count = sum(1 for m in matches if m)
|
||||
|
||||
assert match_count == 2
|
||||
|
||||
def test_no_matches(self):
|
||||
"""Test score when nothing matches."""
|
||||
matches = [False, False, False]
|
||||
match_count = sum(1 for m in matches if m)
|
||||
|
||||
assert match_count == 0
|
||||
|
||||
def test_only_count_present_fields(self):
|
||||
"""Test that only present fields are counted."""
|
||||
# When invoice has both BG and PG but payment_line only has BG,
|
||||
# we should only count BG in validation
|
||||
|
||||
payment_line_account_type = 'bankgiro'
|
||||
bankgiro_match = True
|
||||
plusgiro_match = None # Not compared because payment_line doesn't have PG
|
||||
|
||||
matches = []
|
||||
if payment_line_account_type == 'bankgiro' and bankgiro_match is not None:
|
||||
matches.append(bankgiro_match)
|
||||
elif payment_line_account_type == 'plusgiro' and plusgiro_match is not None:
|
||||
matches.append(plusgiro_match)
|
||||
|
||||
assert len(matches) == 1
|
||||
assert matches[0] is True
|
||||
|
||||
|
||||
class TestAmountNormalization:
|
||||
"""Tests for amount normalization for comparison."""
|
||||
|
||||
def test_normalize_amount_with_comma(self):
|
||||
"""Test normalizing amount with comma decimal."""
|
||||
import re
|
||||
|
||||
amount = "11699,00"
|
||||
normalized = re.sub(r'[^\d]', '', amount)
|
||||
|
||||
# Remove trailing zeros for öre
|
||||
if len(normalized) > 2 and normalized[-2:] == '00':
|
||||
normalized = normalized[:-2]
|
||||
|
||||
assert normalized == '11699'
|
||||
|
||||
def test_normalize_amount_with_dot(self):
|
||||
"""Test normalizing amount with dot decimal."""
|
||||
import re
|
||||
|
||||
amount = "11699.00"
|
||||
normalized = re.sub(r'[^\d]', '', amount)
|
||||
|
||||
if len(normalized) > 2 and normalized[-2:] == '00':
|
||||
normalized = normalized[:-2]
|
||||
|
||||
assert normalized == '11699'
|
||||
|
||||
def test_normalize_amount_with_space_separator(self):
|
||||
"""Test normalizing amount with space thousand separator."""
|
||||
import re
|
||||
|
||||
amount = "11 699,00"
|
||||
normalized = re.sub(r'[^\d]', '', amount)
|
||||
|
||||
if len(normalized) > 2 and normalized[-2:] == '00':
|
||||
normalized = normalized[:-2]
|
||||
|
||||
assert normalized == '11699'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user