This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -25,6 +25,7 @@ from inference.data.admin_models import (
AnnotationHistory,
TrainingDataset,
DatasetDocument,
ModelVersion,
)
logger = logging.getLogger(__name__)
@@ -110,6 +111,7 @@ class AdminDB:
page_count: int = 1,
upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None,
group_key: str | None = None,
admin_token: str | None = None, # Deprecated, kept for compatibility
) -> str:
"""Create a new document record."""
@@ -122,6 +124,7 @@ class AdminDB:
page_count=page_count,
upload_source=upload_source,
csv_field_values=csv_field_values,
group_key=group_key,
)
session.add(document)
session.flush()
@@ -253,6 +256,17 @@ class AdminDB:
document.updated_at = datetime.utcnow()
session.add(document)
def update_document_group_key(self, document_id: str, group_key: str | None) -> bool:
"""Update document group key."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.group_key = group_key
document.updated_at = datetime.utcnow()
session.add(document)
return True
return False
def delete_document(self, document_id: str) -> bool:
"""Delete a document and its annotations."""
with get_session_context() as session:
@@ -1215,6 +1229,39 @@ class AdminDB:
session.expunge(d)
return list(datasets), total
def get_active_training_tasks_for_datasets(
self, dataset_ids: list[str]
) -> dict[str, dict[str, str]]:
"""Get active (pending/scheduled/running) training tasks for datasets.
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
"""
if not dataset_ids:
return {}
# Validate UUIDs before query
valid_uuids = []
for d in dataset_ids:
try:
valid_uuids.append(UUID(d))
except ValueError:
logger.warning("Invalid UUID in get_active_training_tasks_for_datasets: %s", d)
continue
if not valid_uuids:
return {}
with get_session_context() as session:
statement = select(TrainingTask).where(
TrainingTask.dataset_id.in_(valid_uuids),
TrainingTask.status.in_(["pending", "scheduled", "running"]),
)
results = session.exec(statement).all()
return {
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
for t in results
}
def update_dataset_status(
self,
dataset_id: str | UUID,
@@ -1314,3 +1361,182 @@ class AdminDB:
session.delete(dataset)
session.commit()
return True
# ==========================================================================
# Model Version Operations
# ==========================================================================
def create_model_version(
self,
version: str,
name: str,
model_path: str,
description: str | None = None,
task_id: str | UUID | None = None,
dataset_id: str | UUID | None = None,
metrics_mAP: float | None = None,
metrics_precision: float | None = None,
metrics_recall: float | None = None,
document_count: int = 0,
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
model = ModelVersion(
version=version,
name=name,
model_path=model_path,
description=description,
task_id=UUID(str(task_id)) if task_id else None,
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
)
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def get_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Get a model version by ID."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if model:
session.expunge(model)
return model
def get_model_versions(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[ModelVersion], int]:
"""List model versions with optional status filter."""
with get_session_context() as session:
query = select(ModelVersion)
count_query = select(func.count()).select_from(ModelVersion)
if status:
query = query.where(ModelVersion.status == status)
count_query = count_query.where(ModelVersion.status == status)
total = session.exec(count_query).one()
models = session.exec(
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
).all()
for m in models:
session.expunge(m)
return list(models), total
def get_active_model_version(self) -> ModelVersion | None:
"""Get the currently active model version for inference."""
with get_session_context() as session:
result = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).first()
if result:
session.expunge(result)
return result
def activate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Activate a model version for inference (deactivates all others)."""
with get_session_context() as session:
# Deactivate all versions
all_versions = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).all()
for v in all_versions:
v.is_active = False
v.status = "inactive"
v.updated_at = datetime.utcnow()
session.add(v)
# Activate the specified version
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = True
model.status = "active"
model.activated_at = datetime.utcnow()
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def deactivate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Deactivate a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = False
model.status = "inactive"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def update_model_version(
self,
version_id: str | UUID,
name: str | None = None,
description: str | None = None,
status: str | None = None,
) -> ModelVersion | None:
"""Update model version metadata."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
if name is not None:
model.name = name
if description is not None:
model.description = description
if status is not None:
model.status = status
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def archive_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Archive a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
# Cannot archive active model
if model.is_active:
return None
model.status = "archived"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def delete_model_version(self, version_id: str | UUID) -> bool:
"""Delete a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return False
# Cannot delete active model
if model.is_active:
return False
session.delete(model)
session.commit()
return True

View File

