This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -25,6 +25,7 @@ from inference.data.admin_models import (
AnnotationHistory,
TrainingDataset,
DatasetDocument,
ModelVersion,
)
logger = logging.getLogger(__name__)
@@ -110,6 +111,7 @@ class AdminDB:
page_count: int = 1,
upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None,
group_key: str | None = None,
admin_token: str | None = None, # Deprecated, kept for compatibility
) -> str:
"""Create a new document record."""
@@ -122,6 +124,7 @@ class AdminDB:
page_count=page_count,
upload_source=upload_source,
csv_field_values=csv_field_values,
group_key=group_key,
)
session.add(document)
session.flush()
@@ -253,6 +256,17 @@ class AdminDB:
document.updated_at = datetime.utcnow()
session.add(document)
def update_document_group_key(self, document_id: str, group_key: str | None) -> bool:
"""Update document group key."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.group_key = group_key
document.updated_at = datetime.utcnow()
session.add(document)
return True
return False
def delete_document(self, document_id: str) -> bool:
"""Delete a document and its annotations."""
with get_session_context() as session:
@@ -1215,6 +1229,39 @@ class AdminDB:
session.expunge(d)
return list(datasets), total
def get_active_training_tasks_for_datasets(
self, dataset_ids: list[str]
) -> dict[str, dict[str, str]]:
"""Get active (pending/scheduled/running) training tasks for datasets.
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
"""
if not dataset_ids:
return {}
# Validate UUIDs before query
valid_uuids = []
for d in dataset_ids:
try:
valid_uuids.append(UUID(d))
except ValueError:
logger.warning("Invalid UUID in get_active_training_tasks_for_datasets: %s", d)
continue
if not valid_uuids:
return {}
with get_session_context() as session:
statement = select(TrainingTask).where(
TrainingTask.dataset_id.in_(valid_uuids),
TrainingTask.status.in_(["pending", "scheduled", "running"]),
)
results = session.exec(statement).all()
return {
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
for t in results
}
def update_dataset_status(
self,
dataset_id: str | UUID,
@@ -1314,3 +1361,182 @@ class AdminDB:
session.delete(dataset)
session.commit()
return True
# ==========================================================================
# Model Version Operations
# ==========================================================================
def create_model_version(
self,
version: str,
name: str,
model_path: str,
description: str | None = None,
task_id: str | UUID | None = None,
dataset_id: str | UUID | None = None,
metrics_mAP: float | None = None,
metrics_precision: float | None = None,
metrics_recall: float | None = None,
document_count: int = 0,
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
model = ModelVersion(
version=version,
name=name,
model_path=model_path,
description=description,
task_id=UUID(str(task_id)) if task_id else None,
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
)
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def get_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Get a model version by ID."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if model:
session.expunge(model)
return model
def get_model_versions(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[ModelVersion], int]:
"""List model versions with optional status filter."""
with get_session_context() as session:
query = select(ModelVersion)
count_query = select(func.count()).select_from(ModelVersion)
if status:
query = query.where(ModelVersion.status == status)
count_query = count_query.where(ModelVersion.status == status)
total = session.exec(count_query).one()
models = session.exec(
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
).all()
for m in models:
session.expunge(m)
return list(models), total
def get_active_model_version(self) -> ModelVersion | None:
"""Get the currently active model version for inference."""
with get_session_context() as session:
result = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).first()
if result:
session.expunge(result)
return result
def activate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Activate a model version for inference (deactivates all others)."""
with get_session_context() as session:
# Deactivate all versions
all_versions = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).all()
for v in all_versions:
v.is_active = False
v.status = "inactive"
v.updated_at = datetime.utcnow()
session.add(v)
# Activate the specified version
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = True
model.status = "active"
model.activated_at = datetime.utcnow()
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def deactivate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Deactivate a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = False
model.status = "inactive"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def update_model_version(
self,
version_id: str | UUID,
name: str | None = None,
description: str | None = None,
status: str | None = None,
) -> ModelVersion | None:
"""Update model version metadata."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
if name is not None:
model.name = name
if description is not None:
model.description = description
if status is not None:
model.status = status
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def archive_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Archive a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
# Cannot archive active model
if model.is_active:
return None
model.status = "archived"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def delete_model_version(self, version_id: str | UUID) -> bool:
"""Delete a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return False
# Cannot delete active model
if model.is_active:
return False
session.delete(model)
session.commit()
return True

View File

