457 lines
16 KiB
Python
457 lines
16 KiB
Python
"""
|
|
Inference Pipeline Integration Tests
|
|
|
|
Tests the complete pipeline from input to output.
|
|
Note: These tests use mocks for YOLO and OCR to avoid requiring actual models,
|
|
but test the integration of pipeline components.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import numpy as np
|
|
|
|
from inference.pipeline.pipeline import (
|
|
InferencePipeline,
|
|
InferenceResult,
|
|
CrossValidationResult,
|
|
)
|
|
from inference.pipeline.yolo_detector import Detection
|
|
from inference.pipeline.field_extractor import ExtractedField
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_detection():
|
|
"""Create a mock detection."""
|
|
return Detection(
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
confidence=0.95,
|
|
bbox=(100, 50, 200, 30),
|
|
page_no=0,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_extracted_field():
|
|
"""Create a mock extracted field."""
|
|
return ExtractedField(
|
|
field_name="InvoiceNumber",
|
|
raw_text="INV-2024-001",
|
|
normalized_value="INV-2024-001",
|
|
confidence=0.95,
|
|
bbox=(100, 50, 200, 30),
|
|
page_no=0,
|
|
is_valid=True,
|
|
)
|
|
|
|
|
|
class TestInferenceResultConstruction:
|
|
"""Tests for InferenceResult construction and methods."""
|
|
|
|
def test_default_result(self):
|
|
"""Test default InferenceResult values."""
|
|
result = InferenceResult()
|
|
|
|
assert result.document_id is None
|
|
assert result.success is False
|
|
assert result.fields == {}
|
|
assert result.confidence == {}
|
|
assert result.raw_detections == []
|
|
assert result.extracted_fields == []
|
|
assert result.errors == []
|
|
assert result.fallback_used is False
|
|
assert result.cross_validation is None
|
|
|
|
def test_result_to_json(self):
|
|
"""Test JSON serialization of result."""
|
|
result = InferenceResult(
|
|
document_id="test-doc",
|
|
success=True,
|
|
fields={
|
|
"InvoiceNumber": "INV-001",
|
|
"Amount": "1500.00",
|
|
},
|
|
confidence={
|
|
"InvoiceNumber": 0.95,
|
|
"Amount": 0.92,
|
|
},
|
|
bboxes={
|
|
"InvoiceNumber": (100, 50, 200, 30),
|
|
},
|
|
)
|
|
|
|
json_data = result.to_json()
|
|
|
|
assert json_data["DocumentId"] == "test-doc"
|
|
assert json_data["success"] is True
|
|
assert json_data["InvoiceNumber"] == "INV-001"
|
|
assert json_data["Amount"] == "1500.00"
|
|
assert json_data["confidence"]["InvoiceNumber"] == 0.95
|
|
assert "bboxes" in json_data
|
|
|
|
def test_result_get_field(self):
|
|
"""Test getting field value and confidence."""
|
|
result = InferenceResult(
|
|
fields={"InvoiceNumber": "INV-001"},
|
|
confidence={"InvoiceNumber": 0.95},
|
|
)
|
|
|
|
value, conf = result.get_field("InvoiceNumber")
|
|
assert value == "INV-001"
|
|
assert conf == 0.95
|
|
|
|
value, conf = result.get_field("Amount")
|
|
assert value is None
|
|
assert conf == 0.0
|
|
|
|
|
|
class TestCrossValidation:
|
|
"""Tests for cross-validation logic."""
|
|
|
|
def test_cross_validation_default(self):
|
|
"""Test default CrossValidationResult values."""
|
|
cv = CrossValidationResult()
|
|
|
|
assert cv.is_valid is False
|
|
assert cv.ocr_match is None
|
|
assert cv.amount_match is None
|
|
assert cv.bankgiro_match is None
|
|
assert cv.plusgiro_match is None
|
|
assert cv.payment_line_ocr is None
|
|
assert cv.payment_line_amount is None
|
|
assert cv.details == []
|
|
|
|
def test_cross_validation_with_matches(self):
|
|
"""Test CrossValidationResult with matches."""
|
|
cv = CrossValidationResult(
|
|
is_valid=True,
|
|
ocr_match=True,
|
|
amount_match=True,
|
|
bankgiro_match=True,
|
|
payment_line_ocr="12345678901234",
|
|
payment_line_amount="1500.00",
|
|
payment_line_account="1234-5678",
|
|
payment_line_account_type="bankgiro",
|
|
details=["OCR match", "Amount match", "Bankgiro match"],
|
|
)
|
|
|
|
assert cv.is_valid is True
|
|
assert cv.ocr_match is True
|
|
assert cv.amount_match is True
|
|
assert len(cv.details) == 3
|
|
|
|
|
|
class TestPipelineMergeFields:
|
|
"""Tests for field merging logic."""
|
|
|
|
def test_merge_selects_highest_confidence(self):
|
|
"""Test that merge selects highest confidence for duplicate fields."""
|
|
# Create mock pipeline with minimal mocking
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
pipeline.payment_line_parser = MagicMock()
|
|
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
|
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
ExtractedField(
|
|
field_name="InvoiceNumber",
|
|
raw_text="INV-001",
|
|
normalized_value="INV-001",
|
|
confidence=0.85,
|
|
detection_confidence=0.90,
|
|
ocr_confidence=0.85,
|
|
bbox=(100, 50, 200, 30),
|
|
page_no=0,
|
|
is_valid=True,
|
|
),
|
|
ExtractedField(
|
|
field_name="InvoiceNumber",
|
|
raw_text="INV-001",
|
|
normalized_value="INV-001",
|
|
confidence=0.95, # Higher confidence
|
|
detection_confidence=0.95,
|
|
ocr_confidence=0.95,
|
|
bbox=(105, 52, 198, 28),
|
|
page_no=0,
|
|
is_valid=True,
|
|
),
|
|
]
|
|
|
|
pipeline._merge_fields(result)
|
|
|
|
assert result.fields["InvoiceNumber"] == "INV-001"
|
|
assert result.confidence["InvoiceNumber"] == 0.95
|
|
|
|
def test_merge_skips_invalid_fields(self):
|
|
"""Test that merge skips invalid extracted fields."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
pipeline.payment_line_parser = MagicMock()
|
|
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
|
|
|
result = InferenceResult()
|
|
result.extracted_fields = [
|
|
ExtractedField(
|
|
field_name="InvoiceNumber",
|
|
raw_text="",
|
|
normalized_value=None,
|
|
confidence=0.95,
|
|
detection_confidence=0.95,
|
|
ocr_confidence=0.95,
|
|
bbox=(100, 50, 200, 30),
|
|
page_no=0,
|
|
is_valid=False, # Invalid
|
|
),
|
|
ExtractedField(
|
|
field_name="Amount",
|
|
raw_text="1500.00",
|
|
normalized_value="1500.00",
|
|
confidence=0.92,
|
|
detection_confidence=0.92,
|
|
ocr_confidence=0.92,
|
|
bbox=(200, 100, 100, 25),
|
|
page_no=0,
|
|
is_valid=True,
|
|
),
|
|
]
|
|
|
|
pipeline._merge_fields(result)
|
|
|
|
assert "InvoiceNumber" not in result.fields
|
|
assert result.fields["Amount"] == "1500.00"
|
|
|
|
|
|
class TestPaymentLineValidation:
|
|
"""Tests for payment line cross-validation."""
|
|
|
|
def test_payment_line_overrides_ocr(self):
|
|
"""Test that payment line OCR overrides detected OCR."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
# Mock payment line parser
|
|
mock_parsed = MagicMock()
|
|
mock_parsed.is_valid = True
|
|
mock_parsed.ocr_number = "12345678901234"
|
|
mock_parsed.amount = "1500.00"
|
|
mock_parsed.account_number = "12345678"
|
|
|
|
pipeline.payment_line_parser = MagicMock()
|
|
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
|
|
|
result = InferenceResult(
|
|
fields={
|
|
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
|
"OCR": "99999999999999", # Different OCR
|
|
},
|
|
confidence={"OCR": 0.85},
|
|
)
|
|
|
|
pipeline._cross_validate_payment_line(result)
|
|
|
|
# Payment line OCR should override
|
|
assert result.fields["OCR"] == "12345678901234"
|
|
assert result.confidence["OCR"] == 0.95
|
|
|
|
def test_payment_line_overrides_amount(self):
|
|
"""Test that payment line amount overrides detected amount."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
mock_parsed = MagicMock()
|
|
mock_parsed.is_valid = True
|
|
mock_parsed.ocr_number = None
|
|
mock_parsed.amount = "2500.50"
|
|
mock_parsed.account_number = None
|
|
|
|
pipeline.payment_line_parser = MagicMock()
|
|
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
|
|
|
result = InferenceResult(
|
|
fields={
|
|
"payment_line": "# ... # 2500 50 5 > ...",
|
|
"Amount": "2500.00", # Slightly different
|
|
},
|
|
confidence={"Amount": 0.80},
|
|
)
|
|
|
|
pipeline._cross_validate_payment_line(result)
|
|
|
|
assert result.fields["Amount"] == "2500.50"
|
|
assert result.confidence["Amount"] == 0.95
|
|
|
|
def test_cross_validation_records_matches(self):
|
|
"""Test that cross-validation records match status."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
mock_parsed = MagicMock()
|
|
mock_parsed.is_valid = True
|
|
mock_parsed.ocr_number = "12345678901234"
|
|
mock_parsed.amount = "1500.00"
|
|
mock_parsed.account_number = "12345678"
|
|
|
|
pipeline.payment_line_parser = MagicMock()
|
|
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
|
|
|
result = InferenceResult(
|
|
fields={
|
|
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
|
"OCR": "12345678901234", # Matching
|
|
"Amount": "1500.00", # Matching
|
|
"Bankgiro": "1234-5678", # Matching
|
|
},
|
|
confidence={},
|
|
)
|
|
|
|
pipeline._cross_validate_payment_line(result)
|
|
|
|
assert result.cross_validation is not None
|
|
assert result.cross_validation.ocr_match is True
|
|
assert result.cross_validation.amount_match is True
|
|
assert result.cross_validation.is_valid is True
|
|
|
|
|
|
class TestFallbackLogic:
|
|
"""Tests for fallback detection logic."""
|
|
|
|
def test_needs_fallback_when_key_fields_missing(self):
|
|
"""Test fallback is triggered when key fields missing."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
# Only one key field present
|
|
result = InferenceResult(fields={"Amount": "1500.00"})
|
|
|
|
assert pipeline._needs_fallback(result) is True
|
|
|
|
def test_no_fallback_when_fields_present(self):
|
|
"""Test no fallback when key fields present."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
# All key fields present
|
|
result = InferenceResult(
|
|
fields={
|
|
"Amount": "1500.00",
|
|
"InvoiceNumber": "INV-001",
|
|
"OCR": "12345678901234",
|
|
}
|
|
)
|
|
|
|
assert pipeline._needs_fallback(result) is False
|
|
|
|
|
|
class TestPatternExtraction:
|
|
"""Tests for fallback pattern extraction."""
|
|
|
|
def test_extract_amount_pattern(self):
|
|
"""Test amount extraction with regex."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
text = "Att betala: 1 500,00 SEK"
|
|
result = InferenceResult()
|
|
|
|
pipeline._extract_with_patterns(text, result)
|
|
|
|
assert "Amount" in result.fields
|
|
assert result.confidence["Amount"] == 0.5
|
|
|
|
def test_extract_bankgiro_pattern(self):
|
|
"""Test bankgiro extraction with regex."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
text = "Bankgiro: 1234-5678"
|
|
result = InferenceResult()
|
|
|
|
pipeline._extract_with_patterns(text, result)
|
|
|
|
assert "Bankgiro" in result.fields
|
|
assert result.fields["Bankgiro"] == "1234-5678"
|
|
|
|
def test_extract_ocr_pattern(self):
|
|
"""Test OCR extraction with regex."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
text = "OCR: 12345678901234567890"
|
|
result = InferenceResult()
|
|
|
|
pipeline._extract_with_patterns(text, result)
|
|
|
|
assert "OCR" in result.fields
|
|
assert result.fields["OCR"] == "12345678901234567890"
|
|
|
|
def test_does_not_override_existing_fields(self):
|
|
"""Test pattern extraction doesn't override existing fields."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
text = "Fakturanr: 999"
|
|
result = InferenceResult(fields={"InvoiceNumber": "INV-001"})
|
|
|
|
pipeline._extract_with_patterns(text, result)
|
|
|
|
# Should keep existing value
|
|
assert result.fields["InvoiceNumber"] == "INV-001"
|
|
|
|
|
|
class TestAmountNormalization:
|
|
"""Tests for amount normalization."""
|
|
|
|
def test_normalize_swedish_format(self):
|
|
"""Test normalizing Swedish amount format."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
# Swedish format: space as thousands separator, comma as decimal
|
|
assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00
|
|
# Standard format: dot as decimal
|
|
assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00
|
|
# Swedish format with comma as decimal
|
|
assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00
|
|
|
|
def test_normalize_invalid_amount(self):
|
|
"""Test normalizing invalid amount returns None."""
|
|
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
|
pipeline = InferencePipeline.__new__(InferencePipeline)
|
|
|
|
assert pipeline._normalize_amount_for_compare("invalid") is None
|
|
assert pipeline._normalize_amount_for_compare("") is None
|
|
|
|
|
|
class TestResultSerialization:
|
|
"""Tests for result serialization with cross-validation."""
|
|
|
|
def test_to_json_with_cross_validation(self):
|
|
"""Test JSON serialization includes cross-validation."""
|
|
cv = CrossValidationResult(
|
|
is_valid=True,
|
|
ocr_match=True,
|
|
amount_match=True,
|
|
payment_line_ocr="12345678901234",
|
|
payment_line_amount="1500.00",
|
|
details=["OCR match", "Amount match"],
|
|
)
|
|
|
|
result = InferenceResult(
|
|
document_id="test-doc",
|
|
success=True,
|
|
fields={"InvoiceNumber": "INV-001"},
|
|
cross_validation=cv,
|
|
)
|
|
|
|
json_data = result.to_json()
|
|
|
|
assert "cross_validation" in json_data
|
|
assert json_data["cross_validation"]["is_valid"] is True
|
|
assert json_data["cross_validation"]["ocr_match"] is True
|
|
assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"
|