@@ -70,6 +70,8 @@ class AdminDocument(SQLModel, table=True):
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
group_key: str | None = Field(default=None, max_length=255, index=True)
# User-defined grouping key for document organization
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
@@ -275,6 +277,56 @@ class TrainingDocumentLink(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Model Version Management
# =============================================================================
class ModelVersion(SQLModel, table=True):
"""Model version for inference deployment."""
__tablename__ = "model_versions"
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
version: str = Field(max_length=50, index=True)
# Semantic version e.g., "1.0.0", "2.1.0"
name: str = Field(max_length=255)
description: str | None = Field(default=None)
model_path: str = Field(max_length=512)
# Path to the model weights file
status: str = Field(default="inactive", max_length=20, index=True)
# Status: active, inactive, archived
is_active: bool = Field(default=False, index=True)
# Only one version can be active at a time for inference
# Training association
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training metrics
metrics_mAP: float | None = Field(default=None)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
document_count: int = Field(default=0)
# Number of documents used in training
# Training configuration snapshot
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of epochs, batch_size, etc.
# File info
file_size: int | None = Field(default=None)
# Model file size in bytes
# Timestamps
trained_at: datetime | None = Field(default=None)
# When training completed
activated_at: datetime | None = Field(default=None)
# When this version was last activated
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================

View File

@@ -49,6 +49,111 @@ def get_engine():
return _engine
def run_migrations() -> None:
"""Run database migrations for new columns."""
engine = get_engine()
migrations = [
# Migration 004: Training datasets tables and dataset_id on training_tasks
(
"training_datasets_tables",
"""
CREATE TABLE IF NOT EXISTS training_datasets (
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(255) NOT NULL,
description TEXT,
status VARCHAR(20) NOT NULL DEFAULT 'building',
train_ratio FLOAT NOT NULL DEFAULT 0.8,
val_ratio FLOAT NOT NULL DEFAULT 0.1,
seed INTEGER NOT NULL DEFAULT 42,
total_documents INTEGER NOT NULL DEFAULT 0,
total_images INTEGER NOT NULL DEFAULT 0,
total_annotations INTEGER NOT NULL DEFAULT 0,
dataset_path VARCHAR(512),
error_message TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
""",
),
(
"dataset_documents_table",
"""
CREATE TABLE IF NOT EXISTS dataset_documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
split VARCHAR(10) NOT NULL,
page_count INTEGER NOT NULL DEFAULT 0,
annotation_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
UNIQUE(dataset_id, document_id)
);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
""",
),
(
"training_tasks_dataset_id",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
""",
),
# Migration 005: Add group_key to admin_documents
(
"admin_documents_group_key",
"""
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
""",
),
# Migration 006: Model versions table
(
"model_versions_table",
"""
CREATE TABLE IF NOT EXISTS model_versions (
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
version VARCHAR(50) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
model_path VARCHAR(512) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
is_active BOOLEAN NOT NULL DEFAULT FALSE,
task_id UUID REFERENCES training_tasks(task_id),
dataset_id UUID REFERENCES training_datasets(dataset_id),
metrics_mAP FLOAT,
metrics_precision FLOAT,
metrics_recall FLOAT,
document_count INTEGER NOT NULL DEFAULT 0,
training_config JSONB,
file_size BIGINT,
trained_at TIMESTAMP WITH TIME ZONE,
activated_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
""",
),
]
with engine.connect() as conn:
for name, sql in migrations:
try:
conn.execute(text(sql))
conn.commit()
logger.info(f"Migration '{name}' applied successfully")
except Exception as e:
# Log but don't fail - column may already exist
logger.debug(f"Migration '{name}' skipped or failed: {e}")
def create_db_and_tables() -> None:
"""Create all database tables."""
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
@@ -64,6 +169,9 @@ def create_db_and_tables() -> None:
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
# Run migrations for new columns
run_migrations()
def get_session() -> Session:
"""Get a new database session."""

View File

@@ -5,6 +5,7 @@ Document management, annotations, and training endpoints.
"""
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.api.v1.admin.auth import create_auth_router
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.api.v1.admin.locks import create_locks_router
@@ -12,6 +13,7 @@ from inference.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_augmentation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",

View File

@@ -0,0 +1,15 @@
"""Augmentation API module."""
from fastapi import APIRouter
from .routes import register_augmentation_routes
def create_augmentation_router() -> APIRouter:
"""Create and configure the augmentation router."""
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
register_augmentation_routes(router)
return router
__all__ = ["create_augmentation_router"]

View File

@@ -0,0 +1,162 @@
"""Augmentation API routes."""
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminDBDep, AdminTokenDep
from inference.web.schemas.admin.augmentation import (
AugmentationBatchRequest,
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewRequest,
AugmentationPreviewResponse,
AugmentationTypeInfo,
AugmentationTypesResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
PresetInfo,
PresetsResponse,
)
def register_augmentation_routes(router: APIRouter) -> None:
"""Register augmentation endpoints on the router."""
@router.get(
"/types",
response_model=AugmentationTypesResponse,
summary="List available augmentation types",
)
async def list_augmentation_types(
admin_token: AdminTokenDep,
) -> AugmentationTypesResponse:
"""
List all available augmentation types with descriptions and parameters.
"""
from shared.augmentation.pipeline import (
AUGMENTATION_REGISTRY,
AugmentationPipeline,
)
types = []
for name, aug_class in AUGMENTATION_REGISTRY.items():
# Create instance with empty params to get preview params
aug = aug_class({})
types.append(
AugmentationTypeInfo(
name=name,
description=(aug_class.__doc__ or "").strip(),
affects_geometry=aug_class.affects_geometry,
stage=AugmentationPipeline.STAGE_MAPPING[name],
default_params=aug.get_preview_params(),
)
)
return AugmentationTypesResponse(augmentation_types=types)
@router.get(
"/presets",
response_model=PresetsResponse,
summary="Get augmentation presets",
)
async def get_presets(
admin_token: AdminTokenDep,
) -> PresetsResponse:
"""Get predefined augmentation presets for common use cases."""
from shared.augmentation.presets import list_presets
presets = [PresetInfo(**p) for p in list_presets()]
return PresetsResponse(presets=presets)
@router.post(
"/preview/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview augmentation on document image",
)
async def preview_augmentation(
document_id: str,
request: AugmentationPreviewRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Returns URLs to original and augmented preview images.
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.preview_single(
document_id=document_id,
page=page,
augmentation_type=request.augmentation_type,
params=request.params,
)
@router.post(
"/preview-config/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview full augmentation config on document",
)
async def preview_config(
document_id: str,
config: AugmentationConfigSchema,
admin_token: AdminTokenDep,
db: AdminDBDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""Preview complete augmentation pipeline on a document page."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.preview_config(
document_id=document_id,
page=page,
config=config,
)
@router.post(
"/batch",
response_model=AugmentationBatchResponse,
summary="Create augmented dataset (offline preprocessing)",
)
async def create_augmented_dataset(
request: AugmentationBatchRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
This runs as a background task. The augmented images are stored
alongside the original dataset for training.
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.create_augmented_dataset(
source_dataset_id=request.dataset_id,
config=request.config,
output_name=request.output_name,
multiplier=request.multiplier,
)
@router.get(
"/datasets",
response_model=AugmentedDatasetListResponse,
summary="List augmented datasets",
)
async def list_augmented_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
offset: int = Query(default=0, ge=0, description="Offset"),
) -> AugmentedDatasetListResponse:
"""List all augmented datasets."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.list_augmented_datasets(limit=limit, offset=offset)

View File

@@ -91,8 +91,19 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
group_key: Annotated[
str | None,
Query(description="Optional group key for document organization", max_length=255),
] = None,
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
@@ -131,6 +142,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
group_key=group_key,
)
# Save file to admin uploads
@@ -177,6 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
group_key=group_key,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@@ -277,6 +290,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
annotation_count=len(annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
@@ -421,6 +435,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
group_key=document.group_key if hasattr(document, 'group_key') else None,
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
@@ -548,4 +563,50 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
return response
@router.patch(
"/{document_id}/group-key",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document group key",
description="Update the group key for a document.",
)
async def update_document_group_key(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
group_key: Annotated[
str | None,
Query(description="New group key (null to clear)"),
] = None,
) -> dict:
"""Update document group key."""
_validate_uuid(document_id, "document_id")
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Verify document exists
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Update group key
db.update_document_group_key(document_id, group_key)
return {
"status": "updated",
"document_id": document_id,
"group_key": group_key,
"message": "Document group key updated",
}
return router

View File

@@ -11,6 +11,7 @@ from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
from .models import register_model_routes
def create_training_router() -> APIRouter:
@@ -21,6 +22,7 @@ def create_training_router() -> APIRouter:
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
register_model_routes(router)
return router

View File

@@ -41,6 +41,13 @@ def register_dataset_routes(router: APIRouter) -> None:
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
# Validate minimum document count for proper train/val/test split
if len(request.document_ids) < 10:
raise HTTPException(
status_code=400,
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
)
dataset = db.create_dataset(
name=request.name,
description=request.description,
@@ -83,6 +90,15 @@ def register_dataset_routes(router: APIRouter) -> None:
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
# Get active training tasks for each dataset (graceful degradation on error)
dataset_ids = [str(d.dataset_id) for d in datasets]
try:
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
except Exception:
logger.exception("Failed to fetch active training tasks")
active_tasks = {}
return DatasetListResponse(
total=total,
limit=limit,
@@ -93,6 +109,8 @@ def register_dataset_routes(router: APIRouter) -> None:
name=d.name,
description=d.description,
status=d.status,
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
@@ -175,6 +193,7 @@ def register_dataset_routes(router: APIRouter) -> None:
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
description="Create a training task. Set base_model_version_id in config for incremental training.",
)
async def train_from_dataset(
dataset_id: str,
@@ -182,7 +201,11 @@ def register_dataset_routes(router: APIRouter) -> None:
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset."""
"""Create a training task from a dataset.
For incremental training, set config.base_model_version_id to a model version UUID.
The training will use that model as the starting point instead of a pretrained model.
"""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
@@ -194,16 +217,42 @@ def register_dataset_routes(router: APIRouter) -> None:
)
config_dict = request.config.model_dump()
# Resolve base_model_version_id to actual model path for incremental training
base_model_version_id = config_dict.get("base_model_version_id")
if base_model_version_id:
_validate_uuid(base_model_version_id, "base_model_version_id")
base_model = db.get_model_version(base_model_version_id)
if not base_model:
raise HTTPException(
status_code=404,
detail=f"Base model version not found: {base_model_version_id}",
)
# Store the resolved model path for the training worker
config_dict["base_model_path"] = base_model.model_path
config_dict["base_model_version"] = base_model.version
logger.info(
"Incremental training: using model %s (%s) as base",
base_model.version,
base_model.model_path,
)
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type="train",
task_type="finetune" if base_model_version_id else "train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
message = (
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
if base_model_version_id
else "Training task created from dataset"
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
message="Training task created from dataset",
message=message,
)

View File

@@ -145,15 +145,15 @@ def register_document_routes(router: APIRouter) -> None:
)
@router.get(
"/models",
"/completed-tasks",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get trained models",
description="Get list of trained models with metrics and download links.",
summary="Get completed training tasks",
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
)
async def get_training_models(
async def get_completed_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[

View File

@@ -0,0 +1,333 @@
"""Model Version Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query, Request
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
ModelVersionItem,
ModelVersionListResponse,
ModelVersionDetailResponse,
ModelVersionResponse,
ActiveModelResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_model_routes(router: APIRouter) -> None:
"""Register model version endpoints on the router."""
@router.post(
"/models",
response_model=ModelVersionResponse,
summary="Create model version",
description="Register a new model version for deployment.",
)
async def create_model_version(
request: ModelVersionCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Create a new model version."""
if request.task_id:
_validate_uuid(request.task_id, "task_id")
if request.dataset_id:
_validate_uuid(request.dataset_id, "dataset_id")
model = db.create_model_version(
version=request.version,
name=request.name,
model_path=request.model_path,
description=request.description,
task_id=request.task_id,
dataset_id=request.dataset_id,
metrics_mAP=request.metrics_mAP,
metrics_precision=request.metrics_precision,
metrics_recall=request.metrics_recall,
document_count=request.document_count,
training_config=request.training_config,
file_size=request.file_size,
trained_at=request.trained_at,
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version created successfully",
)
@router.get(
"/models",
response_model=ModelVersionListResponse,
summary="List model versions",
)
async def list_model_versions(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> ModelVersionListResponse:
"""List model versions with optional status filter."""
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
return ModelVersionListResponse(
total=total,
limit=limit,
offset=offset,
models=[
ModelVersionItem(
version_id=str(m.version_id),
version=m.version,
name=m.name,
status=m.status,
is_active=m.is_active,
metrics_mAP=m.metrics_mAP,
document_count=m.document_count,
trained_at=m.trained_at,
activated_at=m.activated_at,
created_at=m.created_at,
)
for m in models
],
)
@router.get(
"/models/active",
response_model=ActiveModelResponse,
summary="Get active model",
description="Get the currently active model for inference.",
)
async def get_active_model(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ActiveModelResponse:
"""Get the currently active model version."""
model = db.get_active_model_version()
if not model:
return ActiveModelResponse(has_active_model=False, model=None)
return ActiveModelResponse(
has_active_model=True,
model=ModelVersionItem(
version_id=str(model.version_id),
version=model.version,
name=model.name,
status=model.status,
is_active=model.is_active,
metrics_mAP=model.metrics_mAP,
document_count=model.document_count,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
),
)
@router.get(
"/models/{version_id}",
response_model=ModelVersionDetailResponse,
summary="Get model version detail",
)
async def get_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionDetailResponse:
"""Get detailed model version information."""
_validate_uuid(version_id, "version_id")
model = db.get_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionDetailResponse(
version_id=str(model.version_id),
version=model.version,
name=model.name,
description=model.description,
model_path=model.model_path,
status=model.status,
is_active=model.is_active,
task_id=str(model.task_id) if model.task_id else None,
dataset_id=str(model.dataset_id) if model.dataset_id else None,
metrics_mAP=model.metrics_mAP,
metrics_precision=model.metrics_precision,
metrics_recall=model.metrics_recall,
document_count=model.document_count,
training_config=model.training_config,
file_size=model.file_size,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
updated_at=model.updated_at,
)
@router.patch(
"/models/{version_id}",
response_model=ModelVersionResponse,
summary="Update model version",
)
async def update_model_version(
version_id: str,
request: ModelVersionUpdateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Update model version metadata."""
_validate_uuid(version_id, "version_id")
model = db.update_model_version(
version_id=version_id,
name=request.name,
description=request.description,
status=request.status,
)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version updated successfully",
)
@router.post(
"/models/{version_id}/activate",
response_model=ModelVersionResponse,
summary="Activate model version",
description="Activate a model version for inference (deactivates all others).",
)
async def activate_model_version(
version_id: str,
request: Request,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Activate a model version for inference."""
_validate_uuid(version_id, "version_id")
model = db.activate_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
# Trigger model reload in inference service
inference_service = getattr(request.app.state, "inference_service", None)
model_reloaded = False
if inference_service:
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info(f"Inference model reloaded to version {model.version}")
except Exception as e:
logger.warning(f"Failed to reload inference model: {e}")
message = "Model version activated for inference"
if model_reloaded:
message += " (model reloaded)"
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message=message,
)
@router.post(
"/models/{version_id}/deactivate",
response_model=ModelVersionResponse,
summary="Deactivate model version",
)
async def deactivate_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Deactivate a model version."""
_validate_uuid(version_id, "version_id")
model = db.deactivate_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version deactivated",
)
@router.post(
"/models/{version_id}/archive",
response_model=ModelVersionResponse,
summary="Archive model version",
)
async def archive_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Archive a model version."""
_validate_uuid(version_id, "version_id")
model = db.archive_model_version(version_id)
if not model:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot archive active model",
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version archived",
)
@router.delete(
"/models/{version_id}",
summary="Delete model version",
)
async def delete_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a model version."""
_validate_uuid(version_id, "version_id")
success = db.delete_model_version(version_id)
if not success:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot delete active model",
)
return {"message": "Model version deleted"}
@router.post(
"/models/reload",
summary="Reload inference model",
description="Reload the inference model from the currently active model version.",
)
async def reload_inference_model(
request: Request,
admin_token: AdminTokenDep,
) -> dict:
"""Reload the inference model from active version."""
inference_service = getattr(request.app.state, "inference_service", None)
if not inference_service:
raise HTTPException(
status_code=500,
detail="Inference service not available",
)
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info("Inference model manually reloaded")
return {"message": "Model reloaded successfully", "reloaded": True}
else:
return {"message": "Model already up to date", "reloaded": False}
except Exception as e:
logger.error(f"Failed to reload model: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to reload model: {e}",
)

View File

@@ -37,6 +37,7 @@ from inference.web.core.rate_limiter import RateLimiter
# Admin API imports
from inference.web.api.v1.admin import (
create_annotation_router,
create_augmentation_router,
create_auth_router,
create_documents_router,
create_locks_router,
@@ -69,10 +70,23 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
"""
config = config or default_config
# Create inference service
# Create model path resolver that reads from database
def get_active_model_path():
"""Resolve active model path from database."""
try:
db = AdminDB()
active_model = db.get_active_model_version()
if active_model and active_model.model_path:
return active_model.model_path
except Exception as e:
logger.warning(f"Failed to get active model from database: {e}")
return None
# Create inference service with database model resolver
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
model_path_resolver=get_active_model_path,
)
# Create async processing components
@@ -185,6 +199,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
# Store inference service for access by routes (e.g., model reload)
# This will be set after app creation
app = FastAPI(
title="Invoice Field Extraction API",
description="""
@@ -255,9 +272,15 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
augmentation_router = create_augmentation_router()
app.include_router(augmentation_router, prefix="/api/v1/admin")
# Include batch upload routes
app.include_router(batch_upload_router)
# Store inference service in app state for access by routes
app.state.inference_service = inference_service
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:

View File

@@ -110,6 +110,7 @@ class TrainingScheduler:
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
base_model_path = config.get("base_model_path") # For incremental training
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
@@ -117,12 +118,31 @@ class TrainingScheduler:
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
# Get augmentation config if present
augmentation_config = config.get("augmentation")
augmentation_multiplier = config.get("augmentation_multiplier", 2)
# Determine which model to use as base
if base_model_path:
# Incremental training: use existing trained model
if not Path(base_model_path).exists():
raise ValueError(f"Base model not found: {base_model_path}")
effective_model = base_model_path
self._db.add_training_log(
task_id, "INFO",
f"Incremental training from: {base_model_path}",
)
else:
# Train from pretrained model
effective_model = model_name
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
dataset_path = Path(dataset.dataset_path)
self._db.add_training_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
@@ -132,15 +152,28 @@ class TrainingScheduler:
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
dataset_path = Path(data_yaml).parent
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Apply augmentation if config is provided
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
aug_result = self._apply_augmentation(
task_id, dataset_path, augmentation_config, augmentation_multiplier
)
if aug_result:
self._db.add_training_log(
task_id, "INFO",
f"Augmentation complete: {aug_result['augmented_images']} new images "
f"(total: {aug_result['total_images']})",
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=model_name,
model_name=effective_model, # Use base model or pretrained model
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
@@ -159,11 +192,94 @@ class TrainingScheduler:
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
# Auto-create model version for the completed training
self._create_model_version_from_training(
task_id=task_id,
config=config,
dataset_id=dataset_id,
result=result,
)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
raise
def _create_model_version_from_training(
self,
task_id: str,
config: dict[str, Any],
dataset_id: str | None,
result: dict[str, Any],
) -> None:
"""Create a model version entry from completed training."""
try:
model_path = result.get("model_path")
if not model_path:
logger.warning(f"No model path in training result for task {task_id}")
return
# Get task info for name
task = self._db.get_training_task(task_id)
task_name = task.name if task else f"Task {task_id[:8]}"
# Generate version number based on existing versions
existing_versions = self._db.get_model_versions(limit=1, offset=0)
version_count = existing_versions[1] if existing_versions else 0
version = f"v{version_count + 1}.0"
# Extract metrics from result
metrics = result.get("metrics", {})
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
metrics_precision = metrics.get("precision")
metrics_recall = metrics.get("recall")
# Get file size if possible
file_size = None
model_file = Path(model_path)
if model_file.exists():
file_size = model_file.stat().st_size
# Get document count from dataset if available
document_count = 0
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if dataset:
document_count = dataset.total_documents
# Create model version
model_version = self._db.create_model_version(
version=version,
name=task_name,
model_path=str(model_path),
description=f"Auto-created from training task {task_id[:8]}",
task_id=task_id,
dataset_id=dataset_id,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=config,
file_size=file_size,
trained_at=datetime.utcnow(),
)
logger.info(
f"Created model version {version} (ID: {model_version.version_id}) "
f"from training task {task_id}"
)
self._db.add_training_log(
task_id, "INFO",
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
)
except Exception as e:
logger.error(f"Failed to create model version for task {task_id}: {e}")
self._db.add_training_log(
task_id, "WARNING",
f"Failed to auto-create model version: {e}",
)
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
@@ -256,62 +372,82 @@ names: {list(FIELD_CLASSES.values())}
device: str,
project_name: str,
) -> dict[str, Any]:
"""Run YOLO training."""
"""Run YOLO training using shared trainer."""
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
# Create log callback that writes to DB
def log_callback(level: str, message: str) -> None:
self._db.add_training_log(task_id, level, message)
# Create shared training config
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
config = SharedTrainingConfig(
model_path=model_name,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project="runs/train",
name=f"{project_name}/task_{task_id[:8]}",
workers=0,
)
# Run training using shared trainer
trainer = YOLOTrainer(config=config, log_callback=log_callback)
result = trainer.train()
if not result.success:
raise ValueError(result.error or "Training failed")
return {
"model_path": result.model_path,
"metrics": result.metrics,
}
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
"""Check if any augmentations are enabled in the config."""
augmentation_fields = [
"perspective_warp", "wrinkle", "edge_damage", "stain",
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
]
for field in augmentation_fields:
if field in aug_config:
field_config = aug_config[field]
if isinstance(field_config, dict) and field_config.get("enabled", False):
return True
return False
def _apply_augmentation(
self,
task_id: str,
dataset_path: Path,
aug_config: dict[str, Any],
multiplier: int,
) -> dict[str, int] | None:
"""Apply augmentation to dataset before training."""
try:
from ultralytics import YOLO
# Log training start
self._db.add_training_log(
task_id, "INFO",
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
)
# Load model
model = YOLO(model_name)
# Train
results = model.train(
data=data_yaml,
epochs=epochs,
batch=batch_size,
imgsz=image_size,
lr0=learning_rate,
device=device,
project=f"runs/train/{project_name}",
name=f"task_{task_id[:8]}",
exist_ok=True,
verbose=True,
)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
from shared.augmentation import DatasetAugmenter
self._db.add_training_log(
task_id, "INFO",
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
f"Applying augmentation with multiplier={multiplier}",
)
return {
"model_path": str(best_model) if best_model.exists() else None,
"metrics": metrics,
}
augmenter = DatasetAugmenter(aug_config)
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
return result
except ImportError:
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
raise ValueError("Ultralytics (YOLO) not installed")
except Exception as e:
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
raise
logger.error(f"Augmentation failed for task {task_id}: {e}")
self._db.add_training_log(
task_id, "WARNING",
f"Augmentation failed: {e}. Continuing with original dataset.",
)
return None
# Global scheduler instance

View File

@@ -10,6 +10,7 @@ from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse

View File

@@ -0,0 +1,187 @@
"""Admin Augmentation Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class AugmentationParamsSchema(BaseModel):
"""Single augmentation parameters."""
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
probability: float = Field(
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
)
params: dict[str, Any] = Field(
default_factory=dict, description="Type-specific parameters"
)
class AugmentationConfigSchema(BaseModel):
"""Complete augmentation configuration."""
# Geometric transforms
perspective_warp: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Degradation effects
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
edge_damage: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Lighting effects
lighting_variation: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Blur effects
gaussian_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
motion_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Noise effects
gaussian_noise: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
salt_pepper: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Texture effects
paper_texture: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
scanner_artifacts: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Global settings
preserve_bboxes: bool = Field(
default=True, description="Whether to adjust bboxes for geometric transforms"
)
seed: int | None = Field(default=None, description="Random seed for reproducibility")
class AugmentationTypeInfo(BaseModel):
"""Information about an augmentation type."""
name: str = Field(..., description="Augmentation name")
description: str = Field(..., description="Augmentation description")
affects_geometry: bool = Field(
..., description="Whether this augmentation affects bbox coordinates"
)
stage: str = Field(..., description="Processing stage")
default_params: dict[str, Any] = Field(
default_factory=dict, description="Default parameters"
)
class AugmentationTypesResponse(BaseModel):
"""Response for listing augmentation types."""
augmentation_types: list[AugmentationTypeInfo] = Field(
..., description="Available augmentation types"
)
class PresetInfo(BaseModel):
"""Information about a preset."""
name: str = Field(..., description="Preset name")
description: str = Field(..., description="Preset description")
class PresetsResponse(BaseModel):
"""Response for listing presets."""
presets: list[PresetInfo] = Field(..., description="Available presets")
class AugmentationPreviewRequest(BaseModel):
"""Request to preview augmentation on an image."""
augmentation_type: str = Field(..., description="Type of augmentation to preview")
params: dict[str, Any] = Field(
default_factory=dict, description="Override parameters"
)
class AugmentationPreviewResponse(BaseModel):
"""Response with preview image data."""
preview_url: str = Field(..., description="URL to preview image")
original_url: str = Field(..., description="URL to original image")
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
class AugmentationBatchRequest(BaseModel):
"""Request to augment a dataset offline."""
dataset_id: str = Field(..., description="Source dataset UUID")
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
output_name: str = Field(
..., min_length=1, max_length=255, description="Output dataset name"
)
multiplier: int = Field(
default=2, ge=1, le=10, description="Augmented copies per image"
)
class AugmentationBatchResponse(BaseModel):
"""Response for batch augmentation."""
task_id: str = Field(..., description="Background task UUID")
status: str = Field(..., description="Task status")
message: str = Field(..., description="Status message")
estimated_images: int = Field(..., description="Estimated total images")
class AugmentedDatasetItem(BaseModel):
"""Single augmented dataset in list."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
created_at: datetime = Field(..., description="Creation timestamp")
class AugmentedDatasetListResponse(BaseModel):
"""Response for listing augmented datasets."""
total: int = Field(..., ge=0, description="Total datasets")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
datasets: list[AugmentedDatasetItem] = Field(
default_factory=list, description="Dataset list"
)
class AugmentedDatasetDetailResponse(BaseModel):
"""Detailed augmented dataset response."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
config: AugmentationConfigSchema | None = Field(
None, description="Augmentation config used"
)
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
dataset_path: str | None = Field(None, description="Dataset path on disk")
error_message: str | None = Field(None, description="Error message if failed")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")

View File

@@ -63,6 +63,8 @@ class DatasetListItem(BaseModel):
name: str
description: str | None
status: str
training_status: str | None = None
active_training_task_id: str | None = None
total_documents: int
total_images: int
total_annotations: int

View File

@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
group_key: str | None = Field(None, description="User-defined group key")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
@@ -42,6 +43,7 @@ class DocumentItem(BaseModel):
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@@ -73,6 +75,7 @@ class DocumentDetailResponse(BaseModel):
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)

View File

@@ -0,0 +1,95 @@
"""Admin Model Version Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class ModelVersionCreateRequest(BaseModel):
"""Request to create a model version."""
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
name: str = Field(..., min_length=1, max_length=255, description="Model name")
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
description: str | None = Field(None, description="Optional description")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
document_count: int = Field(0, ge=0, description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
class ModelVersionUpdateRequest(BaseModel):
"""Request to update a model version."""
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
description: str | None = Field(None, description="Description")
status: str | None = Field(None, description="Status (inactive, archived)")
class ModelVersionItem(BaseModel):
"""Model version in list view."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
document_count: int = Field(..., description="Documents used in training")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
class ModelVersionListResponse(BaseModel):
"""Paginated model version list."""
total: int = Field(..., ge=0, description="Total model versions")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
class ModelVersionDetailResponse(BaseModel):
"""Detailed model version info."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
description: str | None = Field(None, description="Description")
model_path: str = Field(..., description="Path to model file")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
metrics_precision: float | None = Field(None, description="Precision")
metrics_recall: float | None = Field(None, description="Recall")
document_count: int = Field(..., description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class ModelVersionResponse(BaseModel):
"""Response for model version operation."""
version_id: str = Field(..., description="Version UUID")
status: str = Field(..., description="Model status")
message: str = Field(..., description="Status message")
class ActiveModelResponse(BaseModel):
"""Response for active model query."""
has_active_model: bool = Field(..., description="Whether an active model exists")
model: ModelVersionItem | None = Field(None, description="Active model if exists")

View File

@@ -5,13 +5,18 @@ from typing import Any
from pydantic import BaseModel, Field
from .augmentation import AugmentationConfigSchema
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name")
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
base_model_version_id: str | None = Field(
default=None,
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
)
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
@@ -21,6 +26,18 @@ class TrainingConfig(BaseModel):
default="invoice_fields", description="Training project name"
)
# Data augmentation settings
augmentation: AugmentationConfigSchema | None = Field(
default=None,
description="Augmentation configuration. If provided, augments dataset before training.",
)
augmentation_multiplier: int = Field(
default=2,
ge=1,
le=10,
description="Number of augmented copies per original image",
)
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""

View File

@@ -0,0 +1,317 @@
"""Augmentation service for handling augmentation operations."""
import base64
import io
import re
import uuid
from pathlib import Path
from typing import Any
import numpy as np
from fastapi import HTTPException
from PIL import Image
from inference.data.admin_db import AdminDB
from inference.web.schemas.admin.augmentation import (
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
)
# Constants
PREVIEW_MAX_SIZE = 800
PREVIEW_SEED = 42
UUID_PATTERN = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
re.IGNORECASE,
)
class AugmentationService:
"""Service for augmentation operations."""
def __init__(self, db: AdminDB) -> None:
"""Initialize service with database connection."""
self.db = db
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
"""
Validate UUID format to prevent path traversal.
Args:
value: Value to validate.
field_name: Field name for error message.
Raises:
HTTPException: If value is not a valid UUID.
"""
if not UUID_PATTERN.match(value):
raise HTTPException(
status_code=400,
detail=f"Invalid {field_name} format: {value}",
)
async def preview_single(
self,
document_id: str,
page: int,
augmentation_type: str,
params: dict[str, Any],
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
augmentation_type: Name of augmentation to apply.
params: Override parameters.
Returns:
Preview response with image URLs.
Raises:
HTTPException: If document not found or augmentation invalid.
"""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
# Validate augmentation type
if augmentation_type not in AUGMENTATION_REGISTRY:
raise HTTPException(
status_code=400,
detail=f"Unknown augmentation type: {augmentation_type}. "
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
)
# Get document and load image
image = await self._load_document_page(document_id, page)
# Create config with only this augmentation enabled
config_kwargs = {
augmentation_type: AugmentationParams(
enabled=True,
probability=1.0, # Always apply for preview
params=params,
),
"seed": PREVIEW_SEED, # Deterministic preview
}
config = AugmentationConfig(**config_kwargs)
pipeline = AugmentationPipeline(config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=params,
)
async def preview_config(
self,
document_id: str,
page: int,
config: AugmentationConfigSchema,
) -> AugmentationPreviewResponse:
"""
Preview full augmentation config on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
config: Full augmentation configuration.
Returns:
Preview response with image URLs.
"""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
# Load image
image = await self._load_document_page(document_id, page)
# Convert Pydantic model to internal config
config_dict = config.model_dump()
internal_config = AugmentationConfig.from_dict(config_dict)
pipeline = AugmentationPipeline(internal_config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=config_dict,
)
async def create_augmented_dataset(
self,
source_dataset_id: str,
config: AugmentationConfigSchema,
output_name: str,
multiplier: int,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
Args:
source_dataset_id: Source dataset UUID.
config: Augmentation configuration.
output_name: Name for the new dataset.
multiplier: Number of augmented copies per image.
Returns:
Batch response with task ID.
Raises:
HTTPException: If source dataset not found.
"""
# Validate source dataset exists
try:
source_dataset = self.db.get_dataset(source_dataset_id)
if source_dataset is None:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
)
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
) from e
# Create task ID for background processing
task_id = str(uuid.uuid4())
# Estimate total images
estimated_images = (
source_dataset.total_images * multiplier
if hasattr(source_dataset, "total_images")
else 0
)
# TODO: Queue background task for actual augmentation
# For now, return pending status
return AugmentationBatchResponse(
task_id=task_id,
status="pending",
message=f"Augmentation task queued for dataset '{output_name}'",
estimated_images=estimated_images,
)
async def list_augmented_datasets(
self,
limit: int = 20,
offset: int = 0,
) -> AugmentedDatasetListResponse:
"""
List augmented datasets.
Args:
limit: Maximum number of datasets to return.
offset: Number of datasets to skip.
Returns:
List response with datasets.
"""
# TODO: Implement actual database query for augmented datasets
# For now, return empty list
return AugmentedDatasetListResponse(
total=0,
limit=limit,
offset=offset,
datasets=[],
)
async def _load_document_page(
self,
document_id: str,
page: int,
) -> np.ndarray:
"""
Load a document page as numpy array.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
Returns:
Image as numpy array (H, W, C) with dtype uint8.
Raises:
HTTPException: If document or page not found.
"""
# Validate document_id format to prevent path traversal
self._validate_uuid(document_id, "document_id")
# Get document from database
try:
document = self.db.get_document(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
) from e
# Get image path for page
if hasattr(document, "images_dir"):
images_dir = Path(document.images_dir)
else:
# Fallback to constructed path
from inference.web.core.config import get_settings
settings = get_settings()
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
# Find image for page
page_idx = page - 1 # Convert to 0-indexed
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
if page_idx >= len(image_files):
raise HTTPException(
status_code=404,
detail=f"Page {page} not found for document {document_id}",
)
# Load image
image_path = image_files[page_idx]
pil_image = Image.open(image_path).convert("RGB")
return np.array(pil_image)
def _image_to_data_url(self, image: np.ndarray) -> str:
"""Convert numpy image to base64 data URL."""
pil_image = Image.fromarray(image)
# Resize for preview if too large
max_size = PREVIEW_MAX_SIZE
if max(pil_image.size) > max_size:
ratio = max_size / max(pil_image.size)
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# Convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{base64_data}"

View File

@@ -81,29 +81,18 @@ class DatasetBuilder:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Shuffle and split documents
# 3. Group documents by group_key and assign splits
doc_list = list(documents)
rng = random.Random(seed)
rng.shuffle(doc_list)
n = len(doc_list)
n_train = max(1, round(n * train_ratio))
n_val = max(0, round(n * val_ratio))
n_test = n - n_train - n_val
splits = (
["train"] * n_train
+ ["val"] * n_val
+ ["test"] * n_test
)
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc, split in zip(doc_list, splits):
for doc in doc_list:
doc_id = str(doc.document_id)
split = doc_splits[doc_id]
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
@@ -174,6 +163,86 @@ class DatasetBuilder:
"total_annotations": total_annotations,
}
def _assign_splits_by_group(
self,
documents: list,
train_ratio: float,
val_ratio: float,
seed: int,
) -> dict[str, str]:
"""Assign splits based on group_key.
Logic:
- Documents with same group_key stay together in the same split
- Groups with only 1 document go directly to train
- Groups with 2+ documents participate in shuffle & split
Args:
documents: List of AdminDocument objects
train_ratio: Fraction for training set
val_ratio: Fraction for validation set
seed: Random seed for reproducibility
Returns:
Dict mapping document_id (str) -> split ("train"/"val"/"test")
"""
# Group documents by group_key
# None/empty group_key treated as unique (each doc is its own group)
groups: dict[str | None, list] = {}
for doc in documents:
key = doc.group_key if doc.group_key else None
if key is None:
# Treat each ungrouped doc as its own unique group
# Use document_id as pseudo-key
key = f"__ungrouped_{doc.document_id}"
groups.setdefault(key, []).append(doc)
# Separate single-doc groups from multi-doc groups
single_doc_groups: list[tuple[str | None, list]] = []
multi_doc_groups: list[tuple[str | None, list]] = []
for key, docs in groups.items():
if len(docs) == 1:
single_doc_groups.append((key, docs))
else:
multi_doc_groups.append((key, docs))
# Initialize result mapping
doc_splits: dict[str, str] = {}
# Combine all groups for splitting
all_groups = single_doc_groups + multi_doc_groups
# Shuffle all groups and assign splits
if all_groups:
rng = random.Random(seed)
rng.shuffle(all_groups)
n_groups = len(all_groups)
n_train = max(1, round(n_groups * train_ratio))
# Ensure at least 1 in val if we have more than 1 group
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
for i, (_key, docs) in enumerate(all_groups):
if i < n_train:
split = "train"
elif i < n_train + n_val:
split = "val"
else:
split = "test"
for doc in docs:
doc_splits[str(doc.document_id)] = split
logger.info(
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
len(all_groups),
sum(1 for s in doc_splits.values() if s == "train"),
sum(1 for s in doc_splits.values() if s == "val"),
)
return doc_splits
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {

View File

@@ -11,7 +11,7 @@ import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
import numpy as np
from PIL import Image
@@ -22,6 +22,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Type alias for model path resolver function
ModelPathResolver = Callable[[], Path | None]
@dataclass
class ServiceResult:
"""Result from inference service."""
@@ -42,25 +46,52 @@ class InferenceService:
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration
model_config: Model configuration (default model settings)
storage_config: Storage configuration
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
self._pipeline = None
self._detector = None
self._is_initialized = False
self._current_model_path: Path | None = None
def _resolve_model_path(self) -> Path:
"""Resolve the model path to use for inference.
Priority:
1. Active model from database (via resolver)
2. Default model from config
"""
if self._model_path_resolver:
try:
db_model_path = self._model_path_resolver()
if db_model_path and Path(db_model_path).exists():
logger.info(f"Using active model from database: {db_model_path}")
return Path(db_model_path)
elif db_model_path:
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
except Exception as e:
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
return self.model_config.model_path
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
@@ -74,16 +105,20 @@ class InferenceService:
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
# Resolve model path (from DB or config)
model_path = self._resolve_model_path()
self._current_model_path = model_path
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(self.model_config.model_path),
str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(self.model_config.model_path),
model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
@@ -92,12 +127,36 @@ class InferenceService:
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s")
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
def reload_model(self) -> bool:
"""Reload the model if active model has changed.
Returns:
True if model was reloaded, False if no change needed.
"""
new_model_path = self._resolve_model_path()
if self._current_model_path == new_model_path:
logger.debug("Model unchanged, no reload needed")
return False
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
self._is_initialized = False
self._pipeline = None
self._detector = None
self.initialize()
return True
@property
def current_model_path(self) -> Path | None:
"""Get the currently loaded model path."""
return self._current_model_path
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""

View File

@@ -0,0 +1,24 @@
"""
Document Image Augmentation Module.
Provides augmentation transformations for training data enhancement,
specifically designed for document images (invoices, forms, etc.).
Key features:
- Document-safe augmentations that preserve text readability
- Support for both offline preprocessing and runtime augmentation
- Bbox-aware geometric transforms
- Configurable augmentation pipeline
"""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.dataset_augmenter import DatasetAugmenter
__all__ = [
"AugmentationConfig",
"AugmentationParams",
"AugmentationResult",
"BaseAugmentation",
"DatasetAugmenter",
]

View File

@@ -0,0 +1,108 @@
"""
Base classes for augmentation transforms.
Provides abstract base class and result dataclass for all augmentation
implementations.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
import numpy as np
@dataclass
class AugmentationResult:
"""
Result of applying an augmentation.
Attributes:
image: The augmented image as numpy array (H, W, C).
bboxes: Updated bounding boxes if geometric transform was applied.
Format: (N, 5) array with [class_id, x_center, y_center, width, height].
transform_matrix: The transformation matrix if applicable (for bbox adjustment).
applied: Whether the augmentation was actually applied.
metadata: Additional metadata about the augmentation.
"""
image: np.ndarray
bboxes: np.ndarray | None = None
transform_matrix: np.ndarray | None = None
applied: bool = True
metadata: dict[str, Any] | None = None
class BaseAugmentation(ABC):
"""
Abstract base class for all augmentations.
Subclasses must implement:
- _validate_params(): Validate augmentation parameters
- apply(): Apply the augmentation to an image
Class attributes:
name: Human-readable name of the augmentation.
affects_geometry: True if this augmentation modifies bbox coordinates.
"""
name: str = "base"
affects_geometry: bool = False
def __init__(self, params: dict[str, Any]) -> None:
"""
Initialize augmentation with parameters.
Args:
params: Dictionary of augmentation-specific parameters.
"""
self.params = params
self._validate_params()
@abstractmethod
def _validate_params(self) -> None:
"""
Validate augmentation parameters.
Raises:
ValueError: If parameters are invalid.
"""
pass
@abstractmethod
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
"""
Apply augmentation to image.
IMPORTANT: Implementations must NOT modify the input image or bboxes.
Always create copies before modifying.
Args:
image: Input image as numpy array (H, W, C) with dtype uint8.
bboxes: Optional bounding boxes in YOLO format (N, 5) array.
Each row: [class_id, x_center, y_center, width, height].
Coordinates are normalized to 0-1 range.
rng: Random number generator for reproducibility.
If None, a new generator should be created.
Returns:
AugmentationResult with augmented image and optionally updated bboxes.
"""
pass
def get_preview_params(self) -> dict[str, Any]:
"""
Get parameters optimized for preview display.
Override this method to provide parameters that produce
clearly visible effects for preview/demo purposes.
Returns:
Dictionary of preview parameters.
"""
return dict(self.params)

View File

@@ -0,0 +1,274 @@
"""
Augmentation configuration module.
Provides dataclasses for configuring document image augmentations.
All default values are document-safe (conservative) to preserve text readability.
"""
from dataclasses import dataclass, field
from typing import Any
@dataclass
class AugmentationParams:
"""
Parameters for a single augmentation type.
Attributes:
enabled: Whether this augmentation is enabled.
probability: Probability of applying this augmentation (0.0 to 1.0).
params: Type-specific parameters dictionary.
"""
enabled: bool = False
probability: float = 0.5
params: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"enabled": self.enabled,
"probability": self.probability,
"params": dict(self.params),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AugmentationParams":
"""Create from dictionary."""
return cls(
enabled=data.get("enabled", False),
probability=data.get("probability", 0.5),
params=dict(data.get("params", {})),
)
def _default_perspective_warp() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"max_warp": 0.02}, # Very conservative - 2% max distortion
)
def _default_wrinkle() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"intensity": 0.3, "num_wrinkles": (2, 5)},
)
def _default_edge_damage() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"max_damage_ratio": 0.05}, # Max 5% of edge damaged
)
def _default_stain() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={
"num_stains": (1, 3),
"max_radius_ratio": 0.1,
"opacity": (0.1, 0.3),
},
)
def _default_lighting_variation() -> AugmentationParams:
return AugmentationParams(
enabled=True, # Safe default, commonly needed
probability=0.5,
params={
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
)
def _default_shadow() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"num_shadows": (1, 2), "opacity": (0.2, 0.4)},
)
def _default_gaussian_blur() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
)
def _default_motion_blur() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"kernel_size": (5, 9), "angle_range": (-45, 45)},
)
def _default_gaussian_noise() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"mean": 0, "std": (5, 15)}, # Conservative noise levels
)
def _default_salt_pepper() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"amount": (0.001, 0.005)}, # Very sparse
)
def _default_paper_texture() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"texture_type": "random", "intensity": (0.05, 0.15)},
)
def _default_scanner_artifacts() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"line_probability": 0.3, "dust_probability": 0.4},
)
@dataclass
class AugmentationConfig:
"""
Complete augmentation configuration.
All augmentation types have document-safe defaults that preserve
text readability. Only lighting_variation is enabled by default.
Attributes:
perspective_warp: Geometric perspective transform (affects bboxes).
wrinkle: Paper wrinkle/crease simulation.
edge_damage: Damaged/torn edge effects.
stain: Coffee stain/smudge effects.
lighting_variation: Brightness and contrast variation.
shadow: Shadow overlay effects.
gaussian_blur: Gaussian blur for focus issues.
motion_blur: Motion blur simulation.
gaussian_noise: Gaussian noise for sensor noise.
salt_pepper: Salt and pepper noise.
paper_texture: Paper texture overlay.
scanner_artifacts: Scanner line and dust artifacts.
preserve_bboxes: Whether to adjust bboxes for geometric transforms.
seed: Random seed for reproducibility.
"""
# Geometric transforms (affects bboxes)
perspective_warp: AugmentationParams = field(
default_factory=_default_perspective_warp
)
# Degradation effects
wrinkle: AugmentationParams = field(default_factory=_default_wrinkle)
edge_damage: AugmentationParams = field(default_factory=_default_edge_damage)
stain: AugmentationParams = field(default_factory=_default_stain)
# Lighting effects
lighting_variation: AugmentationParams = field(
default_factory=_default_lighting_variation
)
shadow: AugmentationParams = field(default_factory=_default_shadow)
# Blur effects
gaussian_blur: AugmentationParams = field(default_factory=_default_gaussian_blur)
motion_blur: AugmentationParams = field(default_factory=_default_motion_blur)
# Noise effects
gaussian_noise: AugmentationParams = field(default_factory=_default_gaussian_noise)
salt_pepper: AugmentationParams = field(default_factory=_default_salt_pepper)
# Texture effects
paper_texture: AugmentationParams = field(default_factory=_default_paper_texture)
scanner_artifacts: AugmentationParams = field(
default_factory=_default_scanner_artifacts
)
# Global settings
preserve_bboxes: bool = True
seed: int | None = None
# List of all augmentation field names
_AUGMENTATION_FIELDS: tuple[str, ...] = (
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
result: dict[str, Any] = {
"preserve_bboxes": self.preserve_bboxes,
"seed": self.seed,
}
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
result[field_name] = params.to_dict()
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AugmentationConfig":
"""Create from dictionary."""
kwargs: dict[str, Any] = {
"preserve_bboxes": data.get("preserve_bboxes", True),
"seed": data.get("seed"),
}
for field_name in cls._AUGMENTATION_FIELDS:
if field_name in data:
field_data = data[field_name]
if isinstance(field_data, dict):
kwargs[field_name] = AugmentationParams.from_dict(field_data)
return cls(**kwargs)
def get_enabled_augmentations(self) -> list[str]:
"""Get list of enabled augmentation names."""
enabled = []
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
if params.enabled:
enabled.append(field_name)
return enabled
def validate(self) -> None:
"""
Validate configuration.
Raises:
ValueError: If any configuration value is invalid.
"""
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
if not (0.0 <= params.probability <= 1.0):
raise ValueError(
f"{field_name}.probability must be between 0 and 1, "
f"got {params.probability}"
)

View File

@@ -0,0 +1,206 @@
"""
Dataset Augmenter Module.
Applies augmentation pipeline to YOLO datasets,
creating new augmented images and label files.
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
logger = logging.getLogger(__name__)
class DatasetAugmenter:
"""
Augments YOLO datasets by creating new images and label files.
Reads images from dataset/images/train/ and labels from dataset/labels/train/,
applies augmentation pipeline, and saves augmented versions with "_augN" suffix.
"""
def __init__(
self,
config: dict[str, Any],
seed: int | None = None,
) -> None:
"""
Initialize augmenter with configuration.
Args:
config: Dictionary mapping augmentation names to their settings.
Each augmentation should have 'enabled', 'probability', and 'params'.
seed: Random seed for reproducibility.
"""
self._config_dict = config
self._seed = seed
self._config = self._build_config(config, seed)
def _build_config(
self,
config_dict: dict[str, Any],
seed: int | None,
) -> AugmentationConfig:
"""Build AugmentationConfig from dictionary."""
kwargs: dict[str, Any] = {"seed": seed, "preserve_bboxes": True}
for aug_name, aug_settings in config_dict.items():
if aug_name in AugmentationConfig._AUGMENTATION_FIELDS:
kwargs[aug_name] = AugmentationParams(
enabled=aug_settings.get("enabled", False),
probability=aug_settings.get("probability", 0.5),
params=aug_settings.get("params", {}),
)
return AugmentationConfig(**kwargs)
def augment_dataset(
self,
dataset_path: Path,
multiplier: int = 1,
split: str = "train",
) -> dict[str, int]:
"""
Augment a YOLO dataset.
Args:
dataset_path: Path to dataset root (containing images/ and labels/).
multiplier: Number of augmented copies per original image.
split: Which split to augment (default: "train").
Returns:
Summary dict with original_images, augmented_images, total_images.
"""
images_dir = dataset_path / "images" / split
labels_dir = dataset_path / "labels" / split
if not images_dir.exists():
raise ValueError(f"Images directory not found: {images_dir}")
# Find all images
image_extensions = ("*.png", "*.jpg", "*.jpeg")
image_files: list[Path] = []
for ext in image_extensions:
image_files.extend(images_dir.glob(ext))
original_count = len(image_files)
augmented_count = 0
if multiplier <= 0:
return {
"original_images": original_count,
"augmented_images": 0,
"total_images": original_count,
}
# Process each image
for img_path in image_files:
# Load image
pil_image = Image.open(img_path).convert("RGB")
image = np.array(pil_image)
# Load corresponding label
label_path = labels_dir / f"{img_path.stem}.txt"
bboxes = self._load_bboxes(label_path) if label_path.exists() else None
# Create multiple augmented versions
for aug_idx in range(multiplier):
# Create pipeline with adjusted seed for each augmentation
aug_seed = None
if self._seed is not None:
aug_seed = self._seed + aug_idx + hash(img_path.stem) % 10000
pipeline = AugmentationPipeline(
self._build_config(self._config_dict, aug_seed)
)
# Apply augmentation
result = pipeline.apply(image, bboxes)
# Save augmented image
aug_name = f"{img_path.stem}_aug{aug_idx}{img_path.suffix}"
aug_img_path = images_dir / aug_name
aug_pil = Image.fromarray(result.image)
aug_pil.save(aug_img_path)
# Save augmented label
aug_label_path = labels_dir / f"{img_path.stem}_aug{aug_idx}.txt"
self._save_bboxes(aug_label_path, result.bboxes)
augmented_count += 1
logger.info(
"Dataset augmentation complete: %d original, %d augmented",
original_count,
augmented_count,
)
return {
"original_images": original_count,
"augmented_images": augmented_count,
"total_images": original_count + augmented_count,
}
def _load_bboxes(self, label_path: Path) -> np.ndarray | None:
"""
Load bounding boxes from YOLO label file.
Args:
label_path: Path to label file.
Returns:
Array of shape (N, 5) with class_id, x_center, y_center, width, height.
Returns None if file is empty or doesn't exist.
"""
if not label_path.exists():
return None
content = label_path.read_text().strip()
if not content:
return None
bboxes = []
for line in content.split("\n"):
parts = line.strip().split()
if len(parts) == 5:
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
bboxes.append([class_id, x_center, y_center, width, height])
if not bboxes:
return None
return np.array(bboxes, dtype=np.float32)
def _save_bboxes(self, label_path: Path, bboxes: np.ndarray | None) -> None:
"""
Save bounding boxes to YOLO label file.
Args:
label_path: Path to save label file.
bboxes: Array of shape (N, 5) or None for empty labels.
"""
if bboxes is None or len(bboxes) == 0:
label_path.write_text("")
return
lines = []
for bbox in bboxes:
class_id = int(bbox[0])
x_center = bbox[1]
y_center = bbox[2]
width = bbox[3]
height = bbox[4]
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
label_path.write_text("\n".join(lines))

View File

@@ -0,0 +1,184 @@
"""
Augmentation pipeline module.
Orchestrates multiple augmentations with proper ordering and
provides preview functionality.
"""
from typing import Any
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.transforms.blur import GaussianBlur, MotionBlur
from shared.augmentation.transforms.degradation import EdgeDamage, Stain, Wrinkle
from shared.augmentation.transforms.geometric import PerspectiveWarp
from shared.augmentation.transforms.lighting import LightingVariation, Shadow
from shared.augmentation.transforms.noise import GaussianNoise, SaltPepper
from shared.augmentation.transforms.texture import PaperTexture, ScannerArtifacts
# Registry of augmentation classes
AUGMENTATION_REGISTRY: dict[str, type[BaseAugmentation]] = {
"perspective_warp": PerspectiveWarp,
"wrinkle": Wrinkle,
"edge_damage": EdgeDamage,
"stain": Stain,
"lighting_variation": LightingVariation,
"shadow": Shadow,
"gaussian_blur": GaussianBlur,
"motion_blur": MotionBlur,
"gaussian_noise": GaussianNoise,
"salt_pepper": SaltPepper,
"paper_texture": PaperTexture,
"scanner_artifacts": ScannerArtifacts,
}
class AugmentationPipeline:
"""
Orchestrates multiple augmentations with proper ordering.
Augmentations are applied in the following order:
1. Geometric (perspective_warp) - affects bboxes
2. Degradation (wrinkle, edge_damage, stain) - visual artifacts
3. Lighting (lighting_variation, shadow)
4. Texture (paper_texture, scanner_artifacts)
5. Blur (gaussian_blur, motion_blur)
6. Noise (gaussian_noise, salt_pepper) - applied last
"""
STAGE_ORDER = [
"geometric",
"degradation",
"lighting",
"texture",
"blur",
"noise",
]
STAGE_MAPPING = {
"perspective_warp": "geometric",
"wrinkle": "degradation",
"edge_damage": "degradation",
"stain": "degradation",
"lighting_variation": "lighting",
"shadow": "lighting",
"paper_texture": "texture",
"scanner_artifacts": "texture",
"gaussian_blur": "blur",
"motion_blur": "blur",
"gaussian_noise": "noise",
"salt_pepper": "noise",
}
def __init__(self, config: AugmentationConfig) -> None:
"""
Initialize pipeline with configuration.
Args:
config: Augmentation configuration.
"""
self.config = config
self._rng = np.random.default_rng(config.seed)
self._augmentations = self._build_augmentations()
def _build_augmentations(
self,
) -> list[tuple[str, BaseAugmentation, float]]:
"""Build ordered list of (name, augmentation, probability) tuples."""
augmentations: list[tuple[str, BaseAugmentation, float]] = []
for aug_name, aug_class in AUGMENTATION_REGISTRY.items():
params: AugmentationParams = getattr(self.config, aug_name)
if params.enabled:
aug = aug_class(params.params)
augmentations.append((aug_name, aug, params.probability))
# Sort by stage order
def sort_key(item: tuple[str, BaseAugmentation, float]) -> int:
name, _, _ = item
stage = self.STAGE_MAPPING[name]
return self.STAGE_ORDER.index(stage)
return sorted(augmentations, key=sort_key)
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
) -> AugmentationResult:
"""
Apply augmentation pipeline to image.
Args:
image: Input image (H, W, C) as numpy array with dtype uint8.
bboxes: Optional bounding boxes in YOLO format (N, 5).
Returns:
AugmentationResult with augmented image and optionally adjusted bboxes.
"""
current_image = image.copy()
current_bboxes = bboxes.copy() if bboxes is not None else None
applied_augmentations: list[str] = []
for name, aug, probability in self._augmentations:
if self._rng.random() < probability:
result = aug.apply(current_image, current_bboxes, self._rng)
current_image = result.image
if result.bboxes is not None and self.config.preserve_bboxes:
current_bboxes = result.bboxes
applied_augmentations.append(name)
return AugmentationResult(
image=current_image,
bboxes=current_bboxes,
metadata={"applied_augmentations": applied_augmentations},
)
def preview(
self,
image: np.ndarray,
augmentation_name: str,
) -> np.ndarray:
"""
Preview a single augmentation deterministically.
Args:
image: Input image.
augmentation_name: Name of augmentation to preview.
Returns:
Augmented image.
Raises:
ValueError: If augmentation_name is not recognized.
"""
if augmentation_name not in AUGMENTATION_REGISTRY:
raise ValueError(f"Unknown augmentation: {augmentation_name}")
params: AugmentationParams = getattr(self.config, augmentation_name)
aug = AUGMENTATION_REGISTRY[augmentation_name](params.params)
# Use deterministic RNG for preview
preview_rng = np.random.default_rng(42)
result = aug.apply(image.copy(), rng=preview_rng)
return result.image
def get_available_augmentations() -> list[dict[str, Any]]:
"""
Get list of available augmentations with metadata.
Returns:
List of dictionaries with augmentation info.
"""
augmentations = []
for name, aug_class in AUGMENTATION_REGISTRY.items():
augmentations.append({
"name": name,
"description": aug_class.__doc__ or "",
"affects_geometry": aug_class.affects_geometry,
"stage": AugmentationPipeline.STAGE_MAPPING[name],
})
return augmentations

View File

@@ -0,0 +1,212 @@
"""
Predefined augmentation presets for common document scenarios.
Presets provide ready-to-use configurations optimized for different
use cases, from conservative (preserves text readability) to aggressive
(simulates poor document quality).
"""
from typing import Any
from shared.augmentation.config import AugmentationConfig, AugmentationParams
PRESETS: dict[str, dict[str, Any]] = {
"conservative": {
"description": "Safe augmentations that preserve text readability",
"config": {
"lighting_variation": {
"enabled": True,
"probability": 0.5,
"params": {
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (3, 10)},
},
},
},
"moderate": {
"description": "Balanced augmentations for typical document degradation",
"config": {
"lighting_variation": {
"enabled": True,
"probability": 0.5,
"params": {
"brightness_range": (-0.15, 0.15),
"contrast_range": (0.85, 1.15),
},
},
"shadow": {
"enabled": True,
"probability": 0.3,
"params": {"num_shadows": (1, 2), "opacity": (0.2, 0.35)},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (5, 12)},
},
"gaussian_blur": {
"enabled": True,
"probability": 0.2,
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.0)},
},
"paper_texture": {
"enabled": True,
"probability": 0.3,
"params": {"intensity": (0.05, 0.12)},
},
},
},
"aggressive": {
"description": "Heavy augmentations simulating poor scan quality",
"config": {
"perspective_warp": {
"enabled": True,
"probability": 0.3,
"params": {"max_warp": 0.02},
},
"wrinkle": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": 0.3, "num_wrinkles": (2, 4)},
},
"stain": {
"enabled": True,
"probability": 0.3,
"params": {
"num_stains": (1, 2),
"max_radius_ratio": 0.08,
"opacity": (0.1, 0.25),
},
},
"lighting_variation": {
"enabled": True,
"probability": 0.6,
"params": {
"brightness_range": (-0.2, 0.2),
"contrast_range": (0.8, 1.2),
},
},
"shadow": {
"enabled": True,
"probability": 0.4,
"params": {"num_shadows": (1, 2), "opacity": (0.25, 0.4)},
},
"gaussian_blur": {
"enabled": True,
"probability": 0.3,
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
},
"motion_blur": {
"enabled": True,
"probability": 0.2,
"params": {"kernel_size": (5, 7), "angle_range": (-30, 30)},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.4,
"params": {"std": (8, 18)},
},
"paper_texture": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": (0.08, 0.15)},
},
"scanner_artifacts": {
"enabled": True,
"probability": 0.3,
"params": {"line_probability": 0.4, "dust_probability": 0.5},
},
"edge_damage": {
"enabled": True,
"probability": 0.2,
"params": {"max_damage_ratio": 0.04},
},
},
},
"scanned_document": {
"description": "Simulates typical scanned document artifacts",
"config": {
"scanner_artifacts": {
"enabled": True,
"probability": 0.5,
"params": {"line_probability": 0.4, "dust_probability": 0.5},
},
"paper_texture": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": (0.05, 0.12)},
},
"lighting_variation": {
"enabled": True,
"probability": 0.3,
"params": {
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (5, 12)},
},
},
},
}
def get_preset_config(preset_name: str) -> dict[str, Any]:
"""
Get the configuration dictionary for a preset.
Args:
preset_name: Name of the preset.
Returns:
Configuration dictionary.
Raises:
ValueError: If preset is not found.
"""
if preset_name not in PRESETS:
raise ValueError(
f"Unknown preset: {preset_name}. "
f"Available presets: {list(PRESETS.keys())}"
)
return PRESETS[preset_name]["config"]
def create_config_from_preset(preset_name: str) -> AugmentationConfig:
"""
Create an AugmentationConfig from a preset.
Args:
preset_name: Name of the preset.
Returns:
AugmentationConfig instance.
Raises:
ValueError: If preset is not found.
"""
config_dict = get_preset_config(preset_name)
return AugmentationConfig.from_dict(config_dict)
def list_presets() -> list[dict[str, str]]:
"""
List all available presets.
Returns:
List of dictionaries with name and description.
"""
return [
{"name": name, "description": preset["description"]}
for name, preset in PRESETS.items()
]

View File

@@ -0,0 +1,13 @@
"""
Augmentation transform implementations.
Each module contains related augmentation classes:
- geometric.py: Perspective warp and other geometric transforms
- degradation.py: Wrinkle, edge damage, stain effects
- lighting.py: Lighting variation and shadow effects
- blur.py: Gaussian and motion blur
- noise.py: Gaussian and salt-pepper noise
- texture.py: Paper texture and scanner artifacts
"""
# Will be populated as transforms are implemented

View File

@@ -0,0 +1,144 @@
"""
Blur augmentation transforms.
Provides blur effects for document image augmentation:
- GaussianBlur: Simulates out-of-focus capture
- MotionBlur: Simulates camera/document movement during capture
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class GaussianBlur(BaseAugmentation):
"""
Applies Gaussian blur to the image.
Simulates out-of-focus capture or low-quality optics.
Conservative defaults to preserve text readability.
Parameters:
kernel_size: Blur kernel size, int or (min, max) tuple (default: (3, 5)).
sigma: Blur sigma, float or (min, max) tuple (default: (0.5, 1.5)).
"""
name = "gaussian_blur"
affects_geometry = False
def _validate_params(self) -> None:
kernel_size = self.params.get("kernel_size", (3, 5))
if isinstance(kernel_size, int):
if kernel_size < 1 or kernel_size % 2 == 0:
raise ValueError("kernel_size must be a positive odd integer")
elif isinstance(kernel_size, tuple):
if kernel_size[0] < 1 or kernel_size[1] < kernel_size[0]:
raise ValueError("kernel_size tuple must be (min, max) with min >= 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
kernel_size = self.params.get("kernel_size", (3, 5))
sigma = self.params.get("sigma", (0.5, 1.5))
if isinstance(kernel_size, tuple):
# Choose random odd kernel size
min_k, max_k = kernel_size
possible_sizes = [k for k in range(min_k, max_k + 1) if k % 2 == 1]
if not possible_sizes:
possible_sizes = [min_k if min_k % 2 == 1 else min_k + 1]
kernel_size = rng.choice(possible_sizes)
if isinstance(sigma, tuple):
sigma = rng.uniform(sigma[0], sigma[1])
# Ensure kernel size is odd
if kernel_size % 2 == 0:
kernel_size += 1
# Apply Gaussian blur
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
return AugmentationResult(
image=blurred,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"kernel_size": kernel_size, "sigma": sigma},
)
def get_preview_params(self) -> dict:
return {"kernel_size": 5, "sigma": 1.5}
class MotionBlur(BaseAugmentation):
"""
Applies motion blur to the image.
Simulates camera shake or document movement during capture.
Parameters:
kernel_size: Blur kernel size, int or (min, max) tuple (default: (5, 9)).
angle_range: Motion angle range in degrees (default: (-45, 45)).
"""
name = "motion_blur"
affects_geometry = False
def _validate_params(self) -> None:
kernel_size = self.params.get("kernel_size", (5, 9))
if isinstance(kernel_size, int):
if kernel_size < 3:
raise ValueError("kernel_size must be at least 3")
elif isinstance(kernel_size, tuple):
if kernel_size[0] < 3:
raise ValueError("kernel_size min must be at least 3")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
kernel_size = self.params.get("kernel_size", (5, 9))
angle_range = self.params.get("angle_range", (-45, 45))
if isinstance(kernel_size, tuple):
kernel_size = rng.integers(kernel_size[0], kernel_size[1] + 1)
angle = rng.uniform(angle_range[0], angle_range[1])
# Create motion blur kernel
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
# Draw a line in the center of the kernel
center = kernel_size // 2
angle_rad = np.deg2rad(angle)
for i in range(kernel_size):
offset = i - center
x = int(center + offset * np.cos(angle_rad))
y = int(center + offset * np.sin(angle_rad))
if 0 <= x < kernel_size and 0 <= y < kernel_size:
kernel[y, x] = 1.0
# Normalize kernel
kernel = kernel / kernel.sum() if kernel.sum() > 0 else kernel
# Apply motion blur
blurred = cv2.filter2D(image, -1, kernel)
return AugmentationResult(
image=blurred,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"kernel_size": kernel_size, "angle": angle},
)
def get_preview_params(self) -> dict:
return {"kernel_size": 7, "angle_range": (-30, 30)}

