327 lines
10 KiB
Python
327 lines
10 KiB
Python
"""
|
|
Tests for Inference Pipeline
|
|
|
|
Tests the cross-validation logic between payment_line and detected fields:
|
|
- OCR override from payment_line
|
|
- Amount override from payment_line
|
|
- Bankgiro/Plusgiro comparison (no override)
|
|
- Validation scoring
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from inference.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
|
|
|
|
|
|
class TestCrossValidationResult:
|
|
"""Tests for CrossValidationResult dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test default values."""
|
|
cv = CrossValidationResult()
|
|
assert cv.ocr_match is None
|
|
assert cv.amount_match is None
|
|
assert cv.bankgiro_match is None
|
|
assert cv.plusgiro_match is None
|
|
assert cv.payment_line_ocr is None
|
|
assert cv.payment_line_amount is None
|
|
assert cv.payment_line_account is None
|
|
assert cv.payment_line_account_type is None
|
|
|
|
def test_attributes(self):
|
|
"""Test setting attributes."""
|
|
cv = CrossValidationResult()
|
|
cv.ocr_match = True
|
|
cv.amount_match = True
|
|
cv.payment_line_ocr = '12345678901'
|
|
cv.payment_line_amount = '100'
|
|
cv.details = ['OCR match', 'Amount match']
|
|
|
|
assert cv.ocr_match is True
|
|
assert cv.amount_match is True
|
|
assert cv.payment_line_ocr == '12345678901'
|
|
assert 'OCR match' in cv.details
|
|
|
|
|
|
class TestInferenceResult:
|
|
"""Tests for InferenceResult dataclass."""
|
|
|
|
def test_default_fields(self):
|
|
"""Test default field values."""
|
|
result = InferenceResult()
|
|
assert result.fields == {}
|
|
assert result.confidence == {}
|
|
assert result.errors == []
|
|
|
|
def test_set_fields(self):
|
|
"""Test setting field values."""
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'OCR': '12345678901',
|
|
'Amount': '100',
|
|
'Bankgiro': '782-1713'
|
|
}
|
|
result.confidence = {
|
|
'OCR': 0.95,
|
|
'Amount': 0.90,
|
|
'Bankgiro': 0.88
|
|
}
|
|
|
|
assert result.fields['OCR'] == '12345678901'
|
|
assert result.fields['Amount'] == '100'
|
|
assert result.fields['Bankgiro'] == '782-1713'
|
|
|
|
def test_cross_validation_assignment(self):
|
|
"""Test cross validation assignment."""
|
|
result = InferenceResult()
|
|
result.fields = {'OCR': '12345678901'}
|
|
|
|
cv = CrossValidationResult()
|
|
cv.ocr_match = True
|
|
cv.payment_line_ocr = '12345678901'
|
|
result.cross_validation = cv
|
|
|
|
assert result.cross_validation is not None
|
|
assert result.cross_validation.ocr_match is True
|
|
|
|
|
|
class TestPaymentLineParsingInPipeline:
|
|
"""Tests for payment_line parsing in cross-validation."""
|
|
|
|
def test_parse_payment_line_format(self):
|
|
"""Test parsing of payment_line format: OCR:xxx Amount:xxx BG:xxx"""
|
|
# Simulate the parsing logic from pipeline
|
|
payment_line = "OCR:310196187399952 Amount:11699 BG:782-1713"
|
|
|
|
pl_parts = {}
|
|
for part in payment_line.split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
assert pl_parts.get('OCR') == '310196187399952'
|
|
assert pl_parts.get('AMOUNT') == '11699'
|
|
assert pl_parts.get('BG') == '782-1713'
|
|
|
|
def test_parse_payment_line_with_plusgiro(self):
|
|
"""Test parsing with Plusgiro."""
|
|
payment_line = "OCR:12345678901 Amount:500 PG:1234567-8"
|
|
|
|
pl_parts = {}
|
|
for part in payment_line.split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
assert pl_parts.get('OCR') == '12345678901'
|
|
assert pl_parts.get('PG') == '1234567-8'
|
|
assert pl_parts.get('BG') is None
|
|
|
|
def test_parse_empty_payment_line(self):
|
|
"""Test parsing empty payment_line."""
|
|
payment_line = ""
|
|
|
|
pl_parts = {}
|
|
for part in payment_line.split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
assert pl_parts.get('OCR') is None
|
|
assert pl_parts.get('AMOUNT') is None
|
|
|
|
|
|
class TestOCROverride:
|
|
"""Tests for OCR override logic."""
|
|
|
|
def test_ocr_override_when_different(self):
|
|
"""Test OCR is overridden when payment_line value differs."""
|
|
result = InferenceResult()
|
|
result.fields = {'OCR': 'wrong_ocr_12345', 'payment_line': 'OCR:correct_ocr_67890 Amount:100 BG:782-1713'}
|
|
|
|
# Simulate the override logic
|
|
payment_line = result.fields.get('payment_line')
|
|
pl_parts = {}
|
|
for part in str(payment_line).split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
payment_line_ocr = pl_parts.get('OCR')
|
|
|
|
# Override detected OCR with payment_line OCR
|
|
if payment_line_ocr:
|
|
result.fields['OCR'] = payment_line_ocr
|
|
|
|
assert result.fields['OCR'] == 'correct_ocr_67890'
|
|
|
|
def test_ocr_no_override_when_no_payment_line(self):
|
|
"""Test OCR is not overridden when no payment_line."""
|
|
result = InferenceResult()
|
|
result.fields = {'OCR': 'original_ocr_12345'}
|
|
|
|
# No payment_line, no override
|
|
assert result.fields['OCR'] == 'original_ocr_12345'
|
|
|
|
|
|
class TestAmountOverride:
|
|
"""Tests for Amount override logic."""
|
|
|
|
def test_amount_override(self):
|
|
"""Test Amount is overridden from payment_line."""
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'Amount': '999.00',
|
|
'payment_line': 'OCR:12345 Amount:11699 BG:782-1713'
|
|
}
|
|
|
|
payment_line = result.fields.get('payment_line')
|
|
pl_parts = {}
|
|
for part in str(payment_line).split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
payment_line_amount = pl_parts.get('AMOUNT')
|
|
|
|
if payment_line_amount:
|
|
result.fields['Amount'] = payment_line_amount
|
|
|
|
assert result.fields['Amount'] == '11699'
|
|
|
|
|
|
class TestBankgiroComparison:
|
|
"""Tests for Bankgiro comparison (no override)."""
|
|
|
|
def test_bankgiro_match(self):
|
|
"""Test Bankgiro match detection."""
|
|
import re
|
|
|
|
detected_bankgiro = '782-1713'
|
|
payment_line_account = '782-1713'
|
|
|
|
det_digits = re.sub(r'\D', '', detected_bankgiro)
|
|
pl_digits = re.sub(r'\D', '', payment_line_account)
|
|
|
|
assert det_digits == pl_digits
|
|
assert det_digits == '7821713'
|
|
|
|
def test_bankgiro_mismatch(self):
|
|
"""Test Bankgiro mismatch detection."""
|
|
import re
|
|
|
|
detected_bankgiro = '782-1713'
|
|
payment_line_account = '123-4567'
|
|
|
|
det_digits = re.sub(r'\D', '', detected_bankgiro)
|
|
pl_digits = re.sub(r'\D', '', payment_line_account)
|
|
|
|
assert det_digits != pl_digits
|
|
|
|
def test_bankgiro_not_overridden(self):
|
|
"""Test that Bankgiro is NOT overridden from payment_line."""
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'Bankgiro': '999-9999', # Different value
|
|
'payment_line': 'OCR:12345 Amount:100 BG:782-1713'
|
|
}
|
|
|
|
# Bankgiro should NOT be overridden (per current logic)
|
|
# Only compared for validation
|
|
original_bankgiro = result.fields['Bankgiro']
|
|
|
|
# The override logic explicitly skips Bankgiro
|
|
# So we verify it remains unchanged
|
|
assert result.fields['Bankgiro'] == '999-9999'
|
|
assert result.fields['Bankgiro'] == original_bankgiro
|
|
|
|
|
|
class TestValidationScoring:
|
|
"""Tests for validation scoring logic."""
|
|
|
|
def test_all_fields_match(self):
|
|
"""Test score when all fields match."""
|
|
matches = [True, True, True] # OCR, Amount, Bankgiro
|
|
match_count = sum(1 for m in matches if m)
|
|
total = len(matches)
|
|
|
|
assert match_count == 3
|
|
assert total == 3
|
|
|
|
def test_partial_match(self):
|
|
"""Test score with partial matches."""
|
|
matches = [True, True, False] # OCR match, Amount match, Bankgiro mismatch
|
|
match_count = sum(1 for m in matches if m)
|
|
|
|
assert match_count == 2
|
|
|
|
def test_no_matches(self):
|
|
"""Test score when nothing matches."""
|
|
matches = [False, False, False]
|
|
match_count = sum(1 for m in matches if m)
|
|
|
|
assert match_count == 0
|
|
|
|
def test_only_count_present_fields(self):
|
|
"""Test that only present fields are counted."""
|
|
# When invoice has both BG and PG but payment_line only has BG,
|
|
# we should only count BG in validation
|
|
|
|
payment_line_account_type = 'bankgiro'
|
|
bankgiro_match = True
|
|
plusgiro_match = None # Not compared because payment_line doesn't have PG
|
|
|
|
matches = []
|
|
if payment_line_account_type == 'bankgiro' and bankgiro_match is not None:
|
|
matches.append(bankgiro_match)
|
|
elif payment_line_account_type == 'plusgiro' and plusgiro_match is not None:
|
|
matches.append(plusgiro_match)
|
|
|
|
assert len(matches) == 1
|
|
assert matches[0] is True
|
|
|
|
|
|
class TestAmountNormalization:
|
|
"""Tests for amount normalization for comparison."""
|
|
|
|
def test_normalize_amount_with_comma(self):
|
|
"""Test normalizing amount with comma decimal."""
|
|
import re
|
|
|
|
amount = "11699,00"
|
|
normalized = re.sub(r'[^\d]', '', amount)
|
|
|
|
# Remove trailing zeros for öre
|
|
if len(normalized) > 2 and normalized[-2:] == '00':
|
|
normalized = normalized[:-2]
|
|
|
|
assert normalized == '11699'
|
|
|
|
def test_normalize_amount_with_dot(self):
|
|
"""Test normalizing amount with dot decimal."""
|
|
import re
|
|
|
|
amount = "11699.00"
|
|
normalized = re.sub(r'[^\d]', '', amount)
|
|
|
|
if len(normalized) > 2 and normalized[-2:] == '00':
|
|
normalized = normalized[:-2]
|
|
|
|
assert normalized == '11699'
|
|
|
|
def test_normalize_amount_with_space_separator(self):
|
|
"""Test normalizing amount with space thousand separator."""
|
|
import re
|
|
|
|
amount = "11 699,00"
|
|
normalized = re.sub(r'[^\d]', '', amount)
|
|
|
|
if len(normalized) > 2 and normalized[-2:] == '00':
|
|
normalized = normalized[:-2]
|
|
|
|
assert normalized == '11699'
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v'])
|