336 lines
11 KiB
Python
336 lines
11 KiB
Python
"""
|
|
Admin Auto-Labeling Service
|
|
|
|
Uses FieldMatcher to automatically create annotations from field values.
|
|
"""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from shared.config import DEFAULT_DPI
|
|
from inference.data.admin_db import AdminDB
|
|
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
|
|
from shared.matcher.field_matcher import FieldMatcher
|
|
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AutoLabelService:
|
|
"""Service for automatic document labeling using field matching."""
|
|
|
|
def __init__(self, ocr_engine: OCREngine | None = None):
|
|
"""
|
|
Initialize auto-label service.
|
|
|
|
Args:
|
|
ocr_engine: OCR engine instance (creates one if not provided)
|
|
"""
|
|
self._ocr_engine = ocr_engine
|
|
self._field_matcher = FieldMatcher()
|
|
|
|
@property
|
|
def ocr_engine(self) -> OCREngine:
|
|
"""Lazy initialization of OCR engine."""
|
|
if self._ocr_engine is None:
|
|
self._ocr_engine = OCREngine(lang="en")
|
|
return self._ocr_engine
|
|
|
|
def auto_label_document(
|
|
self,
|
|
document_id: str,
|
|
file_path: str,
|
|
field_values: dict[str, str],
|
|
db: AdminDB,
|
|
replace_existing: bool = False,
|
|
skip_lock_check: bool = False,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Auto-label a document using field matching.
|
|
|
|
Args:
|
|
document_id: Document UUID
|
|
file_path: Path to document file
|
|
field_values: Dict of field_name -> value to match
|
|
db: Admin database instance
|
|
replace_existing: Whether to replace existing auto annotations
|
|
skip_lock_check: Skip annotation lock check (for batch processing)
|
|
|
|
Returns:
|
|
Dict with status and annotation count
|
|
"""
|
|
try:
|
|
# Get document info first
|
|
document = db.get_document(document_id)
|
|
if document is None:
|
|
raise ValueError(f"Document not found: {document_id}")
|
|
|
|
# Check annotation lock unless explicitly skipped
|
|
if not skip_lock_check:
|
|
from datetime import datetime, timezone
|
|
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
|
if document.annotation_lock_until > datetime.now(timezone.utc):
|
|
raise ValueError(
|
|
f"Document is locked for annotation until {document.annotation_lock_until}. "
|
|
"Auto-labeling skipped."
|
|
)
|
|
|
|
# Update status to running
|
|
db.update_document_status(
|
|
document_id=document_id,
|
|
status="auto_labeling",
|
|
auto_label_status="running",
|
|
)
|
|
|
|
# Delete existing auto annotations if requested
|
|
if replace_existing:
|
|
deleted = db.delete_annotations_for_document(
|
|
document_id=document_id,
|
|
source="auto",
|
|
)
|
|
logger.info(f"Deleted {deleted} existing auto annotations")
|
|
|
|
# Process document
|
|
path = Path(file_path)
|
|
annotations_created = 0
|
|
|
|
if path.suffix.lower() == ".pdf":
|
|
# Process PDF (all pages)
|
|
annotations_created = self._process_pdf(
|
|
document_id, path, field_values, db
|
|
)
|
|
else:
|
|
# Process single image
|
|
annotations_created = self._process_image(
|
|
document_id, path, field_values, db, page_number=1
|
|
)
|
|
|
|
# Update document status
|
|
status = "labeled" if annotations_created > 0 else "pending"
|
|
db.update_document_status(
|
|
document_id=document_id,
|
|
status=status,
|
|
auto_label_status="completed",
|
|
)
|
|
|
|
return {
|
|
"status": "completed",
|
|
"annotations_created": annotations_created,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Auto-labeling failed for {document_id}: {e}")
|
|
db.update_document_status(
|
|
document_id=document_id,
|
|
status="pending",
|
|
auto_label_status="failed",
|
|
auto_label_error=str(e),
|
|
)
|
|
return {
|
|
"status": "failed",
|
|
"error": str(e),
|
|
"annotations_created": 0,
|
|
}
|
|
|
|
def _process_pdf(
|
|
self,
|
|
document_id: str,
|
|
pdf_path: Path,
|
|
field_values: dict[str, str],
|
|
db: AdminDB,
|
|
) -> int:
|
|
"""Process PDF document and create annotations."""
|
|
from shared.pdf.renderer import render_pdf_to_images
|
|
import io
|
|
|
|
total_annotations = 0
|
|
|
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI):
|
|
# Convert to numpy array
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
image_array = np.array(image)
|
|
|
|
# Extract tokens
|
|
tokens = self.ocr_engine.extract_from_image(
|
|
image_array,
|
|
page_no=page_no,
|
|
)
|
|
|
|
# Find matches
|
|
annotations = self._find_annotations(
|
|
document_id,
|
|
tokens,
|
|
field_values,
|
|
page_number=page_no + 1, # 1-indexed
|
|
image_width=image_array.shape[1],
|
|
image_height=image_array.shape[0],
|
|
)
|
|
|
|
# Save annotations
|
|
if annotations:
|
|
db.create_annotations_batch(annotations)
|
|
total_annotations += len(annotations)
|
|
|
|
return total_annotations
|
|
|
|
def _process_image(
|
|
self,
|
|
document_id: str,
|
|
image_path: Path,
|
|
field_values: dict[str, str],
|
|
db: AdminDB,
|
|
page_number: int = 1,
|
|
) -> int:
|
|
"""Process single image and create annotations."""
|
|
# Load image
|
|
image = Image.open(image_path)
|
|
image_array = np.array(image)
|
|
|
|
# Extract tokens
|
|
tokens = self.ocr_engine.extract_from_image(
|
|
image_array,
|
|
page_no=0,
|
|
)
|
|
|
|
# Find matches
|
|
annotations = self._find_annotations(
|
|
document_id,
|
|
tokens,
|
|
field_values,
|
|
page_number=page_number,
|
|
image_width=image_array.shape[1],
|
|
image_height=image_array.shape[0],
|
|
)
|
|
|
|
# Save annotations
|
|
if annotations:
|
|
db.create_annotations_batch(annotations)
|
|
|
|
return len(annotations)
|
|
|
|
def _find_annotations(
|
|
self,
|
|
document_id: str,
|
|
tokens: list[OCRToken],
|
|
field_values: dict[str, str],
|
|
page_number: int,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""Find annotations for field values using token matching."""
|
|
from shared.normalize import normalize_field
|
|
|
|
annotations = []
|
|
|
|
for field_name, value in field_values.items():
|
|
if not value or not value.strip():
|
|
continue
|
|
|
|
# Map field name to class ID
|
|
class_id = self._get_class_id(field_name)
|
|
if class_id is None:
|
|
logger.warning(f"Unknown field name: {field_name}")
|
|
continue
|
|
|
|
class_name = FIELD_CLASSES[class_id]
|
|
|
|
# Normalize value
|
|
try:
|
|
normalized_values = normalize_field(field_name, value)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to normalize {field_name}={value}: {e}")
|
|
normalized_values = [value]
|
|
|
|
# Find matches
|
|
matches = self._field_matcher.find_matches(
|
|
tokens=tokens,
|
|
field_name=field_name,
|
|
normalized_values=normalized_values,
|
|
page_no=page_number - 1, # 0-indexed for matcher
|
|
)
|
|
|
|
# Take best match
|
|
if matches:
|
|
best_match = matches[0]
|
|
bbox = best_match.bbox # (x0, y0, x1, y1)
|
|
|
|
# Calculate normalized coordinates (YOLO format)
|
|
x_center = (bbox[0] + bbox[2]) / 2 / image_width
|
|
y_center = (bbox[1] + bbox[3]) / 2 / image_height
|
|
width = (bbox[2] - bbox[0]) / image_width
|
|
height = (bbox[3] - bbox[1]) / image_height
|
|
|
|
# Pixel coordinates
|
|
bbox_x = int(bbox[0])
|
|
bbox_y = int(bbox[1])
|
|
bbox_width = int(bbox[2] - bbox[0])
|
|
bbox_height = int(bbox[3] - bbox[1])
|
|
|
|
annotations.append({
|
|
"document_id": document_id,
|
|
"page_number": page_number,
|
|
"class_id": class_id,
|
|
"class_name": class_name,
|
|
"x_center": x_center,
|
|
"y_center": y_center,
|
|
"width": width,
|
|
"height": height,
|
|
"bbox_x": bbox_x,
|
|
"bbox_y": bbox_y,
|
|
"bbox_width": bbox_width,
|
|
"bbox_height": bbox_height,
|
|
"text_value": best_match.matched_value,
|
|
"confidence": best_match.score,
|
|
"source": "auto",
|
|
})
|
|
|
|
return annotations
|
|
|
|
def _get_class_id(self, field_name: str) -> int | None:
|
|
"""Map field name to class ID."""
|
|
# Direct match
|
|
if field_name in FIELD_CLASS_IDS:
|
|
return FIELD_CLASS_IDS[field_name]
|
|
|
|
# Handle alternative names
|
|
name_mapping = {
|
|
"InvoiceNumber": "invoice_number",
|
|
"InvoiceDate": "invoice_date",
|
|
"InvoiceDueDate": "invoice_due_date",
|
|
"OCR": "ocr_number",
|
|
"Bankgiro": "bankgiro",
|
|
"Plusgiro": "plusgiro",
|
|
"Amount": "amount",
|
|
"supplier_organisation_number": "supplier_organisation_number",
|
|
"PaymentLine": "payment_line",
|
|
"customer_number": "customer_number",
|
|
}
|
|
|
|
mapped_name = name_mapping.get(field_name)
|
|
if mapped_name and mapped_name in FIELD_CLASS_IDS:
|
|
return FIELD_CLASS_IDS[mapped_name]
|
|
|
|
return None
|
|
|
|
|
|
# Global service instance
|
|
_auto_label_service: AutoLabelService | None = None
|
|
|
|
|
|
def get_auto_label_service() -> AutoLabelService:
|
|
"""Get the auto-label service instance."""
|
|
global _auto_label_service
|
|
if _auto_label_service is None:
|
|
_auto_label_service = AutoLabelService()
|
|
return _auto_label_service
|
|
|
|
|
|
def reset_auto_label_service() -> None:
|
|
"""Reset the auto-label service (for testing)."""
|
|
global _auto_label_service
|
|
_auto_label_service = None
|