View File

@@ -0,0 +1,259 @@
"""
Degradation augmentation transforms.
Provides degradation effects for document image augmentation:
- Wrinkle: Paper wrinkle/crease simulation
- EdgeDamage: Damaged/torn edge effects
- Stain: Coffee stain/smudge effects
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class Wrinkle(BaseAugmentation):
"""
Simulates paper wrinkles/creases using displacement mapping.
Document-friendly: Uses subtle displacement to preserve text readability.
Parameters:
intensity: Wrinkle intensity (0-1) (default: 0.3).
num_wrinkles: Number of wrinkles, int or (min, max) tuple (default: (2, 5)).
"""
name = "wrinkle"
affects_geometry = False
def _validate_params(self) -> None:
intensity = self.params.get("intensity", 0.3)
if not (0 < intensity <= 1):
raise ValueError("intensity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
intensity = self.params.get("intensity", 0.3)
num_wrinkles = self.params.get("num_wrinkles", (2, 5))
if isinstance(num_wrinkles, tuple):
num_wrinkles = rng.integers(num_wrinkles[0], num_wrinkles[1] + 1)
# Create displacement maps
displacement_x = np.zeros((h, w), dtype=np.float32)
displacement_y = np.zeros((h, w), dtype=np.float32)
for _ in range(num_wrinkles):
# Random wrinkle parameters
angle = rng.uniform(0, np.pi)
x0 = rng.uniform(0, w)
y0 = rng.uniform(0, h)
length = rng.uniform(0.3, 0.8) * min(h, w)
width = rng.uniform(0.02, 0.05) * min(h, w)
# Create coordinate grids
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
# Distance from wrinkle line
dx = (xx - x0) * np.cos(angle) + (yy - y0) * np.sin(angle)
dy = -(xx - x0) * np.sin(angle) + (yy - y0) * np.cos(angle)
# Gaussian falloff perpendicular to wrinkle
mask = np.exp(-dy**2 / (2 * width**2))
mask *= (np.abs(dx) < length / 2).astype(np.float32)
# Displacement perpendicular to wrinkle
disp_amount = intensity * rng.uniform(2, 8)
displacement_x += mask * disp_amount * np.sin(angle)
displacement_y += mask * disp_amount * np.cos(angle)
# Create remap coordinates
map_x = (np.arange(w)[np.newaxis, :] + displacement_x).astype(np.float32)
map_y = (np.arange(h)[:, np.newaxis] + displacement_y).astype(np.float32)
# Apply displacement
augmented = cv2.remap(
image, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT
)
# Add subtle shading along wrinkles
max_disp = np.max(np.abs(displacement_y)) + 1e-6
shading = 1 - 0.1 * intensity * np.abs(displacement_y) / max_disp
shading = shading[:, :, np.newaxis]
augmented = (augmented.astype(np.float32) * shading).astype(np.uint8)
return AugmentationResult(
image=augmented,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_wrinkles": num_wrinkles, "intensity": intensity},
)
def get_preview_params(self) -> dict:
return {"intensity": 0.5, "num_wrinkles": 3}
class EdgeDamage(BaseAugmentation):
"""
Adds damaged/torn edge effects to the image.
Simulates worn or torn document edges.
Parameters:
max_damage_ratio: Maximum proportion of edge to damage (default: 0.05).
edges: Which edges to potentially damage (default: all).
"""
name = "edge_damage"
affects_geometry = False
def _validate_params(self) -> None:
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
if not (0 < max_damage_ratio <= 0.2):
raise ValueError("max_damage_ratio must be between 0 and 0.2")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
edges = self.params.get("edges", ["top", "bottom", "left", "right"])
output = image.copy()
# Select random edge to damage
edge = rng.choice(edges)
damage_size = int(max_damage_ratio * min(h, w))
if edge == "top":
# Create irregular top edge
for x in range(w):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
# Random color (white or darker)
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[:depth, x] = color
elif edge == "bottom":
for x in range(w):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[h - depth:, x] = color
elif edge == "left":
for y in range(h):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[y, :depth] = color
else: # right
for y in range(h):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[y, w - depth:] = color
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"edge": edge, "damage_size": damage_size},
)
def get_preview_params(self) -> dict:
return {"max_damage_ratio": 0.08}
class Stain(BaseAugmentation):
"""
Adds coffee stain/smudge effects to the image.
Simulates accidental stains on documents.
Parameters:
num_stains: Number of stains, int or (min, max) tuple (default: (1, 3)).
max_radius_ratio: Maximum stain radius as ratio of image size (default: 0.1).
opacity: Stain opacity, float or (min, max) tuple (default: (0.1, 0.3)).
"""
name = "stain"
affects_geometry = False
def _validate_params(self) -> None:
opacity = self.params.get("opacity", (0.1, 0.3))
if isinstance(opacity, (int, float)):
if not (0 < opacity <= 1):
raise ValueError("opacity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
num_stains = self.params.get("num_stains", (1, 3))
max_radius_ratio = self.params.get("max_radius_ratio", 0.1)
opacity = self.params.get("opacity", (0.1, 0.3))
if isinstance(num_stains, tuple):
num_stains = rng.integers(num_stains[0], num_stains[1] + 1)
if isinstance(opacity, tuple):
opacity = rng.uniform(opacity[0], opacity[1])
output = image.astype(np.float32)
max_radius = int(max_radius_ratio * min(h, w))
for _ in range(num_stains):
# Random stain position and size
cx = rng.integers(max_radius, w - max_radius)
cy = rng.integers(max_radius, h - max_radius)
radius = rng.integers(max_radius // 3, max_radius)
# Create stain mask with irregular edges
yy, xx = np.ogrid[:h, :w]
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
# Add noise to make edges irregular
noise = rng.uniform(0.8, 1.2, (h, w))
mask = (dist < radius * noise).astype(np.float32)
# Blur for soft edges
mask = cv2.GaussianBlur(mask, (21, 21), 0)
# Random stain color (brownish/yellowish)
stain_color = np.array([
rng.integers(180, 220), # R
rng.integers(160, 200), # G
rng.integers(120, 160), # B
], dtype=np.float32)
# Apply stain
mask_3d = mask[:, :, np.newaxis]
output = output * (1 - mask_3d * opacity) + stain_color * mask_3d * opacity
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_stains": num_stains, "opacity": opacity},
)
def get_preview_params(self) -> dict:
return {"num_stains": 2, "max_radius_ratio": 0.1, "opacity": 0.25}

View File

@@ -0,0 +1,145 @@
"""
Geometric augmentation transforms.
Provides geometric transforms for document image augmentation:
- PerspectiveWarp: Subtle perspective distortion
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class PerspectiveWarp(BaseAugmentation):
"""
Applies subtle perspective transformation to the image.
Simulates viewing document at slight angle. Very conservative
by default to preserve text readability.
IMPORTANT: This transform affects bounding box coordinates.
Parameters:
max_warp: Maximum warp as proportion of image size (default: 0.02).
"""
name = "perspective_warp"
affects_geometry = True
def _validate_params(self) -> None:
max_warp = self.params.get("max_warp", 0.02)
if not (0 < max_warp <= 0.1):
raise ValueError("max_warp must be between 0 and 0.1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
max_warp = self.params.get("max_warp", 0.02)
# Original corners
src_pts = np.float32([
[0, 0],
[w, 0],
[w, h],
[0, h],
])
# Add random perturbations to corners
max_offset = max_warp * min(h, w)
dst_pts = src_pts.copy()
for i in range(4):
dst_pts[i, 0] += rng.uniform(-max_offset, max_offset)
dst_pts[i, 1] += rng.uniform(-max_offset, max_offset)
# Compute perspective transform matrix
transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
# Apply perspective transform
warped = cv2.warpPerspective(
image, transform_matrix, (w, h),
borderMode=cv2.BORDER_REPLICATE
)
# Transform bounding boxes if present
transformed_bboxes = None
if bboxes is not None:
transformed_bboxes = self._transform_bboxes(
bboxes, transform_matrix, w, h
)
return AugmentationResult(
image=warped,
bboxes=transformed_bboxes,
transform_matrix=transform_matrix,
metadata={"max_warp": max_warp},
)
def _transform_bboxes(
self,
bboxes: np.ndarray,
transform_matrix: np.ndarray,
w: int,
h: int,
) -> np.ndarray:
"""Transform bounding boxes using perspective matrix."""
if len(bboxes) == 0:
return bboxes.copy()
transformed = []
for bbox in bboxes:
class_id, x_center, y_center, width, height = bbox
# Convert normalized coords to pixel coords
x_center_px = x_center * w
y_center_px = y_center * h
width_px = width * w
height_px = height * h
# Get corner points
x1 = x_center_px - width_px / 2
y1 = y_center_px - height_px / 2
x2 = x_center_px + width_px / 2
y2 = y_center_px + height_px / 2
# Transform all 4 corners
corners = np.float32([
[x1, y1],
[x2, y1],
[x2, y2],
[x1, y2],
]).reshape(-1, 1, 2)
transformed_corners = cv2.perspectiveTransform(corners, transform_matrix)
transformed_corners = transformed_corners.reshape(-1, 2)
# Get bounding box of transformed corners
new_x1 = np.min(transformed_corners[:, 0])
new_y1 = np.min(transformed_corners[:, 1])
new_x2 = np.max(transformed_corners[:, 0])
new_y2 = np.max(transformed_corners[:, 1])
# Convert back to normalized center format
new_width = (new_x2 - new_x1) / w
new_height = (new_y2 - new_y1) / h
new_x_center = ((new_x1 + new_x2) / 2) / w
new_y_center = ((new_y1 + new_y2) / 2) / h
# Clamp to valid range
new_x_center = np.clip(new_x_center, 0, 1)
new_y_center = np.clip(new_y_center, 0, 1)
new_width = np.clip(new_width, 0, 1)
new_height = np.clip(new_height, 0, 1)
transformed.append([class_id, new_x_center, new_y_center, new_width, new_height])
return np.array(transformed, dtype=np.float32)
def get_preview_params(self) -> dict:
return {"max_warp": 0.03}

View File

@@ -0,0 +1,167 @@
"""
Lighting augmentation transforms.
Provides lighting effects for document image augmentation:
- LightingVariation: Adjusts brightness and contrast
- Shadow: Adds shadow overlay effects
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class LightingVariation(BaseAugmentation):
"""
Adjusts image brightness and contrast.
Simulates different lighting conditions during document capture.
Safe for documents with conservative default parameters.
Parameters:
brightness_range: (min, max) brightness adjustment (default: (-0.1, 0.1)).
contrast_range: (min, max) contrast multiplier (default: (0.9, 1.1)).
"""
name = "lighting_variation"
affects_geometry = False
def _validate_params(self) -> None:
brightness = self.params.get("brightness_range", (-0.1, 0.1))
contrast = self.params.get("contrast_range", (0.9, 1.1))
if not isinstance(brightness, tuple) or len(brightness) != 2:
raise ValueError("brightness_range must be a (min, max) tuple")
if not isinstance(contrast, tuple) or len(contrast) != 2:
raise ValueError("contrast_range must be a (min, max) tuple")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
brightness_range = self.params.get("brightness_range", (-0.1, 0.1))
contrast_range = self.params.get("contrast_range", (0.9, 1.1))
# Random brightness and contrast
brightness = rng.uniform(brightness_range[0], brightness_range[1])
contrast = rng.uniform(contrast_range[0], contrast_range[1])
# Apply adjustments
adjusted = image.astype(np.float32)
# Contrast adjustment (multiply around mean)
mean = adjusted.mean()
adjusted = (adjusted - mean) * contrast + mean
# Brightness adjustment (add offset)
adjusted = adjusted + brightness * 255
# Clip and convert back
adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
return AugmentationResult(
image=adjusted,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"brightness": brightness, "contrast": contrast},
)
def get_preview_params(self) -> dict:
return {"brightness_range": (-0.15, 0.15), "contrast_range": (0.85, 1.15)}
class Shadow(BaseAugmentation):
"""
Adds shadow overlay effects to the image.
Simulates shadows from objects or hands during document capture.
Parameters:
num_shadows: Number of shadow regions, int or (min, max) tuple (default: (1, 2)).
opacity: Shadow darkness, float or (min, max) tuple (default: (0.2, 0.4)).
"""
name = "shadow"
affects_geometry = False
def _validate_params(self) -> None:
opacity = self.params.get("opacity", (0.2, 0.4))
if isinstance(opacity, (int, float)):
if not (0 <= opacity <= 1):
raise ValueError("opacity must be between 0 and 1")
elif isinstance(opacity, tuple):
if not (0 <= opacity[0] <= opacity[1] <= 1):
raise ValueError("opacity tuple must be in range [0, 1]")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
num_shadows = self.params.get("num_shadows", (1, 2))
opacity = self.params.get("opacity", (0.2, 0.4))
if isinstance(num_shadows, tuple):
num_shadows = rng.integers(num_shadows[0], num_shadows[1] + 1)
if isinstance(opacity, tuple):
opacity = rng.uniform(opacity[0], opacity[1])
h, w = image.shape[:2]
output = image.astype(np.float32)
for _ in range(num_shadows):
# Generate random shadow polygon
num_vertices = rng.integers(3, 6)
vertices = []
# Start from a random edge
edge = rng.integers(0, 4)
if edge == 0: # Top
start = (rng.integers(0, w), 0)
elif edge == 1: # Right
start = (w, rng.integers(0, h))
elif edge == 2: # Bottom
start = (rng.integers(0, w), h)
else: # Left
start = (0, rng.integers(0, h))
vertices.append(start)
# Add random vertices
for _ in range(num_vertices - 1):
x = rng.integers(0, w)
y = rng.integers(0, h)
vertices.append((x, y))
# Create shadow mask
mask = np.zeros((h, w), dtype=np.float32)
pts = np.array(vertices, dtype=np.int32).reshape((-1, 1, 2))
cv2.fillPoly(mask, [pts], 1.0)
# Blur the mask for soft edges
blur_size = max(31, min(h, w) // 10)
if blur_size % 2 == 0:
blur_size += 1
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
# Apply shadow
shadow_factor = 1 - opacity * mask[:, :, np.newaxis]
output = output * shadow_factor
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_shadows": num_shadows, "opacity": opacity},
)
def get_preview_params(self) -> dict:
return {"num_shadows": 1, "opacity": 0.3}

View File

@@ -0,0 +1,142 @@
"""
Noise augmentation transforms.
Provides noise effects for document image augmentation:
- GaussianNoise: Adds Gaussian noise to simulate sensor noise
- SaltPepper: Adds salt and pepper noise for impulse noise effects
"""
from typing import Any
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class GaussianNoise(BaseAugmentation):
"""
Adds Gaussian noise to the image.
Simulates sensor noise from cameras or scanners.
Document-safe with conservative default parameters.
Parameters:
mean: Mean of the Gaussian noise (default: 0).
std: Standard deviation, can be int or (min, max) tuple (default: (5, 15)).
"""
name = "gaussian_noise"
affects_geometry = False
def _validate_params(self) -> None:
std = self.params.get("std", (5, 15))
if isinstance(std, (int, float)):
if std < 0:
raise ValueError("std must be non-negative")
elif isinstance(std, tuple):
if len(std) != 2 or std[0] < 0 or std[1] < std[0]:
raise ValueError("std tuple must be (min, max) with min <= max >= 0")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
mean = self.params.get("mean", 0)
std = self.params.get("std", (5, 15))
if isinstance(std, tuple):
std = rng.uniform(std[0], std[1])
# Generate noise
noise = rng.normal(mean, std, image.shape).astype(np.float32)
# Apply noise
noisy = image.astype(np.float32) + noise
noisy = np.clip(noisy, 0, 255).astype(np.uint8)
return AugmentationResult(
image=noisy,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"applied_std": std},
)
def get_preview_params(self) -> dict[str, Any]:
return {"mean": 0, "std": 15}
class SaltPepper(BaseAugmentation):
"""
Adds salt and pepper (impulse) noise to the image.
Simulates defects from damaged sensors or transmission errors.
Very sparse by default to preserve document readability.
Parameters:
amount: Proportion of pixels to affect, can be float or (min, max) tuple.
Default: (0.001, 0.005) for very sparse noise.
salt_vs_pepper: Ratio of salt to pepper (default: 0.5 for equal amounts).
"""
name = "salt_pepper"
affects_geometry = False
def _validate_params(self) -> None:
amount = self.params.get("amount", (0.001, 0.005))
if isinstance(amount, (int, float)):
if not (0 <= amount <= 1):
raise ValueError("amount must be between 0 and 1")
elif isinstance(amount, tuple):
if len(amount) != 2 or not (0 <= amount[0] <= amount[1] <= 1):
raise ValueError("amount tuple must be (min, max) in range [0, 1]")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
amount = self.params.get("amount", (0.001, 0.005))
salt_vs_pepper = self.params.get("salt_vs_pepper", 0.5)
if isinstance(amount, tuple):
amount = rng.uniform(amount[0], amount[1])
# Copy image
output = image.copy()
h, w = image.shape[:2]
total_pixels = h * w
# Calculate number of salt and pepper pixels
num_salt = int(total_pixels * amount * salt_vs_pepper)
num_pepper = int(total_pixels * amount * (1 - salt_vs_pepper))
# Add salt (white pixels)
if num_salt > 0:
salt_coords = (
rng.integers(0, h, num_salt),
rng.integers(0, w, num_salt),
)
output[salt_coords] = 255
# Add pepper (black pixels)
if num_pepper > 0:
pepper_coords = (
rng.integers(0, h, num_pepper),
rng.integers(0, w, num_pepper),
)
output[pepper_coords] = 0
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"applied_amount": amount},
)
def get_preview_params(self) -> dict[str, Any]:
return {"amount": 0.01, "salt_vs_pepper": 0.5}

View File

@@ -0,0 +1,159 @@
"""
Texture augmentation transforms.
Provides texture effects for document image augmentation:
- PaperTexture: Adds paper grain/texture
- ScannerArtifacts: Adds scanner line and dust artifacts
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class PaperTexture(BaseAugmentation):
"""
Adds paper texture/grain to the image.
Simulates different paper types and ages.
Parameters:
texture_type: Type of texture ("random", "fine", "coarse") (default: "random").
intensity: Texture intensity, float or (min, max) tuple (default: (0.05, 0.15)).
"""
name = "paper_texture"
affects_geometry = False
def _validate_params(self) -> None:
intensity = self.params.get("intensity", (0.05, 0.15))
if isinstance(intensity, (int, float)):
if not (0 < intensity <= 1):
raise ValueError("intensity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
texture_type = self.params.get("texture_type", "random")
intensity = self.params.get("intensity", (0.05, 0.15))
if texture_type == "random":
texture_type = rng.choice(["fine", "coarse"])
if isinstance(intensity, tuple):
intensity = rng.uniform(intensity[0], intensity[1])
# Generate base noise
if texture_type == "fine":
# Fine grain texture
noise = rng.uniform(-1, 1, (h, w)).astype(np.float32)
noise = cv2.GaussianBlur(noise, (3, 3), 0)
else:
# Coarse texture
# Generate at lower resolution and upscale
small_h, small_w = h // 4, w // 4
noise = rng.uniform(-1, 1, (small_h, small_w)).astype(np.float32)
noise = cv2.resize(noise, (w, h), interpolation=cv2.INTER_LINEAR)
noise = cv2.GaussianBlur(noise, (5, 5), 0)
# Apply texture
output = image.astype(np.float32)
noise_3d = noise[:, :, np.newaxis] * intensity * 255
output = output + noise_3d
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"texture_type": texture_type, "intensity": intensity},
)
def get_preview_params(self) -> dict:
return {"texture_type": "coarse", "intensity": 0.15}
class ScannerArtifacts(BaseAugmentation):
"""
Adds scanner artifacts to the image.
Simulates scanner imperfections like lines and dust spots.
Parameters:
line_probability: Probability of adding scan lines (default: 0.3).
dust_probability: Probability of adding dust spots (default: 0.4).
"""
name = "scanner_artifacts"
affects_geometry = False
def _validate_params(self) -> None:
line_prob = self.params.get("line_probability", 0.3)
dust_prob = self.params.get("dust_probability", 0.4)
if not (0 <= line_prob <= 1):
raise ValueError("line_probability must be between 0 and 1")
if not (0 <= dust_prob <= 1):
raise ValueError("dust_probability must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
line_probability = self.params.get("line_probability", 0.3)
dust_probability = self.params.get("dust_probability", 0.4)
output = image.copy()
# Add scan lines
if rng.random() < line_probability:
num_lines = rng.integers(1, 4)
for _ in range(num_lines):
y = rng.integers(0, h)
thickness = rng.integers(1, 3)
# Light or dark line
color = rng.integers(200, 240) if rng.random() > 0.5 else rng.integers(50, 100)
# Make line partially transparent
alpha = rng.uniform(0.3, 0.6)
for dy in range(thickness):
if y + dy < h:
output[y + dy, :] = (
output[y + dy, :].astype(np.float32) * (1 - alpha) +
color * alpha
).astype(np.uint8)
# Add dust spots
if rng.random() < dust_probability:
num_dust = rng.integers(5, 20)
for _ in range(num_dust):
x = rng.integers(0, w)
y = rng.integers(0, h)
radius = rng.integers(1, 3)
# Dark dust spot
color = rng.integers(50, 120)
cv2.circle(output, (x, y), radius, int(color), -1)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={
"line_probability": line_probability,
"dust_probability": dust_probability,
},
)
def get_preview_params(self) -> dict:
return {"line_probability": 0.8, "dust_probability": 0.8}

