613 lines
21 KiB
Python
613 lines
21 KiB
Python
"""
|
|
Admin Document Routes
|
|
|
|
FastAPI endpoints for admin document management.
|
|
"""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
|
|
|
from inference.web.config import DEFAULT_DPI, StorageConfig
|
|
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
|
from inference.web.schemas.admin import (
|
|
AnnotationItem,
|
|
AnnotationSource,
|
|
AutoLabelStatus,
|
|
BoundingBox,
|
|
DocumentDetailResponse,
|
|
DocumentItem,
|
|
DocumentListResponse,
|
|
DocumentStatus,
|
|
DocumentStatsResponse,
|
|
DocumentUploadResponse,
|
|
ModelMetrics,
|
|
TrainingHistoryItem,
|
|
)
|
|
from inference.web.schemas.common import ErrorResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _validate_uuid(value: str, name: str = "ID") -> None:
|
|
"""Validate UUID format."""
|
|
try:
|
|
UUID(value)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid {name} format. Must be a valid UUID.",
|
|
)
|
|
|
|
|
|
def _convert_pdf_to_images(
|
|
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
|
) -> None:
|
|
"""Convert PDF pages to images for annotation."""
|
|
import fitz
|
|
|
|
doc_images_dir = images_dir / document_id
|
|
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
|
|
|
for page_num in range(page_count):
|
|
page = pdf_doc[page_num]
|
|
# Render at configured DPI for consistency with training
|
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
|
pix = page.get_pixmap(matrix=mat)
|
|
|
|
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
|
pix.save(str(image_path))
|
|
|
|
pdf_doc.close()
|
|
|
|
|
|
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|
"""Create admin documents router."""
|
|
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
|
|
|
|
# Directories are created by StorageConfig.__post_init__
|
|
allowed_extensions = storage_config.allowed_extensions
|
|
|
|
@router.post(
|
|
"",
|
|
response_model=DocumentUploadResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse, "description": "Invalid file"},
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Upload document",
|
|
description="Upload a PDF or image document for labeling.",
|
|
)
|
|
async def upload_document(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
file: UploadFile = File(..., description="PDF or image file"),
|
|
auto_label: Annotated[
|
|
bool,
|
|
Query(description="Trigger auto-labeling after upload"),
|
|
] = True,
|
|
group_key: Annotated[
|
|
str | None,
|
|
Query(description="Optional group key for document organization", max_length=255),
|
|
] = None,
|
|
) -> DocumentUploadResponse:
|
|
"""Upload a document for labeling."""
|
|
# Validate group_key length
|
|
if group_key and len(group_key) > 255:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Group key must be 255 characters or less",
|
|
)
|
|
|
|
# Validate filename
|
|
if not file.filename:
|
|
raise HTTPException(status_code=400, detail="Filename is required")
|
|
|
|
# Validate extension
|
|
file_ext = Path(file.filename).suffix.lower()
|
|
if file_ext not in allowed_extensions:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported file type: {file_ext}. "
|
|
f"Allowed: {', '.join(allowed_extensions)}",
|
|
)
|
|
|
|
# Read file content
|
|
try:
|
|
content = await file.read()
|
|
except Exception as e:
|
|
logger.error(f"Failed to read uploaded file: {e}")
|
|
raise HTTPException(status_code=400, detail="Failed to read file")
|
|
|
|
# Get page count (for PDF)
|
|
page_count = 1
|
|
if file_ext == ".pdf":
|
|
try:
|
|
import fitz
|
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
|
page_count = len(pdf_doc)
|
|
pdf_doc.close()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get PDF page count: {e}")
|
|
|
|
# Create document record (token only used for auth, not stored)
|
|
document_id = db.create_document(
|
|
filename=file.filename,
|
|
file_size=len(content),
|
|
content_type=file.content_type or "application/octet-stream",
|
|
file_path="", # Will update after saving
|
|
page_count=page_count,
|
|
group_key=group_key,
|
|
)
|
|
|
|
# Save file to admin uploads
|
|
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
|
try:
|
|
file_path.write_bytes(content)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save file: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to save file")
|
|
|
|
# Update file path in database
|
|
from inference.data.database import get_session_context
|
|
from inference.data.admin_models import AdminDocument
|
|
with get_session_context() as session:
|
|
doc = session.get(AdminDocument, UUID(document_id))
|
|
if doc:
|
|
doc.file_path = str(file_path)
|
|
session.add(doc)
|
|
|
|
# Convert PDF to images for annotation
|
|
if file_ext == ".pdf":
|
|
try:
|
|
_convert_pdf_to_images(
|
|
document_id, content, page_count,
|
|
storage_config.admin_images_dir, storage_config.dpi
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to convert PDF to images: {e}")
|
|
|
|
# Trigger auto-labeling if requested
|
|
auto_label_started = False
|
|
if auto_label:
|
|
# Auto-labeling will be triggered by a background task
|
|
db.update_document_status(
|
|
document_id=document_id,
|
|
status="auto_labeling",
|
|
auto_label_status="running",
|
|
)
|
|
auto_label_started = True
|
|
|
|
return DocumentUploadResponse(
|
|
document_id=document_id,
|
|
filename=file.filename,
|
|
file_size=len(content),
|
|
page_count=page_count,
|
|
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
|
group_key=group_key,
|
|
auto_label_started=auto_label_started,
|
|
message="Document uploaded successfully",
|
|
)
|
|
|
|
@router.get(
|
|
"",
|
|
response_model=DocumentListResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="List documents",
|
|
description="List all documents for the current admin.",
|
|
)
|
|
async def list_documents(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
status: Annotated[
|
|
str | None,
|
|
Query(description="Filter by status"),
|
|
] = None,
|
|
upload_source: Annotated[
|
|
str | None,
|
|
Query(description="Filter by upload source (ui or api)"),
|
|
] = None,
|
|
has_annotations: Annotated[
|
|
bool | None,
|
|
Query(description="Filter by annotation presence"),
|
|
] = None,
|
|
auto_label_status: Annotated[
|
|
str | None,
|
|
Query(description="Filter by auto-label status"),
|
|
] = None,
|
|
batch_id: Annotated[
|
|
str | None,
|
|
Query(description="Filter by batch ID"),
|
|
] = None,
|
|
limit: Annotated[
|
|
int,
|
|
Query(ge=1, le=100, description="Page size"),
|
|
] = 20,
|
|
offset: Annotated[
|
|
int,
|
|
Query(ge=0, description="Offset"),
|
|
] = 0,
|
|
) -> DocumentListResponse:
|
|
"""List documents."""
|
|
# Validate status
|
|
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid status: {status}",
|
|
)
|
|
|
|
# Validate upload_source
|
|
if upload_source and upload_source not in ("ui", "api"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid upload_source: {upload_source}",
|
|
)
|
|
|
|
# Validate auto_label_status
|
|
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid auto_label_status: {auto_label_status}",
|
|
)
|
|
|
|
documents, total = db.get_documents_by_token(
|
|
admin_token=admin_token,
|
|
status=status,
|
|
upload_source=upload_source,
|
|
has_annotations=has_annotations,
|
|
auto_label_status=auto_label_status,
|
|
batch_id=batch_id,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
# Get annotation counts and build items
|
|
items = []
|
|
for doc in documents:
|
|
annotations = db.get_annotations_for_document(str(doc.document_id))
|
|
|
|
# Determine if document can be annotated (not locked)
|
|
can_annotate = True
|
|
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
|
|
from datetime import datetime, timezone
|
|
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
|
|
|
|
items.append(
|
|
DocumentItem(
|
|
document_id=str(doc.document_id),
|
|
filename=doc.filename,
|
|
file_size=doc.file_size,
|
|
page_count=doc.page_count,
|
|
status=DocumentStatus(doc.status),
|
|
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
|
|
annotation_count=len(annotations),
|
|
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
|
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
|
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
|
can_annotate=can_annotate,
|
|
created_at=doc.created_at,
|
|
updated_at=doc.updated_at,
|
|
)
|
|
)
|
|
|
|
return DocumentListResponse(
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
documents=items,
|
|
)
|
|
|
|
@router.get(
|
|
"/stats",
|
|
response_model=DocumentStatsResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Get document statistics",
|
|
description="Get document count by status.",
|
|
)
|
|
async def get_document_stats(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
) -> DocumentStatsResponse:
|
|
"""Get document statistics."""
|
|
counts = db.count_documents_by_status(admin_token)
|
|
|
|
return DocumentStatsResponse(
|
|
total=sum(counts.values()),
|
|
pending=counts.get("pending", 0),
|
|
auto_labeling=counts.get("auto_labeling", 0),
|
|
labeled=counts.get("labeled", 0),
|
|
exported=counts.get("exported", 0),
|
|
)
|
|
|
|
@router.get(
|
|
"/{document_id}",
|
|
response_model=DocumentDetailResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
|
},
|
|
summary="Get document detail",
|
|
description="Get document details with annotations.",
|
|
)
|
|
async def get_document(
|
|
document_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
) -> DocumentDetailResponse:
|
|
"""Get document details."""
|
|
_validate_uuid(document_id, "document_id")
|
|
|
|
document = db.get_document_by_token(document_id, admin_token)
|
|
if document is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Document not found or does not belong to this token",
|
|
)
|
|
|
|
# Get annotations
|
|
raw_annotations = db.get_annotations_for_document(document_id)
|
|
annotations = [
|
|
AnnotationItem(
|
|
annotation_id=str(ann.annotation_id),
|
|
page_number=ann.page_number,
|
|
class_id=ann.class_id,
|
|
class_name=ann.class_name,
|
|
bbox=BoundingBox(
|
|
x=ann.bbox_x,
|
|
y=ann.bbox_y,
|
|
width=ann.bbox_width,
|
|
height=ann.bbox_height,
|
|
),
|
|
normalized_bbox={
|
|
"x_center": ann.x_center,
|
|
"y_center": ann.y_center,
|
|
"width": ann.width,
|
|
"height": ann.height,
|
|
},
|
|
text_value=ann.text_value,
|
|
confidence=ann.confidence,
|
|
source=AnnotationSource(ann.source),
|
|
created_at=ann.created_at,
|
|
)
|
|
for ann in raw_annotations
|
|
]
|
|
|
|
# Generate image URLs
|
|
image_urls = []
|
|
for page in range(1, document.page_count + 1):
|
|
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
|
|
|
|
# Determine if document can be annotated (not locked)
|
|
can_annotate = True
|
|
annotation_lock_until = None
|
|
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
|
from datetime import datetime, timezone
|
|
annotation_lock_until = document.annotation_lock_until
|
|
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
|
|
|
|
# Get CSV field values if available
|
|
csv_field_values = None
|
|
if hasattr(document, 'csv_field_values') and document.csv_field_values:
|
|
csv_field_values = document.csv_field_values
|
|
|
|
# Get training history (Phase 5)
|
|
training_history = []
|
|
training_links = db.get_document_training_tasks(document.document_id)
|
|
for link in training_links:
|
|
# Get task details
|
|
task = db.get_training_task(str(link.task_id))
|
|
if task:
|
|
# Build metrics
|
|
metrics = None
|
|
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
|
|
metrics = ModelMetrics(
|
|
mAP=task.metrics_mAP,
|
|
precision=task.metrics_precision,
|
|
recall=task.metrics_recall,
|
|
)
|
|
|
|
training_history.append(
|
|
TrainingHistoryItem(
|
|
task_id=str(link.task_id),
|
|
name=task.name,
|
|
trained_at=link.created_at,
|
|
model_metrics=metrics,
|
|
)
|
|
)
|
|
|
|
return DocumentDetailResponse(
|
|
document_id=str(document.document_id),
|
|
filename=document.filename,
|
|
file_size=document.file_size,
|
|
content_type=document.content_type,
|
|
page_count=document.page_count,
|
|
status=DocumentStatus(document.status),
|
|
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
|
|
auto_label_error=document.auto_label_error,
|
|
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
|
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
|
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
|
csv_field_values=csv_field_values,
|
|
can_annotate=can_annotate,
|
|
annotation_lock_until=annotation_lock_until,
|
|
annotations=annotations,
|
|
image_urls=image_urls,
|
|
training_history=training_history,
|
|
created_at=document.created_at,
|
|
updated_at=document.updated_at,
|
|
)
|
|
|
|
@router.delete(
|
|
"/{document_id}",
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
|
},
|
|
summary="Delete document",
|
|
description="Delete a document and its annotations.",
|
|
)
|
|
async def delete_document(
|
|
document_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
) -> dict:
|
|
"""Delete a document."""
|
|
_validate_uuid(document_id, "document_id")
|
|
|
|
# Verify ownership
|
|
document = db.get_document_by_token(document_id, admin_token)
|
|
if document is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Document not found or does not belong to this token",
|
|
)
|
|
|
|
# Delete file
|
|
file_path = Path(document.file_path)
|
|
if file_path.exists():
|
|
file_path.unlink()
|
|
|
|
# Delete images
|
|
images_dir = ADMIN_IMAGES_DIR / document_id
|
|
if images_dir.exists():
|
|
import shutil
|
|
shutil.rmtree(images_dir)
|
|
|
|
# Delete from database
|
|
db.delete_document(document_id)
|
|
|
|
return {
|
|
"status": "deleted",
|
|
"document_id": document_id,
|
|
"message": "Document deleted successfully",
|
|
}
|
|
|
|
@router.patch(
|
|
"/{document_id}/status",
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
|
},
|
|
summary="Update document status",
|
|
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
|
|
)
|
|
async def update_document_status(
|
|
document_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
status: Annotated[
|
|
str,
|
|
Query(description="New status"),
|
|
],
|
|
) -> dict:
|
|
"""Update document status.
|
|
|
|
When status is set to 'labeled', the annotations are automatically
|
|
saved to PostgreSQL documents/field_results tables for consistency
|
|
with CLI auto-label workflow.
|
|
"""
|
|
_validate_uuid(document_id, "document_id")
|
|
|
|
# Validate status
|
|
if status not in ("pending", "labeled", "exported"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid status: {status}",
|
|
)
|
|
|
|
# Verify ownership
|
|
document = db.get_document_by_token(document_id, admin_token)
|
|
if document is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Document not found or does not belong to this token",
|
|
)
|
|
|
|
# If marking as labeled, save annotations to PostgreSQL DocumentDB
|
|
db_save_result = None
|
|
if status == "labeled":
|
|
from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
|
|
|
|
# Get all annotations for this document
|
|
annotations = db.get_annotations_for_document(document_id)
|
|
|
|
if annotations:
|
|
db_save_result = save_manual_annotations_to_document_db(
|
|
document=document,
|
|
annotations=annotations,
|
|
db=db,
|
|
)
|
|
|
|
db.update_document_status(document_id, status)
|
|
|
|
response = {
|
|
"status": "updated",
|
|
"document_id": document_id,
|
|
"new_status": status,
|
|
"message": "Document status updated",
|
|
}
|
|
|
|
# Include PostgreSQL save result if applicable
|
|
if db_save_result:
|
|
response["document_db_saved"] = db_save_result.get("success", False)
|
|
response["fields_saved"] = db_save_result.get("fields_saved", 0)
|
|
|
|
return response
|
|
|
|
@router.patch(
|
|
"/{document_id}/group-key",
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
|
},
|
|
summary="Update document group key",
|
|
description="Update the group key for a document.",
|
|
)
|
|
async def update_document_group_key(
|
|
document_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
group_key: Annotated[
|
|
str | None,
|
|
Query(description="New group key (null to clear)"),
|
|
] = None,
|
|
) -> dict:
|
|
"""Update document group key."""
|
|
_validate_uuid(document_id, "document_id")
|
|
|
|
# Validate group_key length
|
|
if group_key and len(group_key) > 255:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Group key must be 255 characters or less",
|
|
)
|
|
|
|
# Verify document exists
|
|
document = db.get_document_by_token(document_id, admin_token)
|
|
if document is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Document not found or does not belong to this token",
|
|
)
|
|
|
|
# Update group key
|
|
db.update_document_group_key(document_id, group_key)
|
|
|
|
return {
|
|
"status": "updated",
|
|
"document_id": document_id,
|
|
"group_key": group_key,
|
|
"message": "Document group key updated",
|
|
}
|
|
|
|
return router
|