This commit is contained in:
Yaojia Wang
2026-02-03 22:29:53 +01:00
parent 4c7fc3015c
commit c2c8f2dd04
10 changed files with 786 additions and 68 deletions

BIN
.coverage

Binary file not shown.

View File

@@ -0,0 +1,25 @@
"""
Domain Layer
Business logic separated from technical implementation.
Contains document classification and invoice validation logic.
"""
from backend.domain.document_classifier import (
ClassificationResult,
DocumentClassifier,
)
from backend.domain.invoice_validator import (
InvoiceValidator,
ValidationIssue,
ValidationResult,
)
from backend.domain.utils import has_value
__all__ = [
"ClassificationResult",
"DocumentClassifier",
"InvoiceValidator",
"ValidationIssue",
"ValidationResult",
"has_value",
]

View File

@@ -0,0 +1,108 @@
"""
Document Classifier
Business logic for classifying documents based on extracted fields.
Separates classification logic from inference pipeline.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ClassificationResult:
"""
Immutable result of document classification.
Attributes:
document_type: Either "invoice" or "letter"
confidence: Confidence score between 0.0 and 1.0
reason: Human-readable explanation of classification
"""
document_type: str
confidence: float
reason: str
class DocumentClassifier:
"""
Classifies documents as invoice or letter based on extracted fields.
Classification Rules:
1. If payment_line is present -> invoice (high confidence)
2. If 2+ invoice indicators present -> invoice (medium confidence)
3. If 1 invoice indicator present -> invoice (lower confidence)
4. Otherwise -> letter
Invoice indicator fields:
- payment_line (strongest indicator)
- OCR
- Amount
- Bankgiro
- Plusgiro
- InvoiceNumber
"""
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
{
"payment_line",
"OCR",
"Amount",
"Bankgiro",
"Plusgiro",
"InvoiceNumber",
}
)
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
"""
Classify document type based on extracted fields.
Args:
fields: Dictionary of field names to extracted values.
Empty strings or whitespace-only strings are treated as missing.
Returns:
Immutable ClassificationResult with type, confidence, and reason.
"""
# Rule 1: payment_line is the strongest indicator
if has_value(fields.get("payment_line")):
return ClassificationResult(
document_type="invoice",
confidence=0.95,
reason="payment_line detected",
)
# Count present invoice indicators (excluding payment_line already checked)
present_indicators = [
field
for field in self.INVOICE_INDICATOR_FIELDS
if field != "payment_line" and has_value(fields.get(field))
]
indicator_count = len(present_indicators)
# Rule 2: Multiple indicators -> invoice with medium-high confidence
if indicator_count >= 2:
return ClassificationResult(
document_type="invoice",
confidence=0.8,
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
)
# Rule 3: Single indicator -> invoice with lower confidence
if indicator_count == 1:
return ClassificationResult(
document_type="invoice",
confidence=0.6,
reason=f"1 invoice indicator present: {present_indicators[0]}",
)
# Rule 4: No indicators -> letter
return ClassificationResult(
document_type="letter",
confidence=0.7,
reason="no invoice indicators found",
)

View File

@@ -0,0 +1,141 @@
"""
Invoice Validator
Business logic for validating extracted invoice fields.
Checks for required fields, format validity, and confidence thresholds.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ValidationIssue:
"""
Single validation issue.
Attributes:
field: Name of the field with the issue
severity: One of "error", "warning", "info"
message: Human-readable description of the issue
"""
field: str
severity: str
message: str
@dataclass(frozen=True)
class ValidationResult:
"""
Immutable result of invoice validation.
Attributes:
is_valid: True if no errors (warnings are allowed)
issues: Tuple of validation issues found
confidence: Average confidence score of validated fields
"""
is_valid: bool
issues: tuple[ValidationIssue, ...]
confidence: float
class InvoiceValidator:
"""
Validates extracted invoice fields for completeness and consistency.
Validation Rules:
1. Required fields must be present (Amount)
2. At least one payment reference should be present (warning if missing)
3. Field confidence should be above threshold (warning if below)
Required fields:
- Amount
Payment reference fields (at least one expected):
- OCR
- Bankgiro
- Plusgiro
- payment_line
"""
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
DEFAULT_MIN_CONFIDENCE: float = 0.5
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
"""
Initialize validator.
Args:
min_confidence: Minimum confidence threshold for valid fields.
Fields below this threshold produce warnings.
"""
self._min_confidence = min_confidence
def validate(
self,
fields: dict[str, str | None],
confidence: dict[str, float],
) -> ValidationResult:
"""
Validate extracted invoice fields.
Args:
fields: Dictionary of field names to extracted values
confidence: Dictionary of field names to confidence scores
Returns:
Immutable ValidationResult with validity status and issues
"""
issues: list[ValidationIssue] = []
# Check required fields
for field in self.REQUIRED_FIELDS:
if not has_value(fields.get(field)):
issues.append(
ValidationIssue(
field=field,
severity="error",
message=f"Required field '{field}' is missing",
)
)
# Check payment reference (at least one expected)
has_payment_ref = any(
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
)
if not has_payment_ref:
issues.append(
ValidationIssue(
field="payment_reference",
severity="warning",
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
)
)
# Check confidence thresholds
for field, conf in confidence.items():
if conf < self._min_confidence:
issues.append(
ValidationIssue(
field=field,
severity="warning",
message=f"Low confidence ({conf:.2f}) for field '{field}'",
)
)
# Calculate overall validity
has_errors = any(i.severity == "error" for i in issues)
avg_confidence = (
sum(confidence.values()) / len(confidence) if confidence else 0.0
)
return ValidationResult(
is_valid=not has_errors,
issues=tuple(issues),
confidence=avg_confidence,
)

