""" Admin Training API Routes FastAPI endpoints for training task management and scheduling. """ import logging from datetime import datetime from typing import Annotated, Any from uuid import UUID from fastapi import APIRouter, HTTPException, Query from src.data.admin_db import AdminDB from src.web.core.auth import AdminTokenDep, AdminDBDep from src.web.schemas.admin import ( ExportRequest, ExportResponse, ModelMetrics, TrainingConfig, TrainingDocumentItem, TrainingDocumentsResponse, TrainingHistoryItem, TrainingLogItem, TrainingLogsResponse, TrainingModelItem, TrainingModelsResponse, TrainingStatus, TrainingTaskCreate, TrainingTaskDetailResponse, TrainingTaskItem, TrainingTaskListResponse, TrainingTaskResponse, TrainingType, ) from src.web.schemas.common import ErrorResponse logger = logging.getLogger(__name__) def _validate_uuid(value: str, name: str = "ID") -> None: """Validate UUID format.""" try: UUID(value) except ValueError: raise HTTPException( status_code=400, detail=f"Invalid {name} format. Must be a valid UUID.", ) def create_training_router() -> APIRouter: """Create training API router.""" router = APIRouter(prefix="/admin/training", tags=["Admin Training"]) # ========================================================================= # Training Task Endpoints # ========================================================================= @router.post( "/tasks", response_model=TrainingTaskResponse, responses={ 400: {"model": ErrorResponse, "description": "Invalid request"}, 401: {"model": ErrorResponse, "description": "Invalid token"}, }, summary="Create training task", description="Create a new training task.", ) async def create_training_task( request: TrainingTaskCreate, admin_token: AdminTokenDep, db: AdminDBDep, ) -> TrainingTaskResponse: """Create a new training task.""" # Convert config to dict config_dict = request.config.model_dump() if request.config else {} # Create task task_id = db.create_training_task( admin_token=admin_token, name=request.name, task_type=request.task_type.value, description=request.description, config=config_dict, scheduled_at=request.scheduled_at, cron_expression=request.cron_expression, is_recurring=bool(request.cron_expression), ) return TrainingTaskResponse( task_id=task_id, status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING, message="Training task created successfully", ) @router.get( "/tasks", response_model=TrainingTaskListResponse, responses={ 401: {"model": ErrorResponse, "description": "Invalid token"}, }, summary="List training tasks", description="List all training tasks.", ) async def list_training_tasks( admin_token: AdminTokenDep, db: AdminDBDep, status: Annotated[ str | None, Query(description="Filter by status"), ] = None, limit: Annotated[ int, Query(ge=1, le=100, description="Page size"), ] = 20, offset: Annotated[ int, Query(ge=0, description="Offset"), ] = 0, ) -> TrainingTaskListResponse: """List training tasks.""" # Validate status valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled") if status and status not in valid_statuses: raise HTTPException( status_code=400, detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}", ) tasks, total = db.get_training_tasks_by_token( admin_token=admin_token, status=status, limit=limit, offset=offset, ) items = [ TrainingTaskItem( task_id=str(task.task_id), name=task.name, task_type=TrainingType(task.task_type), status=TrainingStatus(task.status), scheduled_at=task.scheduled_at, is_recurring=task.is_recurring, started_at=task.started_at, completed_at=task.completed_at, created_at=task.created_at, ) for task in tasks ] return TrainingTaskListResponse( total=total, limit=limit, offset=offset, tasks=items, ) @router.get( "/tasks/{task_id}", response_model=TrainingTaskDetailResponse, responses={ 401: {"model": ErrorResponse, "description": "Invalid token"}, 404: {"model": ErrorResponse, "description": "Task not found"}, }, summary="Get training task detail", description="Get training task details.", ) async def get_training_task( task_id: str, admin_token: AdminTokenDep, db: AdminDBDep, ) -> TrainingTaskDetailResponse: """Get training task details.""" _validate_uuid(task_id, "task_id") task = db.get_training_task_by_token(task_id, admin_token) if task is None: raise HTTPException( status_code=404, detail="Training task not found or does not belong to this token", ) return TrainingTaskDetailResponse( task_id=str(task.task_id), name=task.name, description=task.description, task_type=TrainingType(task.task_type), status=TrainingStatus(task.status), config=task.config, scheduled_at=task.scheduled_at, cron_expression=task.cron_expression, is_recurring=task.is_recurring, started_at=task.started_at, completed_at=task.completed_at, error_message=task.error_message, result_metrics=task.result_metrics, model_path=task.model_path, created_at=task.created_at, ) @router.post( "/tasks/{task_id}/cancel", response_model=TrainingTaskResponse, responses={ 401: {"model": ErrorResponse, "description": "Invalid token"}, 404: {"model": ErrorResponse, "description": "Task not found"}, 409: {"model": ErrorResponse, "description": "Cannot cancel task"}, }, summary="Cancel training task", description="Cancel a pending or scheduled training task.", ) async def cancel_training_task( task_id: str, admin_token: AdminTokenDep, db: AdminDBDep, ) -> TrainingTaskResponse: """Cancel a training task.""" _validate_uuid(task_id, "task_id") # Verify ownership task = db.get_training_task_by_token(task_id, admin_token) if task is None: raise HTTPException( status_code=404, detail="Training task not found or does not belong to this token", ) # Check if can be cancelled if task.status not in ("pending", "scheduled"): raise HTTPException( status_code=409, detail=f"Cannot cancel task with status: {task.status}", ) # Cancel task success = db.cancel_training_task(task_id) if not success: raise HTTPException( status_code=500, detail="Failed to cancel training task", ) return TrainingTaskResponse( task_id=task_id, status=TrainingStatus.CANCELLED, message="Training task cancelled successfully", ) @router.get( "/tasks/{task_id}/logs", response_model=TrainingLogsResponse, responses={ 401: {"model": ErrorResponse, "description": "Invalid token"}, 404: {"model": ErrorResponse, "description": "Task not found"}, }, summary="Get training logs", description="Get training task logs.", ) async def get_training_logs( task_id: str, admin_token: AdminTokenDep, db: AdminDBDep, limit: Annotated[ int, Query(ge=1, le=500, description="Maximum logs to return"), ] = 100, offset: Annotated[ int, Query(ge=0, description="Offset"), ] = 0, ) -> TrainingLogsResponse: """Get training logs.""" _validate_uuid(task_id, "task_id") # Verify ownership task = db.get_training_task_by_token(task_id, admin_token) if task is None: raise HTTPException( status_code=404, detail="Training task not found or does not belong to this token", ) # Get logs logs = db.get_training_logs(task_id, limit, offset) items = [ TrainingLogItem( level=log.level, message=log.message, details=log.details, created_at=log.created_at, ) for log in logs ] return TrainingLogsResponse( task_id=task_id, logs=items, ) # ========================================================================= # Phase 4: Training Data Management # ========================================================================= @router.get( "/documents", response_model=TrainingDocumentsResponse, responses={ 401: {"model": ErrorResponse, "description": "Invalid token"}, }, summary="Get documents for training", description="Get labeled documents available for training with filtering options.", ) async def get_training_documents( admin_token: AdminTokenDep, db: AdminDBDep, has_annotations: Annotated[ bool, Query(description="Only include documents with annotations"), ] = True, min_annotation_count: Annotated[ int | None, Query(ge=1, description="Minimum annotation count"), ] = None, exclude_used_in_training: Annotated[ bool, Query(description="Exclude documents already used in training"), ] = False, limit: Annotated[ int, Query(ge=1, le=100, description="Page size"), ] = 100, offset: Annotated[ int, Query(ge=0, description="Offset"), ] = 0, ) -> TrainingDocumentsResponse: """Get documents available for training.""" # Get documents documents, total = db.get_documents_for_training( admin_token=admin_token, status="labeled", has_annotations=has_annotations, min_annotation_count=min_annotation_count, exclude_used_in_training=exclude_used_in_training, limit=limit, offset=offset, ) # Build response items with annotation details and training history items = [] for doc in documents: # Get annotations for this document annotations = db.get_annotations_for_document(str(doc.document_id)) # Count annotations by source sources = {"manual": 0, "auto": 0} for ann in annotations: if ann.source in sources: sources[ann.source] += 1 # Get training history training_links = db.get_document_training_tasks(doc.document_id) used_in_training = [str(link.task_id) for link in training_links] items.append( TrainingDocumentItem( document_id=str(doc.document_id), filename=doc.filename, annotation_count=len(annotations), annotation_sources=sources, used_in_training=used_in_training, last_modified=doc.updated_at, ) ) return TrainingDocumentsResponse( total=total, limit=limit, offset=offset, documents=items, ) @router.get( "/models/{task_id}/download", responses={ 401: {"model": ErrorResponse, "description": "Invalid token"}, 404: {"model": ErrorResponse, "description": "Model not found"}, }, summary="Download trained model", description="Download trained model weights file.", ) async def download_model( task_id: str, admin_token: AdminTokenDep, db: AdminDBDep, ): """Download trained model.""" from fastapi.responses import FileResponse from pathlib import Path _validate_uuid(task_id, "task_id") # Verify ownership task = db.get_training_task_by_token(task_id, admin_token) if task is None: raise HTTPException( status_code=404, detail="Training task not found or does not belong to this token", ) # Check if model exists if not task.model_path: raise HTTPException( status_code=404, detail="Model file not available for this task", ) model_path = Path(task.model_path) if not model_path.exists(): raise HTTPException( status_code=404, detail="Model file not found on disk", ) return FileResponse( path=str(model_path), media_type="application/octet-stream", filename=f"{task.name}_model.pt", ) @router.get( "/models", response_model=TrainingModelsResponse, responses={ 401: {"model": ErrorResponse, "description": "Invalid token"}, }, summary="Get trained models", description="Get list of trained models with metrics and download links.", ) async def get_training_models( admin_token: AdminTokenDep, db: AdminDBDep, status: Annotated[ str | None, Query(description="Filter by status (completed, failed, etc.)"), ] = None, limit: Annotated[ int, Query(ge=1, le=100, description="Page size"), ] = 20, offset: Annotated[ int, Query(ge=0, description="Offset"), ] = 0, ) -> TrainingModelsResponse: """Get list of trained models.""" # Get training tasks tasks, total = db.get_training_tasks_by_token( admin_token=admin_token, status=status if status else "completed", limit=limit, offset=offset, ) # Build response items items = [] for task in tasks: # Build metrics metrics = ModelMetrics( mAP=task.metrics_mAP, precision=task.metrics_precision, recall=task.metrics_recall, ) # Build download URL if model exists download_url = None if task.model_path and task.status == "completed": download_url = f"/api/v1/admin/training/models/{task.task_id}/download" items.append( TrainingModelItem( task_id=str(task.task_id), name=task.name, status=TrainingStatus(task.status), document_count=task.document_count, created_at=task.created_at, completed_at=task.completed_at, metrics=metrics, model_path=task.model_path, download_url=download_url, ) ) return TrainingModelsResponse( total=total, limit=limit, offset=offset, models=items, ) # ========================================================================= # Export Endpoints # ========================================================================= @router.post( "/export", response_model=ExportResponse, responses={ 400: {"model": ErrorResponse, "description": "Invalid request"}, 401: {"model": ErrorResponse, "description": "Invalid token"}, }, summary="Export annotations", description="Export annotations in YOLO format for training.", ) async def export_annotations( request: ExportRequest, admin_token: AdminTokenDep, db: AdminDBDep, ) -> ExportResponse: """Export annotations for training.""" from pathlib import Path import shutil # Validate format if request.format not in ("yolo", "coco", "voc"): raise HTTPException( status_code=400, detail=f"Unsupported export format: {request.format}", ) # Get labeled documents documents = db.get_labeled_documents_for_export(admin_token) if not documents: raise HTTPException( status_code=400, detail="No labeled documents available for export", ) # Create export directory export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}" export_dir.mkdir(parents=True, exist_ok=True) # YOLO format directories (export_dir / "images" / "train").mkdir(parents=True, exist_ok=True) (export_dir / "images" / "val").mkdir(parents=True, exist_ok=True) (export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True) (export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True) # Calculate train/val split total_docs = len(documents) train_count = int(total_docs * request.split_ratio) train_docs = documents[:train_count] val_docs = documents[train_count:] total_images = 0 total_annotations = 0 # Export documents for split, docs in [("train", train_docs), ("val", val_docs)]: for doc in docs: # Get annotations annotations = db.get_annotations_for_document(str(doc.document_id)) if not annotations: continue # Export each page for page_num in range(1, doc.page_count + 1): page_annotations = [a for a in annotations if a.page_number == page_num] if not page_annotations and not request.include_images: continue # Copy image src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png" if not src_image.exists(): continue image_name = f"{doc.document_id}_page{page_num}.png" dst_image = export_dir / "images" / split / image_name shutil.copy(src_image, dst_image) total_images += 1 # Write YOLO label file label_name = f"{doc.document_id}_page{page_num}.txt" label_path = export_dir / "labels" / split / label_name with open(label_path, "w") as f: for ann in page_annotations: # YOLO format: class_id x_center y_center width height line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n" f.write(line) total_annotations += 1 # Create data.yaml from src.data.admin_models import FIELD_CLASSES yaml_content = f"""# Auto-generated YOLO dataset config path: {export_dir.absolute()} train: images/train val: images/val nc: {len(FIELD_CLASSES)} names: {list(FIELD_CLASSES.values())} """ (export_dir / "data.yaml").write_text(yaml_content) return ExportResponse( status="completed", export_path=str(export_dir), total_images=total_images, total_annotations=total_annotations, train_count=len(train_docs), val_count=len(val_docs), message=f"Exported {total_images} images with {total_annotations} annotations", ) return router