WIP
This commit is contained in:
25
packages/backend/backend/domain/__init__.py
Normal file
25
packages/backend/backend/domain/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Domain Layer
|
||||
|
||||
Business logic separated from technical implementation.
|
||||
Contains document classification and invoice validation logic.
|
||||
"""
|
||||
from backend.domain.document_classifier import (
|
||||
ClassificationResult,
|
||||
DocumentClassifier,
|
||||
)
|
||||
from backend.domain.invoice_validator import (
|
||||
InvoiceValidator,
|
||||
ValidationIssue,
|
||||
ValidationResult,
|
||||
)
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
__all__ = [
|
||||
"ClassificationResult",
|
||||
"DocumentClassifier",
|
||||
"InvoiceValidator",
|
||||
"ValidationIssue",
|
||||
"ValidationResult",
|
||||
"has_value",
|
||||
]
|
||||
108
packages/backend/backend/domain/document_classifier.py
Normal file
108
packages/backend/backend/domain/document_classifier.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Document Classifier
|
||||
|
||||
Business logic for classifying documents based on extracted fields.
|
||||
Separates classification logic from inference pipeline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClassificationResult:
|
||||
"""
|
||||
Immutable result of document classification.
|
||||
|
||||
Attributes:
|
||||
document_type: Either "invoice" or "letter"
|
||||
confidence: Confidence score between 0.0 and 1.0
|
||||
reason: Human-readable explanation of classification
|
||||
"""
|
||||
|
||||
document_type: str
|
||||
confidence: float
|
||||
reason: str
|
||||
|
||||
|
||||
class DocumentClassifier:
|
||||
"""
|
||||
Classifies documents as invoice or letter based on extracted fields.
|
||||
|
||||
Classification Rules:
|
||||
1. If payment_line is present -> invoice (high confidence)
|
||||
2. If 2+ invoice indicators present -> invoice (medium confidence)
|
||||
3. If 1 invoice indicator present -> invoice (lower confidence)
|
||||
4. Otherwise -> letter
|
||||
|
||||
Invoice indicator fields:
|
||||
- payment_line (strongest indicator)
|
||||
- OCR
|
||||
- Amount
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- InvoiceNumber
|
||||
"""
|
||||
|
||||
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"payment_line",
|
||||
"OCR",
|
||||
"Amount",
|
||||
"Bankgiro",
|
||||
"Plusgiro",
|
||||
"InvoiceNumber",
|
||||
}
|
||||
)
|
||||
|
||||
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
|
||||
"""
|
||||
Classify document type based on extracted fields.
|
||||
|
||||
Args:
|
||||
fields: Dictionary of field names to extracted values.
|
||||
Empty strings or whitespace-only strings are treated as missing.
|
||||
|
||||
Returns:
|
||||
Immutable ClassificationResult with type, confidence, and reason.
|
||||
"""
|
||||
# Rule 1: payment_line is the strongest indicator
|
||||
if has_value(fields.get("payment_line")):
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.95,
|
||||
reason="payment_line detected",
|
||||
)
|
||||
|
||||
# Count present invoice indicators (excluding payment_line already checked)
|
||||
present_indicators = [
|
||||
field
|
||||
for field in self.INVOICE_INDICATOR_FIELDS
|
||||
if field != "payment_line" and has_value(fields.get(field))
|
||||
]
|
||||
indicator_count = len(present_indicators)
|
||||
|
||||
# Rule 2: Multiple indicators -> invoice with medium-high confidence
|
||||
if indicator_count >= 2:
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.8,
|
||||
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
|
||||
)
|
||||
|
||||
# Rule 3: Single indicator -> invoice with lower confidence
|
||||
if indicator_count == 1:
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.6,
|
||||
reason=f"1 invoice indicator present: {present_indicators[0]}",
|
||||
)
|
||||
|
||||
# Rule 4: No indicators -> letter
|
||||
return ClassificationResult(
|
||||
document_type="letter",
|
||||
confidence=0.7,
|
||||
reason="no invoice indicators found",
|
||||
)
|
||||
141
packages/backend/backend/domain/invoice_validator.py
Normal file
141
packages/backend/backend/domain/invoice_validator.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Invoice Validator
|
||||
|
||||
Business logic for validating extracted invoice fields.
|
||||
Checks for required fields, format validity, and confidence thresholds.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidationIssue:
|
||||
"""
|
||||
Single validation issue.
|
||||
|
||||
Attributes:
|
||||
field: Name of the field with the issue
|
||||
severity: One of "error", "warning", "info"
|
||||
message: Human-readable description of the issue
|
||||
"""
|
||||
|
||||
field: str
|
||||
severity: str
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidationResult:
|
||||
"""
|
||||
Immutable result of invoice validation.
|
||||
|
||||
Attributes:
|
||||
is_valid: True if no errors (warnings are allowed)
|
||||
issues: Tuple of validation issues found
|
||||
confidence: Average confidence score of validated fields
|
||||
"""
|
||||
|
||||
is_valid: bool
|
||||
issues: tuple[ValidationIssue, ...]
|
||||
confidence: float
|
||||
|
||||
|
||||
class InvoiceValidator:
|
||||
"""
|
||||
Validates extracted invoice fields for completeness and consistency.
|
||||
|
||||
Validation Rules:
|
||||
1. Required fields must be present (Amount)
|
||||
2. At least one payment reference should be present (warning if missing)
|
||||
3. Field confidence should be above threshold (warning if below)
|
||||
|
||||
Required fields:
|
||||
- Amount
|
||||
|
||||
Payment reference fields (at least one expected):
|
||||
- OCR
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- payment_line
|
||||
"""
|
||||
|
||||
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
|
||||
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
|
||||
DEFAULT_MIN_CONFIDENCE: float = 0.5
|
||||
|
||||
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
|
||||
"""
|
||||
Initialize validator.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum confidence threshold for valid fields.
|
||||
Fields below this threshold produce warnings.
|
||||
"""
|
||||
self._min_confidence = min_confidence
|
||||
|
||||
def validate(
|
||||
self,
|
||||
fields: dict[str, str | None],
|
||||
confidence: dict[str, float],
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate extracted invoice fields.
|
||||
|
||||
Args:
|
||||
fields: Dictionary of field names to extracted values
|
||||
confidence: Dictionary of field names to confidence scores
|
||||
|
||||
Returns:
|
||||
Immutable ValidationResult with validity status and issues
|
||||
"""
|
||||
issues: list[ValidationIssue] = []
|
||||
|
||||
# Check required fields
|
||||
for field in self.REQUIRED_FIELDS:
|
||||
if not has_value(fields.get(field)):
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field=field,
|
||||
severity="error",
|
||||
message=f"Required field '{field}' is missing",
|
||||
)
|
||||
)
|
||||
|
||||
# Check payment reference (at least one expected)
|
||||
has_payment_ref = any(
|
||||
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
|
||||
)
|
||||
if not has_payment_ref:
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field="payment_reference",
|
||||
severity="warning",
|
||||
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
|
||||
)
|
||||
)
|
||||
|
||||
# Check confidence thresholds
|
||||
for field, conf in confidence.items():
|
||||
if conf < self._min_confidence:
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field=field,
|
||||
severity="warning",
|
||||
message=f"Low confidence ({conf:.2f}) for field '{field}'",
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate overall validity
|
||||
has_errors = any(i.severity == "error" for i in issues)
|
||||
avg_confidence = (
|
||||
sum(confidence.values()) / len(confidence) if confidence else 0.0
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=not has_errors,
|
||||
issues=tuple(issues),
|
||||
confidence=avg_confidence,
|
||||
)
|
||||
23
packages/backend/backend/domain/utils.py
Normal file
23
packages/backend/backend/domain/utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Domain Layer Utilities
|
||||
|
||||
Shared helper functions for domain layer classes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def has_value(value: str | None) -> bool:
|
||||
"""
|
||||
Check if a field value is present and non-empty.
|
||||
|
||||
Args:
|
||||
value: Field value to check
|
||||
|
||||
Returns:
|
||||
True if value is a non-empty, non-whitespace string
|
||||
"""
|
||||
if value is None:
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return bool(value)
|
||||
return bool(value.strip())
|
||||
@@ -1,21 +1,24 @@
|
||||
"""
|
||||
Inference Service
|
||||
Inference Service (Adapter Layer)
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
Orchestrates technical pipeline and business domain logic.
|
||||
Acts as adapter between API layer and internal components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Generator
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from backend.domain.document_classifier import DocumentClassifier
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,9 +53,12 @@ class ServiceResult:
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
Service for running invoice field extraction (Adapter Pattern).
|
||||
|
||||
Orchestrates:
|
||||
- Technical layer: InferencePipeline, YOLODetector
|
||||
- Business layer: DocumentClassifier
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
@@ -61,6 +67,7 @@ class InferenceService:
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
document_classifier: DocumentClassifier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
@@ -71,12 +78,19 @@ class InferenceService:
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
document_classifier: Optional custom classifier (uses default if None)
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
|
||||
# Technical layer (lazy initialized)
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
|
||||
# Business layer (eagerly initialized, no heavy resources)
|
||||
self._classifier = document_classifier or DocumentClassifier()
|
||||
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
self._business_features_enabled = False
|
||||
@@ -219,22 +233,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
@@ -293,22 +297,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Include business features if extracted
|
||||
if extract_line_items:
|
||||
@@ -329,10 +323,19 @@ class InferenceService:
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
def _format_detections(self, raw_detections: list) -> list[dict]:
|
||||
"""Format raw detections for response."""
|
||||
return [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in raw_detections
|
||||
]
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization image with detections using existing detector."""
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
@@ -340,9 +343,8 @@ class InferenceService:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
@@ -351,11 +353,20 @@ class InferenceService:
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
@contextmanager
|
||||
def _temp_image_file(
|
||||
self, results_dir: Path, doc_id: str
|
||||
) -> Generator[Path, None, None]:
|
||||
"""Context manager for temporary image file with guaranteed cleanup."""
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
try:
|
||||
yield temp_path
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization for PDF (first page) using existing detector."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
@@ -369,20 +380,19 @@ class InferenceService:
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
# Use context manager for temp file to guarantee cleanup
|
||||
with self._temp_image_file(results_dir, doc_id) as temp_path:
|
||||
image.save(temp_path)
|
||||
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(temp_path), verbose=False)
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user