""" Admin Auto-Labeling Service Uses FieldMatcher to automatically create annotations from field values. """ import logging from pathlib import Path from typing import Any import numpy as np from PIL import Image from shared.config import DEFAULT_DPI from inference.data.admin_db import AdminDB from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES from shared.matcher.field_matcher import FieldMatcher from shared.ocr.paddle_ocr import OCREngine, OCRToken logger = logging.getLogger(__name__) class AutoLabelService: """Service for automatic document labeling using field matching.""" def __init__(self, ocr_engine: OCREngine | None = None): """ Initialize auto-label service. Args: ocr_engine: OCR engine instance (creates one if not provided) """ self._ocr_engine = ocr_engine self._field_matcher = FieldMatcher() @property def ocr_engine(self) -> OCREngine: """Lazy initialization of OCR engine.""" if self._ocr_engine is None: self._ocr_engine = OCREngine(lang="en") return self._ocr_engine def auto_label_document( self, document_id: str, file_path: str, field_values: dict[str, str], db: AdminDB, replace_existing: bool = False, skip_lock_check: bool = False, ) -> dict[str, Any]: """ Auto-label a document using field matching. Args: document_id: Document UUID file_path: Path to document file field_values: Dict of field_name -> value to match db: Admin database instance replace_existing: Whether to replace existing auto annotations skip_lock_check: Skip annotation lock check (for batch processing) Returns: Dict with status and annotation count """ try: # Get document info first document = db.get_document(document_id) if document is None: raise ValueError(f"Document not found: {document_id}") # Check annotation lock unless explicitly skipped if not skip_lock_check: from datetime import datetime, timezone if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until: if document.annotation_lock_until > datetime.now(timezone.utc): raise ValueError( f"Document is locked for annotation until {document.annotation_lock_until}. " "Auto-labeling skipped." ) # Update status to running db.update_document_status( document_id=document_id, status="auto_labeling", auto_label_status="running", ) # Delete existing auto annotations if requested if replace_existing: deleted = db.delete_annotations_for_document( document_id=document_id, source="auto", ) logger.info(f"Deleted {deleted} existing auto annotations") # Process document path = Path(file_path) annotations_created = 0 if path.suffix.lower() == ".pdf": # Process PDF (all pages) annotations_created = self._process_pdf( document_id, path, field_values, db ) else: # Process single image annotations_created = self._process_image( document_id, path, field_values, db, page_number=1 ) # Update document status status = "labeled" if annotations_created > 0 else "pending" db.update_document_status( document_id=document_id, status=status, auto_label_status="completed", ) return { "status": "completed", "annotations_created": annotations_created, } except Exception as e: logger.error(f"Auto-labeling failed for {document_id}: {e}") db.update_document_status( document_id=document_id, status="pending", auto_label_status="failed", auto_label_error=str(e), ) return { "status": "failed", "error": str(e), "annotations_created": 0, } def _process_pdf( self, document_id: str, pdf_path: Path, field_values: dict[str, str], db: AdminDB, ) -> int: """Process PDF document and create annotations.""" from shared.pdf.renderer import render_pdf_to_images import io total_annotations = 0 for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=DEFAULT_DPI): # Convert to numpy array image = Image.open(io.BytesIO(image_bytes)) image_array = np.array(image) # Extract tokens tokens = self.ocr_engine.extract_from_image( image_array, page_no=page_no, ) # Find matches annotations = self._find_annotations( document_id, tokens, field_values, page_number=page_no + 1, # 1-indexed image_width=image_array.shape[1], image_height=image_array.shape[0], ) # Save annotations if annotations: db.create_annotations_batch(annotations) total_annotations += len(annotations) return total_annotations def _process_image( self, document_id: str, image_path: Path, field_values: dict[str, str], db: AdminDB, page_number: int = 1, ) -> int: """Process single image and create annotations.""" # Load image image = Image.open(image_path) image_array = np.array(image) # Extract tokens tokens = self.ocr_engine.extract_from_image( image_array, page_no=0, ) # Find matches annotations = self._find_annotations( document_id, tokens, field_values, page_number=page_number, image_width=image_array.shape[1], image_height=image_array.shape[0], ) # Save annotations if annotations: db.create_annotations_batch(annotations) return len(annotations) def _find_annotations( self, document_id: str, tokens: list[OCRToken], field_values: dict[str, str], page_number: int, image_width: int, image_height: int, ) -> list[dict[str, Any]]: """Find annotations for field values using token matching.""" from shared.normalize import normalize_field annotations = [] for field_name, value in field_values.items(): if not value or not value.strip(): continue # Map field name to class ID class_id = self._get_class_id(field_name) if class_id is None: logger.warning(f"Unknown field name: {field_name}") continue class_name = FIELD_CLASSES[class_id] # Normalize value try: normalized_values = normalize_field(field_name, value) except Exception as e: logger.warning(f"Failed to normalize {field_name}={value}: {e}") normalized_values = [value] # Find matches matches = self._field_matcher.find_matches( tokens=tokens, field_name=field_name, normalized_values=normalized_values, page_no=page_number - 1, # 0-indexed for matcher ) # Take best match if matches: best_match = matches[0] bbox = best_match.bbox # (x0, y0, x1, y1) # Calculate normalized coordinates (YOLO format) x_center = (bbox[0] + bbox[2]) / 2 / image_width y_center = (bbox[1] + bbox[3]) / 2 / image_height width = (bbox[2] - bbox[0]) / image_width height = (bbox[3] - bbox[1]) / image_height # Pixel coordinates bbox_x = int(bbox[0]) bbox_y = int(bbox[1]) bbox_width = int(bbox[2] - bbox[0]) bbox_height = int(bbox[3] - bbox[1]) annotations.append({ "document_id": document_id, "page_number": page_number, "class_id": class_id, "class_name": class_name, "x_center": x_center, "y_center": y_center, "width": width, "height": height, "bbox_x": bbox_x, "bbox_y": bbox_y, "bbox_width": bbox_width, "bbox_height": bbox_height, "text_value": best_match.matched_value, "confidence": best_match.score, "source": "auto", }) return annotations def _get_class_id(self, field_name: str) -> int | None: """Map field name to class ID.""" # Direct match if field_name in FIELD_CLASS_IDS: return FIELD_CLASS_IDS[field_name] # Handle alternative names name_mapping = { "InvoiceNumber": "invoice_number", "InvoiceDate": "invoice_date", "InvoiceDueDate": "invoice_due_date", "OCR": "ocr_number", "Bankgiro": "bankgiro", "Plusgiro": "plusgiro", "Amount": "amount", "supplier_organisation_number": "supplier_organisation_number", "PaymentLine": "payment_line", "customer_number": "customer_number", } mapped_name = name_mapping.get(field_name) if mapped_name and mapped_name in FIELD_CLASS_IDS: return FIELD_CLASS_IDS[mapped_name] return None # Global service instance _auto_label_service: AutoLabelService | None = None def get_auto_label_service() -> AutoLabelService: """Get the auto-label service instance.""" global _auto_label_service if _auto_label_service is None: _auto_label_service = AutoLabelService() return _auto_label_service def reset_auto_label_service() -> None: """Reset the auto-label service (for testing).""" global _auto_label_service _auto_label_service = None