View File

@@ -0,0 +1,23 @@
"""
Domain Layer Utilities
Shared helper functions for domain layer classes.
"""
from __future__ import annotations
def has_value(value: str | None) -> bool:
"""
Check if a field value is present and non-empty.
Args:
value: Field value to check
Returns:
True if value is a non-empty, non-whitespace string
"""
if value is None:
return False
if not isinstance(value, str):
return bool(value)
return bool(value.strip())

View File

@@ -1,21 +1,24 @@
""" """
Inference Service Inference Service (Adapter Layer)
Business logic for invoice field extraction. Orchestrates technical pipeline and business domain logic.
Acts as adapter between API layer and internal components.
""" """
from __future__ import annotations from __future__ import annotations
import io
import logging import logging
import time import time
import uuid import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable, Generator
import numpy as np
from PIL import Image from PIL import Image
from backend.domain.document_classifier import DocumentClassifier
from backend.web.services.storage_helpers import get_storage_helper from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -50,9 +53,12 @@ class ServiceResult:
class InferenceService: class InferenceService:
""" """
Service for running invoice field extraction. Service for running invoice field extraction (Adapter Pattern).
Orchestrates:
- Technical layer: InferencePipeline, YOLODetector
- Business layer: DocumentClassifier
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database. Supports dynamic model loading from database.
""" """
@@ -61,6 +67,7 @@ class InferenceService:
model_config: ModelConfig, model_config: ModelConfig,
storage_config: StorageConfig, storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None, model_path_resolver: ModelPathResolver | None = None,
document_classifier: DocumentClassifier | None = None,
) -> None: ) -> None:
""" """
Initialize inference service. Initialize inference service.
@@ -71,12 +78,19 @@ class InferenceService:
model_path_resolver: Optional function to resolve model path from database. model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path. If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path. If returns None, falls back to model_config.model_path.
document_classifier: Optional custom classifier (uses default if None)
""" """
self.model_config = model_config self.model_config = model_config
self.storage_config = storage_config self.storage_config = storage_config
self._model_path_resolver = model_path_resolver self._model_path_resolver = model_path_resolver
# Technical layer (lazy initialized)
self._pipeline = None self._pipeline = None
self._detector = None self._detector = None
# Business layer (eagerly initialized, no heavy resources)
self._classifier = document_classifier or DocumentClassifier()
self._is_initialized = False self._is_initialized = False
self._current_model_path: Path | None = None self._current_model_path: Path | None = None
self._business_features_enabled = False self._business_features_enabled = False
@@ -219,22 +233,12 @@ class InferenceService:
result.success = pipeline_result.success result.success = pipeline_result.success
result.errors = pipeline_result.errors result.errors = pipeline_result.errors
# Determine document type based on payment_line presence # Business layer: classify document type
# If no payment_line found, it's likely a letter, not an invoice classification = self._classifier.classify(result.fields)
if not result.fields.get('payment_line'): result.document_type = classification.document_type
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections for visualization # Get raw detections for visualization
result.detections = [ result.detections = self._format_detections(pipeline_result.raw_detections)
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization if requested # Save visualization if requested
if save_visualization and pipeline_result.raw_detections: if save_visualization and pipeline_result.raw_detections:
@@ -293,22 +297,12 @@ class InferenceService:
result.success = pipeline_result.success result.success = pipeline_result.success
result.errors = pipeline_result.errors result.errors = pipeline_result.errors
# Determine document type based on payment_line presence # Business layer: classify document type
# If no payment_line found, it's likely a letter, not an invoice classification = self._classifier.classify(result.fields)
if not result.fields.get('payment_line'): result.document_type = classification.document_type
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections # Get raw detections
result.detections = [ result.detections = self._format_detections(pipeline_result.raw_detections)
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Include business features if extracted # Include business features if extracted
if extract_line_items: if extract_line_items:
@@ -329,10 +323,19 @@ class InferenceService:
result.processing_time_ms = (time.time() - start_time) * 1000 result.processing_time_ms = (time.time() - start_time) * 1000
return result return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path: def _format_detections(self, raw_detections: list) -> list[dict]:
"""Save visualization image with detections.""" """Format raw detections for response."""
from ultralytics import YOLO return [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in raw_detections
]
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
"""Save visualization image with detections using existing detector."""
# Get storage helper for results directory # Get storage helper for results directory
storage = get_storage_helper() storage = get_storage_helper()
results_dir = storage.get_results_base_path() results_dir = storage.get_results_base_path()
@@ -340,9 +343,8 @@ class InferenceService:
logger.warning("Cannot save visualization: local storage not available") logger.warning("Cannot save visualization: local storage not available")
return None return None
# Load model and run prediction with visualization # Reuse self._detector instead of creating new YOLO instance
model = YOLO(str(self.model_config.model_path)) results = self._detector.model.predict(str(image_path), verbose=False)
results = model.predict(str(image_path), verbose=False)
# Save annotated image # Save annotated image
output_path = results_dir / f"{doc_id}_result.png" output_path = results_dir / f"{doc_id}_result.png"
@@ -351,11 +353,20 @@ class InferenceService:
return output_path return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path: @contextmanager
"""Save visualization for PDF (first page).""" def _temp_image_file(
self, results_dir: Path, doc_id: str
) -> Generator[Path, None, None]:
"""Context manager for temporary image file with guaranteed cleanup."""
temp_path = results_dir / f"{doc_id}_temp.png"
try:
yield temp_path
finally:
temp_path.unlink(missing_ok=True)
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
"""Save visualization for PDF (first page) using existing detector."""
from shared.pdf.renderer import render_pdf_to_images from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Get storage helper for results directory # Get storage helper for results directory
storage = get_storage_helper() storage = get_storage_helper()
@@ -369,19 +380,18 @@ class InferenceService:
pdf_path, dpi=self.model_config.dpi pdf_path, dpi=self.model_config.dpi
): ):
image = Image.open(io.BytesIO(image_bytes)) image = Image.open(io.BytesIO(image_bytes))
temp_path = results_dir / f"{doc_id}_temp.png"
# Use context manager for temp file to guarantee cleanup
with self._temp_image_file(results_dir, doc_id) as temp_path:
image.save(temp_path) image.save(temp_path)
# Run YOLO and save visualization # Reuse self._detector instead of creating new YOLO instance
model = YOLO(str(self.model_config.model_path)) results = self._detector.model.predict(str(temp_path), verbose=False)
results = model.predict(str(temp_path), verbose=False)
output_path = results_dir / f"{doc_id}_result.png" output_path = results_dir / f"{doc_id}_result.png"
for r in results: for r in results:
r.save(filename=str(output_path)) r.save(filename=str(output_path))
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path return output_path
# If no pages rendered # If no pages rendered

