Files
invoice-master-poc-v2/packages/backend/backend/web/services/inference.py
Yaojia Wang c2c8f2dd04 WIP
2026-02-03 22:29:53 +01:00

399 lines
14 KiB
Python

"""
Inference Service (Adapter Layer)
Orchestrates technical pipeline and business domain logic.
Acts as adapter between API layer and internal components.
"""
from __future__ import annotations
import io
import logging
import time
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Generator
from PIL import Image
from backend.domain.document_classifier import DocumentClassifier
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig
logger = logging.getLogger(__name__)
# Type alias for model path resolver function
ModelPathResolver = Callable[[], Path | None]
@dataclass
class ServiceResult:
"""Result from inference service."""
document_id: str
success: bool = False
document_type: str = "invoice" # "invoice" or "letter"
fields: dict[str, str | None] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
detections: list[dict] = field(default_factory=list)
processing_time_ms: float = 0.0
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
# Business features (optional, populated when extract_line_items=True)
line_items: dict | None = None
vat_summary: dict | None = None
vat_validation: dict | None = None
class InferenceService:
"""
Service for running invoice field extraction (Adapter Pattern).
Orchestrates:
- Technical layer: InferencePipeline, YOLODetector
- Business layer: DocumentClassifier
Supports dynamic model loading from database.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
document_classifier: DocumentClassifier | None = None,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration (default model settings)
storage_config: Storage configuration
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
document_classifier: Optional custom classifier (uses default if None)
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
# Technical layer (lazy initialized)
self._pipeline = None
self._detector = None
# Business layer (eagerly initialized, no heavy resources)
self._classifier = document_classifier or DocumentClassifier()
self._is_initialized = False
self._current_model_path: Path | None = None
self._business_features_enabled = False
def _resolve_model_path(self) -> Path:
"""Resolve the model path to use for inference.
Priority:
1. Active model from database (via resolver)
2. Default model from config
"""
if self._model_path_resolver:
try:
db_model_path = self._model_path_resolver()
if db_model_path and Path(db_model_path).exists():
logger.info(f"Using active model from database: {db_model_path}")
return Path(db_model_path)
elif db_model_path:
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
except Exception as e:
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
return self.model_config.model_path
def initialize(self, enable_business_features: bool = False) -> None:
"""Initialize the inference pipeline (lazy loading).
Args:
enable_business_features: Whether to enable line items and VAT extraction
"""
if self._is_initialized:
return
logger.info(f"Initializing inference service (business_features={enable_business_features})...")
start_time = time.time()
try:
from backend.pipeline.pipeline import InferencePipeline
from backend.pipeline.yolo_detector import YOLODetector
# Resolve model path (from DB or config)
model_path = self._resolve_model_path()
self._current_model_path = model_path
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline with optional business features
self._pipeline = InferencePipeline(
model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
enable_fallback=True,
enable_business_features=enable_business_features,
)
self._is_initialized = True
self._business_features_enabled = enable_business_features
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
def reload_model(self) -> bool:
"""Reload the model if active model has changed.
Returns:
True if model was reloaded, False if no change needed.
"""
new_model_path = self._resolve_model_path()
if self._current_model_path == new_model_path:
logger.debug("Model unchanged, no reload needed")
return False
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
self._is_initialized = False
self._pipeline = None
self._detector = None
self.initialize()
return True
@property
def current_model_path(self) -> Path | None:
"""Get the currently loaded model path."""
return self._current_model_path
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""
return self._is_initialized
@property
def gpu_available(self) -> bool:
"""Check if GPU is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def process_image(
self,
image_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process an image file and extract invoice fields.
Args:
image_path: Path to image file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Business layer: classify document type
classification = self._classifier.classify(result.fields)
result.document_type = classification.document_type
# Get raw detections for visualization
result.detections = self._format_detections(pipeline_result.raw_detections)
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_visualization(image_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def process_pdf(
self,
pdf_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
extract_line_items: bool = False,
) -> ServiceResult:
"""
Process a PDF file and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
save_visualization: Whether to save visualization
extract_line_items: Whether to extract line items and VAT info
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize(enable_business_features=extract_line_items)
elif extract_line_items and not self._business_features_enabled:
# Reinitialize with business features if needed
self._is_initialized = False
self.initialize(enable_business_features=True)
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline with optional business features
pipeline_result = self._pipeline.process_pdf(
pdf_path,
document_id=doc_id,
extract_line_items=extract_line_items,
)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Business layer: classify document type
classification = self._classifier.classify(result.fields)
result.document_type = classification.document_type
# Get raw detections
result.detections = self._format_detections(pipeline_result.raw_detections)
# Include business features if extracted
if extract_line_items:
result.line_items = pipeline_result._line_items_to_json() if pipeline_result.line_items else None
result.vat_summary = pipeline_result._vat_summary_to_json() if pipeline_result.vat_summary else None
result.vat_validation = pipeline_result._vat_validation_to_json() if pipeline_result.vat_validation else None
# Save visualization (render first page)
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _format_detections(self, raw_detections: list) -> list[dict]:
"""Format raw detections for response."""
return [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in raw_detections
]
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
"""Save visualization image with detections using existing detector."""
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Reuse self._detector instead of creating new YOLO instance
results = self._detector.model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
@contextmanager
def _temp_image_file(
self, results_dir: Path, doc_id: str
) -> Generator[Path, None, None]:
"""Context manager for temporary image file with guaranteed cleanup."""
temp_path = results_dir / f"{doc_id}_temp.png"
try:
yield temp_path
finally:
temp_path.unlink(missing_ok=True)
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
"""Save visualization for PDF (first page) using existing detector."""
from shared.pdf.renderer import render_pdf_to_images
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Render first page
for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
# Use context manager for temp file to guarantee cleanup
with self._temp_image_file(results_dir, doc_id) as temp_path:
image.save(temp_path)
# Reuse self._detector instead of creating new YOLO instance
results = self._detector.model.predict(str(temp_path), verbose=False)
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
# If no pages rendered
return None