Files
invoice-master-poc-v2/tests/integration/pipeline/test_pipeline_integration.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

457 lines
16 KiB
Python

"""
Inference Pipeline Integration Tests
Tests the complete pipeline from input to output.
Note: These tests use mocks for YOLO and OCR to avoid requiring actual models,
but test the integration of pipeline components.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
import numpy as np
from backend.pipeline.pipeline import (
InferencePipeline,
InferenceResult,
CrossValidationResult,
)
from backend.pipeline.yolo_detector import Detection
from backend.pipeline.field_extractor import ExtractedField
@pytest.fixture
def mock_detection():
"""Create a mock detection."""
return Detection(
class_id=0,
class_name="invoice_number",
confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
)
@pytest.fixture
def mock_extracted_field():
"""Create a mock extracted field."""
return ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-2024-001",
normalized_value="INV-2024-001",
confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=True,
)
class TestInferenceResultConstruction:
"""Tests for InferenceResult construction and methods."""
def test_default_result(self):
"""Test default InferenceResult values."""
result = InferenceResult()
assert result.document_id is None
assert result.success is False
assert result.fields == {}
assert result.confidence == {}
assert result.raw_detections == []
assert result.extracted_fields == []
assert result.errors == []
assert result.fallback_used is False
assert result.cross_validation is None
def test_result_to_json(self):
"""Test JSON serialization of result."""
result = InferenceResult(
document_id="test-doc",
success=True,
fields={
"InvoiceNumber": "INV-001",
"Amount": "1500.00",
},
confidence={
"InvoiceNumber": 0.95,
"Amount": 0.92,
},
bboxes={
"InvoiceNumber": (100, 50, 200, 30),
},
)
json_data = result.to_json()
assert json_data["DocumentId"] == "test-doc"
assert json_data["success"] is True
assert json_data["InvoiceNumber"] == "INV-001"
assert json_data["Amount"] == "1500.00"
assert json_data["confidence"]["InvoiceNumber"] == 0.95
assert "bboxes" in json_data
def test_result_get_field(self):
"""Test getting field value and confidence."""
result = InferenceResult(
fields={"InvoiceNumber": "INV-001"},
confidence={"InvoiceNumber": 0.95},
)
value, conf = result.get_field("InvoiceNumber")
assert value == "INV-001"
assert conf == 0.95
value, conf = result.get_field("Amount")
assert value is None
assert conf == 0.0
class TestCrossValidation:
"""Tests for cross-validation logic."""
def test_cross_validation_default(self):
"""Test default CrossValidationResult values."""
cv = CrossValidationResult()
assert cv.is_valid is False
assert cv.ocr_match is None
assert cv.amount_match is None
assert cv.bankgiro_match is None
assert cv.plusgiro_match is None
assert cv.payment_line_ocr is None
assert cv.payment_line_amount is None
assert cv.details == []
def test_cross_validation_with_matches(self):
"""Test CrossValidationResult with matches."""
cv = CrossValidationResult(
is_valid=True,
ocr_match=True,
amount_match=True,
bankgiro_match=True,
payment_line_ocr="12345678901234",
payment_line_amount="1500.00",
payment_line_account="1234-5678",
payment_line_account_type="bankgiro",
details=["OCR match", "Amount match", "Bankgiro match"],
)
assert cv.is_valid is True
assert cv.ocr_match is True
assert cv.amount_match is True
assert len(cv.details) == 3
class TestPipelineMergeFields:
"""Tests for field merging logic."""
def test_merge_selects_highest_confidence(self):
"""Test that merge selects highest confidence for duplicate fields."""
# Create mock pipeline with minimal mocking
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-001",
normalized_value="INV-001",
confidence=0.85,
detection_confidence=0.90,
ocr_confidence=0.85,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=True,
),
ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-001",
normalized_value="INV-001",
confidence=0.95, # Higher confidence
detection_confidence=0.95,
ocr_confidence=0.95,
bbox=(105, 52, 198, 28),
page_no=0,
is_valid=True,
),
]
pipeline._merge_fields(result)
assert result.fields["InvoiceNumber"] == "INV-001"
assert result.confidence["InvoiceNumber"] == 0.95
def test_merge_skips_invalid_fields(self):
"""Test that merge skips invalid extracted fields."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name="InvoiceNumber",
raw_text="",
normalized_value=None,
confidence=0.95,
detection_confidence=0.95,
ocr_confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=False, # Invalid
),
ExtractedField(
field_name="Amount",
raw_text="1500.00",
normalized_value="1500.00",
confidence=0.92,
detection_confidence=0.92,
ocr_confidence=0.92,
bbox=(200, 100, 100, 25),
page_no=0,
is_valid=True,
),
]
pipeline._merge_fields(result)
assert "InvoiceNumber" not in result.fields
assert result.fields["Amount"] == "1500.00"
class TestPaymentLineValidation:
"""Tests for payment line cross-validation."""
def test_payment_line_overrides_ocr(self):
"""Test that payment line OCR overrides detected OCR."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Mock payment line parser
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = "12345678901234"
mock_parsed.amount = "1500.00"
mock_parsed.account_number = "12345678"
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
"OCR": "99999999999999", # Different OCR
},
confidence={"OCR": 0.85},
)
pipeline._cross_validate_payment_line(result)
# Payment line OCR should override
assert result.fields["OCR"] == "12345678901234"
assert result.confidence["OCR"] == 0.95
def test_payment_line_overrides_amount(self):
"""Test that payment line amount overrides detected amount."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = None
mock_parsed.amount = "2500.50"
mock_parsed.account_number = None
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# ... # 2500 50 5 > ...",
"Amount": "2500.00", # Slightly different
},
confidence={"Amount": 0.80},
)
pipeline._cross_validate_payment_line(result)
assert result.fields["Amount"] == "2500.50"
assert result.confidence["Amount"] == 0.95
def test_cross_validation_records_matches(self):
"""Test that cross-validation records match status."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = "12345678901234"
mock_parsed.amount = "1500.00"
mock_parsed.account_number = "12345678"
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
"OCR": "12345678901234", # Matching
"Amount": "1500.00", # Matching
"Bankgiro": "1234-5678", # Matching
},
confidence={},
)
pipeline._cross_validate_payment_line(result)
assert result.cross_validation is not None
assert result.cross_validation.ocr_match is True
assert result.cross_validation.amount_match is True
assert result.cross_validation.is_valid is True
class TestFallbackLogic:
"""Tests for fallback detection logic."""
def test_needs_fallback_when_key_fields_missing(self):
"""Test fallback is triggered when key fields missing."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Only one key field present
result = InferenceResult(fields={"Amount": "1500.00"})
assert pipeline._needs_fallback(result) is True
def test_no_fallback_when_fields_present(self):
"""Test no fallback when key fields present."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# All key fields present
result = InferenceResult(
fields={
"Amount": "1500.00",
"InvoiceNumber": "INV-001",
"OCR": "12345678901234",
}
)
assert pipeline._needs_fallback(result) is False
class TestPatternExtraction:
"""Tests for fallback pattern extraction."""
def test_extract_amount_pattern(self):
"""Test amount extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Att betala: 1 500,00 SEK"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "Amount" in result.fields
assert result.confidence["Amount"] == 0.5
def test_extract_bankgiro_pattern(self):
"""Test bankgiro extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Bankgiro: 1234-5678"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "Bankgiro" in result.fields
assert result.fields["Bankgiro"] == "1234-5678"
def test_extract_ocr_pattern(self):
"""Test OCR extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "OCR: 12345678901234567890"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "OCR" in result.fields
assert result.fields["OCR"] == "12345678901234567890"
def test_does_not_override_existing_fields(self):
"""Test pattern extraction doesn't override existing fields."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Fakturanr: 999"
result = InferenceResult(fields={"InvoiceNumber": "INV-001"})
pipeline._extract_with_patterns(text, result)
# Should keep existing value
assert result.fields["InvoiceNumber"] == "INV-001"
class TestAmountNormalization:
"""Tests for amount normalization."""
def test_normalize_swedish_format(self):
"""Test normalizing Swedish amount format."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Swedish format: space as thousands separator, comma as decimal
assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00
# Standard format: dot as decimal
assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00
# Swedish format with comma as decimal
assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00
def test_normalize_invalid_amount(self):
"""Test normalizing invalid amount returns None."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
assert pipeline._normalize_amount_for_compare("invalid") is None
assert pipeline._normalize_amount_for_compare("") is None
class TestResultSerialization:
"""Tests for result serialization with cross-validation."""
def test_to_json_with_cross_validation(self):
"""Test JSON serialization includes cross-validation."""
cv = CrossValidationResult(
is_valid=True,
ocr_match=True,
amount_match=True,
payment_line_ocr="12345678901234",
payment_line_amount="1500.00",
details=["OCR match", "Amount match"],
)
result = InferenceResult(
document_id="test-doc",
success=True,
fields={"InvoiceNumber": "INV-001"},
cross_validation=cv,
)
json_data = result.to_json()
assert "cross_validation" in json_data
assert json_data["cross_validation"]["is_valid"] is True
assert json_data["cross_validation"]["ocr_match"] is True
assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"