1
tests/domain/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Domain layer tests

View File

@@ -0,0 +1,176 @@
"""
Tests for DocumentClassifier - TDD RED phase.
Test document type classification based on extracted fields.
"""
import pytest
from backend.domain.document_classifier import DocumentClassifier, ClassificationResult
class TestDocumentClassifier:
"""Test document classification logic."""
@pytest.fixture
def classifier(self) -> DocumentClassifier:
"""Create classifier instance."""
return DocumentClassifier()
# ==================== Invoice Detection Tests ====================
def test_classify_with_payment_line_returns_invoice(
self, classifier: DocumentClassifier
) -> None:
"""Payment line is the strongest invoice indicator."""
fields = {"payment_line": "# 123456 # 100 00 5 > 308-2963#"}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert result.confidence >= 0.9
assert "payment_line" in result.reason
def test_classify_with_multiple_indicators_returns_invoice(
self, classifier: DocumentClassifier
) -> None:
"""Multiple invoice indicators -> invoice with medium confidence."""
fields = {
"Amount": "1200.00",
"Bankgiro": "123-4567",
"payment_line": None,
}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert result.confidence >= 0.7
def test_classify_with_ocr_and_amount_returns_invoice(
self, classifier: DocumentClassifier
) -> None:
"""OCR + Amount is typical invoice pattern."""
fields = {
"OCR": "123456789012",
"Amount": "500.00",
}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert result.confidence >= 0.7
def test_classify_with_single_indicator_returns_invoice_lower_confidence(
self, classifier: DocumentClassifier
) -> None:
"""Single indicator -> invoice but lower confidence."""
fields = {"Amount": "100.00"}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert 0.5 <= result.confidence < 0.8
def test_classify_with_invoice_number_only(
self, classifier: DocumentClassifier
) -> None:
"""Invoice number alone suggests invoice."""
fields = {"InvoiceNumber": "INV-2024-001"}
result = classifier.classify(fields)
assert result.document_type == "invoice"
# ==================== Letter Detection Tests ====================
def test_classify_with_no_indicators_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""No invoice indicators -> letter."""
fields: dict[str, str | None] = {}
result = classifier.classify(fields)
assert result.document_type == "letter"
assert result.confidence >= 0.5
def test_classify_with_empty_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""All fields empty or None -> letter."""
fields = {
"payment_line": None,
"OCR": None,
"Amount": None,
"Bankgiro": None,
}
result = classifier.classify(fields)
assert result.document_type == "letter"
def test_classify_with_only_non_indicator_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""Fields that don't indicate invoice -> letter."""
fields = {
"CustomerNumber": "C12345",
"SupplierOrgNumber": "556677-8899",
}
result = classifier.classify(fields)
assert result.document_type == "letter"
# ==================== Edge Cases ====================
def test_classify_with_empty_string_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""Empty strings should be treated as missing."""
fields = {
"payment_line": "",
"Amount": "",
}
result = classifier.classify(fields)
assert result.document_type == "letter"
def test_classify_with_whitespace_only_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""Whitespace-only strings should be treated as missing."""
fields = {
"payment_line": " ",
"Amount": "\t\n",
}
result = classifier.classify(fields)
assert result.document_type == "letter"
# ==================== ClassificationResult Immutability ====================
def test_classification_result_is_immutable(
self, classifier: DocumentClassifier
) -> None:
"""ClassificationResult should be a frozen dataclass."""
fields = {"payment_line": "test"}
result = classifier.classify(fields)
with pytest.raises((AttributeError, TypeError)):
result.document_type = "modified" # type: ignore
def test_classification_result_has_required_fields(
self, classifier: DocumentClassifier
) -> None:
"""ClassificationResult must have document_type, confidence, reason."""
fields = {"Amount": "100.00"}
result = classifier.classify(fields)
assert hasattr(result, "document_type")
assert hasattr(result, "confidence")
assert hasattr(result, "reason")
assert isinstance(result.document_type, str)
assert isinstance(result.confidence, float)
assert isinstance(result.reason, str)

