212 lines
6.8 KiB
Python
212 lines
6.8 KiB
Python
"""Training Documents and Models Endpoints."""
|
|
|
|
import logging
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
|
|
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
|
from inference.web.schemas.admin import (
|
|
ModelMetrics,
|
|
TrainingDocumentItem,
|
|
TrainingDocumentsResponse,
|
|
TrainingModelItem,
|
|
TrainingModelsResponse,
|
|
TrainingStatus,
|
|
)
|
|
from inference.web.schemas.common import ErrorResponse
|
|
|
|
from ._utils import _validate_uuid
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def register_document_routes(router: APIRouter) -> None:
|
|
"""Register training document and model endpoints on the router."""
|
|
|
|
@router.get(
|
|
"/documents",
|
|
response_model=TrainingDocumentsResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Get documents for training",
|
|
description="Get labeled documents available for training with filtering options.",
|
|
)
|
|
async def get_training_documents(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
has_annotations: Annotated[
|
|
bool,
|
|
Query(description="Only include documents with annotations"),
|
|
] = True,
|
|
min_annotation_count: Annotated[
|
|
int | None,
|
|
Query(ge=1, description="Minimum annotation count"),
|
|
] = None,
|
|
exclude_used_in_training: Annotated[
|
|
bool,
|
|
Query(description="Exclude documents already used in training"),
|
|
] = False,
|
|
limit: Annotated[
|
|
int,
|
|
Query(ge=1, le=100, description="Page size"),
|
|
] = 100,
|
|
offset: Annotated[
|
|
int,
|
|
Query(ge=0, description="Offset"),
|
|
] = 0,
|
|
) -> TrainingDocumentsResponse:
|
|
"""Get documents available for training."""
|
|
documents, total = db.get_documents_for_training(
|
|
admin_token=admin_token,
|
|
status="labeled",
|
|
has_annotations=has_annotations,
|
|
min_annotation_count=min_annotation_count,
|
|
exclude_used_in_training=exclude_used_in_training,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
items = []
|
|
for doc in documents:
|
|
annotations = db.get_annotations_for_document(str(doc.document_id))
|
|
|
|
sources = {"manual": 0, "auto": 0}
|
|
for ann in annotations:
|
|
if ann.source in sources:
|
|
sources[ann.source] += 1
|
|
|
|
training_links = db.get_document_training_tasks(doc.document_id)
|
|
used_in_training = [str(link.task_id) for link in training_links]
|
|
|
|
items.append(
|
|
TrainingDocumentItem(
|
|
document_id=str(doc.document_id),
|
|
filename=doc.filename,
|
|
annotation_count=len(annotations),
|
|
annotation_sources=sources,
|
|
used_in_training=used_in_training,
|
|
last_modified=doc.updated_at,
|
|
)
|
|
)
|
|
|
|
return TrainingDocumentsResponse(
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
documents=items,
|
|
)
|
|
|
|
@router.get(
|
|
"/models/{task_id}/download",
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
404: {"model": ErrorResponse, "description": "Model not found"},
|
|
},
|
|
summary="Download trained model",
|
|
description="Download trained model weights file.",
|
|
)
|
|
async def download_model(
|
|
task_id: str,
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
):
|
|
"""Download trained model."""
|
|
from fastapi.responses import FileResponse
|
|
from pathlib import Path
|
|
|
|
_validate_uuid(task_id, "task_id")
|
|
|
|
task = db.get_training_task_by_token(task_id, admin_token)
|
|
if task is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Training task not found or does not belong to this token",
|
|
)
|
|
|
|
if not task.model_path:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Model file not available for this task",
|
|
)
|
|
|
|
model_path = Path(task.model_path)
|
|
if not model_path.exists():
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Model file not found on disk",
|
|
)
|
|
|
|
return FileResponse(
|
|
path=str(model_path),
|
|
media_type="application/octet-stream",
|
|
filename=f"{task.name}_model.pt",
|
|
)
|
|
|
|
@router.get(
|
|
"/completed-tasks",
|
|
response_model=TrainingModelsResponse,
|
|
responses={
|
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
|
},
|
|
summary="Get completed training tasks",
|
|
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
|
|
)
|
|
async def get_completed_training_tasks(
|
|
admin_token: AdminTokenDep,
|
|
db: AdminDBDep,
|
|
status: Annotated[
|
|
str | None,
|
|
Query(description="Filter by status (completed, failed, etc.)"),
|
|
] = None,
|
|
limit: Annotated[
|
|
int,
|
|
Query(ge=1, le=100, description="Page size"),
|
|
] = 20,
|
|
offset: Annotated[
|
|
int,
|
|
Query(ge=0, description="Offset"),
|
|
] = 0,
|
|
) -> TrainingModelsResponse:
|
|
"""Get list of trained models."""
|
|
tasks, total = db.get_training_tasks_by_token(
|
|
admin_token=admin_token,
|
|
status=status if status else "completed",
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
items = []
|
|
for task in tasks:
|
|
metrics = ModelMetrics(
|
|
mAP=task.metrics_mAP,
|
|
precision=task.metrics_precision,
|
|
recall=task.metrics_recall,
|
|
)
|
|
|
|
download_url = None
|
|
if task.model_path and task.status == "completed":
|
|
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
|
|
|
items.append(
|
|
TrainingModelItem(
|
|
task_id=str(task.task_id),
|
|
name=task.name,
|
|
status=TrainingStatus(task.status),
|
|
document_count=task.document_count,
|
|
created_at=task.created_at,
|
|
completed_at=task.completed_at,
|
|
metrics=metrics,
|
|
model_path=task.model_path,
|
|
download_url=download_url,
|
|
)
|
|
)
|
|
|
|
return TrainingModelsResponse(
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
models=items,
|
|
)
|