View File

@@ -0,0 +1,5 @@
"""Shared training utilities."""
from .yolo_trainer import YOLOTrainer, TrainingConfig, TrainingResult
__all__ = ["YOLOTrainer", "TrainingConfig", "TrainingResult"]

View File

@@ -0,0 +1,239 @@
"""
Shared YOLO Training Module
Unified training logic for both CLI and Web API.
"""
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
logger = logging.getLogger(__name__)
@dataclass
class TrainingConfig:
"""Training configuration."""
# Model settings
model_path: str = "yolo11n.pt" # Base model or path to trained model
data_yaml: str = "" # Path to data.yaml
# Training hyperparameters
epochs: int = 100
batch_size: int = 16
image_size: int = 640
learning_rate: float = 0.01
device: str = "0"
# Output settings
project: str = "runs/train"
name: str = "invoice_fields"
# Performance settings
workers: int = 4
cache: bool = False
# Resume settings
resume: bool = False
resume_from: str | None = None # Path to checkpoint
# Document-specific augmentation (optimized for invoices)
augmentation: dict[str, Any] = field(default_factory=lambda: {
"degrees": 5.0,
"translate": 0.05,
"scale": 0.2,
"shear": 0.0,
"perspective": 0.0,
"flipud": 0.0,
"fliplr": 0.0,
"mosaic": 0.0,
"mixup": 0.0,
"hsv_h": 0.0,
"hsv_s": 0.1,
"hsv_v": 0.2,
})
@dataclass
class TrainingResult:
"""Training result."""
success: bool
model_path: str | None = None
metrics: dict[str, float] = field(default_factory=dict)
error: str | None = None
save_dir: str | None = None
class YOLOTrainer:
"""Unified YOLO trainer for CLI and Web API."""
def __init__(
self,
config: TrainingConfig,
log_callback: Callable[[str, str], None] | None = None,
):
"""
Initialize trainer.
Args:
config: Training configuration
log_callback: Optional callback for logging (level, message)
"""
self.config = config
self._log_callback = log_callback
def _log(self, level: str, message: str) -> None:
"""Log a message."""
if self._log_callback:
self._log_callback(level, message)
if level == "INFO":
logger.info(message)
elif level == "ERROR":
logger.error(message)
elif level == "WARNING":
logger.warning(message)
def validate_config(self) -> tuple[bool, str | None]:
"""
Validate training configuration.
Returns:
Tuple of (is_valid, error_message)
"""
# Check model path
model_path = Path(self.config.model_path)
if not model_path.suffix == ".pt":
# Could be a model name like "yolo11n.pt" which is downloaded
if not model_path.name.startswith("yolo"):
return False, f"Invalid model: {self.config.model_path}"
elif not model_path.exists():
return False, f"Model file not found: {self.config.model_path}"
# Check data.yaml
if not self.config.data_yaml:
return False, "data_yaml is required"
data_yaml = Path(self.config.data_yaml)
if not data_yaml.exists():
return False, f"data.yaml not found: {self.config.data_yaml}"
return True, None
def train(self) -> TrainingResult:
"""
Run YOLO training.
Returns:
TrainingResult with model path and metrics
"""
try:
from ultralytics import YOLO
except ImportError:
return TrainingResult(
success=False,
error="Ultralytics (YOLO) not installed. Install with: pip install ultralytics",
)
# Validate config
is_valid, error = self.validate_config()
if not is_valid:
return TrainingResult(success=False, error=error)
self._log("INFO", f"Starting YOLO training")
self._log("INFO", f" Model: {self.config.model_path}")
self._log("INFO", f" Data: {self.config.data_yaml}")
self._log("INFO", f" Epochs: {self.config.epochs}")
self._log("INFO", f" Batch size: {self.config.batch_size}")
self._log("INFO", f" Image size: {self.config.image_size}")
try:
# Load model
if self.config.resume and self.config.resume_from:
resume_path = Path(self.config.resume_from)
if resume_path.exists():
self._log("INFO", f"Resuming from: {resume_path}")
model = YOLO(str(resume_path))
else:
model = YOLO(self.config.model_path)
else:
model = YOLO(self.config.model_path)
# Build training arguments
train_args = {
"data": str(Path(self.config.data_yaml).absolute()),
"epochs": self.config.epochs,
"batch": self.config.batch_size,
"imgsz": self.config.image_size,
"lr0": self.config.learning_rate,
"device": self.config.device,
"project": self.config.project,
"name": self.config.name,
"exist_ok": True,
"pretrained": True,
"verbose": True,
"workers": self.config.workers,
"cache": self.config.cache,
"resume": self.config.resume and self.config.resume_from is not None,
}
# Add augmentation settings
train_args.update(self.config.augmentation)
# Train
results = model.train(**train_args)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
self._log("INFO", f"Training completed successfully")
self._log("INFO", f" Best model: {best_model}")
self._log("INFO", f" mAP@0.5: {metrics.get('mAP50', 'N/A')}")
return TrainingResult(
success=True,
model_path=str(best_model) if best_model.exists() else None,
metrics=metrics,
save_dir=str(results.save_dir),
)
except Exception as e:
self._log("ERROR", f"Training failed: {e}")
return TrainingResult(success=False, error=str(e))
def validate(self, split: str = "val") -> dict[str, float]:
"""
Run validation on trained model.
Args:
split: Dataset split to validate on ("val" or "test")
Returns:
Validation metrics
"""
try:
from ultralytics import YOLO
model = YOLO(self.config.model_path)
metrics = model.val(data=self.config.data_yaml, split=split)
return {
"mAP50": metrics.box.map50,
"mAP50-95": metrics.box.map,
"precision": metrics.box.mp,
"recall": metrics.box.mr,
}
except Exception as e:
self._log("ERROR", f"Validation failed: {e}")
return {}

