358 lines
12 KiB
Python
358 lines
12 KiB
Python
"""
|
|
Annotation Repository
|
|
|
|
Handles annotation operations following Single Responsibility Principle.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlmodel import select
|
|
|
|
from inference.data.database import get_session_context
|
|
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
|
|
from inference.data.repositories.base import BaseRepository
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AnnotationRepository(BaseRepository[AdminAnnotation]):
|
|
"""Repository for annotation management.
|
|
|
|
Handles:
|
|
- Annotation CRUD operations
|
|
- Batch annotation creation
|
|
- Annotation verification
|
|
- Annotation override tracking
|
|
"""
|
|
|
|
def create(
|
|
self,
|
|
document_id: str,
|
|
page_number: int,
|
|
class_id: int,
|
|
class_name: str,
|
|
x_center: float,
|
|
y_center: float,
|
|
width: float,
|
|
height: float,
|
|
bbox_x: int,
|
|
bbox_y: int,
|
|
bbox_width: int,
|
|
bbox_height: int,
|
|
text_value: str | None = None,
|
|
confidence: float | None = None,
|
|
source: str = "manual",
|
|
) -> str:
|
|
"""Create a new annotation.
|
|
|
|
Returns:
|
|
Annotation ID as string
|
|
"""
|
|
with get_session_context() as session:
|
|
annotation = AdminAnnotation(
|
|
document_id=UUID(document_id),
|
|
page_number=page_number,
|
|
class_id=class_id,
|
|
class_name=class_name,
|
|
x_center=x_center,
|
|
y_center=y_center,
|
|
width=width,
|
|
height=height,
|
|
bbox_x=bbox_x,
|
|
bbox_y=bbox_y,
|
|
bbox_width=bbox_width,
|
|
bbox_height=bbox_height,
|
|
text_value=text_value,
|
|
confidence=confidence,
|
|
source=source,
|
|
)
|
|
session.add(annotation)
|
|
session.flush()
|
|
return str(annotation.annotation_id)
|
|
|
|
def create_batch(
|
|
self,
|
|
annotations: list[dict[str, Any]],
|
|
) -> list[str]:
|
|
"""Create multiple annotations in a batch.
|
|
|
|
Args:
|
|
annotations: List of annotation data dicts
|
|
|
|
Returns:
|
|
List of annotation IDs
|
|
"""
|
|
with get_session_context() as session:
|
|
ids = []
|
|
for ann_data in annotations:
|
|
annotation = AdminAnnotation(
|
|
document_id=UUID(ann_data["document_id"]),
|
|
page_number=ann_data.get("page_number", 1),
|
|
class_id=ann_data["class_id"],
|
|
class_name=ann_data["class_name"],
|
|
x_center=ann_data["x_center"],
|
|
y_center=ann_data["y_center"],
|
|
width=ann_data["width"],
|
|
height=ann_data["height"],
|
|
bbox_x=ann_data["bbox_x"],
|
|
bbox_y=ann_data["bbox_y"],
|
|
bbox_width=ann_data["bbox_width"],
|
|
bbox_height=ann_data["bbox_height"],
|
|
text_value=ann_data.get("text_value"),
|
|
confidence=ann_data.get("confidence"),
|
|
source=ann_data.get("source", "auto"),
|
|
)
|
|
session.add(annotation)
|
|
session.flush()
|
|
ids.append(str(annotation.annotation_id))
|
|
return ids
|
|
|
|
def get(self, annotation_id: str) -> AdminAnnotation | None:
|
|
"""Get an annotation by ID."""
|
|
with get_session_context() as session:
|
|
result = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if result:
|
|
session.expunge(result)
|
|
return result
|
|
|
|
def get_for_document(
|
|
self,
|
|
document_id: str,
|
|
page_number: int | None = None,
|
|
) -> list[AdminAnnotation]:
|
|
"""Get all annotations for a document."""
|
|
with get_session_context() as session:
|
|
statement = select(AdminAnnotation).where(
|
|
AdminAnnotation.document_id == UUID(document_id)
|
|
)
|
|
if page_number is not None:
|
|
statement = statement.where(AdminAnnotation.page_number == page_number)
|
|
statement = statement.order_by(AdminAnnotation.class_id)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def update(
|
|
self,
|
|
annotation_id: str,
|
|
x_center: float | None = None,
|
|
y_center: float | None = None,
|
|
width: float | None = None,
|
|
height: float | None = None,
|
|
bbox_x: int | None = None,
|
|
bbox_y: int | None = None,
|
|
bbox_width: int | None = None,
|
|
bbox_height: int | None = None,
|
|
text_value: str | None = None,
|
|
class_id: int | None = None,
|
|
class_name: str | None = None,
|
|
) -> bool:
|
|
"""Update an annotation.
|
|
|
|
Returns:
|
|
True if updated, False if not found
|
|
"""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if annotation:
|
|
if x_center is not None:
|
|
annotation.x_center = x_center
|
|
if y_center is not None:
|
|
annotation.y_center = y_center
|
|
if width is not None:
|
|
annotation.width = width
|
|
if height is not None:
|
|
annotation.height = height
|
|
if bbox_x is not None:
|
|
annotation.bbox_x = bbox_x
|
|
if bbox_y is not None:
|
|
annotation.bbox_y = bbox_y
|
|
if bbox_width is not None:
|
|
annotation.bbox_width = bbox_width
|
|
if bbox_height is not None:
|
|
annotation.bbox_height = bbox_height
|
|
if text_value is not None:
|
|
annotation.text_value = text_value
|
|
if class_id is not None:
|
|
annotation.class_id = class_id
|
|
if class_name is not None:
|
|
annotation.class_name = class_name
|
|
annotation.updated_at = datetime.utcnow()
|
|
session.add(annotation)
|
|
return True
|
|
return False
|
|
|
|
def delete(self, annotation_id: str) -> bool:
|
|
"""Delete an annotation."""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if annotation:
|
|
session.delete(annotation)
|
|
session.commit()
|
|
return True
|
|
return False
|
|
|
|
def delete_for_document(
|
|
self,
|
|
document_id: str,
|
|
source: str | None = None,
|
|
) -> int:
|
|
"""Delete all annotations for a document.
|
|
|
|
Returns:
|
|
Count of deleted annotations
|
|
"""
|
|
with get_session_context() as session:
|
|
statement = select(AdminAnnotation).where(
|
|
AdminAnnotation.document_id == UUID(document_id)
|
|
)
|
|
if source:
|
|
statement = statement.where(AdminAnnotation.source == source)
|
|
annotations = session.exec(statement).all()
|
|
count = len(annotations)
|
|
for ann in annotations:
|
|
session.delete(ann)
|
|
session.commit()
|
|
return count
|
|
|
|
def verify(
|
|
self,
|
|
annotation_id: str,
|
|
admin_token: str,
|
|
) -> AdminAnnotation | None:
|
|
"""Mark an annotation as verified."""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if not annotation:
|
|
return None
|
|
|
|
annotation.is_verified = True
|
|
annotation.verified_at = datetime.utcnow()
|
|
annotation.verified_by = admin_token
|
|
annotation.updated_at = datetime.utcnow()
|
|
|
|
session.add(annotation)
|
|
session.commit()
|
|
session.refresh(annotation)
|
|
session.expunge(annotation)
|
|
return annotation
|
|
|
|
def override(
|
|
self,
|
|
annotation_id: str,
|
|
admin_token: str,
|
|
change_reason: str | None = None,
|
|
**updates: Any,
|
|
) -> AdminAnnotation | None:
|
|
"""Override an auto-generated annotation.
|
|
|
|
Creates a history record and updates the annotation.
|
|
"""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if not annotation:
|
|
return None
|
|
|
|
previous_value = {
|
|
"class_id": annotation.class_id,
|
|
"class_name": annotation.class_name,
|
|
"bbox": {
|
|
"x": annotation.bbox_x,
|
|
"y": annotation.bbox_y,
|
|
"width": annotation.bbox_width,
|
|
"height": annotation.bbox_height,
|
|
},
|
|
"normalized": {
|
|
"x_center": annotation.x_center,
|
|
"y_center": annotation.y_center,
|
|
"width": annotation.width,
|
|
"height": annotation.height,
|
|
},
|
|
"text_value": annotation.text_value,
|
|
"confidence": annotation.confidence,
|
|
"source": annotation.source,
|
|
}
|
|
|
|
for key, value in updates.items():
|
|
if hasattr(annotation, key):
|
|
setattr(annotation, key, value)
|
|
|
|
if annotation.source == "auto":
|
|
annotation.override_source = "auto"
|
|
annotation.source = "manual"
|
|
|
|
annotation.updated_at = datetime.utcnow()
|
|
session.add(annotation)
|
|
|
|
history = AnnotationHistory(
|
|
annotation_id=UUID(annotation_id),
|
|
document_id=annotation.document_id,
|
|
action="override",
|
|
previous_value=previous_value,
|
|
new_value=updates,
|
|
changed_by=admin_token,
|
|
change_reason=change_reason,
|
|
)
|
|
session.add(history)
|
|
|
|
session.commit()
|
|
session.refresh(annotation)
|
|
session.expunge(annotation)
|
|
return annotation
|
|
|
|
def create_history(
|
|
self,
|
|
annotation_id: UUID,
|
|
document_id: UUID,
|
|
action: str,
|
|
previous_value: dict[str, Any] | None = None,
|
|
new_value: dict[str, Any] | None = None,
|
|
changed_by: str | None = None,
|
|
change_reason: str | None = None,
|
|
) -> AnnotationHistory:
|
|
"""Create an annotation history record."""
|
|
with get_session_context() as session:
|
|
history = AnnotationHistory(
|
|
annotation_id=annotation_id,
|
|
document_id=document_id,
|
|
action=action,
|
|
previous_value=previous_value,
|
|
new_value=new_value,
|
|
changed_by=changed_by,
|
|
change_reason=change_reason,
|
|
)
|
|
session.add(history)
|
|
session.commit()
|
|
session.refresh(history)
|
|
session.expunge(history)
|
|
return history
|
|
|
|
def get_history(self, annotation_id: UUID) -> list[AnnotationHistory]:
|
|
"""Get history for a specific annotation."""
|
|
with get_session_context() as session:
|
|
statement = select(AnnotationHistory).where(
|
|
AnnotationHistory.annotation_id == annotation_id
|
|
).order_by(AnnotationHistory.created_at.desc())
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def get_document_history(self, document_id: UUID) -> list[AnnotationHistory]:
|
|
"""Get all annotation history for a document."""
|
|
with get_session_context() as session:
|
|
statement = select(AnnotationHistory).where(
|
|
AnnotationHistory.document_id == document_id
|
|
).order_by(AnnotationHistory.created_at.desc())
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|