This commit is contained in:
Yaojia Wang
2026-02-03 22:29:53 +01:00
parent 4c7fc3015c
commit c2c8f2dd04
10 changed files with 786 additions and 68 deletions

View File

@@ -0,0 +1,25 @@
"""
Domain Layer
Business logic separated from technical implementation.
Contains document classification and invoice validation logic.
"""
from backend.domain.document_classifier import (
ClassificationResult,
DocumentClassifier,
)
from backend.domain.invoice_validator import (
InvoiceValidator,
ValidationIssue,
ValidationResult,
)
from backend.domain.utils import has_value
__all__ = [
"ClassificationResult",
"DocumentClassifier",
"InvoiceValidator",
"ValidationIssue",
"ValidationResult",
"has_value",
]

View File

@@ -0,0 +1,108 @@
"""
Document Classifier
Business logic for classifying documents based on extracted fields.
Separates classification logic from inference pipeline.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ClassificationResult:
"""
Immutable result of document classification.
Attributes:
document_type: Either "invoice" or "letter"
confidence: Confidence score between 0.0 and 1.0
reason: Human-readable explanation of classification
"""
document_type: str
confidence: float
reason: str
class DocumentClassifier:
"""
Classifies documents as invoice or letter based on extracted fields.
Classification Rules:
1. If payment_line is present -> invoice (high confidence)
2. If 2+ invoice indicators present -> invoice (medium confidence)
3. If 1 invoice indicator present -> invoice (lower confidence)
4. Otherwise -> letter
Invoice indicator fields:
- payment_line (strongest indicator)
- OCR
- Amount
- Bankgiro
- Plusgiro
- InvoiceNumber
"""
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
{
"payment_line",
"OCR",
"Amount",
"Bankgiro",
"Plusgiro",
"InvoiceNumber",
}
)
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
"""
Classify document type based on extracted fields.
Args:
fields: Dictionary of field names to extracted values.
Empty strings or whitespace-only strings are treated as missing.
Returns:
Immutable ClassificationResult with type, confidence, and reason.
"""
# Rule 1: payment_line is the strongest indicator
if has_value(fields.get("payment_line")):
return ClassificationResult(
document_type="invoice",
confidence=0.95,
reason="payment_line detected",
)
# Count present invoice indicators (excluding payment_line already checked)
present_indicators = [
field
for field in self.INVOICE_INDICATOR_FIELDS
if field != "payment_line" and has_value(fields.get(field))
]
indicator_count = len(present_indicators)
# Rule 2: Multiple indicators -> invoice with medium-high confidence
if indicator_count >= 2:
return ClassificationResult(
document_type="invoice",
confidence=0.8,
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
)
# Rule 3: Single indicator -> invoice with lower confidence
if indicator_count == 1:
return ClassificationResult(
document_type="invoice",
confidence=0.6,
reason=f"1 invoice indicator present: {present_indicators[0]}",
)
# Rule 4: No indicators -> letter
return ClassificationResult(
document_type="letter",
confidence=0.7,
reason="no invoice indicators found",
)

View File

@@ -0,0 +1,141 @@
"""
Invoice Validator
Business logic for validating extracted invoice fields.
Checks for required fields, format validity, and confidence thresholds.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ValidationIssue:
"""
Single validation issue.
Attributes:
field: Name of the field with the issue
severity: One of "error", "warning", "info"
message: Human-readable description of the issue
"""
field: str
severity: str
message: str
@dataclass(frozen=True)
class ValidationResult:
"""
Immutable result of invoice validation.
Attributes:
is_valid: True if no errors (warnings are allowed)
issues: Tuple of validation issues found
confidence: Average confidence score of validated fields
"""
is_valid: bool
issues: tuple[ValidationIssue, ...]
confidence: float
class InvoiceValidator:
"""
Validates extracted invoice fields for completeness and consistency.
Validation Rules:
1. Required fields must be present (Amount)
2. At least one payment reference should be present (warning if missing)
3. Field confidence should be above threshold (warning if below)
Required fields:
- Amount
Payment reference fields (at least one expected):
- OCR
- Bankgiro
- Plusgiro
- payment_line
"""
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
DEFAULT_MIN_CONFIDENCE: float = 0.5
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
"""
Initialize validator.
Args:
min_confidence: Minimum confidence threshold for valid fields.
Fields below this threshold produce warnings.
"""
self._min_confidence = min_confidence
def validate(
self,
fields: dict[str, str | None],
confidence: dict[str, float],
) -> ValidationResult:
"""
Validate extracted invoice fields.
Args:
fields: Dictionary of field names to extracted values
confidence: Dictionary of field names to confidence scores
Returns:
Immutable ValidationResult with validity status and issues
"""
issues: list[ValidationIssue] = []
# Check required fields
for field in self.REQUIRED_FIELDS:
if not has_value(fields.get(field)):
issues.append(
ValidationIssue(
field=field,
severity="error",
message=f"Required field '{field}' is missing",
)
)
# Check payment reference (at least one expected)
has_payment_ref = any(
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
)
if not has_payment_ref:
issues.append(
ValidationIssue(
field="payment_reference",
severity="warning",
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
)
)
# Check confidence thresholds
for field, conf in confidence.items():
if conf < self._min_confidence:
issues.append(
ValidationIssue(
field=field,
severity="warning",
message=f"Low confidence ({conf:.2f}) for field '{field}'",
)
)
# Calculate overall validity
has_errors = any(i.severity == "error" for i in issues)
avg_confidence = (
sum(confidence.values()) / len(confidence) if confidence else 0.0
)
return ValidationResult(
is_valid=not has_errors,
issues=tuple(issues),
confidence=avg_confidence,
)