View File

@@ -199,67 +199,63 @@ def main():
db.close()
return
# Start training
# Start training using shared trainer
print("\n" + "=" * 60)
print("Starting YOLO Training")
print("=" * 60)
from ultralytics import YOLO
from shared.training import YOLOTrainer, TrainingConfig
# Load model
# Determine resume checkpoint
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
if args.resume and last_checkpoint.exists():
print(f"Resuming from: {last_checkpoint}")
model = YOLO(str(last_checkpoint))
else:
model = YOLO(args.model)
resume_from = str(last_checkpoint) if args.resume and last_checkpoint.exists() else None
# Training arguments
# Create training config
data_yaml = dataset_dir / 'dataset.yaml'
train_args = {
'data': str(data_yaml.absolute()),
'epochs': args.epochs,
'batch': args.batch,
'imgsz': args.imgsz,
'project': args.project,
'name': args.name,
'device': args.device,
'exist_ok': True,
'pretrained': True,
'verbose': True,
'workers': args.workers,
'cache': args.cache,
'resume': args.resume and last_checkpoint.exists(),
# Document-specific augmentation settings
'degrees': 5.0,
'translate': 0.05,
'scale': 0.2,
'shear': 0.0,
'perspective': 0.0,
'flipud': 0.0,
'fliplr': 0.0,
'mosaic': 0.0,
'mixup': 0.0,
'hsv_h': 0.0,
'hsv_s': 0.1,
'hsv_v': 0.2,
}
config = TrainingConfig(
model_path=args.model,
data_yaml=str(data_yaml),
epochs=args.epochs,
batch_size=args.batch,
image_size=args.imgsz,
device=args.device,
project=args.project,
name=args.name,
workers=args.workers,
cache=args.cache,
resume=args.resume,
resume_from=resume_from,
)
# Train
results = model.train(**train_args)
# Run training
trainer = YOLOTrainer(config=config)
result = trainer.train()
if not result.success:
print(f"\nError: Training failed - {result.error}")
db.close()
sys.exit(1)
# Print results
print("\n" + "=" * 60)
print("Training Complete")
print("=" * 60)
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
print(f"Best model: {result.model_path}")
print(f"Save directory: {result.save_dir}")
if result.metrics:
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
# Validate on test set
print("\nRunning validation on test set...")
metrics = model.val(split='test')
print(f"mAP50: {metrics.box.map50:.4f}")
print(f"mAP50-95: {metrics.box.map:.4f}")
if result.model_path:
config.model_path = result.model_path
config.data_yaml = str(data_yaml)
test_trainer = YOLOTrainer(config=config)
test_metrics = test_trainer.validate(split='test')
if test_metrics:
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
# Close database
db.close()