WIP
This commit is contained in:
@@ -25,6 +25,7 @@ from inference.data.admin_models import (
|
||||
AnnotationHistory,
|
||||
TrainingDataset,
|
||||
DatasetDocument,
|
||||
ModelVersion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -110,6 +111,7 @@ class AdminDB:
|
||||
page_count: int = 1,
|
||||
upload_source: str = "ui",
|
||||
csv_field_values: dict[str, Any] | None = None,
|
||||
group_key: str | None = None,
|
||||
admin_token: str | None = None, # Deprecated, kept for compatibility
|
||||
) -> str:
|
||||
"""Create a new document record."""
|
||||
@@ -122,6 +124,7 @@ class AdminDB:
|
||||
page_count=page_count,
|
||||
upload_source=upload_source,
|
||||
csv_field_values=csv_field_values,
|
||||
group_key=group_key,
|
||||
)
|
||||
session.add(document)
|
||||
session.flush()
|
||||
@@ -253,6 +256,17 @@ class AdminDB:
|
||||
document.updated_at = datetime.utcnow()
|
||||
session.add(document)
|
||||
|
||||
def update_document_group_key(self, document_id: str, group_key: str | None) -> bool:
|
||||
"""Update document group key."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.group_key = group_key
|
||||
document.updated_at = datetime.utcnow()
|
||||
session.add(document)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_document(self, document_id: str) -> bool:
|
||||
"""Delete a document and its annotations."""
|
||||
with get_session_context() as session:
|
||||
@@ -1215,6 +1229,39 @@ class AdminDB:
|
||||
session.expunge(d)
|
||||
return list(datasets), total
|
||||
|
||||
def get_active_training_tasks_for_datasets(
|
||||
self, dataset_ids: list[str]
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Get active (pending/scheduled/running) training tasks for datasets.
|
||||
|
||||
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
|
||||
"""
|
||||
if not dataset_ids:
|
||||
return {}
|
||||
|
||||
# Validate UUIDs before query
|
||||
valid_uuids = []
|
||||
for d in dataset_ids:
|
||||
try:
|
||||
valid_uuids.append(UUID(d))
|
||||
except ValueError:
|
||||
logger.warning("Invalid UUID in get_active_training_tasks_for_datasets: %s", d)
|
||||
continue
|
||||
|
||||
if not valid_uuids:
|
||||
return {}
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingTask).where(
|
||||
TrainingTask.dataset_id.in_(valid_uuids),
|
||||
TrainingTask.status.in_(["pending", "scheduled", "running"]),
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
return {
|
||||
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
|
||||
for t in results
|
||||
}
|
||||
|
||||
def update_dataset_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
@@ -1314,3 +1361,182 @@ class AdminDB:
|
||||
session.delete(dataset)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# ==========================================================================
|
||||
# Model Version Operations
|
||||
# ==========================================================================
|
||||
|
||||
def create_model_version(
|
||||
self,
|
||||
version: str,
|
||||
name: str,
|
||||
model_path: str,
|
||||
description: str | None = None,
|
||||
task_id: str | UUID | None = None,
|
||||
dataset_id: str | UUID | None = None,
|
||||
metrics_mAP: float | None = None,
|
||||
metrics_precision: float | None = None,
|
||||
metrics_recall: float | None = None,
|
||||
document_count: int = 0,
|
||||
training_config: dict[str, Any] | None = None,
|
||||
file_size: int | None = None,
|
||||
trained_at: datetime | None = None,
|
||||
) -> ModelVersion:
|
||||
"""Create a new model version."""
|
||||
with get_session_context() as session:
|
||||
model = ModelVersion(
|
||||
version=version,
|
||||
name=name,
|
||||
model_path=model_path,
|
||||
description=description,
|
||||
task_id=UUID(str(task_id)) if task_id else None,
|
||||
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
|
||||
metrics_mAP=metrics_mAP,
|
||||
metrics_precision=metrics_precision,
|
||||
metrics_recall=metrics_recall,
|
||||
document_count=document_count,
|
||||
training_config=training_config,
|
||||
file_size=file_size,
|
||||
trained_at=trained_at,
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Get a model version by ID."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if model:
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get_model_versions(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ModelVersion], int]:
|
||||
"""List model versions with optional status filter."""
|
||||
with get_session_context() as session:
|
||||
query = select(ModelVersion)
|
||||
count_query = select(func.count()).select_from(ModelVersion)
|
||||
if status:
|
||||
query = query.where(ModelVersion.status == status)
|
||||
count_query = count_query.where(ModelVersion.status == status)
|
||||
total = session.exec(count_query).one()
|
||||
models = session.exec(
|
||||
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
|
||||
).all()
|
||||
for m in models:
|
||||
session.expunge(m)
|
||||
return list(models), total
|
||||
|
||||
def get_active_model_version(self) -> ModelVersion | None:
|
||||
"""Get the currently active model version for inference."""
|
||||
with get_session_context() as session:
|
||||
result = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def activate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Activate a model version for inference (deactivates all others)."""
|
||||
with get_session_context() as session:
|
||||
# Deactivate all versions
|
||||
all_versions = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).all()
|
||||
for v in all_versions:
|
||||
v.is_active = False
|
||||
v.status = "inactive"
|
||||
v.updated_at = datetime.utcnow()
|
||||
session.add(v)
|
||||
|
||||
# Activate the specified version
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = True
|
||||
model.status = "active"
|
||||
model.activated_at = datetime.utcnow()
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def deactivate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Deactivate a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = False
|
||||
model.status = "inactive"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def update_model_version(
|
||||
self,
|
||||
version_id: str | UUID,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> ModelVersion | None:
|
||||
"""Update model version metadata."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
if name is not None:
|
||||
model.name = name
|
||||
if description is not None:
|
||||
model.description = description
|
||||
if status is not None:
|
||||
model.status = status
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def archive_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Archive a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
# Cannot archive active model
|
||||
if model.is_active:
|
||||
return None
|
||||
model.status = "archived"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def delete_model_version(self, version_id: str | UUID) -> bool:
|
||||
"""Delete a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return False
|
||||
# Cannot delete active model
|
||||
if model.is_active:
|
||||
return False
|
||||
session.delete(model)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@@ -70,6 +70,8 @@ class AdminDocument(SQLModel, table=True):
|
||||
# Upload source: ui, api
|
||||
batch_id: UUID | None = Field(default=None, index=True)
|
||||
# Link to batch upload (if uploaded via ZIP)
|
||||
group_key: str | None = Field(default=None, max_length=255, index=True)
|
||||
# User-defined grouping key for document organization
|
||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Original CSV values for reference
|
||||
auto_label_queued_at: datetime | None = Field(default=None)
|
||||
@@ -275,6 +277,56 @@ class TrainingDocumentLink(SQLModel, table=True):
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Version Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ModelVersion(SQLModel, table=True):
|
||||
"""Model version for inference deployment."""
|
||||
|
||||
__tablename__ = "model_versions"
|
||||
|
||||
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
version: str = Field(max_length=50, index=True)
|
||||
# Semantic version e.g., "1.0.0", "2.1.0"
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
model_path: str = Field(max_length=512)
|
||||
# Path to the model weights file
|
||||
status: str = Field(default="inactive", max_length=20, index=True)
|
||||
# Status: active, inactive, archived
|
||||
is_active: bool = Field(default=False, index=True)
|
||||
# Only one version can be active at a time for inference
|
||||
|
||||
# Training association
|
||||
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
|
||||
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||
|
||||
# Training metrics
|
||||
metrics_mAP: float | None = Field(default=None)
|
||||
metrics_precision: float | None = Field(default=None)
|
||||
metrics_recall: float | None = Field(default=None)
|
||||
document_count: int = Field(default=0)
|
||||
# Number of documents used in training
|
||||
|
||||
# Training configuration snapshot
|
||||
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Snapshot of epochs, batch_size, etc.
|
||||
|
||||
# File info
|
||||
file_size: int | None = Field(default=None)
|
||||
# Model file size in bytes
|
||||
|
||||
# Timestamps
|
||||
trained_at: datetime | None = Field(default=None)
|
||||
# When training completed
|
||||
activated_at: datetime | None = Field(default=None)
|
||||
# When this version was last activated
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation History (v2)
|
||||
# =============================================================================
|
||||
|
||||
@@ -49,6 +49,111 @@ def get_engine():
|
||||
return _engine
|
||||
|
||||
|
||||
def run_migrations() -> None:
|
||||
"""Run database migrations for new columns."""
|
||||
engine = get_engine()
|
||||
|
||||
migrations = [
|
||||
# Migration 004: Training datasets tables and dataset_id on training_tasks
|
||||
(
|
||||
"training_datasets_tables",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS training_datasets (
|
||||
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'building',
|
||||
train_ratio FLOAT NOT NULL DEFAULT 0.8,
|
||||
val_ratio FLOAT NOT NULL DEFAULT 0.1,
|
||||
seed INTEGER NOT NULL DEFAULT 42,
|
||||
total_documents INTEGER NOT NULL DEFAULT 0,
|
||||
total_images INTEGER NOT NULL DEFAULT 0,
|
||||
total_annotations INTEGER NOT NULL DEFAULT 0,
|
||||
dataset_path VARCHAR(512),
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"dataset_documents_table",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS dataset_documents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
|
||||
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
|
||||
split VARCHAR(10) NOT NULL,
|
||||
page_count INTEGER NOT NULL DEFAULT 0,
|
||||
annotation_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(dataset_id, document_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_dataset_id",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
|
||||
""",
|
||||
),
|
||||
# Migration 005: Add group_key to admin_documents
|
||||
(
|
||||
"admin_documents_group_key",
|
||||
"""
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
|
||||
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
|
||||
""",
|
||||
),
|
||||
# Migration 006: Model versions table
|
||||
(
|
||||
"model_versions_table",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS model_versions (
|
||||
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
version VARCHAR(50) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
model_path VARCHAR(512) NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
|
||||
is_active BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
task_id UUID REFERENCES training_tasks(task_id),
|
||||
dataset_id UUID REFERENCES training_datasets(dataset_id),
|
||||
metrics_mAP FLOAT,
|
||||
metrics_precision FLOAT,
|
||||
metrics_recall FLOAT,
|
||||
document_count INTEGER NOT NULL DEFAULT 0,
|
||||
training_config JSONB,
|
||||
file_size BIGINT,
|
||||
trained_at TIMESTAMP WITH TIME ZONE,
|
||||
activated_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
with engine.connect() as conn:
|
||||
for name, sql in migrations:
|
||||
try:
|
||||
conn.execute(text(sql))
|
||||
conn.commit()
|
||||
logger.info(f"Migration '{name}' applied successfully")
|
||||
except Exception as e:
|
||||
# Log but don't fail - column may already exist
|
||||
logger.debug(f"Migration '{name}' skipped or failed: {e}")
|
||||
|
||||
|
||||
def create_db_and_tables() -> None:
|
||||
"""Create all database tables."""
|
||||
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||
@@ -64,6 +169,9 @@ def create_db_and_tables() -> None:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
logger.info("Database tables created/verified")
|
||||
|
||||
# Run migrations for new columns
|
||||
run_migrations()
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
"""Get a new database session."""
|
||||
|
||||
@@ -5,6 +5,7 @@ Document management, annotations, and training endpoints.
|
||||
"""
|
||||
|
||||
from inference.web.api.v1.admin.annotations import create_annotation_router
|
||||
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from inference.web.api.v1.admin.auth import create_auth_router
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.api.v1.admin.locks import create_locks_router
|
||||
@@ -12,6 +13,7 @@ from inference.web.api.v1.admin.training import create_training_router
|
||||
|
||||
__all__ = [
|
||||
"create_annotation_router",
|
||||
"create_augmentation_router",
|
||||
"create_auth_router",
|
||||
"create_documents_router",
|
||||
"create_locks_router",
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Augmentation API module."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import register_augmentation_routes
|
||||
|
||||
|
||||
def create_augmentation_router() -> APIRouter:
|
||||
"""Create and configure the augmentation router."""
|
||||
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
|
||||
register_augmentation_routes(router)
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["create_augmentation_router"]
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Augmentation API routes."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminDBDep, AdminTokenDep
|
||||
from inference.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchRequest,
|
||||
AugmentationBatchResponse,
|
||||
AugmentationConfigSchema,
|
||||
AugmentationPreviewRequest,
|
||||
AugmentationPreviewResponse,
|
||||
AugmentationTypeInfo,
|
||||
AugmentationTypesResponse,
|
||||
AugmentedDatasetItem,
|
||||
AugmentedDatasetListResponse,
|
||||
PresetInfo,
|
||||
PresetsResponse,
|
||||
)
|
||||
|
||||
|
||||
def register_augmentation_routes(router: APIRouter) -> None:
|
||||
"""Register augmentation endpoints on the router."""
|
||||
|
||||
@router.get(
|
||||
"/types",
|
||||
response_model=AugmentationTypesResponse,
|
||||
summary="List available augmentation types",
|
||||
)
|
||||
async def list_augmentation_types(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> AugmentationTypesResponse:
|
||||
"""
|
||||
List all available augmentation types with descriptions and parameters.
|
||||
"""
|
||||
from shared.augmentation.pipeline import (
|
||||
AUGMENTATION_REGISTRY,
|
||||
AugmentationPipeline,
|
||||
)
|
||||
|
||||
types = []
|
||||
for name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||
# Create instance with empty params to get preview params
|
||||
aug = aug_class({})
|
||||
types.append(
|
||||
AugmentationTypeInfo(
|
||||
name=name,
|
||||
description=(aug_class.__doc__ or "").strip(),
|
||||
affects_geometry=aug_class.affects_geometry,
|
||||
stage=AugmentationPipeline.STAGE_MAPPING[name],
|
||||
default_params=aug.get_preview_params(),
|
||||
)
|
||||
)
|
||||
|
||||
return AugmentationTypesResponse(augmentation_types=types)
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
response_model=PresetsResponse,
|
||||
summary="Get augmentation presets",
|
||||
)
|
||||
async def get_presets(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> PresetsResponse:
|
||||
"""Get predefined augmentation presets for common use cases."""
|
||||
from shared.augmentation.presets import list_presets
|
||||
|
||||
presets = [PresetInfo(**p) for p in list_presets()]
|
||||
return PresetsResponse(presets=presets)
|
||||
|
||||
@router.post(
|
||||
"/preview/{document_id}",
|
||||
response_model=AugmentationPreviewResponse,
|
||||
summary="Preview augmentation on document image",
|
||||
)
|
||||
async def preview_augmentation(
|
||||
document_id: str,
|
||||
request: AugmentationPreviewRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview a single augmentation on a document page.
|
||||
|
||||
Returns URLs to original and augmented preview images.
|
||||
"""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.preview_single(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
augmentation_type=request.augmentation_type,
|
||||
params=request.params,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/preview-config/{document_id}",
|
||||
response_model=AugmentationPreviewResponse,
|
||||
summary="Preview full augmentation config on document",
|
||||
)
|
||||
async def preview_config(
|
||||
document_id: str,
|
||||
config: AugmentationConfigSchema,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""Preview complete augmentation pipeline on a document page."""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.preview_config(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/batch",
|
||||
response_model=AugmentationBatchResponse,
|
||||
summary="Create augmented dataset (offline preprocessing)",
|
||||
)
|
||||
async def create_augmented_dataset(
|
||||
request: AugmentationBatchRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AugmentationBatchResponse:
|
||||
"""
|
||||
Create a new augmented dataset from an existing dataset.
|
||||
|
||||
This runs as a background task. The augmented images are stored
|
||||
alongside the original dataset for training.
|
||||
"""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.create_augmented_dataset(
|
||||
source_dataset_id=request.dataset_id,
|
||||
config=request.config,
|
||||
output_name=request.output_name,
|
||||
multiplier=request.multiplier,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets",
|
||||
response_model=AugmentedDatasetListResponse,
|
||||
summary="List augmented datasets",
|
||||
)
|
||||
async def list_augmented_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
|
||||
offset: int = Query(default=0, ge=0, description="Offset"),
|
||||
) -> AugmentedDatasetListResponse:
|
||||
"""List all augmented datasets."""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.list_augmented_datasets(limit=limit, offset=offset)
|
||||
@@ -91,8 +91,19 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
bool,
|
||||
Query(description="Trigger auto-labeling after upload"),
|
||||
] = True,
|
||||
group_key: Annotated[
|
||||
str | None,
|
||||
Query(description="Optional group key for document organization", max_length=255),
|
||||
] = None,
|
||||
) -> DocumentUploadResponse:
|
||||
"""Upload a document for labeling."""
|
||||
# Validate group_key length
|
||||
if group_key and len(group_key) > 255:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Group key must be 255 characters or less",
|
||||
)
|
||||
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
@@ -131,6 +142,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
group_key=group_key,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
@@ -177,6 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
file_size=len(content),
|
||||
page_count=page_count,
|
||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||
group_key=group_key,
|
||||
auto_label_started=auto_label_started,
|
||||
message="Document uploaded successfully",
|
||||
)
|
||||
@@ -277,6 +290,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
annotation_count=len(annotations),
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||
can_annotate=can_annotate,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
@@ -421,6 +435,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
auto_label_error=document.auto_label_error,
|
||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
@@ -548,4 +563,50 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
|
||||
return response
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/group-key",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document group key",
|
||||
description="Update the group key for a document.",
|
||||
)
|
||||
async def update_document_group_key(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
group_key: Annotated[
|
||||
str | None,
|
||||
Query(description="New group key (null to clear)"),
|
||||
] = None,
|
||||
) -> dict:
|
||||
"""Update document group key."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate group_key length
|
||||
if group_key and len(group_key) > 255:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Group key must be 255 characters or less",
|
||||
)
|
||||
|
||||
# Verify document exists
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Update group key
|
||||
db.update_document_group_key(document_id, group_key)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"group_key": group_key,
|
||||
"message": "Document group key updated",
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
@@ -11,6 +11,7 @@ from .tasks import register_task_routes
|
||||
from .documents import register_document_routes
|
||||
from .export import register_export_routes
|
||||
from .datasets import register_dataset_routes
|
||||
from .models import register_model_routes
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
@@ -21,6 +22,7 @@ def create_training_router() -> APIRouter:
|
||||
register_document_routes(router)
|
||||
register_export_routes(router)
|
||||
register_dataset_routes(router)
|
||||
register_model_routes(router)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@@ -41,6 +41,13 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
from pathlib import Path
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
# Validate minimum document count for proper train/val/test split
|
||||
if len(request.document_ids) < 10:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
|
||||
)
|
||||
|
||||
dataset = db.create_dataset(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
@@ -83,6 +90,15 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
) -> DatasetListResponse:
|
||||
"""List training datasets."""
|
||||
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
|
||||
|
||||
# Get active training tasks for each dataset (graceful degradation on error)
|
||||
dataset_ids = [str(d.dataset_id) for d in datasets]
|
||||
try:
|
||||
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch active training tasks")
|
||||
active_tasks = {}
|
||||
|
||||
return DatasetListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
@@ -93,6 +109,8 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
name=d.name,
|
||||
description=d.description,
|
||||
status=d.status,
|
||||
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
|
||||
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
|
||||
total_documents=d.total_documents,
|
||||
total_images=d.total_images,
|
||||
total_annotations=d.total_annotations,
|
||||
@@ -175,6 +193,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
"/datasets/{dataset_id}/train",
|
||||
response_model=TrainingTaskResponse,
|
||||
summary="Start training from dataset",
|
||||
description="Create a training task. Set base_model_version_id in config for incremental training.",
|
||||
)
|
||||
async def train_from_dataset(
|
||||
dataset_id: str,
|
||||
@@ -182,7 +201,11 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a training task from a dataset."""
|
||||
"""Create a training task from a dataset.
|
||||
|
||||
For incremental training, set config.base_model_version_id to a model version UUID.
|
||||
The training will use that model as the starting point instead of a pretrained model.
|
||||
"""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@@ -194,16 +217,42 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
)
|
||||
|
||||
config_dict = request.config.model_dump()
|
||||
|
||||
# Resolve base_model_version_id to actual model path for incremental training
|
||||
base_model_version_id = config_dict.get("base_model_version_id")
|
||||
if base_model_version_id:
|
||||
_validate_uuid(base_model_version_id, "base_model_version_id")
|
||||
base_model = db.get_model_version(base_model_version_id)
|
||||
if not base_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Base model version not found: {base_model_version_id}",
|
||||
)
|
||||
# Store the resolved model path for the training worker
|
||||
config_dict["base_model_path"] = base_model.model_path
|
||||
config_dict["base_model_version"] = base_model.version
|
||||
logger.info(
|
||||
"Incremental training: using model %s (%s) as base",
|
||||
base_model.version,
|
||||
base_model.model_path,
|
||||
)
|
||||
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type="train",
|
||||
task_type="finetune" if base_model_version_id else "train",
|
||||
config=config_dict,
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
|
||||
if base_model_version_id
|
||||
else "Training task created from dataset"
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
message="Training task created from dataset",
|
||||
message=message,
|
||||
)
|
||||
|
||||
@@ -145,15 +145,15 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
"/completed-tasks",
|
||||
response_model=TrainingModelsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get trained models",
|
||||
description="Get list of trained models with metrics and download links.",
|
||||
summary="Get completed training tasks",
|
||||
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
|
||||
)
|
||||
async def get_training_models(
|
||||
async def get_completed_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
|
||||
333
packages/inference/inference/web/api/v1/admin/training/models.py
Normal file
333
packages/inference/inference/web/api/v1/admin/training/models.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Model Version Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
ModelVersionCreateRequest,
|
||||
ModelVersionUpdateRequest,
|
||||
ModelVersionItem,
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ModelVersionResponse,
|
||||
ActiveModelResponse,
|
||||
)
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_model_routes(router: APIRouter) -> None:
|
||||
"""Register model version endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Create model version",
|
||||
description="Register a new model version for deployment.",
|
||||
)
|
||||
async def create_model_version(
|
||||
request: ModelVersionCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Create a new model version."""
|
||||
if request.task_id:
|
||||
_validate_uuid(request.task_id, "task_id")
|
||||
if request.dataset_id:
|
||||
_validate_uuid(request.dataset_id, "dataset_id")
|
||||
|
||||
model = db.create_model_version(
|
||||
version=request.version,
|
||||
name=request.name,
|
||||
model_path=request.model_path,
|
||||
description=request.description,
|
||||
task_id=request.task_id,
|
||||
dataset_id=request.dataset_id,
|
||||
metrics_mAP=request.metrics_mAP,
|
||||
metrics_precision=request.metrics_precision,
|
||||
metrics_recall=request.metrics_recall,
|
||||
document_count=request.document_count,
|
||||
training_config=request.training_config,
|
||||
file_size=request.file_size,
|
||||
trained_at=request.trained_at,
|
||||
)
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=ModelVersionListResponse,
|
||||
summary="List model versions",
|
||||
)
|
||||
async def list_model_versions(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> ModelVersionListResponse:
|
||||
"""List model versions with optional status filter."""
|
||||
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
|
||||
return ModelVersionListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=[
|
||||
ModelVersionItem(
|
||||
version_id=str(m.version_id),
|
||||
version=m.version,
|
||||
name=m.name,
|
||||
status=m.status,
|
||||
is_active=m.is_active,
|
||||
metrics_mAP=m.metrics_mAP,
|
||||
document_count=m.document_count,
|
||||
trained_at=m.trained_at,
|
||||
activated_at=m.activated_at,
|
||||
created_at=m.created_at,
|
||||
)
|
||||
for m in models
|
||||
],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/active",
|
||||
response_model=ActiveModelResponse,
|
||||
summary="Get active model",
|
||||
description="Get the currently active model for inference.",
|
||||
)
|
||||
async def get_active_model(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ActiveModelResponse:
|
||||
"""Get the currently active model version."""
|
||||
model = db.get_active_model_version()
|
||||
if not model:
|
||||
return ActiveModelResponse(has_active_model=False, model=None)
|
||||
|
||||
return ActiveModelResponse(
|
||||
has_active_model=True,
|
||||
model=ModelVersionItem(
|
||||
version_id=str(model.version_id),
|
||||
version=model.version,
|
||||
name=model.name,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
document_count=model.document_count,
|
||||
trained_at=model.trained_at,
|
||||
activated_at=model.activated_at,
|
||||
created_at=model.created_at,
|
||||
),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{version_id}",
|
||||
response_model=ModelVersionDetailResponse,
|
||||
summary="Get model version detail",
|
||||
)
|
||||
async def get_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionDetailResponse:
|
||||
"""Get detailed model version information."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.get_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionDetailResponse(
|
||||
version_id=str(model.version_id),
|
||||
version=model.version,
|
||||
name=model.name,
|
||||
description=model.description,
|
||||
model_path=model.model_path,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
task_id=str(model.task_id) if model.task_id else None,
|
||||
dataset_id=str(model.dataset_id) if model.dataset_id else None,
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
metrics_precision=model.metrics_precision,
|
||||
metrics_recall=model.metrics_recall,
|
||||
document_count=model.document_count,
|
||||
training_config=model.training_config,
|
||||
file_size=model.file_size,
|
||||
trained_at=model.trained_at,
|
||||
activated_at=model.activated_at,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/models/{version_id}",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Update model version",
|
||||
)
|
||||
async def update_model_version(
|
||||
version_id: str,
|
||||
request: ModelVersionUpdateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Update model version metadata."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.update_model_version(
|
||||
version_id=version_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
status=request.status,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version updated successfully",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/activate",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Activate model version",
|
||||
description="Activate a model version for inference (deactivates all others).",
|
||||
)
|
||||
async def activate_model_version(
|
||||
version_id: str,
|
||||
request: Request,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Activate a model version for inference."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.activate_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
# Trigger model reload in inference service
|
||||
inference_service = getattr(request.app.state, "inference_service", None)
|
||||
model_reloaded = False
|
||||
if inference_service:
|
||||
try:
|
||||
model_reloaded = inference_service.reload_model()
|
||||
if model_reloaded:
|
||||
logger.info(f"Inference model reloaded to version {model.version}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reload inference model: {e}")
|
||||
|
||||
message = "Model version activated for inference"
|
||||
if model_reloaded:
|
||||
message += " (model reloaded)"
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message=message,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/deactivate",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Deactivate model version",
|
||||
)
|
||||
async def deactivate_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Deactivate a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.deactivate_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version deactivated",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/archive",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Archive model version",
|
||||
)
|
||||
async def archive_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Archive a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.archive_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model version not found or cannot archive active model",
|
||||
)
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version archived",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/models/{version_id}",
|
||||
summary="Delete model version",
|
||||
)
|
||||
async def delete_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
success = db.delete_model_version(version_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model version not found or cannot delete active model",
|
||||
)
|
||||
|
||||
return {"message": "Model version deleted"}
|
||||
|
||||
@router.post(
|
||||
"/models/reload",
|
||||
summary="Reload inference model",
|
||||
description="Reload the inference model from the currently active model version.",
|
||||
)
|
||||
async def reload_inference_model(
|
||||
request: Request,
|
||||
admin_token: AdminTokenDep,
|
||||
) -> dict:
|
||||
"""Reload the inference model from active version."""
|
||||
inference_service = getattr(request.app.state, "inference_service", None)
|
||||
if not inference_service:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Inference service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
model_reloaded = inference_service.reload_model()
|
||||
if model_reloaded:
|
||||
logger.info("Inference model manually reloaded")
|
||||
return {"message": "Model reloaded successfully", "reloaded": True}
|
||||
else:
|
||||
return {"message": "Model already up to date", "reloaded": False}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to reload model: {e}",
|
||||
)
|
||||
@@ -37,6 +37,7 @@ from inference.web.core.rate_limiter import RateLimiter
|
||||
# Admin API imports
|
||||
from inference.web.api.v1.admin import (
|
||||
create_annotation_router,
|
||||
create_augmentation_router,
|
||||
create_auth_router,
|
||||
create_documents_router,
|
||||
create_locks_router,
|
||||
@@ -69,10 +70,23 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
"""
|
||||
config = config or default_config
|
||||
|
||||
# Create inference service
|
||||
# Create model path resolver that reads from database
|
||||
def get_active_model_path():
|
||||
"""Resolve active model path from database."""
|
||||
try:
|
||||
db = AdminDB()
|
||||
active_model = db.get_active_model_version()
|
||||
if active_model and active_model.model_path:
|
||||
return active_model.model_path
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get active model from database: {e}")
|
||||
return None
|
||||
|
||||
# Create inference service with database model resolver
|
||||
inference_service = InferenceService(
|
||||
model_config=config.model,
|
||||
storage_config=config.storage,
|
||||
model_path_resolver=get_active_model_path,
|
||||
)
|
||||
|
||||
# Create async processing components
|
||||
@@ -185,6 +199,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
logger.error(f"Error closing database: {e}")
|
||||
|
||||
# Create FastAPI app
|
||||
# Store inference service for access by routes (e.g., model reload)
|
||||
# This will be set after app creation
|
||||
|
||||
app = FastAPI(
|
||||
title="Invoice Field Extraction API",
|
||||
description="""
|
||||
@@ -255,9 +272,15 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
training_router = create_training_router()
|
||||
app.include_router(training_router, prefix="/api/v1")
|
||||
|
||||
augmentation_router = create_augmentation_router()
|
||||
app.include_router(augmentation_router, prefix="/api/v1/admin")
|
||||
|
||||
# Include batch upload routes
|
||||
app.include_router(batch_upload_router)
|
||||
|
||||
# Store inference service in app state for access by routes
|
||||
app.state.inference_service = inference_service
|
||||
|
||||
# Root endpoint - serve HTML UI
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root() -> str:
|
||||
|
||||
@@ -110,6 +110,7 @@ class TrainingScheduler:
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
base_model_path = config.get("base_model_path") # For incremental training
|
||||
epochs = config.get("epochs", 100)
|
||||
batch_size = config.get("batch_size", 16)
|
||||
image_size = config.get("image_size", 640)
|
||||
@@ -117,12 +118,31 @@ class TrainingScheduler:
|
||||
device = config.get("device", "0")
|
||||
project_name = config.get("project_name", "invoice_fields")
|
||||
|
||||
# Get augmentation config if present
|
||||
augmentation_config = config.get("augmentation")
|
||||
augmentation_multiplier = config.get("augmentation_multiplier", 2)
|
||||
|
||||
# Determine which model to use as base
|
||||
if base_model_path:
|
||||
# Incremental training: use existing trained model
|
||||
if not Path(base_model_path).exists():
|
||||
raise ValueError(f"Base model not found: {base_model_path}")
|
||||
effective_model = base_model_path
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Incremental training from: {base_model_path}",
|
||||
)
|
||||
else:
|
||||
# Train from pretrained model
|
||||
effective_model = model_name
|
||||
|
||||
# Use dataset if available, otherwise export from scratch
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
if not dataset or not dataset.dataset_path:
|
||||
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
||||
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
||||
dataset_path = Path(dataset.dataset_path)
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
||||
@@ -132,15 +152,28 @@ class TrainingScheduler:
|
||||
if not export_result:
|
||||
raise ValueError("Failed to export training data")
|
||||
data_yaml = export_result["data_yaml"]
|
||||
dataset_path = Path(data_yaml).parent
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
|
||||
# Apply augmentation if config is provided
|
||||
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
|
||||
aug_result = self._apply_augmentation(
|
||||
task_id, dataset_path, augmentation_config, augmentation_multiplier
|
||||
)
|
||||
if aug_result:
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Augmentation complete: {aug_result['augmented_images']} new images "
|
||||
f"(total: {aug_result['total_images']})",
|
||||
)
|
||||
|
||||
# Run YOLO training
|
||||
result = self._run_yolo_training(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_name=effective_model, # Use base model or pretrained model
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
@@ -159,11 +192,94 @@ class TrainingScheduler:
|
||||
)
|
||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
# Auto-create model version for the completed training
|
||||
self._create_model_version_from_training(
|
||||
task_id=task_id,
|
||||
config=config,
|
||||
dataset_id=dataset_id,
|
||||
result=result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
def _create_model_version_from_training(
|
||||
self,
|
||||
task_id: str,
|
||||
config: dict[str, Any],
|
||||
dataset_id: str | None,
|
||||
result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Create a model version entry from completed training."""
|
||||
try:
|
||||
model_path = result.get("model_path")
|
||||
if not model_path:
|
||||
logger.warning(f"No model path in training result for task {task_id}")
|
||||
return
|
||||
|
||||
# Get task info for name
|
||||
task = self._db.get_training_task(task_id)
|
||||
task_name = task.name if task else f"Task {task_id[:8]}"
|
||||
|
||||
# Generate version number based on existing versions
|
||||
existing_versions = self._db.get_model_versions(limit=1, offset=0)
|
||||
version_count = existing_versions[1] if existing_versions else 0
|
||||
version = f"v{version_count + 1}.0"
|
||||
|
||||
# Extract metrics from result
|
||||
metrics = result.get("metrics", {})
|
||||
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
|
||||
metrics_precision = metrics.get("precision")
|
||||
metrics_recall = metrics.get("recall")
|
||||
|
||||
# Get file size if possible
|
||||
file_size = None
|
||||
model_file = Path(model_path)
|
||||
if model_file.exists():
|
||||
file_size = model_file.stat().st_size
|
||||
|
||||
# Get document count from dataset if available
|
||||
document_count = 0
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
if dataset:
|
||||
document_count = dataset.total_documents
|
||||
|
||||
# Create model version
|
||||
model_version = self._db.create_model_version(
|
||||
version=version,
|
||||
name=task_name,
|
||||
model_path=str(model_path),
|
||||
description=f"Auto-created from training task {task_id[:8]}",
|
||||
task_id=task_id,
|
||||
dataset_id=dataset_id,
|
||||
metrics_mAP=metrics_mAP,
|
||||
metrics_precision=metrics_precision,
|
||||
metrics_recall=metrics_recall,
|
||||
document_count=document_count,
|
||||
training_config=config,
|
||||
file_size=file_size,
|
||||
trained_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created model version {version} (ID: {model_version.version_id}) "
|
||||
f"from training task {task_id}"
|
||||
)
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model version for task {task_id}: {e}")
|
||||
self._db.add_training_log(
|
||||
task_id, "WARNING",
|
||||
f"Failed to auto-create model version: {e}",
|
||||
)
|
||||
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
@@ -256,62 +372,82 @@ names: {list(FIELD_CLASSES.values())}
|
||||
device: str,
|
||||
project_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run YOLO training."""
|
||||
"""Run YOLO training using shared trainer."""
|
||||
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
|
||||
|
||||
# Create log callback that writes to DB
|
||||
def log_callback(level: str, message: str) -> None:
|
||||
self._db.add_training_log(task_id, level, message)
|
||||
|
||||
# Create shared training config
|
||||
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
||||
config = SharedTrainingConfig(
|
||||
model_path=model_name,
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
image_size=image_size,
|
||||
learning_rate=learning_rate,
|
||||
device=device,
|
||||
project="runs/train",
|
||||
name=f"{project_name}/task_{task_id[:8]}",
|
||||
workers=0,
|
||||
)
|
||||
|
||||
# Run training using shared trainer
|
||||
trainer = YOLOTrainer(config=config, log_callback=log_callback)
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
raise ValueError(result.error or "Training failed")
|
||||
|
||||
return {
|
||||
"model_path": result.model_path,
|
||||
"metrics": result.metrics,
|
||||
}
|
||||
|
||||
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
|
||||
"""Check if any augmentations are enabled in the config."""
|
||||
augmentation_fields = [
|
||||
"perspective_warp", "wrinkle", "edge_damage", "stain",
|
||||
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
|
||||
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
|
||||
]
|
||||
for field in augmentation_fields:
|
||||
if field in aug_config:
|
||||
field_config = aug_config[field]
|
||||
if isinstance(field_config, dict) and field_config.get("enabled", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _apply_augmentation(
|
||||
self,
|
||||
task_id: str,
|
||||
dataset_path: Path,
|
||||
aug_config: dict[str, Any],
|
||||
multiplier: int,
|
||||
) -> dict[str, int] | None:
|
||||
"""Apply augmentation to dataset before training."""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Log training start
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_name)
|
||||
|
||||
# Train
|
||||
results = model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
batch=batch_size,
|
||||
imgsz=image_size,
|
||||
lr0=learning_rate,
|
||||
device=device,
|
||||
project=f"runs/train/{project_name}",
|
||||
name=f"task_{task_id[:8]}",
|
||||
exist_ok=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Get best model path
|
||||
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||
|
||||
# Extract metrics
|
||||
metrics = {}
|
||||
if hasattr(results, "results_dict"):
|
||||
metrics = {
|
||||
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||
}
|
||||
from shared.augmentation import DatasetAugmenter
|
||||
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
|
||||
f"Applying augmentation with multiplier={multiplier}",
|
||||
)
|
||||
|
||||
return {
|
||||
"model_path": str(best_model) if best_model.exists() else None,
|
||||
"metrics": metrics,
|
||||
}
|
||||
augmenter = DatasetAugmenter(aug_config)
|
||||
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
|
||||
|
||||
return result
|
||||
|
||||
except ImportError:
|
||||
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
|
||||
raise ValueError("Ultralytics (YOLO) not installed")
|
||||
except Exception as e:
|
||||
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
|
||||
raise
|
||||
logger.error(f"Augmentation failed for task {task_id}: {e}")
|
||||
self._db.add_training_log(
|
||||
task_id, "WARNING",
|
||||
f"Augmentation failed: {e}. Continuing with original dataset.",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
|
||||
@@ -10,6 +10,7 @@ from .documents import * # noqa: F401, F403
|
||||
from .annotations import * # noqa: F401, F403
|
||||
from .training import * # noqa: F401, F403
|
||||
from .datasets import * # noqa: F401, F403
|
||||
from .models import * # noqa: F401, F403
|
||||
|
||||
# Resolve forward references for DocumentDetailResponse
|
||||
from .documents import DocumentDetailResponse
|
||||
|
||||
187
packages/inference/inference/web/schemas/admin/augmentation.py
Normal file
187
packages/inference/inference/web/schemas/admin/augmentation.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Admin Augmentation Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AugmentationParamsSchema(BaseModel):
|
||||
"""Single augmentation parameters."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
|
||||
probability: float = Field(
|
||||
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
|
||||
)
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Type-specific parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationConfigSchema(BaseModel):
|
||||
"""Complete augmentation configuration."""
|
||||
|
||||
# Geometric transforms
|
||||
perspective_warp: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Degradation effects
|
||||
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
edge_damage: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
|
||||
# Lighting effects
|
||||
lighting_variation: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
|
||||
# Blur effects
|
||||
gaussian_blur: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
motion_blur: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Noise effects
|
||||
gaussian_noise: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
salt_pepper: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Texture effects
|
||||
paper_texture: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
scanner_artifacts: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Global settings
|
||||
preserve_bboxes: bool = Field(
|
||||
default=True, description="Whether to adjust bboxes for geometric transforms"
|
||||
)
|
||||
seed: int | None = Field(default=None, description="Random seed for reproducibility")
|
||||
|
||||
|
||||
class AugmentationTypeInfo(BaseModel):
|
||||
"""Information about an augmentation type."""
|
||||
|
||||
name: str = Field(..., description="Augmentation name")
|
||||
description: str = Field(..., description="Augmentation description")
|
||||
affects_geometry: bool = Field(
|
||||
..., description="Whether this augmentation affects bbox coordinates"
|
||||
)
|
||||
stage: str = Field(..., description="Processing stage")
|
||||
default_params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationTypesResponse(BaseModel):
|
||||
"""Response for listing augmentation types."""
|
||||
|
||||
augmentation_types: list[AugmentationTypeInfo] = Field(
|
||||
..., description="Available augmentation types"
|
||||
)
|
||||
|
||||
|
||||
class PresetInfo(BaseModel):
|
||||
"""Information about a preset."""
|
||||
|
||||
name: str = Field(..., description="Preset name")
|
||||
description: str = Field(..., description="Preset description")
|
||||
|
||||
|
||||
class PresetsResponse(BaseModel):
|
||||
"""Response for listing presets."""
|
||||
|
||||
presets: list[PresetInfo] = Field(..., description="Available presets")
|
||||
|
||||
|
||||
class AugmentationPreviewRequest(BaseModel):
|
||||
"""Request to preview augmentation on an image."""
|
||||
|
||||
augmentation_type: str = Field(..., description="Type of augmentation to preview")
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Override parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationPreviewResponse(BaseModel):
|
||||
"""Response with preview image data."""
|
||||
|
||||
preview_url: str = Field(..., description="URL to preview image")
|
||||
original_url: str = Field(..., description="URL to original image")
|
||||
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
|
||||
|
||||
|
||||
class AugmentationBatchRequest(BaseModel):
|
||||
"""Request to augment a dataset offline."""
|
||||
|
||||
dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
|
||||
output_name: str = Field(
|
||||
..., min_length=1, max_length=255, description="Output dataset name"
|
||||
)
|
||||
multiplier: int = Field(
|
||||
default=2, ge=1, le=10, description="Augmented copies per image"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationBatchResponse(BaseModel):
|
||||
"""Response for batch augmentation."""
|
||||
|
||||
task_id: str = Field(..., description="Background task UUID")
|
||||
status: str = Field(..., description="Task status")
|
||||
message: str = Field(..., description="Status message")
|
||||
estimated_images: int = Field(..., description="Estimated total images")
|
||||
|
||||
|
||||
class AugmentedDatasetItem(BaseModel):
|
||||
"""Single augmented dataset in list."""
|
||||
|
||||
dataset_id: str = Field(..., description="Dataset UUID")
|
||||
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
name: str = Field(..., description="Dataset name")
|
||||
status: str = Field(..., description="Dataset status")
|
||||
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||
total_original_images: int = Field(..., description="Original image count")
|
||||
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class AugmentedDatasetListResponse(BaseModel):
|
||||
"""Response for listing augmented datasets."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total datasets")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
datasets: list[AugmentedDatasetItem] = Field(
|
||||
default_factory=list, description="Dataset list"
|
||||
)
|
||||
|
||||
|
||||
class AugmentedDatasetDetailResponse(BaseModel):
|
||||
"""Detailed augmented dataset response."""
|
||||
|
||||
dataset_id: str = Field(..., description="Dataset UUID")
|
||||
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
name: str = Field(..., description="Dataset name")
|
||||
status: str = Field(..., description="Dataset status")
|
||||
config: AugmentationConfigSchema | None = Field(
|
||||
None, description="Augmentation config used"
|
||||
)
|
||||
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||
total_original_images: int = Field(..., description="Original image count")
|
||||
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||
dataset_path: str | None = Field(None, description="Dataset path on disk")
|
||||
error_message: str | None = Field(None, description="Error message if failed")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||
@@ -63,6 +63,8 @@ class DatasetListItem(BaseModel):
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
training_status: str | None = None
|
||||
active_training_task_id: str | None = None
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
|
||||
@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
)
|
||||
@@ -42,6 +43,7 @@ class DocumentItem(BaseModel):
|
||||
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
@@ -73,6 +75,7 @@ class DocumentDetailResponse(BaseModel):
|
||||
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
|
||||
95
packages/inference/inference/web/schemas/admin/models.py
Normal file
95
packages/inference/inference/web/schemas/admin/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Admin Model Version Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelVersionCreateRequest(BaseModel):
|
||||
"""Request to create a model version."""
|
||||
|
||||
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Model name")
|
||||
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
task_id: str | None = Field(None, description="Training task UUID")
|
||||
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||
document_count: int = Field(0, ge=0, description="Documents used in training")
|
||||
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
|
||||
|
||||
class ModelVersionUpdateRequest(BaseModel):
|
||||
"""Request to update a model version."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
status: str | None = Field(None, description="Status (inactive, archived)")
|
||||
|
||||
|
||||
class ModelVersionItem(BaseModel):
|
||||
"""Model version in list view."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
name: str = Field(..., description="Model name")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
document_count: int = Field(..., description="Documents used in training")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class ModelVersionListResponse(BaseModel):
|
||||
"""Paginated model version list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total model versions")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
|
||||
|
||||
|
||||
class ModelVersionDetailResponse(BaseModel):
|
||||
"""Detailed model version info."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
name: str = Field(..., description="Model name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
model_path: str = Field(..., description="Path to model file")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
task_id: str | None = Field(None, description="Training task UUID")
|
||||
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, description="Precision")
|
||||
metrics_recall: float | None = Field(None, description="Recall")
|
||||
document_count: int = Field(..., description="Documents used in training")
|
||||
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
file_size: int | None = Field(None, description="Model file size in bytes")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class ModelVersionResponse(BaseModel):
|
||||
"""Response for model version operation."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
status: str = Field(..., description="Model status")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class ActiveModelResponse(BaseModel):
|
||||
"""Response for active model query."""
|
||||
|
||||
has_active_model: bool = Field(..., description="Whether an active model exists")
|
||||
model: ModelVersionItem | None = Field(None, description="Active model if exists")
|
||||
@@ -5,13 +5,18 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .augmentation import AugmentationConfigSchema
|
||||
from .enums import TrainingStatus, TrainingType
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training configuration."""
|
||||
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name")
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
|
||||
base_model_version_id: str | None = Field(
|
||||
default=None,
|
||||
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
|
||||
)
|
||||
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
||||
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
||||
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
||||
@@ -21,6 +26,18 @@ class TrainingConfig(BaseModel):
|
||||
default="invoice_fields", description="Training project name"
|
||||
)
|
||||
|
||||
# Data augmentation settings
|
||||
augmentation: AugmentationConfigSchema | None = Field(
|
||||
default=None,
|
||||
description="Augmentation configuration. If provided, augments dataset before training.",
|
||||
)
|
||||
augmentation_multiplier: int = Field(
|
||||
default=2,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Number of augmented copies per original image",
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskCreate(BaseModel):
|
||||
"""Request to create a training task."""
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
"""Augmentation service for handling augmentation operations."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchResponse,
|
||||
AugmentationConfigSchema,
|
||||
AugmentationPreviewResponse,
|
||||
AugmentedDatasetItem,
|
||||
AugmentedDatasetListResponse,
|
||||
)
|
||||
|
||||
# Constants
|
||||
PREVIEW_MAX_SIZE = 800
|
||||
PREVIEW_SEED = 42
|
||||
UUID_PATTERN = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class AugmentationService:
|
||||
"""Service for augmentation operations."""
|
||||
|
||||
def __init__(self, db: AdminDB) -> None:
|
||||
"""Initialize service with database connection."""
|
||||
self.db = db
|
||||
|
||||
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
|
||||
"""
|
||||
Validate UUID format to prevent path traversal.
|
||||
|
||||
Args:
|
||||
value: Value to validate.
|
||||
field_name: Field name for error message.
|
||||
|
||||
Raises:
|
||||
HTTPException: If value is not a valid UUID.
|
||||
"""
|
||||
if not UUID_PATTERN.match(value):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {field_name} format: {value}",
|
||||
)
|
||||
|
||||
async def preview_single(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
augmentation_type: str,
|
||||
params: dict[str, Any],
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview a single augmentation on a document page.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
augmentation_type: Name of augmentation to apply.
|
||||
params: Override parameters.
|
||||
|
||||
Returns:
|
||||
Preview response with image URLs.
|
||||
|
||||
Raises:
|
||||
HTTPException: If document not found or augmentation invalid.
|
||||
"""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
|
||||
|
||||
# Validate augmentation type
|
||||
if augmentation_type not in AUGMENTATION_REGISTRY:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown augmentation type: {augmentation_type}. "
|
||||
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
|
||||
)
|
||||
|
||||
# Get document and load image
|
||||
image = await self._load_document_page(document_id, page)
|
||||
|
||||
# Create config with only this augmentation enabled
|
||||
config_kwargs = {
|
||||
augmentation_type: AugmentationParams(
|
||||
enabled=True,
|
||||
probability=1.0, # Always apply for preview
|
||||
params=params,
|
||||
),
|
||||
"seed": PREVIEW_SEED, # Deterministic preview
|
||||
}
|
||||
config = AugmentationConfig(**config_kwargs)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Convert to base64 URLs
|
||||
original_url = self._image_to_data_url(image)
|
||||
preview_url = self._image_to_data_url(result.image)
|
||||
|
||||
return AugmentationPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
original_url=original_url,
|
||||
applied_params=params,
|
||||
)
|
||||
|
||||
async def preview_config(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
config: AugmentationConfigSchema,
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview full augmentation config on a document page.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
config: Full augmentation configuration.
|
||||
|
||||
Returns:
|
||||
Preview response with image URLs.
|
||||
"""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
# Load image
|
||||
image = await self._load_document_page(document_id, page)
|
||||
|
||||
# Convert Pydantic model to internal config
|
||||
config_dict = config.model_dump()
|
||||
internal_config = AugmentationConfig.from_dict(config_dict)
|
||||
pipeline = AugmentationPipeline(internal_config)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Convert to base64 URLs
|
||||
original_url = self._image_to_data_url(image)
|
||||
preview_url = self._image_to_data_url(result.image)
|
||||
|
||||
return AugmentationPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
original_url=original_url,
|
||||
applied_params=config_dict,
|
||||
)
|
||||
|
||||
async def create_augmented_dataset(
|
||||
self,
|
||||
source_dataset_id: str,
|
||||
config: AugmentationConfigSchema,
|
||||
output_name: str,
|
||||
multiplier: int,
|
||||
) -> AugmentationBatchResponse:
|
||||
"""
|
||||
Create a new augmented dataset from an existing dataset.
|
||||
|
||||
Args:
|
||||
source_dataset_id: Source dataset UUID.
|
||||
config: Augmentation configuration.
|
||||
output_name: Name for the new dataset.
|
||||
multiplier: Number of augmented copies per image.
|
||||
|
||||
Returns:
|
||||
Batch response with task ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If source dataset not found.
|
||||
"""
|
||||
# Validate source dataset exists
|
||||
try:
|
||||
source_dataset = self.db.get_dataset(source_dataset_id)
|
||||
if source_dataset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source dataset not found: {source_dataset_id}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source dataset not found: {source_dataset_id}",
|
||||
) from e
|
||||
|
||||
# Create task ID for background processing
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Estimate total images
|
||||
estimated_images = (
|
||||
source_dataset.total_images * multiplier
|
||||
if hasattr(source_dataset, "total_images")
|
||||
else 0
|
||||
)
|
||||
|
||||
# TODO: Queue background task for actual augmentation
|
||||
# For now, return pending status
|
||||
|
||||
return AugmentationBatchResponse(
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
message=f"Augmentation task queued for dataset '{output_name}'",
|
||||
estimated_images=estimated_images,
|
||||
)
|
||||
|
||||
async def list_augmented_datasets(
|
||||
self,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> AugmentedDatasetListResponse:
|
||||
"""
|
||||
List augmented datasets.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of datasets to return.
|
||||
offset: Number of datasets to skip.
|
||||
|
||||
Returns:
|
||||
List response with datasets.
|
||||
"""
|
||||
# TODO: Implement actual database query for augmented datasets
|
||||
# For now, return empty list
|
||||
|
||||
return AugmentedDatasetListResponse(
|
||||
total=0,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
datasets=[],
|
||||
)
|
||||
|
||||
async def _load_document_page(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load a document page as numpy array.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image as numpy array (H, W, C) with dtype uint8.
|
||||
|
||||
Raises:
|
||||
HTTPException: If document or page not found.
|
||||
"""
|
||||
# Validate document_id format to prevent path traversal
|
||||
self._validate_uuid(document_id, "document_id")
|
||||
|
||||
# Get document from database
|
||||
try:
|
||||
document = self.db.get_document(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document not found: {document_id}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document not found: {document_id}",
|
||||
) from e
|
||||
|
||||
# Get image path for page
|
||||
if hasattr(document, "images_dir"):
|
||||
images_dir = Path(document.images_dir)
|
||||
else:
|
||||
# Fallback to constructed path
|
||||
from inference.web.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
|
||||
|
||||
# Find image for page
|
||||
page_idx = page - 1 # Convert to 0-indexed
|
||||
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
|
||||
|
||||
if page_idx >= len(image_files):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Page {page} not found for document {document_id}",
|
||||
)
|
||||
|
||||
# Load image
|
||||
image_path = image_files[page_idx]
|
||||
pil_image = Image.open(image_path).convert("RGB")
|
||||
return np.array(pil_image)
|
||||
|
||||
def _image_to_data_url(self, image: np.ndarray) -> str:
|
||||
"""Convert numpy image to base64 data URL."""
|
||||
pil_image = Image.fromarray(image)
|
||||
|
||||
# Resize for preview if too large
|
||||
max_size = PREVIEW_MAX_SIZE
|
||||
if max(pil_image.size) > max_size:
|
||||
ratio = max_size / max(pil_image.size)
|
||||
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return f"data:image/png;base64,{base64_data}"
|
||||
@@ -81,29 +81,18 @@ class DatasetBuilder:
|
||||
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. Shuffle and split documents
|
||||
# 3. Group documents by group_key and assign splits
|
||||
doc_list = list(documents)
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(doc_list)
|
||||
|
||||
n = len(doc_list)
|
||||
n_train = max(1, round(n * train_ratio))
|
||||
n_val = max(0, round(n * val_ratio))
|
||||
n_test = n - n_train - n_val
|
||||
|
||||
splits = (
|
||||
["train"] * n_train
|
||||
+ ["val"] * n_val
|
||||
+ ["test"] * n_test
|
||||
)
|
||||
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
|
||||
|
||||
# 4. Process each document
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
dataset_docs = []
|
||||
|
||||
for doc, split in zip(doc_list, splits):
|
||||
for doc in doc_list:
|
||||
doc_id = str(doc.document_id)
|
||||
split = doc_splits[doc_id]
|
||||
annotations = self._db.get_annotations_for_document(doc.document_id)
|
||||
|
||||
# Group annotations by page
|
||||
@@ -174,6 +163,86 @@ class DatasetBuilder:
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _assign_splits_by_group(
|
||||
self,
|
||||
documents: list,
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
) -> dict[str, str]:
|
||||
"""Assign splits based on group_key.
|
||||
|
||||
Logic:
|
||||
- Documents with same group_key stay together in the same split
|
||||
- Groups with only 1 document go directly to train
|
||||
- Groups with 2+ documents participate in shuffle & split
|
||||
|
||||
Args:
|
||||
documents: List of AdminDocument objects
|
||||
train_ratio: Fraction for training set
|
||||
val_ratio: Fraction for validation set
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Dict mapping document_id (str) -> split ("train"/"val"/"test")
|
||||
"""
|
||||
# Group documents by group_key
|
||||
# None/empty group_key treated as unique (each doc is its own group)
|
||||
groups: dict[str | None, list] = {}
|
||||
for doc in documents:
|
||||
key = doc.group_key if doc.group_key else None
|
||||
if key is None:
|
||||
# Treat each ungrouped doc as its own unique group
|
||||
# Use document_id as pseudo-key
|
||||
key = f"__ungrouped_{doc.document_id}"
|
||||
groups.setdefault(key, []).append(doc)
|
||||
|
||||
# Separate single-doc groups from multi-doc groups
|
||||
single_doc_groups: list[tuple[str | None, list]] = []
|
||||
multi_doc_groups: list[tuple[str | None, list]] = []
|
||||
|
||||
for key, docs in groups.items():
|
||||
if len(docs) == 1:
|
||||
single_doc_groups.append((key, docs))
|
||||
else:
|
||||
multi_doc_groups.append((key, docs))
|
||||
|
||||
# Initialize result mapping
|
||||
doc_splits: dict[str, str] = {}
|
||||
|
||||
# Combine all groups for splitting
|
||||
all_groups = single_doc_groups + multi_doc_groups
|
||||
|
||||
# Shuffle all groups and assign splits
|
||||
if all_groups:
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(all_groups)
|
||||
|
||||
n_groups = len(all_groups)
|
||||
n_train = max(1, round(n_groups * train_ratio))
|
||||
# Ensure at least 1 in val if we have more than 1 group
|
||||
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
|
||||
|
||||
for i, (_key, docs) in enumerate(all_groups):
|
||||
if i < n_train:
|
||||
split = "train"
|
||||
elif i < n_train + n_val:
|
||||
split = "val"
|
||||
else:
|
||||
split = "test"
|
||||
|
||||
for doc in docs:
|
||||
doc_splits[str(doc.document_id)] = split
|
||||
|
||||
logger.info(
|
||||
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
|
||||
len(all_groups),
|
||||
sum(1 for s in doc_splits.values() if s == "train"),
|
||||
sum(1 for s in doc_splits.values() if s == "val"),
|
||||
)
|
||||
|
||||
return doc_splits
|
||||
|
||||
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
||||
"""Generate YOLO data.yaml configuration file."""
|
||||
data = {
|
||||
|
||||
@@ -11,7 +11,7 @@ import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -22,6 +22,10 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Type alias for model path resolver function
|
||||
ModelPathResolver = Callable[[], Path | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceResult:
|
||||
"""Result from inference service."""
|
||||
@@ -42,25 +46,52 @@ class InferenceService:
|
||||
Service for running invoice field extraction.
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration
|
||||
model_config: Model configuration (default model settings)
|
||||
storage_config: Storage configuration
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
|
||||
def _resolve_model_path(self) -> Path:
|
||||
"""Resolve the model path to use for inference.
|
||||
|
||||
Priority:
|
||||
1. Active model from database (via resolver)
|
||||
2. Default model from config
|
||||
"""
|
||||
if self._model_path_resolver:
|
||||
try:
|
||||
db_model_path = self._model_path_resolver()
|
||||
if db_model_path and Path(db_model_path).exists():
|
||||
logger.info(f"Using active model from database: {db_model_path}")
|
||||
return Path(db_model_path)
|
||||
elif db_model_path:
|
||||
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
|
||||
|
||||
return self.model_config.model_path
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading)."""
|
||||
@@ -74,16 +105,20 @@ class InferenceService:
|
||||
from inference.pipeline.pipeline import InferencePipeline
|
||||
from inference.pipeline.yolo_detector import YOLODetector
|
||||
|
||||
# Resolve model path (from DB or config)
|
||||
model_path = self._resolve_model_path()
|
||||
self._current_model_path = model_path
|
||||
|
||||
# Initialize YOLO detector for visualization
|
||||
self._detector = YOLODetector(
|
||||
str(self.model_config.model_path),
|
||||
str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||
)
|
||||
|
||||
# Initialize full pipeline
|
||||
self._pipeline = InferencePipeline(
|
||||
model_path=str(self.model_config.model_path),
|
||||
model_path=str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
use_gpu=self.model_config.use_gpu,
|
||||
dpi=self.model_config.dpi,
|
||||
@@ -92,12 +127,36 @@ class InferenceService:
|
||||
|
||||
self._is_initialized = True
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s")
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
raise
|
||||
|
||||
def reload_model(self) -> bool:
|
||||
"""Reload the model if active model has changed.
|
||||
|
||||
Returns:
|
||||
True if model was reloaded, False if no change needed.
|
||||
"""
|
||||
new_model_path = self._resolve_model_path()
|
||||
|
||||
if self._current_model_path == new_model_path:
|
||||
logger.debug("Model unchanged, no reload needed")
|
||||
return False
|
||||
|
||||
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
|
||||
self._is_initialized = False
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self.initialize()
|
||||
return True
|
||||
|
||||
@property
|
||||
def current_model_path(self) -> Path | None:
|
||||
"""Get the currently loaded model path."""
|
||||
return self._current_model_path
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if service is initialized."""
|
||||
|
||||
Reference in New Issue
Block a user