WIP
This commit is contained in:
25
packages/backend/backend/domain/__init__.py
Normal file
25
packages/backend/backend/domain/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Domain Layer
|
||||
|
||||
Business logic separated from technical implementation.
|
||||
Contains document classification and invoice validation logic.
|
||||
"""
|
||||
from backend.domain.document_classifier import (
|
||||
ClassificationResult,
|
||||
DocumentClassifier,
|
||||
)
|
||||
from backend.domain.invoice_validator import (
|
||||
InvoiceValidator,
|
||||
ValidationIssue,
|
||||
ValidationResult,
|
||||
)
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
__all__ = [
|
||||
"ClassificationResult",
|
||||
"DocumentClassifier",
|
||||
"InvoiceValidator",
|
||||
"ValidationIssue",
|
||||
"ValidationResult",
|
||||
"has_value",
|
||||
]
|
||||
108
packages/backend/backend/domain/document_classifier.py
Normal file
108
packages/backend/backend/domain/document_classifier.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Document Classifier
|
||||
|
||||
Business logic for classifying documents based on extracted fields.
|
||||
Separates classification logic from inference pipeline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClassificationResult:
|
||||
"""
|
||||
Immutable result of document classification.
|
||||
|
||||
Attributes:
|
||||
document_type: Either "invoice" or "letter"
|
||||
confidence: Confidence score between 0.0 and 1.0
|
||||
reason: Human-readable explanation of classification
|
||||
"""
|
||||
|
||||
document_type: str
|
||||
confidence: float
|
||||
reason: str
|
||||
|
||||
|
||||
class DocumentClassifier:
|
||||
"""
|
||||
Classifies documents as invoice or letter based on extracted fields.
|
||||
|
||||
Classification Rules:
|
||||
1. If payment_line is present -> invoice (high confidence)
|
||||
2. If 2+ invoice indicators present -> invoice (medium confidence)
|
||||
3. If 1 invoice indicator present -> invoice (lower confidence)
|
||||
4. Otherwise -> letter
|
||||
|
||||
Invoice indicator fields:
|
||||
- payment_line (strongest indicator)
|
||||
- OCR
|
||||
- Amount
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- InvoiceNumber
|
||||
"""
|
||||
|
||||
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"payment_line",
|
||||
"OCR",
|
||||
"Amount",
|
||||
"Bankgiro",
|
||||
"Plusgiro",
|
||||
"InvoiceNumber",
|
||||
}
|
||||
)
|
||||
|
||||
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
|
||||
"""
|
||||
Classify document type based on extracted fields.
|
||||
|
||||
Args:
|
||||
fields: Dictionary of field names to extracted values.
|
||||
Empty strings or whitespace-only strings are treated as missing.
|
||||
|
||||
Returns:
|
||||
Immutable ClassificationResult with type, confidence, and reason.
|
||||
"""
|
||||
# Rule 1: payment_line is the strongest indicator
|
||||
if has_value(fields.get("payment_line")):
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.95,
|
||||
reason="payment_line detected",
|
||||
)
|
||||
|
||||
# Count present invoice indicators (excluding payment_line already checked)
|
||||
present_indicators = [
|
||||
field
|
||||
for field in self.INVOICE_INDICATOR_FIELDS
|
||||
if field != "payment_line" and has_value(fields.get(field))
|
||||
]
|
||||
indicator_count = len(present_indicators)
|
||||
|
||||
# Rule 2: Multiple indicators -> invoice with medium-high confidence
|
||||
if indicator_count >= 2:
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.8,
|
||||
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
|
||||
)
|
||||
|
||||
# Rule 3: Single indicator -> invoice with lower confidence
|
||||
if indicator_count == 1:
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.6,
|
||||
reason=f"1 invoice indicator present: {present_indicators[0]}",
|
||||
)
|
||||
|
||||
# Rule 4: No indicators -> letter
|
||||
return ClassificationResult(
|
||||
document_type="letter",
|
||||
confidence=0.7,
|
||||
reason="no invoice indicators found",
|
||||
)
|
||||
141
packages/backend/backend/domain/invoice_validator.py
Normal file
141
packages/backend/backend/domain/invoice_validator.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Invoice Validator
|
||||
|
||||
Business logic for validating extracted invoice fields.
|
||||
Checks for required fields, format validity, and confidence thresholds.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidationIssue:
|
||||
"""
|
||||
Single validation issue.
|
||||
|
||||
Attributes:
|
||||
field: Name of the field with the issue
|
||||
severity: One of "error", "warning", "info"
|
||||
message: Human-readable description of the issue
|
||||
"""
|
||||
|
||||
field: str
|
||||
severity: str
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidationResult:
|
||||
"""
|
||||
Immutable result of invoice validation.
|
||||
|
||||
Attributes:
|
||||
is_valid: True if no errors (warnings are allowed)
|
||||
issues: Tuple of validation issues found
|
||||
confidence: Average confidence score of validated fields
|
||||
"""
|
||||
|
||||
is_valid: bool
|
||||
issues: tuple[ValidationIssue, ...]
|
||||
confidence: float
|
||||
|
||||
|
||||
class InvoiceValidator:
|
||||
"""
|
||||
Validates extracted invoice fields for completeness and consistency.
|
||||
|
||||
Validation Rules:
|
||||
1. Required fields must be present (Amount)
|
||||
2. At least one payment reference should be present (warning if missing)
|
||||
3. Field confidence should be above threshold (warning if below)
|
||||
|
||||
Required fields:
|
||||
- Amount
|
||||
|
||||
Payment reference fields (at least one expected):
|
||||
- OCR
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- payment_line
|
||||
"""
|
||||
|
||||
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
|
||||
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
|
||||
DEFAULT_MIN_CONFIDENCE: float = 0.5
|
||||
|
||||
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
|
||||
"""
|
||||
Initialize validator.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum confidence threshold for valid fields.
|
||||
Fields below this threshold produce warnings.
|
||||
"""
|
||||
self._min_confidence = min_confidence
|
||||
|
||||
def validate(
|
||||
self,
|
||||
fields: dict[str, str | None],
|
||||
confidence: dict[str, float],
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate extracted invoice fields.
|
||||
|
||||
Args:
|
||||
fields: Dictionary of field names to extracted values
|
||||
confidence: Dictionary of field names to confidence scores
|
||||
|
||||
Returns:
|
||||
Immutable ValidationResult with validity status and issues
|
||||
"""
|
||||
issues: list[ValidationIssue] = []
|
||||
|
||||
# Check required fields
|
||||
for field in self.REQUIRED_FIELDS:
|
||||
if not has_value(fields.get(field)):
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field=field,
|
||||
severity="error",
|
||||
message=f"Required field '{field}' is missing",
|
||||
)
|
||||
)
|
||||
|
||||
# Check payment reference (at least one expected)
|
||||
has_payment_ref = any(
|
||||
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
|
||||
)
|
||||
if not has_payment_ref:
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field="payment_reference",
|
||||
severity="warning",
|
||||
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
|
||||
)
|
||||
)
|
||||
|
||||
# Check confidence thresholds
|
||||
for field, conf in confidence.items():
|
||||
if conf < self._min_confidence:
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field=field,
|
||||
severity="warning",
|
||||
message=f"Low confidence ({conf:.2f}) for field '{field}'",
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate overall validity
|
||||
has_errors = any(i.severity == "error" for i in issues)
|
||||
avg_confidence = (
|
||||
sum(confidence.values()) / len(confidence) if confidence else 0.0
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=not has_errors,
|
||||
issues=tuple(issues),
|
||||
confidence=avg_confidence,
|
||||
)
|
||||
23
packages/backend/backend/domain/utils.py
Normal file
23
packages/backend/backend/domain/utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Domain Layer Utilities
|
||||
|
||||
Shared helper functions for domain layer classes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def has_value(value: str | None) -> bool:
|
||||
"""
|
||||
Check if a field value is present and non-empty.
|
||||
|
||||
Args:
|
||||
value: Field value to check
|
||||
|
||||
Returns:
|
||||
True if value is a non-empty, non-whitespace string
|
||||
"""
|
||||
if value is None:
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return bool(value)
|
||||
return bool(value.strip())
|
||||
@@ -1,21 +1,24 @@
|
||||
"""
|
||||
Inference Service
|
||||
Inference Service (Adapter Layer)
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
Orchestrates technical pipeline and business domain logic.
|
||||
Acts as adapter between API layer and internal components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Generator
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from backend.domain.document_classifier import DocumentClassifier
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,9 +53,12 @@ class ServiceResult:
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
Service for running invoice field extraction (Adapter Pattern).
|
||||
|
||||
Orchestrates:
|
||||
- Technical layer: InferencePipeline, YOLODetector
|
||||
- Business layer: DocumentClassifier
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
@@ -61,6 +67,7 @@ class InferenceService:
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
document_classifier: DocumentClassifier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
@@ -71,12 +78,19 @@ class InferenceService:
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
document_classifier: Optional custom classifier (uses default if None)
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
|
||||
# Technical layer (lazy initialized)
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
|
||||
# Business layer (eagerly initialized, no heavy resources)
|
||||
self._classifier = document_classifier or DocumentClassifier()
|
||||
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
self._business_features_enabled = False
|
||||
@@ -219,22 +233,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
@@ -293,22 +297,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Include business features if extracted
|
||||
if extract_line_items:
|
||||
@@ -329,10 +323,19 @@ class InferenceService:
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
def _format_detections(self, raw_detections: list) -> list[dict]:
|
||||
"""Format raw detections for response."""
|
||||
return [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in raw_detections
|
||||
]
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization image with detections using existing detector."""
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
@@ -340,9 +343,8 @@ class InferenceService:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
@@ -351,11 +353,20 @@ class InferenceService:
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
@contextmanager
|
||||
def _temp_image_file(
|
||||
self, results_dir: Path, doc_id: str
|
||||
) -> Generator[Path, None, None]:
|
||||
"""Context manager for temporary image file with guaranteed cleanup."""
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
try:
|
||||
yield temp_path
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization for PDF (first page) using existing detector."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
@@ -369,20 +380,19 @@ class InferenceService:
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
# Use context manager for temp file to guarantee cleanup
|
||||
with self._temp_image_file(results_dir, doc_id) as temp_path:
|
||||
image.save(temp_path)
|
||||
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(temp_path), verbose=False)
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
|
||||
1
tests/domain/__init__.py
Normal file
1
tests/domain/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Domain layer tests
|
||||
176
tests/domain/test_document_classifier.py
Normal file
176
tests/domain/test_document_classifier.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Tests for DocumentClassifier - TDD RED phase.
|
||||
|
||||
Test document type classification based on extracted fields.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from backend.domain.document_classifier import DocumentClassifier, ClassificationResult
|
||||
|
||||
|
||||
class TestDocumentClassifier:
|
||||
"""Test document classification logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def classifier(self) -> DocumentClassifier:
|
||||
"""Create classifier instance."""
|
||||
return DocumentClassifier()
|
||||
|
||||
# ==================== Invoice Detection Tests ====================
|
||||
|
||||
def test_classify_with_payment_line_returns_invoice(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Payment line is the strongest invoice indicator."""
|
||||
fields = {"payment_line": "# 123456 # 100 00 5 > 308-2963#"}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert result.confidence >= 0.9
|
||||
assert "payment_line" in result.reason
|
||||
|
||||
def test_classify_with_multiple_indicators_returns_invoice(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Multiple invoice indicators -> invoice with medium confidence."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"Bankgiro": "123-4567",
|
||||
"payment_line": None,
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert result.confidence >= 0.7
|
||||
|
||||
def test_classify_with_ocr_and_amount_returns_invoice(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""OCR + Amount is typical invoice pattern."""
|
||||
fields = {
|
||||
"OCR": "123456789012",
|
||||
"Amount": "500.00",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert result.confidence >= 0.7
|
||||
|
||||
def test_classify_with_single_indicator_returns_invoice_lower_confidence(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Single indicator -> invoice but lower confidence."""
|
||||
fields = {"Amount": "100.00"}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert 0.5 <= result.confidence < 0.8
|
||||
|
||||
def test_classify_with_invoice_number_only(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Invoice number alone suggests invoice."""
|
||||
fields = {"InvoiceNumber": "INV-2024-001"}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
|
||||
# ==================== Letter Detection Tests ====================
|
||||
|
||||
def test_classify_with_no_indicators_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""No invoice indicators -> letter."""
|
||||
fields: dict[str, str | None] = {}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
assert result.confidence >= 0.5
|
||||
|
||||
def test_classify_with_empty_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""All fields empty or None -> letter."""
|
||||
fields = {
|
||||
"payment_line": None,
|
||||
"OCR": None,
|
||||
"Amount": None,
|
||||
"Bankgiro": None,
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
def test_classify_with_only_non_indicator_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Fields that don't indicate invoice -> letter."""
|
||||
fields = {
|
||||
"CustomerNumber": "C12345",
|
||||
"SupplierOrgNumber": "556677-8899",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
# ==================== Edge Cases ====================
|
||||
|
||||
def test_classify_with_empty_string_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Empty strings should be treated as missing."""
|
||||
fields = {
|
||||
"payment_line": "",
|
||||
"Amount": "",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
def test_classify_with_whitespace_only_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Whitespace-only strings should be treated as missing."""
|
||||
fields = {
|
||||
"payment_line": " ",
|
||||
"Amount": "\t\n",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
# ==================== ClassificationResult Immutability ====================
|
||||
|
||||
def test_classification_result_is_immutable(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""ClassificationResult should be a frozen dataclass."""
|
||||
fields = {"payment_line": "test"}
|
||||
result = classifier.classify(fields)
|
||||
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
result.document_type = "modified" # type: ignore
|
||||
|
||||
def test_classification_result_has_required_fields(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""ClassificationResult must have document_type, confidence, reason."""
|
||||
fields = {"Amount": "100.00"}
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert hasattr(result, "document_type")
|
||||
assert hasattr(result, "confidence")
|
||||
assert hasattr(result, "reason")
|
||||
assert isinstance(result.document_type, str)
|
||||
assert isinstance(result.confidence, float)
|
||||
assert isinstance(result.reason, str)
|
||||
232
tests/domain/test_invoice_validator.py
Normal file
232
tests/domain/test_invoice_validator.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tests for InvoiceValidator - TDD RED phase.
|
||||
|
||||
Test invoice field validation logic.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from backend.domain.invoice_validator import (
|
||||
InvoiceValidator,
|
||||
ValidationResult,
|
||||
ValidationIssue,
|
||||
)
|
||||
|
||||
|
||||
class TestInvoiceValidator:
|
||||
"""Test invoice validation logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def validator(self) -> InvoiceValidator:
|
||||
"""Create validator instance with default settings."""
|
||||
return InvoiceValidator()
|
||||
|
||||
@pytest.fixture
|
||||
def validator_strict(self) -> InvoiceValidator:
|
||||
"""Create validator with strict confidence threshold."""
|
||||
return InvoiceValidator(min_confidence=0.8)
|
||||
|
||||
# ==================== Valid Invoice Tests ====================
|
||||
|
||||
def test_validate_complete_invoice_is_valid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Complete invoice with all required fields is valid."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"OCR": "123456789012",
|
||||
"Bankgiro": "123-4567",
|
||||
}
|
||||
confidence = {
|
||||
"Amount": 0.95,
|
||||
"OCR": 0.90,
|
||||
"Bankgiro": 0.85,
|
||||
}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert len([i for i in result.issues if i.severity == "error"]) == 0
|
||||
|
||||
def test_validate_invoice_with_payment_line_is_valid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Invoice with payment_line as payment reference is valid."""
|
||||
fields = {
|
||||
"Amount": "500.00",
|
||||
"payment_line": "# 123 # 500 00 5 > 308#",
|
||||
}
|
||||
confidence = {"Amount": 0.9, "payment_line": 0.85}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
# ==================== Invalid Invoice Tests ====================
|
||||
|
||||
def test_validate_missing_amount_is_invalid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Missing Amount field should produce error."""
|
||||
fields = {
|
||||
"OCR": "123456789012",
|
||||
"Bankgiro": "123-4567",
|
||||
}
|
||||
confidence = {"OCR": 0.9, "Bankgiro": 0.85}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
error_fields = [i.field for i in result.issues if i.severity == "error"]
|
||||
assert "Amount" in error_fields
|
||||
|
||||
def test_validate_missing_payment_reference_produces_warning(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Missing all payment references should produce warning."""
|
||||
fields = {"Amount": "1200.00"}
|
||||
confidence = {"Amount": 0.9}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
# Missing payment ref is warning, not error
|
||||
warning_fields = [i.field for i in result.issues if i.severity == "warning"]
|
||||
assert "payment_reference" in warning_fields
|
||||
|
||||
# ==================== Confidence Threshold Tests ====================
|
||||
|
||||
def test_validate_low_confidence_produces_warning(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Fields below confidence threshold should produce warning."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {
|
||||
"Amount": 0.9,
|
||||
"OCR": 0.3, # Below default threshold of 0.5
|
||||
}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
low_conf_warnings = [
|
||||
i for i in result.issues
|
||||
if i.severity == "warning" and "confidence" in i.message.lower()
|
||||
]
|
||||
assert len(low_conf_warnings) > 0
|
||||
|
||||
def test_validate_strict_threshold_more_warnings(
|
||||
self, validator_strict: InvoiceValidator
|
||||
) -> None:
|
||||
"""Strict validator should produce more warnings."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {
|
||||
"Amount": 0.7, # Below 0.8 threshold
|
||||
"OCR": 0.6, # Below 0.8 threshold
|
||||
}
|
||||
|
||||
result = validator_strict.validate(fields, confidence)
|
||||
|
||||
low_conf_warnings = [
|
||||
i for i in result.issues
|
||||
if i.severity == "warning" and "confidence" in i.message.lower()
|
||||
]
|
||||
assert len(low_conf_warnings) >= 2
|
||||
|
||||
# ==================== Edge Cases ====================
|
||||
|
||||
def test_validate_empty_fields_is_invalid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Empty fields dict should be invalid."""
|
||||
fields: dict[str, str | None] = {}
|
||||
confidence: dict[str, float] = {}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
|
||||
def test_validate_none_field_values_treated_as_missing(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""None values should be treated as missing."""
|
||||
fields = {
|
||||
"Amount": None,
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {"OCR": 0.9}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
error_fields = [i.field for i in result.issues if i.severity == "error"]
|
||||
assert "Amount" in error_fields
|
||||
|
||||
def test_validate_empty_string_treated_as_missing(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Empty string should be treated as missing."""
|
||||
fields = {
|
||||
"Amount": "",
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {"OCR": 0.9}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
|
||||
# ==================== ValidationResult Properties ====================
|
||||
|
||||
def test_validation_result_is_immutable(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""ValidationResult should be a frozen dataclass."""
|
||||
fields = {"Amount": "100.00", "OCR": "123"}
|
||||
confidence = {"Amount": 0.9, "OCR": 0.9}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
result.is_valid = False # type: ignore
|
||||
|
||||
def test_validation_result_issues_is_tuple(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Issues should be a tuple (immutable)."""
|
||||
fields = {"Amount": "100.00"}
|
||||
confidence = {"Amount": 0.9}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert isinstance(result.issues, tuple)
|
||||
|
||||
def test_validation_result_has_confidence(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""ValidationResult should have confidence score."""
|
||||
fields = {"Amount": "100.00", "OCR": "123"}
|
||||
confidence = {"Amount": 0.9, "OCR": 0.8}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert hasattr(result, "confidence")
|
||||
assert 0.0 <= result.confidence <= 1.0
|
||||
|
||||
# ==================== ValidationIssue Tests ====================
|
||||
|
||||
def test_validation_issue_has_required_fields(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""ValidationIssue must have field, severity, message."""
|
||||
fields: dict[str, str | None] = {}
|
||||
confidence: dict[str, float] = {}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert len(result.issues) > 0
|
||||
issue = result.issues[0]
|
||||
|
||||
assert hasattr(issue, "field")
|
||||
assert hasattr(issue, "severity")
|
||||
assert hasattr(issue, "message")
|
||||
assert issue.severity in ("error", "warning", "info")
|
||||
@@ -232,10 +232,8 @@ class TestInferenceServicePDFRendering:
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_pdf_visualization_imports_correctly(
|
||||
self,
|
||||
mock_yolo_class,
|
||||
mock_render_pdf,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
@@ -248,12 +246,22 @@ class TestInferenceServicePDFRendering:
|
||||
This catches the import error we had with:
|
||||
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
|
||||
"""
|
||||
# Setup mocks
|
||||
# Setup mocks for detector
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_model = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.save = Mock()
|
||||
mock_model.predict.return_value = [mock_result]
|
||||
mock_detector_instance.model = mock_model
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
|
||||
# Setup mock for pipeline
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Initialize service to setup _detector
|
||||
inference_service.initialize()
|
||||
|
||||
# Create a fake PDF path
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
pdf_path.touch()
|
||||
@@ -264,18 +272,12 @@ class TestInferenceServicePDFRendering:
|
||||
img.save(image_bytes, format='PNG')
|
||||
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
|
||||
|
||||
# Mock YOLO
|
||||
mock_model_instance = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.save = Mock()
|
||||
mock_model_instance.predict.return_value = [mock_result]
|
||||
mock_yolo_class.return_value = mock_model_instance
|
||||
|
||||
# This should not raise ImportError
|
||||
# This should not raise ImportError and should use self._detector.model
|
||||
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
|
||||
|
||||
# Verify import was successful
|
||||
# Verify import was successful and detector.model was used
|
||||
mock_render_pdf.assert_called_once()
|
||||
mock_model.predict.assert_called_once()
|
||||
assert result_path is not None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user