@@ -70,6 +70,8 @@ class AdminDocument(SQLModel, table=True):
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
group_key: str | None = Field(default=None, max_length=255, index=True)
# User-defined grouping key for document organization
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
@@ -275,6 +277,56 @@ class TrainingDocumentLink(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Model Version Management
# =============================================================================
class ModelVersion(SQLModel, table=True):
"""Model version for inference deployment."""
__tablename__ = "model_versions"
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
version: str = Field(max_length=50, index=True)
# Semantic version e.g., "1.0.0", "2.1.0"
name: str = Field(max_length=255)
description: str | None = Field(default=None)
model_path: str = Field(max_length=512)
# Path to the model weights file
status: str = Field(default="inactive", max_length=20, index=True)
# Status: active, inactive, archived
is_active: bool = Field(default=False, index=True)
# Only one version can be active at a time for inference
# Training association
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training metrics
metrics_mAP: float | None = Field(default=None)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
document_count: int = Field(default=0)
# Number of documents used in training
# Training configuration snapshot
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of epochs, batch_size, etc.
# File info
file_size: int | None = Field(default=None)
# Model file size in bytes
# Timestamps
trained_at: datetime | None = Field(default=None)
# When training completed
activated_at: datetime | None = Field(default=None)
# When this version was last activated
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================

View File

@@ -49,6 +49,111 @@ def get_engine():
return _engine
def run_migrations() -> None:
"""Run database migrations for new columns."""
engine = get_engine()
migrations = [
# Migration 004: Training datasets tables and dataset_id on training_tasks
(
"training_datasets_tables",
"""
CREATE TABLE IF NOT EXISTS training_datasets (
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(255) NOT NULL,
description TEXT,
status VARCHAR(20) NOT NULL DEFAULT 'building',
train_ratio FLOAT NOT NULL DEFAULT 0.8,
val_ratio FLOAT NOT NULL DEFAULT 0.1,
seed INTEGER NOT NULL DEFAULT 42,
total_documents INTEGER NOT NULL DEFAULT 0,
total_images INTEGER NOT NULL DEFAULT 0,
total_annotations INTEGER NOT NULL DEFAULT 0,
dataset_path VARCHAR(512),
error_message TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
""",
),
(
"dataset_documents_table",
"""
CREATE TABLE IF NOT EXISTS dataset_documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
split VARCHAR(10) NOT NULL,
page_count INTEGER NOT NULL DEFAULT 0,
annotation_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
UNIQUE(dataset_id, document_id)
);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
""",
),
(
"training_tasks_dataset_id",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
""",
),
# Migration 005: Add group_key to admin_documents
(
"admin_documents_group_key",
"""
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
""",
),
# Migration 006: Model versions table
(
"model_versions_table",
"""
CREATE TABLE IF NOT EXISTS model_versions (
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
version VARCHAR(50) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
model_path VARCHAR(512) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
is_active BOOLEAN NOT NULL DEFAULT FALSE,
task_id UUID REFERENCES training_tasks(task_id),
dataset_id UUID REFERENCES training_datasets(dataset_id),
metrics_mAP FLOAT,
metrics_precision FLOAT,
metrics_recall FLOAT,
document_count INTEGER NOT NULL DEFAULT 0,
training_config JSONB,
file_size BIGINT,
trained_at TIMESTAMP WITH TIME ZONE,
activated_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
""",
),
]
with engine.connect() as conn:
for name, sql in migrations:
try:
conn.execute(text(sql))
conn.commit()
logger.info(f"Migration '{name}' applied successfully")
except Exception as e:
# Log but don't fail - column may already exist
logger.debug(f"Migration '{name}' skipped or failed: {e}")
def create_db_and_tables() -> None:
"""Create all database tables."""
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
@@ -64,6 +169,9 @@ def create_db_and_tables() -> None:
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
# Run migrations for new columns
run_migrations()
def get_session() -> Session:
"""Get a new database session."""

View File

@@ -5,6 +5,7 @@ Document management, annotations, and training endpoints.
"""
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.api.v1.admin.auth import create_auth_router
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.api.v1.admin.locks import create_locks_router
@@ -12,6 +13,7 @@ from inference.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_augmentation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",

View File

@@ -0,0 +1,15 @@
"""Augmentation API module."""
from fastapi import APIRouter
from .routes import register_augmentation_routes
def create_augmentation_router() -> APIRouter:
"""Create and configure the augmentation router."""
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
register_augmentation_routes(router)
return router
__all__ = ["create_augmentation_router"]

View File

@@ -0,0 +1,162 @@
"""Augmentation API routes."""
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminDBDep, AdminTokenDep
from inference.web.schemas.admin.augmentation import (
AugmentationBatchRequest,
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewRequest,
AugmentationPreviewResponse,
AugmentationTypeInfo,
AugmentationTypesResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
PresetInfo,
PresetsResponse,
)
def register_augmentation_routes(router: APIRouter) -> None:
"""Register augmentation endpoints on the router."""
@router.get(
"/types",
response_model=AugmentationTypesResponse,
summary="List available augmentation types",
)
async def list_augmentation_types(
admin_token: AdminTokenDep,
) -> AugmentationTypesResponse:
"""
List all available augmentation types with descriptions and parameters.
"""
from shared.augmentation.pipeline import (
AUGMENTATION_REGISTRY,
AugmentationPipeline,
)
types = []
for name, aug_class in AUGMENTATION_REGISTRY.items():
# Create instance with empty params to get preview params
aug = aug_class({})
types.append(
AugmentationTypeInfo(
name=name,
description=(aug_class.__doc__ or "").strip(),
affects_geometry=aug_class.affects_geometry,
stage=AugmentationPipeline.STAGE_MAPPING[name],
default_params=aug.get_preview_params(),
)
)
return AugmentationTypesResponse(augmentation_types=types)
@router.get(
"/presets",
response_model=PresetsResponse,
summary="Get augmentation presets",
)
async def get_presets(
admin_token: AdminTokenDep,
) -> PresetsResponse:
"""Get predefined augmentation presets for common use cases."""
from shared.augmentation.presets import list_presets
presets = [PresetInfo(**p) for p in list_presets()]
return PresetsResponse(presets=presets)
@router.post(
"/preview/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview augmentation on document image",
)
async def preview_augmentation(
document_id: str,
request: AugmentationPreviewRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Returns URLs to original and augmented preview images.
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.preview_single(
document_id=document_id,
page=page,
augmentation_type=request.augmentation_type,
params=request.params,
)
@router.post(
"/preview-config/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview full augmentation config on document",
)
async def preview_config(
document_id: str,
config: AugmentationConfigSchema,
admin_token: AdminTokenDep,
db: AdminDBDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""Preview complete augmentation pipeline on a document page."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.preview_config(
document_id=document_id,
page=page,
config=config,
)
@router.post(
"/batch",
response_model=AugmentationBatchResponse,
summary="Create augmented dataset (offline preprocessing)",
)
async def create_augmented_dataset(
request: AugmentationBatchRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
This runs as a background task. The augmented images are stored
alongside the original dataset for training.
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.create_augmented_dataset(
source_dataset_id=request.dataset_id,
config=request.config,
output_name=request.output_name,
multiplier=request.multiplier,
)
@router.get(
"/datasets",
response_model=AugmentedDatasetListResponse,
summary="List augmented datasets",
)
async def list_augmented_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
offset: int = Query(default=0, ge=0, description="Offset"),
) -> AugmentedDatasetListResponse:
"""List all augmented datasets."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.list_augmented_datasets(limit=limit, offset=offset)

View File

@@ -91,8 +91,19 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
group_key: Annotated[
str | None,
Query(description="Optional group key for document organization", max_length=255),
] = None,
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
@@ -131,6 +142,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
group_key=group_key,
)
# Save file to admin uploads
@@ -177,6 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
group_key=group_key,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@@ -277,6 +290,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
annotation_count=len(annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
@@ -421,6 +435,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
group_key=document.group_key if hasattr(document, 'group_key') else None,
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
@@ -548,4 +563,50 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
return response
@router.patch(
"/{document_id}/group-key",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document group key",
description="Update the group key for a document.",
)
async def update_document_group_key(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
group_key: Annotated[
str | None,
Query(description="New group key (null to clear)"),
] = None,
) -> dict:
"""Update document group key."""
_validate_uuid(document_id, "document_id")
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Verify document exists
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Update group key
db.update_document_group_key(document_id, group_key)
return {
"status": "updated",
"document_id": document_id,
"group_key": group_key,
"message": "Document group key updated",
}
return router

View File

@@ -11,6 +11,7 @@ from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
from .models import register_model_routes
def create_training_router() -> APIRouter:
@@ -21,6 +22,7 @@ def create_training_router() -> APIRouter:
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
register_model_routes(router)
return router

View File

@@ -41,6 +41,13 @@ def register_dataset_routes(router: APIRouter) -> None:
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
# Validate minimum document count for proper train/val/test split
if len(request.document_ids) < 10:
raise HTTPException(
status_code=400,
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
)
dataset = db.create_dataset(
name=request.name,
description=request.description,
@@ -83,6 +90,15 @@ def register_dataset_routes(router: APIRouter) -> None:
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
# Get active training tasks for each dataset (graceful degradation on error)
dataset_ids = [str(d.dataset_id) for d in datasets]
try:
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
except Exception:
logger.exception("Failed to fetch active training tasks")
active_tasks = {}
return DatasetListResponse(
total=total,
limit=limit,
@@ -93,6 +109,8 @@ def register_dataset_routes(router: APIRouter) -> None:
name=d.name,
description=d.description,
status=d.status,
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
@@ -175,6 +193,7 @@ def register_dataset_routes(router: APIRouter) -> None:
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
description="Create a training task. Set base_model_version_id in config for incremental training.",
)
async def train_from_dataset(
dataset_id: str,
@@ -182,7 +201,11 @@ def register_dataset_routes(router: APIRouter) -> None:
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset."""
"""Create a training task from a dataset.
For incremental training, set config.base_model_version_id to a model version UUID.
The training will use that model as the starting point instead of a pretrained model.
"""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
@@ -194,16 +217,42 @@ def register_dataset_routes(router: APIRouter) -> None:
)
config_dict = request.config.model_dump()
# Resolve base_model_version_id to actual model path for incremental training
base_model_version_id = config_dict.get("base_model_version_id")
if base_model_version_id:
_validate_uuid(base_model_version_id, "base_model_version_id")
base_model = db.get_model_version(base_model_version_id)
if not base_model:
raise HTTPException(
status_code=404,
detail=f"Base model version not found: {base_model_version_id}",
)
# Store the resolved model path for the training worker
config_dict["base_model_path"] = base_model.model_path
config_dict["base_model_version"] = base_model.version
logger.info(
"Incremental training: using model %s (%s) as base",
base_model.version,
base_model.model_path,
)
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type="train",
task_type="finetune" if base_model_version_id else "train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
message = (
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
if base_model_version_id
else "Training task created from dataset"
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
message="Training task created from dataset",
message=message,
)

View File

@@ -145,15 +145,15 @@ def register_document_routes(router: APIRouter) -> None:
)
@router.get(
"/models",
"/completed-tasks",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get trained models",
description="Get list of trained models with metrics and download links.",
summary="Get completed training tasks",
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
)
async def get_training_models(
async def get_completed_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[

View File

@@ -0,0 +1,333 @@
"""Model Version Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query, Request
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
ModelVersionItem,
ModelVersionListResponse,
ModelVersionDetailResponse,
ModelVersionResponse,
ActiveModelResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_model_routes(router: APIRouter) -> None:
"""Register model version endpoints on the router."""
@router.post(
"/models",
response_model=ModelVersionResponse,
summary="Create model version",
description="Register a new model version for deployment.",
)
async def create_model_version(
request: ModelVersionCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Create a new model version."""
if request.task_id:
_validate_uuid(request.task_id, "task_id")
if request.dataset_id:
_validate_uuid(request.dataset_id, "dataset_id")
model = db.create_model_version(
version=request.version,
name=request.name,
model_path=request.model_path,
description=request.description,
task_id=request.task_id,
dataset_id=request.dataset_id,
metrics_mAP=request.metrics_mAP,
metrics_precision=request.metrics_precision,
metrics_recall=request.metrics_recall,
document_count=request.document_count,
training_config=request.training_config,
file_size=request.file_size,
trained_at=request.trained_at,
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version created successfully",
)
@router.get(
"/models",
response_model=ModelVersionListResponse,
summary="List model versions",
)
async def list_model_versions(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> ModelVersionListResponse:
"""List model versions with optional status filter."""
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
return ModelVersionListResponse(
total=total,
limit=limit,
offset=offset,
models=[
ModelVersionItem(
version_id=str(m.version_id),
version=m.version,
name=m.name,
status=m.status,
is_active=m.is_active,
metrics_mAP=m.metrics_mAP,
document_count=m.document_count,
trained_at=m.trained_at,
activated_at=m.activated_at,
created_at=m.created_at,
)
for m in models
],
)
@router.get(
"/models/active",
response_model=ActiveModelResponse,
summary="Get active model",
description="Get the currently active model for inference.",
)
async def get_active_model(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ActiveModelResponse:
"""Get the currently active model version."""
model = db.get_active_model_version()
if not model:
return ActiveModelResponse(has_active_model=False, model=None)
return ActiveModelResponse(
has_active_model=True,
model=ModelVersionItem(
version_id=str(model.version_id),
version=model.version,
name=model.name,
status=model.status,
is_active=model.is_active,
metrics_mAP=model.metrics_mAP,
document_count=model.document_count,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
),
)
@router.get(
"/models/{version_id}",
response_model=ModelVersionDetailResponse,
summary="Get model version detail",
)
async def get_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionDetailResponse:
"""Get detailed model version information."""
_validate_uuid(version_id, "version_id")
model = db.get_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionDetailResponse(
version_id=str(model.version_id),
version=model.version,
name=model.name,
description=model.description,
model_path=model.model_path,
status=model.status,
is_active=model.is_active,
task_id=str(model.task_id) if model.task_id else None,
dataset_id=str(model.dataset_id) if model.dataset_id else None,
metrics_mAP=model.metrics_mAP,
metrics_precision=model.metrics_precision,
metrics_recall=model.metrics_recall,
document_count=model.document_count,
training_config=model.training_config,
file_size=model.file_size,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
updated_at=model.updated_at,
)
@router.patch(
"/models/{version_id}",
response_model=ModelVersionResponse,
summary="Update model version",
)
async def update_model_version(
version_id: str,
request: ModelVersionUpdateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Update model version metadata."""
_validate_uuid(version_id, "version_id")
model = db.update_model_version(
version_id=version_id,
name=request.name,
description=request.description,
status=request.status,
)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version updated successfully",
)
@router.post(
"/models/{version_id}/activate",
response_model=ModelVersionResponse,
summary="Activate model version",
description="Activate a model version for inference (deactivates all others).",
)
async def activate_model_version(
version_id: str,
request: Request,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Activate a model version for inference."""
_validate_uuid(version_id, "version_id")
model = db.activate_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
# Trigger model reload in inference service
inference_service = getattr(request.app.state, "inference_service", None)
model_reloaded = False
if inference_service:
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info(f"Inference model reloaded to version {model.version}")
except Exception as e:
logger.warning(f"Failed to reload inference model: {e}")
message = "Model version activated for inference"
if model_reloaded:
message += " (model reloaded)"
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message=message,
)
@router.post(
"/models/{version_id}/deactivate",
response_model=ModelVersionResponse,
summary="Deactivate model version",
)
async def deactivate_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Deactivate a model version."""
_validate_uuid(version_id, "version_id")
model = db.deactivate_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version deactivated",
)
@router.post(
"/models/{version_id}/archive",
response_model=ModelVersionResponse,
summary="Archive model version",
)
async def archive_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Archive a model version."""
_validate_uuid(version_id, "version_id")
model = db.archive_model_version(version_id)
if not model:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot archive active model",
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version archived",
)
@router.delete(
"/models/{version_id}",
summary="Delete model version",
)
async def delete_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a model version."""
_validate_uuid(version_id, "version_id")
success = db.delete_model_version(version_id)
if not success:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot delete active model",
)
return {"message": "Model version deleted"}
@router.post(
"/models/reload",
summary="Reload inference model",
description="Reload the inference model from the currently active model version.",
)
async def reload_inference_model(
request: Request,
admin_token: AdminTokenDep,
) -> dict:
"""Reload the inference model from active version."""
inference_service = getattr(request.app.state, "inference_service", None)
if not inference_service:
raise HTTPException(
status_code=500,
detail="Inference service not available",
)
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info("Inference model manually reloaded")
return {"message": "Model reloaded successfully", "reloaded": True}
else:
return {"message": "Model already up to date", "reloaded": False}
except Exception as e:
logger.error(f"Failed to reload model: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to reload model: {e}",
)

View File

@@ -37,6 +37,7 @@ from inference.web.core.rate_limiter import RateLimiter
# Admin API imports
from inference.web.api.v1.admin import (
create_annotation_router,
create_augmentation_router,
create_auth_router,
create_documents_router,
create_locks_router,
@@ -69,10 +70,23 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
"""
config = config or default_config
# Create inference service
# Create model path resolver that reads from database
def get_active_model_path():
"""Resolve active model path from database."""
try:
db = AdminDB()
active_model = db.get_active_model_version()
if active_model and active_model.model_path:
return active_model.model_path
except Exception as e:
logger.warning(f"Failed to get active model from database: {e}")
return None
# Create inference service with database model resolver
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
model_path_resolver=get_active_model_path,
)
# Create async processing components
@@ -185,6 +199,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
# Store inference service for access by routes (e.g., model reload)
# This will be set after app creation
app = FastAPI(
title="Invoice Field Extraction API",
description="""
@@ -255,9 +272,15 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
augmentation_router = create_augmentation_router()
app.include_router(augmentation_router, prefix="/api/v1/admin")
# Include batch upload routes
app.include_router(batch_upload_router)
# Store inference service in app state for access by routes
app.state.inference_service = inference_service
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:

View File

@@ -110,6 +110,7 @@ class TrainingScheduler:
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
base_model_path = config.get("base_model_path") # For incremental training
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
@@ -117,12 +118,31 @@ class TrainingScheduler:
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
# Get augmentation config if present
augmentation_config = config.get("augmentation")
augmentation_multiplier = config.get("augmentation_multiplier", 2)
# Determine which model to use as base
if base_model_path:
# Incremental training: use existing trained model
if not Path(base_model_path).exists():
raise ValueError(f"Base model not found: {base_model_path}")
effective_model = base_model_path
self._db.add_training_log(
task_id, "INFO",
f"Incremental training from: {base_model_path}",
)
else:
# Train from pretrained model
effective_model = model_name
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
dataset_path = Path(dataset.dataset_path)
self._db.add_training_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
@@ -132,15 +152,28 @@ class TrainingScheduler:
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
dataset_path = Path(data_yaml).parent
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Apply augmentation if config is provided
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
aug_result = self._apply_augmentation(
task_id, dataset_path, augmentation_config, augmentation_multiplier
)
if aug_result:
self._db.add_training_log(
task_id, "INFO",
f"Augmentation complete: {aug_result['augmented_images']} new images "
f"(total: {aug_result['total_images']})",
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=model_name,
model_name=effective_model, # Use base model or pretrained model
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
@@ -159,11 +192,94 @@ class TrainingScheduler:
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
# Auto-create model version for the completed training
self._create_model_version_from_training(
task_id=task_id,
config=config,
dataset_id=dataset_id,
result=result,
)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
raise
def _create_model_version_from_training(
self,
task_id: str,
config: dict[str, Any],
dataset_id: str | None,
result: dict[str, Any],
) -> None:
"""Create a model version entry from completed training."""
try:
model_path = result.get("model_path")
if not model_path:
logger.warning(f"No model path in training result for task {task_id}")
return
# Get task info for name
task = self._db.get_training_task(task_id)
task_name = task.name if task else f"Task {task_id[:8]}"
# Generate version number based on existing versions
existing_versions = self._db.get_model_versions(limit=1, offset=0)
version_count = existing_versions[1] if existing_versions else 0
version = f"v{version_count + 1}.0"
# Extract metrics from result
metrics = result.get("metrics", {})
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
metrics_precision = metrics.get("precision")
metrics_recall = metrics.get("recall")
# Get file size if possible
file_size = None
model_file = Path(model_path)
if model_file.exists():
file_size = model_file.stat().st_size
# Get document count from dataset if available
document_count = 0
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if dataset:
document_count = dataset.total_documents
# Create model version
model_version = self._db.create_model_version(
version=version,
name=task_name,
model_path=str(model_path),
description=f"Auto-created from training task {task_id[:8]}",
task_id=task_id,
dataset_id=dataset_id,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=config,
file_size=file_size,
trained_at=datetime.utcnow(),
)
logger.info(
f"Created model version {version} (ID: {model_version.version_id}) "
f"from training task {task_id}"
)
self._db.add_training_log(
task_id, "INFO",
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
)
except Exception as e:
logger.error(f"Failed to create model version for task {task_id}: {e}")
self._db.add_training_log(
task_id, "WARNING",
f"Failed to auto-create model version: {e}",
)
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
@@ -256,62 +372,82 @@ names: {list(FIELD_CLASSES.values())}
device: str,
project_name: str,
) -> dict[str, Any]:
"""Run YOLO training."""
"""Run YOLO training using shared trainer."""
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
# Create log callback that writes to DB
def log_callback(level: str, message: str) -> None:
self._db.add_training_log(task_id, level, message)
# Create shared training config
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
config = SharedTrainingConfig(
model_path=model_name,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project="runs/train",
name=f"{project_name}/task_{task_id[:8]}",
workers=0,
)
# Run training using shared trainer
trainer = YOLOTrainer(config=config, log_callback=log_callback)
result = trainer.train()
if not result.success:
raise ValueError(result.error or "Training failed")
return {
"model_path": result.model_path,
"metrics": result.metrics,
}
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
"""Check if any augmentations are enabled in the config."""
augmentation_fields = [
"perspective_warp", "wrinkle", "edge_damage", "stain",
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
]
for field in augmentation_fields:
if field in aug_config:
field_config = aug_config[field]
if isinstance(field_config, dict) and field_config.get("enabled", False):
return True
return False
def _apply_augmentation(
self,
task_id: str,
dataset_path: Path,
aug_config: dict[str, Any],
multiplier: int,
) -> dict[str, int] | None:
"""Apply augmentation to dataset before training."""
try:
from ultralytics import YOLO
# Log training start
self._db.add_training_log(
task_id, "INFO",
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
)
# Load model
model = YOLO(model_name)
# Train
results = model.train(
data=data_yaml,
epochs=epochs,
batch=batch_size,
imgsz=image_size,
lr0=learning_rate,
device=device,
project=f"runs/train/{project_name}",
name=f"task_{task_id[:8]}",
exist_ok=True,
verbose=True,
)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
from shared.augmentation import DatasetAugmenter
self._db.add_training_log(
task_id, "INFO",
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
f"Applying augmentation with multiplier={multiplier}",
)
return {
"model_path": str(best_model) if best_model.exists() else None,
"metrics": metrics,
}
augmenter = DatasetAugmenter(aug_config)
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
return result
except ImportError:
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
raise ValueError("Ultralytics (YOLO) not installed")
except Exception as e:
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
raise
logger.error(f"Augmentation failed for task {task_id}: {e}")
self._db.add_training_log(
task_id, "WARNING",
f"Augmentation failed: {e}. Continuing with original dataset.",
)
return None
# Global scheduler instance

View File

@@ -10,6 +10,7 @@ from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse

View File

@@ -0,0 +1,187 @@
"""Admin Augmentation Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class AugmentationParamsSchema(BaseModel):
"""Single augmentation parameters."""
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
probability: float = Field(
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
)
params: dict[str, Any] = Field(
default_factory=dict, description="Type-specific parameters"
)
class AugmentationConfigSchema(BaseModel):
"""Complete augmentation configuration."""
# Geometric transforms
perspective_warp: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Degradation effects
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
edge_damage: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Lighting effects
lighting_variation: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Blur effects
gaussian_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
motion_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Noise effects
gaussian_noise: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
salt_pepper: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Texture effects
paper_texture: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
scanner_artifacts: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Global settings
preserve_bboxes: bool = Field(
default=True, description="Whether to adjust bboxes for geometric transforms"
)
seed: int | None = Field(default=None, description="Random seed for reproducibility")
class AugmentationTypeInfo(BaseModel):
"""Information about an augmentation type."""
name: str = Field(..., description="Augmentation name")
description: str = Field(..., description="Augmentation description")
affects_geometry: bool = Field(
..., description="Whether this augmentation affects bbox coordinates"
)
stage: str = Field(..., description="Processing stage")
default_params: dict[str, Any] = Field(
default_factory=dict, description="Default parameters"
)
class AugmentationTypesResponse(BaseModel):
"""Response for listing augmentation types."""
augmentation_types: list[AugmentationTypeInfo] = Field(
..., description="Available augmentation types"
)
class PresetInfo(BaseModel):
"""Information about a preset."""
name: str = Field(..., description="Preset name")
description: str = Field(..., description="Preset description")
class PresetsResponse(BaseModel):
"""Response for listing presets."""
presets: list[PresetInfo] = Field(..., description="Available presets")
class AugmentationPreviewRequest(BaseModel):
"""Request to preview augmentation on an image."""
augmentation_type: str = Field(..., description="Type of augmentation to preview")
params: dict[str, Any] = Field(
default_factory=dict, description="Override parameters"
)
class AugmentationPreviewResponse(BaseModel):
"""Response with preview image data."""
preview_url: str = Field(..., description="URL to preview image")
original_url: str = Field(..., description="URL to original image")
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
class AugmentationBatchRequest(BaseModel):
"""Request to augment a dataset offline."""
dataset_id: str = Field(..., description="Source dataset UUID")
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
output_name: str = Field(
..., min_length=1, max_length=255, description="Output dataset name"
)
multiplier: int = Field(
default=2, ge=1, le=10, description="Augmented copies per image"
)
class AugmentationBatchResponse(BaseModel):
"""Response for batch augmentation."""
task_id: str = Field(..., description="Background task UUID")
status: str = Field(..., description="Task status")
message: str = Field(..., description="Status message")
estimated_images: int = Field(..., description="Estimated total images")
class AugmentedDatasetItem(BaseModel):
"""Single augmented dataset in list."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
created_at: datetime = Field(..., description="Creation timestamp")
class AugmentedDatasetListResponse(BaseModel):
"""Response for listing augmented datasets."""
total: int = Field(..., ge=0, description="Total datasets")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
datasets: list[AugmentedDatasetItem] = Field(
default_factory=list, description="Dataset list"
)
class AugmentedDatasetDetailResponse(BaseModel):
"""Detailed augmented dataset response."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
config: AugmentationConfigSchema | None = Field(
None, description="Augmentation config used"
)
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
dataset_path: str | None = Field(None, description="Dataset path on disk")
error_message: str | None = Field(None, description="Error message if failed")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")

View File

@@ -63,6 +63,8 @@ class DatasetListItem(BaseModel):
name: str
description: str | None
status: str
training_status: str | None = None
active_training_task_id: str | None = None
total_documents: int
total_images: int
total_annotations: int

View File

@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
group_key: str | None = Field(None, description="User-defined group key")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
@@ -42,6 +43,7 @@ class DocumentItem(BaseModel):
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@@ -73,6 +75,7 @@ class DocumentDetailResponse(BaseModel):
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)

View File

@@ -0,0 +1,95 @@
"""Admin Model Version Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class ModelVersionCreateRequest(BaseModel):
"""Request to create a model version."""
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
name: str = Field(..., min_length=1, max_length=255, description="Model name")
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
description: str | None = Field(None, description="Optional description")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
document_count: int = Field(0, ge=0, description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
class ModelVersionUpdateRequest(BaseModel):
"""Request to update a model version."""
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
description: str | None = Field(None, description="Description")
status: str | None = Field(None, description="Status (inactive, archived)")
class ModelVersionItem(BaseModel):
"""Model version in list view."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
document_count: int = Field(..., description="Documents used in training")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
class ModelVersionListResponse(BaseModel):
"""Paginated model version list."""
total: int = Field(..., ge=0, description="Total model versions")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
class ModelVersionDetailResponse(BaseModel):
"""Detailed model version info."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
description: str | None = Field(None, description="Description")
model_path: str = Field(..., description="Path to model file")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
metrics_precision: float | None = Field(None, description="Precision")
metrics_recall: float | None = Field(None, description="Recall")
document_count: int = Field(..., description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class ModelVersionResponse(BaseModel):
"""Response for model version operation."""
version_id: str = Field(..., description="Version UUID")
status: str = Field(..., description="Model status")
message: str = Field(..., description="Status message")
class ActiveModelResponse(BaseModel):
"""Response for active model query."""
has_active_model: bool = Field(..., description="Whether an active model exists")
model: ModelVersionItem | None = Field(None, description="Active model if exists")

View File

@@ -5,13 +5,18 @@ from typing import Any
from pydantic import BaseModel, Field
from .augmentation import AugmentationConfigSchema
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name")
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
base_model_version_id: str | None = Field(
default=None,
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
)
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
@@ -21,6 +26,18 @@ class TrainingConfig(BaseModel):
default="invoice_fields", description="Training project name"
)
# Data augmentation settings
augmentation: AugmentationConfigSchema | None = Field(
default=None,
description="Augmentation configuration. If provided, augments dataset before training.",
)
augmentation_multiplier: int = Field(
default=2,
ge=1,
le=10,
description="Number of augmented copies per original image",
)
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""

View File

@@ -0,0 +1,317 @@
"""Augmentation service for handling augmentation operations."""
import base64
import io
import re
import uuid
from pathlib import Path
from typing import Any
import numpy as np
from fastapi import HTTPException
from PIL import Image
from inference.data.admin_db import AdminDB
from inference.web.schemas.admin.augmentation import (
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
)
# Constants
PREVIEW_MAX_SIZE = 800
PREVIEW_SEED = 42
UUID_PATTERN = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
re.IGNORECASE,
)
class AugmentationService:
"""Service for augmentation operations."""
def __init__(self, db: AdminDB) -> None:
"""Initialize service with database connection."""
self.db = db
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
"""
Validate UUID format to prevent path traversal.
Args:
value: Value to validate.
field_name: Field name for error message.
Raises:
HTTPException: If value is not a valid UUID.
"""
if not UUID_PATTERN.match(value):
raise HTTPException(
status_code=400,
detail=f"Invalid {field_name} format: {value}",
)
async def preview_single(
self,
document_id: str,
page: int,
augmentation_type: str,
params: dict[str, Any],
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
augmentation_type: Name of augmentation to apply.
params: Override parameters.
Returns:
Preview response with image URLs.
Raises:
HTTPException: If document not found or augmentation invalid.
"""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
# Validate augmentation type
if augmentation_type not in AUGMENTATION_REGISTRY:
raise HTTPException(
status_code=400,
detail=f"Unknown augmentation type: {augmentation_type}. "
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
)
# Get document and load image
image = await self._load_document_page(document_id, page)
# Create config with only this augmentation enabled
config_kwargs = {
augmentation_type: AugmentationParams(
enabled=True,
probability=1.0, # Always apply for preview
params=params,
),
"seed": PREVIEW_SEED, # Deterministic preview
}
config = AugmentationConfig(**config_kwargs)
pipeline = AugmentationPipeline(config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=params,
)
async def preview_config(
self,
document_id: str,
page: int,
config: AugmentationConfigSchema,
) -> AugmentationPreviewResponse:
"""
Preview full augmentation config on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
config: Full augmentation configuration.
Returns:
Preview response with image URLs.
"""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
# Load image
image = await self._load_document_page(document_id, page)
# Convert Pydantic model to internal config
config_dict = config.model_dump()
internal_config = AugmentationConfig.from_dict(config_dict)
pipeline = AugmentationPipeline(internal_config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=config_dict,
)
async def create_augmented_dataset(
self,
source_dataset_id: str,
config: AugmentationConfigSchema,
output_name: str,
multiplier: int,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
Args:
source_dataset_id: Source dataset UUID.
config: Augmentation configuration.
output_name: Name for the new dataset.
multiplier: Number of augmented copies per image.
Returns:
Batch response with task ID.
Raises:
HTTPException: If source dataset not found.
"""
# Validate source dataset exists
try:
source_dataset = self.db.get_dataset(source_dataset_id)
if source_dataset is None:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
)
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
) from e
# Create task ID for background processing
task_id = str(uuid.uuid4())
# Estimate total images
estimated_images = (
source_dataset.total_images * multiplier
if hasattr(source_dataset, "total_images")
else 0
)
# TODO: Queue background task for actual augmentation
# For now, return pending status
return AugmentationBatchResponse(
task_id=task_id,
status="pending",
message=f"Augmentation task queued for dataset '{output_name}'",
estimated_images=estimated_images,
)
async def list_augmented_datasets(
self,
limit: int = 20,
offset: int = 0,
) -> AugmentedDatasetListResponse:
"""
List augmented datasets.
Args:
limit: Maximum number of datasets to return.
offset: Number of datasets to skip.
Returns:
List response with datasets.
"""
# TODO: Implement actual database query for augmented datasets
# For now, return empty list
return AugmentedDatasetListResponse(
total=0,
limit=limit,
offset=offset,
datasets=[],
)
async def _load_document_page(
self,
document_id: str,
page: int,
) -> np.ndarray:
"""
Load a document page as numpy array.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
Returns:
Image as numpy array (H, W, C) with dtype uint8.
Raises:
HTTPException: If document or page not found.
"""
# Validate document_id format to prevent path traversal
self._validate_uuid(document_id, "document_id")
# Get document from database
try:
document = self.db.get_document(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
) from e
# Get image path for page
if hasattr(document, "images_dir"):
images_dir = Path(document.images_dir)
else:
# Fallback to constructed path
from inference.web.core.config import get_settings
settings = get_settings()
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
# Find image for page
page_idx = page - 1 # Convert to 0-indexed
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
if page_idx >= len(image_files):
raise HTTPException(
status_code=404,
detail=f"Page {page} not found for document {document_id}",
)
# Load image
image_path = image_files[page_idx]
pil_image = Image.open(image_path).convert("RGB")
return np.array(pil_image)
def _image_to_data_url(self, image: np.ndarray) -> str:
"""Convert numpy image to base64 data URL."""
pil_image = Image.fromarray(image)
# Resize for preview if too large
max_size = PREVIEW_MAX_SIZE
if max(pil_image.size) > max_size:
ratio = max_size / max(pil_image.size)
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# Convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{base64_data}"

View File

@@ -81,29 +81,18 @@ class DatasetBuilder:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Shuffle and split documents
# 3. Group documents by group_key and assign splits
doc_list = list(documents)
rng = random.Random(seed)
rng.shuffle(doc_list)
n = len(doc_list)
n_train = max(1, round(n * train_ratio))
n_val = max(0, round(n * val_ratio))
n_test = n - n_train - n_val
splits = (
["train"] * n_train
+ ["val"] * n_val
+ ["test"] * n_test
)
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc, split in zip(doc_list, splits):
for doc in doc_list:
doc_id = str(doc.document_id)
split = doc_splits[doc_id]
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
@@ -174,6 +163,86 @@ class DatasetBuilder:
"total_annotations": total_annotations,
}
def _assign_splits_by_group(
self,
documents: list,
train_ratio: float,
val_ratio: float,
seed: int,
) -> dict[str, str]:
"""Assign splits based on group_key.
Logic:
- Documents with same group_key stay together in the same split
- Groups with only 1 document go directly to train
- Groups with 2+ documents participate in shuffle & split
Args:
documents: List of AdminDocument objects
train_ratio: Fraction for training set
val_ratio: Fraction for validation set
seed: Random seed for reproducibility
Returns:
Dict mapping document_id (str) -> split ("train"/"val"/"test")
"""
# Group documents by group_key
# None/empty group_key treated as unique (each doc is its own group)
groups: dict[str | None, list] = {}
for doc in documents:
key = doc.group_key if doc.group_key else None
if key is None:
# Treat each ungrouped doc as its own unique group
# Use document_id as pseudo-key
key = f"__ungrouped_{doc.document_id}"
groups.setdefault(key, []).append(doc)
# Separate single-doc groups from multi-doc groups
single_doc_groups: list[tuple[str | None, list]] = []
multi_doc_groups: list[tuple[str | None, list]] = []
for key, docs in groups.items():
if len(docs) == 1:
single_doc_groups.append((key, docs))
else:
multi_doc_groups.append((key, docs))
# Initialize result mapping
doc_splits: dict[str, str] = {}
# Combine all groups for splitting
all_groups = single_doc_groups + multi_doc_groups
# Shuffle all groups and assign splits
if all_groups:
rng = random.Random(seed)
rng.shuffle(all_groups)
n_groups = len(all_groups)
n_train = max(1, round(n_groups * train_ratio))
# Ensure at least 1 in val if we have more than 1 group
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
for i, (_key, docs) in enumerate(all_groups):
if i < n_train:
split = "train"
elif i < n_train + n_val:
split = "val"
else:
split = "test"
for doc in docs:
doc_splits[str(doc.document_id)] = split
logger.info(
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
len(all_groups),
sum(1 for s in doc_splits.values() if s == "train"),
sum(1 for s in doc_splits.values() if s == "val"),
)
return doc_splits
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {

View File

@@ -11,7 +11,7 @@ import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
import numpy as np
from PIL import Image
@@ -22,6 +22,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Type alias for model path resolver function
ModelPathResolver = Callable[[], Path | None]
@dataclass
class ServiceResult:
"""Result from inference service."""
@@ -42,25 +46,52 @@ class InferenceService:
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration
model_config: Model configuration (default model settings)
storage_config: Storage configuration
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
self._pipeline = None
self._detector = None
self._is_initialized = False
self._current_model_path: Path | None = None
def _resolve_model_path(self) -> Path:
"""Resolve the model path to use for inference.
Priority:
1. Active model from database (via resolver)
2. Default model from config
"""
if self._model_path_resolver:
try:
db_model_path = self._model_path_resolver()
if db_model_path and Path(db_model_path).exists():
logger.info(f"Using active model from database: {db_model_path}")
return Path(db_model_path)
elif db_model_path:
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
except Exception as e:
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
return self.model_config.model_path
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
@@ -74,16 +105,20 @@ class InferenceService:
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
# Resolve model path (from DB or config)
model_path = self._resolve_model_path()
self._current_model_path = model_path
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(self.model_config.model_path),
str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(self.model_config.model_path),
model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
@@ -92,12 +127,36 @@ class InferenceService:
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s")
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
def reload_model(self) -> bool:
"""Reload the model if active model has changed.
Returns:
True if model was reloaded, False if no change needed.
"""
new_model_path = self._resolve_model_path()
if self._current_model_path == new_model_path:
logger.debug("Model unchanged, no reload needed")
return False
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
self._is_initialized = False
self._pipeline = None
self._detector = None
self.initialize()
return True
@property
def current_model_path(self) -> Path | None:
"""Get the currently loaded model path."""
return self._current_model_path
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""