diff --git a/.coverage b/.coverage index d71ec33..7898222 100644 Binary files a/.coverage and b/.coverage differ diff --git a/packages/backend/backend/domain/__init__.py b/packages/backend/backend/domain/__init__.py new file mode 100644 index 0000000..90fa440 --- /dev/null +++ b/packages/backend/backend/domain/__init__.py @@ -0,0 +1,25 @@ +""" +Domain Layer + +Business logic separated from technical implementation. +Contains document classification and invoice validation logic. +""" +from backend.domain.document_classifier import ( + ClassificationResult, + DocumentClassifier, +) +from backend.domain.invoice_validator import ( + InvoiceValidator, + ValidationIssue, + ValidationResult, +) +from backend.domain.utils import has_value + +__all__ = [ + "ClassificationResult", + "DocumentClassifier", + "InvoiceValidator", + "ValidationIssue", + "ValidationResult", + "has_value", +] diff --git a/packages/backend/backend/domain/document_classifier.py b/packages/backend/backend/domain/document_classifier.py new file mode 100644 index 0000000..2ae7e16 --- /dev/null +++ b/packages/backend/backend/domain/document_classifier.py @@ -0,0 +1,108 @@ +""" +Document Classifier + +Business logic for classifying documents based on extracted fields. +Separates classification logic from inference pipeline. +""" +from __future__ import annotations + +from dataclasses import dataclass + +from backend.domain.utils import has_value + + +@dataclass(frozen=True) +class ClassificationResult: + """ + Immutable result of document classification. + + Attributes: + document_type: Either "invoice" or "letter" + confidence: Confidence score between 0.0 and 1.0 + reason: Human-readable explanation of classification + """ + + document_type: str + confidence: float + reason: str + + +class DocumentClassifier: + """ + Classifies documents as invoice or letter based on extracted fields. + + Classification Rules: + 1. If payment_line is present -> invoice (high confidence) + 2. If 2+ invoice indicators present -> invoice (medium confidence) + 3. If 1 invoice indicator present -> invoice (lower confidence) + 4. Otherwise -> letter + + Invoice indicator fields: + - payment_line (strongest indicator) + - OCR + - Amount + - Bankgiro + - Plusgiro + - InvoiceNumber + """ + + INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset( + { + "payment_line", + "OCR", + "Amount", + "Bankgiro", + "Plusgiro", + "InvoiceNumber", + } + ) + + def classify(self, fields: dict[str, str | None]) -> ClassificationResult: + """ + Classify document type based on extracted fields. + + Args: + fields: Dictionary of field names to extracted values. + Empty strings or whitespace-only strings are treated as missing. + + Returns: + Immutable ClassificationResult with type, confidence, and reason. + """ + # Rule 1: payment_line is the strongest indicator + if has_value(fields.get("payment_line")): + return ClassificationResult( + document_type="invoice", + confidence=0.95, + reason="payment_line detected", + ) + + # Count present invoice indicators (excluding payment_line already checked) + present_indicators = [ + field + for field in self.INVOICE_INDICATOR_FIELDS + if field != "payment_line" and has_value(fields.get(field)) + ] + indicator_count = len(present_indicators) + + # Rule 2: Multiple indicators -> invoice with medium-high confidence + if indicator_count >= 2: + return ClassificationResult( + document_type="invoice", + confidence=0.8, + reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}", + ) + + # Rule 3: Single indicator -> invoice with lower confidence + if indicator_count == 1: + return ClassificationResult( + document_type="invoice", + confidence=0.6, + reason=f"1 invoice indicator present: {present_indicators[0]}", + ) + + # Rule 4: No indicators -> letter + return ClassificationResult( + document_type="letter", + confidence=0.7, + reason="no invoice indicators found", + ) diff --git a/packages/backend/backend/domain/invoice_validator.py b/packages/backend/backend/domain/invoice_validator.py new file mode 100644 index 0000000..45caa86 --- /dev/null +++ b/packages/backend/backend/domain/invoice_validator.py @@ -0,0 +1,141 @@ +""" +Invoice Validator + +Business logic for validating extracted invoice fields. +Checks for required fields, format validity, and confidence thresholds. +""" +from __future__ import annotations + +from dataclasses import dataclass + +from backend.domain.utils import has_value + + +@dataclass(frozen=True) +class ValidationIssue: + """ + Single validation issue. + + Attributes: + field: Name of the field with the issue + severity: One of "error", "warning", "info" + message: Human-readable description of the issue + """ + + field: str + severity: str + message: str + + +@dataclass(frozen=True) +class ValidationResult: + """ + Immutable result of invoice validation. + + Attributes: + is_valid: True if no errors (warnings are allowed) + issues: Tuple of validation issues found + confidence: Average confidence score of validated fields + """ + + is_valid: bool + issues: tuple[ValidationIssue, ...] + confidence: float + + +class InvoiceValidator: + """ + Validates extracted invoice fields for completeness and consistency. + + Validation Rules: + 1. Required fields must be present (Amount) + 2. At least one payment reference should be present (warning if missing) + 3. Field confidence should be above threshold (warning if below) + + Required fields: + - Amount + + Payment reference fields (at least one expected): + - OCR + - Bankgiro + - Plusgiro + - payment_line + """ + + REQUIRED_FIELDS: tuple[str, ...] = ("Amount",) + PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line") + DEFAULT_MIN_CONFIDENCE: float = 0.5 + + def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None: + """ + Initialize validator. + + Args: + min_confidence: Minimum confidence threshold for valid fields. + Fields below this threshold produce warnings. + """ + self._min_confidence = min_confidence + + def validate( + self, + fields: dict[str, str | None], + confidence: dict[str, float], + ) -> ValidationResult: + """ + Validate extracted invoice fields. + + Args: + fields: Dictionary of field names to extracted values + confidence: Dictionary of field names to confidence scores + + Returns: + Immutable ValidationResult with validity status and issues + """ + issues: list[ValidationIssue] = [] + + # Check required fields + for field in self.REQUIRED_FIELDS: + if not has_value(fields.get(field)): + issues.append( + ValidationIssue( + field=field, + severity="error", + message=f"Required field '{field}' is missing", + ) + ) + + # Check payment reference (at least one expected) + has_payment_ref = any( + has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS + ) + if not has_payment_ref: + issues.append( + ValidationIssue( + field="payment_reference", + severity="warning", + message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)", + ) + ) + + # Check confidence thresholds + for field, conf in confidence.items(): + if conf < self._min_confidence: + issues.append( + ValidationIssue( + field=field, + severity="warning", + message=f"Low confidence ({conf:.2f}) for field '{field}'", + ) + ) + + # Calculate overall validity + has_errors = any(i.severity == "error" for i in issues) + avg_confidence = ( + sum(confidence.values()) / len(confidence) if confidence else 0.0 + ) + + return ValidationResult( + is_valid=not has_errors, + issues=tuple(issues), + confidence=avg_confidence, + ) diff --git a/packages/backend/backend/domain/utils.py b/packages/backend/backend/domain/utils.py new file mode 100644 index 0000000..53f4fca --- /dev/null +++ b/packages/backend/backend/domain/utils.py @@ -0,0 +1,23 @@ +""" +Domain Layer Utilities + +Shared helper functions for domain layer classes. +""" +from __future__ import annotations + + +def has_value(value: str | None) -> bool: + """ + Check if a field value is present and non-empty. + + Args: + value: Field value to check + + Returns: + True if value is a non-empty, non-whitespace string + """ + if value is None: + return False + if not isinstance(value, str): + return bool(value) + return bool(value.strip()) diff --git a/packages/backend/backend/web/services/inference.py b/packages/backend/backend/web/services/inference.py index 576c2ee..fb645d9 100644 --- a/packages/backend/backend/web/services/inference.py +++ b/packages/backend/backend/web/services/inference.py @@ -1,21 +1,24 @@ """ -Inference Service +Inference Service (Adapter Layer) -Business logic for invoice field extraction. +Orchestrates technical pipeline and business domain logic. +Acts as adapter between API layer and internal components. """ from __future__ import annotations +import io import logging import time import uuid +from contextlib import contextmanager from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Generator -import numpy as np from PIL import Image +from backend.domain.document_classifier import DocumentClassifier from backend.web.services.storage_helpers import get_storage_helper if TYPE_CHECKING: @@ -50,9 +53,12 @@ class ServiceResult: class InferenceService: """ - Service for running invoice field extraction. + Service for running invoice field extraction (Adapter Pattern). + + Orchestrates: + - Technical layer: InferencePipeline, YOLODetector + - Business layer: DocumentClassifier - Encapsulates YOLO detection and OCR extraction logic. Supports dynamic model loading from database. """ @@ -61,6 +67,7 @@ class InferenceService: model_config: ModelConfig, storage_config: StorageConfig, model_path_resolver: ModelPathResolver | None = None, + document_classifier: DocumentClassifier | None = None, ) -> None: """ Initialize inference service. @@ -71,12 +78,19 @@ class InferenceService: model_path_resolver: Optional function to resolve model path from database. If provided, will be called to get active model path. If returns None, falls back to model_config.model_path. + document_classifier: Optional custom classifier (uses default if None) """ self.model_config = model_config self.storage_config = storage_config self._model_path_resolver = model_path_resolver + + # Technical layer (lazy initialized) self._pipeline = None self._detector = None + + # Business layer (eagerly initialized, no heavy resources) + self._classifier = document_classifier or DocumentClassifier() + self._is_initialized = False self._current_model_path: Path | None = None self._business_features_enabled = False @@ -219,22 +233,12 @@ class InferenceService: result.success = pipeline_result.success result.errors = pipeline_result.errors - # Determine document type based on payment_line presence - # If no payment_line found, it's likely a letter, not an invoice - if not result.fields.get('payment_line'): - result.document_type = "letter" - else: - result.document_type = "invoice" + # Business layer: classify document type + classification = self._classifier.classify(result.fields) + result.document_type = classification.document_type # Get raw detections for visualization - result.detections = [ - { - "field": d.class_name, - "confidence": d.confidence, - "bbox": list(d.bbox), - } - for d in pipeline_result.raw_detections - ] + result.detections = self._format_detections(pipeline_result.raw_detections) # Save visualization if requested if save_visualization and pipeline_result.raw_detections: @@ -293,22 +297,12 @@ class InferenceService: result.success = pipeline_result.success result.errors = pipeline_result.errors - # Determine document type based on payment_line presence - # If no payment_line found, it's likely a letter, not an invoice - if not result.fields.get('payment_line'): - result.document_type = "letter" - else: - result.document_type = "invoice" + # Business layer: classify document type + classification = self._classifier.classify(result.fields) + result.document_type = classification.document_type # Get raw detections - result.detections = [ - { - "field": d.class_name, - "confidence": d.confidence, - "bbox": list(d.bbox), - } - for d in pipeline_result.raw_detections - ] + result.detections = self._format_detections(pipeline_result.raw_detections) # Include business features if extracted if extract_line_items: @@ -329,10 +323,19 @@ class InferenceService: result.processing_time_ms = (time.time() - start_time) * 1000 return result - def _save_visualization(self, image_path: Path, doc_id: str) -> Path: - """Save visualization image with detections.""" - from ultralytics import YOLO + def _format_detections(self, raw_detections: list) -> list[dict]: + """Format raw detections for response.""" + return [ + { + "field": d.class_name, + "confidence": d.confidence, + "bbox": list(d.bbox), + } + for d in raw_detections + ] + def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None: + """Save visualization image with detections using existing detector.""" # Get storage helper for results directory storage = get_storage_helper() results_dir = storage.get_results_base_path() @@ -340,9 +343,8 @@ class InferenceService: logger.warning("Cannot save visualization: local storage not available") return None - # Load model and run prediction with visualization - model = YOLO(str(self.model_config.model_path)) - results = model.predict(str(image_path), verbose=False) + # Reuse self._detector instead of creating new YOLO instance + results = self._detector.model.predict(str(image_path), verbose=False) # Save annotated image output_path = results_dir / f"{doc_id}_result.png" @@ -351,11 +353,20 @@ class InferenceService: return output_path - def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path: - """Save visualization for PDF (first page).""" + @contextmanager + def _temp_image_file( + self, results_dir: Path, doc_id: str + ) -> Generator[Path, None, None]: + """Context manager for temporary image file with guaranteed cleanup.""" + temp_path = results_dir / f"{doc_id}_temp.png" + try: + yield temp_path + finally: + temp_path.unlink(missing_ok=True) + + def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None: + """Save visualization for PDF (first page) using existing detector.""" from shared.pdf.renderer import render_pdf_to_images - from ultralytics import YOLO - import io # Get storage helper for results directory storage = get_storage_helper() @@ -369,20 +380,19 @@ class InferenceService: pdf_path, dpi=self.model_config.dpi ): image = Image.open(io.BytesIO(image_bytes)) - temp_path = results_dir / f"{doc_id}_temp.png" - image.save(temp_path) - # Run YOLO and save visualization - model = YOLO(str(self.model_config.model_path)) - results = model.predict(str(temp_path), verbose=False) + # Use context manager for temp file to guarantee cleanup + with self._temp_image_file(results_dir, doc_id) as temp_path: + image.save(temp_path) - output_path = results_dir / f"{doc_id}_result.png" - for r in results: - r.save(filename=str(output_path)) + # Reuse self._detector instead of creating new YOLO instance + results = self._detector.model.predict(str(temp_path), verbose=False) - # Cleanup temp file - temp_path.unlink(missing_ok=True) - return output_path + output_path = results_dir / f"{doc_id}_result.png" + for r in results: + r.save(filename=str(output_path)) + + return output_path # If no pages rendered return None diff --git a/tests/domain/__init__.py b/tests/domain/__init__.py new file mode 100644 index 0000000..bf25238 --- /dev/null +++ b/tests/domain/__init__.py @@ -0,0 +1 @@ +# Domain layer tests diff --git a/tests/domain/test_document_classifier.py b/tests/domain/test_document_classifier.py new file mode 100644 index 0000000..3f20887 --- /dev/null +++ b/tests/domain/test_document_classifier.py @@ -0,0 +1,176 @@ +""" +Tests for DocumentClassifier - TDD RED phase. + +Test document type classification based on extracted fields. +""" +import pytest + +from backend.domain.document_classifier import DocumentClassifier, ClassificationResult + + +class TestDocumentClassifier: + """Test document classification logic.""" + + @pytest.fixture + def classifier(self) -> DocumentClassifier: + """Create classifier instance.""" + return DocumentClassifier() + + # ==================== Invoice Detection Tests ==================== + + def test_classify_with_payment_line_returns_invoice( + self, classifier: DocumentClassifier + ) -> None: + """Payment line is the strongest invoice indicator.""" + fields = {"payment_line": "# 123456 # 100 00 5 > 308-2963#"} + + result = classifier.classify(fields) + + assert result.document_type == "invoice" + assert result.confidence >= 0.9 + assert "payment_line" in result.reason + + def test_classify_with_multiple_indicators_returns_invoice( + self, classifier: DocumentClassifier + ) -> None: + """Multiple invoice indicators -> invoice with medium confidence.""" + fields = { + "Amount": "1200.00", + "Bankgiro": "123-4567", + "payment_line": None, + } + + result = classifier.classify(fields) + + assert result.document_type == "invoice" + assert result.confidence >= 0.7 + + def test_classify_with_ocr_and_amount_returns_invoice( + self, classifier: DocumentClassifier + ) -> None: + """OCR + Amount is typical invoice pattern.""" + fields = { + "OCR": "123456789012", + "Amount": "500.00", + } + + result = classifier.classify(fields) + + assert result.document_type == "invoice" + assert result.confidence >= 0.7 + + def test_classify_with_single_indicator_returns_invoice_lower_confidence( + self, classifier: DocumentClassifier + ) -> None: + """Single indicator -> invoice but lower confidence.""" + fields = {"Amount": "100.00"} + + result = classifier.classify(fields) + + assert result.document_type == "invoice" + assert 0.5 <= result.confidence < 0.8 + + def test_classify_with_invoice_number_only( + self, classifier: DocumentClassifier + ) -> None: + """Invoice number alone suggests invoice.""" + fields = {"InvoiceNumber": "INV-2024-001"} + + result = classifier.classify(fields) + + assert result.document_type == "invoice" + + # ==================== Letter Detection Tests ==================== + + def test_classify_with_no_indicators_returns_letter( + self, classifier: DocumentClassifier + ) -> None: + """No invoice indicators -> letter.""" + fields: dict[str, str | None] = {} + + result = classifier.classify(fields) + + assert result.document_type == "letter" + assert result.confidence >= 0.5 + + def test_classify_with_empty_fields_returns_letter( + self, classifier: DocumentClassifier + ) -> None: + """All fields empty or None -> letter.""" + fields = { + "payment_line": None, + "OCR": None, + "Amount": None, + "Bankgiro": None, + } + + result = classifier.classify(fields) + + assert result.document_type == "letter" + + def test_classify_with_only_non_indicator_fields_returns_letter( + self, classifier: DocumentClassifier + ) -> None: + """Fields that don't indicate invoice -> letter.""" + fields = { + "CustomerNumber": "C12345", + "SupplierOrgNumber": "556677-8899", + } + + result = classifier.classify(fields) + + assert result.document_type == "letter" + + # ==================== Edge Cases ==================== + + def test_classify_with_empty_string_fields_returns_letter( + self, classifier: DocumentClassifier + ) -> None: + """Empty strings should be treated as missing.""" + fields = { + "payment_line": "", + "Amount": "", + } + + result = classifier.classify(fields) + + assert result.document_type == "letter" + + def test_classify_with_whitespace_only_fields_returns_letter( + self, classifier: DocumentClassifier + ) -> None: + """Whitespace-only strings should be treated as missing.""" + fields = { + "payment_line": " ", + "Amount": "\t\n", + } + + result = classifier.classify(fields) + + assert result.document_type == "letter" + + # ==================== ClassificationResult Immutability ==================== + + def test_classification_result_is_immutable( + self, classifier: DocumentClassifier + ) -> None: + """ClassificationResult should be a frozen dataclass.""" + fields = {"payment_line": "test"} + result = classifier.classify(fields) + + with pytest.raises((AttributeError, TypeError)): + result.document_type = "modified" # type: ignore + + def test_classification_result_has_required_fields( + self, classifier: DocumentClassifier + ) -> None: + """ClassificationResult must have document_type, confidence, reason.""" + fields = {"Amount": "100.00"} + result = classifier.classify(fields) + + assert hasattr(result, "document_type") + assert hasattr(result, "confidence") + assert hasattr(result, "reason") + assert isinstance(result.document_type, str) + assert isinstance(result.confidence, float) + assert isinstance(result.reason, str) diff --git a/tests/domain/test_invoice_validator.py b/tests/domain/test_invoice_validator.py new file mode 100644 index 0000000..369ecdc --- /dev/null +++ b/tests/domain/test_invoice_validator.py @@ -0,0 +1,232 @@ +""" +Tests for InvoiceValidator - TDD RED phase. + +Test invoice field validation logic. +""" +import pytest + +from backend.domain.invoice_validator import ( + InvoiceValidator, + ValidationResult, + ValidationIssue, +) + + +class TestInvoiceValidator: + """Test invoice validation logic.""" + + @pytest.fixture + def validator(self) -> InvoiceValidator: + """Create validator instance with default settings.""" + return InvoiceValidator() + + @pytest.fixture + def validator_strict(self) -> InvoiceValidator: + """Create validator with strict confidence threshold.""" + return InvoiceValidator(min_confidence=0.8) + + # ==================== Valid Invoice Tests ==================== + + def test_validate_complete_invoice_is_valid( + self, validator: InvoiceValidator + ) -> None: + """Complete invoice with all required fields is valid.""" + fields = { + "Amount": "1200.00", + "OCR": "123456789012", + "Bankgiro": "123-4567", + } + confidence = { + "Amount": 0.95, + "OCR": 0.90, + "Bankgiro": 0.85, + } + + result = validator.validate(fields, confidence) + + assert result.is_valid is True + assert len([i for i in result.issues if i.severity == "error"]) == 0 + + def test_validate_invoice_with_payment_line_is_valid( + self, validator: InvoiceValidator + ) -> None: + """Invoice with payment_line as payment reference is valid.""" + fields = { + "Amount": "500.00", + "payment_line": "# 123 # 500 00 5 > 308#", + } + confidence = {"Amount": 0.9, "payment_line": 0.85} + + result = validator.validate(fields, confidence) + + assert result.is_valid is True + + # ==================== Invalid Invoice Tests ==================== + + def test_validate_missing_amount_is_invalid( + self, validator: InvoiceValidator + ) -> None: + """Missing Amount field should produce error.""" + fields = { + "OCR": "123456789012", + "Bankgiro": "123-4567", + } + confidence = {"OCR": 0.9, "Bankgiro": 0.85} + + result = validator.validate(fields, confidence) + + assert result.is_valid is False + error_fields = [i.field for i in result.issues if i.severity == "error"] + assert "Amount" in error_fields + + def test_validate_missing_payment_reference_produces_warning( + self, validator: InvoiceValidator + ) -> None: + """Missing all payment references should produce warning.""" + fields = {"Amount": "1200.00"} + confidence = {"Amount": 0.9} + + result = validator.validate(fields, confidence) + + # Missing payment ref is warning, not error + warning_fields = [i.field for i in result.issues if i.severity == "warning"] + assert "payment_reference" in warning_fields + + # ==================== Confidence Threshold Tests ==================== + + def test_validate_low_confidence_produces_warning( + self, validator: InvoiceValidator + ) -> None: + """Fields below confidence threshold should produce warning.""" + fields = { + "Amount": "1200.00", + "OCR": "123456789012", + } + confidence = { + "Amount": 0.9, + "OCR": 0.3, # Below default threshold of 0.5 + } + + result = validator.validate(fields, confidence) + + low_conf_warnings = [ + i for i in result.issues + if i.severity == "warning" and "confidence" in i.message.lower() + ] + assert len(low_conf_warnings) > 0 + + def test_validate_strict_threshold_more_warnings( + self, validator_strict: InvoiceValidator + ) -> None: + """Strict validator should produce more warnings.""" + fields = { + "Amount": "1200.00", + "OCR": "123456789012", + } + confidence = { + "Amount": 0.7, # Below 0.8 threshold + "OCR": 0.6, # Below 0.8 threshold + } + + result = validator_strict.validate(fields, confidence) + + low_conf_warnings = [ + i for i in result.issues + if i.severity == "warning" and "confidence" in i.message.lower() + ] + assert len(low_conf_warnings) >= 2 + + # ==================== Edge Cases ==================== + + def test_validate_empty_fields_is_invalid( + self, validator: InvoiceValidator + ) -> None: + """Empty fields dict should be invalid.""" + fields: dict[str, str | None] = {} + confidence: dict[str, float] = {} + + result = validator.validate(fields, confidence) + + assert result.is_valid is False + + def test_validate_none_field_values_treated_as_missing( + self, validator: InvoiceValidator + ) -> None: + """None values should be treated as missing.""" + fields = { + "Amount": None, + "OCR": "123456789012", + } + confidence = {"OCR": 0.9} + + result = validator.validate(fields, confidence) + + assert result.is_valid is False + error_fields = [i.field for i in result.issues if i.severity == "error"] + assert "Amount" in error_fields + + def test_validate_empty_string_treated_as_missing( + self, validator: InvoiceValidator + ) -> None: + """Empty string should be treated as missing.""" + fields = { + "Amount": "", + "OCR": "123456789012", + } + confidence = {"OCR": 0.9} + + result = validator.validate(fields, confidence) + + assert result.is_valid is False + + # ==================== ValidationResult Properties ==================== + + def test_validation_result_is_immutable( + self, validator: InvoiceValidator + ) -> None: + """ValidationResult should be a frozen dataclass.""" + fields = {"Amount": "100.00", "OCR": "123"} + confidence = {"Amount": 0.9, "OCR": 0.9} + result = validator.validate(fields, confidence) + + with pytest.raises((AttributeError, TypeError)): + result.is_valid = False # type: ignore + + def test_validation_result_issues_is_tuple( + self, validator: InvoiceValidator + ) -> None: + """Issues should be a tuple (immutable).""" + fields = {"Amount": "100.00"} + confidence = {"Amount": 0.9} + result = validator.validate(fields, confidence) + + assert isinstance(result.issues, tuple) + + def test_validation_result_has_confidence( + self, validator: InvoiceValidator + ) -> None: + """ValidationResult should have confidence score.""" + fields = {"Amount": "100.00", "OCR": "123"} + confidence = {"Amount": 0.9, "OCR": 0.8} + result = validator.validate(fields, confidence) + + assert hasattr(result, "confidence") + assert 0.0 <= result.confidence <= 1.0 + + # ==================== ValidationIssue Tests ==================== + + def test_validation_issue_has_required_fields( + self, validator: InvoiceValidator + ) -> None: + """ValidationIssue must have field, severity, message.""" + fields: dict[str, str | None] = {} + confidence: dict[str, float] = {} + result = validator.validate(fields, confidence) + + assert len(result.issues) > 0 + issue = result.issues[0] + + assert hasattr(issue, "field") + assert hasattr(issue, "severity") + assert hasattr(issue, "message") + assert issue.severity in ("error", "warning", "info") diff --git a/tests/web/test_inference_service.py b/tests/web/test_inference_service.py index 69f439a..93c6399 100644 --- a/tests/web/test_inference_service.py +++ b/tests/web/test_inference_service.py @@ -232,10 +232,8 @@ class TestInferenceServicePDFRendering: @patch('backend.pipeline.pipeline.InferencePipeline') @patch('backend.pipeline.yolo_detector.YOLODetector') @patch('shared.pdf.renderer.render_pdf_to_images') - @patch('ultralytics.YOLO') def test_pdf_visualization_imports_correctly( self, - mock_yolo_class, mock_render_pdf, mock_yolo_detector, mock_pipeline, @@ -248,12 +246,22 @@ class TestInferenceServicePDFRendering: This catches the import error we had with: from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct) """ - # Setup mocks + # Setup mocks for detector mock_detector_instance = Mock() - mock_pipeline_instance = Mock() + mock_model = Mock() + mock_result = Mock() + mock_result.save = Mock() + mock_model.predict.return_value = [mock_result] + mock_detector_instance.model = mock_model mock_yolo_detector.return_value = mock_detector_instance + + # Setup mock for pipeline + mock_pipeline_instance = Mock() mock_pipeline.return_value = mock_pipeline_instance + # Initialize service to setup _detector + inference_service.initialize() + # Create a fake PDF path pdf_path = tmp_path / "test.pdf" pdf_path.touch() @@ -264,18 +272,12 @@ class TestInferenceServicePDFRendering: img.save(image_bytes, format='PNG') mock_render_pdf.return_value = [(1, image_bytes.getvalue())] - # Mock YOLO - mock_model_instance = Mock() - mock_result = Mock() - mock_result.save = Mock() - mock_model_instance.predict.return_value = [mock_result] - mock_yolo_class.return_value = mock_model_instance - - # This should not raise ImportError + # This should not raise ImportError and should use self._detector.model result_path = inference_service._save_pdf_visualization(pdf_path, "test123") - # Verify import was successful + # Verify import was successful and detector.model was used mock_render_pdf.assert_called_once() + mock_model.predict.assert_called_once() assert result_path is not None