1157 lines
42 KiB
Python
1157 lines
42 KiB
Python
"""
|
|
Admin Database Operations
|
|
|
|
Database interface for admin document management, annotations, and training tasks.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import func
|
|
from sqlmodel import select
|
|
|
|
from src.data.database import get_session_context
|
|
from src.data.admin_models import (
|
|
AdminToken,
|
|
AdminDocument,
|
|
AdminAnnotation,
|
|
TrainingTask,
|
|
TrainingLog,
|
|
BatchUpload,
|
|
BatchUploadFile,
|
|
TrainingDocumentLink,
|
|
AnnotationHistory,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AdminDB:
|
|
"""Database interface for admin operations using SQLModel."""
|
|
|
|
# ==========================================================================
|
|
# Admin Token Operations
|
|
# ==========================================================================
|
|
|
|
def is_valid_admin_token(self, token: str) -> bool:
|
|
"""Check if admin token exists and is active."""
|
|
with get_session_context() as session:
|
|
result = session.get(AdminToken, token)
|
|
if result is None:
|
|
return False
|
|
if not result.is_active:
|
|
return False
|
|
if result.expires_at and result.expires_at < datetime.utcnow():
|
|
return False
|
|
return True
|
|
|
|
def get_admin_token(self, token: str) -> AdminToken | None:
|
|
"""Get admin token details."""
|
|
with get_session_context() as session:
|
|
result = session.get(AdminToken, token)
|
|
if result:
|
|
session.expunge(result)
|
|
return result
|
|
|
|
def create_admin_token(
|
|
self,
|
|
token: str,
|
|
name: str,
|
|
expires_at: datetime | None = None,
|
|
) -> None:
|
|
"""Create a new admin token."""
|
|
with get_session_context() as session:
|
|
existing = session.get(AdminToken, token)
|
|
if existing:
|
|
existing.name = name
|
|
existing.expires_at = expires_at
|
|
existing.is_active = True
|
|
session.add(existing)
|
|
else:
|
|
new_token = AdminToken(
|
|
token=token,
|
|
name=name,
|
|
expires_at=expires_at,
|
|
)
|
|
session.add(new_token)
|
|
|
|
def update_admin_token_usage(self, token: str) -> None:
|
|
"""Update admin token last used timestamp."""
|
|
with get_session_context() as session:
|
|
admin_token = session.get(AdminToken, token)
|
|
if admin_token:
|
|
admin_token.last_used_at = datetime.utcnow()
|
|
session.add(admin_token)
|
|
|
|
def deactivate_admin_token(self, token: str) -> bool:
|
|
"""Deactivate an admin token."""
|
|
with get_session_context() as session:
|
|
admin_token = session.get(AdminToken, token)
|
|
if admin_token:
|
|
admin_token.is_active = False
|
|
session.add(admin_token)
|
|
return True
|
|
return False
|
|
|
|
# ==========================================================================
|
|
# Document Operations
|
|
# ==========================================================================
|
|
|
|
def create_document(
|
|
self,
|
|
filename: str,
|
|
file_size: int,
|
|
content_type: str,
|
|
file_path: str,
|
|
page_count: int = 1,
|
|
upload_source: str = "ui",
|
|
csv_field_values: dict[str, Any] | None = None,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
) -> str:
|
|
"""Create a new document record."""
|
|
with get_session_context() as session:
|
|
document = AdminDocument(
|
|
filename=filename,
|
|
file_size=file_size,
|
|
content_type=content_type,
|
|
file_path=file_path,
|
|
page_count=page_count,
|
|
upload_source=upload_source,
|
|
csv_field_values=csv_field_values,
|
|
)
|
|
session.add(document)
|
|
session.flush()
|
|
return str(document.document_id)
|
|
|
|
def get_document(self, document_id: str) -> AdminDocument | None:
|
|
"""Get a document by ID."""
|
|
with get_session_context() as session:
|
|
result = session.get(AdminDocument, UUID(document_id))
|
|
if result:
|
|
session.expunge(result)
|
|
return result
|
|
|
|
def get_document_by_token(
|
|
self,
|
|
document_id: str,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
) -> AdminDocument | None:
|
|
"""Get a document by ID. Token parameter is deprecated."""
|
|
return self.get_document(document_id)
|
|
|
|
def get_documents_by_token(
|
|
self,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
status: str | None = None,
|
|
upload_source: str | None = None,
|
|
has_annotations: bool | None = None,
|
|
auto_label_status: str | None = None,
|
|
batch_id: str | None = None,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
) -> tuple[list[AdminDocument], int]:
|
|
"""Get paginated documents with optional filters. Token parameter is deprecated."""
|
|
with get_session_context() as session:
|
|
# Base where clause (no token filtering)
|
|
where_clauses = []
|
|
|
|
# Apply filters
|
|
if status:
|
|
where_clauses.append(AdminDocument.status == status)
|
|
if upload_source:
|
|
where_clauses.append(AdminDocument.upload_source == upload_source)
|
|
if auto_label_status:
|
|
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
|
|
if batch_id:
|
|
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
|
|
|
|
# Count query
|
|
count_stmt = select(func.count()).select_from(AdminDocument)
|
|
if where_clauses:
|
|
count_stmt = count_stmt.where(*where_clauses)
|
|
|
|
# For has_annotations filter, we need to join with annotations
|
|
if has_annotations is not None:
|
|
from src.data.admin_models import AdminAnnotation
|
|
|
|
if has_annotations:
|
|
# Documents WITH annotations
|
|
count_stmt = (
|
|
count_stmt
|
|
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
|
.group_by(AdminDocument.document_id)
|
|
)
|
|
else:
|
|
# Documents WITHOUT annotations - use left join and filter for null
|
|
count_stmt = (
|
|
count_stmt
|
|
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
|
.where(AdminAnnotation.annotation_id.is_(None))
|
|
)
|
|
|
|
total = session.exec(count_stmt).one()
|
|
|
|
# Fetch query
|
|
statement = select(AdminDocument)
|
|
if where_clauses:
|
|
statement = statement.where(*where_clauses)
|
|
|
|
# Apply has_annotations filter
|
|
if has_annotations is not None:
|
|
from src.data.admin_models import AdminAnnotation
|
|
|
|
if has_annotations:
|
|
statement = (
|
|
statement
|
|
.join(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
|
.group_by(AdminDocument.document_id)
|
|
)
|
|
else:
|
|
statement = (
|
|
statement
|
|
.outerjoin(AdminAnnotation, AdminAnnotation.document_id == AdminDocument.document_id)
|
|
.where(AdminAnnotation.annotation_id.is_(None))
|
|
)
|
|
|
|
statement = statement.order_by(AdminDocument.created_at.desc())
|
|
statement = statement.offset(offset).limit(limit)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results), total
|
|
|
|
def update_document_status(
|
|
self,
|
|
document_id: str,
|
|
status: str,
|
|
auto_label_status: str | None = None,
|
|
auto_label_error: str | None = None,
|
|
) -> None:
|
|
"""Update document status."""
|
|
with get_session_context() as session:
|
|
document = session.get(AdminDocument, UUID(document_id))
|
|
if document:
|
|
document.status = status
|
|
document.updated_at = datetime.utcnow()
|
|
if auto_label_status is not None:
|
|
document.auto_label_status = auto_label_status
|
|
if auto_label_error is not None:
|
|
document.auto_label_error = auto_label_error
|
|
session.add(document)
|
|
|
|
def update_document_file_path(self, document_id: str, file_path: str) -> None:
|
|
"""Update document file path."""
|
|
with get_session_context() as session:
|
|
document = session.get(AdminDocument, UUID(document_id))
|
|
if document:
|
|
document.file_path = file_path
|
|
document.updated_at = datetime.utcnow()
|
|
session.add(document)
|
|
|
|
def delete_document(self, document_id: str) -> bool:
|
|
"""Delete a document and its annotations."""
|
|
with get_session_context() as session:
|
|
document = session.get(AdminDocument, UUID(document_id))
|
|
if document:
|
|
# Delete annotations first
|
|
ann_stmt = select(AdminAnnotation).where(
|
|
AdminAnnotation.document_id == UUID(document_id)
|
|
)
|
|
annotations = session.exec(ann_stmt).all()
|
|
for ann in annotations:
|
|
session.delete(ann)
|
|
session.delete(document)
|
|
return True
|
|
return False
|
|
|
|
# ==========================================================================
|
|
# Annotation Operations
|
|
# ==========================================================================
|
|
|
|
def create_annotation(
|
|
self,
|
|
document_id: str,
|
|
page_number: int,
|
|
class_id: int,
|
|
class_name: str,
|
|
x_center: float,
|
|
y_center: float,
|
|
width: float,
|
|
height: float,
|
|
bbox_x: int,
|
|
bbox_y: int,
|
|
bbox_width: int,
|
|
bbox_height: int,
|
|
text_value: str | None = None,
|
|
confidence: float | None = None,
|
|
source: str = "manual",
|
|
) -> str:
|
|
"""Create a new annotation."""
|
|
with get_session_context() as session:
|
|
annotation = AdminAnnotation(
|
|
document_id=UUID(document_id),
|
|
page_number=page_number,
|
|
class_id=class_id,
|
|
class_name=class_name,
|
|
x_center=x_center,
|
|
y_center=y_center,
|
|
width=width,
|
|
height=height,
|
|
bbox_x=bbox_x,
|
|
bbox_y=bbox_y,
|
|
bbox_width=bbox_width,
|
|
bbox_height=bbox_height,
|
|
text_value=text_value,
|
|
confidence=confidence,
|
|
source=source,
|
|
)
|
|
session.add(annotation)
|
|
session.flush()
|
|
return str(annotation.annotation_id)
|
|
|
|
def create_annotations_batch(
|
|
self,
|
|
annotations: list[dict[str, Any]],
|
|
) -> list[str]:
|
|
"""Create multiple annotations in a batch."""
|
|
with get_session_context() as session:
|
|
ids = []
|
|
for ann_data in annotations:
|
|
annotation = AdminAnnotation(
|
|
document_id=UUID(ann_data["document_id"]),
|
|
page_number=ann_data.get("page_number", 1),
|
|
class_id=ann_data["class_id"],
|
|
class_name=ann_data["class_name"],
|
|
x_center=ann_data["x_center"],
|
|
y_center=ann_data["y_center"],
|
|
width=ann_data["width"],
|
|
height=ann_data["height"],
|
|
bbox_x=ann_data["bbox_x"],
|
|
bbox_y=ann_data["bbox_y"],
|
|
bbox_width=ann_data["bbox_width"],
|
|
bbox_height=ann_data["bbox_height"],
|
|
text_value=ann_data.get("text_value"),
|
|
confidence=ann_data.get("confidence"),
|
|
source=ann_data.get("source", "auto"),
|
|
)
|
|
session.add(annotation)
|
|
session.flush()
|
|
ids.append(str(annotation.annotation_id))
|
|
return ids
|
|
|
|
def get_annotation(self, annotation_id: str) -> AdminAnnotation | None:
|
|
"""Get an annotation by ID."""
|
|
with get_session_context() as session:
|
|
result = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if result:
|
|
session.expunge(result)
|
|
return result
|
|
|
|
def get_annotations_for_document(
|
|
self,
|
|
document_id: str,
|
|
page_number: int | None = None,
|
|
) -> list[AdminAnnotation]:
|
|
"""Get all annotations for a document."""
|
|
with get_session_context() as session:
|
|
statement = select(AdminAnnotation).where(
|
|
AdminAnnotation.document_id == UUID(document_id)
|
|
)
|
|
if page_number is not None:
|
|
statement = statement.where(AdminAnnotation.page_number == page_number)
|
|
statement = statement.order_by(AdminAnnotation.class_id)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def update_annotation(
|
|
self,
|
|
annotation_id: str,
|
|
x_center: float | None = None,
|
|
y_center: float | None = None,
|
|
width: float | None = None,
|
|
height: float | None = None,
|
|
bbox_x: int | None = None,
|
|
bbox_y: int | None = None,
|
|
bbox_width: int | None = None,
|
|
bbox_height: int | None = None,
|
|
text_value: str | None = None,
|
|
class_id: int | None = None,
|
|
class_name: str | None = None,
|
|
) -> bool:
|
|
"""Update an annotation."""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if annotation:
|
|
if x_center is not None:
|
|
annotation.x_center = x_center
|
|
if y_center is not None:
|
|
annotation.y_center = y_center
|
|
if width is not None:
|
|
annotation.width = width
|
|
if height is not None:
|
|
annotation.height = height
|
|
if bbox_x is not None:
|
|
annotation.bbox_x = bbox_x
|
|
if bbox_y is not None:
|
|
annotation.bbox_y = bbox_y
|
|
if bbox_width is not None:
|
|
annotation.bbox_width = bbox_width
|
|
if bbox_height is not None:
|
|
annotation.bbox_height = bbox_height
|
|
if text_value is not None:
|
|
annotation.text_value = text_value
|
|
if class_id is not None:
|
|
annotation.class_id = class_id
|
|
if class_name is not None:
|
|
annotation.class_name = class_name
|
|
annotation.updated_at = datetime.utcnow()
|
|
session.add(annotation)
|
|
return True
|
|
return False
|
|
|
|
def delete_annotation(self, annotation_id: str) -> bool:
|
|
"""Delete an annotation."""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if annotation:
|
|
session.delete(annotation)
|
|
return True
|
|
return False
|
|
|
|
def delete_annotations_for_document(
|
|
self,
|
|
document_id: str,
|
|
source: str | None = None,
|
|
) -> int:
|
|
"""Delete all annotations for a document. Returns count deleted."""
|
|
with get_session_context() as session:
|
|
statement = select(AdminAnnotation).where(
|
|
AdminAnnotation.document_id == UUID(document_id)
|
|
)
|
|
if source:
|
|
statement = statement.where(AdminAnnotation.source == source)
|
|
annotations = session.exec(statement).all()
|
|
count = len(annotations)
|
|
for ann in annotations:
|
|
session.delete(ann)
|
|
return count
|
|
|
|
# ==========================================================================
|
|
# Training Task Operations
|
|
# ==========================================================================
|
|
|
|
def create_training_task(
|
|
self,
|
|
admin_token: str,
|
|
name: str,
|
|
task_type: str = "train",
|
|
description: str | None = None,
|
|
config: dict[str, Any] | None = None,
|
|
scheduled_at: datetime | None = None,
|
|
cron_expression: str | None = None,
|
|
is_recurring: bool = False,
|
|
) -> str:
|
|
"""Create a new training task."""
|
|
with get_session_context() as session:
|
|
task = TrainingTask(
|
|
admin_token=admin_token,
|
|
name=name,
|
|
task_type=task_type,
|
|
description=description,
|
|
config=config,
|
|
scheduled_at=scheduled_at,
|
|
cron_expression=cron_expression,
|
|
is_recurring=is_recurring,
|
|
status="scheduled" if scheduled_at else "pending",
|
|
)
|
|
session.add(task)
|
|
session.flush()
|
|
return str(task.task_id)
|
|
|
|
def get_training_task(self, task_id: str) -> TrainingTask | None:
|
|
"""Get a training task by ID."""
|
|
with get_session_context() as session:
|
|
result = session.get(TrainingTask, UUID(task_id))
|
|
if result:
|
|
session.expunge(result)
|
|
return result
|
|
|
|
def get_training_task_by_token(
|
|
self,
|
|
task_id: str,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
) -> TrainingTask | None:
|
|
"""Get a training task by ID. Token parameter is deprecated."""
|
|
return self.get_training_task(task_id)
|
|
|
|
def get_training_tasks_by_token(
|
|
self,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
status: str | None = None,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
) -> tuple[list[TrainingTask], int]:
|
|
"""Get paginated training tasks. Token parameter is deprecated."""
|
|
with get_session_context() as session:
|
|
# Count query (no token filtering)
|
|
count_stmt = select(func.count()).select_from(TrainingTask)
|
|
if status:
|
|
count_stmt = count_stmt.where(TrainingTask.status == status)
|
|
total = session.exec(count_stmt).one()
|
|
|
|
# Fetch query (no token filtering)
|
|
statement = select(TrainingTask)
|
|
if status:
|
|
statement = statement.where(TrainingTask.status == status)
|
|
statement = statement.order_by(TrainingTask.created_at.desc())
|
|
statement = statement.offset(offset).limit(limit)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results), total
|
|
|
|
def get_pending_training_tasks(self) -> list[TrainingTask]:
|
|
"""Get pending training tasks ready to run."""
|
|
with get_session_context() as session:
|
|
now = datetime.utcnow()
|
|
statement = select(TrainingTask).where(
|
|
TrainingTask.status.in_(["pending", "scheduled"]),
|
|
(TrainingTask.scheduled_at == None) | (TrainingTask.scheduled_at <= now),
|
|
).order_by(TrainingTask.created_at)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def update_training_task_status(
|
|
self,
|
|
task_id: str,
|
|
status: str,
|
|
error_message: str | None = None,
|
|
result_metrics: dict[str, Any] | None = None,
|
|
model_path: str | None = None,
|
|
) -> None:
|
|
"""Update training task status."""
|
|
with get_session_context() as session:
|
|
task = session.get(TrainingTask, UUID(task_id))
|
|
if task:
|
|
task.status = status
|
|
task.updated_at = datetime.utcnow()
|
|
if status == "running":
|
|
task.started_at = datetime.utcnow()
|
|
elif status in ("completed", "failed"):
|
|
task.completed_at = datetime.utcnow()
|
|
if error_message is not None:
|
|
task.error_message = error_message
|
|
if result_metrics is not None:
|
|
task.result_metrics = result_metrics
|
|
if model_path is not None:
|
|
task.model_path = model_path
|
|
session.add(task)
|
|
|
|
def cancel_training_task(self, task_id: str) -> bool:
|
|
"""Cancel a training task."""
|
|
with get_session_context() as session:
|
|
task = session.get(TrainingTask, UUID(task_id))
|
|
if task and task.status in ("pending", "scheduled"):
|
|
task.status = "cancelled"
|
|
task.updated_at = datetime.utcnow()
|
|
session.add(task)
|
|
return True
|
|
return False
|
|
|
|
# ==========================================================================
|
|
# Training Log Operations
|
|
# ==========================================================================
|
|
|
|
def add_training_log(
|
|
self,
|
|
task_id: str,
|
|
level: str,
|
|
message: str,
|
|
details: dict[str, Any] | None = None,
|
|
) -> None:
|
|
"""Add a training log entry."""
|
|
with get_session_context() as session:
|
|
log = TrainingLog(
|
|
task_id=UUID(task_id),
|
|
level=level,
|
|
message=message,
|
|
details=details,
|
|
)
|
|
session.add(log)
|
|
|
|
def get_training_logs(
|
|
self,
|
|
task_id: str,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[TrainingLog]:
|
|
"""Get training logs for a task."""
|
|
with get_session_context() as session:
|
|
statement = select(TrainingLog).where(
|
|
TrainingLog.task_id == UUID(task_id)
|
|
).order_by(TrainingLog.created_at.desc()).offset(offset).limit(limit)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
# ==========================================================================
|
|
# Export Operations
|
|
# ==========================================================================
|
|
|
|
def get_labeled_documents_for_export(
|
|
self,
|
|
admin_token: str | None = None,
|
|
) -> list[AdminDocument]:
|
|
"""Get all labeled documents ready for export."""
|
|
with get_session_context() as session:
|
|
statement = select(AdminDocument).where(
|
|
AdminDocument.status == "labeled"
|
|
)
|
|
if admin_token:
|
|
statement = statement.where(AdminDocument.admin_token == admin_token)
|
|
statement = statement.order_by(AdminDocument.created_at)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def count_documents_by_status(
|
|
self,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
) -> dict[str, int]:
|
|
"""Count documents by status. Token parameter is deprecated."""
|
|
with get_session_context() as session:
|
|
statement = select(
|
|
AdminDocument.status,
|
|
func.count(AdminDocument.document_id),
|
|
).group_by(AdminDocument.status)
|
|
# No longer filter by token
|
|
|
|
results = session.exec(statement).all()
|
|
return {status: count for status, count in results}
|
|
|
|
# ==========================================================================
|
|
# Batch Upload Operations (v2)
|
|
# ==========================================================================
|
|
|
|
def create_batch_upload(
|
|
self,
|
|
admin_token: str,
|
|
filename: str,
|
|
file_size: int,
|
|
upload_source: str = "ui",
|
|
) -> BatchUpload:
|
|
"""Create a new batch upload record."""
|
|
with get_session_context() as session:
|
|
batch = BatchUpload(
|
|
admin_token=admin_token,
|
|
filename=filename,
|
|
file_size=file_size,
|
|
upload_source=upload_source,
|
|
)
|
|
session.add(batch)
|
|
session.commit()
|
|
session.refresh(batch)
|
|
session.expunge(batch)
|
|
return batch
|
|
|
|
def get_batch_upload(self, batch_id: UUID) -> BatchUpload | None:
|
|
"""Get batch upload by ID."""
|
|
with get_session_context() as session:
|
|
result = session.get(BatchUpload, batch_id)
|
|
if result:
|
|
session.expunge(result)
|
|
return result
|
|
|
|
def update_batch_upload(
|
|
self,
|
|
batch_id: UUID,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Update batch upload fields."""
|
|
with get_session_context() as session:
|
|
batch = session.get(BatchUpload, batch_id)
|
|
if batch:
|
|
for key, value in kwargs.items():
|
|
if hasattr(batch, key):
|
|
setattr(batch, key, value)
|
|
session.add(batch)
|
|
|
|
def create_batch_upload_file(
|
|
self,
|
|
batch_id: UUID,
|
|
filename: str,
|
|
**kwargs: Any,
|
|
) -> BatchUploadFile:
|
|
"""Create a batch upload file record."""
|
|
with get_session_context() as session:
|
|
file_record = BatchUploadFile(
|
|
batch_id=batch_id,
|
|
filename=filename,
|
|
**kwargs,
|
|
)
|
|
session.add(file_record)
|
|
session.commit()
|
|
session.refresh(file_record)
|
|
session.expunge(file_record)
|
|
return file_record
|
|
|
|
def update_batch_upload_file(
|
|
self,
|
|
file_id: UUID,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Update batch upload file fields."""
|
|
with get_session_context() as session:
|
|
file_record = session.get(BatchUploadFile, file_id)
|
|
if file_record:
|
|
for key, value in kwargs.items():
|
|
if hasattr(file_record, key):
|
|
setattr(file_record, key, value)
|
|
session.add(file_record)
|
|
|
|
def get_batch_upload_files(
|
|
self,
|
|
batch_id: UUID,
|
|
) -> list[BatchUploadFile]:
|
|
"""Get all files for a batch upload."""
|
|
with get_session_context() as session:
|
|
statement = select(BatchUploadFile).where(
|
|
BatchUploadFile.batch_id == batch_id
|
|
).order_by(BatchUploadFile.created_at)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def get_batch_uploads_by_token(
|
|
self,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> tuple[list[BatchUpload], int]:
|
|
"""Get paginated batch uploads. Token parameter is deprecated."""
|
|
with get_session_context() as session:
|
|
# Count query (no token filtering)
|
|
count_stmt = select(func.count()).select_from(BatchUpload)
|
|
total = session.exec(count_stmt).one()
|
|
|
|
# Fetch query (no token filtering)
|
|
statement = select(BatchUpload).order_by(
|
|
BatchUpload.created_at.desc()
|
|
).offset(offset).limit(limit)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results), total
|
|
|
|
# ==========================================================================
|
|
# Training Document Link Operations (v2)
|
|
# ==========================================================================
|
|
|
|
def create_training_document_link(
|
|
self,
|
|
task_id: UUID,
|
|
document_id: UUID,
|
|
annotation_snapshot: dict[str, Any] | None = None,
|
|
) -> TrainingDocumentLink:
|
|
"""Create a training document link."""
|
|
with get_session_context() as session:
|
|
link = TrainingDocumentLink(
|
|
task_id=task_id,
|
|
document_id=document_id,
|
|
annotation_snapshot=annotation_snapshot,
|
|
)
|
|
session.add(link)
|
|
session.commit()
|
|
session.refresh(link)
|
|
session.expunge(link)
|
|
return link
|
|
|
|
def get_training_document_links(
|
|
self,
|
|
task_id: UUID,
|
|
) -> list[TrainingDocumentLink]:
|
|
"""Get all document links for a training task."""
|
|
with get_session_context() as session:
|
|
statement = select(TrainingDocumentLink).where(
|
|
TrainingDocumentLink.task_id == task_id
|
|
).order_by(TrainingDocumentLink.created_at)
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def get_document_training_tasks(
|
|
self,
|
|
document_id: UUID,
|
|
) -> list[TrainingDocumentLink]:
|
|
"""Get all training tasks that used this document."""
|
|
with get_session_context() as session:
|
|
statement = select(TrainingDocumentLink).where(
|
|
TrainingDocumentLink.document_id == document_id
|
|
).order_by(TrainingDocumentLink.created_at.desc())
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
# ==========================================================================
|
|
# Annotation History Operations (v2)
|
|
# ==========================================================================
|
|
|
|
def create_annotation_history(
|
|
self,
|
|
annotation_id: UUID,
|
|
document_id: UUID,
|
|
action: str,
|
|
previous_value: dict[str, Any] | None = None,
|
|
new_value: dict[str, Any] | None = None,
|
|
changed_by: str | None = None,
|
|
change_reason: str | None = None,
|
|
) -> AnnotationHistory:
|
|
"""Create an annotation history record."""
|
|
with get_session_context() as session:
|
|
history = AnnotationHistory(
|
|
annotation_id=annotation_id,
|
|
document_id=document_id,
|
|
action=action,
|
|
previous_value=previous_value,
|
|
new_value=new_value,
|
|
changed_by=changed_by,
|
|
change_reason=change_reason,
|
|
)
|
|
session.add(history)
|
|
session.commit()
|
|
session.refresh(history)
|
|
session.expunge(history)
|
|
return history
|
|
|
|
def get_annotation_history(
|
|
self,
|
|
annotation_id: UUID,
|
|
) -> list[AnnotationHistory]:
|
|
"""Get history for a specific annotation."""
|
|
with get_session_context() as session:
|
|
statement = select(AnnotationHistory).where(
|
|
AnnotationHistory.annotation_id == annotation_id
|
|
).order_by(AnnotationHistory.created_at.desc())
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
def get_document_annotation_history(
|
|
self,
|
|
document_id: UUID,
|
|
) -> list[AnnotationHistory]:
|
|
"""Get all annotation history for a document."""
|
|
with get_session_context() as session:
|
|
statement = select(AnnotationHistory).where(
|
|
AnnotationHistory.document_id == document_id
|
|
).order_by(AnnotationHistory.created_at.desc())
|
|
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
return list(results)
|
|
|
|
# =========================================================================
|
|
# Annotation Lock Methods
|
|
# =========================================================================
|
|
|
|
def acquire_annotation_lock(
|
|
self,
|
|
document_id: str,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
duration_seconds: int = 300,
|
|
) -> AdminDocument | None:
|
|
"""Acquire annotation lock for a document.
|
|
|
|
Returns the updated document if lock was acquired, None if failed.
|
|
"""
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
with get_session_context() as session:
|
|
# Get document
|
|
doc = session.get(AdminDocument, UUID(document_id))
|
|
if not doc:
|
|
return None
|
|
|
|
# Check if already locked by someone else
|
|
now = datetime.now(timezone.utc)
|
|
if doc.annotation_lock_until and doc.annotation_lock_until > now:
|
|
# Document is already locked
|
|
return None
|
|
|
|
# Acquire lock
|
|
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
|
|
session.add(doc)
|
|
session.commit()
|
|
session.refresh(doc)
|
|
session.expunge(doc)
|
|
return doc
|
|
|
|
def release_annotation_lock(
|
|
self,
|
|
document_id: str,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
force: bool = False,
|
|
) -> AdminDocument | None:
|
|
"""Release annotation lock for a document.
|
|
|
|
Args:
|
|
document_id: Document UUID
|
|
admin_token: Deprecated, kept for compatibility
|
|
force: If True, release lock even if expired (admin override)
|
|
|
|
Returns the updated document if lock was released, None if failed.
|
|
"""
|
|
with get_session_context() as session:
|
|
# Get document
|
|
doc = session.get(AdminDocument, UUID(document_id))
|
|
if not doc:
|
|
return None
|
|
|
|
# Release lock
|
|
doc.annotation_lock_until = None
|
|
session.add(doc)
|
|
session.commit()
|
|
session.refresh(doc)
|
|
session.expunge(doc)
|
|
return doc
|
|
|
|
def extend_annotation_lock(
|
|
self,
|
|
document_id: str,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
additional_seconds: int = 300,
|
|
) -> AdminDocument | None:
|
|
"""Extend an existing annotation lock.
|
|
|
|
Returns the updated document if lock was extended, None if failed.
|
|
"""
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
with get_session_context() as session:
|
|
# Get document
|
|
doc = session.get(AdminDocument, UUID(document_id))
|
|
if not doc:
|
|
return None
|
|
|
|
# Check if lock exists and is still valid
|
|
now = datetime.now(timezone.utc)
|
|
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
|
|
# Lock doesn't exist or has expired
|
|
return None
|
|
|
|
# Extend lock
|
|
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
|
|
session.add(doc)
|
|
session.commit()
|
|
session.refresh(doc)
|
|
session.expunge(doc)
|
|
return doc
|
|
|
|
# ==========================================================================
|
|
# Phase 4 & 5: Training Data Management and Annotation Enhancement
|
|
# ==========================================================================
|
|
|
|
def get_documents_for_training(
|
|
self,
|
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
|
status: str = "labeled",
|
|
has_annotations: bool = True,
|
|
min_annotation_count: int | None = None,
|
|
exclude_used_in_training: bool = False,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> tuple[list[AdminDocument], int]:
|
|
"""Get documents suitable for training with filtering.
|
|
|
|
Args:
|
|
admin_token: Deprecated, kept for compatibility
|
|
status: Document status filter (default: labeled)
|
|
has_annotations: Only include documents with annotations
|
|
min_annotation_count: Minimum annotation count filter
|
|
exclude_used_in_training: Exclude documents already used in training
|
|
limit: Page size
|
|
offset: Pagination offset
|
|
|
|
Returns:
|
|
Tuple of (documents, total_count)
|
|
"""
|
|
with get_session_context() as session:
|
|
# Base query (no token filtering)
|
|
statement = select(AdminDocument).where(
|
|
AdminDocument.status == status,
|
|
)
|
|
|
|
# Filter by annotations if needed
|
|
if has_annotations or min_annotation_count:
|
|
# Join with annotations to filter
|
|
from sqlalchemy import exists
|
|
annotation_subq = (
|
|
select(func.count(AdminAnnotation.annotation_id))
|
|
.where(AdminAnnotation.document_id == AdminDocument.document_id)
|
|
.correlate(AdminDocument)
|
|
.scalar_subquery()
|
|
)
|
|
|
|
if has_annotations:
|
|
statement = statement.where(annotation_subq > 0)
|
|
|
|
if min_annotation_count:
|
|
statement = statement.where(annotation_subq >= min_annotation_count)
|
|
|
|
# Exclude documents used in training if requested
|
|
if exclude_used_in_training:
|
|
from sqlalchemy import exists
|
|
training_subq = exists(
|
|
select(1)
|
|
.select_from(TrainingDocumentLink)
|
|
.where(TrainingDocumentLink.document_id == AdminDocument.document_id)
|
|
)
|
|
statement = statement.where(~training_subq)
|
|
|
|
# Get total count
|
|
count_statement = select(func.count()).select_from(statement.subquery())
|
|
total = session.exec(count_statement).one()
|
|
|
|
# Apply pagination
|
|
statement = statement.order_by(AdminDocument.created_at.desc())
|
|
statement = statement.limit(limit).offset(offset)
|
|
|
|
# Execute query
|
|
results = session.exec(statement).all()
|
|
for r in results:
|
|
session.expunge(r)
|
|
|
|
return list(results), total
|
|
|
|
def verify_annotation(
|
|
self,
|
|
annotation_id: str,
|
|
admin_token: str,
|
|
) -> AdminAnnotation | None:
|
|
"""Mark an annotation as verified.
|
|
|
|
Args:
|
|
annotation_id: Annotation UUID
|
|
admin_token: Admin token (recorded as verified_by)
|
|
|
|
Returns:
|
|
Updated annotation or None if not found
|
|
"""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if not annotation:
|
|
return None
|
|
|
|
# Mark as verified
|
|
annotation.is_verified = True
|
|
annotation.verified_at = datetime.utcnow()
|
|
annotation.verified_by = admin_token
|
|
annotation.updated_at = datetime.utcnow()
|
|
|
|
session.add(annotation)
|
|
session.commit()
|
|
session.refresh(annotation)
|
|
session.expunge(annotation)
|
|
return annotation
|
|
|
|
def override_annotation(
|
|
self,
|
|
annotation_id: str,
|
|
admin_token: str,
|
|
change_reason: str | None = None,
|
|
**updates: Any,
|
|
) -> AdminAnnotation | None:
|
|
"""Override an auto-generated annotation.
|
|
|
|
This creates a history record and updates the annotation, marking it as
|
|
manually overridden.
|
|
|
|
Args:
|
|
annotation_id: Annotation UUID
|
|
admin_token: Admin token
|
|
change_reason: Optional reason for override
|
|
**updates: Fields to update (bbox, text_value, etc.)
|
|
|
|
Returns:
|
|
Updated annotation or None if not found
|
|
"""
|
|
with get_session_context() as session:
|
|
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
|
if not annotation:
|
|
return None
|
|
|
|
# Save previous state
|
|
previous_value = {
|
|
"class_id": annotation.class_id,
|
|
"class_name": annotation.class_name,
|
|
"bbox": {
|
|
"x": annotation.bbox_x,
|
|
"y": annotation.bbox_y,
|
|
"width": annotation.bbox_width,
|
|
"height": annotation.bbox_height,
|
|
},
|
|
"normalized": {
|
|
"x_center": annotation.x_center,
|
|
"y_center": annotation.y_center,
|
|
"width": annotation.width,
|
|
"height": annotation.height,
|
|
},
|
|
"text_value": annotation.text_value,
|
|
"confidence": annotation.confidence,
|
|
"source": annotation.source,
|
|
}
|
|
|
|
# Apply updates
|
|
for key, value in updates.items():
|
|
if hasattr(annotation, key):
|
|
setattr(annotation, key, value)
|
|
|
|
# Mark as overridden if was auto-generated
|
|
if annotation.source == "auto":
|
|
annotation.override_source = "auto"
|
|
annotation.source = "manual"
|
|
|
|
annotation.updated_at = datetime.utcnow()
|
|
session.add(annotation)
|
|
|
|
# Create history record
|
|
history = AnnotationHistory(
|
|
annotation_id=UUID(annotation_id),
|
|
document_id=annotation.document_id,
|
|
action="override",
|
|
previous_value=previous_value,
|
|
new_value=updates,
|
|
changed_by=admin_token,
|
|
change_reason=change_reason,
|
|
)
|
|
session.add(history)
|
|
|
|
session.commit()
|
|
session.refresh(annotation)
|
|
session.expunge(annotation)
|
|
return annotation
|