View File

@@ -0,0 +1,23 @@
"""
Domain Layer Utilities
Shared helper functions for domain layer classes.
"""
from __future__ import annotations
def has_value(value: str | None) -> bool:
"""
Check if a field value is present and non-empty.
Args:
value: Field value to check
Returns:
True if value is a non-empty, non-whitespace string
"""
if value is None:
return False
if not isinstance(value, str):
return bool(value)
return bool(value.strip())

View File

@@ -1,21 +1,24 @@
"""
Inference Service
Inference Service (Adapter Layer)
Business logic for invoice field extraction.
Orchestrates technical pipeline and business domain logic.
Acts as adapter between API layer and internal components.
"""
from __future__ import annotations
import io
import logging
import time
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Generator
import numpy as np
from PIL import Image
from backend.domain.document_classifier import DocumentClassifier
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
@@ -50,9 +53,12 @@ class ServiceResult:
class InferenceService:
"""
Service for running invoice field extraction.
Service for running invoice field extraction (Adapter Pattern).
Orchestrates:
- Technical layer: InferencePipeline, YOLODetector
- Business layer: DocumentClassifier
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database.
"""
@@ -61,6 +67,7 @@ class InferenceService:
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
document_classifier: DocumentClassifier | None = None,
) -> None:
"""
Initialize inference service.
@@ -71,12 +78,19 @@ class InferenceService:
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
document_classifier: Optional custom classifier (uses default if None)
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
# Technical layer (lazy initialized)
self._pipeline = None
self._detector = None
# Business layer (eagerly initialized, no heavy resources)
self._classifier = document_classifier or DocumentClassifier()
self._is_initialized = False
self._current_model_path: Path | None = None
self._business_features_enabled = False
@@ -219,22 +233,12 @@ class InferenceService:
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Business layer: classify document type
classification = self._classifier.classify(result.fields)
result.document_type = classification.document_type
# Get raw detections for visualization
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
result.detections = self._format_detections(pipeline_result.raw_detections)
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
@@ -293,22 +297,12 @@ class InferenceService:
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Business layer: classify document type
classification = self._classifier.classify(result.fields)
result.document_type = classification.document_type
# Get raw detections
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
result.detections = self._format_detections(pipeline_result.raw_detections)
# Include business features if extracted
if extract_line_items:
@@ -329,10 +323,19 @@ class InferenceService:
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
"""Save visualization image with detections."""
from ultralytics import YOLO
def _format_detections(self, raw_detections: list) -> list[dict]:
"""Format raw detections for response."""
return [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in raw_detections
]
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
"""Save visualization image with detections using existing detector."""
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
@@ -340,9 +343,8 @@ class InferenceService:
logger.warning("Cannot save visualization: local storage not available")
return None
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Reuse self._detector instead of creating new YOLO instance
results = self._detector.model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = results_dir / f"{doc_id}_result.png"
@@ -351,11 +353,20 @@ class InferenceService:
return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page)."""
@contextmanager
def _temp_image_file(
self, results_dir: Path, doc_id: str
) -> Generator[Path, None, None]:
"""Context manager for temporary image file with guaranteed cleanup."""
temp_path = results_dir / f"{doc_id}_temp.png"
try:
yield temp_path
finally:
temp_path.unlink(missing_ok=True)
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
"""Save visualization for PDF (first page) using existing detector."""
from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Get storage helper for results directory
storage = get_storage_helper()
@@ -369,20 +380,19 @@ class InferenceService:
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = results_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
# Use context manager for temp file to guarantee cleanup
with self._temp_image_file(results_dir, doc_id) as temp_path:
image.save(temp_path)
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
# Reuse self._detector instead of creating new YOLO instance
results = self._detector.model.predict(str(temp_path), verbose=False)
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
# If no pages rendered
return None