View File

@@ -0,0 +1,232 @@
"""
Tests for InvoiceValidator - TDD RED phase.
Test invoice field validation logic.
"""
import pytest
from backend.domain.invoice_validator import (
InvoiceValidator,
ValidationResult,
ValidationIssue,
)
class TestInvoiceValidator:
"""Test invoice validation logic."""
@pytest.fixture
def validator(self) -> InvoiceValidator:
"""Create validator instance with default settings."""
return InvoiceValidator()
@pytest.fixture
def validator_strict(self) -> InvoiceValidator:
"""Create validator with strict confidence threshold."""
return InvoiceValidator(min_confidence=0.8)
# ==================== Valid Invoice Tests ====================
def test_validate_complete_invoice_is_valid(
self, validator: InvoiceValidator
) -> None:
"""Complete invoice with all required fields is valid."""
fields = {
"Amount": "1200.00",
"OCR": "123456789012",
"Bankgiro": "123-4567",
}
confidence = {
"Amount": 0.95,
"OCR": 0.90,
"Bankgiro": 0.85,
}
result = validator.validate(fields, confidence)
assert result.is_valid is True
assert len([i for i in result.issues if i.severity == "error"]) == 0
def test_validate_invoice_with_payment_line_is_valid(
self, validator: InvoiceValidator
) -> None:
"""Invoice with payment_line as payment reference is valid."""
fields = {
"Amount": "500.00",
"payment_line": "# 123 # 500 00 5 > 308#",
}
confidence = {"Amount": 0.9, "payment_line": 0.85}
result = validator.validate(fields, confidence)
assert result.is_valid is True
# ==================== Invalid Invoice Tests ====================
def test_validate_missing_amount_is_invalid(
self, validator: InvoiceValidator
) -> None:
"""Missing Amount field should produce error."""
fields = {
"OCR": "123456789012",
"Bankgiro": "123-4567",
}
confidence = {"OCR": 0.9, "Bankgiro": 0.85}
result = validator.validate(fields, confidence)
assert result.is_valid is False
error_fields = [i.field for i in result.issues if i.severity == "error"]
assert "Amount" in error_fields
def test_validate_missing_payment_reference_produces_warning(
self, validator: InvoiceValidator
) -> None:
"""Missing all payment references should produce warning."""
fields = {"Amount": "1200.00"}
confidence = {"Amount": 0.9}
result = validator.validate(fields, confidence)
# Missing payment ref is warning, not error
warning_fields = [i.field for i in result.issues if i.severity == "warning"]
assert "payment_reference" in warning_fields
# ==================== Confidence Threshold Tests ====================
def test_validate_low_confidence_produces_warning(
self, validator: InvoiceValidator
) -> None:
"""Fields below confidence threshold should produce warning."""
fields = {
"Amount": "1200.00",
"OCR": "123456789012",
}
confidence = {
"Amount": 0.9,
"OCR": 0.3, # Below default threshold of 0.5
}
result = validator.validate(fields, confidence)
low_conf_warnings = [
i for i in result.issues
if i.severity == "warning" and "confidence" in i.message.lower()
]
assert len(low_conf_warnings) > 0
def test_validate_strict_threshold_more_warnings(
self, validator_strict: InvoiceValidator
) -> None:
"""Strict validator should produce more warnings."""
fields = {
"Amount": "1200.00",
"OCR": "123456789012",
}
confidence = {
"Amount": 0.7, # Below 0.8 threshold
"OCR": 0.6, # Below 0.8 threshold
}
result = validator_strict.validate(fields, confidence)
low_conf_warnings = [
i for i in result.issues
if i.severity == "warning" and "confidence" in i.message.lower()
]
assert len(low_conf_warnings) >= 2
# ==================== Edge Cases ====================
def test_validate_empty_fields_is_invalid(
self, validator: InvoiceValidator
) -> None:
"""Empty fields dict should be invalid."""
fields: dict[str, str | None] = {}
confidence: dict[str, float] = {}
result = validator.validate(fields, confidence)
assert result.is_valid is False
def test_validate_none_field_values_treated_as_missing(
self, validator: InvoiceValidator
) -> None:
"""None values should be treated as missing."""
fields = {
"Amount": None,
"OCR": "123456789012",
}
confidence = {"OCR": 0.9}
result = validator.validate(fields, confidence)
assert result.is_valid is False
error_fields = [i.field for i in result.issues if i.severity == "error"]
assert "Amount" in error_fields
def test_validate_empty_string_treated_as_missing(
self, validator: InvoiceValidator
) -> None:
"""Empty string should be treated as missing."""
fields = {
"Amount": "",
"OCR": "123456789012",
}
confidence = {"OCR": 0.9}
result = validator.validate(fields, confidence)
assert result.is_valid is False
# ==================== ValidationResult Properties ====================
def test_validation_result_is_immutable(
self, validator: InvoiceValidator
) -> None:
"""ValidationResult should be a frozen dataclass."""
fields = {"Amount": "100.00", "OCR": "123"}
confidence = {"Amount": 0.9, "OCR": 0.9}
result = validator.validate(fields, confidence)
with pytest.raises((AttributeError, TypeError)):
result.is_valid = False # type: ignore
def test_validation_result_issues_is_tuple(
self, validator: InvoiceValidator
) -> None:
"""Issues should be a tuple (immutable)."""
fields = {"Amount": "100.00"}
confidence = {"Amount": 0.9}
result = validator.validate(fields, confidence)
assert isinstance(result.issues, tuple)
def test_validation_result_has_confidence(
self, validator: InvoiceValidator
) -> None:
"""ValidationResult should have confidence score."""
fields = {"Amount": "100.00", "OCR": "123"}
confidence = {"Amount": 0.9, "OCR": 0.8}
result = validator.validate(fields, confidence)
assert hasattr(result, "confidence")
assert 0.0 <= result.confidence <= 1.0
# ==================== ValidationIssue Tests ====================
def test_validation_issue_has_required_fields(
self, validator: InvoiceValidator
) -> None:
"""ValidationIssue must have field, severity, message."""
fields: dict[str, str | None] = {}
confidence: dict[str, float] = {}
result = validator.validate(fields, confidence)
assert len(result.issues) > 0
issue = result.issues[0]
assert hasattr(issue, "field")
assert hasattr(issue, "severity")
assert hasattr(issue, "message")
assert issue.severity in ("error", "warning", "info")

