WIP
This commit is contained in:
@@ -1,21 +1,24 @@
|
||||
"""
|
||||
Inference Service
|
||||
Inference Service (Adapter Layer)
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
Orchestrates technical pipeline and business domain logic.
|
||||
Acts as adapter between API layer and internal components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Generator
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from backend.domain.document_classifier import DocumentClassifier
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,9 +53,12 @@ class ServiceResult:
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
Service for running invoice field extraction (Adapter Pattern).
|
||||
|
||||
Orchestrates:
|
||||
- Technical layer: InferencePipeline, YOLODetector
|
||||
- Business layer: DocumentClassifier
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
@@ -61,6 +67,7 @@ class InferenceService:
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
document_classifier: DocumentClassifier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
@@ -71,12 +78,19 @@ class InferenceService:
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
document_classifier: Optional custom classifier (uses default if None)
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
|
||||
# Technical layer (lazy initialized)
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
|
||||
# Business layer (eagerly initialized, no heavy resources)
|
||||
self._classifier = document_classifier or DocumentClassifier()
|
||||
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
self._business_features_enabled = False
|
||||
@@ -219,22 +233,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
@@ -293,22 +297,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Include business features if extracted
|
||||
if extract_line_items:
|
||||
@@ -329,10 +323,19 @@ class InferenceService:
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
def _format_detections(self, raw_detections: list) -> list[dict]:
|
||||
"""Format raw detections for response."""
|
||||
return [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in raw_detections
|
||||
]
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization image with detections using existing detector."""
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
@@ -340,9 +343,8 @@ class InferenceService:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
@@ -351,11 +353,20 @@ class InferenceService:
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
@contextmanager
|
||||
def _temp_image_file(
|
||||
self, results_dir: Path, doc_id: str
|
||||
) -> Generator[Path, None, None]:
|
||||
"""Context manager for temporary image file with guaranteed cleanup."""
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
try:
|
||||
yield temp_path
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization for PDF (first page) using existing detector."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
@@ -369,20 +380,19 @@ class InferenceService:
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
# Use context manager for temp file to guarantee cleanup
|
||||
with self._temp_image_file(results_dir, doc_id) as temp_path:
|
||||
image.save(temp_path)
|
||||
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(temp_path), verbose=False)
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user