399 lines
14 KiB
Python
399 lines
14 KiB
Python
"""
|
|
Inference Service (Adapter Layer)
|
|
|
|
Orchestrates technical pipeline and business domain logic.
|
|
Acts as adapter between API layer and internal components.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Callable, Generator
|
|
|
|
from PIL import Image
|
|
|
|
from backend.domain.document_classifier import DocumentClassifier
|
|
from backend.web.services.storage_helpers import get_storage_helper
|
|
|
|
if TYPE_CHECKING:
|
|
from .config import ModelConfig, StorageConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Type alias for model path resolver function
|
|
ModelPathResolver = Callable[[], Path | None]
|
|
|
|
|
|
@dataclass
|
|
class ServiceResult:
|
|
"""Result from inference service."""
|
|
|
|
document_id: str
|
|
success: bool = False
|
|
document_type: str = "invoice" # "invoice" or "letter"
|
|
fields: dict[str, str | None] = field(default_factory=dict)
|
|
confidence: dict[str, float] = field(default_factory=dict)
|
|
detections: list[dict] = field(default_factory=list)
|
|
processing_time_ms: float = 0.0
|
|
visualization_path: Path | None = None
|
|
errors: list[str] = field(default_factory=list)
|
|
|
|
# Business features (optional, populated when extract_line_items=True)
|
|
line_items: dict | None = None
|
|
vat_summary: dict | None = None
|
|
vat_validation: dict | None = None
|
|
|
|
|
|
class InferenceService:
|
|
"""
|
|
Service for running invoice field extraction (Adapter Pattern).
|
|
|
|
Orchestrates:
|
|
- Technical layer: InferencePipeline, YOLODetector
|
|
- Business layer: DocumentClassifier
|
|
|
|
Supports dynamic model loading from database.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_config: ModelConfig,
|
|
storage_config: StorageConfig,
|
|
model_path_resolver: ModelPathResolver | None = None,
|
|
document_classifier: DocumentClassifier | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize inference service.
|
|
|
|
Args:
|
|
model_config: Model configuration (default model settings)
|
|
storage_config: Storage configuration
|
|
model_path_resolver: Optional function to resolve model path from database.
|
|
If provided, will be called to get active model path.
|
|
If returns None, falls back to model_config.model_path.
|
|
document_classifier: Optional custom classifier (uses default if None)
|
|
"""
|
|
self.model_config = model_config
|
|
self.storage_config = storage_config
|
|
self._model_path_resolver = model_path_resolver
|
|
|
|
# Technical layer (lazy initialized)
|
|
self._pipeline = None
|
|
self._detector = None
|
|
|
|
# Business layer (eagerly initialized, no heavy resources)
|
|
self._classifier = document_classifier or DocumentClassifier()
|
|
|
|
self._is_initialized = False
|
|
self._current_model_path: Path | None = None
|
|
self._business_features_enabled = False
|
|
|
|
def _resolve_model_path(self) -> Path:
|
|
"""Resolve the model path to use for inference.
|
|
|
|
Priority:
|
|
1. Active model from database (via resolver)
|
|
2. Default model from config
|
|
"""
|
|
if self._model_path_resolver:
|
|
try:
|
|
db_model_path = self._model_path_resolver()
|
|
if db_model_path and Path(db_model_path).exists():
|
|
logger.info(f"Using active model from database: {db_model_path}")
|
|
return Path(db_model_path)
|
|
elif db_model_path:
|
|
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
|
|
|
|
return self.model_config.model_path
|
|
|
|
def initialize(self, enable_business_features: bool = False) -> None:
|
|
"""Initialize the inference pipeline (lazy loading).
|
|
|
|
Args:
|
|
enable_business_features: Whether to enable line items and VAT extraction
|
|
"""
|
|
if self._is_initialized:
|
|
return
|
|
|
|
logger.info(f"Initializing inference service (business_features={enable_business_features})...")
|
|
start_time = time.time()
|
|
|
|
try:
|
|
from backend.pipeline.pipeline import InferencePipeline
|
|
from backend.pipeline.yolo_detector import YOLODetector
|
|
|
|
# Resolve model path (from DB or config)
|
|
model_path = self._resolve_model_path()
|
|
self._current_model_path = model_path
|
|
|
|
# Initialize YOLO detector for visualization
|
|
self._detector = YOLODetector(
|
|
str(model_path),
|
|
confidence_threshold=self.model_config.confidence_threshold,
|
|
device="cuda" if self.model_config.use_gpu else "cpu",
|
|
)
|
|
|
|
# Initialize full pipeline with optional business features
|
|
self._pipeline = InferencePipeline(
|
|
model_path=str(model_path),
|
|
confidence_threshold=self.model_config.confidence_threshold,
|
|
use_gpu=self.model_config.use_gpu,
|
|
dpi=self.model_config.dpi,
|
|
enable_fallback=True,
|
|
enable_business_features=enable_business_features,
|
|
)
|
|
|
|
self._is_initialized = True
|
|
self._business_features_enabled = enable_business_features
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize inference service: {e}")
|
|
raise
|
|
|
|
def reload_model(self) -> bool:
|
|
"""Reload the model if active model has changed.
|
|
|
|
Returns:
|
|
True if model was reloaded, False if no change needed.
|
|
"""
|
|
new_model_path = self._resolve_model_path()
|
|
|
|
if self._current_model_path == new_model_path:
|
|
logger.debug("Model unchanged, no reload needed")
|
|
return False
|
|
|
|
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
|
|
self._is_initialized = False
|
|
self._pipeline = None
|
|
self._detector = None
|
|
self.initialize()
|
|
return True
|
|
|
|
@property
|
|
def current_model_path(self) -> Path | None:
|
|
"""Get the currently loaded model path."""
|
|
return self._current_model_path
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Check if service is initialized."""
|
|
return self._is_initialized
|
|
|
|
@property
|
|
def gpu_available(self) -> bool:
|
|
"""Check if GPU is available."""
|
|
try:
|
|
import torch
|
|
return torch.cuda.is_available()
|
|
except ImportError:
|
|
return False
|
|
|
|
def process_image(
|
|
self,
|
|
image_path: Path,
|
|
document_id: str | None = None,
|
|
save_visualization: bool = True,
|
|
) -> ServiceResult:
|
|
"""
|
|
Process an image file and extract invoice fields.
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
document_id: Optional document ID
|
|
save_visualization: Whether to save visualization
|
|
|
|
Returns:
|
|
ServiceResult with extracted fields
|
|
"""
|
|
if not self._is_initialized:
|
|
self.initialize()
|
|
|
|
doc_id = document_id or str(uuid.uuid4())[:8]
|
|
start_time = time.time()
|
|
|
|
result = ServiceResult(document_id=doc_id)
|
|
|
|
try:
|
|
# Run inference pipeline
|
|
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
|
|
|
|
result.fields = pipeline_result.fields
|
|
result.confidence = pipeline_result.confidence
|
|
result.success = pipeline_result.success
|
|
result.errors = pipeline_result.errors
|
|
|
|
# Business layer: classify document type
|
|
classification = self._classifier.classify(result.fields)
|
|
result.document_type = classification.document_type
|
|
|
|
# Get raw detections for visualization
|
|
result.detections = self._format_detections(pipeline_result.raw_detections)
|
|
|
|
# Save visualization if requested
|
|
if save_visualization and pipeline_result.raw_detections:
|
|
viz_path = self._save_visualization(image_path, doc_id)
|
|
result.visualization_path = viz_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing image {image_path}: {e}")
|
|
result.errors.append(str(e))
|
|
result.success = False
|
|
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
return result
|
|
|
|
def process_pdf(
|
|
self,
|
|
pdf_path: Path,
|
|
document_id: str | None = None,
|
|
save_visualization: bool = True,
|
|
extract_line_items: bool = False,
|
|
) -> ServiceResult:
|
|
"""
|
|
Process a PDF file and extract invoice fields.
|
|
|
|
Args:
|
|
pdf_path: Path to PDF file
|
|
document_id: Optional document ID
|
|
save_visualization: Whether to save visualization
|
|
extract_line_items: Whether to extract line items and VAT info
|
|
|
|
Returns:
|
|
ServiceResult with extracted fields
|
|
"""
|
|
if not self._is_initialized:
|
|
self.initialize(enable_business_features=extract_line_items)
|
|
elif extract_line_items and not self._business_features_enabled:
|
|
# Reinitialize with business features if needed
|
|
self._is_initialized = False
|
|
self.initialize(enable_business_features=True)
|
|
|
|
doc_id = document_id or str(uuid.uuid4())[:8]
|
|
start_time = time.time()
|
|
|
|
result = ServiceResult(document_id=doc_id)
|
|
|
|
try:
|
|
# Run inference pipeline with optional business features
|
|
pipeline_result = self._pipeline.process_pdf(
|
|
pdf_path,
|
|
document_id=doc_id,
|
|
extract_line_items=extract_line_items,
|
|
)
|
|
|
|
result.fields = pipeline_result.fields
|
|
result.confidence = pipeline_result.confidence
|
|
result.success = pipeline_result.success
|
|
result.errors = pipeline_result.errors
|
|
|
|
# Business layer: classify document type
|
|
classification = self._classifier.classify(result.fields)
|
|
result.document_type = classification.document_type
|
|
|
|
# Get raw detections
|
|
result.detections = self._format_detections(pipeline_result.raw_detections)
|
|
|
|
# Include business features if extracted
|
|
if extract_line_items:
|
|
result.line_items = pipeline_result._line_items_to_json() if pipeline_result.line_items else None
|
|
result.vat_summary = pipeline_result._vat_summary_to_json() if pipeline_result.vat_summary else None
|
|
result.vat_validation = pipeline_result._vat_validation_to_json() if pipeline_result.vat_validation else None
|
|
|
|
# Save visualization (render first page)
|
|
if save_visualization and pipeline_result.raw_detections:
|
|
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
|
|
result.visualization_path = viz_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing PDF {pdf_path}: {e}")
|
|
result.errors.append(str(e))
|
|
result.success = False
|
|
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
return result
|
|
|
|
def _format_detections(self, raw_detections: list) -> list[dict]:
|
|
"""Format raw detections for response."""
|
|
return [
|
|
{
|
|
"field": d.class_name,
|
|
"confidence": d.confidence,
|
|
"bbox": list(d.bbox),
|
|
}
|
|
for d in raw_detections
|
|
]
|
|
|
|
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
|
|
"""Save visualization image with detections using existing detector."""
|
|
# Get storage helper for results directory
|
|
storage = get_storage_helper()
|
|
results_dir = storage.get_results_base_path()
|
|
if results_dir is None:
|
|
logger.warning("Cannot save visualization: local storage not available")
|
|
return None
|
|
|
|
# Reuse self._detector instead of creating new YOLO instance
|
|
results = self._detector.model.predict(str(image_path), verbose=False)
|
|
|
|
# Save annotated image
|
|
output_path = results_dir / f"{doc_id}_result.png"
|
|
for r in results:
|
|
r.save(filename=str(output_path))
|
|
|
|
return output_path
|
|
|
|
@contextmanager
|
|
def _temp_image_file(
|
|
self, results_dir: Path, doc_id: str
|
|
) -> Generator[Path, None, None]:
|
|
"""Context manager for temporary image file with guaranteed cleanup."""
|
|
temp_path = results_dir / f"{doc_id}_temp.png"
|
|
try:
|
|
yield temp_path
|
|
finally:
|
|
temp_path.unlink(missing_ok=True)
|
|
|
|
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
|
|
"""Save visualization for PDF (first page) using existing detector."""
|
|
from shared.pdf.renderer import render_pdf_to_images
|
|
|
|
# Get storage helper for results directory
|
|
storage = get_storage_helper()
|
|
results_dir = storage.get_results_base_path()
|
|
if results_dir is None:
|
|
logger.warning("Cannot save visualization: local storage not available")
|
|
return None
|
|
|
|
# Render first page
|
|
for page_no, image_bytes in render_pdf_to_images(
|
|
pdf_path, dpi=self.model_config.dpi
|
|
):
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
|
|
# Use context manager for temp file to guarantee cleanup
|
|
with self._temp_image_file(results_dir, doc_id) as temp_path:
|
|
image.save(temp_path)
|
|
|
|
# Reuse self._detector instead of creating new YOLO instance
|
|
results = self._detector.model.predict(str(temp_path), verbose=False)
|
|
|
|
output_path = results_dir / f"{doc_id}_result.png"
|
|
for r in results:
|
|
r.save(filename=str(output_path))
|
|
|
|
return output_path
|
|
|
|
# If no pages rendered
|
|
return None
|