Files
invoice-master-poc-v2/packages/inference/inference/data/repositories/annotation_repository.py
2026-02-01 22:40:41 +01:00

358 lines
12 KiB
Python

"""
Annotation Repository
Handles annotation operations following Single Responsibility Principle.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
from inference.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class AnnotationRepository(BaseRepository[AdminAnnotation]):
"""Repository for annotation management.
Handles:
- Annotation CRUD operations
- Batch annotation creation
- Annotation verification
- Annotation override tracking
"""
def create(
self,
document_id: str,
page_number: int,
class_id: int,
class_name: str,
x_center: float,
y_center: float,
width: float,
height: float,
bbox_x: int,
bbox_y: int,
bbox_width: int,
bbox_height: int,
text_value: str | None = None,
confidence: float | None = None,
source: str = "manual",
) -> str:
"""Create a new annotation.
Returns:
Annotation ID as string
"""
with get_session_context() as session:
annotation = AdminAnnotation(
document_id=UUID(document_id),
page_number=page_number,
class_id=class_id,
class_name=class_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=bbox_x,
bbox_y=bbox_y,
bbox_width=bbox_width,
bbox_height=bbox_height,
text_value=text_value,
confidence=confidence,
source=source,
)
session.add(annotation)
session.flush()
return str(annotation.annotation_id)
def create_batch(
self,
annotations: list[dict[str, Any]],
) -> list[str]:
"""Create multiple annotations in a batch.
Args:
annotations: List of annotation data dicts
Returns:
List of annotation IDs
"""
with get_session_context() as session:
ids = []
for ann_data in annotations:
annotation = AdminAnnotation(
document_id=UUID(ann_data["document_id"]),
page_number=ann_data.get("page_number", 1),
class_id=ann_data["class_id"],
class_name=ann_data["class_name"],
x_center=ann_data["x_center"],
y_center=ann_data["y_center"],
width=ann_data["width"],
height=ann_data["height"],
bbox_x=ann_data["bbox_x"],
bbox_y=ann_data["bbox_y"],
bbox_width=ann_data["bbox_width"],
bbox_height=ann_data["bbox_height"],
text_value=ann_data.get("text_value"),
confidence=ann_data.get("confidence"),
source=ann_data.get("source", "auto"),
)
session.add(annotation)
session.flush()
ids.append(str(annotation.annotation_id))
return ids
def get(self, annotation_id: str) -> AdminAnnotation | None:
"""Get an annotation by ID."""
with get_session_context() as session:
result = session.get(AdminAnnotation, UUID(annotation_id))
if result:
session.expunge(result)
return result
def get_for_document(
self,
document_id: str,
page_number: int | None = None,
) -> list[AdminAnnotation]:
"""Get all annotations for a document."""
with get_session_context() as session:
statement = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
if page_number is not None:
statement = statement.where(AdminAnnotation.page_number == page_number)
statement = statement.order_by(AdminAnnotation.class_id)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def update(
self,
annotation_id: str,
x_center: float | None = None,
y_center: float | None = None,
width: float | None = None,
height: float | None = None,
bbox_x: int | None = None,
bbox_y: int | None = None,
bbox_width: int | None = None,
bbox_height: int | None = None,
text_value: str | None = None,
class_id: int | None = None,
class_name: str | None = None,
) -> bool:
"""Update an annotation.
Returns:
True if updated, False if not found
"""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
if x_center is not None:
annotation.x_center = x_center
if y_center is not None:
annotation.y_center = y_center
if width is not None:
annotation.width = width
if height is not None:
annotation.height = height
if bbox_x is not None:
annotation.bbox_x = bbox_x
if bbox_y is not None:
annotation.bbox_y = bbox_y
if bbox_width is not None:
annotation.bbox_width = bbox_width
if bbox_height is not None:
annotation.bbox_height = bbox_height
if text_value is not None:
annotation.text_value = text_value
if class_id is not None:
annotation.class_id = class_id
if class_name is not None:
annotation.class_name = class_name
annotation.updated_at = datetime.utcnow()
session.add(annotation)
return True
return False
def delete(self, annotation_id: str) -> bool:
"""Delete an annotation."""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
session.delete(annotation)
session.commit()
return True
return False
def delete_for_document(
self,
document_id: str,
source: str | None = None,
) -> int:
"""Delete all annotations for a document.
Returns:
Count of deleted annotations
"""
with get_session_context() as session:
statement = select(AdminAnnotation).where(
AdminAnnotation.document_id == UUID(document_id)
)
if source:
statement = statement.where(AdminAnnotation.source == source)
annotations = session.exec(statement).all()
count = len(annotations)
for ann in annotations:
session.delete(ann)
session.commit()
return count
def verify(
self,
annotation_id: str,
admin_token: str,
) -> AdminAnnotation | None:
"""Mark an annotation as verified."""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if not annotation:
return None
annotation.is_verified = True
annotation.verified_at = datetime.utcnow()
annotation.verified_by = admin_token
annotation.updated_at = datetime.utcnow()
session.add(annotation)
session.commit()
session.refresh(annotation)
session.expunge(annotation)
return annotation
def override(
self,
annotation_id: str,
admin_token: str,
change_reason: str | None = None,
**updates: Any,
) -> AdminAnnotation | None:
"""Override an auto-generated annotation.
Creates a history record and updates the annotation.
"""
with get_session_context() as session:
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if not annotation:
return None
previous_value = {
"class_id": annotation.class_id,
"class_name": annotation.class_name,
"bbox": {
"x": annotation.bbox_x,
"y": annotation.bbox_y,
"width": annotation.bbox_width,
"height": annotation.bbox_height,
},
"normalized": {
"x_center": annotation.x_center,
"y_center": annotation.y_center,
"width": annotation.width,
"height": annotation.height,
},
"text_value": annotation.text_value,
"confidence": annotation.confidence,
"source": annotation.source,
}
for key, value in updates.items():
if hasattr(annotation, key):
setattr(annotation, key, value)
if annotation.source == "auto":
annotation.override_source = "auto"
annotation.source = "manual"
annotation.updated_at = datetime.utcnow()
session.add(annotation)
history = AnnotationHistory(
annotation_id=UUID(annotation_id),
document_id=annotation.document_id,
action="override",
previous_value=previous_value,
new_value=updates,
changed_by=admin_token,
change_reason=change_reason,
)
session.add(history)
session.commit()
session.refresh(annotation)
session.expunge(annotation)
return annotation
def create_history(
self,
annotation_id: UUID,
document_id: UUID,
action: str,
previous_value: dict[str, Any] | None = None,
new_value: dict[str, Any] | None = None,
changed_by: str | None = None,
change_reason: str | None = None,
) -> AnnotationHistory:
"""Create an annotation history record."""
with get_session_context() as session:
history = AnnotationHistory(
annotation_id=annotation_id,
document_id=document_id,
action=action,
previous_value=previous_value,
new_value=new_value,
changed_by=changed_by,
change_reason=change_reason,
)
session.add(history)
session.commit()
session.refresh(history)
session.expunge(history)
return history
def get_history(self, annotation_id: UUID) -> list[AnnotationHistory]:
"""Get history for a specific annotation."""
with get_session_context() as session:
statement = select(AnnotationHistory).where(
AnnotationHistory.annotation_id == annotation_id
).order_by(AnnotationHistory.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_document_history(self, document_id: UUID) -> list[AnnotationHistory]:
"""Get all annotation history for a document."""
with get_session_context() as session:
statement = select(AnnotationHistory).where(
AnnotationHistory.document_id == document_id
).order_by(AnnotationHistory.created_at.desc())
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)