Files
invoice-master-poc-v2/packages/inference/inference/web/services/autolabel.py
Yaojia Wang a516de4320 WIP
2026-02-01 00:08:40 +01:00

336 lines
11 KiB
Python

"""
Admin Auto-Labeling Service
Uses FieldMatcher to automatically create annotations from field values.
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.matcher.field_matcher import FieldMatcher
from shared.ocr.paddle_ocr import OCREngine, OCRToken
logger = logging.getLogger(__name__)
class AutoLabelService:
"""Service for automatic document labeling using field matching."""
def __init__(self, ocr_engine: OCREngine | None = None):
"""
Initialize auto-label service.
Args:
ocr_engine: OCR engine instance (creates one if not provided)
"""
self._ocr_engine = ocr_engine
self._field_matcher = FieldMatcher()
@property
def ocr_engine(self) -> OCREngine:
"""Lazy initialization of OCR engine."""
if self._ocr_engine is None:
self._ocr_engine = OCREngine(lang="en")
return self._ocr_engine
def auto_label_document(
self,
document_id: str,
file_path: str,
field_values: dict[str, str],
db: AdminDB,
replace_existing: bool = False,
skip_lock_check: bool = False,
) -> dict[str, Any]:
"""
Auto-label a document using field matching.
Args:
document_id: Document UUID
file_path: Path to document file
field_values: Dict of field_name -> value to match
db: Admin database instance
replace_existing: Whether to replace existing auto annotations
skip_lock_check: Skip annotation lock check (for batch processing)
Returns:
Dict with status and annotation count
"""
try:
# Get document info first
document = db.get_document(document_id)
if document is None:
raise ValueError(f"Document not found: {document_id}")
# Check annotation lock unless explicitly skipped
if not skip_lock_check:
from datetime import datetime, timezone
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
if document.annotation_lock_until > datetime.now(timezone.utc):
raise ValueError(
f"Document is locked for annotation until {document.annotation_lock_until}. "
"Auto-labeling skipped."
)
# Update status to running
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
# Delete existing auto annotations if requested
if replace_existing:
deleted = db.delete_annotations_for_document(
document_id=document_id,
source="auto",
)
logger.info(f"Deleted {deleted} existing auto annotations")
# Process document
path = Path(file_path)
annotations_created = 0
if path.suffix.lower() == ".pdf":
# Process PDF (all pages)
annotations_created = self._process_pdf(
document_id, path, field_values, db
)
else:
# Process single image
annotations_created = self._process_image(
document_id, path, field_values, db, page_number=1
)
# Update document status
status = "labeled" if annotations_created > 0 else "pending"
db.update_document_status(
document_id=document_id,
status=status,
auto_label_status="completed",
)
return {
"status": "completed",
"annotations_created": annotations_created,
}
except Exception as e:
logger.error(f"Auto-labeling failed for {document_id}: {e}")
db.update_document_status(
document_id=document_id,
status="pending",
auto_label_status="failed",
auto_label_error=str(e),
)
return {
"status": "failed",
"error": str(e),
"annotations_created": 0,
}
def _process_pdf(
self,
document_id: str,
pdf_path: Path,
field_values: dict[str, str],
db: AdminDB,
) -> int:
"""Process PDF document and create annotations."""
from shared.pdf.renderer import render_pdf_to_images
import io
total_annotations = 0
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=page_no,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_no + 1, # 1-indexed
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
total_annotations += len(annotations)
return total_annotations
def _process_image(
self,
document_id: str,
image_path: Path,
field_values: dict[str, str],
db: AdminDB,
page_number: int = 1,
) -> int:
"""Process single image and create annotations."""
# Load image
image = Image.open(image_path)
image_array = np.array(image)
# Extract tokens
tokens = self.ocr_engine.extract_from_image(
image_array,
page_no=0,
)
# Find matches
annotations = self._find_annotations(
document_id,
tokens,
field_values,
page_number=page_number,
image_width=image_array.shape[1],
image_height=image_array.shape[0],
)
# Save annotations
if annotations:
db.create_annotations_batch(annotations)
return len(annotations)
def _find_annotations(
self,
document_id: str,
tokens: list[OCRToken],
field_values: dict[str, str],
page_number: int,
image_width: int,
image_height: int,
) -> list[dict[str, Any]]:
"""Find annotations for field values using token matching."""
from shared.normalize import normalize_field
annotations = []
for field_name, value in field_values.items():
if not value or not value.strip():
continue
# Map field name to class ID
class_id = self._get_class_id(field_name)
if class_id is None:
logger.warning(f"Unknown field name: {field_name}")
continue
class_name = FIELD_CLASSES[class_id]
# Normalize value
try:
normalized_values = normalize_field(field_name, value)
except Exception as e:
logger.warning(f"Failed to normalize {field_name}={value}: {e}")
normalized_values = [value]
# Find matches
matches = self._field_matcher.find_matches(
tokens=tokens,
field_name=field_name,
normalized_values=normalized_values,
page_no=page_number - 1, # 0-indexed for matcher
)
# Take best match
if matches:
best_match = matches[0]
bbox = best_match.bbox # (x0, y0, x1, y1)
# Calculate normalized coordinates (YOLO format)
x_center = (bbox[0] + bbox[2]) / 2 / image_width
y_center = (bbox[1] + bbox[3]) / 2 / image_height
width = (bbox[2] - bbox[0]) / image_width
height = (bbox[3] - bbox[1]) / image_height
# Pixel coordinates
bbox_x = int(bbox[0])
bbox_y = int(bbox[1])
bbox_width = int(bbox[2] - bbox[0])
bbox_height = int(bbox[3] - bbox[1])
annotations.append({
"document_id": document_id,
"page_number": page_number,
"class_id": class_id,
"class_name": class_name,
"x_center": x_center,
"y_center": y_center,
"width": width,
"height": height,
"bbox_x": bbox_x,
"bbox_y": bbox_y,
"bbox_width": bbox_width,
"bbox_height": bbox_height,
"text_value": best_match.matched_value,
"confidence": best_match.score,
"source": "auto",
})
return annotations
def _get_class_id(self, field_name: str) -> int | None:
"""Map field name to class ID."""
# Direct match
if field_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[field_name]
# Handle alternative names
name_mapping = {
"InvoiceNumber": "invoice_number",
"InvoiceDate": "invoice_date",
"InvoiceDueDate": "invoice_due_date",
"OCR": "ocr_number",
"Bankgiro": "bankgiro",
"Plusgiro": "plusgiro",
"Amount": "amount",
"supplier_organisation_number": "supplier_organisation_number",
"PaymentLine": "payment_line",
"customer_number": "customer_number",
}
mapped_name = name_mapping.get(field_name)
if mapped_name and mapped_name in FIELD_CLASS_IDS:
return FIELD_CLASS_IDS[mapped_name]
return None
# Global service instance
_auto_label_service: AutoLabelService | None = None
def get_auto_label_service() -> AutoLabelService:
"""Get the auto-label service instance."""
global _auto_label_service
if _auto_label_service is None:
_auto_label_service = AutoLabelService()
return _auto_label_service
def reset_auto_label_service() -> None:
"""Reset the auto-label service (for testing)."""
global _auto_label_service
_auto_label_service = None