623 lines
20 KiB
Python
623 lines
20 KiB
Python
"""
|
|
Admin Training API Routes
|
|
|
|
FastAPI endpoints for training task management and scheduling.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Annotated, Any
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
|
|
from src.data.admin_db import AdminDB
|
|
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
|
from src.web.schemas.admin import (
|
|
ExportRequest,
|
|
ExportResponse,
|
|
ModelMetrics,
|
|
TrainingConfig,
|
|
TrainingDocumentItem,
|
|
TrainingDocumentsResponse,
|
|
TrainingHistoryItem,
|
|
TrainingLogItem,
|
|
TrainingLogsResponse,
|
|
TrainingModelItem,
|
|
TrainingModelsResponse,
|
|
TrainingStatus,
|
|
TrainingTaskCreate,
|
|
TrainingTaskDetailResponse,
|
|
TrainingTaskItem,
|
|
TrainingTaskListResponse,
|
|
TrainingTaskResponse,
|
|
TrainingType,
|
|
)
|
|
from src.web.schemas.common import ErrorResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _validate_uuid(value: str, name: str = "ID") -> None:
|
|
"""Validate UUID format."""
|
|
try:
|
|
UUID(value)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid {name} format. Must be a valid UUID.",
|
|
)
|
|
|
|
|
|
def create_training_router() -> APIRouter:
|
|
"""Create training API router."""
|
|
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
|
|
|
|
# =========================================================================
|
|
# Training Task Endpoints
|
|
# =========================================================================
|
|
|
|
@router.post(
|
|
"/tasks",
|
|
response_model=TrainingTaskResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Create training task",
|
|
description="Create a new training task.",
|
|
)
|
|
async def create_training_task(
|
|
request: TrainingTaskCreate,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
) -> TrainingTaskResponse:
|
|
"""Create a new training task."""
|
|
# Convert config to dict
|
|
config_dict = request.config.model_dump() if request.config else {}
|
|
|
|
# Create task
|
|
task_id = db.create_training_task(
|
|
admin_token=admin_token,
|
|
name=request.name,
|
|
task_type=request.task_type.value,
|
|
description=request.description,
|
|
config=config_dict,
|
|
scheduled_at=request.scheduled_at,
|
|
cron_expression=request.cron_expression,
|
|
is_recurring=bool(request.cron_expression),
|
|
)
|
|
|
|
return TrainingTaskResponse(
|
|
task_id=task_id,
|
|
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
|
|
message="Training task created successfully",
|
|
)
|
|
|
|
@router.get(
|
|
"/tasks",
|
|
response_model=TrainingTaskListResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="List training tasks",
|
|
description="List all training tasks.",
|
|
)
|
|
async def list_training_tasks(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
status: Annotated[
|
|
str | None,
|
|
Query(description="Filter by status"),
|
|
] = None,
|
|
limit: Annotated[
|
|
int,
|
|
Query(ge=1, le=100, description="Page size"),
|
|
] = 20,
|
|
offset: Annotated[
|
|
int,
|
|
Query(ge=0, description="Offset"),
|
|
] = 0,
|
|
) -> TrainingTaskListResponse:
|
|
"""List training tasks."""
|
|
# Validate status
|
|
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
|
|
if status and status not in valid_statuses:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
|
)
|
|
|
|
tasks, total = db.get_training_tasks_by_token(
|
|
admin_token=admin_token,
|
|
status=status,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
items = [
|
|
TrainingTaskItem(
|
|
task_id=str(task.task_id),
|
|
name=task.name,
|
|
task_type=TrainingType(task.task_type),
|
|
status=TrainingStatus(task.status),
|
|
scheduled_at=task.scheduled_at,
|
|
is_recurring=task.is_recurring,
|
|
started_at=task.started_at,
|
|
completed_at=task.completed_at,
|
|
created_at=task.created_at,
|
|
)
|
|
for task in tasks
|
|
]
|
|
|
|
return TrainingTaskListResponse(
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
tasks=items,
|
|
)
|
|
|
|
@router.get(
|
|
"/tasks/{task_id}",
|
|
response_model=TrainingTaskDetailResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Task not found"},
|
|
},
|
|
summary="Get training task detail",
|
|
description="Get training task details.",
|
|
)
|
|
async def get_training_task(
|
|
task_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
) -> TrainingTaskDetailResponse:
|
|
"""Get training task details."""
|
|
_validate_uuid(task_id, "task_id")
|
|
|
|
task = db.get_training_task_by_token(task_id, admin_token)
|
|
if task is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Training task not found or does not belong to this token",
|
|
)
|
|
|
|
return TrainingTaskDetailResponse(
|
|
task_id=str(task.task_id),
|
|
name=task.name,
|
|
description=task.description,
|
|
task_type=TrainingType(task.task_type),
|
|
status=TrainingStatus(task.status),
|
|
config=task.config,
|
|
scheduled_at=task.scheduled_at,
|
|
cron_expression=task.cron_expression,
|
|
is_recurring=task.is_recurring,
|
|
started_at=task.started_at,
|
|
completed_at=task.completed_at,
|
|
error_message=task.error_message,
|
|
result_metrics=task.result_metrics,
|
|
model_path=task.model_path,
|
|
created_at=task.created_at,
|
|
)
|
|
|
|
@router.post(
|
|
"/tasks/{task_id}/cancel",
|
|
response_model=TrainingTaskResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Task not found"},
|
|
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
|
|
},
|
|
summary="Cancel training task",
|
|
description="Cancel a pending or scheduled training task.",
|
|
)
|
|
async def cancel_training_task(
|
|
task_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
) -> TrainingTaskResponse:
|
|
"""Cancel a training task."""
|
|
_validate_uuid(task_id, "task_id")
|
|
|
|
# Verify ownership
|
|
task = db.get_training_task_by_token(task_id, admin_token)
|
|
if task is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Training task not found or does not belong to this token",
|
|
)
|
|
|
|
# Check if can be cancelled
|
|
if task.status not in ("pending", "scheduled"):
|
|
raise HTTPException(
|
|
status_code=409,
|
|
detail=f"Cannot cancel task with status: {task.status}",
|
|
)
|
|
|
|
# Cancel task
|
|
success = db.cancel_training_task(task_id)
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Failed to cancel training task",
|
|
)
|
|
|
|
return TrainingTaskResponse(
|
|
task_id=task_id,
|
|
status=TrainingStatus.CANCELLED,
|
|
message="Training task cancelled successfully",
|
|
)
|
|
|
|
@router.get(
|
|
"/tasks/{task_id}/logs",
|
|
response_model=TrainingLogsResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Task not found"},
|
|
},
|
|
summary="Get training logs",
|
|
description="Get training task logs.",
|
|
)
|
|
async def get_training_logs(
|
|
task_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
limit: Annotated[
|
|
int,
|
|
Query(ge=1, le=500, description="Maximum logs to return"),
|
|
] = 100,
|
|
offset: Annotated[
|
|
int,
|
|
Query(ge=0, description="Offset"),
|
|
] = 0,
|
|
) -> TrainingLogsResponse:
|
|
"""Get training logs."""
|
|
_validate_uuid(task_id, "task_id")
|
|
|
|
# Verify ownership
|
|
task = db.get_training_task_by_token(task_id, admin_token)
|
|
if task is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Training task not found or does not belong to this token",
|
|
)
|
|
|
|
# Get logs
|
|
logs = db.get_training_logs(task_id, limit, offset)
|
|
|
|
items = [
|
|
TrainingLogItem(
|
|
level=log.level,
|
|
message=log.message,
|
|
details=log.details,
|
|
created_at=log.created_at,
|
|
)
|
|
for log in logs
|
|
]
|
|
|
|
return TrainingLogsResponse(
|
|
task_id=task_id,
|
|
logs=items,
|
|
)
|
|
|
|
# =========================================================================
|
|
# Phase 4: Training Data Management
|
|
# =========================================================================
|
|
|
|
@router.get(
|
|
"/documents",
|
|
response_model=TrainingDocumentsResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Get documents for training",
|
|
description="Get labeled documents available for training with filtering options.",
|
|
)
|
|
async def get_training_documents(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
has_annotations: Annotated[
|
|
bool,
|
|
Query(description="Only include documents with annotations"),
|
|
] = True,
|
|
min_annotation_count: Annotated[
|
|
int | None,
|
|
Query(ge=1, description="Minimum annotation count"),
|
|
] = None,
|
|
exclude_used_in_training: Annotated[
|
|
bool,
|
|
Query(description="Exclude documents already used in training"),
|
|
] = False,
|
|
limit: Annotated[
|
|
int,
|
|
Query(ge=1, le=100, description="Page size"),
|
|
] = 100,
|
|
offset: Annotated[
|
|
int,
|
|
Query(ge=0, description="Offset"),
|
|
] = 0,
|
|
) -> TrainingDocumentsResponse:
|
|
"""Get documents available for training."""
|
|
# Get documents
|
|
documents, total = db.get_documents_for_training(
|
|
admin_token=admin_token,
|
|
status="labeled",
|
|
has_annotations=has_annotations,
|
|
min_annotation_count=min_annotation_count,
|
|
exclude_used_in_training=exclude_used_in_training,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
# Build response items with annotation details and training history
|
|
items = []
|
|
for doc in documents:
|
|
# Get annotations for this document
|
|
annotations = db.get_annotations_for_document(str(doc.document_id))
|
|
|
|
# Count annotations by source
|
|
sources = {"manual": 0, "auto": 0}
|
|
for ann in annotations:
|
|
if ann.source in sources:
|
|
sources[ann.source] += 1
|
|
|
|
# Get training history
|
|
training_links = db.get_document_training_tasks(doc.document_id)
|
|
used_in_training = [str(link.task_id) for link in training_links]
|
|
|
|
items.append(
|
|
TrainingDocumentItem(
|
|
document_id=str(doc.document_id),
|
|
filename=doc.filename,
|
|
annotation_count=len(annotations),
|
|
annotation_sources=sources,
|
|
used_in_training=used_in_training,
|
|
last_modified=doc.updated_at,
|
|
)
|
|
)
|
|
|
|
return TrainingDocumentsResponse(
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
documents=items,
|
|
)
|
|
|
|
@router.get(
|
|
"/models/{task_id}/download",
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Model not found"},
|
|
},
|
|
summary="Download trained model",
|
|
description="Download trained model weights file.",
|
|
)
|
|
async def download_model(
|
|
task_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
):
|
|
"""Download trained model."""
|
|
from fastapi.responses import FileResponse
|
|
from pathlib import Path
|
|
|
|
_validate_uuid(task_id, "task_id")
|
|
|
|
# Verify ownership
|
|
task = db.get_training_task_by_token(task_id, admin_token)
|
|
if task is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Training task not found or does not belong to this token",
|
|
)
|
|
|
|
# Check if model exists
|
|
if not task.model_path:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Model file not available for this task",
|
|
)
|
|
|
|
model_path = Path(task.model_path)
|
|
if not model_path.exists():
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Model file not found on disk",
|
|
)
|
|
|
|
return FileResponse(
|
|
path=str(model_path),
|
|
media_type="application/octet-stream",
|
|
filename=f"{task.name}_model.pt",
|
|
)
|
|
|
|
@router.get(
|
|
"/models",
|
|
response_model=TrainingModelsResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Get trained models",
|
|
description="Get list of trained models with metrics and download links.",
|
|
)
|
|
async def get_training_models(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
status: Annotated[
|
|
str | None,
|
|
Query(description="Filter by status (completed, failed, etc.)"),
|
|
] = None,
|
|
limit: Annotated[
|
|
int,
|
|
Query(ge=1, le=100, description="Page size"),
|
|
] = 20,
|
|
offset: Annotated[
|
|
int,
|
|
Query(ge=0, description="Offset"),
|
|
] = 0,
|
|
) -> TrainingModelsResponse:
|
|
"""Get list of trained models."""
|
|
# Get training tasks
|
|
tasks, total = db.get_training_tasks_by_token(
|
|
admin_token=admin_token,
|
|
status=status if status else "completed",
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
# Build response items
|
|
items = []
|
|
for task in tasks:
|
|
# Build metrics
|
|
metrics = ModelMetrics(
|
|
mAP=task.metrics_mAP,
|
|
precision=task.metrics_precision,
|
|
recall=task.metrics_recall,
|
|
)
|
|
|
|
# Build download URL if model exists
|
|
download_url = None
|
|
if task.model_path and task.status == "completed":
|
|
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
|
|
|
items.append(
|
|
TrainingModelItem(
|
|
task_id=str(task.task_id),
|
|
name=task.name,
|
|
status=TrainingStatus(task.status),
|
|
document_count=task.document_count,
|
|
created_at=task.created_at,
|
|
completed_at=task.completed_at,
|
|
metrics=metrics,
|
|
model_path=task.model_path,
|
|
download_url=download_url,
|
|
)
|
|
)
|
|
|
|
return TrainingModelsResponse(
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
models=items,
|
|
)
|
|
|
|
# =========================================================================
|
|
# Export Endpoints
|
|
# =========================================================================
|
|
|
|
@router.post(
|
|
"/export",
|
|
response_model=ExportResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Export annotations",
|
|
description="Export annotations in YOLO format for training.",
|
|
)
|
|
async def export_annotations(
|
|
request: ExportRequest,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
) -> ExportResponse:
|
|
"""Export annotations for training."""
|
|
from pathlib import Path
|
|
import shutil
|
|
|
|
# Validate format
|
|
if request.format not in ("yolo", "coco", "voc"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported export format: {request.format}",
|
|
)
|
|
|
|
# Get labeled documents
|
|
documents = db.get_labeled_documents_for_export(admin_token)
|
|
|
|
if not documents:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No labeled documents available for export",
|
|
)
|
|
|
|
# Create export directory
|
|
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
|
export_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# YOLO format directories
|
|
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
|
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
|
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
|
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
|
|
|
# Calculate train/val split
|
|
total_docs = len(documents)
|
|
train_count = int(total_docs * request.split_ratio)
|
|
train_docs = documents[:train_count]
|
|
val_docs = documents[train_count:]
|
|
|
|
total_images = 0
|
|
total_annotations = 0
|
|
|
|
# Export documents
|
|
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
|
for doc in docs:
|
|
# Get annotations
|
|
annotations = db.get_annotations_for_document(str(doc.document_id))
|
|
|
|
if not annotations:
|
|
continue
|
|
|
|
# Export each page
|
|
for page_num in range(1, doc.page_count + 1):
|
|
page_annotations = [a for a in annotations if a.page_number == page_num]
|
|
|
|
if not page_annotations and not request.include_images:
|
|
continue
|
|
|
|
# Copy image
|
|
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
|
if not src_image.exists():
|
|
continue
|
|
|
|
image_name = f"{doc.document_id}_page{page_num}.png"
|
|
dst_image = export_dir / "images" / split / image_name
|
|
shutil.copy(src_image, dst_image)
|
|
total_images += 1
|
|
|
|
# Write YOLO label file
|
|
label_name = f"{doc.document_id}_page{page_num}.txt"
|
|
label_path = export_dir / "labels" / split / label_name
|
|
|
|
with open(label_path, "w") as f:
|
|
for ann in page_annotations:
|
|
# YOLO format: class_id x_center y_center width height
|
|
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
|
f.write(line)
|
|
total_annotations += 1
|
|
|
|
# Create data.yaml
|
|
from src.data.admin_models import FIELD_CLASSES
|
|
|
|
yaml_content = f"""# Auto-generated YOLO dataset config
|
|
path: {export_dir.absolute()}
|
|
train: images/train
|
|
val: images/val
|
|
|
|
nc: {len(FIELD_CLASSES)}
|
|
names: {list(FIELD_CLASSES.values())}
|
|
"""
|
|
(export_dir / "data.yaml").write_text(yaml_content)
|
|
|
|
return ExportResponse(
|
|
status="completed",
|
|
export_path=str(export_dir),
|
|
total_images=total_images,
|
|
total_annotations=total_annotations,
|
|
train_count=len(train_docs),
|
|
val_count=len(val_docs),
|
|
message=f"Exported {total_images} images with {total_annotations} annotations",
|
|
)
|
|
|
|
return router
|