Remove not used file
This commit is contained in:
@@ -1,622 +0,0 @@
|
||||
"""
|
||||
Admin Training API Routes
|
||||
|
||||
FastAPI endpoints for training task management and scheduling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
ExportRequest,
|
||||
ExportResponse,
|
||||
ModelMetrics,
|
||||
TrainingConfig,
|
||||
TrainingDocumentItem,
|
||||
TrainingDocumentsResponse,
|
||||
TrainingHistoryItem,
|
||||
TrainingLogItem,
|
||||
TrainingLogsResponse,
|
||||
TrainingModelItem,
|
||||
TrainingModelsResponse,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
TrainingTaskDetailResponse,
|
||||
TrainingTaskItem,
|
||||
TrainingTaskListResponse,
|
||||
TrainingTaskResponse,
|
||||
TrainingType,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
"""Create training API router."""
|
||||
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
|
||||
|
||||
# =========================================================================
|
||||
# Training Task Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Create training task",
|
||||
description="Create a new training task.",
|
||||
)
|
||||
async def create_training_task(
|
||||
request: TrainingTaskCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a new training task."""
|
||||
# Convert config to dict
|
||||
config_dict = request.config.model_dump() if request.config else {}
|
||||
|
||||
# Create task
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type=request.task_type.value,
|
||||
description=request.description,
|
||||
config=config_dict,
|
||||
scheduled_at=request.scheduled_at,
|
||||
cron_expression=request.cron_expression,
|
||||
is_recurring=bool(request.cron_expression),
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
|
||||
message="Training task created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List training tasks",
|
||||
description="List all training tasks.",
|
||||
)
|
||||
async def list_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingTaskListResponse:
|
||||
"""List training tasks."""
|
||||
# Validate status
|
||||
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
|
||||
if status and status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = [
|
||||
TrainingTaskItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
scheduled_at=task.scheduled_at,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
return TrainingTaskListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tasks=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
response_model=TrainingTaskDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training task detail",
|
||||
description="Get training task details.",
|
||||
)
|
||||
async def get_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskDetailResponse:
|
||||
"""Get training task details."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
return TrainingTaskDetailResponse(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
config=task.config,
|
||||
scheduled_at=task.scheduled_at,
|
||||
cron_expression=task.cron_expression,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
error_message=task.error_message,
|
||||
result_metrics=task.result_metrics,
|
||||
model_path=task.model_path,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/tasks/{task_id}/cancel",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
|
||||
},
|
||||
summary="Cancel training task",
|
||||
description="Cancel a pending or scheduled training task.",
|
||||
)
|
||||
async def cancel_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Cancel a training task."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
# Verify ownership
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Check if can be cancelled
|
||||
if task.status not in ("pending", "scheduled"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Cannot cancel task with status: {task.status}",
|
||||
)
|
||||
|
||||
# Cancel task
|
||||
success = db.cancel_training_task(task_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to cancel training task",
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.CANCELLED,
|
||||
message="Training task cancelled successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/logs",
|
||||
response_model=TrainingLogsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training logs",
|
||||
description="Get training task logs.",
|
||||
)
|
||||
async def get_training_logs(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=500, description="Maximum logs to return"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingLogsResponse:
|
||||
"""Get training logs."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
# Verify ownership
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get logs
|
||||
logs = db.get_training_logs(task_id, limit, offset)
|
||||
|
||||
items = [
|
||||
TrainingLogItem(
|
||||
level=log.level,
|
||||
message=log.message,
|
||||
details=log.details,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return TrainingLogsResponse(
|
||||
task_id=task_id,
|
||||
logs=items,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Phase 4: Training Data Management
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=TrainingDocumentsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get documents for training",
|
||||
description="Get labeled documents available for training with filtering options.",
|
||||
)
|
||||
async def get_training_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
has_annotations: Annotated[
|
||||
bool,
|
||||
Query(description="Only include documents with annotations"),
|
||||
] = True,
|
||||
min_annotation_count: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Minimum annotation count"),
|
||||
] = None,
|
||||
exclude_used_in_training: Annotated[
|
||||
bool,
|
||||
Query(description="Exclude documents already used in training"),
|
||||
] = False,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingDocumentsResponse:
|
||||
"""Get documents available for training."""
|
||||
# Get documents
|
||||
documents, total = db.get_documents_for_training(
|
||||
admin_token=admin_token,
|
||||
status="labeled",
|
||||
has_annotations=has_annotations,
|
||||
min_annotation_count=min_annotation_count,
|
||||
exclude_used_in_training=exclude_used_in_training,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Build response items with annotation details and training history
|
||||
items = []
|
||||
for doc in documents:
|
||||
# Get annotations for this document
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
# Count annotations by source
|
||||
sources = {"manual": 0, "auto": 0}
|
||||
for ann in annotations:
|
||||
if ann.source in sources:
|
||||
sources[ann.source] += 1
|
||||
|
||||
# Get training history
|
||||
training_links = db.get_document_training_tasks(doc.document_id)
|
||||
used_in_training = [str(link.task_id) for link in training_links]
|
||||
|
||||
items.append(
|
||||
TrainingDocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
annotation_count=len(annotations),
|
||||
annotation_sources=sources,
|
||||
used_in_training=used_in_training,
|
||||
last_modified=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingDocumentsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{task_id}/download",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Model not found"},
|
||||
},
|
||||
summary="Download trained model",
|
||||
description="Download trained model weights file.",
|
||||
)
|
||||
async def download_model(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
):
|
||||
"""Download trained model."""
|
||||
from fastapi.responses import FileResponse
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
# Verify ownership
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Check if model exists
|
||||
if not task.model_path:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not available for this task",
|
||||
)
|
||||
|
||||
model_path = Path(task.model_path)
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not found on disk",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(model_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=f"{task.name}_model.pt",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=TrainingModelsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get trained models",
|
||||
description="Get list of trained models with metrics and download links.",
|
||||
)
|
||||
async def get_training_models(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (completed, failed, etc.)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingModelsResponse:
|
||||
"""Get list of trained models."""
|
||||
# Get training tasks
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status if status else "completed",
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Build response items
|
||||
items = []
|
||||
for task in tasks:
|
||||
# Build metrics
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
# Build download URL if model exists
|
||||
download_url = None
|
||||
if task.model_path and task.status == "completed":
|
||||
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
||||
|
||||
items.append(
|
||||
TrainingModelItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
status=TrainingStatus(task.status),
|
||||
document_count=task.document_count,
|
||||
created_at=task.created_at,
|
||||
completed_at=task.completed_at,
|
||||
metrics=metrics,
|
||||
model_path=task.model_path,
|
||||
download_url=download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingModelsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=items,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Export Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/export",
|
||||
response_model=ExportResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Export annotations",
|
||||
description="Export annotations in YOLO format for training.",
|
||||
)
|
||||
async def export_annotations(
|
||||
request: ExportRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
# Validate format
|
||||
if request.format not in ("yolo", "coco", "voc"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported export format: {request.format}",
|
||||
)
|
||||
|
||||
# Get labeled documents
|
||||
documents = db.get_labeled_documents_for_export(admin_token)
|
||||
|
||||
if not documents:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No labeled documents available for export",
|
||||
)
|
||||
|
||||
# Create export directory
|
||||
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO format directories
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Calculate train/val split
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * request.split_ratio)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
# Export documents
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
# Get annotations
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
|
||||
# Export each page
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
# Copy image
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
continue
|
||||
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
total_images += 1
|
||||
|
||||
# Write YOLO label file
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
# YOLO format: class_id x_center y_center width height
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
# Create data.yaml
|
||||
from src.data.admin_models import FIELD_CLASSES
|
||||
|
||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||
path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
(export_dir / "data.yaml").write_text(yaml_content)
|
||||
|
||||
return ExportResponse(
|
||||
status="completed",
|
||||
export_path=str(export_dir),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
train_count=len(train_docs),
|
||||
val_count=len(val_docs),
|
||||
message=f"Exported {total_images} images with {total_annotations} annotations",
|
||||
)
|
||||
|
||||
return router
|
||||
@@ -1,539 +0,0 @@
|
||||
"""
|
||||
Admin API Request/Response Schemas
|
||||
|
||||
Pydantic models for admin API validation and serialization.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Enums
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
"""Document status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
AUTO_LABELING = "auto_labeling"
|
||||
LABELED = "labeled"
|
||||
EXPORTED = "exported"
|
||||
|
||||
|
||||
class AutoLabelStatus(str, Enum):
|
||||
"""Auto-labeling status enum."""
|
||||
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training task status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
SCHEDULED = "scheduled"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TrainingType(str, Enum):
|
||||
"""Training task type enum."""
|
||||
|
||||
TRAIN = "train"
|
||||
FINETUNE = "finetune"
|
||||
|
||||
|
||||
class AnnotationSource(str, Enum):
|
||||
"""Annotation source enum."""
|
||||
|
||||
MANUAL = "manual"
|
||||
AUTO = "auto"
|
||||
IMPORTED = "imported"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Auth Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AdminTokenCreate(BaseModel):
|
||||
"""Request to create an admin token."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Token name")
|
||||
expires_in_days: int | None = Field(
|
||||
None, ge=1, le=365, description="Token expiration in days (optional)"
|
||||
)
|
||||
|
||||
|
||||
class AdminTokenResponse(BaseModel):
|
||||
"""Response with created admin token."""
|
||||
|
||||
token: str = Field(..., description="Admin token")
|
||||
name: str = Field(..., description="Token name")
|
||||
expires_at: datetime | None = Field(None, description="Token expiration time")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Document Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""Response for document upload."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class DocumentItem(BaseModel):
|
||||
"""Single document in list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentListResponse(BaseModel):
|
||||
"""Response for document list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
documents: list[DocumentItem] = Field(
|
||||
default_factory=list, description="Document list"
|
||||
)
|
||||
|
||||
|
||||
class DocumentDetailResponse(BaseModel):
|
||||
"""Response for document detail."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
content_type: str = Field(..., description="MIME type")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
annotation_lock_until: datetime | None = Field(
|
||||
None, description="Lock expiration time if document is locked"
|
||||
)
|
||||
annotations: list["AnnotationItem"] = Field(
|
||||
default_factory=list, description="Document annotations"
|
||||
)
|
||||
image_urls: list[str] = Field(
|
||||
default_factory=list, description="URLs to page images"
|
||||
)
|
||||
training_history: list["TrainingHistoryItem"] = Field(
|
||||
default_factory=list, description="Training tasks that used this document"
|
||||
)
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentStatsResponse(BaseModel):
|
||||
"""Document statistics response."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
pending: int = Field(default=0, ge=0, description="Pending documents")
|
||||
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
||||
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
||||
exported: int = Field(default=0, ge=0, description="Exported documents")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
"""Bounding box coordinates."""
|
||||
|
||||
x: int = Field(..., ge=0, description="X coordinate (pixels)")
|
||||
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
|
||||
width: int = Field(..., ge=1, description="Width (pixels)")
|
||||
height: int = Field(..., ge=1, description="Height (pixels)")
|
||||
|
||||
|
||||
class AnnotationCreate(BaseModel):
|
||||
"""Request to create an annotation."""
|
||||
|
||||
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
text_value: str | None = Field(None, description="Text value (optional)")
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
"""Request to update an annotation."""
|
||||
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
|
||||
bbox: BoundingBox | None = Field(None, description="New bounding box")
|
||||
text_value: str | None = Field(None, description="New text value")
|
||||
|
||||
|
||||
class AnnotationItem(BaseModel):
|
||||
"""Single annotation item."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
page_number: int = Field(..., ge=1, description="Page number")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID")
|
||||
class_name: str = Field(..., description="Class name")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
normalized_bbox: dict[str, float] = Field(
|
||||
..., description="Normalized bbox (x_center, y_center, width, height)"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Text value")
|
||||
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
|
||||
source: AnnotationSource = Field(..., description="Annotation source")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class AnnotationResponse(BaseModel):
|
||||
"""Response for annotation operation."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationListResponse(BaseModel):
|
||||
"""Response for annotation list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
page_count: int = Field(..., ge=1, description="Total pages")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
annotations: list[AnnotationItem] = Field(
|
||||
default_factory=list, description="Annotation list"
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockRequest(BaseModel):
|
||||
"""Request to acquire annotation lock."""
|
||||
|
||||
duration_seconds: int = Field(
|
||||
default=300,
|
||||
ge=60,
|
||||
le=3600,
|
||||
description="Lock duration in seconds (60-3600)",
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockResponse(BaseModel):
|
||||
"""Response for annotation lock operation."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
locked: bool = Field(..., description="Whether lock was acquired/released")
|
||||
lock_expires_at: datetime | None = Field(
|
||||
None, description="Lock expiration time"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AutoLabelRequest(BaseModel):
|
||||
"""Request to trigger auto-labeling."""
|
||||
|
||||
field_values: dict[str, str] = Field(
|
||||
...,
|
||||
description="Field values to match (e.g., {'invoice_number': '12345'})",
|
||||
)
|
||||
replace_existing: bool = Field(
|
||||
default=False, description="Replace existing auto annotations"
|
||||
)
|
||||
|
||||
|
||||
class AutoLabelResponse(BaseModel):
|
||||
"""Response for auto-labeling."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
status: str = Field(..., description="Auto-labeling status")
|
||||
annotations_created: int = Field(
|
||||
default=0, ge=0, description="Number of annotations created"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Training Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training configuration."""
|
||||
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name")
|
||||
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
||||
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
||||
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
||||
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
|
||||
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
|
||||
project_name: str = Field(
|
||||
default="invoice_fields", description="Training project name"
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskCreate(BaseModel):
|
||||
"""Request to create a training task."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Task name")
|
||||
description: str | None = Field(None, max_length=1000, description="Description")
|
||||
task_type: TrainingType = Field(
|
||||
default=TrainingType.TRAIN, description="Task type"
|
||||
)
|
||||
config: TrainingConfig = Field(
|
||||
default_factory=TrainingConfig, description="Training configuration"
|
||||
)
|
||||
scheduled_at: datetime | None = Field(
|
||||
None, description="Scheduled execution time"
|
||||
)
|
||||
cron_expression: str | None = Field(
|
||||
None, max_length=50, description="Cron expression for recurring tasks"
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskItem(BaseModel):
|
||||
"""Single training task in list."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskListResponse(BaseModel):
|
||||
"""Response for training task list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total tasks")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
|
||||
|
||||
|
||||
class TrainingTaskDetailResponse(BaseModel):
|
||||
"""Response for training task detail."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
cron_expression: str | None = Field(None, description="Cron expression")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
error_message: str | None = Field(None, description="Error message")
|
||||
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
|
||||
model_path: str | None = Field(None, description="Trained model path")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskResponse(BaseModel):
|
||||
"""Response for training task operation."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class TrainingLogItem(BaseModel):
|
||||
"""Single training log entry."""
|
||||
|
||||
level: str = Field(..., description="Log level")
|
||||
message: str = Field(..., description="Log message")
|
||||
details: dict[str, Any] | None = Field(None, description="Additional details")
|
||||
created_at: datetime = Field(..., description="Timestamp")
|
||||
|
||||
|
||||
class TrainingLogsResponse(BaseModel):
|
||||
"""Response for training logs."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Export Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""Request to export annotations."""
|
||||
|
||||
format: str = Field(
|
||||
default="yolo", description="Export format (yolo, coco, voc)"
|
||||
)
|
||||
include_images: bool = Field(
|
||||
default=True, description="Include images in export"
|
||||
)
|
||||
split_ratio: float = Field(
|
||||
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
|
||||
)
|
||||
|
||||
|
||||
class ExportResponse(BaseModel):
|
||||
"""Response for export operation."""
|
||||
|
||||
status: str = Field(..., description="Export status")
|
||||
export_path: str = Field(..., description="Path to exported dataset")
|
||||
total_images: int = Field(..., ge=0, description="Total images exported")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
train_count: int = Field(..., ge=0, description="Training set count")
|
||||
val_count: int = Field(..., ge=0, description="Validation set count")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Phase 4 & 5: Training Data Management and Annotation Enhancement
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TrainingDocumentItem(BaseModel):
|
||||
"""Document item for training page."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Filename")
|
||||
annotation_count: int = Field(..., ge=0, description="Total annotations")
|
||||
annotation_sources: dict[str, int] = Field(
|
||||
..., description="Annotation counts by source (manual, auto)"
|
||||
)
|
||||
used_in_training: list[str] = Field(
|
||||
default_factory=list, description="List of training task IDs that used this document"
|
||||
)
|
||||
last_modified: datetime = Field(..., description="Last modification time")
|
||||
|
||||
|
||||
class TrainingDocumentsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/documents."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total document count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
documents: list[TrainingDocumentItem] = Field(
|
||||
default_factory=list, description="Documents available for training"
|
||||
)
|
||||
|
||||
|
||||
class ModelMetrics(BaseModel):
|
||||
"""Training model metrics."""
|
||||
|
||||
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||
|
||||
|
||||
class TrainingModelItem(BaseModel):
|
||||
"""Trained model item for model list."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Model name")
|
||||
status: TrainingStatus = Field(..., description="Training status")
|
||||
document_count: int = Field(..., ge=0, description="Documents used in training")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||
metrics: ModelMetrics = Field(..., description="Model metrics")
|
||||
model_path: str | None = Field(None, description="Path to model weights")
|
||||
download_url: str | None = Field(None, description="Download URL for model")
|
||||
|
||||
|
||||
class TrainingModelsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/models."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total model count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
models: list[TrainingModelItem] = Field(
|
||||
default_factory=list, description="Trained models"
|
||||
)
|
||||
|
||||
|
||||
class AnnotationVerifyRequest(BaseModel):
|
||||
"""Request to verify an annotation."""
|
||||
|
||||
pass # No body needed, just POST to verify
|
||||
|
||||
|
||||
class AnnotationVerifyResponse(BaseModel):
|
||||
"""Response for annotation verification."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
is_verified: bool = Field(..., description="Verification status")
|
||||
verified_at: datetime = Field(..., description="Verification timestamp")
|
||||
verified_by: str = Field(..., description="Admin token who verified")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationOverrideRequest(BaseModel):
|
||||
"""Request to override an annotation."""
|
||||
|
||||
bbox: dict[str, int] | None = Field(
|
||||
None, description="Updated bounding box {x, y, width, height}"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Updated text value")
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
|
||||
class_name: str | None = Field(None, description="Updated class name")
|
||||
reason: str | None = Field(None, description="Reason for override")
|
||||
|
||||
|
||||
class AnnotationOverrideResponse(BaseModel):
|
||||
"""Response for annotation override."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
source: str = Field(..., description="New source (manual)")
|
||||
override_source: str | None = Field(None, description="Original source (auto)")
|
||||
original_annotation_id: str | None = Field(None, description="Original annotation ID")
|
||||
message: str = Field(..., description="Status message")
|
||||
history_id: str = Field(..., description="History record UUID")
|
||||
|
||||
|
||||
class TrainingHistoryItem(BaseModel):
|
||||
"""Training history for a document."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Training task name")
|
||||
trained_at: datetime = Field(..., description="Training timestamp")
|
||||
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")
|
||||
|
||||
|
||||
# Forward reference update
|
||||
DocumentDetailResponse.model_rebuild()
|
||||
Reference in New Issue
Block a user