Add more tests
This commit is contained in:
1
tests/integration/pipeline/__init__.py
Normal file
1
tests/integration/pipeline/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Pipeline integration tests."""
|
||||
456
tests/integration/pipeline/test_pipeline_integration.py
Normal file
456
tests/integration/pipeline/test_pipeline_integration.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Inference Pipeline Integration Tests
|
||||
|
||||
Tests the complete pipeline from input to output.
|
||||
Note: These tests use mocks for YOLO and OCR to avoid requiring actual models,
|
||||
but test the integration of pipeline components.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from inference.pipeline.pipeline import (
|
||||
InferencePipeline,
|
||||
InferenceResult,
|
||||
CrossValidationResult,
|
||||
)
|
||||
from inference.pipeline.yolo_detector import Detection
|
||||
from inference.pipeline.field_extractor import ExtractedField
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_detection():
|
||||
"""Create a mock detection."""
|
||||
return Detection(
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_extracted_field():
|
||||
"""Create a mock extracted field."""
|
||||
return ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-2024-001",
|
||||
normalized_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceResultConstruction:
|
||||
"""Tests for InferenceResult construction and methods."""
|
||||
|
||||
def test_default_result(self):
|
||||
"""Test default InferenceResult values."""
|
||||
result = InferenceResult()
|
||||
|
||||
assert result.document_id is None
|
||||
assert result.success is False
|
||||
assert result.fields == {}
|
||||
assert result.confidence == {}
|
||||
assert result.raw_detections == []
|
||||
assert result.extracted_fields == []
|
||||
assert result.errors == []
|
||||
assert result.fallback_used is False
|
||||
assert result.cross_validation is None
|
||||
|
||||
def test_result_to_json(self):
|
||||
"""Test JSON serialization of result."""
|
||||
result = InferenceResult(
|
||||
document_id="test-doc",
|
||||
success=True,
|
||||
fields={
|
||||
"InvoiceNumber": "INV-001",
|
||||
"Amount": "1500.00",
|
||||
},
|
||||
confidence={
|
||||
"InvoiceNumber": 0.95,
|
||||
"Amount": 0.92,
|
||||
},
|
||||
bboxes={
|
||||
"InvoiceNumber": (100, 50, 200, 30),
|
||||
},
|
||||
)
|
||||
|
||||
json_data = result.to_json()
|
||||
|
||||
assert json_data["DocumentId"] == "test-doc"
|
||||
assert json_data["success"] is True
|
||||
assert json_data["InvoiceNumber"] == "INV-001"
|
||||
assert json_data["Amount"] == "1500.00"
|
||||
assert json_data["confidence"]["InvoiceNumber"] == 0.95
|
||||
assert "bboxes" in json_data
|
||||
|
||||
def test_result_get_field(self):
|
||||
"""Test getting field value and confidence."""
|
||||
result = InferenceResult(
|
||||
fields={"InvoiceNumber": "INV-001"},
|
||||
confidence={"InvoiceNumber": 0.95},
|
||||
)
|
||||
|
||||
value, conf = result.get_field("InvoiceNumber")
|
||||
assert value == "INV-001"
|
||||
assert conf == 0.95
|
||||
|
||||
value, conf = result.get_field("Amount")
|
||||
assert value is None
|
||||
assert conf == 0.0
|
||||
|
||||
|
||||
class TestCrossValidation:
|
||||
"""Tests for cross-validation logic."""
|
||||
|
||||
def test_cross_validation_default(self):
|
||||
"""Test default CrossValidationResult values."""
|
||||
cv = CrossValidationResult()
|
||||
|
||||
assert cv.is_valid is False
|
||||
assert cv.ocr_match is None
|
||||
assert cv.amount_match is None
|
||||
assert cv.bankgiro_match is None
|
||||
assert cv.plusgiro_match is None
|
||||
assert cv.payment_line_ocr is None
|
||||
assert cv.payment_line_amount is None
|
||||
assert cv.details == []
|
||||
|
||||
def test_cross_validation_with_matches(self):
|
||||
"""Test CrossValidationResult with matches."""
|
||||
cv = CrossValidationResult(
|
||||
is_valid=True,
|
||||
ocr_match=True,
|
||||
amount_match=True,
|
||||
bankgiro_match=True,
|
||||
payment_line_ocr="12345678901234",
|
||||
payment_line_amount="1500.00",
|
||||
payment_line_account="1234-5678",
|
||||
payment_line_account_type="bankgiro",
|
||||
details=["OCR match", "Amount match", "Bankgiro match"],
|
||||
)
|
||||
|
||||
assert cv.is_valid is True
|
||||
assert cv.ocr_match is True
|
||||
assert cv.amount_match is True
|
||||
assert len(cv.details) == 3
|
||||
|
||||
|
||||
class TestPipelineMergeFields:
|
||||
"""Tests for field merging logic."""
|
||||
|
||||
def test_merge_selects_highest_confidence(self):
|
||||
"""Test that merge selects highest confidence for duplicate fields."""
|
||||
# Create mock pipeline with minimal mocking
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-001",
|
||||
normalized_value="INV-001",
|
||||
confidence=0.85,
|
||||
detection_confidence=0.90,
|
||||
ocr_confidence=0.85,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-001",
|
||||
normalized_value="INV-001",
|
||||
confidence=0.95, # Higher confidence
|
||||
detection_confidence=0.95,
|
||||
ocr_confidence=0.95,
|
||||
bbox=(105, 52, 198, 28),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
]
|
||||
|
||||
pipeline._merge_fields(result)
|
||||
|
||||
assert result.fields["InvoiceNumber"] == "INV-001"
|
||||
assert result.confidence["InvoiceNumber"] == 0.95
|
||||
|
||||
def test_merge_skips_invalid_fields(self):
|
||||
"""Test that merge skips invalid extracted fields."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="",
|
||||
normalized_value=None,
|
||||
confidence=0.95,
|
||||
detection_confidence=0.95,
|
||||
ocr_confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=False, # Invalid
|
||||
),
|
||||
ExtractedField(
|
||||
field_name="Amount",
|
||||
raw_text="1500.00",
|
||||
normalized_value="1500.00",
|
||||
confidence=0.92,
|
||||
detection_confidence=0.92,
|
||||
ocr_confidence=0.92,
|
||||
bbox=(200, 100, 100, 25),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
]
|
||||
|
||||
pipeline._merge_fields(result)
|
||||
|
||||
assert "InvoiceNumber" not in result.fields
|
||||
assert result.fields["Amount"] == "1500.00"
|
||||
|
||||
|
||||
class TestPaymentLineValidation:
|
||||
"""Tests for payment line cross-validation."""
|
||||
|
||||
def test_payment_line_overrides_ocr(self):
|
||||
"""Test that payment line OCR overrides detected OCR."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Mock payment line parser
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = "12345678901234"
|
||||
mock_parsed.amount = "1500.00"
|
||||
mock_parsed.account_number = "12345678"
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
||||
"OCR": "99999999999999", # Different OCR
|
||||
},
|
||||
confidence={"OCR": 0.85},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
# Payment line OCR should override
|
||||
assert result.fields["OCR"] == "12345678901234"
|
||||
assert result.confidence["OCR"] == 0.95
|
||||
|
||||
def test_payment_line_overrides_amount(self):
|
||||
"""Test that payment line amount overrides detected amount."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = None
|
||||
mock_parsed.amount = "2500.50"
|
||||
mock_parsed.account_number = None
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# ... # 2500 50 5 > ...",
|
||||
"Amount": "2500.00", # Slightly different
|
||||
},
|
||||
confidence={"Amount": 0.80},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
assert result.fields["Amount"] == "2500.50"
|
||||
assert result.confidence["Amount"] == 0.95
|
||||
|
||||
def test_cross_validation_records_matches(self):
|
||||
"""Test that cross-validation records match status."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = "12345678901234"
|
||||
mock_parsed.amount = "1500.00"
|
||||
mock_parsed.account_number = "12345678"
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
||||
"OCR": "12345678901234", # Matching
|
||||
"Amount": "1500.00", # Matching
|
||||
"Bankgiro": "1234-5678", # Matching
|
||||
},
|
||||
confidence={},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
assert result.cross_validation is not None
|
||||
assert result.cross_validation.ocr_match is True
|
||||
assert result.cross_validation.amount_match is True
|
||||
assert result.cross_validation.is_valid is True
|
||||
|
||||
|
||||
class TestFallbackLogic:
|
||||
"""Tests for fallback detection logic."""
|
||||
|
||||
def test_needs_fallback_when_key_fields_missing(self):
|
||||
"""Test fallback is triggered when key fields missing."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Only one key field present
|
||||
result = InferenceResult(fields={"Amount": "1500.00"})
|
||||
|
||||
assert pipeline._needs_fallback(result) is True
|
||||
|
||||
def test_no_fallback_when_fields_present(self):
|
||||
"""Test no fallback when key fields present."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# All key fields present
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"Amount": "1500.00",
|
||||
"InvoiceNumber": "INV-001",
|
||||
"OCR": "12345678901234",
|
||||
}
|
||||
)
|
||||
|
||||
assert pipeline._needs_fallback(result) is False
|
||||
|
||||
|
||||
class TestPatternExtraction:
|
||||
"""Tests for fallback pattern extraction."""
|
||||
|
||||
def test_extract_amount_pattern(self):
|
||||
"""Test amount extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Att betala: 1 500,00 SEK"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "Amount" in result.fields
|
||||
assert result.confidence["Amount"] == 0.5
|
||||
|
||||
def test_extract_bankgiro_pattern(self):
|
||||
"""Test bankgiro extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Bankgiro: 1234-5678"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "Bankgiro" in result.fields
|
||||
assert result.fields["Bankgiro"] == "1234-5678"
|
||||
|
||||
def test_extract_ocr_pattern(self):
|
||||
"""Test OCR extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "OCR: 12345678901234567890"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "OCR" in result.fields
|
||||
assert result.fields["OCR"] == "12345678901234567890"
|
||||
|
||||
def test_does_not_override_existing_fields(self):
|
||||
"""Test pattern extraction doesn't override existing fields."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Fakturanr: 999"
|
||||
result = InferenceResult(fields={"InvoiceNumber": "INV-001"})
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
# Should keep existing value
|
||||
assert result.fields["InvoiceNumber"] == "INV-001"
|
||||
|
||||
|
||||
class TestAmountNormalization:
|
||||
"""Tests for amount normalization."""
|
||||
|
||||
def test_normalize_swedish_format(self):
|
||||
"""Test normalizing Swedish amount format."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Swedish format: space as thousands separator, comma as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00
|
||||
# Standard format: dot as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00
|
||||
# Swedish format with comma as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00
|
||||
|
||||
def test_normalize_invalid_amount(self):
|
||||
"""Test normalizing invalid amount returns None."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
assert pipeline._normalize_amount_for_compare("invalid") is None
|
||||
assert pipeline._normalize_amount_for_compare("") is None
|
||||
|
||||
|
||||
class TestResultSerialization:
|
||||
"""Tests for result serialization with cross-validation."""
|
||||
|
||||
def test_to_json_with_cross_validation(self):
|
||||
"""Test JSON serialization includes cross-validation."""
|
||||
cv = CrossValidationResult(
|
||||
is_valid=True,
|
||||
ocr_match=True,
|
||||
amount_match=True,
|
||||
payment_line_ocr="12345678901234",
|
||||
payment_line_amount="1500.00",
|
||||
details=["OCR match", "Amount match"],
|
||||
)
|
||||
|
||||
result = InferenceResult(
|
||||
document_id="test-doc",
|
||||
success=True,
|
||||
fields={"InvoiceNumber": "INV-001"},
|
||||
cross_validation=cv,
|
||||
)
|
||||
|
||||
json_data = result.to_json()
|
||||
|
||||
assert "cross_validation" in json_data
|
||||
assert json_data["cross_validation"]["is_valid"] is True
|
||||
assert json_data["cross_validation"]["ocr_match"] is True
|
||||
assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"
|
||||
Reference in New Issue
Block a user