WIP
This commit is contained in:
@@ -25,6 +25,7 @@ from inference.data.admin_models import (
|
||||
AnnotationHistory,
|
||||
TrainingDataset,
|
||||
DatasetDocument,
|
||||
ModelVersion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -110,6 +111,7 @@ class AdminDB:
|
||||
page_count: int = 1,
|
||||
upload_source: str = "ui",
|
||||
csv_field_values: dict[str, Any] | None = None,
|
||||
group_key: str | None = None,
|
||||
admin_token: str | None = None, # Deprecated, kept for compatibility
|
||||
) -> str:
|
||||
"""Create a new document record."""
|
||||
@@ -122,6 +124,7 @@ class AdminDB:
|
||||
page_count=page_count,
|
||||
upload_source=upload_source,
|
||||
csv_field_values=csv_field_values,
|
||||
group_key=group_key,
|
||||
)
|
||||
session.add(document)
|
||||
session.flush()
|
||||
@@ -253,6 +256,17 @@ class AdminDB:
|
||||
document.updated_at = datetime.utcnow()
|
||||
session.add(document)
|
||||
|
||||
def update_document_group_key(self, document_id: str, group_key: str | None) -> bool:
|
||||
"""Update document group key."""
|
||||
with get_session_context() as session:
|
||||
document = session.get(AdminDocument, UUID(document_id))
|
||||
if document:
|
||||
document.group_key = group_key
|
||||
document.updated_at = datetime.utcnow()
|
||||
session.add(document)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_document(self, document_id: str) -> bool:
|
||||
"""Delete a document and its annotations."""
|
||||
with get_session_context() as session:
|
||||
@@ -1215,6 +1229,39 @@ class AdminDB:
|
||||
session.expunge(d)
|
||||
return list(datasets), total
|
||||
|
||||
def get_active_training_tasks_for_datasets(
|
||||
self, dataset_ids: list[str]
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Get active (pending/scheduled/running) training tasks for datasets.
|
||||
|
||||
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
|
||||
"""
|
||||
if not dataset_ids:
|
||||
return {}
|
||||
|
||||
# Validate UUIDs before query
|
||||
valid_uuids = []
|
||||
for d in dataset_ids:
|
||||
try:
|
||||
valid_uuids.append(UUID(d))
|
||||
except ValueError:
|
||||
logger.warning("Invalid UUID in get_active_training_tasks_for_datasets: %s", d)
|
||||
continue
|
||||
|
||||
if not valid_uuids:
|
||||
return {}
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(TrainingTask).where(
|
||||
TrainingTask.dataset_id.in_(valid_uuids),
|
||||
TrainingTask.status.in_(["pending", "scheduled", "running"]),
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
return {
|
||||
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
|
||||
for t in results
|
||||
}
|
||||
|
||||
def update_dataset_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
@@ -1314,3 +1361,182 @@ class AdminDB:
|
||||
session.delete(dataset)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# ==========================================================================
|
||||
# Model Version Operations
|
||||
# ==========================================================================
|
||||
|
||||
def create_model_version(
|
||||
self,
|
||||
version: str,
|
||||
name: str,
|
||||
model_path: str,
|
||||
description: str | None = None,
|
||||
task_id: str | UUID | None = None,
|
||||
dataset_id: str | UUID | None = None,
|
||||
metrics_mAP: float | None = None,
|
||||
metrics_precision: float | None = None,
|
||||
metrics_recall: float | None = None,
|
||||
document_count: int = 0,
|
||||
training_config: dict[str, Any] | None = None,
|
||||
file_size: int | None = None,
|
||||
trained_at: datetime | None = None,
|
||||
) -> ModelVersion:
|
||||
"""Create a new model version."""
|
||||
with get_session_context() as session:
|
||||
model = ModelVersion(
|
||||
version=version,
|
||||
name=name,
|
||||
model_path=model_path,
|
||||
description=description,
|
||||
task_id=UUID(str(task_id)) if task_id else None,
|
||||
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
|
||||
metrics_mAP=metrics_mAP,
|
||||
metrics_precision=metrics_precision,
|
||||
metrics_recall=metrics_recall,
|
||||
document_count=document_count,
|
||||
training_config=training_config,
|
||||
file_size=file_size,
|
||||
trained_at=trained_at,
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Get a model version by ID."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if model:
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def get_model_versions(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ModelVersion], int]:
|
||||
"""List model versions with optional status filter."""
|
||||
with get_session_context() as session:
|
||||
query = select(ModelVersion)
|
||||
count_query = select(func.count()).select_from(ModelVersion)
|
||||
if status:
|
||||
query = query.where(ModelVersion.status == status)
|
||||
count_query = count_query.where(ModelVersion.status == status)
|
||||
total = session.exec(count_query).one()
|
||||
models = session.exec(
|
||||
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
|
||||
).all()
|
||||
for m in models:
|
||||
session.expunge(m)
|
||||
return list(models), total
|
||||
|
||||
def get_active_model_version(self) -> ModelVersion | None:
|
||||
"""Get the currently active model version for inference."""
|
||||
with get_session_context() as session:
|
||||
result = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def activate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Activate a model version for inference (deactivates all others)."""
|
||||
with get_session_context() as session:
|
||||
# Deactivate all versions
|
||||
all_versions = session.exec(
|
||||
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||
).all()
|
||||
for v in all_versions:
|
||||
v.is_active = False
|
||||
v.status = "inactive"
|
||||
v.updated_at = datetime.utcnow()
|
||||
session.add(v)
|
||||
|
||||
# Activate the specified version
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = True
|
||||
model.status = "active"
|
||||
model.activated_at = datetime.utcnow()
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def deactivate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Deactivate a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
model.is_active = False
|
||||
model.status = "inactive"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def update_model_version(
|
||||
self,
|
||||
version_id: str | UUID,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> ModelVersion | None:
|
||||
"""Update model version metadata."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
if name is not None:
|
||||
model.name = name
|
||||
if description is not None:
|
||||
model.description = description
|
||||
if status is not None:
|
||||
model.status = status
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def archive_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||
"""Archive a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return None
|
||||
# Cannot archive active model
|
||||
if model.is_active:
|
||||
return None
|
||||
model.status = "archived"
|
||||
model.updated_at = datetime.utcnow()
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
session.expunge(model)
|
||||
return model
|
||||
|
||||
def delete_model_version(self, version_id: str | UUID) -> bool:
|
||||
"""Delete a model version."""
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||
if not model:
|
||||
return False
|
||||
# Cannot delete active model
|
||||
if model.is_active:
|
||||
return False
|
||||
session.delete(model)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@@ -70,6 +70,8 @@ class AdminDocument(SQLModel, table=True):
|
||||
# Upload source: ui, api
|
||||
batch_id: UUID | None = Field(default=None, index=True)
|
||||
# Link to batch upload (if uploaded via ZIP)
|
||||
group_key: str | None = Field(default=None, max_length=255, index=True)
|
||||
# User-defined grouping key for document organization
|
||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Original CSV values for reference
|
||||
auto_label_queued_at: datetime | None = Field(default=None)
|
||||
@@ -275,6 +277,56 @@ class TrainingDocumentLink(SQLModel, table=True):
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Version Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ModelVersion(SQLModel, table=True):
|
||||
"""Model version for inference deployment."""
|
||||
|
||||
__tablename__ = "model_versions"
|
||||
|
||||
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
version: str = Field(max_length=50, index=True)
|
||||
# Semantic version e.g., "1.0.0", "2.1.0"
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
model_path: str = Field(max_length=512)
|
||||
# Path to the model weights file
|
||||
status: str = Field(default="inactive", max_length=20, index=True)
|
||||
# Status: active, inactive, archived
|
||||
is_active: bool = Field(default=False, index=True)
|
||||
# Only one version can be active at a time for inference
|
||||
|
||||
# Training association
|
||||
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
|
||||
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||
|
||||
# Training metrics
|
||||
metrics_mAP: float | None = Field(default=None)
|
||||
metrics_precision: float | None = Field(default=None)
|
||||
metrics_recall: float | None = Field(default=None)
|
||||
document_count: int = Field(default=0)
|
||||
# Number of documents used in training
|
||||
|
||||
# Training configuration snapshot
|
||||
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Snapshot of epochs, batch_size, etc.
|
||||
|
||||
# File info
|
||||
file_size: int | None = Field(default=None)
|
||||
# Model file size in bytes
|
||||
|
||||
# Timestamps
|
||||
trained_at: datetime | None = Field(default=None)
|
||||
# When training completed
|
||||
activated_at: datetime | None = Field(default=None)
|
||||
# When this version was last activated
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation History (v2)
|
||||
# =============================================================================
|
||||
|
||||
@@ -49,6 +49,111 @@ def get_engine():
|
||||
return _engine
|
||||
|
||||
|
||||
def run_migrations() -> None:
|
||||
"""Run database migrations for new columns."""
|
||||
engine = get_engine()
|
||||
|
||||
migrations = [
|
||||
# Migration 004: Training datasets tables and dataset_id on training_tasks
|
||||
(
|
||||
"training_datasets_tables",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS training_datasets (
|
||||
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'building',
|
||||
train_ratio FLOAT NOT NULL DEFAULT 0.8,
|
||||
val_ratio FLOAT NOT NULL DEFAULT 0.1,
|
||||
seed INTEGER NOT NULL DEFAULT 42,
|
||||
total_documents INTEGER NOT NULL DEFAULT 0,
|
||||
total_images INTEGER NOT NULL DEFAULT 0,
|
||||
total_annotations INTEGER NOT NULL DEFAULT 0,
|
||||
dataset_path VARCHAR(512),
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"dataset_documents_table",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS dataset_documents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
|
||||
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
|
||||
split VARCHAR(10) NOT NULL,
|
||||
page_count INTEGER NOT NULL DEFAULT 0,
|
||||
annotation_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(dataset_id, document_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_dataset_id",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
|
||||
""",
|
||||
),
|
||||
# Migration 005: Add group_key to admin_documents
|
||||
(
|
||||
"admin_documents_group_key",
|
||||
"""
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
|
||||
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
|
||||
""",
|
||||
),
|
||||
# Migration 006: Model versions table
|
||||
(
|
||||
"model_versions_table",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS model_versions (
|
||||
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
version VARCHAR(50) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
model_path VARCHAR(512) NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
|
||||
is_active BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
task_id UUID REFERENCES training_tasks(task_id),
|
||||
dataset_id UUID REFERENCES training_datasets(dataset_id),
|
||||
metrics_mAP FLOAT,
|
||||
metrics_precision FLOAT,
|
||||
metrics_recall FLOAT,
|
||||
document_count INTEGER NOT NULL DEFAULT 0,
|
||||
training_config JSONB,
|
||||
file_size BIGINT,
|
||||
trained_at TIMESTAMP WITH TIME ZONE,
|
||||
activated_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
with engine.connect() as conn:
|
||||
for name, sql in migrations:
|
||||
try:
|
||||
conn.execute(text(sql))
|
||||
conn.commit()
|
||||
logger.info(f"Migration '{name}' applied successfully")
|
||||
except Exception as e:
|
||||
# Log but don't fail - column may already exist
|
||||
logger.debug(f"Migration '{name}' skipped or failed: {e}")
|
||||
|
||||
|
||||
def create_db_and_tables() -> None:
|
||||
"""Create all database tables."""
|
||||
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||
@@ -64,6 +169,9 @@ def create_db_and_tables() -> None:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
logger.info("Database tables created/verified")
|
||||
|
||||
# Run migrations for new columns
|
||||
run_migrations()
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
"""Get a new database session."""
|
||||
|
||||
@@ -5,6 +5,7 @@ Document management, annotations, and training endpoints.
|
||||
"""
|
||||
|
||||
from inference.web.api.v1.admin.annotations import create_annotation_router
|
||||
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from inference.web.api.v1.admin.auth import create_auth_router
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.api.v1.admin.locks import create_locks_router
|
||||
@@ -12,6 +13,7 @@ from inference.web.api.v1.admin.training import create_training_router
|
||||
|
||||
__all__ = [
|
||||
"create_annotation_router",
|
||||
"create_augmentation_router",
|
||||
"create_auth_router",
|
||||
"create_documents_router",
|
||||
"create_locks_router",
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Augmentation API module."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import register_augmentation_routes
|
||||
|
||||
|
||||
def create_augmentation_router() -> APIRouter:
|
||||
"""Create and configure the augmentation router."""
|
||||
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
|
||||
register_augmentation_routes(router)
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["create_augmentation_router"]
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Augmentation API routes."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminDBDep, AdminTokenDep
|
||||
from inference.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchRequest,
|
||||
AugmentationBatchResponse,
|
||||
AugmentationConfigSchema,
|
||||
AugmentationPreviewRequest,
|
||||
AugmentationPreviewResponse,
|
||||
AugmentationTypeInfo,
|
||||
AugmentationTypesResponse,
|
||||
AugmentedDatasetItem,
|
||||
AugmentedDatasetListResponse,
|
||||
PresetInfo,
|
||||
PresetsResponse,
|
||||
)
|
||||
|
||||
|
||||
def register_augmentation_routes(router: APIRouter) -> None:
|
||||
"""Register augmentation endpoints on the router."""
|
||||
|
||||
@router.get(
|
||||
"/types",
|
||||
response_model=AugmentationTypesResponse,
|
||||
summary="List available augmentation types",
|
||||
)
|
||||
async def list_augmentation_types(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> AugmentationTypesResponse:
|
||||
"""
|
||||
List all available augmentation types with descriptions and parameters.
|
||||
"""
|
||||
from shared.augmentation.pipeline import (
|
||||
AUGMENTATION_REGISTRY,
|
||||
AugmentationPipeline,
|
||||
)
|
||||
|
||||
types = []
|
||||
for name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||
# Create instance with empty params to get preview params
|
||||
aug = aug_class({})
|
||||
types.append(
|
||||
AugmentationTypeInfo(
|
||||
name=name,
|
||||
description=(aug_class.__doc__ or "").strip(),
|
||||
affects_geometry=aug_class.affects_geometry,
|
||||
stage=AugmentationPipeline.STAGE_MAPPING[name],
|
||||
default_params=aug.get_preview_params(),
|
||||
)
|
||||
)
|
||||
|
||||
return AugmentationTypesResponse(augmentation_types=types)
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
response_model=PresetsResponse,
|
||||
summary="Get augmentation presets",
|
||||
)
|
||||
async def get_presets(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> PresetsResponse:
|
||||
"""Get predefined augmentation presets for common use cases."""
|
||||
from shared.augmentation.presets import list_presets
|
||||
|
||||
presets = [PresetInfo(**p) for p in list_presets()]
|
||||
return PresetsResponse(presets=presets)
|
||||
|
||||
@router.post(
|
||||
"/preview/{document_id}",
|
||||
response_model=AugmentationPreviewResponse,
|
||||
summary="Preview augmentation on document image",
|
||||
)
|
||||
async def preview_augmentation(
|
||||
document_id: str,
|
||||
request: AugmentationPreviewRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview a single augmentation on a document page.
|
||||
|
||||
Returns URLs to original and augmented preview images.
|
||||
"""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.preview_single(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
augmentation_type=request.augmentation_type,
|
||||
params=request.params,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/preview-config/{document_id}",
|
||||
response_model=AugmentationPreviewResponse,
|
||||
summary="Preview full augmentation config on document",
|
||||
)
|
||||
async def preview_config(
|
||||
document_id: str,
|
||||
config: AugmentationConfigSchema,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
page: int = Query(default=1, ge=1, description="Page number"),
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""Preview complete augmentation pipeline on a document page."""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.preview_config(
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/batch",
|
||||
response_model=AugmentationBatchResponse,
|
||||
summary="Create augmented dataset (offline preprocessing)",
|
||||
)
|
||||
async def create_augmented_dataset(
|
||||
request: AugmentationBatchRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AugmentationBatchResponse:
|
||||
"""
|
||||
Create a new augmented dataset from an existing dataset.
|
||||
|
||||
This runs as a background task. The augmented images are stored
|
||||
alongside the original dataset for training.
|
||||
"""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.create_augmented_dataset(
|
||||
source_dataset_id=request.dataset_id,
|
||||
config=request.config,
|
||||
output_name=request.output_name,
|
||||
multiplier=request.multiplier,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets",
|
||||
response_model=AugmentedDatasetListResponse,
|
||||
summary="List augmented datasets",
|
||||
)
|
||||
async def list_augmented_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
|
||||
offset: int = Query(default=0, ge=0, description="Offset"),
|
||||
) -> AugmentedDatasetListResponse:
|
||||
"""List all augmented datasets."""
|
||||
from inference.web.services.augmentation_service import AugmentationService
|
||||
|
||||
service = AugmentationService(db=db)
|
||||
return await service.list_augmented_datasets(limit=limit, offset=offset)
|
||||
@@ -91,8 +91,19 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
bool,
|
||||
Query(description="Trigger auto-labeling after upload"),
|
||||
] = True,
|
||||
group_key: Annotated[
|
||||
str | None,
|
||||
Query(description="Optional group key for document organization", max_length=255),
|
||||
] = None,
|
||||
) -> DocumentUploadResponse:
|
||||
"""Upload a document for labeling."""
|
||||
# Validate group_key length
|
||||
if group_key and len(group_key) > 255:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Group key must be 255 characters or less",
|
||||
)
|
||||
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
@@ -131,6 +142,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
group_key=group_key,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
@@ -177,6 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
file_size=len(content),
|
||||
page_count=page_count,
|
||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||
group_key=group_key,
|
||||
auto_label_started=auto_label_started,
|
||||
message="Document uploaded successfully",
|
||||
)
|
||||
@@ -277,6 +290,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
annotation_count=len(annotations),
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||
can_annotate=can_annotate,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
@@ -421,6 +435,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
auto_label_error=document.auto_label_error,
|
||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
@@ -548,4 +563,50 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
|
||||
return response
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/group-key",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document group key",
|
||||
description="Update the group key for a document.",
|
||||
)
|
||||
async def update_document_group_key(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
group_key: Annotated[
|
||||
str | None,
|
||||
Query(description="New group key (null to clear)"),
|
||||
] = None,
|
||||
) -> dict:
|
||||
"""Update document group key."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate group_key length
|
||||
if group_key and len(group_key) > 255:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Group key must be 255 characters or less",
|
||||
)
|
||||
|
||||
# Verify document exists
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Update group key
|
||||
db.update_document_group_key(document_id, group_key)
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"group_key": group_key,
|
||||
"message": "Document group key updated",
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
@@ -11,6 +11,7 @@ from .tasks import register_task_routes
|
||||
from .documents import register_document_routes
|
||||
from .export import register_export_routes
|
||||
from .datasets import register_dataset_routes
|
||||
from .models import register_model_routes
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
@@ -21,6 +22,7 @@ def create_training_router() -> APIRouter:
|
||||
register_document_routes(router)
|
||||
register_export_routes(router)
|
||||
register_dataset_routes(router)
|
||||
register_model_routes(router)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@@ -41,6 +41,13 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
from pathlib import Path
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
# Validate minimum document count for proper train/val/test split
|
||||
if len(request.document_ids) < 10:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
|
||||
)
|
||||
|
||||
dataset = db.create_dataset(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
@@ -83,6 +90,15 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
) -> DatasetListResponse:
|
||||
"""List training datasets."""
|
||||
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
|
||||
|
||||
# Get active training tasks for each dataset (graceful degradation on error)
|
||||
dataset_ids = [str(d.dataset_id) for d in datasets]
|
||||
try:
|
||||
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch active training tasks")
|
||||
active_tasks = {}
|
||||
|
||||
return DatasetListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
@@ -93,6 +109,8 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
name=d.name,
|
||||
description=d.description,
|
||||
status=d.status,
|
||||
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
|
||||
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
|
||||
total_documents=d.total_documents,
|
||||
total_images=d.total_images,
|
||||
total_annotations=d.total_annotations,
|
||||
@@ -175,6 +193,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
"/datasets/{dataset_id}/train",
|
||||
response_model=TrainingTaskResponse,
|
||||
summary="Start training from dataset",
|
||||
description="Create a training task. Set base_model_version_id in config for incremental training.",
|
||||
)
|
||||
async def train_from_dataset(
|
||||
dataset_id: str,
|
||||
@@ -182,7 +201,11 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a training task from a dataset."""
|
||||
"""Create a training task from a dataset.
|
||||
|
||||
For incremental training, set config.base_model_version_id to a model version UUID.
|
||||
The training will use that model as the starting point instead of a pretrained model.
|
||||
"""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@@ -194,16 +217,42 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
)
|
||||
|
||||
config_dict = request.config.model_dump()
|
||||
|
||||
# Resolve base_model_version_id to actual model path for incremental training
|
||||
base_model_version_id = config_dict.get("base_model_version_id")
|
||||
if base_model_version_id:
|
||||
_validate_uuid(base_model_version_id, "base_model_version_id")
|
||||
base_model = db.get_model_version(base_model_version_id)
|
||||
if not base_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Base model version not found: {base_model_version_id}",
|
||||
)
|
||||
# Store the resolved model path for the training worker
|
||||
config_dict["base_model_path"] = base_model.model_path
|
||||
config_dict["base_model_version"] = base_model.version
|
||||
logger.info(
|
||||
"Incremental training: using model %s (%s) as base",
|
||||
base_model.version,
|
||||
base_model.model_path,
|
||||
)
|
||||
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type="train",
|
||||
task_type="finetune" if base_model_version_id else "train",
|
||||
config=config_dict,
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
|
||||
if base_model_version_id
|
||||
else "Training task created from dataset"
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
message="Training task created from dataset",
|
||||
message=message,
|
||||
)
|
||||
|
||||
@@ -145,15 +145,15 @@ def register_document_routes(router: APIRouter) -> None:
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
"/completed-tasks",
|
||||
response_model=TrainingModelsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get trained models",
|
||||
description="Get list of trained models with metrics and download links.",
|
||||
summary="Get completed training tasks",
|
||||
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
|
||||
)
|
||||
async def get_training_models(
|
||||
async def get_completed_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
|
||||
333
packages/inference/inference/web/api/v1/admin/training/models.py
Normal file
333
packages/inference/inference/web/api/v1/admin/training/models.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Model Version Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
ModelVersionCreateRequest,
|
||||
ModelVersionUpdateRequest,
|
||||
ModelVersionItem,
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ModelVersionResponse,
|
||||
ActiveModelResponse,
|
||||
)
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_model_routes(router: APIRouter) -> None:
|
||||
"""Register model version endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Create model version",
|
||||
description="Register a new model version for deployment.",
|
||||
)
|
||||
async def create_model_version(
|
||||
request: ModelVersionCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Create a new model version."""
|
||||
if request.task_id:
|
||||
_validate_uuid(request.task_id, "task_id")
|
||||
if request.dataset_id:
|
||||
_validate_uuid(request.dataset_id, "dataset_id")
|
||||
|
||||
model = db.create_model_version(
|
||||
version=request.version,
|
||||
name=request.name,
|
||||
model_path=request.model_path,
|
||||
description=request.description,
|
||||
task_id=request.task_id,
|
||||
dataset_id=request.dataset_id,
|
||||
metrics_mAP=request.metrics_mAP,
|
||||
metrics_precision=request.metrics_precision,
|
||||
metrics_recall=request.metrics_recall,
|
||||
document_count=request.document_count,
|
||||
training_config=request.training_config,
|
||||
file_size=request.file_size,
|
||||
trained_at=request.trained_at,
|
||||
)
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=ModelVersionListResponse,
|
||||
summary="List model versions",
|
||||
)
|
||||
async def list_model_versions(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> ModelVersionListResponse:
|
||||
"""List model versions with optional status filter."""
|
||||
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
|
||||
return ModelVersionListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=[
|
||||
ModelVersionItem(
|
||||
version_id=str(m.version_id),
|
||||
version=m.version,
|
||||
name=m.name,
|
||||
status=m.status,
|
||||
is_active=m.is_active,
|
||||
metrics_mAP=m.metrics_mAP,
|
||||
document_count=m.document_count,
|
||||
trained_at=m.trained_at,
|
||||
activated_at=m.activated_at,
|
||||
created_at=m.created_at,
|
||||
)
|
||||
for m in models
|
||||
],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/active",
|
||||
response_model=ActiveModelResponse,
|
||||
summary="Get active model",
|
||||
description="Get the currently active model for inference.",
|
||||
)
|
||||
async def get_active_model(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ActiveModelResponse:
|
||||
"""Get the currently active model version."""
|
||||
model = db.get_active_model_version()
|
||||
if not model:
|
||||
return ActiveModelResponse(has_active_model=False, model=None)
|
||||
|
||||
return ActiveModelResponse(
|
||||
has_active_model=True,
|
||||
model=ModelVersionItem(
|
||||
version_id=str(model.version_id),
|
||||
version=model.version,
|
||||
name=model.name,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
document_count=model.document_count,
|
||||
trained_at=model.trained_at,
|
||||
activated_at=model.activated_at,
|
||||
created_at=model.created_at,
|
||||
),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{version_id}",
|
||||
response_model=ModelVersionDetailResponse,
|
||||
summary="Get model version detail",
|
||||
)
|
||||
async def get_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionDetailResponse:
|
||||
"""Get detailed model version information."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.get_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionDetailResponse(
|
||||
version_id=str(model.version_id),
|
||||
version=model.version,
|
||||
name=model.name,
|
||||
description=model.description,
|
||||
model_path=model.model_path,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
task_id=str(model.task_id) if model.task_id else None,
|
||||
dataset_id=str(model.dataset_id) if model.dataset_id else None,
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
metrics_precision=model.metrics_precision,
|
||||
metrics_recall=model.metrics_recall,
|
||||
document_count=model.document_count,
|
||||
training_config=model.training_config,
|
||||
file_size=model.file_size,
|
||||
trained_at=model.trained_at,
|
||||
activated_at=model.activated_at,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/models/{version_id}",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Update model version",
|
||||
)
|
||||
async def update_model_version(
|
||||
version_id: str,
|
||||
request: ModelVersionUpdateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Update model version metadata."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.update_model_version(
|
||||
version_id=version_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
status=request.status,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version updated successfully",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/activate",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Activate model version",
|
||||
description="Activate a model version for inference (deactivates all others).",
|
||||
)
|
||||
async def activate_model_version(
|
||||
version_id: str,
|
||||
request: Request,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Activate a model version for inference."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.activate_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
# Trigger model reload in inference service
|
||||
inference_service = getattr(request.app.state, "inference_service", None)
|
||||
model_reloaded = False
|
||||
if inference_service:
|
||||
try:
|
||||
model_reloaded = inference_service.reload_model()
|
||||
if model_reloaded:
|
||||
logger.info(f"Inference model reloaded to version {model.version}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reload inference model: {e}")
|
||||
|
||||
message = "Model version activated for inference"
|
||||
if model_reloaded:
|
||||
message += " (model reloaded)"
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message=message,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/deactivate",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Deactivate model version",
|
||||
)
|
||||
async def deactivate_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Deactivate a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.deactivate_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version deactivated",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/models/{version_id}/archive",
|
||||
response_model=ModelVersionResponse,
|
||||
summary="Archive model version",
|
||||
)
|
||||
async def archive_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ModelVersionResponse:
|
||||
"""Archive a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
model = db.archive_model_version(version_id)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model version not found or cannot archive active model",
|
||||
)
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
status=model.status,
|
||||
message="Model version archived",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/models/{version_id}",
|
||||
summary="Delete model version",
|
||||
)
|
||||
async def delete_model_version(
|
||||
version_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete a model version."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
success = db.delete_model_version(version_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model version not found or cannot delete active model",
|
||||
)
|
||||
|
||||
return {"message": "Model version deleted"}
|
||||
|
||||
@router.post(
|
||||
"/models/reload",
|
||||
summary="Reload inference model",
|
||||
description="Reload the inference model from the currently active model version.",
|
||||
)
|
||||
async def reload_inference_model(
|
||||
request: Request,
|
||||
admin_token: AdminTokenDep,
|
||||
) -> dict:
|
||||
"""Reload the inference model from active version."""
|
||||
inference_service = getattr(request.app.state, "inference_service", None)
|
||||
if not inference_service:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Inference service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
model_reloaded = inference_service.reload_model()
|
||||
if model_reloaded:
|
||||
logger.info("Inference model manually reloaded")
|
||||
return {"message": "Model reloaded successfully", "reloaded": True}
|
||||
else:
|
||||
return {"message": "Model already up to date", "reloaded": False}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to reload model: {e}",
|
||||
)
|
||||
@@ -37,6 +37,7 @@ from inference.web.core.rate_limiter import RateLimiter
|
||||
# Admin API imports
|
||||
from inference.web.api.v1.admin import (
|
||||
create_annotation_router,
|
||||
create_augmentation_router,
|
||||
create_auth_router,
|
||||
create_documents_router,
|
||||
create_locks_router,
|
||||
@@ -69,10 +70,23 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
"""
|
||||
config = config or default_config
|
||||
|
||||
# Create inference service
|
||||
# Create model path resolver that reads from database
|
||||
def get_active_model_path():
|
||||
"""Resolve active model path from database."""
|
||||
try:
|
||||
db = AdminDB()
|
||||
active_model = db.get_active_model_version()
|
||||
if active_model and active_model.model_path:
|
||||
return active_model.model_path
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get active model from database: {e}")
|
||||
return None
|
||||
|
||||
# Create inference service with database model resolver
|
||||
inference_service = InferenceService(
|
||||
model_config=config.model,
|
||||
storage_config=config.storage,
|
||||
model_path_resolver=get_active_model_path,
|
||||
)
|
||||
|
||||
# Create async processing components
|
||||
@@ -185,6 +199,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
logger.error(f"Error closing database: {e}")
|
||||
|
||||
# Create FastAPI app
|
||||
# Store inference service for access by routes (e.g., model reload)
|
||||
# This will be set after app creation
|
||||
|
||||
app = FastAPI(
|
||||
title="Invoice Field Extraction API",
|
||||
description="""
|
||||
@@ -255,9 +272,15 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
training_router = create_training_router()
|
||||
app.include_router(training_router, prefix="/api/v1")
|
||||
|
||||
augmentation_router = create_augmentation_router()
|
||||
app.include_router(augmentation_router, prefix="/api/v1/admin")
|
||||
|
||||
# Include batch upload routes
|
||||
app.include_router(batch_upload_router)
|
||||
|
||||
# Store inference service in app state for access by routes
|
||||
app.state.inference_service = inference_service
|
||||
|
||||
# Root endpoint - serve HTML UI
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root() -> str:
|
||||
|
||||
@@ -110,6 +110,7 @@ class TrainingScheduler:
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
base_model_path = config.get("base_model_path") # For incremental training
|
||||
epochs = config.get("epochs", 100)
|
||||
batch_size = config.get("batch_size", 16)
|
||||
image_size = config.get("image_size", 640)
|
||||
@@ -117,12 +118,31 @@ class TrainingScheduler:
|
||||
device = config.get("device", "0")
|
||||
project_name = config.get("project_name", "invoice_fields")
|
||||
|
||||
# Get augmentation config if present
|
||||
augmentation_config = config.get("augmentation")
|
||||
augmentation_multiplier = config.get("augmentation_multiplier", 2)
|
||||
|
||||
# Determine which model to use as base
|
||||
if base_model_path:
|
||||
# Incremental training: use existing trained model
|
||||
if not Path(base_model_path).exists():
|
||||
raise ValueError(f"Base model not found: {base_model_path}")
|
||||
effective_model = base_model_path
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Incremental training from: {base_model_path}",
|
||||
)
|
||||
else:
|
||||
# Train from pretrained model
|
||||
effective_model = model_name
|
||||
|
||||
# Use dataset if available, otherwise export from scratch
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
if not dataset or not dataset.dataset_path:
|
||||
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
||||
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
||||
dataset_path = Path(dataset.dataset_path)
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
||||
@@ -132,15 +152,28 @@ class TrainingScheduler:
|
||||
if not export_result:
|
||||
raise ValueError("Failed to export training data")
|
||||
data_yaml = export_result["data_yaml"]
|
||||
dataset_path = Path(data_yaml).parent
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
|
||||
# Apply augmentation if config is provided
|
||||
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
|
||||
aug_result = self._apply_augmentation(
|
||||
task_id, dataset_path, augmentation_config, augmentation_multiplier
|
||||
)
|
||||
if aug_result:
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Augmentation complete: {aug_result['augmented_images']} new images "
|
||||
f"(total: {aug_result['total_images']})",
|
||||
)
|
||||
|
||||
# Run YOLO training
|
||||
result = self._run_yolo_training(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
model_name=effective_model, # Use base model or pretrained model
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
@@ -159,11 +192,94 @@ class TrainingScheduler:
|
||||
)
|
||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
# Auto-create model version for the completed training
|
||||
self._create_model_version_from_training(
|
||||
task_id=task_id,
|
||||
config=config,
|
||||
dataset_id=dataset_id,
|
||||
result=result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
def _create_model_version_from_training(
|
||||
self,
|
||||
task_id: str,
|
||||
config: dict[str, Any],
|
||||
dataset_id: str | None,
|
||||
result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Create a model version entry from completed training."""
|
||||
try:
|
||||
model_path = result.get("model_path")
|
||||
if not model_path:
|
||||
logger.warning(f"No model path in training result for task {task_id}")
|
||||
return
|
||||
|
||||
# Get task info for name
|
||||
task = self._db.get_training_task(task_id)
|
||||
task_name = task.name if task else f"Task {task_id[:8]}"
|
||||
|
||||
# Generate version number based on existing versions
|
||||
existing_versions = self._db.get_model_versions(limit=1, offset=0)
|
||||
version_count = existing_versions[1] if existing_versions else 0
|
||||
version = f"v{version_count + 1}.0"
|
||||
|
||||
# Extract metrics from result
|
||||
metrics = result.get("metrics", {})
|
||||
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
|
||||
metrics_precision = metrics.get("precision")
|
||||
metrics_recall = metrics.get("recall")
|
||||
|
||||
# Get file size if possible
|
||||
file_size = None
|
||||
model_file = Path(model_path)
|
||||
if model_file.exists():
|
||||
file_size = model_file.stat().st_size
|
||||
|
||||
# Get document count from dataset if available
|
||||
document_count = 0
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
if dataset:
|
||||
document_count = dataset.total_documents
|
||||
|
||||
# Create model version
|
||||
model_version = self._db.create_model_version(
|
||||
version=version,
|
||||
name=task_name,
|
||||
model_path=str(model_path),
|
||||
description=f"Auto-created from training task {task_id[:8]}",
|
||||
task_id=task_id,
|
||||
dataset_id=dataset_id,
|
||||
metrics_mAP=metrics_mAP,
|
||||
metrics_precision=metrics_precision,
|
||||
metrics_recall=metrics_recall,
|
||||
document_count=document_count,
|
||||
training_config=config,
|
||||
file_size=file_size,
|
||||
trained_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created model version {version} (ID: {model_version.version_id}) "
|
||||
f"from training task {task_id}"
|
||||
)
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model version for task {task_id}: {e}")
|
||||
self._db.add_training_log(
|
||||
task_id, "WARNING",
|
||||
f"Failed to auto-create model version: {e}",
|
||||
)
|
||||
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
@@ -256,62 +372,82 @@ names: {list(FIELD_CLASSES.values())}
|
||||
device: str,
|
||||
project_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run YOLO training."""
|
||||
"""Run YOLO training using shared trainer."""
|
||||
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
|
||||
|
||||
# Create log callback that writes to DB
|
||||
def log_callback(level: str, message: str) -> None:
|
||||
self._db.add_training_log(task_id, level, message)
|
||||
|
||||
# Create shared training config
|
||||
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
||||
config = SharedTrainingConfig(
|
||||
model_path=model_name,
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
image_size=image_size,
|
||||
learning_rate=learning_rate,
|
||||
device=device,
|
||||
project="runs/train",
|
||||
name=f"{project_name}/task_{task_id[:8]}",
|
||||
workers=0,
|
||||
)
|
||||
|
||||
# Run training using shared trainer
|
||||
trainer = YOLOTrainer(config=config, log_callback=log_callback)
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
raise ValueError(result.error or "Training failed")
|
||||
|
||||
return {
|
||||
"model_path": result.model_path,
|
||||
"metrics": result.metrics,
|
||||
}
|
||||
|
||||
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
|
||||
"""Check if any augmentations are enabled in the config."""
|
||||
augmentation_fields = [
|
||||
"perspective_warp", "wrinkle", "edge_damage", "stain",
|
||||
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
|
||||
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
|
||||
]
|
||||
for field in augmentation_fields:
|
||||
if field in aug_config:
|
||||
field_config = aug_config[field]
|
||||
if isinstance(field_config, dict) and field_config.get("enabled", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _apply_augmentation(
|
||||
self,
|
||||
task_id: str,
|
||||
dataset_path: Path,
|
||||
aug_config: dict[str, Any],
|
||||
multiplier: int,
|
||||
) -> dict[str, int] | None:
|
||||
"""Apply augmentation to dataset before training."""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Log training start
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_name)
|
||||
|
||||
# Train
|
||||
results = model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
batch=batch_size,
|
||||
imgsz=image_size,
|
||||
lr0=learning_rate,
|
||||
device=device,
|
||||
project=f"runs/train/{project_name}",
|
||||
name=f"task_{task_id[:8]}",
|
||||
exist_ok=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Get best model path
|
||||
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||
|
||||
# Extract metrics
|
||||
metrics = {}
|
||||
if hasattr(results, "results_dict"):
|
||||
metrics = {
|
||||
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||
}
|
||||
from shared.augmentation import DatasetAugmenter
|
||||
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
|
||||
f"Applying augmentation with multiplier={multiplier}",
|
||||
)
|
||||
|
||||
return {
|
||||
"model_path": str(best_model) if best_model.exists() else None,
|
||||
"metrics": metrics,
|
||||
}
|
||||
augmenter = DatasetAugmenter(aug_config)
|
||||
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
|
||||
|
||||
return result
|
||||
|
||||
except ImportError:
|
||||
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
|
||||
raise ValueError("Ultralytics (YOLO) not installed")
|
||||
except Exception as e:
|
||||
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
|
||||
raise
|
||||
logger.error(f"Augmentation failed for task {task_id}: {e}")
|
||||
self._db.add_training_log(
|
||||
task_id, "WARNING",
|
||||
f"Augmentation failed: {e}. Continuing with original dataset.",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
|
||||
@@ -10,6 +10,7 @@ from .documents import * # noqa: F401, F403
|
||||
from .annotations import * # noqa: F401, F403
|
||||
from .training import * # noqa: F401, F403
|
||||
from .datasets import * # noqa: F401, F403
|
||||
from .models import * # noqa: F401, F403
|
||||
|
||||
# Resolve forward references for DocumentDetailResponse
|
||||
from .documents import DocumentDetailResponse
|
||||
|
||||
187
packages/inference/inference/web/schemas/admin/augmentation.py
Normal file
187
packages/inference/inference/web/schemas/admin/augmentation.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Admin Augmentation Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AugmentationParamsSchema(BaseModel):
|
||||
"""Single augmentation parameters."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
|
||||
probability: float = Field(
|
||||
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
|
||||
)
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Type-specific parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationConfigSchema(BaseModel):
|
||||
"""Complete augmentation configuration."""
|
||||
|
||||
# Geometric transforms
|
||||
perspective_warp: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Degradation effects
|
||||
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
edge_damage: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
|
||||
# Lighting effects
|
||||
lighting_variation: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||
|
||||
# Blur effects
|
||||
gaussian_blur: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
motion_blur: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Noise effects
|
||||
gaussian_noise: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
salt_pepper: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Texture effects
|
||||
paper_texture: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
scanner_artifacts: AugmentationParamsSchema = Field(
|
||||
default_factory=AugmentationParamsSchema
|
||||
)
|
||||
|
||||
# Global settings
|
||||
preserve_bboxes: bool = Field(
|
||||
default=True, description="Whether to adjust bboxes for geometric transforms"
|
||||
)
|
||||
seed: int | None = Field(default=None, description="Random seed for reproducibility")
|
||||
|
||||
|
||||
class AugmentationTypeInfo(BaseModel):
|
||||
"""Information about an augmentation type."""
|
||||
|
||||
name: str = Field(..., description="Augmentation name")
|
||||
description: str = Field(..., description="Augmentation description")
|
||||
affects_geometry: bool = Field(
|
||||
..., description="Whether this augmentation affects bbox coordinates"
|
||||
)
|
||||
stage: str = Field(..., description="Processing stage")
|
||||
default_params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationTypesResponse(BaseModel):
|
||||
"""Response for listing augmentation types."""
|
||||
|
||||
augmentation_types: list[AugmentationTypeInfo] = Field(
|
||||
..., description="Available augmentation types"
|
||||
)
|
||||
|
||||
|
||||
class PresetInfo(BaseModel):
|
||||
"""Information about a preset."""
|
||||
|
||||
name: str = Field(..., description="Preset name")
|
||||
description: str = Field(..., description="Preset description")
|
||||
|
||||
|
||||
class PresetsResponse(BaseModel):
|
||||
"""Response for listing presets."""
|
||||
|
||||
presets: list[PresetInfo] = Field(..., description="Available presets")
|
||||
|
||||
|
||||
class AugmentationPreviewRequest(BaseModel):
|
||||
"""Request to preview augmentation on an image."""
|
||||
|
||||
augmentation_type: str = Field(..., description="Type of augmentation to preview")
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Override parameters"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationPreviewResponse(BaseModel):
|
||||
"""Response with preview image data."""
|
||||
|
||||
preview_url: str = Field(..., description="URL to preview image")
|
||||
original_url: str = Field(..., description="URL to original image")
|
||||
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
|
||||
|
||||
|
||||
class AugmentationBatchRequest(BaseModel):
|
||||
"""Request to augment a dataset offline."""
|
||||
|
||||
dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
|
||||
output_name: str = Field(
|
||||
..., min_length=1, max_length=255, description="Output dataset name"
|
||||
)
|
||||
multiplier: int = Field(
|
||||
default=2, ge=1, le=10, description="Augmented copies per image"
|
||||
)
|
||||
|
||||
|
||||
class AugmentationBatchResponse(BaseModel):
|
||||
"""Response for batch augmentation."""
|
||||
|
||||
task_id: str = Field(..., description="Background task UUID")
|
||||
status: str = Field(..., description="Task status")
|
||||
message: str = Field(..., description="Status message")
|
||||
estimated_images: int = Field(..., description="Estimated total images")
|
||||
|
||||
|
||||
class AugmentedDatasetItem(BaseModel):
|
||||
"""Single augmented dataset in list."""
|
||||
|
||||
dataset_id: str = Field(..., description="Dataset UUID")
|
||||
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
name: str = Field(..., description="Dataset name")
|
||||
status: str = Field(..., description="Dataset status")
|
||||
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||
total_original_images: int = Field(..., description="Original image count")
|
||||
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class AugmentedDatasetListResponse(BaseModel):
|
||||
"""Response for listing augmented datasets."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total datasets")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
datasets: list[AugmentedDatasetItem] = Field(
|
||||
default_factory=list, description="Dataset list"
|
||||
)
|
||||
|
||||
|
||||
class AugmentedDatasetDetailResponse(BaseModel):
|
||||
"""Detailed augmented dataset response."""
|
||||
|
||||
dataset_id: str = Field(..., description="Dataset UUID")
|
||||
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||
name: str = Field(..., description="Dataset name")
|
||||
status: str = Field(..., description="Dataset status")
|
||||
config: AugmentationConfigSchema | None = Field(
|
||||
None, description="Augmentation config used"
|
||||
)
|
||||
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||
total_original_images: int = Field(..., description="Original image count")
|
||||
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||
dataset_path: str | None = Field(None, description="Dataset path on disk")
|
||||
error_message: str | None = Field(None, description="Error message if failed")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||
@@ -63,6 +63,8 @@ class DatasetListItem(BaseModel):
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
training_status: str | None = None
|
||||
active_training_task_id: str | None = None
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
|
||||
@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
)
|
||||
@@ -42,6 +43,7 @@ class DocumentItem(BaseModel):
|
||||
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
@@ -73,6 +75,7 @@ class DocumentDetailResponse(BaseModel):
|
||||
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
group_key: str | None = Field(None, description="User-defined group key")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
|
||||
95
packages/inference/inference/web/schemas/admin/models.py
Normal file
95
packages/inference/inference/web/schemas/admin/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Admin Model Version Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelVersionCreateRequest(BaseModel):
|
||||
"""Request to create a model version."""
|
||||
|
||||
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Model name")
|
||||
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
task_id: str | None = Field(None, description="Training task UUID")
|
||||
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||
document_count: int = Field(0, ge=0, description="Documents used in training")
|
||||
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
|
||||
|
||||
class ModelVersionUpdateRequest(BaseModel):
|
||||
"""Request to update a model version."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
status: str | None = Field(None, description="Status (inactive, archived)")
|
||||
|
||||
|
||||
class ModelVersionItem(BaseModel):
|
||||
"""Model version in list view."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
name: str = Field(..., description="Model name")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
document_count: int = Field(..., description="Documents used in training")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class ModelVersionListResponse(BaseModel):
|
||||
"""Paginated model version list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total model versions")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
|
||||
|
||||
|
||||
class ModelVersionDetailResponse(BaseModel):
|
||||
"""Detailed model version info."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
version: str = Field(..., description="Semantic version")
|
||||
name: str = Field(..., description="Model name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
model_path: str = Field(..., description="Path to model file")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
task_id: str | None = Field(None, description="Training task UUID")
|
||||
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, description="Precision")
|
||||
metrics_recall: float | None = Field(None, description="Recall")
|
||||
document_count: int = Field(..., description="Documents used in training")
|
||||
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
file_size: int | None = Field(None, description="Model file size in bytes")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class ModelVersionResponse(BaseModel):
|
||||
"""Response for model version operation."""
|
||||
|
||||
version_id: str = Field(..., description="Version UUID")
|
||||
status: str = Field(..., description="Model status")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class ActiveModelResponse(BaseModel):
|
||||
"""Response for active model query."""
|
||||
|
||||
has_active_model: bool = Field(..., description="Whether an active model exists")
|
||||
model: ModelVersionItem | None = Field(None, description="Active model if exists")
|
||||
@@ -5,13 +5,18 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .augmentation import AugmentationConfigSchema
|
||||
from .enums import TrainingStatus, TrainingType
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training configuration."""
|
||||
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name")
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
|
||||
base_model_version_id: str | None = Field(
|
||||
default=None,
|
||||
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
|
||||
)
|
||||
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
||||
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
||||
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
||||
@@ -21,6 +26,18 @@ class TrainingConfig(BaseModel):
|
||||
default="invoice_fields", description="Training project name"
|
||||
)
|
||||
|
||||
# Data augmentation settings
|
||||
augmentation: AugmentationConfigSchema | None = Field(
|
||||
default=None,
|
||||
description="Augmentation configuration. If provided, augments dataset before training.",
|
||||
)
|
||||
augmentation_multiplier: int = Field(
|
||||
default=2,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Number of augmented copies per original image",
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskCreate(BaseModel):
|
||||
"""Request to create a training task."""
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
"""Augmentation service for handling augmentation operations."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.schemas.admin.augmentation import (
|
||||
AugmentationBatchResponse,
|
||||
AugmentationConfigSchema,
|
||||
AugmentationPreviewResponse,
|
||||
AugmentedDatasetItem,
|
||||
AugmentedDatasetListResponse,
|
||||
)
|
||||
|
||||
# Constants
|
||||
PREVIEW_MAX_SIZE = 800
|
||||
PREVIEW_SEED = 42
|
||||
UUID_PATTERN = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class AugmentationService:
|
||||
"""Service for augmentation operations."""
|
||||
|
||||
def __init__(self, db: AdminDB) -> None:
|
||||
"""Initialize service with database connection."""
|
||||
self.db = db
|
||||
|
||||
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
|
||||
"""
|
||||
Validate UUID format to prevent path traversal.
|
||||
|
||||
Args:
|
||||
value: Value to validate.
|
||||
field_name: Field name for error message.
|
||||
|
||||
Raises:
|
||||
HTTPException: If value is not a valid UUID.
|
||||
"""
|
||||
if not UUID_PATTERN.match(value):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {field_name} format: {value}",
|
||||
)
|
||||
|
||||
async def preview_single(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
augmentation_type: str,
|
||||
params: dict[str, Any],
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview a single augmentation on a document page.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
augmentation_type: Name of augmentation to apply.
|
||||
params: Override parameters.
|
||||
|
||||
Returns:
|
||||
Preview response with image URLs.
|
||||
|
||||
Raises:
|
||||
HTTPException: If document not found or augmentation invalid.
|
||||
"""
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
|
||||
|
||||
# Validate augmentation type
|
||||
if augmentation_type not in AUGMENTATION_REGISTRY:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown augmentation type: {augmentation_type}. "
|
||||
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
|
||||
)
|
||||
|
||||
# Get document and load image
|
||||
image = await self._load_document_page(document_id, page)
|
||||
|
||||
# Create config with only this augmentation enabled
|
||||
config_kwargs = {
|
||||
augmentation_type: AugmentationParams(
|
||||
enabled=True,
|
||||
probability=1.0, # Always apply for preview
|
||||
params=params,
|
||||
),
|
||||
"seed": PREVIEW_SEED, # Deterministic preview
|
||||
}
|
||||
config = AugmentationConfig(**config_kwargs)
|
||||
pipeline = AugmentationPipeline(config)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Convert to base64 URLs
|
||||
original_url = self._image_to_data_url(image)
|
||||
preview_url = self._image_to_data_url(result.image)
|
||||
|
||||
return AugmentationPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
original_url=original_url,
|
||||
applied_params=params,
|
||||
)
|
||||
|
||||
async def preview_config(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
config: AugmentationConfigSchema,
|
||||
) -> AugmentationPreviewResponse:
|
||||
"""
|
||||
Preview full augmentation config on a document page.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
config: Full augmentation configuration.
|
||||
|
||||
Returns:
|
||||
Preview response with image URLs.
|
||||
"""
|
||||
from shared.augmentation.config import AugmentationConfig
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
# Load image
|
||||
image = await self._load_document_page(document_id, page)
|
||||
|
||||
# Convert Pydantic model to internal config
|
||||
config_dict = config.model_dump()
|
||||
internal_config = AugmentationConfig.from_dict(config_dict)
|
||||
pipeline = AugmentationPipeline(internal_config)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image)
|
||||
|
||||
# Convert to base64 URLs
|
||||
original_url = self._image_to_data_url(image)
|
||||
preview_url = self._image_to_data_url(result.image)
|
||||
|
||||
return AugmentationPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
original_url=original_url,
|
||||
applied_params=config_dict,
|
||||
)
|
||||
|
||||
async def create_augmented_dataset(
|
||||
self,
|
||||
source_dataset_id: str,
|
||||
config: AugmentationConfigSchema,
|
||||
output_name: str,
|
||||
multiplier: int,
|
||||
) -> AugmentationBatchResponse:
|
||||
"""
|
||||
Create a new augmented dataset from an existing dataset.
|
||||
|
||||
Args:
|
||||
source_dataset_id: Source dataset UUID.
|
||||
config: Augmentation configuration.
|
||||
output_name: Name for the new dataset.
|
||||
multiplier: Number of augmented copies per image.
|
||||
|
||||
Returns:
|
||||
Batch response with task ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If source dataset not found.
|
||||
"""
|
||||
# Validate source dataset exists
|
||||
try:
|
||||
source_dataset = self.db.get_dataset(source_dataset_id)
|
||||
if source_dataset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source dataset not found: {source_dataset_id}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source dataset not found: {source_dataset_id}",
|
||||
) from e
|
||||
|
||||
# Create task ID for background processing
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Estimate total images
|
||||
estimated_images = (
|
||||
source_dataset.total_images * multiplier
|
||||
if hasattr(source_dataset, "total_images")
|
||||
else 0
|
||||
)
|
||||
|
||||
# TODO: Queue background task for actual augmentation
|
||||
# For now, return pending status
|
||||
|
||||
return AugmentationBatchResponse(
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
message=f"Augmentation task queued for dataset '{output_name}'",
|
||||
estimated_images=estimated_images,
|
||||
)
|
||||
|
||||
async def list_augmented_datasets(
|
||||
self,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> AugmentedDatasetListResponse:
|
||||
"""
|
||||
List augmented datasets.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of datasets to return.
|
||||
offset: Number of datasets to skip.
|
||||
|
||||
Returns:
|
||||
List response with datasets.
|
||||
"""
|
||||
# TODO: Implement actual database query for augmented datasets
|
||||
# For now, return empty list
|
||||
|
||||
return AugmentedDatasetListResponse(
|
||||
total=0,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
datasets=[],
|
||||
)
|
||||
|
||||
async def _load_document_page(
|
||||
self,
|
||||
document_id: str,
|
||||
page: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load a document page as numpy array.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID.
|
||||
page: Page number (1-indexed).
|
||||
|
||||
Returns:
|
||||
Image as numpy array (H, W, C) with dtype uint8.
|
||||
|
||||
Raises:
|
||||
HTTPException: If document or page not found.
|
||||
"""
|
||||
# Validate document_id format to prevent path traversal
|
||||
self._validate_uuid(document_id, "document_id")
|
||||
|
||||
# Get document from database
|
||||
try:
|
||||
document = self.db.get_document(document_id)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document not found: {document_id}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document not found: {document_id}",
|
||||
) from e
|
||||
|
||||
# Get image path for page
|
||||
if hasattr(document, "images_dir"):
|
||||
images_dir = Path(document.images_dir)
|
||||
else:
|
||||
# Fallback to constructed path
|
||||
from inference.web.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
|
||||
|
||||
# Find image for page
|
||||
page_idx = page - 1 # Convert to 0-indexed
|
||||
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
|
||||
|
||||
if page_idx >= len(image_files):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Page {page} not found for document {document_id}",
|
||||
)
|
||||
|
||||
# Load image
|
||||
image_path = image_files[page_idx]
|
||||
pil_image = Image.open(image_path).convert("RGB")
|
||||
return np.array(pil_image)
|
||||
|
||||
def _image_to_data_url(self, image: np.ndarray) -> str:
|
||||
"""Convert numpy image to base64 data URL."""
|
||||
pil_image = Image.fromarray(image)
|
||||
|
||||
# Resize for preview if too large
|
||||
max_size = PREVIEW_MAX_SIZE
|
||||
if max(pil_image.size) > max_size:
|
||||
ratio = max_size / max(pil_image.size)
|
||||
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return f"data:image/png;base64,{base64_data}"
|
||||
@@ -81,29 +81,18 @@ class DatasetBuilder:
|
||||
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. Shuffle and split documents
|
||||
# 3. Group documents by group_key and assign splits
|
||||
doc_list = list(documents)
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(doc_list)
|
||||
|
||||
n = len(doc_list)
|
||||
n_train = max(1, round(n * train_ratio))
|
||||
n_val = max(0, round(n * val_ratio))
|
||||
n_test = n - n_train - n_val
|
||||
|
||||
splits = (
|
||||
["train"] * n_train
|
||||
+ ["val"] * n_val
|
||||
+ ["test"] * n_test
|
||||
)
|
||||
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
|
||||
|
||||
# 4. Process each document
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
dataset_docs = []
|
||||
|
||||
for doc, split in zip(doc_list, splits):
|
||||
for doc in doc_list:
|
||||
doc_id = str(doc.document_id)
|
||||
split = doc_splits[doc_id]
|
||||
annotations = self._db.get_annotations_for_document(doc.document_id)
|
||||
|
||||
# Group annotations by page
|
||||
@@ -174,6 +163,86 @@ class DatasetBuilder:
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _assign_splits_by_group(
|
||||
self,
|
||||
documents: list,
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
) -> dict[str, str]:
|
||||
"""Assign splits based on group_key.
|
||||
|
||||
Logic:
|
||||
- Documents with same group_key stay together in the same split
|
||||
- Groups with only 1 document go directly to train
|
||||
- Groups with 2+ documents participate in shuffle & split
|
||||
|
||||
Args:
|
||||
documents: List of AdminDocument objects
|
||||
train_ratio: Fraction for training set
|
||||
val_ratio: Fraction for validation set
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Dict mapping document_id (str) -> split ("train"/"val"/"test")
|
||||
"""
|
||||
# Group documents by group_key
|
||||
# None/empty group_key treated as unique (each doc is its own group)
|
||||
groups: dict[str | None, list] = {}
|
||||
for doc in documents:
|
||||
key = doc.group_key if doc.group_key else None
|
||||
if key is None:
|
||||
# Treat each ungrouped doc as its own unique group
|
||||
# Use document_id as pseudo-key
|
||||
key = f"__ungrouped_{doc.document_id}"
|
||||
groups.setdefault(key, []).append(doc)
|
||||
|
||||
# Separate single-doc groups from multi-doc groups
|
||||
single_doc_groups: list[tuple[str | None, list]] = []
|
||||
multi_doc_groups: list[tuple[str | None, list]] = []
|
||||
|
||||
for key, docs in groups.items():
|
||||
if len(docs) == 1:
|
||||
single_doc_groups.append((key, docs))
|
||||
else:
|
||||
multi_doc_groups.append((key, docs))
|
||||
|
||||
# Initialize result mapping
|
||||
doc_splits: dict[str, str] = {}
|
||||
|
||||
# Combine all groups for splitting
|
||||
all_groups = single_doc_groups + multi_doc_groups
|
||||
|
||||
# Shuffle all groups and assign splits
|
||||
if all_groups:
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(all_groups)
|
||||
|
||||
n_groups = len(all_groups)
|
||||
n_train = max(1, round(n_groups * train_ratio))
|
||||
# Ensure at least 1 in val if we have more than 1 group
|
||||
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
|
||||
|
||||
for i, (_key, docs) in enumerate(all_groups):
|
||||
if i < n_train:
|
||||
split = "train"
|
||||
elif i < n_train + n_val:
|
||||
split = "val"
|
||||
else:
|
||||
split = "test"
|
||||
|
||||
for doc in docs:
|
||||
doc_splits[str(doc.document_id)] = split
|
||||
|
||||
logger.info(
|
||||
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
|
||||
len(all_groups),
|
||||
sum(1 for s in doc_splits.values() if s == "train"),
|
||||
sum(1 for s in doc_splits.values() if s == "val"),
|
||||
)
|
||||
|
||||
return doc_splits
|
||||
|
||||
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
||||
"""Generate YOLO data.yaml configuration file."""
|
||||
data = {
|
||||
|
||||
@@ -11,7 +11,7 @@ import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -22,6 +22,10 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Type alias for model path resolver function
|
||||
ModelPathResolver = Callable[[], Path | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceResult:
|
||||
"""Result from inference service."""
|
||||
@@ -42,25 +46,52 @@ class InferenceService:
|
||||
Service for running invoice field extraction.
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration
|
||||
model_config: Model configuration (default model settings)
|
||||
storage_config: Storage configuration
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
|
||||
def _resolve_model_path(self) -> Path:
|
||||
"""Resolve the model path to use for inference.
|
||||
|
||||
Priority:
|
||||
1. Active model from database (via resolver)
|
||||
2. Default model from config
|
||||
"""
|
||||
if self._model_path_resolver:
|
||||
try:
|
||||
db_model_path = self._model_path_resolver()
|
||||
if db_model_path and Path(db_model_path).exists():
|
||||
logger.info(f"Using active model from database: {db_model_path}")
|
||||
return Path(db_model_path)
|
||||
elif db_model_path:
|
||||
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
|
||||
|
||||
return self.model_config.model_path
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading)."""
|
||||
@@ -74,16 +105,20 @@ class InferenceService:
|
||||
from inference.pipeline.pipeline import InferencePipeline
|
||||
from inference.pipeline.yolo_detector import YOLODetector
|
||||
|
||||
# Resolve model path (from DB or config)
|
||||
model_path = self._resolve_model_path()
|
||||
self._current_model_path = model_path
|
||||
|
||||
# Initialize YOLO detector for visualization
|
||||
self._detector = YOLODetector(
|
||||
str(self.model_config.model_path),
|
||||
str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||
)
|
||||
|
||||
# Initialize full pipeline
|
||||
self._pipeline = InferencePipeline(
|
||||
model_path=str(self.model_config.model_path),
|
||||
model_path=str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
use_gpu=self.model_config.use_gpu,
|
||||
dpi=self.model_config.dpi,
|
||||
@@ -92,12 +127,36 @@ class InferenceService:
|
||||
|
||||
self._is_initialized = True
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s")
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
raise
|
||||
|
||||
def reload_model(self) -> bool:
|
||||
"""Reload the model if active model has changed.
|
||||
|
||||
Returns:
|
||||
True if model was reloaded, False if no change needed.
|
||||
"""
|
||||
new_model_path = self._resolve_model_path()
|
||||
|
||||
if self._current_model_path == new_model_path:
|
||||
logger.debug("Model unchanged, no reload needed")
|
||||
return False
|
||||
|
||||
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
|
||||
self._is_initialized = False
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self.initialize()
|
||||
return True
|
||||
|
||||
@property
|
||||
def current_model_path(self) -> Path | None:
|
||||
"""Get the currently loaded model path."""
|
||||
return self._current_model_path
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if service is initialized."""
|
||||
|
||||
24
packages/shared/shared/augmentation/__init__.py
Normal file
24
packages/shared/shared/augmentation/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Document Image Augmentation Module.
|
||||
|
||||
Provides augmentation transformations for training data enhancement,
|
||||
specifically designed for document images (invoices, forms, etc.).
|
||||
|
||||
Key features:
|
||||
- Document-safe augmentations that preserve text readability
|
||||
- Support for both offline preprocessing and runtime augmentation
|
||||
- Bbox-aware geometric transforms
|
||||
- Configurable augmentation pipeline
|
||||
"""
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
"AugmentationParams",
|
||||
"AugmentationResult",
|
||||
"BaseAugmentation",
|
||||
"DatasetAugmenter",
|
||||
]
|
||||
108
packages/shared/shared/augmentation/base.py
Normal file
108
packages/shared/shared/augmentation/base.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Base classes for augmentation transforms.
|
||||
|
||||
Provides abstract base class and result dataclass for all augmentation
|
||||
implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationResult:
|
||||
"""
|
||||
Result of applying an augmentation.
|
||||
|
||||
Attributes:
|
||||
image: The augmented image as numpy array (H, W, C).
|
||||
bboxes: Updated bounding boxes if geometric transform was applied.
|
||||
Format: (N, 5) array with [class_id, x_center, y_center, width, height].
|
||||
transform_matrix: The transformation matrix if applicable (for bbox adjustment).
|
||||
applied: Whether the augmentation was actually applied.
|
||||
metadata: Additional metadata about the augmentation.
|
||||
"""
|
||||
|
||||
image: np.ndarray
|
||||
bboxes: np.ndarray | None = None
|
||||
transform_matrix: np.ndarray | None = None
|
||||
applied: bool = True
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class BaseAugmentation(ABC):
|
||||
"""
|
||||
Abstract base class for all augmentations.
|
||||
|
||||
Subclasses must implement:
|
||||
- _validate_params(): Validate augmentation parameters
|
||||
- apply(): Apply the augmentation to an image
|
||||
|
||||
Class attributes:
|
||||
name: Human-readable name of the augmentation.
|
||||
affects_geometry: True if this augmentation modifies bbox coordinates.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
affects_geometry: bool = False
|
||||
|
||||
def __init__(self, params: dict[str, Any]) -> None:
|
||||
"""
|
||||
Initialize augmentation with parameters.
|
||||
|
||||
Args:
|
||||
params: Dictionary of augmentation-specific parameters.
|
||||
"""
|
||||
self.params = params
|
||||
self._validate_params()
|
||||
|
||||
@abstractmethod
|
||||
def _validate_params(self) -> None:
|
||||
"""
|
||||
Validate augmentation parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
"""
|
||||
Apply augmentation to image.
|
||||
|
||||
IMPORTANT: Implementations must NOT modify the input image or bboxes.
|
||||
Always create copies before modifying.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (H, W, C) with dtype uint8.
|
||||
bboxes: Optional bounding boxes in YOLO format (N, 5) array.
|
||||
Each row: [class_id, x_center, y_center, width, height].
|
||||
Coordinates are normalized to 0-1 range.
|
||||
rng: Random number generator for reproducibility.
|
||||
If None, a new generator should be created.
|
||||
|
||||
Returns:
|
||||
AugmentationResult with augmented image and optionally updated bboxes.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_preview_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get parameters optimized for preview display.
|
||||
|
||||
Override this method to provide parameters that produce
|
||||
clearly visible effects for preview/demo purposes.
|
||||
|
||||
Returns:
|
||||
Dictionary of preview parameters.
|
||||
"""
|
||||
return dict(self.params)
|
||||
274
packages/shared/shared/augmentation/config.py
Normal file
274
packages/shared/shared/augmentation/config.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Augmentation configuration module.
|
||||
|
||||
Provides dataclasses for configuring document image augmentations.
|
||||
All default values are document-safe (conservative) to preserve text readability.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationParams:
|
||||
"""
|
||||
Parameters for a single augmentation type.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether this augmentation is enabled.
|
||||
probability: Probability of applying this augmentation (0.0 to 1.0).
|
||||
params: Type-specific parameters dictionary.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
probability: float = 0.5
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"probability": self.probability,
|
||||
"params": dict(self.params),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AugmentationParams":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
probability=data.get("probability", 0.5),
|
||||
params=dict(data.get("params", {})),
|
||||
)
|
||||
|
||||
|
||||
def _default_perspective_warp() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"max_warp": 0.02}, # Very conservative - 2% max distortion
|
||||
)
|
||||
|
||||
|
||||
def _default_wrinkle() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"intensity": 0.3, "num_wrinkles": (2, 5)},
|
||||
)
|
||||
|
||||
|
||||
def _default_edge_damage() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"max_damage_ratio": 0.05}, # Max 5% of edge damaged
|
||||
)
|
||||
|
||||
|
||||
def _default_stain() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={
|
||||
"num_stains": (1, 3),
|
||||
"max_radius_ratio": 0.1,
|
||||
"opacity": (0.1, 0.3),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _default_lighting_variation() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=True, # Safe default, commonly needed
|
||||
probability=0.5,
|
||||
params={
|
||||
"brightness_range": (-0.1, 0.1),
|
||||
"contrast_range": (0.9, 1.1),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _default_shadow() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"num_shadows": (1, 2), "opacity": (0.2, 0.4)},
|
||||
)
|
||||
|
||||
|
||||
def _default_gaussian_blur() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
|
||||
)
|
||||
|
||||
|
||||
def _default_motion_blur() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"kernel_size": (5, 9), "angle_range": (-45, 45)},
|
||||
)
|
||||
|
||||
|
||||
def _default_gaussian_noise() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"mean": 0, "std": (5, 15)}, # Conservative noise levels
|
||||
)
|
||||
|
||||
|
||||
def _default_salt_pepper() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"amount": (0.001, 0.005)}, # Very sparse
|
||||
)
|
||||
|
||||
|
||||
def _default_paper_texture() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"texture_type": "random", "intensity": (0.05, 0.15)},
|
||||
)
|
||||
|
||||
|
||||
def _default_scanner_artifacts() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"line_probability": 0.3, "dust_probability": 0.4},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationConfig:
|
||||
"""
|
||||
Complete augmentation configuration.
|
||||
|
||||
All augmentation types have document-safe defaults that preserve
|
||||
text readability. Only lighting_variation is enabled by default.
|
||||
|
||||
Attributes:
|
||||
perspective_warp: Geometric perspective transform (affects bboxes).
|
||||
wrinkle: Paper wrinkle/crease simulation.
|
||||
edge_damage: Damaged/torn edge effects.
|
||||
stain: Coffee stain/smudge effects.
|
||||
lighting_variation: Brightness and contrast variation.
|
||||
shadow: Shadow overlay effects.
|
||||
gaussian_blur: Gaussian blur for focus issues.
|
||||
motion_blur: Motion blur simulation.
|
||||
gaussian_noise: Gaussian noise for sensor noise.
|
||||
salt_pepper: Salt and pepper noise.
|
||||
paper_texture: Paper texture overlay.
|
||||
scanner_artifacts: Scanner line and dust artifacts.
|
||||
preserve_bboxes: Whether to adjust bboxes for geometric transforms.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
# Geometric transforms (affects bboxes)
|
||||
perspective_warp: AugmentationParams = field(
|
||||
default_factory=_default_perspective_warp
|
||||
)
|
||||
|
||||
# Degradation effects
|
||||
wrinkle: AugmentationParams = field(default_factory=_default_wrinkle)
|
||||
edge_damage: AugmentationParams = field(default_factory=_default_edge_damage)
|
||||
stain: AugmentationParams = field(default_factory=_default_stain)
|
||||
|
||||
# Lighting effects
|
||||
lighting_variation: AugmentationParams = field(
|
||||
default_factory=_default_lighting_variation
|
||||
)
|
||||
shadow: AugmentationParams = field(default_factory=_default_shadow)
|
||||
|
||||
# Blur effects
|
||||
gaussian_blur: AugmentationParams = field(default_factory=_default_gaussian_blur)
|
||||
motion_blur: AugmentationParams = field(default_factory=_default_motion_blur)
|
||||
|
||||
# Noise effects
|
||||
gaussian_noise: AugmentationParams = field(default_factory=_default_gaussian_noise)
|
||||
salt_pepper: AugmentationParams = field(default_factory=_default_salt_pepper)
|
||||
|
||||
# Texture effects
|
||||
paper_texture: AugmentationParams = field(default_factory=_default_paper_texture)
|
||||
scanner_artifacts: AugmentationParams = field(
|
||||
default_factory=_default_scanner_artifacts
|
||||
)
|
||||
|
||||
# Global settings
|
||||
preserve_bboxes: bool = True
|
||||
seed: int | None = None
|
||||
|
||||
# List of all augmentation field names
|
||||
_AUGMENTATION_FIELDS: tuple[str, ...] = (
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
result: dict[str, Any] = {
|
||||
"preserve_bboxes": self.preserve_bboxes,
|
||||
"seed": self.seed,
|
||||
}
|
||||
|
||||
for field_name in self._AUGMENTATION_FIELDS:
|
||||
params: AugmentationParams = getattr(self, field_name)
|
||||
result[field_name] = params.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AugmentationConfig":
|
||||
"""Create from dictionary."""
|
||||
kwargs: dict[str, Any] = {
|
||||
"preserve_bboxes": data.get("preserve_bboxes", True),
|
||||
"seed": data.get("seed"),
|
||||
}
|
||||
|
||||
for field_name in cls._AUGMENTATION_FIELDS:
|
||||
if field_name in data:
|
||||
field_data = data[field_name]
|
||||
if isinstance(field_data, dict):
|
||||
kwargs[field_name] = AugmentationParams.from_dict(field_data)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
def get_enabled_augmentations(self) -> list[str]:
|
||||
"""Get list of enabled augmentation names."""
|
||||
enabled = []
|
||||
for field_name in self._AUGMENTATION_FIELDS:
|
||||
params: AugmentationParams = getattr(self, field_name)
|
||||
if params.enabled:
|
||||
enabled.append(field_name)
|
||||
return enabled
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Validate configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If any configuration value is invalid.
|
||||
"""
|
||||
for field_name in self._AUGMENTATION_FIELDS:
|
||||
params: AugmentationParams = getattr(self, field_name)
|
||||
if not (0.0 <= params.probability <= 1.0):
|
||||
raise ValueError(
|
||||
f"{field_name}.probability must be between 0 and 1, "
|
||||
f"got {params.probability}"
|
||||
)
|
||||
206
packages/shared/shared/augmentation/dataset_augmenter.py
Normal file
206
packages/shared/shared/augmentation/dataset_augmenter.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Dataset Augmenter Module.
|
||||
|
||||
Applies augmentation pipeline to YOLO datasets,
|
||||
creating new augmented images and label files.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetAugmenter:
|
||||
"""
|
||||
Augments YOLO datasets by creating new images and label files.
|
||||
|
||||
Reads images from dataset/images/train/ and labels from dataset/labels/train/,
|
||||
applies augmentation pipeline, and saves augmented versions with "_augN" suffix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
seed: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize augmenter with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary mapping augmentation names to their settings.
|
||||
Each augmentation should have 'enabled', 'probability', and 'params'.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
self._config_dict = config
|
||||
self._seed = seed
|
||||
self._config = self._build_config(config, seed)
|
||||
|
||||
def _build_config(
|
||||
self,
|
||||
config_dict: dict[str, Any],
|
||||
seed: int | None,
|
||||
) -> AugmentationConfig:
|
||||
"""Build AugmentationConfig from dictionary."""
|
||||
kwargs: dict[str, Any] = {"seed": seed, "preserve_bboxes": True}
|
||||
|
||||
for aug_name, aug_settings in config_dict.items():
|
||||
if aug_name in AugmentationConfig._AUGMENTATION_FIELDS:
|
||||
kwargs[aug_name] = AugmentationParams(
|
||||
enabled=aug_settings.get("enabled", False),
|
||||
probability=aug_settings.get("probability", 0.5),
|
||||
params=aug_settings.get("params", {}),
|
||||
)
|
||||
|
||||
return AugmentationConfig(**kwargs)
|
||||
|
||||
def augment_dataset(
|
||||
self,
|
||||
dataset_path: Path,
|
||||
multiplier: int = 1,
|
||||
split: str = "train",
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Augment a YOLO dataset.
|
||||
|
||||
Args:
|
||||
dataset_path: Path to dataset root (containing images/ and labels/).
|
||||
multiplier: Number of augmented copies per original image.
|
||||
split: Which split to augment (default: "train").
|
||||
|
||||
Returns:
|
||||
Summary dict with original_images, augmented_images, total_images.
|
||||
"""
|
||||
images_dir = dataset_path / "images" / split
|
||||
labels_dir = dataset_path / "labels" / split
|
||||
|
||||
if not images_dir.exists():
|
||||
raise ValueError(f"Images directory not found: {images_dir}")
|
||||
|
||||
# Find all images
|
||||
image_extensions = ("*.png", "*.jpg", "*.jpeg")
|
||||
image_files: list[Path] = []
|
||||
for ext in image_extensions:
|
||||
image_files.extend(images_dir.glob(ext))
|
||||
|
||||
original_count = len(image_files)
|
||||
augmented_count = 0
|
||||
|
||||
if multiplier <= 0:
|
||||
return {
|
||||
"original_images": original_count,
|
||||
"augmented_images": 0,
|
||||
"total_images": original_count,
|
||||
}
|
||||
|
||||
# Process each image
|
||||
for img_path in image_files:
|
||||
# Load image
|
||||
pil_image = Image.open(img_path).convert("RGB")
|
||||
image = np.array(pil_image)
|
||||
|
||||
# Load corresponding label
|
||||
label_path = labels_dir / f"{img_path.stem}.txt"
|
||||
bboxes = self._load_bboxes(label_path) if label_path.exists() else None
|
||||
|
||||
# Create multiple augmented versions
|
||||
for aug_idx in range(multiplier):
|
||||
# Create pipeline with adjusted seed for each augmentation
|
||||
aug_seed = None
|
||||
if self._seed is not None:
|
||||
aug_seed = self._seed + aug_idx + hash(img_path.stem) % 10000
|
||||
|
||||
pipeline = AugmentationPipeline(
|
||||
self._build_config(self._config_dict, aug_seed)
|
||||
)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image, bboxes)
|
||||
|
||||
# Save augmented image
|
||||
aug_name = f"{img_path.stem}_aug{aug_idx}{img_path.suffix}"
|
||||
aug_img_path = images_dir / aug_name
|
||||
aug_pil = Image.fromarray(result.image)
|
||||
aug_pil.save(aug_img_path)
|
||||
|
||||
# Save augmented label
|
||||
aug_label_path = labels_dir / f"{img_path.stem}_aug{aug_idx}.txt"
|
||||
self._save_bboxes(aug_label_path, result.bboxes)
|
||||
|
||||
augmented_count += 1
|
||||
|
||||
logger.info(
|
||||
"Dataset augmentation complete: %d original, %d augmented",
|
||||
original_count,
|
||||
augmented_count,
|
||||
)
|
||||
|
||||
return {
|
||||
"original_images": original_count,
|
||||
"augmented_images": augmented_count,
|
||||
"total_images": original_count + augmented_count,
|
||||
}
|
||||
|
||||
def _load_bboxes(self, label_path: Path) -> np.ndarray | None:
|
||||
"""
|
||||
Load bounding boxes from YOLO label file.
|
||||
|
||||
Args:
|
||||
label_path: Path to label file.
|
||||
|
||||
Returns:
|
||||
Array of shape (N, 5) with class_id, x_center, y_center, width, height.
|
||||
Returns None if file is empty or doesn't exist.
|
||||
"""
|
||||
if not label_path.exists():
|
||||
return None
|
||||
|
||||
content = label_path.read_text().strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
bboxes = []
|
||||
for line in content.split("\n"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) == 5:
|
||||
class_id = int(parts[0])
|
||||
x_center = float(parts[1])
|
||||
y_center = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
bboxes.append([class_id, x_center, y_center, width, height])
|
||||
|
||||
if not bboxes:
|
||||
return None
|
||||
|
||||
return np.array(bboxes, dtype=np.float32)
|
||||
|
||||
def _save_bboxes(self, label_path: Path, bboxes: np.ndarray | None) -> None:
|
||||
"""
|
||||
Save bounding boxes to YOLO label file.
|
||||
|
||||
Args:
|
||||
label_path: Path to save label file.
|
||||
bboxes: Array of shape (N, 5) or None for empty labels.
|
||||
"""
|
||||
if bboxes is None or len(bboxes) == 0:
|
||||
label_path.write_text("")
|
||||
return
|
||||
|
||||
lines = []
|
||||
for bbox in bboxes:
|
||||
class_id = int(bbox[0])
|
||||
x_center = bbox[1]
|
||||
y_center = bbox[2]
|
||||
width = bbox[3]
|
||||
height = bbox[4]
|
||||
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
|
||||
label_path.write_text("\n".join(lines))
|
||||
184
packages/shared/shared/augmentation/pipeline.py
Normal file
184
packages/shared/shared/augmentation/pipeline.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Augmentation pipeline module.
|
||||
|
||||
Orchestrates multiple augmentations with proper ordering and
|
||||
provides preview functionality.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.transforms.blur import GaussianBlur, MotionBlur
|
||||
from shared.augmentation.transforms.degradation import EdgeDamage, Stain, Wrinkle
|
||||
from shared.augmentation.transforms.geometric import PerspectiveWarp
|
||||
from shared.augmentation.transforms.lighting import LightingVariation, Shadow
|
||||
from shared.augmentation.transforms.noise import GaussianNoise, SaltPepper
|
||||
from shared.augmentation.transforms.texture import PaperTexture, ScannerArtifacts
|
||||
|
||||
# Registry of augmentation classes
|
||||
AUGMENTATION_REGISTRY: dict[str, type[BaseAugmentation]] = {
|
||||
"perspective_warp": PerspectiveWarp,
|
||||
"wrinkle": Wrinkle,
|
||||
"edge_damage": EdgeDamage,
|
||||
"stain": Stain,
|
||||
"lighting_variation": LightingVariation,
|
||||
"shadow": Shadow,
|
||||
"gaussian_blur": GaussianBlur,
|
||||
"motion_blur": MotionBlur,
|
||||
"gaussian_noise": GaussianNoise,
|
||||
"salt_pepper": SaltPepper,
|
||||
"paper_texture": PaperTexture,
|
||||
"scanner_artifacts": ScannerArtifacts,
|
||||
}
|
||||
|
||||
|
||||
class AugmentationPipeline:
|
||||
"""
|
||||
Orchestrates multiple augmentations with proper ordering.
|
||||
|
||||
Augmentations are applied in the following order:
|
||||
1. Geometric (perspective_warp) - affects bboxes
|
||||
2. Degradation (wrinkle, edge_damage, stain) - visual artifacts
|
||||
3. Lighting (lighting_variation, shadow)
|
||||
4. Texture (paper_texture, scanner_artifacts)
|
||||
5. Blur (gaussian_blur, motion_blur)
|
||||
6. Noise (gaussian_noise, salt_pepper) - applied last
|
||||
"""
|
||||
|
||||
STAGE_ORDER = [
|
||||
"geometric",
|
||||
"degradation",
|
||||
"lighting",
|
||||
"texture",
|
||||
"blur",
|
||||
"noise",
|
||||
]
|
||||
|
||||
STAGE_MAPPING = {
|
||||
"perspective_warp": "geometric",
|
||||
"wrinkle": "degradation",
|
||||
"edge_damage": "degradation",
|
||||
"stain": "degradation",
|
||||
"lighting_variation": "lighting",
|
||||
"shadow": "lighting",
|
||||
"paper_texture": "texture",
|
||||
"scanner_artifacts": "texture",
|
||||
"gaussian_blur": "blur",
|
||||
"motion_blur": "blur",
|
||||
"gaussian_noise": "noise",
|
||||
"salt_pepper": "noise",
|
||||
}
|
||||
|
||||
def __init__(self, config: AugmentationConfig) -> None:
|
||||
"""
|
||||
Initialize pipeline with configuration.
|
||||
|
||||
Args:
|
||||
config: Augmentation configuration.
|
||||
"""
|
||||
self.config = config
|
||||
self._rng = np.random.default_rng(config.seed)
|
||||
self._augmentations = self._build_augmentations()
|
||||
|
||||
def _build_augmentations(
|
||||
self,
|
||||
) -> list[tuple[str, BaseAugmentation, float]]:
|
||||
"""Build ordered list of (name, augmentation, probability) tuples."""
|
||||
augmentations: list[tuple[str, BaseAugmentation, float]] = []
|
||||
|
||||
for aug_name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||
params: AugmentationParams = getattr(self.config, aug_name)
|
||||
if params.enabled:
|
||||
aug = aug_class(params.params)
|
||||
augmentations.append((aug_name, aug, params.probability))
|
||||
|
||||
# Sort by stage order
|
||||
def sort_key(item: tuple[str, BaseAugmentation, float]) -> int:
|
||||
name, _, _ = item
|
||||
stage = self.STAGE_MAPPING[name]
|
||||
return self.STAGE_ORDER.index(stage)
|
||||
|
||||
return sorted(augmentations, key=sort_key)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
) -> AugmentationResult:
|
||||
"""
|
||||
Apply augmentation pipeline to image.
|
||||
|
||||
Args:
|
||||
image: Input image (H, W, C) as numpy array with dtype uint8.
|
||||
bboxes: Optional bounding boxes in YOLO format (N, 5).
|
||||
|
||||
Returns:
|
||||
AugmentationResult with augmented image and optionally adjusted bboxes.
|
||||
"""
|
||||
current_image = image.copy()
|
||||
current_bboxes = bboxes.copy() if bboxes is not None else None
|
||||
applied_augmentations: list[str] = []
|
||||
|
||||
for name, aug, probability in self._augmentations:
|
||||
if self._rng.random() < probability:
|
||||
result = aug.apply(current_image, current_bboxes, self._rng)
|
||||
current_image = result.image
|
||||
if result.bboxes is not None and self.config.preserve_bboxes:
|
||||
current_bboxes = result.bboxes
|
||||
applied_augmentations.append(name)
|
||||
|
||||
return AugmentationResult(
|
||||
image=current_image,
|
||||
bboxes=current_bboxes,
|
||||
metadata={"applied_augmentations": applied_augmentations},
|
||||
)
|
||||
|
||||
def preview(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
augmentation_name: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Preview a single augmentation deterministically.
|
||||
|
||||
Args:
|
||||
image: Input image.
|
||||
augmentation_name: Name of augmentation to preview.
|
||||
|
||||
Returns:
|
||||
Augmented image.
|
||||
|
||||
Raises:
|
||||
ValueError: If augmentation_name is not recognized.
|
||||
"""
|
||||
if augmentation_name not in AUGMENTATION_REGISTRY:
|
||||
raise ValueError(f"Unknown augmentation: {augmentation_name}")
|
||||
|
||||
params: AugmentationParams = getattr(self.config, augmentation_name)
|
||||
aug = AUGMENTATION_REGISTRY[augmentation_name](params.params)
|
||||
|
||||
# Use deterministic RNG for preview
|
||||
preview_rng = np.random.default_rng(42)
|
||||
result = aug.apply(image.copy(), rng=preview_rng)
|
||||
return result.image
|
||||
|
||||
|
||||
def get_available_augmentations() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get list of available augmentations with metadata.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with augmentation info.
|
||||
"""
|
||||
augmentations = []
|
||||
for name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||
augmentations.append({
|
||||
"name": name,
|
||||
"description": aug_class.__doc__ or "",
|
||||
"affects_geometry": aug_class.affects_geometry,
|
||||
"stage": AugmentationPipeline.STAGE_MAPPING[name],
|
||||
})
|
||||
return augmentations
|
||||
212
packages/shared/shared/augmentation/presets.py
Normal file
212
packages/shared/shared/augmentation/presets.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Predefined augmentation presets for common document scenarios.
|
||||
|
||||
Presets provide ready-to-use configurations optimized for different
|
||||
use cases, from conservative (preserves text readability) to aggressive
|
||||
(simulates poor document quality).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
|
||||
PRESETS: dict[str, dict[str, Any]] = {
|
||||
"conservative": {
|
||||
"description": "Safe augmentations that preserve text readability",
|
||||
"config": {
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.5,
|
||||
"params": {
|
||||
"brightness_range": (-0.1, 0.1),
|
||||
"contrast_range": (0.9, 1.1),
|
||||
},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"std": (3, 10)},
|
||||
},
|
||||
},
|
||||
},
|
||||
"moderate": {
|
||||
"description": "Balanced augmentations for typical document degradation",
|
||||
"config": {
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.5,
|
||||
"params": {
|
||||
"brightness_range": (-0.15, 0.15),
|
||||
"contrast_range": (0.85, 1.15),
|
||||
},
|
||||
},
|
||||
"shadow": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"num_shadows": (1, 2), "opacity": (0.2, 0.35)},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"std": (5, 12)},
|
||||
},
|
||||
"gaussian_blur": {
|
||||
"enabled": True,
|
||||
"probability": 0.2,
|
||||
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.0)},
|
||||
},
|
||||
"paper_texture": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"intensity": (0.05, 0.12)},
|
||||
},
|
||||
},
|
||||
},
|
||||
"aggressive": {
|
||||
"description": "Heavy augmentations simulating poor scan quality",
|
||||
"config": {
|
||||
"perspective_warp": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"max_warp": 0.02},
|
||||
},
|
||||
"wrinkle": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"intensity": 0.3, "num_wrinkles": (2, 4)},
|
||||
},
|
||||
"stain": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {
|
||||
"num_stains": (1, 2),
|
||||
"max_radius_ratio": 0.08,
|
||||
"opacity": (0.1, 0.25),
|
||||
},
|
||||
},
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.6,
|
||||
"params": {
|
||||
"brightness_range": (-0.2, 0.2),
|
||||
"contrast_range": (0.8, 1.2),
|
||||
},
|
||||
},
|
||||
"shadow": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"num_shadows": (1, 2), "opacity": (0.25, 0.4)},
|
||||
},
|
||||
"gaussian_blur": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
|
||||
},
|
||||
"motion_blur": {
|
||||
"enabled": True,
|
||||
"probability": 0.2,
|
||||
"params": {"kernel_size": (5, 7), "angle_range": (-30, 30)},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"std": (8, 18)},
|
||||
},
|
||||
"paper_texture": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"intensity": (0.08, 0.15)},
|
||||
},
|
||||
"scanner_artifacts": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"line_probability": 0.4, "dust_probability": 0.5},
|
||||
},
|
||||
"edge_damage": {
|
||||
"enabled": True,
|
||||
"probability": 0.2,
|
||||
"params": {"max_damage_ratio": 0.04},
|
||||
},
|
||||
},
|
||||
},
|
||||
"scanned_document": {
|
||||
"description": "Simulates typical scanned document artifacts",
|
||||
"config": {
|
||||
"scanner_artifacts": {
|
||||
"enabled": True,
|
||||
"probability": 0.5,
|
||||
"params": {"line_probability": 0.4, "dust_probability": 0.5},
|
||||
},
|
||||
"paper_texture": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"intensity": (0.05, 0.12)},
|
||||
},
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {
|
||||
"brightness_range": (-0.1, 0.1),
|
||||
"contrast_range": (0.9, 1.1),
|
||||
},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"std": (5, 12)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_preset_config(preset_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get the configuration dictionary for a preset.
|
||||
|
||||
Args:
|
||||
preset_name: Name of the preset.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If preset is not found.
|
||||
"""
|
||||
if preset_name not in PRESETS:
|
||||
raise ValueError(
|
||||
f"Unknown preset: {preset_name}. "
|
||||
f"Available presets: {list(PRESETS.keys())}"
|
||||
)
|
||||
return PRESETS[preset_name]["config"]
|
||||
|
||||
|
||||
def create_config_from_preset(preset_name: str) -> AugmentationConfig:
|
||||
"""
|
||||
Create an AugmentationConfig from a preset.
|
||||
|
||||
Args:
|
||||
preset_name: Name of the preset.
|
||||
|
||||
Returns:
|
||||
AugmentationConfig instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If preset is not found.
|
||||
"""
|
||||
config_dict = get_preset_config(preset_name)
|
||||
return AugmentationConfig.from_dict(config_dict)
|
||||
|
||||
|
||||
def list_presets() -> list[dict[str, str]]:
|
||||
"""
|
||||
List all available presets.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with name and description.
|
||||
"""
|
||||
return [
|
||||
{"name": name, "description": preset["description"]}
|
||||
for name, preset in PRESETS.items()
|
||||
]
|
||||
13
packages/shared/shared/augmentation/transforms/__init__.py
Normal file
13
packages/shared/shared/augmentation/transforms/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Augmentation transform implementations.
|
||||
|
||||
Each module contains related augmentation classes:
|
||||
- geometric.py: Perspective warp and other geometric transforms
|
||||
- degradation.py: Wrinkle, edge damage, stain effects
|
||||
- lighting.py: Lighting variation and shadow effects
|
||||
- blur.py: Gaussian and motion blur
|
||||
- noise.py: Gaussian and salt-pepper noise
|
||||
- texture.py: Paper texture and scanner artifacts
|
||||
"""
|
||||
|
||||
# Will be populated as transforms are implemented
|
||||
144
packages/shared/shared/augmentation/transforms/blur.py
Normal file
144
packages/shared/shared/augmentation/transforms/blur.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Blur augmentation transforms.
|
||||
|
||||
Provides blur effects for document image augmentation:
|
||||
- GaussianBlur: Simulates out-of-focus capture
|
||||
- MotionBlur: Simulates camera/document movement during capture
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class GaussianBlur(BaseAugmentation):
|
||||
"""
|
||||
Applies Gaussian blur to the image.
|
||||
|
||||
Simulates out-of-focus capture or low-quality optics.
|
||||
Conservative defaults to preserve text readability.
|
||||
|
||||
Parameters:
|
||||
kernel_size: Blur kernel size, int or (min, max) tuple (default: (3, 5)).
|
||||
sigma: Blur sigma, float or (min, max) tuple (default: (0.5, 1.5)).
|
||||
"""
|
||||
|
||||
name = "gaussian_blur"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
kernel_size = self.params.get("kernel_size", (3, 5))
|
||||
if isinstance(kernel_size, int):
|
||||
if kernel_size < 1 or kernel_size % 2 == 0:
|
||||
raise ValueError("kernel_size must be a positive odd integer")
|
||||
elif isinstance(kernel_size, tuple):
|
||||
if kernel_size[0] < 1 or kernel_size[1] < kernel_size[0]:
|
||||
raise ValueError("kernel_size tuple must be (min, max) with min >= 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
kernel_size = self.params.get("kernel_size", (3, 5))
|
||||
sigma = self.params.get("sigma", (0.5, 1.5))
|
||||
|
||||
if isinstance(kernel_size, tuple):
|
||||
# Choose random odd kernel size
|
||||
min_k, max_k = kernel_size
|
||||
possible_sizes = [k for k in range(min_k, max_k + 1) if k % 2 == 1]
|
||||
if not possible_sizes:
|
||||
possible_sizes = [min_k if min_k % 2 == 1 else min_k + 1]
|
||||
kernel_size = rng.choice(possible_sizes)
|
||||
|
||||
if isinstance(sigma, tuple):
|
||||
sigma = rng.uniform(sigma[0], sigma[1])
|
||||
|
||||
# Ensure kernel size is odd
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
|
||||
# Apply Gaussian blur
|
||||
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||
|
||||
return AugmentationResult(
|
||||
image=blurred,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"kernel_size": kernel_size, "sigma": sigma},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"kernel_size": 5, "sigma": 1.5}
|
||||
|
||||
|
||||
class MotionBlur(BaseAugmentation):
|
||||
"""
|
||||
Applies motion blur to the image.
|
||||
|
||||
Simulates camera shake or document movement during capture.
|
||||
|
||||
Parameters:
|
||||
kernel_size: Blur kernel size, int or (min, max) tuple (default: (5, 9)).
|
||||
angle_range: Motion angle range in degrees (default: (-45, 45)).
|
||||
"""
|
||||
|
||||
name = "motion_blur"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
kernel_size = self.params.get("kernel_size", (5, 9))
|
||||
if isinstance(kernel_size, int):
|
||||
if kernel_size < 3:
|
||||
raise ValueError("kernel_size must be at least 3")
|
||||
elif isinstance(kernel_size, tuple):
|
||||
if kernel_size[0] < 3:
|
||||
raise ValueError("kernel_size min must be at least 3")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
kernel_size = self.params.get("kernel_size", (5, 9))
|
||||
angle_range = self.params.get("angle_range", (-45, 45))
|
||||
|
||||
if isinstance(kernel_size, tuple):
|
||||
kernel_size = rng.integers(kernel_size[0], kernel_size[1] + 1)
|
||||
|
||||
angle = rng.uniform(angle_range[0], angle_range[1])
|
||||
|
||||
# Create motion blur kernel
|
||||
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
||||
|
||||
# Draw a line in the center of the kernel
|
||||
center = kernel_size // 2
|
||||
angle_rad = np.deg2rad(angle)
|
||||
|
||||
for i in range(kernel_size):
|
||||
offset = i - center
|
||||
x = int(center + offset * np.cos(angle_rad))
|
||||
y = int(center + offset * np.sin(angle_rad))
|
||||
if 0 <= x < kernel_size and 0 <= y < kernel_size:
|
||||
kernel[y, x] = 1.0
|
||||
|
||||
# Normalize kernel
|
||||
kernel = kernel / kernel.sum() if kernel.sum() > 0 else kernel
|
||||
|
||||
# Apply motion blur
|
||||
blurred = cv2.filter2D(image, -1, kernel)
|
||||
|
||||
return AugmentationResult(
|
||||
image=blurred,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"kernel_size": kernel_size, "angle": angle},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"kernel_size": 7, "angle_range": (-30, 30)}
|
||||
259
packages/shared/shared/augmentation/transforms/degradation.py
Normal file
259
packages/shared/shared/augmentation/transforms/degradation.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Degradation augmentation transforms.
|
||||
|
||||
Provides degradation effects for document image augmentation:
|
||||
- Wrinkle: Paper wrinkle/crease simulation
|
||||
- EdgeDamage: Damaged/torn edge effects
|
||||
- Stain: Coffee stain/smudge effects
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class Wrinkle(BaseAugmentation):
|
||||
"""
|
||||
Simulates paper wrinkles/creases using displacement mapping.
|
||||
|
||||
Document-friendly: Uses subtle displacement to preserve text readability.
|
||||
|
||||
Parameters:
|
||||
intensity: Wrinkle intensity (0-1) (default: 0.3).
|
||||
num_wrinkles: Number of wrinkles, int or (min, max) tuple (default: (2, 5)).
|
||||
"""
|
||||
|
||||
name = "wrinkle"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
intensity = self.params.get("intensity", 0.3)
|
||||
if not (0 < intensity <= 1):
|
||||
raise ValueError("intensity must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
intensity = self.params.get("intensity", 0.3)
|
||||
num_wrinkles = self.params.get("num_wrinkles", (2, 5))
|
||||
|
||||
if isinstance(num_wrinkles, tuple):
|
||||
num_wrinkles = rng.integers(num_wrinkles[0], num_wrinkles[1] + 1)
|
||||
|
||||
# Create displacement maps
|
||||
displacement_x = np.zeros((h, w), dtype=np.float32)
|
||||
displacement_y = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
for _ in range(num_wrinkles):
|
||||
# Random wrinkle parameters
|
||||
angle = rng.uniform(0, np.pi)
|
||||
x0 = rng.uniform(0, w)
|
||||
y0 = rng.uniform(0, h)
|
||||
length = rng.uniform(0.3, 0.8) * min(h, w)
|
||||
width = rng.uniform(0.02, 0.05) * min(h, w)
|
||||
|
||||
# Create coordinate grids
|
||||
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
|
||||
|
||||
# Distance from wrinkle line
|
||||
dx = (xx - x0) * np.cos(angle) + (yy - y0) * np.sin(angle)
|
||||
dy = -(xx - x0) * np.sin(angle) + (yy - y0) * np.cos(angle)
|
||||
|
||||
# Gaussian falloff perpendicular to wrinkle
|
||||
mask = np.exp(-dy**2 / (2 * width**2))
|
||||
mask *= (np.abs(dx) < length / 2).astype(np.float32)
|
||||
|
||||
# Displacement perpendicular to wrinkle
|
||||
disp_amount = intensity * rng.uniform(2, 8)
|
||||
displacement_x += mask * disp_amount * np.sin(angle)
|
||||
displacement_y += mask * disp_amount * np.cos(angle)
|
||||
|
||||
# Create remap coordinates
|
||||
map_x = (np.arange(w)[np.newaxis, :] + displacement_x).astype(np.float32)
|
||||
map_y = (np.arange(h)[:, np.newaxis] + displacement_y).astype(np.float32)
|
||||
|
||||
# Apply displacement
|
||||
augmented = cv2.remap(
|
||||
image, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT
|
||||
)
|
||||
|
||||
# Add subtle shading along wrinkles
|
||||
max_disp = np.max(np.abs(displacement_y)) + 1e-6
|
||||
shading = 1 - 0.1 * intensity * np.abs(displacement_y) / max_disp
|
||||
shading = shading[:, :, np.newaxis]
|
||||
augmented = (augmented.astype(np.float32) * shading).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=augmented,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"num_wrinkles": num_wrinkles, "intensity": intensity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"intensity": 0.5, "num_wrinkles": 3}
|
||||
|
||||
|
||||
class EdgeDamage(BaseAugmentation):
|
||||
"""
|
||||
Adds damaged/torn edge effects to the image.
|
||||
|
||||
Simulates worn or torn document edges.
|
||||
|
||||
Parameters:
|
||||
max_damage_ratio: Maximum proportion of edge to damage (default: 0.05).
|
||||
edges: Which edges to potentially damage (default: all).
|
||||
"""
|
||||
|
||||
name = "edge_damage"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
|
||||
if not (0 < max_damage_ratio <= 0.2):
|
||||
raise ValueError("max_damage_ratio must be between 0 and 0.2")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
|
||||
edges = self.params.get("edges", ["top", "bottom", "left", "right"])
|
||||
|
||||
output = image.copy()
|
||||
|
||||
# Select random edge to damage
|
||||
edge = rng.choice(edges)
|
||||
damage_size = int(max_damage_ratio * min(h, w))
|
||||
|
||||
if edge == "top":
|
||||
# Create irregular top edge
|
||||
for x in range(w):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
# Random color (white or darker)
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[:depth, x] = color
|
||||
|
||||
elif edge == "bottom":
|
||||
for x in range(w):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[h - depth:, x] = color
|
||||
|
||||
elif edge == "left":
|
||||
for y in range(h):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[y, :depth] = color
|
||||
|
||||
else: # right
|
||||
for y in range(h):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[y, w - depth:] = color
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"edge": edge, "damage_size": damage_size},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"max_damage_ratio": 0.08}
|
||||
|
||||
|
||||
class Stain(BaseAugmentation):
|
||||
"""
|
||||
Adds coffee stain/smudge effects to the image.
|
||||
|
||||
Simulates accidental stains on documents.
|
||||
|
||||
Parameters:
|
||||
num_stains: Number of stains, int or (min, max) tuple (default: (1, 3)).
|
||||
max_radius_ratio: Maximum stain radius as ratio of image size (default: 0.1).
|
||||
opacity: Stain opacity, float or (min, max) tuple (default: (0.1, 0.3)).
|
||||
"""
|
||||
|
||||
name = "stain"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
opacity = self.params.get("opacity", (0.1, 0.3))
|
||||
if isinstance(opacity, (int, float)):
|
||||
if not (0 < opacity <= 1):
|
||||
raise ValueError("opacity must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
num_stains = self.params.get("num_stains", (1, 3))
|
||||
max_radius_ratio = self.params.get("max_radius_ratio", 0.1)
|
||||
opacity = self.params.get("opacity", (0.1, 0.3))
|
||||
|
||||
if isinstance(num_stains, tuple):
|
||||
num_stains = rng.integers(num_stains[0], num_stains[1] + 1)
|
||||
if isinstance(opacity, tuple):
|
||||
opacity = rng.uniform(opacity[0], opacity[1])
|
||||
|
||||
output = image.astype(np.float32)
|
||||
max_radius = int(max_radius_ratio * min(h, w))
|
||||
|
||||
for _ in range(num_stains):
|
||||
# Random stain position and size
|
||||
cx = rng.integers(max_radius, w - max_radius)
|
||||
cy = rng.integers(max_radius, h - max_radius)
|
||||
radius = rng.integers(max_radius // 3, max_radius)
|
||||
|
||||
# Create stain mask with irregular edges
|
||||
yy, xx = np.ogrid[:h, :w]
|
||||
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
|
||||
|
||||
# Add noise to make edges irregular
|
||||
noise = rng.uniform(0.8, 1.2, (h, w))
|
||||
mask = (dist < radius * noise).astype(np.float32)
|
||||
|
||||
# Blur for soft edges
|
||||
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
||||
|
||||
# Random stain color (brownish/yellowish)
|
||||
stain_color = np.array([
|
||||
rng.integers(180, 220), # R
|
||||
rng.integers(160, 200), # G
|
||||
rng.integers(120, 160), # B
|
||||
], dtype=np.float32)
|
||||
|
||||
# Apply stain
|
||||
mask_3d = mask[:, :, np.newaxis]
|
||||
output = output * (1 - mask_3d * opacity) + stain_color * mask_3d * opacity
|
||||
|
||||
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"num_stains": num_stains, "opacity": opacity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"num_stains": 2, "max_radius_ratio": 0.1, "opacity": 0.25}
|
||||
145
packages/shared/shared/augmentation/transforms/geometric.py
Normal file
145
packages/shared/shared/augmentation/transforms/geometric.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Geometric augmentation transforms.
|
||||
|
||||
Provides geometric transforms for document image augmentation:
|
||||
- PerspectiveWarp: Subtle perspective distortion
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class PerspectiveWarp(BaseAugmentation):
|
||||
"""
|
||||
Applies subtle perspective transformation to the image.
|
||||
|
||||
Simulates viewing document at slight angle. Very conservative
|
||||
by default to preserve text readability.
|
||||
|
||||
IMPORTANT: This transform affects bounding box coordinates.
|
||||
|
||||
Parameters:
|
||||
max_warp: Maximum warp as proportion of image size (default: 0.02).
|
||||
"""
|
||||
|
||||
name = "perspective_warp"
|
||||
affects_geometry = True
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
max_warp = self.params.get("max_warp", 0.02)
|
||||
if not (0 < max_warp <= 0.1):
|
||||
raise ValueError("max_warp must be between 0 and 0.1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
max_warp = self.params.get("max_warp", 0.02)
|
||||
|
||||
# Original corners
|
||||
src_pts = np.float32([
|
||||
[0, 0],
|
||||
[w, 0],
|
||||
[w, h],
|
||||
[0, h],
|
||||
])
|
||||
|
||||
# Add random perturbations to corners
|
||||
max_offset = max_warp * min(h, w)
|
||||
dst_pts = src_pts.copy()
|
||||
for i in range(4):
|
||||
dst_pts[i, 0] += rng.uniform(-max_offset, max_offset)
|
||||
dst_pts[i, 1] += rng.uniform(-max_offset, max_offset)
|
||||
|
||||
# Compute perspective transform matrix
|
||||
transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
|
||||
|
||||
# Apply perspective transform
|
||||
warped = cv2.warpPerspective(
|
||||
image, transform_matrix, (w, h),
|
||||
borderMode=cv2.BORDER_REPLICATE
|
||||
)
|
||||
|
||||
# Transform bounding boxes if present
|
||||
transformed_bboxes = None
|
||||
if bboxes is not None:
|
||||
transformed_bboxes = self._transform_bboxes(
|
||||
bboxes, transform_matrix, w, h
|
||||
)
|
||||
|
||||
return AugmentationResult(
|
||||
image=warped,
|
||||
bboxes=transformed_bboxes,
|
||||
transform_matrix=transform_matrix,
|
||||
metadata={"max_warp": max_warp},
|
||||
)
|
||||
|
||||
def _transform_bboxes(
|
||||
self,
|
||||
bboxes: np.ndarray,
|
||||
transform_matrix: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> np.ndarray:
|
||||
"""Transform bounding boxes using perspective matrix."""
|
||||
if len(bboxes) == 0:
|
||||
return bboxes.copy()
|
||||
|
||||
transformed = []
|
||||
for bbox in bboxes:
|
||||
class_id, x_center, y_center, width, height = bbox
|
||||
|
||||
# Convert normalized coords to pixel coords
|
||||
x_center_px = x_center * w
|
||||
y_center_px = y_center * h
|
||||
width_px = width * w
|
||||
height_px = height * h
|
||||
|
||||
# Get corner points
|
||||
x1 = x_center_px - width_px / 2
|
||||
y1 = y_center_px - height_px / 2
|
||||
x2 = x_center_px + width_px / 2
|
||||
y2 = y_center_px + height_px / 2
|
||||
|
||||
# Transform all 4 corners
|
||||
corners = np.float32([
|
||||
[x1, y1],
|
||||
[x2, y1],
|
||||
[x2, y2],
|
||||
[x1, y2],
|
||||
]).reshape(-1, 1, 2)
|
||||
|
||||
transformed_corners = cv2.perspectiveTransform(corners, transform_matrix)
|
||||
transformed_corners = transformed_corners.reshape(-1, 2)
|
||||
|
||||
# Get bounding box of transformed corners
|
||||
new_x1 = np.min(transformed_corners[:, 0])
|
||||
new_y1 = np.min(transformed_corners[:, 1])
|
||||
new_x2 = np.max(transformed_corners[:, 0])
|
||||
new_y2 = np.max(transformed_corners[:, 1])
|
||||
|
||||
# Convert back to normalized center format
|
||||
new_width = (new_x2 - new_x1) / w
|
||||
new_height = (new_y2 - new_y1) / h
|
||||
new_x_center = ((new_x1 + new_x2) / 2) / w
|
||||
new_y_center = ((new_y1 + new_y2) / 2) / h
|
||||
|
||||
# Clamp to valid range
|
||||
new_x_center = np.clip(new_x_center, 0, 1)
|
||||
new_y_center = np.clip(new_y_center, 0, 1)
|
||||
new_width = np.clip(new_width, 0, 1)
|
||||
new_height = np.clip(new_height, 0, 1)
|
||||
|
||||
transformed.append([class_id, new_x_center, new_y_center, new_width, new_height])
|
||||
|
||||
return np.array(transformed, dtype=np.float32)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"max_warp": 0.03}
|
||||
167
packages/shared/shared/augmentation/transforms/lighting.py
Normal file
167
packages/shared/shared/augmentation/transforms/lighting.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Lighting augmentation transforms.
|
||||
|
||||
Provides lighting effects for document image augmentation:
|
||||
- LightingVariation: Adjusts brightness and contrast
|
||||
- Shadow: Adds shadow overlay effects
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class LightingVariation(BaseAugmentation):
|
||||
"""
|
||||
Adjusts image brightness and contrast.
|
||||
|
||||
Simulates different lighting conditions during document capture.
|
||||
Safe for documents with conservative default parameters.
|
||||
|
||||
Parameters:
|
||||
brightness_range: (min, max) brightness adjustment (default: (-0.1, 0.1)).
|
||||
contrast_range: (min, max) contrast multiplier (default: (0.9, 1.1)).
|
||||
"""
|
||||
|
||||
name = "lighting_variation"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
brightness = self.params.get("brightness_range", (-0.1, 0.1))
|
||||
contrast = self.params.get("contrast_range", (0.9, 1.1))
|
||||
|
||||
if not isinstance(brightness, tuple) or len(brightness) != 2:
|
||||
raise ValueError("brightness_range must be a (min, max) tuple")
|
||||
if not isinstance(contrast, tuple) or len(contrast) != 2:
|
||||
raise ValueError("contrast_range must be a (min, max) tuple")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
brightness_range = self.params.get("brightness_range", (-0.1, 0.1))
|
||||
contrast_range = self.params.get("contrast_range", (0.9, 1.1))
|
||||
|
||||
# Random brightness and contrast
|
||||
brightness = rng.uniform(brightness_range[0], brightness_range[1])
|
||||
contrast = rng.uniform(contrast_range[0], contrast_range[1])
|
||||
|
||||
# Apply adjustments
|
||||
adjusted = image.astype(np.float32)
|
||||
|
||||
# Contrast adjustment (multiply around mean)
|
||||
mean = adjusted.mean()
|
||||
adjusted = (adjusted - mean) * contrast + mean
|
||||
|
||||
# Brightness adjustment (add offset)
|
||||
adjusted = adjusted + brightness * 255
|
||||
|
||||
# Clip and convert back
|
||||
adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=adjusted,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"brightness": brightness, "contrast": contrast},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"brightness_range": (-0.15, 0.15), "contrast_range": (0.85, 1.15)}
|
||||
|
||||
|
||||
class Shadow(BaseAugmentation):
|
||||
"""
|
||||
Adds shadow overlay effects to the image.
|
||||
|
||||
Simulates shadows from objects or hands during document capture.
|
||||
|
||||
Parameters:
|
||||
num_shadows: Number of shadow regions, int or (min, max) tuple (default: (1, 2)).
|
||||
opacity: Shadow darkness, float or (min, max) tuple (default: (0.2, 0.4)).
|
||||
"""
|
||||
|
||||
name = "shadow"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
opacity = self.params.get("opacity", (0.2, 0.4))
|
||||
if isinstance(opacity, (int, float)):
|
||||
if not (0 <= opacity <= 1):
|
||||
raise ValueError("opacity must be between 0 and 1")
|
||||
elif isinstance(opacity, tuple):
|
||||
if not (0 <= opacity[0] <= opacity[1] <= 1):
|
||||
raise ValueError("opacity tuple must be in range [0, 1]")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
num_shadows = self.params.get("num_shadows", (1, 2))
|
||||
opacity = self.params.get("opacity", (0.2, 0.4))
|
||||
|
||||
if isinstance(num_shadows, tuple):
|
||||
num_shadows = rng.integers(num_shadows[0], num_shadows[1] + 1)
|
||||
if isinstance(opacity, tuple):
|
||||
opacity = rng.uniform(opacity[0], opacity[1])
|
||||
|
||||
h, w = image.shape[:2]
|
||||
output = image.astype(np.float32)
|
||||
|
||||
for _ in range(num_shadows):
|
||||
# Generate random shadow polygon
|
||||
num_vertices = rng.integers(3, 6)
|
||||
vertices = []
|
||||
|
||||
# Start from a random edge
|
||||
edge = rng.integers(0, 4)
|
||||
if edge == 0: # Top
|
||||
start = (rng.integers(0, w), 0)
|
||||
elif edge == 1: # Right
|
||||
start = (w, rng.integers(0, h))
|
||||
elif edge == 2: # Bottom
|
||||
start = (rng.integers(0, w), h)
|
||||
else: # Left
|
||||
start = (0, rng.integers(0, h))
|
||||
|
||||
vertices.append(start)
|
||||
|
||||
# Add random vertices
|
||||
for _ in range(num_vertices - 1):
|
||||
x = rng.integers(0, w)
|
||||
y = rng.integers(0, h)
|
||||
vertices.append((x, y))
|
||||
|
||||
# Create shadow mask
|
||||
mask = np.zeros((h, w), dtype=np.float32)
|
||||
pts = np.array(vertices, dtype=np.int32).reshape((-1, 1, 2))
|
||||
cv2.fillPoly(mask, [pts], 1.0)
|
||||
|
||||
# Blur the mask for soft edges
|
||||
blur_size = max(31, min(h, w) // 10)
|
||||
if blur_size % 2 == 0:
|
||||
blur_size += 1
|
||||
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
|
||||
|
||||
# Apply shadow
|
||||
shadow_factor = 1 - opacity * mask[:, :, np.newaxis]
|
||||
output = output * shadow_factor
|
||||
|
||||
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"num_shadows": num_shadows, "opacity": opacity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"num_shadows": 1, "opacity": 0.3}
|
||||
142
packages/shared/shared/augmentation/transforms/noise.py
Normal file
142
packages/shared/shared/augmentation/transforms/noise.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Noise augmentation transforms.
|
||||
|
||||
Provides noise effects for document image augmentation:
|
||||
- GaussianNoise: Adds Gaussian noise to simulate sensor noise
|
||||
- SaltPepper: Adds salt and pepper noise for impulse noise effects
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class GaussianNoise(BaseAugmentation):
|
||||
"""
|
||||
Adds Gaussian noise to the image.
|
||||
|
||||
Simulates sensor noise from cameras or scanners.
|
||||
Document-safe with conservative default parameters.
|
||||
|
||||
Parameters:
|
||||
mean: Mean of the Gaussian noise (default: 0).
|
||||
std: Standard deviation, can be int or (min, max) tuple (default: (5, 15)).
|
||||
"""
|
||||
|
||||
name = "gaussian_noise"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
std = self.params.get("std", (5, 15))
|
||||
if isinstance(std, (int, float)):
|
||||
if std < 0:
|
||||
raise ValueError("std must be non-negative")
|
||||
elif isinstance(std, tuple):
|
||||
if len(std) != 2 or std[0] < 0 or std[1] < std[0]:
|
||||
raise ValueError("std tuple must be (min, max) with min <= max >= 0")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
mean = self.params.get("mean", 0)
|
||||
std = self.params.get("std", (5, 15))
|
||||
|
||||
if isinstance(std, tuple):
|
||||
std = rng.uniform(std[0], std[1])
|
||||
|
||||
# Generate noise
|
||||
noise = rng.normal(mean, std, image.shape).astype(np.float32)
|
||||
|
||||
# Apply noise
|
||||
noisy = image.astype(np.float32) + noise
|
||||
noisy = np.clip(noisy, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=noisy,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"applied_std": std},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict[str, Any]:
|
||||
return {"mean": 0, "std": 15}
|
||||
|
||||
|
||||
class SaltPepper(BaseAugmentation):
|
||||
"""
|
||||
Adds salt and pepper (impulse) noise to the image.
|
||||
|
||||
Simulates defects from damaged sensors or transmission errors.
|
||||
Very sparse by default to preserve document readability.
|
||||
|
||||
Parameters:
|
||||
amount: Proportion of pixels to affect, can be float or (min, max) tuple.
|
||||
Default: (0.001, 0.005) for very sparse noise.
|
||||
salt_vs_pepper: Ratio of salt to pepper (default: 0.5 for equal amounts).
|
||||
"""
|
||||
|
||||
name = "salt_pepper"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
amount = self.params.get("amount", (0.001, 0.005))
|
||||
if isinstance(amount, (int, float)):
|
||||
if not (0 <= amount <= 1):
|
||||
raise ValueError("amount must be between 0 and 1")
|
||||
elif isinstance(amount, tuple):
|
||||
if len(amount) != 2 or not (0 <= amount[0] <= amount[1] <= 1):
|
||||
raise ValueError("amount tuple must be (min, max) in range [0, 1]")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
amount = self.params.get("amount", (0.001, 0.005))
|
||||
salt_vs_pepper = self.params.get("salt_vs_pepper", 0.5)
|
||||
|
||||
if isinstance(amount, tuple):
|
||||
amount = rng.uniform(amount[0], amount[1])
|
||||
|
||||
# Copy image
|
||||
output = image.copy()
|
||||
h, w = image.shape[:2]
|
||||
total_pixels = h * w
|
||||
|
||||
# Calculate number of salt and pepper pixels
|
||||
num_salt = int(total_pixels * amount * salt_vs_pepper)
|
||||
num_pepper = int(total_pixels * amount * (1 - salt_vs_pepper))
|
||||
|
||||
# Add salt (white pixels)
|
||||
if num_salt > 0:
|
||||
salt_coords = (
|
||||
rng.integers(0, h, num_salt),
|
||||
rng.integers(0, w, num_salt),
|
||||
)
|
||||
output[salt_coords] = 255
|
||||
|
||||
# Add pepper (black pixels)
|
||||
if num_pepper > 0:
|
||||
pepper_coords = (
|
||||
rng.integers(0, h, num_pepper),
|
||||
rng.integers(0, w, num_pepper),
|
||||
)
|
||||
output[pepper_coords] = 0
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"applied_amount": amount},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict[str, Any]:
|
||||
return {"amount": 0.01, "salt_vs_pepper": 0.5}
|
||||
159
packages/shared/shared/augmentation/transforms/texture.py
Normal file
159
packages/shared/shared/augmentation/transforms/texture.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Texture augmentation transforms.
|
||||
|
||||
Provides texture effects for document image augmentation:
|
||||
- PaperTexture: Adds paper grain/texture
|
||||
- ScannerArtifacts: Adds scanner line and dust artifacts
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class PaperTexture(BaseAugmentation):
|
||||
"""
|
||||
Adds paper texture/grain to the image.
|
||||
|
||||
Simulates different paper types and ages.
|
||||
|
||||
Parameters:
|
||||
texture_type: Type of texture ("random", "fine", "coarse") (default: "random").
|
||||
intensity: Texture intensity, float or (min, max) tuple (default: (0.05, 0.15)).
|
||||
"""
|
||||
|
||||
name = "paper_texture"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
intensity = self.params.get("intensity", (0.05, 0.15))
|
||||
if isinstance(intensity, (int, float)):
|
||||
if not (0 < intensity <= 1):
|
||||
raise ValueError("intensity must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
texture_type = self.params.get("texture_type", "random")
|
||||
intensity = self.params.get("intensity", (0.05, 0.15))
|
||||
|
||||
if texture_type == "random":
|
||||
texture_type = rng.choice(["fine", "coarse"])
|
||||
|
||||
if isinstance(intensity, tuple):
|
||||
intensity = rng.uniform(intensity[0], intensity[1])
|
||||
|
||||
# Generate base noise
|
||||
if texture_type == "fine":
|
||||
# Fine grain texture
|
||||
noise = rng.uniform(-1, 1, (h, w)).astype(np.float32)
|
||||
noise = cv2.GaussianBlur(noise, (3, 3), 0)
|
||||
else:
|
||||
# Coarse texture
|
||||
# Generate at lower resolution and upscale
|
||||
small_h, small_w = h // 4, w // 4
|
||||
noise = rng.uniform(-1, 1, (small_h, small_w)).astype(np.float32)
|
||||
noise = cv2.resize(noise, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
noise = cv2.GaussianBlur(noise, (5, 5), 0)
|
||||
|
||||
# Apply texture
|
||||
output = image.astype(np.float32)
|
||||
noise_3d = noise[:, :, np.newaxis] * intensity * 255
|
||||
output = output + noise_3d
|
||||
|
||||
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"texture_type": texture_type, "intensity": intensity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"texture_type": "coarse", "intensity": 0.15}
|
||||
|
||||
|
||||
class ScannerArtifacts(BaseAugmentation):
|
||||
"""
|
||||
Adds scanner artifacts to the image.
|
||||
|
||||
Simulates scanner imperfections like lines and dust spots.
|
||||
|
||||
Parameters:
|
||||
line_probability: Probability of adding scan lines (default: 0.3).
|
||||
dust_probability: Probability of adding dust spots (default: 0.4).
|
||||
"""
|
||||
|
||||
name = "scanner_artifacts"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
line_prob = self.params.get("line_probability", 0.3)
|
||||
dust_prob = self.params.get("dust_probability", 0.4)
|
||||
if not (0 <= line_prob <= 1):
|
||||
raise ValueError("line_probability must be between 0 and 1")
|
||||
if not (0 <= dust_prob <= 1):
|
||||
raise ValueError("dust_probability must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
line_probability = self.params.get("line_probability", 0.3)
|
||||
dust_probability = self.params.get("dust_probability", 0.4)
|
||||
|
||||
output = image.copy()
|
||||
|
||||
# Add scan lines
|
||||
if rng.random() < line_probability:
|
||||
num_lines = rng.integers(1, 4)
|
||||
for _ in range(num_lines):
|
||||
y = rng.integers(0, h)
|
||||
thickness = rng.integers(1, 3)
|
||||
# Light or dark line
|
||||
color = rng.integers(200, 240) if rng.random() > 0.5 else rng.integers(50, 100)
|
||||
|
||||
# Make line partially transparent
|
||||
alpha = rng.uniform(0.3, 0.6)
|
||||
for dy in range(thickness):
|
||||
if y + dy < h:
|
||||
output[y + dy, :] = (
|
||||
output[y + dy, :].astype(np.float32) * (1 - alpha) +
|
||||
color * alpha
|
||||
).astype(np.uint8)
|
||||
|
||||
# Add dust spots
|
||||
if rng.random() < dust_probability:
|
||||
num_dust = rng.integers(5, 20)
|
||||
for _ in range(num_dust):
|
||||
x = rng.integers(0, w)
|
||||
y = rng.integers(0, h)
|
||||
radius = rng.integers(1, 3)
|
||||
|
||||
# Dark dust spot
|
||||
color = rng.integers(50, 120)
|
||||
cv2.circle(output, (x, y), radius, int(color), -1)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={
|
||||
"line_probability": line_probability,
|
||||
"dust_probability": dust_probability,
|
||||
},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"line_probability": 0.8, "dust_probability": 0.8}
|
||||
5
packages/shared/shared/training/__init__.py
Normal file
5
packages/shared/shared/training/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Shared training utilities."""
|
||||
|
||||
from .yolo_trainer import YOLOTrainer, TrainingConfig, TrainingResult
|
||||
|
||||
__all__ = ["YOLOTrainer", "TrainingConfig", "TrainingResult"]
|
||||
239
packages/shared/shared/training/yolo_trainer.py
Normal file
239
packages/shared/shared/training/yolo_trainer.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Shared YOLO Training Module
|
||||
|
||||
Unified training logic for both CLI and Web API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Training configuration."""
|
||||
|
||||
# Model settings
|
||||
model_path: str = "yolo11n.pt" # Base model or path to trained model
|
||||
data_yaml: str = "" # Path to data.yaml
|
||||
|
||||
# Training hyperparameters
|
||||
epochs: int = 100
|
||||
batch_size: int = 16
|
||||
image_size: int = 640
|
||||
learning_rate: float = 0.01
|
||||
device: str = "0"
|
||||
|
||||
# Output settings
|
||||
project: str = "runs/train"
|
||||
name: str = "invoice_fields"
|
||||
|
||||
# Performance settings
|
||||
workers: int = 4
|
||||
cache: bool = False
|
||||
|
||||
# Resume settings
|
||||
resume: bool = False
|
||||
resume_from: str | None = None # Path to checkpoint
|
||||
|
||||
# Document-specific augmentation (optimized for invoices)
|
||||
augmentation: dict[str, Any] = field(default_factory=lambda: {
|
||||
"degrees": 5.0,
|
||||
"translate": 0.05,
|
||||
"scale": 0.2,
|
||||
"shear": 0.0,
|
||||
"perspective": 0.0,
|
||||
"flipud": 0.0,
|
||||
"fliplr": 0.0,
|
||||
"mosaic": 0.0,
|
||||
"mixup": 0.0,
|
||||
"hsv_h": 0.0,
|
||||
"hsv_s": 0.1,
|
||||
"hsv_v": 0.2,
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingResult:
|
||||
"""Training result."""
|
||||
|
||||
success: bool
|
||||
model_path: str | None = None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
save_dir: str | None = None
|
||||
|
||||
|
||||
class YOLOTrainer:
|
||||
"""Unified YOLO trainer for CLI and Web API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainingConfig,
|
||||
log_callback: Callable[[str, str], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize trainer.
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
log_callback: Optional callback for logging (level, message)
|
||||
"""
|
||||
self.config = config
|
||||
self._log_callback = log_callback
|
||||
|
||||
def _log(self, level: str, message: str) -> None:
|
||||
"""Log a message."""
|
||||
if self._log_callback:
|
||||
self._log_callback(level, message)
|
||||
if level == "INFO":
|
||||
logger.info(message)
|
||||
elif level == "ERROR":
|
||||
logger.error(message)
|
||||
elif level == "WARNING":
|
||||
logger.warning(message)
|
||||
|
||||
def validate_config(self) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate training configuration.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Check model path
|
||||
model_path = Path(self.config.model_path)
|
||||
if not model_path.suffix == ".pt":
|
||||
# Could be a model name like "yolo11n.pt" which is downloaded
|
||||
if not model_path.name.startswith("yolo"):
|
||||
return False, f"Invalid model: {self.config.model_path}"
|
||||
elif not model_path.exists():
|
||||
return False, f"Model file not found: {self.config.model_path}"
|
||||
|
||||
# Check data.yaml
|
||||
if not self.config.data_yaml:
|
||||
return False, "data_yaml is required"
|
||||
data_yaml = Path(self.config.data_yaml)
|
||||
if not data_yaml.exists():
|
||||
return False, f"data.yaml not found: {self.config.data_yaml}"
|
||||
|
||||
return True, None
|
||||
|
||||
def train(self) -> TrainingResult:
|
||||
"""
|
||||
Run YOLO training.
|
||||
|
||||
Returns:
|
||||
TrainingResult with model path and metrics
|
||||
"""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
except ImportError:
|
||||
return TrainingResult(
|
||||
success=False,
|
||||
error="Ultralytics (YOLO) not installed. Install with: pip install ultralytics",
|
||||
)
|
||||
|
||||
# Validate config
|
||||
is_valid, error = self.validate_config()
|
||||
if not is_valid:
|
||||
return TrainingResult(success=False, error=error)
|
||||
|
||||
self._log("INFO", f"Starting YOLO training")
|
||||
self._log("INFO", f" Model: {self.config.model_path}")
|
||||
self._log("INFO", f" Data: {self.config.data_yaml}")
|
||||
self._log("INFO", f" Epochs: {self.config.epochs}")
|
||||
self._log("INFO", f" Batch size: {self.config.batch_size}")
|
||||
self._log("INFO", f" Image size: {self.config.image_size}")
|
||||
|
||||
try:
|
||||
# Load model
|
||||
if self.config.resume and self.config.resume_from:
|
||||
resume_path = Path(self.config.resume_from)
|
||||
if resume_path.exists():
|
||||
self._log("INFO", f"Resuming from: {resume_path}")
|
||||
model = YOLO(str(resume_path))
|
||||
else:
|
||||
model = YOLO(self.config.model_path)
|
||||
else:
|
||||
model = YOLO(self.config.model_path)
|
||||
|
||||
# Build training arguments
|
||||
train_args = {
|
||||
"data": str(Path(self.config.data_yaml).absolute()),
|
||||
"epochs": self.config.epochs,
|
||||
"batch": self.config.batch_size,
|
||||
"imgsz": self.config.image_size,
|
||||
"lr0": self.config.learning_rate,
|
||||
"device": self.config.device,
|
||||
"project": self.config.project,
|
||||
"name": self.config.name,
|
||||
"exist_ok": True,
|
||||
"pretrained": True,
|
||||
"verbose": True,
|
||||
"workers": self.config.workers,
|
||||
"cache": self.config.cache,
|
||||
"resume": self.config.resume and self.config.resume_from is not None,
|
||||
}
|
||||
|
||||
# Add augmentation settings
|
||||
train_args.update(self.config.augmentation)
|
||||
|
||||
# Train
|
||||
results = model.train(**train_args)
|
||||
|
||||
# Get best model path
|
||||
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||
|
||||
# Extract metrics
|
||||
metrics = {}
|
||||
if hasattr(results, "results_dict"):
|
||||
metrics = {
|
||||
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||
}
|
||||
|
||||
self._log("INFO", f"Training completed successfully")
|
||||
self._log("INFO", f" Best model: {best_model}")
|
||||
self._log("INFO", f" mAP@0.5: {metrics.get('mAP50', 'N/A')}")
|
||||
|
||||
return TrainingResult(
|
||||
success=True,
|
||||
model_path=str(best_model) if best_model.exists() else None,
|
||||
metrics=metrics,
|
||||
save_dir=str(results.save_dir),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._log("ERROR", f"Training failed: {e}")
|
||||
return TrainingResult(success=False, error=str(e))
|
||||
|
||||
def validate(self, split: str = "val") -> dict[str, float]:
|
||||
"""
|
||||
Run validation on trained model.
|
||||
|
||||
Args:
|
||||
split: Dataset split to validate on ("val" or "test")
|
||||
|
||||
Returns:
|
||||
Validation metrics
|
||||
"""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO(self.config.model_path)
|
||||
metrics = model.val(data=self.config.data_yaml, split=split)
|
||||
|
||||
return {
|
||||
"mAP50": metrics.box.map50,
|
||||
"mAP50-95": metrics.box.map,
|
||||
"precision": metrics.box.mp,
|
||||
"recall": metrics.box.mr,
|
||||
}
|
||||
except Exception as e:
|
||||
self._log("ERROR", f"Validation failed: {e}")
|
||||
return {}
|
||||
@@ -199,67 +199,63 @@ def main():
|
||||
db.close()
|
||||
return
|
||||
|
||||
# Start training
|
||||
# Start training using shared trainer
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting YOLO Training")
|
||||
print("=" * 60)
|
||||
|
||||
from ultralytics import YOLO
|
||||
from shared.training import YOLOTrainer, TrainingConfig
|
||||
|
||||
# Load model
|
||||
# Determine resume checkpoint
|
||||
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
|
||||
if args.resume and last_checkpoint.exists():
|
||||
print(f"Resuming from: {last_checkpoint}")
|
||||
model = YOLO(str(last_checkpoint))
|
||||
else:
|
||||
model = YOLO(args.model)
|
||||
resume_from = str(last_checkpoint) if args.resume and last_checkpoint.exists() else None
|
||||
|
||||
# Training arguments
|
||||
# Create training config
|
||||
data_yaml = dataset_dir / 'dataset.yaml'
|
||||
train_args = {
|
||||
'data': str(data_yaml.absolute()),
|
||||
'epochs': args.epochs,
|
||||
'batch': args.batch,
|
||||
'imgsz': args.imgsz,
|
||||
'project': args.project,
|
||||
'name': args.name,
|
||||
'device': args.device,
|
||||
'exist_ok': True,
|
||||
'pretrained': True,
|
||||
'verbose': True,
|
||||
'workers': args.workers,
|
||||
'cache': args.cache,
|
||||
'resume': args.resume and last_checkpoint.exists(),
|
||||
# Document-specific augmentation settings
|
||||
'degrees': 5.0,
|
||||
'translate': 0.05,
|
||||
'scale': 0.2,
|
||||
'shear': 0.0,
|
||||
'perspective': 0.0,
|
||||
'flipud': 0.0,
|
||||
'fliplr': 0.0,
|
||||
'mosaic': 0.0,
|
||||
'mixup': 0.0,
|
||||
'hsv_h': 0.0,
|
||||
'hsv_s': 0.1,
|
||||
'hsv_v': 0.2,
|
||||
}
|
||||
config = TrainingConfig(
|
||||
model_path=args.model,
|
||||
data_yaml=str(data_yaml),
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch,
|
||||
image_size=args.imgsz,
|
||||
device=args.device,
|
||||
project=args.project,
|
||||
name=args.name,
|
||||
workers=args.workers,
|
||||
cache=args.cache,
|
||||
resume=args.resume,
|
||||
resume_from=resume_from,
|
||||
)
|
||||
|
||||
# Train
|
||||
results = model.train(**train_args)
|
||||
# Run training
|
||||
trainer = YOLOTrainer(config=config)
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
print(f"\nError: Training failed - {result.error}")
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("Training Complete")
|
||||
print("=" * 60)
|
||||
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
|
||||
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
|
||||
print(f"Best model: {result.model_path}")
|
||||
print(f"Save directory: {result.save_dir}")
|
||||
if result.metrics:
|
||||
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
|
||||
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
|
||||
|
||||
# Validate on test set
|
||||
print("\nRunning validation on test set...")
|
||||
metrics = model.val(split='test')
|
||||
print(f"mAP50: {metrics.box.map50:.4f}")
|
||||
print(f"mAP50-95: {metrics.box.map:.4f}")
|
||||
if result.model_path:
|
||||
config.model_path = result.model_path
|
||||
config.data_yaml = str(data_yaml)
|
||||
test_trainer = YOLOTrainer(config=config)
|
||||
test_metrics = test_trainer.validate(split='test')
|
||||
if test_metrics:
|
||||
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
|
||||
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
|
||||
|
||||
# Close database
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user