WIP
This commit is contained in:
25
packages/backend/backend/domain/__init__.py
Normal file
25
packages/backend/backend/domain/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
Domain Layer
|
||||||
|
|
||||||
|
Business logic separated from technical implementation.
|
||||||
|
Contains document classification and invoice validation logic.
|
||||||
|
"""
|
||||||
|
from backend.domain.document_classifier import (
|
||||||
|
ClassificationResult,
|
||||||
|
DocumentClassifier,
|
||||||
|
)
|
||||||
|
from backend.domain.invoice_validator import (
|
||||||
|
InvoiceValidator,
|
||||||
|
ValidationIssue,
|
||||||
|
ValidationResult,
|
||||||
|
)
|
||||||
|
from backend.domain.utils import has_value
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClassificationResult",
|
||||||
|
"DocumentClassifier",
|
||||||
|
"InvoiceValidator",
|
||||||
|
"ValidationIssue",
|
||||||
|
"ValidationResult",
|
||||||
|
"has_value",
|
||||||
|
]
|
||||||
108
packages/backend/backend/domain/document_classifier.py
Normal file
108
packages/backend/backend/domain/document_classifier.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Document Classifier
|
||||||
|
|
||||||
|
Business logic for classifying documents based on extracted fields.
|
||||||
|
Separates classification logic from inference pipeline.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from backend.domain.utils import has_value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ClassificationResult:
|
||||||
|
"""
|
||||||
|
Immutable result of document classification.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
document_type: Either "invoice" or "letter"
|
||||||
|
confidence: Confidence score between 0.0 and 1.0
|
||||||
|
reason: Human-readable explanation of classification
|
||||||
|
"""
|
||||||
|
|
||||||
|
document_type: str
|
||||||
|
confidence: float
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentClassifier:
|
||||||
|
"""
|
||||||
|
Classifies documents as invoice or letter based on extracted fields.
|
||||||
|
|
||||||
|
Classification Rules:
|
||||||
|
1. If payment_line is present -> invoice (high confidence)
|
||||||
|
2. If 2+ invoice indicators present -> invoice (medium confidence)
|
||||||
|
3. If 1 invoice indicator present -> invoice (lower confidence)
|
||||||
|
4. Otherwise -> letter
|
||||||
|
|
||||||
|
Invoice indicator fields:
|
||||||
|
- payment_line (strongest indicator)
|
||||||
|
- OCR
|
||||||
|
- Amount
|
||||||
|
- Bankgiro
|
||||||
|
- Plusgiro
|
||||||
|
- InvoiceNumber
|
||||||
|
"""
|
||||||
|
|
||||||
|
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"payment_line",
|
||||||
|
"OCR",
|
||||||
|
"Amount",
|
||||||
|
"Bankgiro",
|
||||||
|
"Plusgiro",
|
||||||
|
"InvoiceNumber",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
|
||||||
|
"""
|
||||||
|
Classify document type based on extracted fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields: Dictionary of field names to extracted values.
|
||||||
|
Empty strings or whitespace-only strings are treated as missing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Immutable ClassificationResult with type, confidence, and reason.
|
||||||
|
"""
|
||||||
|
# Rule 1: payment_line is the strongest indicator
|
||||||
|
if has_value(fields.get("payment_line")):
|
||||||
|
return ClassificationResult(
|
||||||
|
document_type="invoice",
|
||||||
|
confidence=0.95,
|
||||||
|
reason="payment_line detected",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count present invoice indicators (excluding payment_line already checked)
|
||||||
|
present_indicators = [
|
||||||
|
field
|
||||||
|
for field in self.INVOICE_INDICATOR_FIELDS
|
||||||
|
if field != "payment_line" and has_value(fields.get(field))
|
||||||
|
]
|
||||||
|
indicator_count = len(present_indicators)
|
||||||
|
|
||||||
|
# Rule 2: Multiple indicators -> invoice with medium-high confidence
|
||||||
|
if indicator_count >= 2:
|
||||||
|
return ClassificationResult(
|
||||||
|
document_type="invoice",
|
||||||
|
confidence=0.8,
|
||||||
|
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rule 3: Single indicator -> invoice with lower confidence
|
||||||
|
if indicator_count == 1:
|
||||||
|
return ClassificationResult(
|
||||||
|
document_type="invoice",
|
||||||
|
confidence=0.6,
|
||||||
|
reason=f"1 invoice indicator present: {present_indicators[0]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rule 4: No indicators -> letter
|
||||||
|
return ClassificationResult(
|
||||||
|
document_type="letter",
|
||||||
|
confidence=0.7,
|
||||||
|
reason="no invoice indicators found",
|
||||||
|
)
|
||||||
141
packages/backend/backend/domain/invoice_validator.py
Normal file
141
packages/backend/backend/domain/invoice_validator.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Invoice Validator
|
||||||
|
|
||||||
|
Business logic for validating extracted invoice fields.
|
||||||
|
Checks for required fields, format validity, and confidence thresholds.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from backend.domain.utils import has_value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ValidationIssue:
|
||||||
|
"""
|
||||||
|
Single validation issue.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
field: Name of the field with the issue
|
||||||
|
severity: One of "error", "warning", "info"
|
||||||
|
message: Human-readable description of the issue
|
||||||
|
"""
|
||||||
|
|
||||||
|
field: str
|
||||||
|
severity: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ValidationResult:
|
||||||
|
"""
|
||||||
|
Immutable result of invoice validation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
is_valid: True if no errors (warnings are allowed)
|
||||||
|
issues: Tuple of validation issues found
|
||||||
|
confidence: Average confidence score of validated fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_valid: bool
|
||||||
|
issues: tuple[ValidationIssue, ...]
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
class InvoiceValidator:
|
||||||
|
"""
|
||||||
|
Validates extracted invoice fields for completeness and consistency.
|
||||||
|
|
||||||
|
Validation Rules:
|
||||||
|
1. Required fields must be present (Amount)
|
||||||
|
2. At least one payment reference should be present (warning if missing)
|
||||||
|
3. Field confidence should be above threshold (warning if below)
|
||||||
|
|
||||||
|
Required fields:
|
||||||
|
- Amount
|
||||||
|
|
||||||
|
Payment reference fields (at least one expected):
|
||||||
|
- OCR
|
||||||
|
- Bankgiro
|
||||||
|
- Plusgiro
|
||||||
|
- payment_line
|
||||||
|
"""
|
||||||
|
|
||||||
|
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
|
||||||
|
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
|
||||||
|
DEFAULT_MIN_CONFIDENCE: float = 0.5
|
||||||
|
|
||||||
|
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
|
||||||
|
"""
|
||||||
|
Initialize validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_confidence: Minimum confidence threshold for valid fields.
|
||||||
|
Fields below this threshold produce warnings.
|
||||||
|
"""
|
||||||
|
self._min_confidence = min_confidence
|
||||||
|
|
||||||
|
def validate(
|
||||||
|
self,
|
||||||
|
fields: dict[str, str | None],
|
||||||
|
confidence: dict[str, float],
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate extracted invoice fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields: Dictionary of field names to extracted values
|
||||||
|
confidence: Dictionary of field names to confidence scores
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Immutable ValidationResult with validity status and issues
|
||||||
|
"""
|
||||||
|
issues: list[ValidationIssue] = []
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
for field in self.REQUIRED_FIELDS:
|
||||||
|
if not has_value(fields.get(field)):
|
||||||
|
issues.append(
|
||||||
|
ValidationIssue(
|
||||||
|
field=field,
|
||||||
|
severity="error",
|
||||||
|
message=f"Required field '{field}' is missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check payment reference (at least one expected)
|
||||||
|
has_payment_ref = any(
|
||||||
|
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
|
||||||
|
)
|
||||||
|
if not has_payment_ref:
|
||||||
|
issues.append(
|
||||||
|
ValidationIssue(
|
||||||
|
field="payment_reference",
|
||||||
|
severity="warning",
|
||||||
|
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check confidence thresholds
|
||||||
|
for field, conf in confidence.items():
|
||||||
|
if conf < self._min_confidence:
|
||||||
|
issues.append(
|
||||||
|
ValidationIssue(
|
||||||
|
field=field,
|
||||||
|
severity="warning",
|
||||||
|
message=f"Low confidence ({conf:.2f}) for field '{field}'",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate overall validity
|
||||||
|
has_errors = any(i.severity == "error" for i in issues)
|
||||||
|
avg_confidence = (
|
||||||
|
sum(confidence.values()) / len(confidence) if confidence else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=not has_errors,
|
||||||
|
issues=tuple(issues),
|
||||||
|
confidence=avg_confidence,
|
||||||
|
)
|
||||||
23
packages/backend/backend/domain/utils.py
Normal file
23
packages/backend/backend/domain/utils.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
Domain Layer Utilities
|
||||||
|
|
||||||
|
Shared helper functions for domain layer classes.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def has_value(value: str | None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a field value is present and non-empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Field value to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if value is a non-empty, non-whitespace string
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return bool(value)
|
||||||
|
return bool(value.strip())
|
||||||
@@ -1,21 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
Inference Service
|
Inference Service (Adapter Layer)
|
||||||
|
|
||||||
Business logic for invoice field extraction.
|
Orchestrates technical pipeline and business domain logic.
|
||||||
|
Acts as adapter between API layer and internal components.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable, Generator
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from backend.domain.document_classifier import DocumentClassifier
|
||||||
from backend.web.services.storage_helpers import get_storage_helper
|
from backend.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -50,9 +53,12 @@ class ServiceResult:
|
|||||||
|
|
||||||
class InferenceService:
|
class InferenceService:
|
||||||
"""
|
"""
|
||||||
Service for running invoice field extraction.
|
Service for running invoice field extraction (Adapter Pattern).
|
||||||
|
|
||||||
|
Orchestrates:
|
||||||
|
- Technical layer: InferencePipeline, YOLODetector
|
||||||
|
- Business layer: DocumentClassifier
|
||||||
|
|
||||||
Encapsulates YOLO detection and OCR extraction logic.
|
|
||||||
Supports dynamic model loading from database.
|
Supports dynamic model loading from database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -61,6 +67,7 @@ class InferenceService:
|
|||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
storage_config: StorageConfig,
|
storage_config: StorageConfig,
|
||||||
model_path_resolver: ModelPathResolver | None = None,
|
model_path_resolver: ModelPathResolver | None = None,
|
||||||
|
document_classifier: DocumentClassifier | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize inference service.
|
Initialize inference service.
|
||||||
@@ -71,12 +78,19 @@ class InferenceService:
|
|||||||
model_path_resolver: Optional function to resolve model path from database.
|
model_path_resolver: Optional function to resolve model path from database.
|
||||||
If provided, will be called to get active model path.
|
If provided, will be called to get active model path.
|
||||||
If returns None, falls back to model_config.model_path.
|
If returns None, falls back to model_config.model_path.
|
||||||
|
document_classifier: Optional custom classifier (uses default if None)
|
||||||
"""
|
"""
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.storage_config = storage_config
|
self.storage_config = storage_config
|
||||||
self._model_path_resolver = model_path_resolver
|
self._model_path_resolver = model_path_resolver
|
||||||
|
|
||||||
|
# Technical layer (lazy initialized)
|
||||||
self._pipeline = None
|
self._pipeline = None
|
||||||
self._detector = None
|
self._detector = None
|
||||||
|
|
||||||
|
# Business layer (eagerly initialized, no heavy resources)
|
||||||
|
self._classifier = document_classifier or DocumentClassifier()
|
||||||
|
|
||||||
self._is_initialized = False
|
self._is_initialized = False
|
||||||
self._current_model_path: Path | None = None
|
self._current_model_path: Path | None = None
|
||||||
self._business_features_enabled = False
|
self._business_features_enabled = False
|
||||||
@@ -219,22 +233,12 @@ class InferenceService:
|
|||||||
result.success = pipeline_result.success
|
result.success = pipeline_result.success
|
||||||
result.errors = pipeline_result.errors
|
result.errors = pipeline_result.errors
|
||||||
|
|
||||||
# Determine document type based on payment_line presence
|
# Business layer: classify document type
|
||||||
# If no payment_line found, it's likely a letter, not an invoice
|
classification = self._classifier.classify(result.fields)
|
||||||
if not result.fields.get('payment_line'):
|
result.document_type = classification.document_type
|
||||||
result.document_type = "letter"
|
|
||||||
else:
|
|
||||||
result.document_type = "invoice"
|
|
||||||
|
|
||||||
# Get raw detections for visualization
|
# Get raw detections for visualization
|
||||||
result.detections = [
|
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||||
{
|
|
||||||
"field": d.class_name,
|
|
||||||
"confidence": d.confidence,
|
|
||||||
"bbox": list(d.bbox),
|
|
||||||
}
|
|
||||||
for d in pipeline_result.raw_detections
|
|
||||||
]
|
|
||||||
|
|
||||||
# Save visualization if requested
|
# Save visualization if requested
|
||||||
if save_visualization and pipeline_result.raw_detections:
|
if save_visualization and pipeline_result.raw_detections:
|
||||||
@@ -293,22 +297,12 @@ class InferenceService:
|
|||||||
result.success = pipeline_result.success
|
result.success = pipeline_result.success
|
||||||
result.errors = pipeline_result.errors
|
result.errors = pipeline_result.errors
|
||||||
|
|
||||||
# Determine document type based on payment_line presence
|
# Business layer: classify document type
|
||||||
# If no payment_line found, it's likely a letter, not an invoice
|
classification = self._classifier.classify(result.fields)
|
||||||
if not result.fields.get('payment_line'):
|
result.document_type = classification.document_type
|
||||||
result.document_type = "letter"
|
|
||||||
else:
|
|
||||||
result.document_type = "invoice"
|
|
||||||
|
|
||||||
# Get raw detections
|
# Get raw detections
|
||||||
result.detections = [
|
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||||
{
|
|
||||||
"field": d.class_name,
|
|
||||||
"confidence": d.confidence,
|
|
||||||
"bbox": list(d.bbox),
|
|
||||||
}
|
|
||||||
for d in pipeline_result.raw_detections
|
|
||||||
]
|
|
||||||
|
|
||||||
# Include business features if extracted
|
# Include business features if extracted
|
||||||
if extract_line_items:
|
if extract_line_items:
|
||||||
@@ -329,10 +323,19 @@ class InferenceService:
|
|||||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
def _format_detections(self, raw_detections: list) -> list[dict]:
|
||||||
"""Save visualization image with detections."""
|
"""Format raw detections for response."""
|
||||||
from ultralytics import YOLO
|
return [
|
||||||
|
{
|
||||||
|
"field": d.class_name,
|
||||||
|
"confidence": d.confidence,
|
||||||
|
"bbox": list(d.bbox),
|
||||||
|
}
|
||||||
|
for d in raw_detections
|
||||||
|
]
|
||||||
|
|
||||||
|
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
|
||||||
|
"""Save visualization image with detections using existing detector."""
|
||||||
# Get storage helper for results directory
|
# Get storage helper for results directory
|
||||||
storage = get_storage_helper()
|
storage = get_storage_helper()
|
||||||
results_dir = storage.get_results_base_path()
|
results_dir = storage.get_results_base_path()
|
||||||
@@ -340,9 +343,8 @@ class InferenceService:
|
|||||||
logger.warning("Cannot save visualization: local storage not available")
|
logger.warning("Cannot save visualization: local storage not available")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Load model and run prediction with visualization
|
# Reuse self._detector instead of creating new YOLO instance
|
||||||
model = YOLO(str(self.model_config.model_path))
|
results = self._detector.model.predict(str(image_path), verbose=False)
|
||||||
results = model.predict(str(image_path), verbose=False)
|
|
||||||
|
|
||||||
# Save annotated image
|
# Save annotated image
|
||||||
output_path = results_dir / f"{doc_id}_result.png"
|
output_path = results_dir / f"{doc_id}_result.png"
|
||||||
@@ -351,11 +353,20 @@ class InferenceService:
|
|||||||
|
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
@contextmanager
|
||||||
"""Save visualization for PDF (first page)."""
|
def _temp_image_file(
|
||||||
|
self, results_dir: Path, doc_id: str
|
||||||
|
) -> Generator[Path, None, None]:
|
||||||
|
"""Context manager for temporary image file with guaranteed cleanup."""
|
||||||
|
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||||
|
try:
|
||||||
|
yield temp_path
|
||||||
|
finally:
|
||||||
|
temp_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
|
||||||
|
"""Save visualization for PDF (first page) using existing detector."""
|
||||||
from shared.pdf.renderer import render_pdf_to_images
|
from shared.pdf.renderer import render_pdf_to_images
|
||||||
from ultralytics import YOLO
|
|
||||||
import io
|
|
||||||
|
|
||||||
# Get storage helper for results directory
|
# Get storage helper for results directory
|
||||||
storage = get_storage_helper()
|
storage = get_storage_helper()
|
||||||
@@ -369,20 +380,19 @@ class InferenceService:
|
|||||||
pdf_path, dpi=self.model_config.dpi
|
pdf_path, dpi=self.model_config.dpi
|
||||||
):
|
):
|
||||||
image = Image.open(io.BytesIO(image_bytes))
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
|
||||||
image.save(temp_path)
|
|
||||||
|
|
||||||
# Run YOLO and save visualization
|
# Use context manager for temp file to guarantee cleanup
|
||||||
model = YOLO(str(self.model_config.model_path))
|
with self._temp_image_file(results_dir, doc_id) as temp_path:
|
||||||
results = model.predict(str(temp_path), verbose=False)
|
image.save(temp_path)
|
||||||
|
|
||||||
output_path = results_dir / f"{doc_id}_result.png"
|
# Reuse self._detector instead of creating new YOLO instance
|
||||||
for r in results:
|
results = self._detector.model.predict(str(temp_path), verbose=False)
|
||||||
r.save(filename=str(output_path))
|
|
||||||
|
|
||||||
# Cleanup temp file
|
output_path = results_dir / f"{doc_id}_result.png"
|
||||||
temp_path.unlink(missing_ok=True)
|
for r in results:
|
||||||
return output_path
|
r.save(filename=str(output_path))
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
# If no pages rendered
|
# If no pages rendered
|
||||||
return None
|
return None
|
||||||
|
|||||||
1
tests/domain/__init__.py
Normal file
1
tests/domain/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Domain layer tests
|
||||||
176
tests/domain/test_document_classifier.py
Normal file
176
tests/domain/test_document_classifier.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""
|
||||||
|
Tests for DocumentClassifier - TDD RED phase.
|
||||||
|
|
||||||
|
Test document type classification based on extracted fields.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.domain.document_classifier import DocumentClassifier, ClassificationResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentClassifier:
|
||||||
|
"""Test document classification logic."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def classifier(self) -> DocumentClassifier:
|
||||||
|
"""Create classifier instance."""
|
||||||
|
return DocumentClassifier()
|
||||||
|
|
||||||
|
# ==================== Invoice Detection Tests ====================
|
||||||
|
|
||||||
|
def test_classify_with_payment_line_returns_invoice(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""Payment line is the strongest invoice indicator."""
|
||||||
|
fields = {"payment_line": "# 123456 # 100 00 5 > 308-2963#"}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "invoice"
|
||||||
|
assert result.confidence >= 0.9
|
||||||
|
assert "payment_line" in result.reason
|
||||||
|
|
||||||
|
def test_classify_with_multiple_indicators_returns_invoice(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""Multiple invoice indicators -> invoice with medium confidence."""
|
||||||
|
fields = {
|
||||||
|
"Amount": "1200.00",
|
||||||
|
"Bankgiro": "123-4567",
|
||||||
|
"payment_line": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "invoice"
|
||||||
|
assert result.confidence >= 0.7
|
||||||
|
|
||||||
|
def test_classify_with_ocr_and_amount_returns_invoice(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""OCR + Amount is typical invoice pattern."""
|
||||||
|
fields = {
|
||||||
|
"OCR": "123456789012",
|
||||||
|
"Amount": "500.00",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "invoice"
|
||||||
|
assert result.confidence >= 0.7
|
||||||
|
|
||||||
|
def test_classify_with_single_indicator_returns_invoice_lower_confidence(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""Single indicator -> invoice but lower confidence."""
|
||||||
|
fields = {"Amount": "100.00"}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "invoice"
|
||||||
|
assert 0.5 <= result.confidence < 0.8
|
||||||
|
|
||||||
|
def test_classify_with_invoice_number_only(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""Invoice number alone suggests invoice."""
|
||||||
|
fields = {"InvoiceNumber": "INV-2024-001"}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "invoice"
|
||||||
|
|
||||||
|
# ==================== Letter Detection Tests ====================
|
||||||
|
|
||||||
|
def test_classify_with_no_indicators_returns_letter(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""No invoice indicators -> letter."""
|
||||||
|
fields: dict[str, str | None] = {}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "letter"
|
||||||
|
assert result.confidence >= 0.5
|
||||||
|
|
||||||
|
def test_classify_with_empty_fields_returns_letter(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""All fields empty or None -> letter."""
|
||||||
|
fields = {
|
||||||
|
"payment_line": None,
|
||||||
|
"OCR": None,
|
||||||
|
"Amount": None,
|
||||||
|
"Bankgiro": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "letter"
|
||||||
|
|
||||||
|
def test_classify_with_only_non_indicator_fields_returns_letter(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""Fields that don't indicate invoice -> letter."""
|
||||||
|
fields = {
|
||||||
|
"CustomerNumber": "C12345",
|
||||||
|
"SupplierOrgNumber": "556677-8899",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "letter"
|
||||||
|
|
||||||
|
# ==================== Edge Cases ====================
|
||||||
|
|
||||||
|
def test_classify_with_empty_string_fields_returns_letter(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""Empty strings should be treated as missing."""
|
||||||
|
fields = {
|
||||||
|
"payment_line": "",
|
||||||
|
"Amount": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "letter"
|
||||||
|
|
||||||
|
def test_classify_with_whitespace_only_fields_returns_letter(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""Whitespace-only strings should be treated as missing."""
|
||||||
|
fields = {
|
||||||
|
"payment_line": " ",
|
||||||
|
"Amount": "\t\n",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert result.document_type == "letter"
|
||||||
|
|
||||||
|
# ==================== ClassificationResult Immutability ====================
|
||||||
|
|
||||||
|
def test_classification_result_is_immutable(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""ClassificationResult should be a frozen dataclass."""
|
||||||
|
fields = {"payment_line": "test"}
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
result.document_type = "modified" # type: ignore
|
||||||
|
|
||||||
|
def test_classification_result_has_required_fields(
|
||||||
|
self, classifier: DocumentClassifier
|
||||||
|
) -> None:
|
||||||
|
"""ClassificationResult must have document_type, confidence, reason."""
|
||||||
|
fields = {"Amount": "100.00"}
|
||||||
|
result = classifier.classify(fields)
|
||||||
|
|
||||||
|
assert hasattr(result, "document_type")
|
||||||
|
assert hasattr(result, "confidence")
|
||||||
|
assert hasattr(result, "reason")
|
||||||
|
assert isinstance(result.document_type, str)
|
||||||
|
assert isinstance(result.confidence, float)
|
||||||
|
assert isinstance(result.reason, str)
|
||||||
232
tests/domain/test_invoice_validator.py
Normal file
232
tests/domain/test_invoice_validator.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
"""
|
||||||
|
Tests for InvoiceValidator - TDD RED phase.
|
||||||
|
|
||||||
|
Test invoice field validation logic.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.domain.invoice_validator import (
|
||||||
|
InvoiceValidator,
|
||||||
|
ValidationResult,
|
||||||
|
ValidationIssue,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvoiceValidator:
|
||||||
|
"""Test invoice validation logic."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def validator(self) -> InvoiceValidator:
|
||||||
|
"""Create validator instance with default settings."""
|
||||||
|
return InvoiceValidator()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def validator_strict(self) -> InvoiceValidator:
|
||||||
|
"""Create validator with strict confidence threshold."""
|
||||||
|
return InvoiceValidator(min_confidence=0.8)
|
||||||
|
|
||||||
|
# ==================== Valid Invoice Tests ====================
|
||||||
|
|
||||||
|
def test_validate_complete_invoice_is_valid(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Complete invoice with all required fields is valid."""
|
||||||
|
fields = {
|
||||||
|
"Amount": "1200.00",
|
||||||
|
"OCR": "123456789012",
|
||||||
|
"Bankgiro": "123-4567",
|
||||||
|
}
|
||||||
|
confidence = {
|
||||||
|
"Amount": 0.95,
|
||||||
|
"OCR": 0.90,
|
||||||
|
"Bankgiro": 0.85,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert len([i for i in result.issues if i.severity == "error"]) == 0
|
||||||
|
|
||||||
|
def test_validate_invoice_with_payment_line_is_valid(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Invoice with payment_line as payment reference is valid."""
|
||||||
|
fields = {
|
||||||
|
"Amount": "500.00",
|
||||||
|
"payment_line": "# 123 # 500 00 5 > 308#",
|
||||||
|
}
|
||||||
|
confidence = {"Amount": 0.9, "payment_line": 0.85}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
|
||||||
|
# ==================== Invalid Invoice Tests ====================
|
||||||
|
|
||||||
|
def test_validate_missing_amount_is_invalid(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Missing Amount field should produce error."""
|
||||||
|
fields = {
|
||||||
|
"OCR": "123456789012",
|
||||||
|
"Bankgiro": "123-4567",
|
||||||
|
}
|
||||||
|
confidence = {"OCR": 0.9, "Bankgiro": 0.85}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
error_fields = [i.field for i in result.issues if i.severity == "error"]
|
||||||
|
assert "Amount" in error_fields
|
||||||
|
|
||||||
|
def test_validate_missing_payment_reference_produces_warning(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Missing all payment references should produce warning."""
|
||||||
|
fields = {"Amount": "1200.00"}
|
||||||
|
confidence = {"Amount": 0.9}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
# Missing payment ref is warning, not error
|
||||||
|
warning_fields = [i.field for i in result.issues if i.severity == "warning"]
|
||||||
|
assert "payment_reference" in warning_fields
|
||||||
|
|
||||||
|
# ==================== Confidence Threshold Tests ====================
|
||||||
|
|
||||||
|
def test_validate_low_confidence_produces_warning(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Fields below confidence threshold should produce warning."""
|
||||||
|
fields = {
|
||||||
|
"Amount": "1200.00",
|
||||||
|
"OCR": "123456789012",
|
||||||
|
}
|
||||||
|
confidence = {
|
||||||
|
"Amount": 0.9,
|
||||||
|
"OCR": 0.3, # Below default threshold of 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
low_conf_warnings = [
|
||||||
|
i for i in result.issues
|
||||||
|
if i.severity == "warning" and "confidence" in i.message.lower()
|
||||||
|
]
|
||||||
|
assert len(low_conf_warnings) > 0
|
||||||
|
|
||||||
|
def test_validate_strict_threshold_more_warnings(
|
||||||
|
self, validator_strict: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Strict validator should produce more warnings."""
|
||||||
|
fields = {
|
||||||
|
"Amount": "1200.00",
|
||||||
|
"OCR": "123456789012",
|
||||||
|
}
|
||||||
|
confidence = {
|
||||||
|
"Amount": 0.7, # Below 0.8 threshold
|
||||||
|
"OCR": 0.6, # Below 0.8 threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
result = validator_strict.validate(fields, confidence)
|
||||||
|
|
||||||
|
low_conf_warnings = [
|
||||||
|
i for i in result.issues
|
||||||
|
if i.severity == "warning" and "confidence" in i.message.lower()
|
||||||
|
]
|
||||||
|
assert len(low_conf_warnings) >= 2
|
||||||
|
|
||||||
|
# ==================== Edge Cases ====================
|
||||||
|
|
||||||
|
def test_validate_empty_fields_is_invalid(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Empty fields dict should be invalid."""
|
||||||
|
fields: dict[str, str | None] = {}
|
||||||
|
confidence: dict[str, float] = {}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
|
||||||
|
def test_validate_none_field_values_treated_as_missing(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""None values should be treated as missing."""
|
||||||
|
fields = {
|
||||||
|
"Amount": None,
|
||||||
|
"OCR": "123456789012",
|
||||||
|
}
|
||||||
|
confidence = {"OCR": 0.9}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
error_fields = [i.field for i in result.issues if i.severity == "error"]
|
||||||
|
assert "Amount" in error_fields
|
||||||
|
|
||||||
|
def test_validate_empty_string_treated_as_missing(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Empty string should be treated as missing."""
|
||||||
|
fields = {
|
||||||
|
"Amount": "",
|
||||||
|
"OCR": "123456789012",
|
||||||
|
}
|
||||||
|
confidence = {"OCR": 0.9}
|
||||||
|
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
|
||||||
|
# ==================== ValidationResult Properties ====================
|
||||||
|
|
||||||
|
def test_validation_result_is_immutable(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""ValidationResult should be a frozen dataclass."""
|
||||||
|
fields = {"Amount": "100.00", "OCR": "123"}
|
||||||
|
confidence = {"Amount": 0.9, "OCR": 0.9}
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
with pytest.raises((AttributeError, TypeError)):
|
||||||
|
result.is_valid = False # type: ignore
|
||||||
|
|
||||||
|
def test_validation_result_issues_is_tuple(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""Issues should be a tuple (immutable)."""
|
||||||
|
fields = {"Amount": "100.00"}
|
||||||
|
confidence = {"Amount": 0.9}
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert isinstance(result.issues, tuple)
|
||||||
|
|
||||||
|
def test_validation_result_has_confidence(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""ValidationResult should have confidence score."""
|
||||||
|
fields = {"Amount": "100.00", "OCR": "123"}
|
||||||
|
confidence = {"Amount": 0.9, "OCR": 0.8}
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert hasattr(result, "confidence")
|
||||||
|
assert 0.0 <= result.confidence <= 1.0
|
||||||
|
|
||||||
|
# ==================== ValidationIssue Tests ====================
|
||||||
|
|
||||||
|
def test_validation_issue_has_required_fields(
|
||||||
|
self, validator: InvoiceValidator
|
||||||
|
) -> None:
|
||||||
|
"""ValidationIssue must have field, severity, message."""
|
||||||
|
fields: dict[str, str | None] = {}
|
||||||
|
confidence: dict[str, float] = {}
|
||||||
|
result = validator.validate(fields, confidence)
|
||||||
|
|
||||||
|
assert len(result.issues) > 0
|
||||||
|
issue = result.issues[0]
|
||||||
|
|
||||||
|
assert hasattr(issue, "field")
|
||||||
|
assert hasattr(issue, "severity")
|
||||||
|
assert hasattr(issue, "message")
|
||||||
|
assert issue.severity in ("error", "warning", "info")
|
||||||
@@ -232,10 +232,8 @@ class TestInferenceServicePDFRendering:
|
|||||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||||
@patch('ultralytics.YOLO')
|
|
||||||
def test_pdf_visualization_imports_correctly(
|
def test_pdf_visualization_imports_correctly(
|
||||||
self,
|
self,
|
||||||
mock_yolo_class,
|
|
||||||
mock_render_pdf,
|
mock_render_pdf,
|
||||||
mock_yolo_detector,
|
mock_yolo_detector,
|
||||||
mock_pipeline,
|
mock_pipeline,
|
||||||
@@ -248,12 +246,22 @@ class TestInferenceServicePDFRendering:
|
|||||||
This catches the import error we had with:
|
This catches the import error we had with:
|
||||||
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
|
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
|
||||||
"""
|
"""
|
||||||
# Setup mocks
|
# Setup mocks for detector
|
||||||
mock_detector_instance = Mock()
|
mock_detector_instance = Mock()
|
||||||
mock_pipeline_instance = Mock()
|
mock_model = Mock()
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.save = Mock()
|
||||||
|
mock_model.predict.return_value = [mock_result]
|
||||||
|
mock_detector_instance.model = mock_model
|
||||||
mock_yolo_detector.return_value = mock_detector_instance
|
mock_yolo_detector.return_value = mock_detector_instance
|
||||||
|
|
||||||
|
# Setup mock for pipeline
|
||||||
|
mock_pipeline_instance = Mock()
|
||||||
mock_pipeline.return_value = mock_pipeline_instance
|
mock_pipeline.return_value = mock_pipeline_instance
|
||||||
|
|
||||||
|
# Initialize service to setup _detector
|
||||||
|
inference_service.initialize()
|
||||||
|
|
||||||
# Create a fake PDF path
|
# Create a fake PDF path
|
||||||
pdf_path = tmp_path / "test.pdf"
|
pdf_path = tmp_path / "test.pdf"
|
||||||
pdf_path.touch()
|
pdf_path.touch()
|
||||||
@@ -264,18 +272,12 @@ class TestInferenceServicePDFRendering:
|
|||||||
img.save(image_bytes, format='PNG')
|
img.save(image_bytes, format='PNG')
|
||||||
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
|
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
|
||||||
|
|
||||||
# Mock YOLO
|
# This should not raise ImportError and should use self._detector.model
|
||||||
mock_model_instance = Mock()
|
|
||||||
mock_result = Mock()
|
|
||||||
mock_result.save = Mock()
|
|
||||||
mock_model_instance.predict.return_value = [mock_result]
|
|
||||||
mock_yolo_class.return_value = mock_model_instance
|
|
||||||
|
|
||||||
# This should not raise ImportError
|
|
||||||
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
|
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
|
||||||
|
|
||||||
# Verify import was successful
|
# Verify import was successful and detector.model was used
|
||||||
mock_render_pdf.assert_called_once()
|
mock_render_pdf.assert_called_once()
|
||||||
|
mock_model.predict.assert_called_once()
|
||||||
assert result_path is not None
|
assert result_path is not None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user