View File

@@ -232,10 +232,8 @@ class TestInferenceServicePDFRendering:
@patch('backend.pipeline.pipeline.InferencePipeline') @patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector') @patch('backend.pipeline.yolo_detector.YOLODetector')
@patch('shared.pdf.renderer.render_pdf_to_images') @patch('shared.pdf.renderer.render_pdf_to_images')
@patch('ultralytics.YOLO')
def test_pdf_visualization_imports_correctly( def test_pdf_visualization_imports_correctly(
self, self,
mock_yolo_class,
mock_render_pdf, mock_render_pdf,
mock_yolo_detector, mock_yolo_detector,
mock_pipeline, mock_pipeline,
@@ -248,12 +246,22 @@ class TestInferenceServicePDFRendering:
This catches the import error we had with: This catches the import error we had with:
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct) from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
""" """
# Setup mocks # Setup mocks for detector
mock_detector_instance = Mock() mock_detector_instance = Mock()
mock_pipeline_instance = Mock() mock_model = Mock()
mock_result = Mock()
mock_result.save = Mock()
mock_model.predict.return_value = [mock_result]
mock_detector_instance.model = mock_model
mock_yolo_detector.return_value = mock_detector_instance mock_yolo_detector.return_value = mock_detector_instance
# Setup mock for pipeline
mock_pipeline_instance = Mock()
mock_pipeline.return_value = mock_pipeline_instance mock_pipeline.return_value = mock_pipeline_instance
# Initialize service to setup _detector
inference_service.initialize()
# Create a fake PDF path # Create a fake PDF path
pdf_path = tmp_path / "test.pdf" pdf_path = tmp_path / "test.pdf"
pdf_path.touch() pdf_path.touch()
@@ -264,18 +272,12 @@ class TestInferenceServicePDFRendering:
img.save(image_bytes, format='PNG') img.save(image_bytes, format='PNG')
mock_render_pdf.return_value = [(1, image_bytes.getvalue())] mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
# Mock YOLO # This should not raise ImportError and should use self._detector.model
mock_model_instance = Mock()
mock_result = Mock()
mock_result.save = Mock()
mock_model_instance.predict.return_value = [mock_result]
mock_yolo_class.return_value = mock_model_instance
# This should not raise ImportError
result_path = inference_service._save_pdf_visualization(pdf_path, "test123") result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
# Verify import was successful # Verify import was successful and detector.model was used
mock_render_pdf.assert_called_once() mock_render_pdf.assert_called_once()
mock_model.predict.assert_called_once()
assert result_path is not None assert result_path is not None