re-structure
This commit is contained in:
360
packages/backend/backend/web/services/inference.py
Normal file
360
packages/backend/backend/web/services/inference.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Inference Service
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import ModelConfig, StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Type alias for model path resolver function
|
||||
ModelPathResolver = Callable[[], Path | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceResult:
|
||||
"""Result from inference service."""
|
||||
|
||||
document_id: str
|
||||
success: bool = False
|
||||
document_type: str = "invoice" # "invoice" or "letter"
|
||||
fields: dict[str, str | None] = field(default_factory=dict)
|
||||
confidence: dict[str, float] = field(default_factory=dict)
|
||||
detections: list[dict] = field(default_factory=list)
|
||||
processing_time_ms: float = 0.0
|
||||
visualization_path: Path | None = None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration (default model settings)
|
||||
storage_config: Storage configuration
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
|
||||
def _resolve_model_path(self) -> Path:
|
||||
"""Resolve the model path to use for inference.
|
||||
|
||||
Priority:
|
||||
1. Active model from database (via resolver)
|
||||
2. Default model from config
|
||||
"""
|
||||
if self._model_path_resolver:
|
||||
try:
|
||||
db_model_path = self._model_path_resolver()
|
||||
if db_model_path and Path(db_model_path).exists():
|
||||
logger.info(f"Using active model from database: {db_model_path}")
|
||||
return Path(db_model_path)
|
||||
elif db_model_path:
|
||||
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
|
||||
|
||||
return self.model_config.model_path
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading)."""
|
||||
if self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing inference service...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from backend.pipeline.pipeline import InferencePipeline
|
||||
from backend.pipeline.yolo_detector import YOLODetector
|
||||
|
||||
# Resolve model path (from DB or config)
|
||||
model_path = self._resolve_model_path()
|
||||
self._current_model_path = model_path
|
||||
|
||||
# Initialize YOLO detector for visualization
|
||||
self._detector = YOLODetector(
|
||||
str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||
)
|
||||
|
||||
# Initialize full pipeline
|
||||
self._pipeline = InferencePipeline(
|
||||
model_path=str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
use_gpu=self.model_config.use_gpu,
|
||||
dpi=self.model_config.dpi,
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
self._is_initialized = True
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
raise
|
||||
|
||||
def reload_model(self) -> bool:
|
||||
"""Reload the model if active model has changed.
|
||||
|
||||
Returns:
|
||||
True if model was reloaded, False if no change needed.
|
||||
"""
|
||||
new_model_path = self._resolve_model_path()
|
||||
|
||||
if self._current_model_path == new_model_path:
|
||||
logger.debug("Model unchanged, no reload needed")
|
||||
return False
|
||||
|
||||
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
|
||||
self._is_initialized = False
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self.initialize()
|
||||
return True
|
||||
|
||||
@property
|
||||
def current_model_path(self) -> Path | None:
|
||||
"""Get the currently loaded model path."""
|
||||
return self._current_model_path
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if service is initialized."""
|
||||
return self._is_initialized
|
||||
|
||||
@property
|
||||
def gpu_available(self) -> bool:
|
||||
"""Check if GPU is available."""
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def process_image(
|
||||
self,
|
||||
image_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process an image file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_visualization(image_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def process_pdf(
|
||||
self,
|
||||
pdf_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process a PDF file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization (render first page)
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF {pdf_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
if results_dir is None:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Render first page
|
||||
for page_no, image_bytes in render_pdf_to_images(
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
Reference in New Issue
Block a user