1057 lines
40 KiB
Python
1057 lines
40 KiB
Python
"""
|
|
Tests for Inference Pipeline
|
|
|
|
Tests the cross-validation logic between payment_line and detected fields:
|
|
- OCR override from payment_line
|
|
- Amount override from payment_line
|
|
- Bankgiro/Plusgiro comparison (no override)
|
|
- Validation scoring
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from backend.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
|
|
|
|
|
|
class TestCrossValidationResult:
|
|
"""Tests for CrossValidationResult dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test default values."""
|
|
cv = CrossValidationResult()
|
|
assert cv.ocr_match is None
|
|
assert cv.amount_match is None
|
|
assert cv.bankgiro_match is None
|
|
assert cv.plusgiro_match is None
|
|
assert cv.payment_line_ocr is None
|
|
assert cv.payment_line_amount is None
|
|
assert cv.payment_line_account is None
|
|
assert cv.payment_line_account_type is None
|
|
|
|
def test_attributes(self):
|
|
"""Test setting attributes."""
|
|
cv = CrossValidationResult()
|
|
cv.ocr_match = True
|
|
cv.amount_match = True
|
|
cv.payment_line_ocr = '12345678901'
|
|
cv.payment_line_amount = '100'
|
|
cv.details = ['OCR match', 'Amount match']
|
|
|
|
assert cv.ocr_match is True
|
|
assert cv.amount_match is True
|
|
assert cv.payment_line_ocr == '12345678901'
|
|
assert 'OCR match' in cv.details
|
|
|
|
|
|
class TestInferenceResult:
|
|
"""Tests for InferenceResult dataclass."""
|
|
|
|
def test_default_fields(self):
|
|
"""Test default field values."""
|
|
result = InferenceResult()
|
|
assert result.fields == {}
|
|
assert result.confidence == {}
|
|
assert result.errors == []
|
|
|
|
def test_set_fields(self):
|
|
"""Test setting field values."""
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'OCR': '12345678901',
|
|
'Amount': '100',
|
|
'Bankgiro': '782-1713'
|
|
}
|
|
result.confidence = {
|
|
'OCR': 0.95,
|
|
'Amount': 0.90,
|
|
'Bankgiro': 0.88
|
|
}
|
|
|
|
assert result.fields['OCR'] == '12345678901'
|
|
assert result.fields['Amount'] == '100'
|
|
assert result.fields['Bankgiro'] == '782-1713'
|
|
|
|
def test_cross_validation_assignment(self):
|
|
"""Test cross validation assignment."""
|
|
result = InferenceResult()
|
|
result.fields = {'OCR': '12345678901'}
|
|
|
|
cv = CrossValidationResult()
|
|
cv.ocr_match = True
|
|
cv.payment_line_ocr = '12345678901'
|
|
result.cross_validation = cv
|
|
|
|
assert result.cross_validation is not None
|
|
assert result.cross_validation.ocr_match is True
|
|
|
|
|
|
class TestPaymentLineParsingInPipeline:
|
|
"""Tests for payment_line parsing in cross-validation."""
|
|
|
|
def test_parse_payment_line_format(self):
|
|
"""Test parsing of payment_line format: OCR:xxx Amount:xxx BG:xxx"""
|
|
# Simulate the parsing logic from pipeline
|
|
payment_line = "OCR:310196187399952 Amount:11699 BG:782-1713"
|
|
|
|
pl_parts = {}
|
|
for part in payment_line.split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
assert pl_parts.get('OCR') == '310196187399952'
|
|
assert pl_parts.get('AMOUNT') == '11699'
|
|
assert pl_parts.get('BG') == '782-1713'
|
|
|
|
def test_parse_payment_line_with_plusgiro(self):
|
|
"""Test parsing with Plusgiro."""
|
|
payment_line = "OCR:12345678901 Amount:500 PG:1234567-8"
|
|
|
|
pl_parts = {}
|
|
for part in payment_line.split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
assert pl_parts.get('OCR') == '12345678901'
|
|
assert pl_parts.get('PG') == '1234567-8'
|
|
assert pl_parts.get('BG') is None
|
|
|
|
def test_parse_empty_payment_line(self):
|
|
"""Test parsing empty payment_line."""
|
|
payment_line = ""
|
|
|
|
pl_parts = {}
|
|
for part in payment_line.split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
assert pl_parts.get('OCR') is None
|
|
assert pl_parts.get('AMOUNT') is None
|
|
|
|
|
|
class TestOCROverride:
|
|
"""Tests for OCR override logic."""
|
|
|
|
def test_ocr_override_when_different(self):
|
|
"""Test OCR is overridden when payment_line value differs."""
|
|
result = InferenceResult()
|
|
result.fields = {'OCR': 'wrong_ocr_12345', 'payment_line': 'OCR:correct_ocr_67890 Amount:100 BG:782-1713'}
|
|
|
|
# Simulate the override logic
|
|
payment_line = result.fields.get('payment_line')
|
|
pl_parts = {}
|
|
for part in str(payment_line).split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
payment_line_ocr = pl_parts.get('OCR')
|
|
|
|
# Override detected OCR with payment_line OCR
|
|
if payment_line_ocr:
|
|
result.fields['OCR'] = payment_line_ocr
|
|
|
|
assert result.fields['OCR'] == 'correct_ocr_67890'
|
|
|
|
def test_ocr_no_override_when_no_payment_line(self):
|
|
"""Test OCR is not overridden when no payment_line."""
|
|
result = InferenceResult()
|
|
result.fields = {'OCR': 'original_ocr_12345'}
|
|
|
|
# No payment_line, no override
|
|
assert result.fields['OCR'] == 'original_ocr_12345'
|
|
|
|
|
|
class TestAmountOverride:
|
|
"""Tests for Amount override logic."""
|
|
|
|
def test_amount_override(self):
|
|
"""Test Amount is overridden from payment_line."""
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'Amount': '999.00',
|
|
'payment_line': 'OCR:12345 Amount:11699 BG:782-1713'
|
|
}
|
|
|
|
payment_line = result.fields.get('payment_line')
|
|
pl_parts = {}
|
|
for part in str(payment_line).split():
|
|
if ':' in part:
|
|
key, value = part.split(':', 1)
|
|
pl_parts[key.upper()] = value
|
|
|
|
payment_line_amount = pl_parts.get('AMOUNT')
|
|
|
|
if payment_line_amount:
|
|
result.fields['Amount'] = payment_line_amount
|
|
|
|
assert result.fields['Amount'] == '11699'
|
|
|
|
|
|
class TestBankgiroComparison:
|
|
"""Tests for Bankgiro comparison (no override)."""
|
|
|
|
def test_bankgiro_match(self):
|
|
"""Test Bankgiro match detection."""
|
|
import re
|
|
|
|
detected_bankgiro = '782-1713'
|
|
payment_line_account = '782-1713'
|
|
|
|
det_digits = re.sub(r'\D', '', detected_bankgiro)
|
|
pl_digits = re.sub(r'\D', '', payment_line_account)
|
|
|
|
assert det_digits == pl_digits
|
|
assert det_digits == '7821713'
|
|
|
|
def test_bankgiro_mismatch(self):
|
|
"""Test Bankgiro mismatch detection."""
|
|
import re
|
|
|
|
detected_bankgiro = '782-1713'
|
|
payment_line_account = '123-4567'
|
|
|
|
det_digits = re.sub(r'\D', '', detected_bankgiro)
|
|
pl_digits = re.sub(r'\D', '', payment_line_account)
|
|
|
|
assert det_digits != pl_digits
|
|
|
|
def test_bankgiro_not_overridden(self):
|
|
"""Test that Bankgiro is NOT overridden from payment_line."""
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'Bankgiro': '999-9999', # Different value
|
|
'payment_line': 'OCR:12345 Amount:100 BG:782-1713'
|
|
}
|
|
|
|
# Bankgiro should NOT be overridden (per current logic)
|
|
# Only compared for validation
|
|
original_bankgiro = result.fields['Bankgiro']
|
|
|
|
# The override logic explicitly skips Bankgiro
|
|
# So we verify it remains unchanged
|
|
assert result.fields['Bankgiro'] == '999-9999'
|
|
assert result.fields['Bankgiro'] == original_bankgiro
|
|
|
|
|
|
class TestValidationScoring:
|
|
"""Tests for validation scoring logic."""
|
|
|
|
def test_all_fields_match(self):
|
|
"""Test score when all fields match."""
|
|
matches = [True, True, True] # OCR, Amount, Bankgiro
|
|
match_count = sum(1 for m in matches if m)
|
|
total = len(matches)
|
|
|
|
assert match_count == 3
|
|
assert total == 3
|
|
|
|
def test_partial_match(self):
|
|
"""Test score with partial matches."""
|
|
matches = [True, True, False] # OCR match, Amount match, Bankgiro mismatch
|
|
match_count = sum(1 for m in matches if m)
|
|
|
|
assert match_count == 2
|
|
|
|
def test_no_matches(self):
|
|
"""Test score when nothing matches."""
|
|
matches = [False, False, False]
|
|
match_count = sum(1 for m in matches if m)
|
|
|
|
assert match_count == 0
|
|
|
|
def test_only_count_present_fields(self):
|
|
"""Test that only present fields are counted."""
|
|
# When invoice has both BG and PG but payment_line only has BG,
|
|
# we should only count BG in validation
|
|
|
|
payment_line_account_type = 'bankgiro'
|
|
bankgiro_match = True
|
|
plusgiro_match = None # Not compared because payment_line doesn't have PG
|
|
|
|
matches = []
|
|
if payment_line_account_type == 'bankgiro' and bankgiro_match is not None:
|
|
matches.append(bankgiro_match)
|
|
elif payment_line_account_type == 'plusgiro' and plusgiro_match is not None:
|
|
matches.append(plusgiro_match)
|
|
|
|
assert len(matches) == 1
|
|
assert matches[0] is True
|
|
|
|
|
|
class TestAmountNormalization:
|
|
"""Tests for amount normalization for comparison."""
|
|
|
|
def test_normalize_amount_with_comma(self):
|
|
"""Test normalizing amount with comma decimal."""
|
|
import re
|
|
|
|
amount = "11699,00"
|
|
normalized = re.sub(r'[^\d]', '', amount)
|
|
|
|
# Remove trailing zeros for öre
|
|
if len(normalized) > 2 and normalized[-2:] == '00':
|
|
normalized = normalized[:-2]
|
|
|
|
assert normalized == '11699'
|
|
|
|
def test_normalize_amount_with_dot(self):
|
|
"""Test normalizing amount with dot decimal."""
|
|
import re
|
|
|
|
amount = "11699.00"
|
|
normalized = re.sub(r'[^\d]', '', amount)
|
|
|
|
if len(normalized) > 2 and normalized[-2:] == '00':
|
|
normalized = normalized[:-2]
|
|
|
|
assert normalized == '11699'
|
|
|
|
def test_normalize_amount_with_space_separator(self):
|
|
"""Test normalizing amount with space thousand separator."""
|
|
import re
|
|
|
|
amount = "11 699,00"
|
|
normalized = re.sub(r'[^\d]', '', amount)
|
|
|
|
if len(normalized) > 2 and normalized[-2:] == '00':
|
|
normalized = normalized[:-2]
|
|
|
|
assert normalized == '11699'
|
|
|
|
|
|
class TestBusinessFeatures:
|
|
"""Tests for business invoice features (line items, VAT, validation)."""
|
|
|
|
def test_inference_result_has_business_fields(self):
|
|
"""Test that InferenceResult has business feature fields."""
|
|
result = InferenceResult()
|
|
assert result.line_items is None
|
|
assert result.vat_summary is None
|
|
assert result.vat_validation is None
|
|
|
|
def test_to_json_without_business_features(self):
|
|
"""Test to_json works without business features."""
|
|
result = InferenceResult()
|
|
result.fields = {'InvoiceNumber': '12345'}
|
|
result.confidence = {'InvoiceNumber': 0.95}
|
|
|
|
json_result = result.to_json()
|
|
|
|
assert json_result['InvoiceNumber'] == '12345'
|
|
assert 'line_items' not in json_result
|
|
assert 'vat_summary' not in json_result
|
|
assert 'vat_validation' not in json_result
|
|
|
|
def test_to_json_with_line_items(self):
|
|
"""Test to_json includes line items when present."""
|
|
from backend.table.line_items_extractor import LineItem, LineItemsResult
|
|
|
|
result = InferenceResult()
|
|
result.fields = {'Amount': '12500.00'}
|
|
result.line_items = LineItemsResult(
|
|
items=[
|
|
LineItem(
|
|
row_index=0,
|
|
description="Product A",
|
|
quantity="2",
|
|
unit_price="5000,00",
|
|
amount="10000,00",
|
|
vat_rate="25",
|
|
confidence=0.9
|
|
)
|
|
],
|
|
header_row=["Beskrivning", "Antal", "Pris", "Belopp", "Moms"],
|
|
raw_html="<table>...</table>"
|
|
)
|
|
|
|
json_result = result.to_json()
|
|
|
|
assert 'line_items' in json_result
|
|
assert len(json_result['line_items']['items']) == 1
|
|
assert json_result['line_items']['items'][0]['description'] == "Product A"
|
|
assert json_result['line_items']['items'][0]['amount'] == "10000,00"
|
|
|
|
def test_to_json_with_vat_summary(self):
|
|
"""Test to_json includes VAT summary when present."""
|
|
from backend.vat.vat_extractor import VATBreakdown, VATSummary
|
|
|
|
result = InferenceResult()
|
|
result.vat_summary = VATSummary(
|
|
breakdowns=[
|
|
VATBreakdown(rate=25.0, base_amount="10000,00", vat_amount="2500,00", source="regex")
|
|
],
|
|
total_excl_vat="10000,00",
|
|
total_vat="2500,00",
|
|
total_incl_vat="12500,00",
|
|
confidence=0.9
|
|
)
|
|
|
|
json_result = result.to_json()
|
|
|
|
assert 'vat_summary' in json_result
|
|
assert len(json_result['vat_summary']['breakdowns']) == 1
|
|
assert json_result['vat_summary']['breakdowns'][0]['rate'] == 25.0
|
|
assert json_result['vat_summary']['total_incl_vat'] == "12500,00"
|
|
|
|
def test_to_json_with_vat_validation(self):
|
|
"""Test to_json includes VAT validation when present."""
|
|
from backend.validation.vat_validator import VATValidationResult, MathCheckResult
|
|
|
|
result = InferenceResult()
|
|
result.vat_validation = VATValidationResult(
|
|
is_valid=True,
|
|
confidence_score=0.95,
|
|
math_checks=[
|
|
MathCheckResult(
|
|
rate=25.0,
|
|
base_amount=10000.0,
|
|
expected_vat=2500.0,
|
|
actual_vat=2500.0,
|
|
is_valid=True,
|
|
tolerance=0.5
|
|
)
|
|
],
|
|
total_check=True,
|
|
line_items_vs_summary=True,
|
|
amount_consistency=True,
|
|
needs_review=False,
|
|
review_reasons=[]
|
|
)
|
|
|
|
json_result = result.to_json()
|
|
|
|
assert 'vat_validation' in json_result
|
|
assert json_result['vat_validation']['is_valid'] is True
|
|
assert json_result['vat_validation']['confidence_score'] == 0.95
|
|
assert len(json_result['vat_validation']['math_checks']) == 1
|
|
|
|
|
|
class TestBusinessFeaturesAvailable:
|
|
"""Tests for BUSINESS_FEATURES_AVAILABLE flag."""
|
|
|
|
def test_business_features_available(self):
|
|
"""Test that business features are available."""
|
|
from backend.pipeline import BUSINESS_FEATURES_AVAILABLE
|
|
assert BUSINESS_FEATURES_AVAILABLE is True
|
|
|
|
|
|
class TestExtractBusinessFeaturesErrorHandling:
|
|
"""Tests for _extract_business_features error handling."""
|
|
|
|
def test_pipeline_module_has_logger(self):
|
|
"""Test that pipeline module defines logger correctly."""
|
|
from backend.pipeline import pipeline
|
|
assert hasattr(pipeline, 'logger')
|
|
assert pipeline.logger is not None
|
|
|
|
def test_extract_business_features_logs_errors(self):
|
|
"""Test that _extract_business_features logs detailed errors."""
|
|
from backend.pipeline.pipeline import InferencePipeline, InferenceResult
|
|
|
|
# Create a pipeline with mocked extractors that raise an exception
|
|
with patch.object(InferencePipeline, '__init__', lambda self, **kwargs: None):
|
|
pipeline = InferencePipeline()
|
|
pipeline.line_items_extractor = MagicMock()
|
|
pipeline.vat_extractor = MagicMock()
|
|
pipeline.vat_validator = MagicMock()
|
|
|
|
# Make line_items_extractor raise an exception
|
|
test_error = ValueError("Test error message")
|
|
pipeline.line_items_extractor.extract_from_pdf.side_effect = test_error
|
|
|
|
result = InferenceResult()
|
|
|
|
# Call the method
|
|
pipeline._extract_business_features("/fake/path.pdf", result, "full text")
|
|
|
|
# Verify error was captured with type info
|
|
assert len(result.errors) == 1
|
|
assert "ValueError" in result.errors[0]
|
|
assert "Test error message" in result.errors[0]
|
|
|
|
def test_extract_business_features_handles_numeric_exceptions(self):
|
|
"""Test that _extract_business_features handles non-standard exceptions."""
|
|
from backend.pipeline.pipeline import InferencePipeline, InferenceResult
|
|
|
|
with patch.object(InferencePipeline, '__init__', lambda self, **kwargs: None):
|
|
pipeline = InferencePipeline()
|
|
pipeline.line_items_extractor = MagicMock()
|
|
pipeline.vat_extractor = MagicMock()
|
|
pipeline.vat_validator = MagicMock()
|
|
|
|
# Simulate an exception that might have a numeric value (like exit codes)
|
|
class NumericException(Exception):
|
|
def __str__(self):
|
|
return "0"
|
|
|
|
pipeline.line_items_extractor.extract_from_pdf.side_effect = NumericException()
|
|
|
|
result = InferenceResult()
|
|
pipeline._extract_business_features("/fake/path.pdf", result, "full text")
|
|
|
|
# Should include type name even when str(e) is just "0"
|
|
assert len(result.errors) == 1
|
|
assert "NumericException" in result.errors[0]
|
|
|
|
|
|
class TestProcessPdfTokenPath:
|
|
"""Tests for PDF text token extraction path in process_pdf()."""
|
|
|
|
def _make_pipeline(self):
|
|
"""Create pipeline with mocked internals, bypassing __init__."""
|
|
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
|
p = InferencePipeline()
|
|
p.detector = MagicMock()
|
|
p.extractor = MagicMock()
|
|
p.payment_line_parser = MagicMock()
|
|
p.dpi = 300
|
|
p.enable_fallback = False
|
|
p.enable_business_features = False
|
|
p.vat_tolerance = 0.5
|
|
p.line_items_extractor = None
|
|
p.vat_extractor = None
|
|
p.vat_validator = None
|
|
p._business_ocr_engine = None
|
|
p._table_detector = None
|
|
return p
|
|
|
|
def _make_detection(self, class_name='Amount', confidence=0.85, page_no=0):
|
|
"""Create a Detection object."""
|
|
from backend.pipeline.yolo_detector import Detection
|
|
return Detection(
|
|
class_id=6,
|
|
class_name=class_name,
|
|
confidence=confidence,
|
|
bbox=(100.0, 200.0, 300.0, 250.0),
|
|
page_no=page_no,
|
|
)
|
|
|
|
def _make_extracted_field(self, field_name='Amount', raw_text='2.254,50',
|
|
normalized='2254.50', confidence=0.85):
|
|
"""Create an ExtractedField object."""
|
|
from backend.pipeline.field_extractor import ExtractedField
|
|
return ExtractedField(
|
|
field_name=field_name,
|
|
raw_text=raw_text,
|
|
normalized_value=normalized,
|
|
confidence=confidence,
|
|
detection_confidence=confidence,
|
|
ocr_confidence=1.0,
|
|
bbox=(100.0, 200.0, 300.0, 250.0),
|
|
page_no=0,
|
|
)
|
|
|
|
def _make_image_bytes(self):
|
|
"""Create minimal valid PNG bytes (100x100 white image)."""
|
|
from PIL import Image as PILImage
|
|
import io as _io
|
|
img = PILImage.new('RGB', (100, 100), color='white')
|
|
buf = _io.BytesIO()
|
|
img.save(buf, format='PNG')
|
|
return buf.getvalue()
|
|
|
|
@patch('shared.pdf.extractor.PDFDocument')
|
|
@patch('shared.pdf.renderer.render_pdf_to_images')
|
|
def test_text_pdf_uses_pdf_tokens(self, mock_render, mock_pdf_doc_cls):
|
|
"""When PDF is text-based, extract_from_detection_with_pdf is used."""
|
|
from shared.pdf.extractor import Token
|
|
|
|
pipeline = self._make_pipeline()
|
|
detection = self._make_detection()
|
|
image_bytes = self._make_image_bytes()
|
|
|
|
# Setup PDFDocument mock - text PDF with tokens
|
|
mock_pdf_doc = MagicMock()
|
|
mock_pdf_doc.is_text_pdf.return_value = True
|
|
mock_pdf_doc.page_count = 1
|
|
tokens = [Token(text="2.254,50", bbox=(100, 200, 200, 220), page_no=0)]
|
|
mock_pdf_doc.extract_text_tokens.return_value = iter(tokens)
|
|
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
|
|
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
pipeline.detector.detect.return_value = [detection]
|
|
pipeline.extractor.extract_from_detection_with_pdf.return_value = (
|
|
self._make_extracted_field()
|
|
)
|
|
|
|
mock_render.return_value = iter([(0, image_bytes)])
|
|
result = pipeline.process_pdf('/fake/invoice.pdf')
|
|
|
|
pipeline.extractor.extract_from_detection_with_pdf.assert_called_once()
|
|
pipeline.extractor.extract_from_detection.assert_not_called()
|
|
assert result.fields.get('Amount') == '2254.50'
|
|
assert result.success is True
|
|
|
|
@patch('shared.pdf.extractor.PDFDocument')
|
|
@patch('shared.pdf.renderer.render_pdf_to_images')
|
|
def test_scanned_pdf_uses_ocr(self, mock_render, mock_pdf_doc_cls):
|
|
"""When PDF is scanned, extract_from_detection (OCR) is used."""
|
|
pipeline = self._make_pipeline()
|
|
detection = self._make_detection()
|
|
image_bytes = self._make_image_bytes()
|
|
|
|
mock_pdf_doc = MagicMock()
|
|
mock_pdf_doc.is_text_pdf.return_value = False
|
|
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
|
|
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
pipeline.detector.detect.return_value = [detection]
|
|
pipeline.extractor.extract_from_detection.return_value = (
|
|
self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75)
|
|
)
|
|
|
|
mock_render.return_value = iter([(0, image_bytes)])
|
|
result = pipeline.process_pdf('/fake/invoice.pdf')
|
|
|
|
pipeline.extractor.extract_from_detection.assert_called_once()
|
|
pipeline.extractor.extract_from_detection_with_pdf.assert_not_called()
|
|
|
|
@patch('shared.pdf.extractor.PDFDocument')
|
|
@patch('shared.pdf.renderer.render_pdf_to_images')
|
|
def test_pdf_detection_error_falls_back_to_ocr(self, mock_render, mock_pdf_doc_cls):
|
|
"""When PDF text detection throws, fall back to OCR."""
|
|
pipeline = self._make_pipeline()
|
|
detection = self._make_detection()
|
|
image_bytes = self._make_image_bytes()
|
|
|
|
mock_ctx = MagicMock()
|
|
mock_ctx.__enter__ = MagicMock(side_effect=Exception("corrupt PDF"))
|
|
mock_ctx.__exit__ = MagicMock(return_value=False)
|
|
mock_pdf_doc_cls.return_value = mock_ctx
|
|
|
|
pipeline.detector.detect.return_value = [detection]
|
|
pipeline.extractor.extract_from_detection.return_value = (
|
|
self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75)
|
|
)
|
|
|
|
mock_render.return_value = iter([(0, image_bytes)])
|
|
result = pipeline.process_pdf('/fake/invoice.pdf')
|
|
|
|
pipeline.extractor.extract_from_detection.assert_called_once()
|
|
pipeline.extractor.extract_from_detection_with_pdf.assert_not_called()
|
|
|
|
@patch('shared.pdf.extractor.PDFDocument')
|
|
@patch('shared.pdf.renderer.render_pdf_to_images')
|
|
def test_text_pdf_passes_correct_args(self, mock_render, mock_pdf_doc_cls):
|
|
"""Verify correct token list and image dimensions are passed."""
|
|
from shared.pdf.extractor import Token
|
|
|
|
pipeline = self._make_pipeline()
|
|
detection = self._make_detection()
|
|
image_bytes = self._make_image_bytes() # 100x100 PNG
|
|
|
|
mock_pdf_doc = MagicMock()
|
|
mock_pdf_doc.is_text_pdf.return_value = True
|
|
mock_pdf_doc.page_count = 1
|
|
tokens = [
|
|
Token(text="Fakturabelopp:", bbox=(50, 190, 100, 210), page_no=0),
|
|
Token(text="2.254,50", bbox=(105, 190, 180, 210), page_no=0),
|
|
Token(text="SEK", bbox=(185, 190, 210, 210), page_no=0),
|
|
]
|
|
mock_pdf_doc.extract_text_tokens.return_value = iter(tokens)
|
|
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
|
|
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
pipeline.detector.detect.return_value = [detection]
|
|
pipeline.extractor.extract_from_detection_with_pdf.return_value = (
|
|
self._make_extracted_field()
|
|
)
|
|
|
|
mock_render.return_value = iter([(0, image_bytes)])
|
|
pipeline.process_pdf('/fake/invoice.pdf')
|
|
|
|
call_args = pipeline.extractor.extract_from_detection_with_pdf.call_args[0]
|
|
assert call_args[0] == detection
|
|
assert len(call_args[1]) == 3 # 3 tokens passed
|
|
assert call_args[2] == 100 # image width
|
|
assert call_args[3] == 100 # image height
|
|
|
|
|
|
class TestDpiPassthrough:
|
|
"""Tests for DPI being passed from pipeline to FieldExtractor (Bug 1)."""
|
|
|
|
def test_field_extractor_receives_pipeline_dpi(self):
|
|
"""FieldExtractor should receive the pipeline's DPI, not default to 300."""
|
|
with patch('backend.pipeline.pipeline.YOLODetector'):
|
|
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
|
|
InferencePipeline(
|
|
model_path='/fake/model.pt',
|
|
dpi=150,
|
|
use_gpu=False,
|
|
)
|
|
mock_fe_cls.assert_called_once_with(
|
|
ocr_lang='en', use_gpu=False, dpi=150
|
|
)
|
|
|
|
def test_field_extractor_receives_default_dpi(self):
|
|
"""When dpi=300 (default), FieldExtractor should also get 300."""
|
|
with patch('backend.pipeline.pipeline.YOLODetector'):
|
|
with patch('backend.pipeline.pipeline.FieldExtractor') as mock_fe_cls:
|
|
InferencePipeline(
|
|
model_path='/fake/model.pt',
|
|
dpi=300,
|
|
use_gpu=False,
|
|
)
|
|
mock_fe_cls.assert_called_once_with(
|
|
ocr_lang='en', use_gpu=False, dpi=300
|
|
)
|
|
|
|
|
|
class TestFallbackPatternExtraction:
|
|
"""Tests for _extract_with_patterns fallback regex (Bugs 2, 3)."""
|
|
|
|
def _make_pipeline_with_patterns(self):
|
|
"""Create pipeline with mocked internals for pattern testing."""
|
|
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
|
p = InferencePipeline()
|
|
p.dpi = 150
|
|
p.enable_fallback = True
|
|
return p
|
|
|
|
def test_bankgiro_no_match_in_org_number(self):
|
|
"""Bankgiro regex must NOT match digits embedded in an org number."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Org.nr 802546-1610", result)
|
|
assert 'Bankgiro' not in result.fields
|
|
|
|
def test_bankgiro_matches_labeled(self):
|
|
"""Bankgiro regex should match when preceded by 'Bankgiro' label."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Bankgiro 5393-9484", result)
|
|
assert result.fields.get('Bankgiro') == '5393-9484'
|
|
|
|
def test_bankgiro_matches_standalone(self):
|
|
"""Bankgiro regex should match a standalone 4-4 digit pattern."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Betala till 5393-9484 senast", result)
|
|
assert result.fields.get('Bankgiro') == '5393-9484'
|
|
|
|
def test_amount_rejects_bare_integer(self):
|
|
"""Amount regex must NOT match bare integers like 'Summa 1'."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Summa 1 Medlemsavgift", result)
|
|
assert 'Amount' not in result.fields
|
|
|
|
def test_amount_requires_decimal(self):
|
|
"""Amount regex should require a decimal separator."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Total 5 items", result)
|
|
assert 'Amount' not in result.fields
|
|
|
|
def test_amount_with_decimal_works(self):
|
|
"""Amount regex should match Swedish decimal amounts."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Att betala 1 234,56 SEK", result)
|
|
assert 'Amount' in result.fields
|
|
assert float(result.fields['Amount']) == pytest.approx(1234.56, abs=0.01)
|
|
|
|
def test_amount_with_sek_suffix(self):
|
|
"""Amount regex should match amounts ending with SEK."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("7 500,00 SEK", result)
|
|
assert 'Amount' in result.fields
|
|
assert float(result.fields['Amount']) == pytest.approx(7500.00, abs=0.01)
|
|
|
|
def test_fallback_extracts_invoice_date(self):
|
|
"""Fallback should extract InvoiceDate from Swedish text."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Fakturadatum 2025-01-15 Referens ABC", result)
|
|
assert result.fields.get('InvoiceDate') == '2025-01-15'
|
|
|
|
def test_fallback_extracts_due_date(self):
|
|
"""Fallback should extract InvoiceDueDate from Swedish text."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Forfallodag 2025-02-15 Belopp", result)
|
|
assert result.fields.get('InvoiceDueDate') == '2025-02-15'
|
|
|
|
def test_fallback_extracts_supplier_org(self):
|
|
"""Fallback should extract supplier_organisation_number."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Org.nr 556123-4567 Stockholm", result)
|
|
assert result.fields.get('supplier_organisation_number') == '556123-4567'
|
|
|
|
def test_fallback_extracts_plusgiro(self):
|
|
"""Fallback should extract Plusgiro number."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Plusgiro 12 34 56-7 betalning", result)
|
|
assert 'Plusgiro' in result.fields
|
|
|
|
def test_fallback_skips_year_as_invoice_number(self):
|
|
"""Fallback should NOT extract year-like value as InvoiceNumber."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Fakturanr 2025 Datum 2025-01-15", result)
|
|
assert 'InvoiceNumber' not in result.fields
|
|
|
|
def test_fallback_accepts_valid_invoice_number(self):
|
|
"""Fallback should extract valid non-year InvoiceNumber."""
|
|
p = self._make_pipeline_with_patterns()
|
|
result = InferenceResult()
|
|
p._extract_with_patterns("Fakturanr 12345 Summa", result)
|
|
assert result.fields.get('InvoiceNumber') == '12345'
|
|
|
|
|
|
class TestDateValidation:
|
|
"""Tests for InvoiceDueDate < InvoiceDate validation (Bug 6)."""
|
|
|
|
def _make_pipeline_for_merge(self):
|
|
"""Create pipeline with mocked internals for merge testing."""
|
|
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
|
p = InferencePipeline()
|
|
p.payment_line_parser = MagicMock()
|
|
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
|
return p
|
|
|
|
def test_due_date_before_invoice_date_dropped(self):
|
|
"""DueDate earlier than InvoiceDate should be removed."""
|
|
from backend.pipeline.field_extractor import ExtractedField
|
|
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
ExtractedField(
|
|
field_name='InvoiceDate', raw_text='2026-01-16',
|
|
normalized_value='2026-01-16', confidence=0.9,
|
|
detection_confidence=0.9, ocr_confidence=1.0,
|
|
bbox=(0, 0, 100, 50), page_no=0,
|
|
),
|
|
ExtractedField(
|
|
field_name='InvoiceDueDate', raw_text='2025-12-01',
|
|
normalized_value='2025-12-01', confidence=0.9,
|
|
detection_confidence=0.9, ocr_confidence=1.0,
|
|
bbox=(0, 60, 100, 110), page_no=0,
|
|
),
|
|
]
|
|
p._merge_fields(result)
|
|
assert 'InvoiceDate' in result.fields
|
|
assert 'InvoiceDueDate' not in result.fields
|
|
|
|
def test_valid_dates_preserved(self):
|
|
"""Both dates kept when DueDate >= InvoiceDate."""
|
|
from backend.pipeline.field_extractor import ExtractedField
|
|
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
ExtractedField(
|
|
field_name='InvoiceDate', raw_text='2026-01-16',
|
|
normalized_value='2026-01-16', confidence=0.9,
|
|
detection_confidence=0.9, ocr_confidence=1.0,
|
|
bbox=(0, 0, 100, 50), page_no=0,
|
|
),
|
|
ExtractedField(
|
|
field_name='InvoiceDueDate', raw_text='2026-02-15',
|
|
normalized_value='2026-02-15', confidence=0.9,
|
|
detection_confidence=0.9, ocr_confidence=1.0,
|
|
bbox=(0, 60, 100, 110), page_no=0,
|
|
),
|
|
]
|
|
p._merge_fields(result)
|
|
assert result.fields['InvoiceDate'] == '2026-01-16'
|
|
assert result.fields['InvoiceDueDate'] == '2026-02-15'
|
|
|
|
def test_same_dates_preserved(self):
|
|
"""Same InvoiceDate and DueDate should both be kept."""
|
|
from backend.pipeline.field_extractor import ExtractedField
|
|
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
ExtractedField(
|
|
field_name='InvoiceDate', raw_text='2026-01-16',
|
|
normalized_value='2026-01-16', confidence=0.9,
|
|
detection_confidence=0.9, ocr_confidence=1.0,
|
|
bbox=(0, 0, 100, 50), page_no=0,
|
|
),
|
|
ExtractedField(
|
|
field_name='InvoiceDueDate', raw_text='2026-01-16',
|
|
normalized_value='2026-01-16', confidence=0.9,
|
|
detection_confidence=0.9, ocr_confidence=1.0,
|
|
bbox=(0, 60, 100, 110), page_no=0,
|
|
),
|
|
]
|
|
p._merge_fields(result)
|
|
assert result.fields['InvoiceDate'] == '2026-01-16'
|
|
assert result.fields['InvoiceDueDate'] == '2026-01-16'
|
|
|
|
|
|
class TestCrossFieldDedup:
|
|
"""Tests for cross-field deduplication of InvoiceNumber vs OCR/Bankgiro."""
|
|
|
|
def _make_pipeline_for_merge(self):
|
|
"""Create pipeline with mocked internals for merge testing."""
|
|
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
|
p = InferencePipeline()
|
|
p.payment_line_parser = MagicMock()
|
|
p.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
|
return p
|
|
|
|
def _make_extracted_field(self, field_name, raw_text, normalized, confidence=0.9):
|
|
from backend.pipeline.field_extractor import ExtractedField
|
|
return ExtractedField(
|
|
field_name=field_name,
|
|
raw_text=raw_text,
|
|
normalized_value=normalized,
|
|
confidence=confidence,
|
|
detection_confidence=confidence,
|
|
ocr_confidence=1.0,
|
|
bbox=(0, 0, 100, 50),
|
|
page_no=0,
|
|
)
|
|
|
|
def test_invoice_number_not_same_as_ocr(self):
|
|
"""When InvoiceNumber == OCR, InvoiceNumber should be dropped."""
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
self._make_extracted_field('InvoiceNumber', '9179845608', '9179845608'),
|
|
self._make_extracted_field('OCR', '9179845608', '9179845608'),
|
|
self._make_extracted_field('Amount', '1234,56', '1234.56'),
|
|
]
|
|
p._merge_fields(result)
|
|
assert 'OCR' in result.fields
|
|
assert result.fields['OCR'] == '9179845608'
|
|
assert 'InvoiceNumber' not in result.fields
|
|
|
|
def test_invoice_number_not_same_as_bankgiro_digits(self):
|
|
"""When InvoiceNumber digits == Bankgiro digits, InvoiceNumber should be dropped."""
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
self._make_extracted_field('InvoiceNumber', '53939484', '53939484'),
|
|
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
|
|
self._make_extracted_field('Amount', '500,00', '500.00'),
|
|
]
|
|
p._merge_fields(result)
|
|
assert 'Bankgiro' in result.fields
|
|
assert result.fields['Bankgiro'] == '5393-9484'
|
|
assert 'InvoiceNumber' not in result.fields
|
|
|
|
def test_unrelated_values_kept(self):
|
|
"""When InvoiceNumber, OCR, and Bankgiro are all different, keep all."""
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
|
|
self._make_extracted_field('OCR', '9179845608', '9179845608'),
|
|
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
|
|
]
|
|
p._merge_fields(result)
|
|
assert result.fields['InvoiceNumber'] == '19061'
|
|
assert result.fields['OCR'] == '9179845608'
|
|
assert result.fields['Bankgiro'] == '5393-9484'
|
|
|
|
def test_dedup_after_fallback_re_add(self):
|
|
"""Dedup should remove InvoiceNumber re-added by fallback if it matches OCR."""
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
# Simulate state after fallback re-adds InvoiceNumber = OCR
|
|
result.fields = {
|
|
'OCR': '758200602426',
|
|
'Amount': '164.00',
|
|
'InvoiceNumber': '758200602426', # re-added by fallback
|
|
}
|
|
result.confidence = {
|
|
'OCR': 0.9,
|
|
'Amount': 0.9,
|
|
'InvoiceNumber': 0.5, # fallback confidence
|
|
}
|
|
result.bboxes = {}
|
|
p._dedup_invoice_number(result)
|
|
assert 'InvoiceNumber' not in result.fields
|
|
assert 'OCR' in result.fields
|
|
|
|
def test_invoice_number_substring_of_bankgiro(self):
|
|
"""When InvoiceNumber digits are a substring of Bankgiro digits, drop InvoiceNumber."""
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
self._make_extracted_field('InvoiceNumber', '4639', '4639'),
|
|
self._make_extracted_field('Bankgiro', '134-4639', '134-4639'),
|
|
self._make_extracted_field('Amount', '500,00', '500.00'),
|
|
]
|
|
p._merge_fields(result)
|
|
assert 'Bankgiro' in result.fields
|
|
assert result.fields['Bankgiro'] == '134-4639'
|
|
assert 'InvoiceNumber' not in result.fields
|
|
|
|
def test_invoice_number_not_substring_of_unrelated_bankgiro(self):
|
|
"""When InvoiceNumber is NOT a substring of Bankgiro, keep both."""
|
|
p = self._make_pipeline_for_merge()
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
self._make_extracted_field('InvoiceNumber', '19061', '19061'),
|
|
self._make_extracted_field('Bankgiro', '5393-9484', '5393-9484'),
|
|
self._make_extracted_field('Amount', '500,00', '500.00'),
|
|
]
|
|
p._merge_fields(result)
|
|
assert result.fields['InvoiceNumber'] == '19061'
|
|
assert result.fields['Bankgiro'] == '5393-9484'
|
|
|
|
|
|
class TestFallbackTrigger:
|
|
"""Tests for _needs_fallback trigger threshold."""
|
|
|
|
def _make_pipeline(self):
|
|
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
|
p = InferencePipeline()
|
|
return p
|
|
|
|
def test_fallback_triggers_when_1_key_field_missing(self):
|
|
"""Should trigger when only 1 key field (e.g. InvoiceNumber) is missing."""
|
|
p = self._make_pipeline()
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'Amount': '1234.56',
|
|
'OCR': '12345678901',
|
|
'InvoiceDate': '2025-01-15',
|
|
'InvoiceDueDate': '2025-02-15',
|
|
'supplier_organisation_number': '556123-4567',
|
|
}
|
|
# InvoiceNumber missing -> should trigger
|
|
assert p._needs_fallback(result) is True
|
|
|
|
def test_fallback_triggers_when_dates_missing(self):
|
|
"""Should trigger when all key fields present but 2+ important fields missing."""
|
|
p = self._make_pipeline()
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'Amount': '1234.56',
|
|
'InvoiceNumber': '12345',
|
|
'OCR': '12345678901',
|
|
}
|
|
# InvoiceDate, InvoiceDueDate, supplier_org all missing -> should trigger
|
|
assert p._needs_fallback(result) is True
|
|
|
|
def test_no_fallback_when_all_fields_present(self):
|
|
"""Should NOT trigger when all key and important fields present."""
|
|
p = self._make_pipeline()
|
|
result = InferenceResult()
|
|
result.fields = {
|
|
'Amount': '1234.56',
|
|
'InvoiceNumber': '12345',
|
|
'OCR': '12345678901',
|
|
'InvoiceDate': '2025-01-15',
|
|
'InvoiceDueDate': '2025-02-15',
|
|
'supplier_organisation_number': '556123-4567',
|
|
}
|
|
assert p._needs_fallback(result) is False
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v'])
|