re-structure

This commit is contained in:
Yaojia Wang
2026-02-01 22:55:31 +01:00
parent 400b12a967
commit b602d0a340
176 changed files with 856 additions and 853 deletions

View File

@@ -0,0 +1,360 @@
"""
Inference Service
Business logic for invoice field extraction.
"""
from __future__ import annotations
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable
import numpy as np
from PIL import Image
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig
logger = logging.getLogger(__name__)
# Type alias for model path resolver function
ModelPathResolver = Callable[[], Path | None]
@dataclass
class ServiceResult:
"""Result from inference service."""
document_id: str
success: bool = False
document_type: str = "invoice" # "invoice" or "letter"
fields: dict[str, str | None] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
detections: list[dict] = field(default_factory=list)
processing_time_ms: float = 0.0
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
class InferenceService:
"""
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration (default model settings)
storage_config: Storage configuration
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
self._pipeline = None
self._detector = None
self._is_initialized = False
self._current_model_path: Path | None = None
def _resolve_model_path(self) -> Path:
"""Resolve the model path to use for inference.
Priority:
1. Active model from database (via resolver)
2. Default model from config
"""
if self._model_path_resolver:
try:
db_model_path = self._model_path_resolver()
if db_model_path and Path(db_model_path).exists():
logger.info(f"Using active model from database: {db_model_path}")
return Path(db_model_path)
elif db_model_path:
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
except Exception as e:
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
return self.model_config.model_path
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
if self._is_initialized:
return
logger.info("Initializing inference service...")
start_time = time.time()
try:
from backend.pipeline.pipeline import InferencePipeline
from backend.pipeline.yolo_detector import YOLODetector
# Resolve model path (from DB or config)
model_path = self._resolve_model_path()
self._current_model_path = model_path
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
enable_fallback=True,
)
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
def reload_model(self) -> bool:
"""Reload the model if active model has changed.
Returns:
True if model was reloaded, False if no change needed.
"""
new_model_path = self._resolve_model_path()
if self._current_model_path == new_model_path:
logger.debug("Model unchanged, no reload needed")
return False
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
self._is_initialized = False
self._pipeline = None
self._detector = None
self.initialize()
return True
@property
def current_model_path(self) -> Path | None:
"""Get the currently loaded model path."""
return self._current_model_path
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""
return self._is_initialized
@property
def gpu_available(self) -> bool:
"""Check if GPU is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def process_image(
self,
image_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process an image file and extract invoice fields.
Args:
image_path: Path to image file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections for visualization
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_visualization(image_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def process_pdf(
self,
pdf_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process a PDF file and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization (render first page)
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
"""Save visualization image with detections."""
from ultralytics import YOLO
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page)."""
from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Render first page
for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = results_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path
# If no pages rendered
return None