= ({ onNavigate }) => {
|
{doc.annotation_count || 0} annotations
|
+
+ {doc.group_key || '-'}
+ |
{doc.auto_label_status === 'running' && progress && (
diff --git a/frontend/src/components/DatasetDetail.tsx b/frontend/src/components/DatasetDetail.tsx
new file mode 100644
index 0000000..b70b704
--- /dev/null
+++ b/frontend/src/components/DatasetDetail.tsx
@@ -0,0 +1,122 @@
+import React from 'react'
+import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react'
+import { Button } from './Button'
+import { useDatasetDetail } from '../hooks/useDatasets'
+
+interface DatasetDetailProps {
+ datasetId: string
+ onBack: () => void
+}
+
+const SPLIT_STYLES: Record = {
+ train: 'bg-warm-state-info/10 text-warm-state-info',
+ val: 'bg-warm-state-warning/10 text-warm-state-warning',
+ test: 'bg-warm-state-success/10 text-warm-state-success',
+}
+
+export const DatasetDetail: React.FC = ({ datasetId, onBack }) => {
+ const { dataset, isLoading, error } = useDatasetDetail(datasetId)
+
+ if (isLoading) {
+ return (
+
+ Loading dataset...
+
+ )
+ }
+
+ if (error || !dataset) {
+ return (
+
+
+ Failed to load dataset.
+
+ )
+ }
+
+ const statusIcon = dataset.status === 'ready'
+ ?
+ : dataset.status === 'failed'
+ ?
+ :
+
+ return (
+
+ {/* Header */}
+
+
+
+
+
+ {dataset.name} {statusIcon}
+
+ {dataset.description && (
+ {dataset.description}
+ )}
+
+ {dataset.status === 'ready' && (
+
+ )}
+
+
+ {dataset.error_message && (
+
+ {dataset.error_message}
+
+ )}
+
+ {/* Stats */}
+
+ {[
+ ['Documents', dataset.total_documents],
+ ['Images', dataset.total_images],
+ ['Annotations', dataset.total_annotations],
+ ['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
+ ].map(([label, value]) => (
+
+ ))}
+
+
+ {/* Document list */}
+ Documents
+
+
+
+
+ | Document ID |
+ Split |
+ Pages |
+ Annotations |
+
+
+
+ {dataset.documents.map(doc => (
+
+ | {doc.document_id.slice(0, 8)}... |
+
+
+ {doc.split}
+
+ |
+ {doc.page_count} |
+ {doc.annotation_count} |
+
+ ))}
+
+
+
+
+
+ Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
+ {dataset.dataset_path && <> | Path: {dataset.dataset_path}>}
+
+
+ )
+}
diff --git a/frontend/src/components/DocumentDetail.tsx b/frontend/src/components/DocumentDetail.tsx
index cf4da73..258a4ba 100644
--- a/frontend/src/components/DocumentDetail.tsx
+++ b/frontend/src/components/DocumentDetail.tsx
@@ -1,8 +1,9 @@
import React, { useState, useRef, useEffect } from 'react'
-import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
+import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
import { Button } from './Button'
import { useDocumentDetail } from '../hooks/useDocumentDetail'
import { useAnnotations } from '../hooks/useAnnotations'
+import { useDocuments } from '../hooks/useDocuments'
import { documentsApi } from '../api/endpoints/documents'
import type { AnnotationItem } from '../api/types'
@@ -26,7 +27,7 @@ const FIELD_CLASSES: Record = {
}
export const DocumentDetail: React.FC = ({ docId, onBack }) => {
- const { document, annotations, isLoading } = useDocumentDetail(docId)
+ const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
const {
createAnnotation,
updateAnnotation,
@@ -34,10 +35,13 @@ export const DocumentDetail: React.FC = ({ docId, onBack })
isCreating,
isDeleting,
} = useAnnotations(docId)
+ const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
const [selectedId, setSelectedId] = useState(null)
const [zoom, setZoom] = useState(100)
const [isDrawing, setIsDrawing] = useState(false)
+ const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
+ const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
const [selectedClassId, setSelectedClassId] = useState(0)
@@ -426,6 +430,65 @@ export const DocumentDetail: React.FC = ({ docId, onBack })
{new Date(document.created_at).toLocaleDateString()}
+
+ Group
+ {isEditingGroupKey ? (
+
+ setEditGroupKeyValue(e.target.value)}
+ className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
+ placeholder="group key"
+ autoFocus
+ />
+
+
+
+ ) : (
+
+
+ {document.group_key || '-'}
+
+
+
+ )}
+
diff --git a/frontend/src/components/Models.tsx b/frontend/src/components/Models.tsx
index c35052f..bfe2222 100644
--- a/frontend/src/components/Models.tsx
+++ b/frontend/src/components/Models.tsx
@@ -1,73 +1,108 @@
-import React from 'react';
+import React, { useState } from 'react';
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
+import { Loader2, Power, CheckCircle } from 'lucide-react';
import { Button } from './Button';
+import { useModels, useModelDetail } from '../hooks';
+import type { ModelVersionItem } from '../api/types';
-const CHART_DATA = [
- { name: 'Model A', value: 75 },
- { name: 'Model B', value: 82 },
- { name: 'Model C', value: 95 },
- { name: 'Model D', value: 68 },
-];
-
-const METRICS_DATA = [
- { name: 'Precision', value: 88 },
- { name: 'Recall', value: 76 },
- { name: 'F1 Score', value: 91 },
- { name: 'Accuracy', value: 82 },
-];
-
-const JOBS = [
- { id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
- { id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
- { id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
-];
+const formatDate = (dateString: string | null): string => {
+ if (!dateString) return 'N/A';
+ return new Date(dateString).toLocaleString();
+};
export const Models: React.FC = () => {
+ const [selectedModel, setSelectedModel] = useState(null);
+ const { models, isLoading, activateModel, isActivating } = useModels();
+ const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
+
+ // Build chart data from selected model's metrics
+ const metricsData = modelDetail ? [
+ { name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
+ { name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
+ { name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
+ ] : [
+ { name: 'Precision', value: 0 },
+ { name: 'Recall', value: 0 },
+ { name: 'mAP', value: 0 },
+ ];
+
+ // Build comparison chart from all models (with placeholder if empty)
+ const chartData = models.length > 0
+ ? models.slice(0, 4).map(m => ({
+ name: m.version,
+ value: (m.metrics_mAP ?? 0) * 100,
+ }))
+ : [
+ { name: 'Model A', value: 0 },
+ { name: 'Model B', value: 0 },
+ { name: 'Model C', value: 0 },
+ { name: 'Model D', value: 0 },
+ ];
+
return (
{/* Left: Job History */}
Models & History
- Recent Training Jobs
+ Model Versions
-
- {JOBS.map(job => (
-
-
-
- {job.name}
- Started {job.date}
-
-
- {job.status}
-
-
-
- {job.status === 'Running' ? (
-
-
+ {isLoading ? (
+
+
+
+ ) : models.length === 0 ? (
+
+ No model versions found. Complete a training task to create a model version.
+
+ ) : (
+
+ {models.map(model => (
+ setSelectedModel(model)}
+ className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
+ selectedModel?.version_id === model.version_id
+ ? 'border-warm-text-secondary'
+ : 'border-warm-border hover:border-warm-divider'
+ }`}
+ >
+
+
+
+ {model.name}
+ {model.is_active && }
+
+ Trained {formatDate(model.trained_at)}
+
+
+ {model.is_active ? 'Active' : model.status}
+
- ) : (
+
- Success
- {job.success}
+ Documents
+ {model.document_count}
- Performance
- {job.metrics}%
+ mAP
+
+ {model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
+
- Completed
- 100%
+ Version
+ {model.version}
- )}
-
- ))}
-
+
+ ))}
+
+ )}
{/* Right: Model Detail */}
@@ -75,27 +110,34 @@ export const Models: React.FC = () => {
Model Detail
- Completed
+
+ {selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
+
Model name
- Invoices Q4 v2.1
+
+ {selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
+
{/* Chart 1 */}
- Bar Rate Metrics
+ Model Comparison (mAP)
-
+
-
+
- [`${value.toFixed(1)}%`, 'mAP']}
/>
@@ -105,14 +147,17 @@ export const Models: React.FC = () => {
{/* Chart 2 */}
- Entity Extraction Accuracy
+ Performance Metrics
-
+
-
+ [`${value.toFixed(1)}%`, 'Score']}
+ />
@@ -121,14 +166,43 @@ export const Models: React.FC = () => {
-
+ {selectedModel && !selectedModel.is_active ? (
+
+ ) : (
+
+ )}
-
-
+
+
);
-};
\ No newline at end of file
+};
diff --git a/frontend/src/components/Training.tsx b/frontend/src/components/Training.tsx
index 39a9976..13c63eb 100644
--- a/frontend/src/components/Training.tsx
+++ b/frontend/src/components/Training.tsx
@@ -1,113 +1,482 @@
-import React, { useState } from 'react';
-import { Check, AlertCircle } from 'lucide-react';
-import { Button } from './Button';
-import { DocumentStatus } from '../types';
+import React, { useState, useMemo } from 'react'
+import { useQuery } from '@tanstack/react-query'
+import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
+import { Button } from './Button'
+import { AugmentationConfig } from './AugmentationConfig'
+import { useDatasets } from '../hooks/useDatasets'
+import { useTrainingDocuments } from '../hooks/useTraining'
+import { trainingApi } from '../api/endpoints'
+import type { DatasetListItem } from '../api/types'
+import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
-export const Training: React.FC = () => {
- const [split, setSplit] = useState(80);
+type Tab = 'datasets' | 'create'
- const docs = [
- { id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
- { id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
- { id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
- { id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
- { id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
- { id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
- { id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
- ];
+interface TrainingProps {
+ onNavigate?: (view: string, id?: string) => void
+}
+
+const STATUS_STYLES: Record = {
+ ready: 'bg-warm-state-success/10 text-warm-state-success',
+ building: 'bg-warm-state-info/10 text-warm-state-info',
+ training: 'bg-warm-state-info/10 text-warm-state-info',
+ failed: 'bg-warm-state-error/10 text-warm-state-error',
+ pending: 'bg-warm-state-warning/10 text-warm-state-warning',
+ scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
+ running: 'bg-warm-state-info/10 text-warm-state-info',
+}
+
+const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
+ // If there's an active training task, show training status
+ const displayStatus = trainingStatus === 'running'
+ ? 'training'
+ : trainingStatus === 'pending' || trainingStatus === 'scheduled'
+ ? 'pending'
+ : status
return (
-
- {/* Document Selection List */}
-
- Document Selection
-
-
-
+
+ {(displayStatus === 'building' || displayStatus === 'training') && }
+ {displayStatus === 'ready' && }
+ {displayStatus === 'failed' && }
+ {displayStatus}
+
+ )
+}
- {/* Configuration Panel */}
-
-
- Training Configuration
-
-
-
-
- void
+ onSubmit: (config: {
+ name: string
+ config: {
+ model_name?: string
+ base_model_version_id?: string | null
+ epochs: number
+ batch_size: number
+ augmentation?: AugmentationConfigType
+ augmentation_multiplier?: number
+ }
+ }) => void
+ isPending: boolean
+}
+
+const TrainDialog: React.FC = ({ dataset, onClose, onSubmit, isPending }) => {
+ const [name, setName] = useState(`train-${dataset.name}`)
+ const [epochs, setEpochs] = useState(100)
+ const [batchSize, setBatchSize] = useState(16)
+ const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
+ const [baseModelVersionId, setBaseModelVersionId] = useState(null)
+ const [augmentationEnabled, setAugmentationEnabled] = useState(false)
+ const [augmentationConfig, setAugmentationConfig] = useState>({})
+ const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
+
+ // Fetch available trained models
+ const { data: modelsData } = useQuery({
+ queryKey: ['training', 'models', 'completed'],
+ queryFn: () => trainingApi.getModels({ status: 'completed' }),
+ })
+ const completedModels = modelsData?.models ?? []
+
+ const handleSubmit = () => {
+ onSubmit({
+ name,
+ config: {
+ model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
+ base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
+ epochs,
+ batch_size: batchSize,
+ augmentation: augmentationEnabled
+ ? (augmentationConfig as AugmentationConfigType)
+ : undefined,
+ augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
+ },
+ })
+ }
+
+ return (
+
+ e.stopPropagation()}>
+ Start Training
+
+ Dataset: {dataset.name}
+ {' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
+
+
+
+
+
+ setName(e.target.value)}
+ className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
+
+
+ {/* Base Model Selection */}
+
+
+
+
+ {baseModelType === 'pretrained'
+ ? 'Start from pretrained YOLO model'
+ : 'Continue training from an existing model (incremental training)'}
+
+
+
+
+
+
+ setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
-
-
-
-
-
-
-
-
-
- {split}% / {100-split}%
-
- setSplit(parseInt(e.target.value))}
- className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
+
+
+ setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
+ className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
+
+
+ {/* Augmentation Configuration */}
+
+
+ {/* Augmentation Multiplier - only shown when augmentation is enabled */}
+ {augmentationEnabled && (
+
+
+ setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
+ className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
+ />
+
+ Number of augmented copies per original image (1-10)
+
+
+ )}
+
+
+
+
+
+
+
+
+ )
+}
+
+// --- Dataset List ---
+
+const DatasetList: React.FC<{
+ onNavigate?: (view: string, id?: string) => void
+ onSwitchTab: (tab: Tab) => void
+}> = ({ onNavigate, onSwitchTab }) => {
+ const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
+ const [trainTarget, setTrainTarget] = useState (null)
+
+ const handleTrain = (config: {
+ name: string
+ config: {
+ model_name?: string
+ base_model_version_id?: string | null
+ epochs: number
+ batch_size: number
+ augmentation?: AugmentationConfigType
+ augmentation_multiplier?: number
+ }
+ }) => {
+ if (!trainTarget) return
+ // Pass config to the training API
+ const trainRequest = {
+ name: config.name,
+ config: config.config,
+ }
+ trainFromDataset(
+ { datasetId: trainTarget.dataset_id, req: trainRequest },
+ { onSuccess: () => setTrainTarget(null) },
+ )
+ }
+
+ if (isLoading) {
+ return Loading datasets...
+ }
+
+ if (datasets.length === 0) {
+ return (
+
+
+ No datasets yet
+ Create a dataset to start training
+
+
+ )
+ }
+
+ return (
+ <>
+
+
+
+
+ | Name |
+ Status |
+ Docs |
+ Images |
+ Annotations |
+ Created |
+ Actions |
+
+
+
+ {datasets.map(ds => (
+
+ | {ds.name} |
+ |
+ {ds.total_documents} |
+ {ds.total_images} |
+ {ds.total_annotations} |
+ {new Date(ds.created_at).toLocaleDateString()} |
+
+
+
+ {ds.status === 'ready' && (
+
+ )}
+
+
+ |
+
+ ))}
+
+
+
+
+ {trainTarget && (
+ setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
+ )}
+ >
+ )
+}
+
+// --- Create Dataset ---
+
+const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
+ const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
+ const { createDatasetAsync, isCreating } = useDatasets()
+
+ const [selectedIds, setSelectedIds] = useState>(new Set())
+ const [name, setName] = useState('')
+ const [description, setDescription] = useState('')
+ const [trainRatio, setTrainRatio] = useState(0.7)
+ const [valRatio, setValRatio] = useState(0.2)
+
+ const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
+
+ const toggleDoc = (id: string) => {
+ setSelectedIds(prev => {
+ const next = new Set(prev)
+ if (next.has(id)) { next.delete(id) } else { next.add(id) }
+ return next
+ })
+ }
+
+ const toggleAll = () => {
+ if (selectedIds.size === documents.length) {
+ setSelectedIds(new Set())
+ } else {
+ setSelectedIds(new Set(documents.map((d) => d.document_id)))
+ }
+ }
+
+ const handleCreate = async () => {
+ await createDatasetAsync({
+ name,
+ description: description || undefined,
+ document_ids: [...selectedIds],
+ train_ratio: trainRatio,
+ val_ratio: valRatio,
+ })
+ onSwitchTab('datasets')
+ }
+
+ return (
+
+ {/* Document selection */}
+
+ Select Documents
+ {isLoadingDocs ? (
+ Loading...
+ ) : (
+
+ )}
+ {selectedIds.size} of {documents.length} documents selected
+
+
+ {/* Config panel */}
+
+
+ Dataset Configuration
+
+
+
+ setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
+ className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
+
+
+
+
+
+
+
+
-
+ {selectedIds.size > 0 && selectedIds.size < 10 && (
+
+ Minimum 10 documents required for training ({selectedIds.size}/10 selected)
+
+ )}
+
- );
-};
\ No newline at end of file
+ )
+}
+
+// --- Main Training Component ---
+
+export const Training: React.FC = ({ onNavigate }) => {
+ const [activeTab, setActiveTab] = useState('datasets')
+
+ return (
+
+
+ Training
+
+
+ {/* Tabs */}
+
+ {([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
+
+ ))}
+
+
+ {activeTab === 'datasets' && }
+ {activeTab === 'create' && }
+
+ )
+}
diff --git a/frontend/src/components/UploadModal.tsx b/frontend/src/components/UploadModal.tsx
index ca7b9b3..f76df93 100644
--- a/frontend/src/components/UploadModal.tsx
+++ b/frontend/src/components/UploadModal.tsx
@@ -11,6 +11,7 @@ interface UploadModalProps {
export const UploadModal: React.FC = ({ isOpen, onClose }) => {
const [isDragging, setIsDragging] = useState(false)
const [selectedFiles, setSelectedFiles] = useState([])
+ const [groupKey, setGroupKey] = useState('')
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
const [errorMessage, setErrorMessage] = useState('')
const fileInputRef = useRef(null)
@@ -61,10 +62,13 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) =>
// Upload files one by one
for (const file of selectedFiles) {
await new Promise((resolve, reject) => {
- uploadDocument(file, {
- onSuccess: () => resolve(),
- onError: (error: Error) => reject(error),
- })
+ uploadDocument(
+ { file, groupKey: groupKey || undefined },
+ {
+ onSuccess: () => resolve(),
+ onError: (error: Error) => reject(error),
+ }
+ )
})
}
@@ -72,6 +76,7 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) =>
setTimeout(() => {
onClose()
setSelectedFiles([])
+ setGroupKey('')
setUploadStatus('idle')
}, 1500)
} catch (error) {
@@ -85,6 +90,7 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) =>
return // Prevent closing during upload
}
setSelectedFiles([])
+ setGroupKey('')
setUploadStatus('idle')
setErrorMessage('')
onClose()
@@ -173,6 +179,26 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) =>
)}
+ {/* Group Key Input */}
+ {selectedFiles.length > 0 && (
+
+
+ setGroupKey(e.target.value)}
+ placeholder="e.g., 2024-Q1, supplier-abc, project-name"
+ className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
+ disabled={uploadStatus === 'uploading'}
+ />
+
+ Use group keys to organize documents into logical groups
+
+
+ )}
+
{/* Status Messages */}
{uploadStatus === 'success' && (
diff --git a/frontend/src/hooks/index.ts b/frontend/src/hooks/index.ts
index d642394..d1b5a7b 100644
--- a/frontend/src/hooks/index.ts
+++ b/frontend/src/hooks/index.ts
@@ -2,3 +2,6 @@ export { useDocuments } from './useDocuments'
export { useDocumentDetail } from './useDocumentDetail'
export { useAnnotations } from './useAnnotations'
export { useTraining, useTrainingDocuments } from './useTraining'
+export { useDatasets, useDatasetDetail } from './useDatasets'
+export { useAugmentation } from './useAugmentation'
+export { useModels, useModelDetail, useActiveModel } from './useModels'
diff --git a/frontend/src/hooks/useAugmentation.test.tsx b/frontend/src/hooks/useAugmentation.test.tsx
new file mode 100644
index 0000000..33a5dde
--- /dev/null
+++ b/frontend/src/hooks/useAugmentation.test.tsx
@@ -0,0 +1,226 @@
+/**
+ * Tests for useAugmentation hook.
+ *
+ * TDD Phase 1: RED - Write tests first, then implement to pass.
+ */
+
+import { describe, it, expect, vi, beforeEach } from 'vitest'
+import { renderHook, waitFor } from '@testing-library/react'
+import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
+import { augmentationApi } from '../api/endpoints/augmentation'
+import { useAugmentation } from './useAugmentation'
+import type { ReactNode } from 'react'
+
+// Mock the API
+vi.mock('../api/endpoints/augmentation', () => ({
+ augmentationApi: {
+ getTypes: vi.fn(),
+ getPresets: vi.fn(),
+ preview: vi.fn(),
+ previewConfig: vi.fn(),
+ createBatch: vi.fn(),
+ },
+}))
+
+// Test wrapper with QueryClient
+const createWrapper = () => {
+ const queryClient = new QueryClient({
+ defaultOptions: {
+ queries: {
+ retry: false,
+ },
+ },
+ })
+ return ({ children }: { children: ReactNode }) => (
+ {children}
+ )
+}
+
+describe('useAugmentation', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ describe('getTypes', () => {
+ it('should fetch augmentation types', async () => {
+ const mockTypes = {
+ augmentation_types: [
+ {
+ name: 'gaussian_noise',
+ description: 'Adds Gaussian noise',
+ affects_geometry: false,
+ stage: 'noise',
+ default_params: { mean: 0, std: 15 },
+ },
+ {
+ name: 'perspective_warp',
+ description: 'Applies perspective warp',
+ affects_geometry: true,
+ stage: 'geometric',
+ default_params: { max_warp: 0.02 },
+ },
+ ],
+ }
+ vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce(mockTypes)
+
+ const { result } = renderHook(() => useAugmentation(), {
+ wrapper: createWrapper(),
+ })
+
+ await waitFor(() => {
+ expect(result.current.isLoadingTypes).toBe(false)
+ })
+
+ expect(result.current.augmentationTypes).toHaveLength(2)
+ expect(result.current.augmentationTypes[0].name).toBe('gaussian_noise')
+ })
+
+ it('should handle error when fetching types', async () => {
+ vi.mocked(augmentationApi.getTypes).mockRejectedValueOnce(new Error('Network error'))
+
+ const { result } = renderHook(() => useAugmentation(), {
+ wrapper: createWrapper(),
+ })
+
+ await waitFor(() => {
+ expect(result.current.isLoadingTypes).toBe(false)
+ })
+
+ expect(result.current.typesError).toBeTruthy()
+ })
+ })
+
+ describe('getPresets', () => {
+ it('should fetch augmentation presets', async () => {
+ const mockPresets = {
+ presets: [
+ { name: 'conservative', description: 'Safe augmentations' },
+ { name: 'moderate', description: 'Balanced augmentations' },
+ { name: 'aggressive', description: 'Strong augmentations' },
+ ],
+ }
+ vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
+ vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce(mockPresets)
+
+ const { result } = renderHook(() => useAugmentation(), {
+ wrapper: createWrapper(),
+ })
+
+ await waitFor(() => {
+ expect(result.current.isLoadingPresets).toBe(false)
+ })
+
+ expect(result.current.presets).toHaveLength(3)
+ expect(result.current.presets[0].name).toBe('conservative')
+ })
+ })
+
+ describe('preview', () => {
+ it('should preview single augmentation', async () => {
+ const mockPreview = {
+ preview_url: 'data:image/png;base64,xxx',
+ original_url: 'data:image/png;base64,yyy',
+ applied_params: { std: 15 },
+ }
+ vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
+ vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
+ vi.mocked(augmentationApi.preview).mockResolvedValueOnce(mockPreview)
+
+ const { result } = renderHook(() => useAugmentation(), {
+ wrapper: createWrapper(),
+ })
+
+ await waitFor(() => {
+ expect(result.current.isLoadingTypes).toBe(false)
+ })
+
+ // Call preview mutation
+ result.current.preview({
+ documentId: 'doc-123',
+ augmentationType: 'gaussian_noise',
+ params: { std: 15 },
+ page: 1,
+ })
+
+ await waitFor(() => {
+ expect(augmentationApi.preview).toHaveBeenCalledWith(
+ 'doc-123',
+ { augmentation_type: 'gaussian_noise', params: { std: 15 } },
+ 1
+ )
+ })
+ })
+
+ it('should track preview loading state', async () => {
+ vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
+ vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
+ vi.mocked(augmentationApi.preview).mockImplementation(
+ () => new Promise((resolve) => setTimeout(resolve, 100))
+ )
+
+ const { result } = renderHook(() => useAugmentation(), {
+ wrapper: createWrapper(),
+ })
+
+ await waitFor(() => {
+ expect(result.current.isLoadingTypes).toBe(false)
+ })
+
+ expect(result.current.isPreviewing).toBe(false)
+
+ result.current.preview({
+ documentId: 'doc-123',
+ augmentationType: 'gaussian_noise',
+ params: {},
+ page: 1,
+ })
+
+ // State update happens asynchronously
+ await waitFor(() => {
+ expect(result.current.isPreviewing).toBe(true)
+ })
+ })
+ })
+
+ describe('createBatch', () => {
+ it('should create augmented dataset', async () => {
+ const mockResponse = {
+ task_id: 'task-123',
+ status: 'pending',
+ message: 'Augmentation task queued',
+ estimated_images: 100,
+ }
+ vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
+ vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
+ vi.mocked(augmentationApi.createBatch).mockResolvedValueOnce(mockResponse)
+
+ const { result } = renderHook(() => useAugmentation(), {
+ wrapper: createWrapper(),
+ })
+
+ await waitFor(() => {
+ expect(result.current.isLoadingTypes).toBe(false)
+ })
+
+ result.current.createBatch({
+ dataset_id: 'dataset-123',
+ config: {
+ gaussian_noise: { enabled: true, probability: 0.5, params: {} },
+ },
+ output_name: 'augmented-dataset',
+ multiplier: 2,
+ })
+
+ await waitFor(() => {
+ expect(augmentationApi.createBatch).toHaveBeenCalledWith({
+ dataset_id: 'dataset-123',
+ config: {
+ gaussian_noise: { enabled: true, probability: 0.5, params: {} },
+ },
+ output_name: 'augmented-dataset',
+ multiplier: 2,
+ })
+ })
+ })
+ })
+})
diff --git a/frontend/src/hooks/useAugmentation.ts b/frontend/src/hooks/useAugmentation.ts
new file mode 100644
index 0000000..dd299ce
--- /dev/null
+++ b/frontend/src/hooks/useAugmentation.ts
@@ -0,0 +1,121 @@
+/**
+ * Hook for managing augmentation operations.
+ *
+ * Provides functions for fetching augmentation types, presets, and previewing augmentations.
+ */
+
+import { useQuery, useMutation } from '@tanstack/react-query'
+import {
+ augmentationApi,
+ type AugmentationTypesResponse,
+ type PresetsResponse,
+ type PreviewResponse,
+ type BatchRequest,
+ type BatchResponse,
+ type AugmentationConfig,
+} from '../api/endpoints/augmentation'
+
+interface PreviewParams {
+ documentId: string
+ augmentationType: string
+ params: Record
+ page?: number
+}
+
+interface PreviewConfigParams {
+ documentId: string
+ config: AugmentationConfig
+ page?: number
+}
+
+export const useAugmentation = () => {
+ // Fetch augmentation types
+ const {
+ data: typesData,
+ isLoading: isLoadingTypes,
+ error: typesError,
+ } = useQuery({
+ queryKey: ['augmentation', 'types'],
+ queryFn: () => augmentationApi.getTypes(),
+ staleTime: 5 * 60 * 1000, // Cache for 5 minutes
+ })
+
+ // Fetch presets
+ const {
+ data: presetsData,
+ isLoading: isLoadingPresets,
+ error: presetsError,
+ } = useQuery({
+ queryKey: ['augmentation', 'presets'],
+ queryFn: () => augmentationApi.getPresets(),
+ staleTime: 5 * 60 * 1000,
+ })
+
+ // Preview single augmentation mutation
+ const previewMutation = useMutation({
+ mutationFn: ({ documentId, augmentationType, params, page = 1 }) =>
+ augmentationApi.preview(
+ documentId,
+ { augmentation_type: augmentationType, params },
+ page
+ ),
+ onError: (error) => {
+ console.error('Preview augmentation failed:', error)
+ },
+ })
+
+ // Preview full config mutation
+ const previewConfigMutation = useMutation({
+ mutationFn: ({ documentId, config, page = 1 }) =>
+ augmentationApi.previewConfig(documentId, config, page),
+ onError: (error) => {
+ console.error('Preview config failed:', error)
+ },
+ })
+
+ // Create augmented dataset mutation
+ const createBatchMutation = useMutation({
+ mutationFn: (request) => augmentationApi.createBatch(request),
+ onError: (error) => {
+ console.error('Create augmented dataset failed:', error)
+ },
+ })
+
+ return {
+ // Types data
+ augmentationTypes: typesData?.augmentation_types || [],
+ isLoadingTypes,
+ typesError,
+
+ // Presets data
+ presets: presetsData?.presets || [],
+ isLoadingPresets,
+ presetsError,
+
+ // Preview single augmentation
+ preview: previewMutation.mutate,
+ previewAsync: previewMutation.mutateAsync,
+ isPreviewing: previewMutation.isPending,
+ previewData: previewMutation.data,
+ previewError: previewMutation.error,
+
+ // Preview full config
+ previewConfig: previewConfigMutation.mutate,
+ previewConfigAsync: previewConfigMutation.mutateAsync,
+ isPreviewingConfig: previewConfigMutation.isPending,
+ previewConfigData: previewConfigMutation.data,
+ previewConfigError: previewConfigMutation.error,
+
+ // Create batch
+ createBatch: createBatchMutation.mutate,
+ createBatchAsync: createBatchMutation.mutateAsync,
+ isCreatingBatch: createBatchMutation.isPending,
+ batchData: createBatchMutation.data,
+ batchError: createBatchMutation.error,
+
+ // Reset functions for clearing stale mutation state
+ resetPreview: previewMutation.reset,
+ resetPreviewConfig: previewConfigMutation.reset,
+ resetBatch: createBatchMutation.reset,
+ }
+}
diff --git a/frontend/src/hooks/useDatasets.ts b/frontend/src/hooks/useDatasets.ts
new file mode 100644
index 0000000..a6fe6e0
--- /dev/null
+++ b/frontend/src/hooks/useDatasets.ts
@@ -0,0 +1,84 @@
+import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
+import { datasetsApi } from '../api/endpoints'
+import type {
+ DatasetCreateRequest,
+ DatasetDetailResponse,
+ DatasetListResponse,
+ DatasetTrainRequest,
+} from '../api/types'
+
+export const useDatasets = (params?: {
+ status?: string
+ limit?: number
+ offset?: number
+}) => {
+ const queryClient = useQueryClient()
+
+ const { data, isLoading, error, refetch } = useQuery({
+ queryKey: ['datasets', params],
+ queryFn: () => datasetsApi.list(params),
+ staleTime: 30000,
+ // Poll every 5 seconds when there's an active training task
+ refetchInterval: (query) => {
+ const datasets = query.state.data?.datasets ?? []
+ const hasActiveTraining = datasets.some(
+ d => d.training_status === 'running' || d.training_status === 'pending' || d.training_status === 'scheduled'
+ )
+ return hasActiveTraining ? 5000 : false
+ },
+ })
+
+ const createMutation = useMutation({
+ mutationFn: (req: DatasetCreateRequest) => datasetsApi.create(req),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['datasets'] })
+ },
+ })
+
+ const deleteMutation = useMutation({
+ mutationFn: (datasetId: string) => datasetsApi.remove(datasetId),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['datasets'] })
+ },
+ })
+
+ const trainMutation = useMutation({
+ mutationFn: ({ datasetId, req }: { datasetId: string; req: DatasetTrainRequest }) =>
+ datasetsApi.trainFromDataset(datasetId, req),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['datasets'] })
+ queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
+ },
+ })
+
+ return {
+ datasets: data?.datasets ?? [],
+ total: data?.total ?? 0,
+ isLoading,
+ error,
+ refetch,
+ createDataset: createMutation.mutate,
+ createDatasetAsync: createMutation.mutateAsync,
+ isCreating: createMutation.isPending,
+ deleteDataset: deleteMutation.mutate,
+ isDeleting: deleteMutation.isPending,
+ trainFromDataset: trainMutation.mutate,
+ trainFromDatasetAsync: trainMutation.mutateAsync,
+ isTraining: trainMutation.isPending,
+ }
+}
+
+export const useDatasetDetail = (datasetId: string | null) => {
+ const { data, isLoading, error } = useQuery({
+ queryKey: ['datasets', datasetId],
+ queryFn: () => datasetsApi.getDetail(datasetId!),
+ enabled: !!datasetId,
+ staleTime: 30000,
+ })
+
+ return {
+ dataset: data ?? null,
+ isLoading,
+ error,
+ }
+}
diff --git a/frontend/src/hooks/useDocuments.ts b/frontend/src/hooks/useDocuments.ts
index 22e07c1..b75a126 100644
--- a/frontend/src/hooks/useDocuments.ts
+++ b/frontend/src/hooks/useDocuments.ts
@@ -18,7 +18,16 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
})
const uploadMutation = useMutation({
- mutationFn: (file: File) => documentsApi.upload(file),
+ mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) =>
+ documentsApi.upload(file, groupKey),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['documents'] })
+ },
+ })
+
+ const updateGroupKeyMutation = useMutation({
+ mutationFn: ({ documentId, groupKey }: { documentId: string; groupKey: string | null }) =>
+ documentsApi.updateGroupKey(documentId, groupKey),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
@@ -74,5 +83,8 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
isUpdatingStatus: updateStatusMutation.isPending,
triggerAutoLabel: triggerAutoLabelMutation.mutate,
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
+ updateGroupKey: updateGroupKeyMutation.mutate,
+ updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
+ isUpdatingGroupKey: updateGroupKeyMutation.isPending,
}
}
diff --git a/frontend/src/hooks/useModels.ts b/frontend/src/hooks/useModels.ts
new file mode 100644
index 0000000..9b3d8f7
--- /dev/null
+++ b/frontend/src/hooks/useModels.ts
@@ -0,0 +1,98 @@
+import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
+import { modelsApi } from '../api/endpoints'
+import type {
+ ModelVersionListResponse,
+ ModelVersionDetailResponse,
+ ActiveModelResponse,
+} from '../api/types'
+
+export const useModels = (params?: {
+ status?: string
+ limit?: number
+ offset?: number
+}) => {
+ const queryClient = useQueryClient()
+
+ const { data, isLoading, error, refetch } = useQuery({
+ queryKey: ['models', params],
+ queryFn: () => modelsApi.list(params),
+ staleTime: 30000,
+ })
+
+ const activateMutation = useMutation({
+ mutationFn: (versionId: string) => modelsApi.activate(versionId),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['models'] })
+ queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
+ },
+ })
+
+ const deactivateMutation = useMutation({
+ mutationFn: (versionId: string) => modelsApi.deactivate(versionId),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['models'] })
+ queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
+ },
+ })
+
+ const archiveMutation = useMutation({
+ mutationFn: (versionId: string) => modelsApi.archive(versionId),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['models'] })
+ },
+ })
+
+ const deleteMutation = useMutation({
+ mutationFn: (versionId: string) => modelsApi.delete(versionId),
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: ['models'] })
+ },
+ })
+
+ return {
+ models: data?.models ?? [],
+ total: data?.total ?? 0,
+ isLoading,
+ error,
+ refetch,
+ activateModel: activateMutation.mutate,
+ activateModelAsync: activateMutation.mutateAsync,
+ isActivating: activateMutation.isPending,
+ deactivateModel: deactivateMutation.mutate,
+ isDeactivating: deactivateMutation.isPending,
+ archiveModel: archiveMutation.mutate,
+ isArchiving: archiveMutation.isPending,
+ deleteModel: deleteMutation.mutate,
+ isDeleting: deleteMutation.isPending,
+ }
+}
+
+export const useModelDetail = (versionId: string | null) => {
+ const { data, isLoading, error } = useQuery({
+ queryKey: ['models', versionId],
+ queryFn: () => modelsApi.getDetail(versionId!),
+ enabled: !!versionId,
+ staleTime: 30000,
+ })
+
+ return {
+ model: data ?? null,
+ isLoading,
+ error,
+ }
+}
+
+export const useActiveModel = () => {
+ const { data, isLoading, error } = useQuery({
+ queryKey: ['models', 'active'],
+ queryFn: () => modelsApi.getActive(),
+ staleTime: 30000,
+ })
+
+ return {
+ hasActiveModel: data?.has_active_model ?? false,
+ activeModel: data?.model ?? null,
+ isLoading,
+ error,
+ }
+}
diff --git a/migrations/005_add_group_key.sql b/migrations/005_add_group_key.sql
new file mode 100644
index 0000000..981caa4
--- /dev/null
+++ b/migrations/005_add_group_key.sql
@@ -0,0 +1,8 @@
+-- Add group_key column to admin_documents
+-- Allows users to organize documents into logical groups
+
+-- Add the column (nullable, VARCHAR 255)
+ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
+
+-- Add index for filtering/grouping queries
+CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
diff --git a/migrations/006_model_versions.sql b/migrations/006_model_versions.sql
new file mode 100644
index 0000000..4d4e7f3
--- /dev/null
+++ b/migrations/006_model_versions.sql
@@ -0,0 +1,49 @@
+-- Model versions table for tracking trained model deployments.
+-- Each training run can produce a model version for inference.
+
+CREATE TABLE IF NOT EXISTS model_versions (
+ version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ version VARCHAR(50) NOT NULL,
+ name VARCHAR(255) NOT NULL,
+ description TEXT,
+ model_path VARCHAR(512) NOT NULL,
+ status VARCHAR(20) NOT NULL DEFAULT 'inactive',
+ is_active BOOLEAN NOT NULL DEFAULT FALSE,
+
+ -- Training association
+ task_id UUID REFERENCES training_tasks(task_id) ON DELETE SET NULL,
+ dataset_id UUID REFERENCES training_datasets(dataset_id) ON DELETE SET NULL,
+
+ -- Training metrics
+ metrics_mAP DOUBLE PRECISION,
+ metrics_precision DOUBLE PRECISION,
+ metrics_recall DOUBLE PRECISION,
+ document_count INTEGER NOT NULL DEFAULT 0,
+
+ -- Training configuration snapshot
+ training_config JSONB,
+
+ -- File info
+ file_size BIGINT,
+
+ -- Timestamps
+ trained_at TIMESTAMP WITH TIME ZONE,
+ activated_at TIMESTAMP WITH TIME ZONE,
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
+);
+
+-- Indexes
+CREATE INDEX IF NOT EXISTS idx_model_versions_version ON model_versions(version);
+CREATE INDEX IF NOT EXISTS idx_model_versions_status ON model_versions(status);
+CREATE INDEX IF NOT EXISTS idx_model_versions_is_active ON model_versions(is_active);
+CREATE INDEX IF NOT EXISTS idx_model_versions_task_id ON model_versions(task_id);
+CREATE INDEX IF NOT EXISTS idx_model_versions_dataset_id ON model_versions(dataset_id);
+CREATE INDEX IF NOT EXISTS idx_model_versions_created ON model_versions(created_at);
+
+-- Ensure only one active model at a time
+CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
+ ON model_versions(is_active) WHERE is_active = TRUE;
+
+-- Comment
+COMMENT ON TABLE model_versions IS 'Trained model versions for inference deployment';
diff --git a/migrations/007_training_tasks_extra_columns.sql b/migrations/007_training_tasks_extra_columns.sql
new file mode 100644
index 0000000..b4ef0ae
--- /dev/null
+++ b/migrations/007_training_tasks_extra_columns.sql
@@ -0,0 +1,46 @@
+-- Add missing columns to training_tasks table
+
+-- Add name column
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
+UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
+ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
+
+-- Add description column
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
+
+-- Add admin_token column (for multi-tenant support)
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
+
+-- Add task_type column
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
+
+-- Add recurring schedule columns
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
+
+-- Add result metrics columns (for display without parsing JSONB)
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
+
+-- Rename metrics to config if exists
+DO $$
+BEGIN
+ IF EXISTS (SELECT 1 FROM information_schema.columns
+ WHERE table_name = 'training_tasks' AND column_name = 'metrics'
+ AND NOT EXISTS (SELECT 1 FROM information_schema.columns
+ WHERE table_name = 'training_tasks' AND column_name = 'config')) THEN
+ ALTER TABLE training_tasks RENAME COLUMN metrics TO config;
+ END IF;
+END $$;
+
+-- Add updated_at column
+ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
+
+-- Create index on name
+CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
+
+-- Create index on metrics_mAP
+CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
diff --git a/migrations/008_fix_model_versions_fk.sql b/migrations/008_fix_model_versions_fk.sql
new file mode 100644
index 0000000..7a4dfc8
--- /dev/null
+++ b/migrations/008_fix_model_versions_fk.sql
@@ -0,0 +1,14 @@
+-- Fix foreign key constraints on model_versions table to allow CASCADE delete
+
+-- Drop existing constraints
+ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
+ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
+
+-- Add constraints with ON DELETE SET NULL
+ALTER TABLE model_versions
+ADD CONSTRAINT model_versions_dataset_id_fkey
+FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
+
+ALTER TABLE model_versions
+ADD CONSTRAINT model_versions_task_id_fkey
+FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
diff --git a/packages/inference/inference/data/admin_db.py b/packages/inference/inference/data/admin_db.py
index 02f9d8c..9f9765d 100644
--- a/packages/inference/inference/data/admin_db.py
+++ b/packages/inference/inference/data/admin_db.py
@@ -25,6 +25,7 @@ from inference.data.admin_models import (
AnnotationHistory,
TrainingDataset,
DatasetDocument,
+ ModelVersion,
)
logger = logging.getLogger(__name__)
@@ -110,6 +111,7 @@ class AdminDB:
page_count: int = 1,
upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None,
+ group_key: str | None = None,
admin_token: str | None = None, # Deprecated, kept for compatibility
) -> str:
"""Create a new document record."""
@@ -122,6 +124,7 @@ class AdminDB:
page_count=page_count,
upload_source=upload_source,
csv_field_values=csv_field_values,
+ group_key=group_key,
)
session.add(document)
session.flush()
@@ -253,6 +256,17 @@ class AdminDB:
document.updated_at = datetime.utcnow()
session.add(document)
+ def update_document_group_key(self, document_id: str, group_key: str | None) -> bool:
+ """Update document group key."""
+ with get_session_context() as session:
+ document = session.get(AdminDocument, UUID(document_id))
+ if document:
+ document.group_key = group_key
+ document.updated_at = datetime.utcnow()
+ session.add(document)
+ return True
+ return False
+
def delete_document(self, document_id: str) -> bool:
"""Delete a document and its annotations."""
with get_session_context() as session:
@@ -1215,6 +1229,39 @@ class AdminDB:
session.expunge(d)
return list(datasets), total
+ def get_active_training_tasks_for_datasets(
+ self, dataset_ids: list[str]
+ ) -> dict[str, dict[str, str]]:
+ """Get active (pending/scheduled/running) training tasks for datasets.
+
+ Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
+ """
+ if not dataset_ids:
+ return {}
+
+ # Validate UUIDs before query
+ valid_uuids = []
+ for d in dataset_ids:
+ try:
+ valid_uuids.append(UUID(d))
+ except ValueError:
+ logger.warning("Invalid UUID in get_active_training_tasks_for_datasets: %s", d)
+ continue
+
+ if not valid_uuids:
+ return {}
+
+ with get_session_context() as session:
+ statement = select(TrainingTask).where(
+ TrainingTask.dataset_id.in_(valid_uuids),
+ TrainingTask.status.in_(["pending", "scheduled", "running"]),
+ )
+ results = session.exec(statement).all()
+ return {
+ str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
+ for t in results
+ }
+
def update_dataset_status(
self,
dataset_id: str | UUID,
@@ -1314,3 +1361,182 @@ class AdminDB:
session.delete(dataset)
session.commit()
return True
+
+ # ==========================================================================
+ # Model Version Operations
+ # ==========================================================================
+
+ def create_model_version(
+ self,
+ version: str,
+ name: str,
+ model_path: str,
+ description: str | None = None,
+ task_id: str | UUID | None = None,
+ dataset_id: str | UUID | None = None,
+ metrics_mAP: float | None = None,
+ metrics_precision: float | None = None,
+ metrics_recall: float | None = None,
+ document_count: int = 0,
+ training_config: dict[str, Any] | None = None,
+ file_size: int | None = None,
+ trained_at: datetime | None = None,
+ ) -> ModelVersion:
+ """Create a new model version."""
+ with get_session_context() as session:
+ model = ModelVersion(
+ version=version,
+ name=name,
+ model_path=model_path,
+ description=description,
+ task_id=UUID(str(task_id)) if task_id else None,
+ dataset_id=UUID(str(dataset_id)) if dataset_id else None,
+ metrics_mAP=metrics_mAP,
+ metrics_precision=metrics_precision,
+ metrics_recall=metrics_recall,
+ document_count=document_count,
+ training_config=training_config,
+ file_size=file_size,
+ trained_at=trained_at,
+ )
+ session.add(model)
+ session.commit()
+ session.refresh(model)
+ session.expunge(model)
+ return model
+
+ def get_model_version(self, version_id: str | UUID) -> ModelVersion | None:
+ """Get a model version by ID."""
+ with get_session_context() as session:
+ model = session.get(ModelVersion, UUID(str(version_id)))
+ if model:
+ session.expunge(model)
+ return model
+
+ def get_model_versions(
+ self,
+ status: str | None = None,
+ limit: int = 20,
+ offset: int = 0,
+ ) -> tuple[list[ModelVersion], int]:
+ """List model versions with optional status filter."""
+ with get_session_context() as session:
+ query = select(ModelVersion)
+ count_query = select(func.count()).select_from(ModelVersion)
+ if status:
+ query = query.where(ModelVersion.status == status)
+ count_query = count_query.where(ModelVersion.status == status)
+ total = session.exec(count_query).one()
+ models = session.exec(
+ query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
+ ).all()
+ for m in models:
+ session.expunge(m)
+ return list(models), total
+
+ def get_active_model_version(self) -> ModelVersion | None:
+ """Get the currently active model version for inference."""
+ with get_session_context() as session:
+ result = session.exec(
+ select(ModelVersion).where(ModelVersion.is_active == True)
+ ).first()
+ if result:
+ session.expunge(result)
+ return result
+
+ def activate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
+ """Activate a model version for inference (deactivates all others)."""
+ with get_session_context() as session:
+ # Deactivate all versions
+ all_versions = session.exec(
+ select(ModelVersion).where(ModelVersion.is_active == True)
+ ).all()
+ for v in all_versions:
+ v.is_active = False
+ v.status = "inactive"
+ v.updated_at = datetime.utcnow()
+ session.add(v)
+
+ # Activate the specified version
+ model = session.get(ModelVersion, UUID(str(version_id)))
+ if not model:
+ return None
+ model.is_active = True
+ model.status = "active"
+ model.activated_at = datetime.utcnow()
+ model.updated_at = datetime.utcnow()
+ session.add(model)
+ session.commit()
+ session.refresh(model)
+ session.expunge(model)
+ return model
+
+ def deactivate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
+ """Deactivate a model version."""
+ with get_session_context() as session:
+ model = session.get(ModelVersion, UUID(str(version_id)))
+ if not model:
+ return None
+ model.is_active = False
+ model.status = "inactive"
+ model.updated_at = datetime.utcnow()
+ session.add(model)
+ session.commit()
+ session.refresh(model)
+ session.expunge(model)
+ return model
+
+ def update_model_version(
+ self,
+ version_id: str | UUID,
+ name: str | None = None,
+ description: str | None = None,
+ status: str | None = None,
+ ) -> ModelVersion | None:
+ """Update model version metadata."""
+ with get_session_context() as session:
+ model = session.get(ModelVersion, UUID(str(version_id)))
+ if not model:
+ return None
+ if name is not None:
+ model.name = name
+ if description is not None:
+ model.description = description
+ if status is not None:
+ model.status = status
+ model.updated_at = datetime.utcnow()
+ session.add(model)
+ session.commit()
+ session.refresh(model)
+ session.expunge(model)
+ return model
+
+ def archive_model_version(self, version_id: str | UUID) -> ModelVersion | None:
+ """Archive a model version."""
+ with get_session_context() as session:
+ model = session.get(ModelVersion, UUID(str(version_id)))
+ if not model:
+ return None
+ # Cannot archive active model
+ if model.is_active:
+ return None
+ model.status = "archived"
+ model.updated_at = datetime.utcnow()
+ session.add(model)
+ session.commit()
+ session.refresh(model)
+ session.expunge(model)
+ return model
+
+ def delete_model_version(self, version_id: str | UUID) -> bool:
+ """Delete a model version."""
+ with get_session_context() as session:
+ model = session.get(ModelVersion, UUID(str(version_id)))
+ if not model:
+ return False
+ # Cannot delete active model
+ if model.is_active:
+ return False
+ session.delete(model)
+ session.commit()
+ return True
diff --git a/packages/inference/inference/data/admin_models.py b/packages/inference/inference/data/admin_models.py
index a374d07..ca7e5f2 100644
--- a/packages/inference/inference/data/admin_models.py
+++ b/packages/inference/inference/data/admin_models.py
@@ -70,6 +70,8 @@ class AdminDocument(SQLModel, table=True):
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
+ group_key: str | None = Field(default=None, max_length=255, index=True)
+ # User-defined grouping key for document organization
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
@@ -275,6 +277,56 @@ class TrainingDocumentLink(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow)
+# =============================================================================
+# Model Version Management
+# =============================================================================
+
+
+class ModelVersion(SQLModel, table=True):
+ """Model version for inference deployment."""
+
+ __tablename__ = "model_versions"
+
+ version_id: UUID = Field(default_factory=uuid4, primary_key=True)
+ version: str = Field(max_length=50, index=True)
+ # Semantic version e.g., "1.0.0", "2.1.0"
+ name: str = Field(max_length=255)
+ description: str | None = Field(default=None)
+ model_path: str = Field(max_length=512)
+ # Path to the model weights file
+ status: str = Field(default="inactive", max_length=20, index=True)
+ # Status: active, inactive, archived
+ is_active: bool = Field(default=False, index=True)
+ # Only one version can be active at a time for inference
+
+ # Training association
+ task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
+ dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
+
+ # Training metrics
+ metrics_mAP: float | None = Field(default=None)
+ metrics_precision: float | None = Field(default=None)
+ metrics_recall: float | None = Field(default=None)
+ document_count: int = Field(default=0)
+ # Number of documents used in training
+
+ # Training configuration snapshot
+ training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
+ # Snapshot of epochs, batch_size, etc.
+
+ # File info
+ file_size: int | None = Field(default=None)
+ # Model file size in bytes
+
+ # Timestamps
+ trained_at: datetime | None = Field(default=None)
+ # When training completed
+ activated_at: datetime | None = Field(default=None)
+ # When this version was last activated
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
+
+
# =============================================================================
# Annotation History (v2)
# =============================================================================
diff --git a/packages/inference/inference/data/database.py b/packages/inference/inference/data/database.py
index 7613b6f..15b4c14 100644
--- a/packages/inference/inference/data/database.py
+++ b/packages/inference/inference/data/database.py
@@ -49,6 +49,111 @@ def get_engine():
return _engine
+def run_migrations() -> None:
+ """Run database migrations for new columns."""
+ engine = get_engine()
+
+ migrations = [
+ # Migration 004: Training datasets tables and dataset_id on training_tasks
+ (
+ "training_datasets_tables",
+ """
+ CREATE TABLE IF NOT EXISTS training_datasets (
+ dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ name VARCHAR(255) NOT NULL,
+ description TEXT,
+ status VARCHAR(20) NOT NULL DEFAULT 'building',
+ train_ratio FLOAT NOT NULL DEFAULT 0.8,
+ val_ratio FLOAT NOT NULL DEFAULT 0.1,
+ seed INTEGER NOT NULL DEFAULT 42,
+ total_documents INTEGER NOT NULL DEFAULT 0,
+ total_images INTEGER NOT NULL DEFAULT 0,
+ total_annotations INTEGER NOT NULL DEFAULT 0,
+ dataset_path VARCHAR(512),
+ error_message TEXT,
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
+ );
+ CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
+ """,
+ ),
+ (
+ "dataset_documents_table",
+ """
+ CREATE TABLE IF NOT EXISTS dataset_documents (
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
+ document_id UUID NOT NULL REFERENCES admin_documents(document_id),
+ split VARCHAR(10) NOT NULL,
+ page_count INTEGER NOT NULL DEFAULT 0,
+ annotation_count INTEGER NOT NULL DEFAULT 0,
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
+ UNIQUE(dataset_id, document_id)
+ );
+ CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
+ CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
+ """,
+ ),
+ (
+ "training_tasks_dataset_id",
+ """
+ ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
+ CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
+ """,
+ ),
+ # Migration 005: Add group_key to admin_documents
+ (
+ "admin_documents_group_key",
+ """
+ ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
+ CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
+ """,
+ ),
+ # Migration 006: Model versions table
+ (
+ "model_versions_table",
+ """
+ CREATE TABLE IF NOT EXISTS model_versions (
+ version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ version VARCHAR(50) NOT NULL,
+ name VARCHAR(255) NOT NULL,
+ description TEXT,
+ model_path VARCHAR(512) NOT NULL,
+ status VARCHAR(20) NOT NULL DEFAULT 'inactive',
+ is_active BOOLEAN NOT NULL DEFAULT FALSE,
+ task_id UUID REFERENCES training_tasks(task_id),
+ dataset_id UUID REFERENCES training_datasets(dataset_id),
+ metrics_mAP FLOAT,
+ metrics_precision FLOAT,
+ metrics_recall FLOAT,
+ document_count INTEGER NOT NULL DEFAULT 0,
+ training_config JSONB,
+ file_size BIGINT,
+ trained_at TIMESTAMP WITH TIME ZONE,
+ activated_at TIMESTAMP WITH TIME ZONE,
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
+ );
+ CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
+ CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
+ CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
+ CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
+ CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
+ """,
+ ),
+ ]
+
+ with engine.connect() as conn:
+ for name, sql in migrations:
+ try:
+ conn.execute(text(sql))
+ conn.commit()
+ logger.info(f"Migration '{name}' applied successfully")
+ except Exception as e:
+ # Log but don't fail - column may already exist
+ logger.debug(f"Migration '{name}' skipped or failed: {e}")
+
+
def create_db_and_tables() -> None:
"""Create all database tables."""
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
@@ -64,6 +169,9 @@ def create_db_and_tables() -> None:
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
+ # Run migrations for new columns
+ run_migrations()
+
def get_session() -> Session:
"""Get a new database session."""
diff --git a/packages/inference/inference/web/api/v1/admin/__init__.py b/packages/inference/inference/web/api/v1/admin/__init__.py
index 8d5081b..ce7a5d0 100644
--- a/packages/inference/inference/web/api/v1/admin/__init__.py
+++ b/packages/inference/inference/web/api/v1/admin/__init__.py
@@ -5,6 +5,7 @@ Document management, annotations, and training endpoints.
"""
from inference.web.api.v1.admin.annotations import create_annotation_router
+from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.api.v1.admin.auth import create_auth_router
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.api.v1.admin.locks import create_locks_router
@@ -12,6 +13,7 @@ from inference.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
+ "create_augmentation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",
diff --git a/packages/inference/inference/web/api/v1/admin/augmentation/__init__.py b/packages/inference/inference/web/api/v1/admin/augmentation/__init__.py
new file mode 100644
index 0000000..185e649
--- /dev/null
+++ b/packages/inference/inference/web/api/v1/admin/augmentation/__init__.py
@@ -0,0 +1,15 @@
+"""Augmentation API module."""
+
+from fastapi import APIRouter
+
+from .routes import register_augmentation_routes
+
+
+def create_augmentation_router() -> APIRouter:
+ """Create and configure the augmentation router."""
+ router = APIRouter(prefix="/augmentation", tags=["augmentation"])
+ register_augmentation_routes(router)
+ return router
+
+
+__all__ = ["create_augmentation_router"]
diff --git a/packages/inference/inference/web/api/v1/admin/augmentation/routes.py b/packages/inference/inference/web/api/v1/admin/augmentation/routes.py
new file mode 100644
index 0000000..fbf6e3e
--- /dev/null
+++ b/packages/inference/inference/web/api/v1/admin/augmentation/routes.py
@@ -0,0 +1,162 @@
+"""Augmentation API routes."""
+
+from typing import Annotated
+
+from fastapi import APIRouter, HTTPException, Query
+
+from inference.web.core.auth import AdminDBDep, AdminTokenDep
+from inference.web.schemas.admin.augmentation import (
+ AugmentationBatchRequest,
+ AugmentationBatchResponse,
+ AugmentationConfigSchema,
+ AugmentationPreviewRequest,
+ AugmentationPreviewResponse,
+ AugmentationTypeInfo,
+ AugmentationTypesResponse,
+ AugmentedDatasetItem,
+ AugmentedDatasetListResponse,
+ PresetInfo,
+ PresetsResponse,
+)
+
+
+def register_augmentation_routes(router: APIRouter) -> None:
+ """Register augmentation endpoints on the router."""
+
+ @router.get(
+ "/types",
+ response_model=AugmentationTypesResponse,
+ summary="List available augmentation types",
+ )
+ async def list_augmentation_types(
+ admin_token: AdminTokenDep,
+ ) -> AugmentationTypesResponse:
+ """
+ List all available augmentation types with descriptions and parameters.
+ """
+ from shared.augmentation.pipeline import (
+ AUGMENTATION_REGISTRY,
+ AugmentationPipeline,
+ )
+
+ types = []
+ for name, aug_class in AUGMENTATION_REGISTRY.items():
+ # Create instance with empty params to get preview params
+ aug = aug_class({})
+ types.append(
+ AugmentationTypeInfo(
+ name=name,
+ description=(aug_class.__doc__ or "").strip(),
+ affects_geometry=aug_class.affects_geometry,
+ stage=AugmentationPipeline.STAGE_MAPPING[name],
+ default_params=aug.get_preview_params(),
+ )
+ )
+
+ return AugmentationTypesResponse(augmentation_types=types)
+
+ @router.get(
+ "/presets",
+ response_model=PresetsResponse,
+ summary="Get augmentation presets",
+ )
+ async def get_presets(
+ admin_token: AdminTokenDep,
+ ) -> PresetsResponse:
+ """Get predefined augmentation presets for common use cases."""
+ from shared.augmentation.presets import list_presets
+
+ presets = [PresetInfo(**p) for p in list_presets()]
+ return PresetsResponse(presets=presets)
+
+ @router.post(
+ "/preview/{document_id}",
+ response_model=AugmentationPreviewResponse,
+ summary="Preview augmentation on document image",
+ )
+ async def preview_augmentation(
+ document_id: str,
+ request: AugmentationPreviewRequest,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ page: int = Query(default=1, ge=1, description="Page number"),
+ ) -> AugmentationPreviewResponse:
+ """
+ Preview a single augmentation on a document page.
+
+ Returns URLs to original and augmented preview images.
+ """
+ from inference.web.services.augmentation_service import AugmentationService
+
+ service = AugmentationService(db=db)
+ return await service.preview_single(
+ document_id=document_id,
+ page=page,
+ augmentation_type=request.augmentation_type,
+ params=request.params,
+ )
+
+ @router.post(
+ "/preview-config/{document_id}",
+ response_model=AugmentationPreviewResponse,
+ summary="Preview full augmentation config on document",
+ )
+ async def preview_config(
+ document_id: str,
+ config: AugmentationConfigSchema,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ page: int = Query(default=1, ge=1, description="Page number"),
+ ) -> AugmentationPreviewResponse:
+ """Preview complete augmentation pipeline on a document page."""
+ from inference.web.services.augmentation_service import AugmentationService
+
+ service = AugmentationService(db=db)
+ return await service.preview_config(
+ document_id=document_id,
+ page=page,
+ config=config,
+ )
+
+ @router.post(
+ "/batch",
+ response_model=AugmentationBatchResponse,
+ summary="Create augmented dataset (offline preprocessing)",
+ )
+ async def create_augmented_dataset(
+ request: AugmentationBatchRequest,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> AugmentationBatchResponse:
+ """
+ Create a new augmented dataset from an existing dataset.
+
+ This runs as a background task. The augmented images are stored
+ alongside the original dataset for training.
+ """
+ from inference.web.services.augmentation_service import AugmentationService
+
+ service = AugmentationService(db=db)
+ return await service.create_augmented_dataset(
+ source_dataset_id=request.dataset_id,
+ config=request.config,
+ output_name=request.output_name,
+ multiplier=request.multiplier,
+ )
+
+ @router.get(
+ "/datasets",
+ response_model=AugmentedDatasetListResponse,
+ summary="List augmented datasets",
+ )
+ async def list_augmented_datasets(
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ limit: int = Query(default=20, ge=1, le=100, description="Page size"),
+ offset: int = Query(default=0, ge=0, description="Offset"),
+ ) -> AugmentedDatasetListResponse:
+ """List all augmented datasets."""
+ from inference.web.services.augmentation_service import AugmentationService
+
+ service = AugmentationService(db=db)
+ return await service.list_augmented_datasets(limit=limit, offset=offset)
diff --git a/packages/inference/inference/web/api/v1/admin/documents.py b/packages/inference/inference/web/api/v1/admin/documents.py
index fd2f355..f78db66 100644
--- a/packages/inference/inference/web/api/v1/admin/documents.py
+++ b/packages/inference/inference/web/api/v1/admin/documents.py
@@ -91,8 +91,19 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
+ group_key: Annotated[
+ str | None,
+ Query(description="Optional group key for document organization", max_length=255),
+ ] = None,
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
+ # Validate group_key length
+ if group_key and len(group_key) > 255:
+ raise HTTPException(
+ status_code=400,
+ detail="Group key must be 255 characters or less",
+ )
+
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
@@ -131,6 +142,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
+ group_key=group_key,
)
# Save file to admin uploads
@@ -177,6 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
+ group_key=group_key,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@@ -277,6 +290,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
annotation_count=len(annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
+ group_key=doc.group_key if hasattr(doc, 'group_key') else None,
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
@@ -421,6 +435,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
+ group_key=document.group_key if hasattr(document, 'group_key') else None,
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
@@ -548,4 +563,50 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
return response
+ @router.patch(
+ "/{document_id}/group-key",
+ responses={
+ 401: {"model": ErrorResponse, "description": "Invalid token"},
+ 404: {"model": ErrorResponse, "description": "Document not found"},
+ },
+ summary="Update document group key",
+ description="Update the group key for a document.",
+ )
+ async def update_document_group_key(
+ document_id: str,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ group_key: Annotated[
+ str | None,
+ Query(description="New group key (null to clear)"),
+ ] = None,
+ ) -> dict:
+ """Update document group key."""
+ _validate_uuid(document_id, "document_id")
+
+ # Validate group_key length
+ if group_key and len(group_key) > 255:
+ raise HTTPException(
+ status_code=400,
+ detail="Group key must be 255 characters or less",
+ )
+
+ # Verify document exists
+ document = db.get_document_by_token(document_id, admin_token)
+ if document is None:
+ raise HTTPException(
+ status_code=404,
+ detail="Document not found or does not belong to this token",
+ )
+
+ # Update group key
+ db.update_document_group_key(document_id, group_key)
+
+ return {
+ "status": "updated",
+ "document_id": document_id,
+ "group_key": group_key,
+ "message": "Document group key updated",
+ }
+
return router
diff --git a/packages/inference/inference/web/api/v1/admin/training/__init__.py b/packages/inference/inference/web/api/v1/admin/training/__init__.py
index d2fba3c..cde7547 100644
--- a/packages/inference/inference/web/api/v1/admin/training/__init__.py
+++ b/packages/inference/inference/web/api/v1/admin/training/__init__.py
@@ -11,6 +11,7 @@ from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
+from .models import register_model_routes
def create_training_router() -> APIRouter:
@@ -21,6 +22,7 @@ def create_training_router() -> APIRouter:
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
+ register_model_routes(router)
return router
diff --git a/packages/inference/inference/web/api/v1/admin/training/datasets.py b/packages/inference/inference/web/api/v1/admin/training/datasets.py
index a46c4b3..bf93239 100644
--- a/packages/inference/inference/web/api/v1/admin/training/datasets.py
+++ b/packages/inference/inference/web/api/v1/admin/training/datasets.py
@@ -41,6 +41,13 @@ def register_dataset_routes(router: APIRouter) -> None:
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
+ # Validate minimum document count for proper train/val/test split
+ if len(request.document_ids) < 10:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
+ )
+
dataset = db.create_dataset(
name=request.name,
description=request.description,
@@ -83,6 +90,15 @@ def register_dataset_routes(router: APIRouter) -> None:
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
+
+ # Get active training tasks for each dataset (graceful degradation on error)
+ dataset_ids = [str(d.dataset_id) for d in datasets]
+ try:
+ active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
+ except Exception:
+ logger.exception("Failed to fetch active training tasks")
+ active_tasks = {}
+
return DatasetListResponse(
total=total,
limit=limit,
@@ -93,6 +109,8 @@ def register_dataset_routes(router: APIRouter) -> None:
name=d.name,
description=d.description,
status=d.status,
+ training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
+ active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
@@ -175,6 +193,7 @@ def register_dataset_routes(router: APIRouter) -> None:
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
+ description="Create a training task. Set base_model_version_id in config for incremental training.",
)
async def train_from_dataset(
dataset_id: str,
@@ -182,7 +201,11 @@ def register_dataset_routes(router: APIRouter) -> None:
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
- """Create a training task from a dataset."""
+ """Create a training task from a dataset.
+
+ For incremental training, set config.base_model_version_id to a model version UUID.
+ The training will use that model as the starting point instead of a pretrained model.
+ """
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
@@ -194,16 +217,42 @@ def register_dataset_routes(router: APIRouter) -> None:
)
config_dict = request.config.model_dump()
+
+ # Resolve base_model_version_id to actual model path for incremental training
+ base_model_version_id = config_dict.get("base_model_version_id")
+ if base_model_version_id:
+ _validate_uuid(base_model_version_id, "base_model_version_id")
+ base_model = db.get_model_version(base_model_version_id)
+ if not base_model:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Base model version not found: {base_model_version_id}",
+ )
+ # Store the resolved model path for the training worker
+ config_dict["base_model_path"] = base_model.model_path
+ config_dict["base_model_version"] = base_model.version
+ logger.info(
+ "Incremental training: using model %s (%s) as base",
+ base_model.version,
+ base_model.model_path,
+ )
+
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
- task_type="train",
+ task_type="finetune" if base_model_version_id else "train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
+ message = (
+ f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
+ if base_model_version_id
+ else "Training task created from dataset"
+ )
+
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
- message="Training task created from dataset",
+ message=message,
)
diff --git a/packages/inference/inference/web/api/v1/admin/training/documents.py b/packages/inference/inference/web/api/v1/admin/training/documents.py
index 27e935a..18e9e7d 100644
--- a/packages/inference/inference/web/api/v1/admin/training/documents.py
+++ b/packages/inference/inference/web/api/v1/admin/training/documents.py
@@ -145,15 +145,15 @@ def register_document_routes(router: APIRouter) -> None:
)
@router.get(
- "/models",
+ "/completed-tasks",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
- summary="Get trained models",
- description="Get list of trained models with metrics and download links.",
+ summary="Get completed training tasks",
+ description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
)
- async def get_training_models(
+ async def get_completed_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
diff --git a/packages/inference/inference/web/api/v1/admin/training/models.py b/packages/inference/inference/web/api/v1/admin/training/models.py
new file mode 100644
index 0000000..fcbb64b
--- /dev/null
+++ b/packages/inference/inference/web/api/v1/admin/training/models.py
@@ -0,0 +1,333 @@
+"""Model Version Endpoints."""
+
+import logging
+from typing import Annotated
+
+from fastapi import APIRouter, HTTPException, Query, Request
+
+from inference.web.core.auth import AdminTokenDep, AdminDBDep
+from inference.web.schemas.admin import (
+ ModelVersionCreateRequest,
+ ModelVersionUpdateRequest,
+ ModelVersionItem,
+ ModelVersionListResponse,
+ ModelVersionDetailResponse,
+ ModelVersionResponse,
+ ActiveModelResponse,
+)
+
+from ._utils import _validate_uuid
+
+logger = logging.getLogger(__name__)
+
+
+def register_model_routes(router: APIRouter) -> None:
+ """Register model version endpoints on the router."""
+
+ @router.post(
+ "/models",
+ response_model=ModelVersionResponse,
+ summary="Create model version",
+ description="Register a new model version for deployment.",
+ )
+ async def create_model_version(
+ request: ModelVersionCreateRequest,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> ModelVersionResponse:
+ """Create a new model version."""
+ if request.task_id:
+ _validate_uuid(request.task_id, "task_id")
+ if request.dataset_id:
+ _validate_uuid(request.dataset_id, "dataset_id")
+
+ model = db.create_model_version(
+ version=request.version,
+ name=request.name,
+ model_path=request.model_path,
+ description=request.description,
+ task_id=request.task_id,
+ dataset_id=request.dataset_id,
+ metrics_mAP=request.metrics_mAP,
+ metrics_precision=request.metrics_precision,
+ metrics_recall=request.metrics_recall,
+ document_count=request.document_count,
+ training_config=request.training_config,
+ file_size=request.file_size,
+ trained_at=request.trained_at,
+ )
+
+ return ModelVersionResponse(
+ version_id=str(model.version_id),
+ status=model.status,
+ message="Model version created successfully",
+ )
+
+ @router.get(
+ "/models",
+ response_model=ModelVersionListResponse,
+ summary="List model versions",
+ )
+ async def list_model_versions(
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ status: Annotated[str | None, Query(description="Filter by status")] = None,
+ limit: Annotated[int, Query(ge=1, le=100)] = 20,
+ offset: Annotated[int, Query(ge=0)] = 0,
+ ) -> ModelVersionListResponse:
+ """List model versions with optional status filter."""
+ models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
+ return ModelVersionListResponse(
+ total=total,
+ limit=limit,
+ offset=offset,
+ models=[
+ ModelVersionItem(
+ version_id=str(m.version_id),
+ version=m.version,
+ name=m.name,
+ status=m.status,
+ is_active=m.is_active,
+ metrics_mAP=m.metrics_mAP,
+ document_count=m.document_count,
+ trained_at=m.trained_at,
+ activated_at=m.activated_at,
+ created_at=m.created_at,
+ )
+ for m in models
+ ],
+ )
+
+ @router.get(
+ "/models/active",
+ response_model=ActiveModelResponse,
+ summary="Get active model",
+ description="Get the currently active model for inference.",
+ )
+ async def get_active_model(
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> ActiveModelResponse:
+ """Get the currently active model version."""
+ model = db.get_active_model_version()
+ if not model:
+ return ActiveModelResponse(has_active_model=False, model=None)
+
+ return ActiveModelResponse(
+ has_active_model=True,
+ model=ModelVersionItem(
+ version_id=str(model.version_id),
+ version=model.version,
+ name=model.name,
+ status=model.status,
+ is_active=model.is_active,
+ metrics_mAP=model.metrics_mAP,
+ document_count=model.document_count,
+ trained_at=model.trained_at,
+ activated_at=model.activated_at,
+ created_at=model.created_at,
+ ),
+ )
+
+ @router.get(
+ "/models/{version_id}",
+ response_model=ModelVersionDetailResponse,
+ summary="Get model version detail",
+ )
+ async def get_model_version(
+ version_id: str,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> ModelVersionDetailResponse:
+ """Get detailed model version information."""
+ _validate_uuid(version_id, "version_id")
+ model = db.get_model_version(version_id)
+ if not model:
+ raise HTTPException(status_code=404, detail="Model version not found")
+
+ return ModelVersionDetailResponse(
+ version_id=str(model.version_id),
+ version=model.version,
+ name=model.name,
+ description=model.description,
+ model_path=model.model_path,
+ status=model.status,
+ is_active=model.is_active,
+ task_id=str(model.task_id) if model.task_id else None,
+ dataset_id=str(model.dataset_id) if model.dataset_id else None,
+ metrics_mAP=model.metrics_mAP,
+ metrics_precision=model.metrics_precision,
+ metrics_recall=model.metrics_recall,
+ document_count=model.document_count,
+ training_config=model.training_config,
+ file_size=model.file_size,
+ trained_at=model.trained_at,
+ activated_at=model.activated_at,
+ created_at=model.created_at,
+ updated_at=model.updated_at,
+ )
+
+ @router.patch(
+ "/models/{version_id}",
+ response_model=ModelVersionResponse,
+ summary="Update model version",
+ )
+ async def update_model_version(
+ version_id: str,
+ request: ModelVersionUpdateRequest,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> ModelVersionResponse:
+ """Update model version metadata."""
+ _validate_uuid(version_id, "version_id")
+ model = db.update_model_version(
+ version_id=version_id,
+ name=request.name,
+ description=request.description,
+ status=request.status,
+ )
+ if not model:
+ raise HTTPException(status_code=404, detail="Model version not found")
+
+ return ModelVersionResponse(
+ version_id=str(model.version_id),
+ status=model.status,
+ message="Model version updated successfully",
+ )
+
+ @router.post(
+ "/models/{version_id}/activate",
+ response_model=ModelVersionResponse,
+ summary="Activate model version",
+ description="Activate a model version for inference (deactivates all others).",
+ )
+ async def activate_model_version(
+ version_id: str,
+ request: Request,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> ModelVersionResponse:
+ """Activate a model version for inference."""
+ _validate_uuid(version_id, "version_id")
+ model = db.activate_model_version(version_id)
+ if not model:
+ raise HTTPException(status_code=404, detail="Model version not found")
+
+ # Trigger model reload in inference service
+ inference_service = getattr(request.app.state, "inference_service", None)
+ model_reloaded = False
+ if inference_service:
+ try:
+ model_reloaded = inference_service.reload_model()
+ if model_reloaded:
+ logger.info(f"Inference model reloaded to version {model.version}")
+ except Exception as e:
+ logger.warning(f"Failed to reload inference model: {e}")
+
+ message = "Model version activated for inference"
+ if model_reloaded:
+ message += " (model reloaded)"
+
+ return ModelVersionResponse(
+ version_id=str(model.version_id),
+ status=model.status,
+ message=message,
+ )
+
+ @router.post(
+ "/models/{version_id}/deactivate",
+ response_model=ModelVersionResponse,
+ summary="Deactivate model version",
+ )
+ async def deactivate_model_version(
+ version_id: str,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> ModelVersionResponse:
+ """Deactivate a model version."""
+ _validate_uuid(version_id, "version_id")
+ model = db.deactivate_model_version(version_id)
+ if not model:
+ raise HTTPException(status_code=404, detail="Model version not found")
+
+ return ModelVersionResponse(
+ version_id=str(model.version_id),
+ status=model.status,
+ message="Model version deactivated",
+ )
+
+ @router.post(
+ "/models/{version_id}/archive",
+ response_model=ModelVersionResponse,
+ summary="Archive model version",
+ )
+ async def archive_model_version(
+ version_id: str,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> ModelVersionResponse:
+ """Archive a model version."""
+ _validate_uuid(version_id, "version_id")
+ model = db.archive_model_version(version_id)
+ if not model:
+ raise HTTPException(
+ status_code=400,
+ detail="Model version not found or cannot archive active model",
+ )
+
+ return ModelVersionResponse(
+ version_id=str(model.version_id),
+ status=model.status,
+ message="Model version archived",
+ )
+
+ @router.delete(
+ "/models/{version_id}",
+ summary="Delete model version",
+ )
+ async def delete_model_version(
+ version_id: str,
+ admin_token: AdminTokenDep,
+ db: AdminDBDep,
+ ) -> dict:
+ """Delete a model version."""
+ _validate_uuid(version_id, "version_id")
+ success = db.delete_model_version(version_id)
+ if not success:
+ raise HTTPException(
+ status_code=400,
+ detail="Model version not found or cannot delete active model",
+ )
+
+ return {"message": "Model version deleted"}
+
+ @router.post(
+ "/models/reload",
+ summary="Reload inference model",
+ description="Reload the inference model from the currently active model version.",
+ )
+ async def reload_inference_model(
+ request: Request,
+ admin_token: AdminTokenDep,
+ ) -> dict:
+ """Reload the inference model from active version."""
+ inference_service = getattr(request.app.state, "inference_service", None)
+ if not inference_service:
+ raise HTTPException(
+ status_code=500,
+ detail="Inference service not available",
+ )
+
+ try:
+ model_reloaded = inference_service.reload_model()
+ if model_reloaded:
+ logger.info("Inference model manually reloaded")
+ return {"message": "Model reloaded successfully", "reloaded": True}
+ else:
+ return {"message": "Model already up to date", "reloaded": False}
+ except Exception as e:
+ logger.error(f"Failed to reload model: {e}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to reload model: {e}",
+ )
diff --git a/packages/inference/inference/web/app.py b/packages/inference/inference/web/app.py
index 2cfbfb6..f14e259 100644
--- a/packages/inference/inference/web/app.py
+++ b/packages/inference/inference/web/app.py
@@ -37,6 +37,7 @@ from inference.web.core.rate_limiter import RateLimiter
# Admin API imports
from inference.web.api.v1.admin import (
create_annotation_router,
+ create_augmentation_router,
create_auth_router,
create_documents_router,
create_locks_router,
@@ -69,10 +70,23 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
"""
config = config or default_config
- # Create inference service
+ # Create model path resolver that reads from database
+ def get_active_model_path():
+ """Resolve active model path from database."""
+ try:
+ db = AdminDB()
+ active_model = db.get_active_model_version()
+ if active_model and active_model.model_path:
+ return active_model.model_path
+ except Exception as e:
+ logger.warning(f"Failed to get active model from database: {e}")
+ return None
+
+ # Create inference service with database model resolver
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
+ model_path_resolver=get_active_model_path,
)
# Create async processing components
@@ -185,6 +199,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
+ # Store inference service for access by routes (e.g., model reload)
+ # This will be set after app creation
+
app = FastAPI(
title="Invoice Field Extraction API",
description="""
@@ -255,9 +272,15 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
+ augmentation_router = create_augmentation_router()
+ app.include_router(augmentation_router, prefix="/api/v1/admin")
+
# Include batch upload routes
app.include_router(batch_upload_router)
+ # Store inference service in app state for access by routes
+ app.state.inference_service = inference_service
+
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:
diff --git a/packages/inference/inference/web/core/scheduler.py b/packages/inference/inference/web/core/scheduler.py
index ec36469..7ece72b 100644
--- a/packages/inference/inference/web/core/scheduler.py
+++ b/packages/inference/inference/web/core/scheduler.py
@@ -110,6 +110,7 @@ class TrainingScheduler:
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
+ base_model_path = config.get("base_model_path") # For incremental training
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
@@ -117,12 +118,31 @@ class TrainingScheduler:
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
+ # Get augmentation config if present
+ augmentation_config = config.get("augmentation")
+ augmentation_multiplier = config.get("augmentation_multiplier", 2)
+
+ # Determine which model to use as base
+ if base_model_path:
+ # Incremental training: use existing trained model
+ if not Path(base_model_path).exists():
+ raise ValueError(f"Base model not found: {base_model_path}")
+ effective_model = base_model_path
+ self._db.add_training_log(
+ task_id, "INFO",
+ f"Incremental training from: {base_model_path}",
+ )
+ else:
+ # Train from pretrained model
+ effective_model = model_name
+
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
+ dataset_path = Path(dataset.dataset_path)
self._db.add_training_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
@@ -132,15 +152,28 @@ class TrainingScheduler:
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
+ dataset_path = Path(data_yaml).parent
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
+ # Apply augmentation if config is provided
+ if augmentation_config and self._has_enabled_augmentations(augmentation_config):
+ aug_result = self._apply_augmentation(
+ task_id, dataset_path, augmentation_config, augmentation_multiplier
+ )
+ if aug_result:
+ self._db.add_training_log(
+ task_id, "INFO",
+ f"Augmentation complete: {aug_result['augmented_images']} new images "
+ f"(total: {aug_result['total_images']})",
+ )
+
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
- model_name=model_name,
+ model_name=effective_model, # Use base model or pretrained model
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
@@ -159,11 +192,94 @@ class TrainingScheduler:
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
+ # Auto-create model version for the completed training
+ self._create_model_version_from_training(
+ task_id=task_id,
+ config=config,
+ dataset_id=dataset_id,
+ result=result,
+ )
+
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
raise
+ def _create_model_version_from_training(
+ self,
+ task_id: str,
+ config: dict[str, Any],
+ dataset_id: str | None,
+ result: dict[str, Any],
+ ) -> None:
+ """Create a model version entry from completed training."""
+ try:
+ model_path = result.get("model_path")
+ if not model_path:
+ logger.warning(f"No model path in training result for task {task_id}")
+ return
+
+ # Get task info for name
+ task = self._db.get_training_task(task_id)
+ task_name = task.name if task else f"Task {task_id[:8]}"
+
+ # Generate version number based on existing versions
+ existing_versions = self._db.get_model_versions(limit=1, offset=0)
+ version_count = existing_versions[1] if existing_versions else 0
+ version = f"v{version_count + 1}.0"
+
+ # Extract metrics from result
+ metrics = result.get("metrics", {})
+ metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
+ metrics_precision = metrics.get("precision")
+ metrics_recall = metrics.get("recall")
+
+ # Get file size if possible
+ file_size = None
+ model_file = Path(model_path)
+ if model_file.exists():
+ file_size = model_file.stat().st_size
+
+ # Get document count from dataset if available
+ document_count = 0
+ if dataset_id:
+ dataset = self._db.get_dataset(dataset_id)
+ if dataset:
+ document_count = dataset.total_documents
+
+ # Create model version
+ model_version = self._db.create_model_version(
+ version=version,
+ name=task_name,
+ model_path=str(model_path),
+ description=f"Auto-created from training task {task_id[:8]}",
+ task_id=task_id,
+ dataset_id=dataset_id,
+ metrics_mAP=metrics_mAP,
+ metrics_precision=metrics_precision,
+ metrics_recall=metrics_recall,
+ document_count=document_count,
+ training_config=config,
+ file_size=file_size,
+ trained_at=datetime.utcnow(),
+ )
+
+ logger.info(
+ f"Created model version {version} (ID: {model_version.version_id}) "
+ f"from training task {task_id}"
+ )
+ self._db.add_training_log(
+ task_id, "INFO",
+ f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to create model version for task {task_id}: {e}")
+ self._db.add_training_log(
+ task_id, "WARNING",
+ f"Failed to auto-create model version: {e}",
+ )
+
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
@@ -256,62 +372,82 @@ names: {list(FIELD_CLASSES.values())}
device: str,
project_name: str,
) -> dict[str, Any]:
- """Run YOLO training."""
+ """Run YOLO training using shared trainer."""
+ from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
+
+ # Create log callback that writes to DB
+ def log_callback(level: str, message: str) -> None:
+ self._db.add_training_log(task_id, level, message)
+
+ # Create shared training config
+ # Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
+ config = SharedTrainingConfig(
+ model_path=model_name,
+ data_yaml=data_yaml,
+ epochs=epochs,
+ batch_size=batch_size,
+ image_size=image_size,
+ learning_rate=learning_rate,
+ device=device,
+ project="runs/train",
+ name=f"{project_name}/task_{task_id[:8]}",
+ workers=0,
+ )
+
+ # Run training using shared trainer
+ trainer = YOLOTrainer(config=config, log_callback=log_callback)
+ result = trainer.train()
+
+ if not result.success:
+ raise ValueError(result.error or "Training failed")
+
+ return {
+ "model_path": result.model_path,
+ "metrics": result.metrics,
+ }
+
+ def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
+ """Check if any augmentations are enabled in the config."""
+ augmentation_fields = [
+ "perspective_warp", "wrinkle", "edge_damage", "stain",
+ "lighting_variation", "shadow", "gaussian_blur", "motion_blur",
+ "gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
+ ]
+ for field in augmentation_fields:
+ if field in aug_config:
+ field_config = aug_config[field]
+ if isinstance(field_config, dict) and field_config.get("enabled", False):
+ return True
+ return False
+
+ def _apply_augmentation(
+ self,
+ task_id: str,
+ dataset_path: Path,
+ aug_config: dict[str, Any],
+ multiplier: int,
+ ) -> dict[str, int] | None:
+ """Apply augmentation to dataset before training."""
try:
- from ultralytics import YOLO
-
- # Log training start
- self._db.add_training_log(
- task_id, "INFO",
- f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
- )
-
- # Load model
- model = YOLO(model_name)
-
- # Train
- results = model.train(
- data=data_yaml,
- epochs=epochs,
- batch=batch_size,
- imgsz=image_size,
- lr0=learning_rate,
- device=device,
- project=f"runs/train/{project_name}",
- name=f"task_{task_id[:8]}",
- exist_ok=True,
- verbose=True,
- )
-
- # Get best model path
- best_model = Path(results.save_dir) / "weights" / "best.pt"
-
- # Extract metrics
- metrics = {}
- if hasattr(results, "results_dict"):
- metrics = {
- "mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
- "mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
- "precision": results.results_dict.get("metrics/precision(B)", 0),
- "recall": results.results_dict.get("metrics/recall(B)", 0),
- }
+ from shared.augmentation import DatasetAugmenter
self._db.add_training_log(
task_id, "INFO",
- f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
+ f"Applying augmentation with multiplier={multiplier}",
)
- return {
- "model_path": str(best_model) if best_model.exists() else None,
- "metrics": metrics,
- }
+ augmenter = DatasetAugmenter(aug_config)
+ result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
+
+ return result
- except ImportError:
- self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
- raise ValueError("Ultralytics (YOLO) not installed")
except Exception as e:
- self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
- raise
+ logger.error(f"Augmentation failed for task {task_id}: {e}")
+ self._db.add_training_log(
+ task_id, "WARNING",
+ f"Augmentation failed: {e}. Continuing with original dataset.",
+ )
+ return None
# Global scheduler instance
diff --git a/packages/inference/inference/web/schemas/admin/__init__.py b/packages/inference/inference/web/schemas/admin/__init__.py
index 1300b4e..ca4d999 100644
--- a/packages/inference/inference/web/schemas/admin/__init__.py
+++ b/packages/inference/inference/web/schemas/admin/__init__.py
@@ -10,6 +10,7 @@ from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
+from .models import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse
diff --git a/packages/inference/inference/web/schemas/admin/augmentation.py b/packages/inference/inference/web/schemas/admin/augmentation.py
new file mode 100644
index 0000000..3cd638d
--- /dev/null
+++ b/packages/inference/inference/web/schemas/admin/augmentation.py
@@ -0,0 +1,187 @@
+"""Admin Augmentation Schemas."""
+
+from datetime import datetime
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+
+class AugmentationParamsSchema(BaseModel):
+ """Single augmentation parameters."""
+
+ enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
+ probability: float = Field(
+ default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
+ )
+ params: dict[str, Any] = Field(
+ default_factory=dict, description="Type-specific parameters"
+ )
+
+
+class AugmentationConfigSchema(BaseModel):
+ """Complete augmentation configuration."""
+
+ # Geometric transforms
+ perspective_warp: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+
+ # Degradation effects
+ wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
+ edge_damage: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+ stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
+
+ # Lighting effects
+ lighting_variation: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+ shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
+
+ # Blur effects
+ gaussian_blur: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+ motion_blur: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+
+ # Noise effects
+ gaussian_noise: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+ salt_pepper: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+
+ # Texture effects
+ paper_texture: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+ scanner_artifacts: AugmentationParamsSchema = Field(
+ default_factory=AugmentationParamsSchema
+ )
+
+ # Global settings
+ preserve_bboxes: bool = Field(
+ default=True, description="Whether to adjust bboxes for geometric transforms"
+ )
+ seed: int | None = Field(default=None, description="Random seed for reproducibility")
+
+
+class AugmentationTypeInfo(BaseModel):
+ """Information about an augmentation type."""
+
+ name: str = Field(..., description="Augmentation name")
+ description: str = Field(..., description="Augmentation description")
+ affects_geometry: bool = Field(
+ ..., description="Whether this augmentation affects bbox coordinates"
+ )
+ stage: str = Field(..., description="Processing stage")
+ default_params: dict[str, Any] = Field(
+ default_factory=dict, description="Default parameters"
+ )
+
+
+class AugmentationTypesResponse(BaseModel):
+ """Response for listing augmentation types."""
+
+ augmentation_types: list[AugmentationTypeInfo] = Field(
+ ..., description="Available augmentation types"
+ )
+
+
+class PresetInfo(BaseModel):
+ """Information about a preset."""
+
+ name: str = Field(..., description="Preset name")
+ description: str = Field(..., description="Preset description")
+
+
+class PresetsResponse(BaseModel):
+ """Response for listing presets."""
+
+ presets: list[PresetInfo] = Field(..., description="Available presets")
+
+
+class AugmentationPreviewRequest(BaseModel):
+ """Request to preview augmentation on an image."""
+
+ augmentation_type: str = Field(..., description="Type of augmentation to preview")
+ params: dict[str, Any] = Field(
+ default_factory=dict, description="Override parameters"
+ )
+
+
+class AugmentationPreviewResponse(BaseModel):
+ """Response with preview image data."""
+
+ preview_url: str = Field(..., description="URL to preview image")
+ original_url: str = Field(..., description="URL to original image")
+ applied_params: dict[str, Any] = Field(..., description="Applied parameters")
+
+
+class AugmentationBatchRequest(BaseModel):
+ """Request to augment a dataset offline."""
+
+ dataset_id: str = Field(..., description="Source dataset UUID")
+ config: AugmentationConfigSchema = Field(..., description="Augmentation config")
+ output_name: str = Field(
+ ..., min_length=1, max_length=255, description="Output dataset name"
+ )
+ multiplier: int = Field(
+ default=2, ge=1, le=10, description="Augmented copies per image"
+ )
+
+
+class AugmentationBatchResponse(BaseModel):
+ """Response for batch augmentation."""
+
+ task_id: str = Field(..., description="Background task UUID")
+ status: str = Field(..., description="Task status")
+ message: str = Field(..., description="Status message")
+ estimated_images: int = Field(..., description="Estimated total images")
+
+
+class AugmentedDatasetItem(BaseModel):
+ """Single augmented dataset in list."""
+
+ dataset_id: str = Field(..., description="Dataset UUID")
+ source_dataset_id: str = Field(..., description="Source dataset UUID")
+ name: str = Field(..., description="Dataset name")
+ status: str = Field(..., description="Dataset status")
+ multiplier: int = Field(..., description="Augmentation multiplier")
+ total_original_images: int = Field(..., description="Original image count")
+ total_augmented_images: int = Field(..., description="Augmented image count")
+ created_at: datetime = Field(..., description="Creation timestamp")
+
+
+class AugmentedDatasetListResponse(BaseModel):
+ """Response for listing augmented datasets."""
+
+ total: int = Field(..., ge=0, description="Total datasets")
+ limit: int = Field(..., ge=1, description="Page size")
+ offset: int = Field(..., ge=0, description="Current offset")
+ datasets: list[AugmentedDatasetItem] = Field(
+ default_factory=list, description="Dataset list"
+ )
+
+
+class AugmentedDatasetDetailResponse(BaseModel):
+ """Detailed augmented dataset response."""
+
+ dataset_id: str = Field(..., description="Dataset UUID")
+ source_dataset_id: str = Field(..., description="Source dataset UUID")
+ name: str = Field(..., description="Dataset name")
+ status: str = Field(..., description="Dataset status")
+ config: AugmentationConfigSchema | None = Field(
+ None, description="Augmentation config used"
+ )
+ multiplier: int = Field(..., description="Augmentation multiplier")
+ total_original_images: int = Field(..., description="Original image count")
+ total_augmented_images: int = Field(..., description="Augmented image count")
+ dataset_path: str | None = Field(None, description="Dataset path on disk")
+ error_message: str | None = Field(None, description="Error message if failed")
+ created_at: datetime = Field(..., description="Creation timestamp")
+ completed_at: datetime | None = Field(None, description="Completion timestamp")
diff --git a/packages/inference/inference/web/schemas/admin/datasets.py b/packages/inference/inference/web/schemas/admin/datasets.py
index f7e38c9..e1c9420 100644
--- a/packages/inference/inference/web/schemas/admin/datasets.py
+++ b/packages/inference/inference/web/schemas/admin/datasets.py
@@ -63,6 +63,8 @@ class DatasetListItem(BaseModel):
name: str
description: str | None
status: str
+ training_status: str | None = None
+ active_training_task_id: str | None = None
total_documents: int
total_images: int
total_annotations: int
diff --git a/packages/inference/inference/web/schemas/admin/documents.py b/packages/inference/inference/web/schemas/admin/documents.py
index fdf3874..cf3ea82 100644
--- a/packages/inference/inference/web/schemas/admin/documents.py
+++ b/packages/inference/inference/web/schemas/admin/documents.py
@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
+ group_key: str | None = Field(None, description="User-defined group key")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
@@ -42,6 +43,7 @@ class DocumentItem(BaseModel):
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
+ group_key: str | None = Field(None, description="User-defined group key")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@@ -73,6 +75,7 @@ class DocumentDetailResponse(BaseModel):
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
+ group_key: str | None = Field(None, description="User-defined group key")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)
diff --git a/packages/inference/inference/web/schemas/admin/models.py b/packages/inference/inference/web/schemas/admin/models.py
new file mode 100644
index 0000000..7359a1e
--- /dev/null
+++ b/packages/inference/inference/web/schemas/admin/models.py
@@ -0,0 +1,95 @@
+"""Admin Model Version Schemas."""
+
+from datetime import datetime
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+
+class ModelVersionCreateRequest(BaseModel):
+ """Request to create a model version."""
+
+ version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
+ name: str = Field(..., min_length=1, max_length=255, description="Model name")
+ model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
+ description: str | None = Field(None, description="Optional description")
+ task_id: str | None = Field(None, description="Training task UUID")
+ dataset_id: str | None = Field(None, description="Dataset UUID")
+ metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
+ metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
+ metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
+ document_count: int = Field(0, ge=0, description="Documents used in training")
+ training_config: dict[str, Any] | None = Field(None, description="Training configuration")
+ file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
+ trained_at: datetime | None = Field(None, description="Training completion time")
+
+
+class ModelVersionUpdateRequest(BaseModel):
+ """Request to update a model version."""
+
+ name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
+ description: str | None = Field(None, description="Description")
+ status: str | None = Field(None, description="Status (inactive, archived)")
+
+
+class ModelVersionItem(BaseModel):
+ """Model version in list view."""
+
+ version_id: str = Field(..., description="Version UUID")
+ version: str = Field(..., description="Semantic version")
+ name: str = Field(..., description="Model name")
+ status: str = Field(..., description="Status (active, inactive, archived)")
+ is_active: bool = Field(..., description="Is currently active for inference")
+ metrics_mAP: float | None = Field(None, description="Mean Average Precision")
+ document_count: int = Field(..., description="Documents used in training")
+ trained_at: datetime | None = Field(None, description="Training completion time")
+ activated_at: datetime | None = Field(None, description="Last activation time")
+ created_at: datetime = Field(..., description="Creation timestamp")
+
+
+class ModelVersionListResponse(BaseModel):
+ """Paginated model version list."""
+
+ total: int = Field(..., ge=0, description="Total model versions")
+ limit: int = Field(..., ge=1, description="Page size")
+ offset: int = Field(..., ge=0, description="Current offset")
+ models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
+
+
+class ModelVersionDetailResponse(BaseModel):
+ """Detailed model version info."""
+
+ version_id: str = Field(..., description="Version UUID")
+ version: str = Field(..., description="Semantic version")
+ name: str = Field(..., description="Model name")
+ description: str | None = Field(None, description="Description")
+ model_path: str = Field(..., description="Path to model file")
+ status: str = Field(..., description="Status (active, inactive, archived)")
+ is_active: bool = Field(..., description="Is currently active for inference")
+ task_id: str | None = Field(None, description="Training task UUID")
+ dataset_id: str | None = Field(None, description="Dataset UUID")
+ metrics_mAP: float | None = Field(None, description="Mean Average Precision")
+ metrics_precision: float | None = Field(None, description="Precision")
+ metrics_recall: float | None = Field(None, description="Recall")
+ document_count: int = Field(..., description="Documents used in training")
+ training_config: dict[str, Any] | None = Field(None, description="Training configuration")
+ file_size: int | None = Field(None, description="Model file size in bytes")
+ trained_at: datetime | None = Field(None, description="Training completion time")
+ activated_at: datetime | None = Field(None, description="Last activation time")
+ created_at: datetime = Field(..., description="Creation timestamp")
+ updated_at: datetime = Field(..., description="Last update timestamp")
+
+
+class ModelVersionResponse(BaseModel):
+ """Response for model version operation."""
+
+ version_id: str = Field(..., description="Version UUID")
+ status: str = Field(..., description="Model status")
+ message: str = Field(..., description="Status message")
+
+
+class ActiveModelResponse(BaseModel):
+ """Response for active model query."""
+
+ has_active_model: bool = Field(..., description="Whether an active model exists")
+ model: ModelVersionItem | None = Field(None, description="Active model if exists")
diff --git a/packages/inference/inference/web/schemas/admin/training.py b/packages/inference/inference/web/schemas/admin/training.py
index 6958692..2fe0cfd 100644
--- a/packages/inference/inference/web/schemas/admin/training.py
+++ b/packages/inference/inference/web/schemas/admin/training.py
@@ -5,13 +5,18 @@ from typing import Any
from pydantic import BaseModel, Field
+from .augmentation import AugmentationConfigSchema
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
- model_name: str = Field(default="yolo11n.pt", description="Base model name")
+ model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
+ base_model_version_id: str | None = Field(
+ default=None,
+ description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
+ )
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
@@ -21,6 +26,18 @@ class TrainingConfig(BaseModel):
default="invoice_fields", description="Training project name"
)
+ # Data augmentation settings
+ augmentation: AugmentationConfigSchema | None = Field(
+ default=None,
+ description="Augmentation configuration. If provided, augments dataset before training.",
+ )
+ augmentation_multiplier: int = Field(
+ default=2,
+ ge=1,
+ le=10,
+ description="Number of augmented copies per original image",
+ )
+
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""
diff --git a/packages/inference/inference/web/services/augmentation_service.py b/packages/inference/inference/web/services/augmentation_service.py
new file mode 100644
index 0000000..e13e22a
--- /dev/null
+++ b/packages/inference/inference/web/services/augmentation_service.py
@@ -0,0 +1,317 @@
+"""Augmentation service for handling augmentation operations."""
+
+import base64
+import io
+import re
+import uuid
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+from fastapi import HTTPException
+from PIL import Image
+
+from inference.data.admin_db import AdminDB
+from inference.web.schemas.admin.augmentation import (
+ AugmentationBatchResponse,
+ AugmentationConfigSchema,
+ AugmentationPreviewResponse,
+ AugmentedDatasetItem,
+ AugmentedDatasetListResponse,
+)
+
+# Constants
+PREVIEW_MAX_SIZE = 800
+PREVIEW_SEED = 42
+UUID_PATTERN = re.compile(
+ r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
+ re.IGNORECASE,
+)
+
+
+class AugmentationService:
+ """Service for augmentation operations."""
+
+ def __init__(self, db: AdminDB) -> None:
+ """Initialize service with database connection."""
+ self.db = db
+
+ def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
+ """
+ Validate UUID format to prevent path traversal.
+
+ Args:
+ value: Value to validate.
+ field_name: Field name for error message.
+
+ Raises:
+ HTTPException: If value is not a valid UUID.
+ """
+ if not UUID_PATTERN.match(value):
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid {field_name} format: {value}",
+ )
+
+ async def preview_single(
+ self,
+ document_id: str,
+ page: int,
+ augmentation_type: str,
+ params: dict[str, Any],
+ ) -> AugmentationPreviewResponse:
+ """
+ Preview a single augmentation on a document page.
+
+ Args:
+ document_id: Document UUID.
+ page: Page number (1-indexed).
+ augmentation_type: Name of augmentation to apply.
+ params: Override parameters.
+
+ Returns:
+ Preview response with image URLs.
+
+ Raises:
+ HTTPException: If document not found or augmentation invalid.
+ """
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+ from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
+
+ # Validate augmentation type
+ if augmentation_type not in AUGMENTATION_REGISTRY:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unknown augmentation type: {augmentation_type}. "
+ f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
+ )
+
+ # Get document and load image
+ image = await self._load_document_page(document_id, page)
+
+ # Create config with only this augmentation enabled
+ config_kwargs = {
+ augmentation_type: AugmentationParams(
+ enabled=True,
+ probability=1.0, # Always apply for preview
+ params=params,
+ ),
+ "seed": PREVIEW_SEED, # Deterministic preview
+ }
+ config = AugmentationConfig(**config_kwargs)
+ pipeline = AugmentationPipeline(config)
+
+ # Apply augmentation
+ result = pipeline.apply(image)
+
+ # Convert to base64 URLs
+ original_url = self._image_to_data_url(image)
+ preview_url = self._image_to_data_url(result.image)
+
+ return AugmentationPreviewResponse(
+ preview_url=preview_url,
+ original_url=original_url,
+ applied_params=params,
+ )
+
+ async def preview_config(
+ self,
+ document_id: str,
+ page: int,
+ config: AugmentationConfigSchema,
+ ) -> AugmentationPreviewResponse:
+ """
+ Preview full augmentation config on a document page.
+
+ Args:
+ document_id: Document UUID.
+ page: Page number (1-indexed).
+ config: Full augmentation configuration.
+
+ Returns:
+ Preview response with image URLs.
+ """
+ from shared.augmentation.config import AugmentationConfig
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ # Load image
+ image = await self._load_document_page(document_id, page)
+
+ # Convert Pydantic model to internal config
+ config_dict = config.model_dump()
+ internal_config = AugmentationConfig.from_dict(config_dict)
+ pipeline = AugmentationPipeline(internal_config)
+
+ # Apply augmentation
+ result = pipeline.apply(image)
+
+ # Convert to base64 URLs
+ original_url = self._image_to_data_url(image)
+ preview_url = self._image_to_data_url(result.image)
+
+ return AugmentationPreviewResponse(
+ preview_url=preview_url,
+ original_url=original_url,
+ applied_params=config_dict,
+ )
+
+ async def create_augmented_dataset(
+ self,
+ source_dataset_id: str,
+ config: AugmentationConfigSchema,
+ output_name: str,
+ multiplier: int,
+ ) -> AugmentationBatchResponse:
+ """
+ Create a new augmented dataset from an existing dataset.
+
+ Args:
+ source_dataset_id: Source dataset UUID.
+ config: Augmentation configuration.
+ output_name: Name for the new dataset.
+ multiplier: Number of augmented copies per image.
+
+ Returns:
+ Batch response with task ID.
+
+ Raises:
+ HTTPException: If source dataset not found.
+ """
+ # Validate source dataset exists
+ try:
+ source_dataset = self.db.get_dataset(source_dataset_id)
+ if source_dataset is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Source dataset not found: {source_dataset_id}",
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Source dataset not found: {source_dataset_id}",
+ ) from e
+
+ # Create task ID for background processing
+ task_id = str(uuid.uuid4())
+
+ # Estimate total images
+ estimated_images = (
+ source_dataset.total_images * multiplier
+ if hasattr(source_dataset, "total_images")
+ else 0
+ )
+
+ # TODO: Queue background task for actual augmentation
+ # For now, return pending status
+
+ return AugmentationBatchResponse(
+ task_id=task_id,
+ status="pending",
+ message=f"Augmentation task queued for dataset '{output_name}'",
+ estimated_images=estimated_images,
+ )
+
+ async def list_augmented_datasets(
+ self,
+ limit: int = 20,
+ offset: int = 0,
+ ) -> AugmentedDatasetListResponse:
+ """
+ List augmented datasets.
+
+ Args:
+ limit: Maximum number of datasets to return.
+ offset: Number of datasets to skip.
+
+ Returns:
+ List response with datasets.
+ """
+ # TODO: Implement actual database query for augmented datasets
+ # For now, return empty list
+
+ return AugmentedDatasetListResponse(
+ total=0,
+ limit=limit,
+ offset=offset,
+ datasets=[],
+ )
+
+ async def _load_document_page(
+ self,
+ document_id: str,
+ page: int,
+ ) -> np.ndarray:
+ """
+ Load a document page as numpy array.
+
+ Args:
+ document_id: Document UUID.
+ page: Page number (1-indexed).
+
+ Returns:
+ Image as numpy array (H, W, C) with dtype uint8.
+
+ Raises:
+ HTTPException: If document or page not found.
+ """
+ # Validate document_id format to prevent path traversal
+ self._validate_uuid(document_id, "document_id")
+
+ # Get document from database
+ try:
+ document = self.db.get_document(document_id)
+ if document is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Document not found: {document_id}",
+ )
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Document not found: {document_id}",
+ ) from e
+
+ # Get image path for page
+ if hasattr(document, "images_dir"):
+ images_dir = Path(document.images_dir)
+ else:
+ # Fallback to constructed path
+ from inference.web.core.config import get_settings
+
+ settings = get_settings()
+ images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
+
+ # Find image for page
+ page_idx = page - 1 # Convert to 0-indexed
+ image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
+
+ if page_idx >= len(image_files):
+ raise HTTPException(
+ status_code=404,
+ detail=f"Page {page} not found for document {document_id}",
+ )
+
+ # Load image
+ image_path = image_files[page_idx]
+ pil_image = Image.open(image_path).convert("RGB")
+ return np.array(pil_image)
+
+ def _image_to_data_url(self, image: np.ndarray) -> str:
+ """Convert numpy image to base64 data URL."""
+ pil_image = Image.fromarray(image)
+
+ # Resize for preview if too large
+ max_size = PREVIEW_MAX_SIZE
+ if max(pil_image.size) > max_size:
+ ratio = max_size / max(pil_image.size)
+ new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
+ pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
+
+ # Convert to base64
+ buffer = io.BytesIO()
+ pil_image.save(buffer, format="PNG")
+ base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+ return f"data:image/png;base64,{base64_data}"
diff --git a/packages/inference/inference/web/services/dataset_builder.py b/packages/inference/inference/web/services/dataset_builder.py
index 30c69ce..fc383b4 100644
--- a/packages/inference/inference/web/services/dataset_builder.py
+++ b/packages/inference/inference/web/services/dataset_builder.py
@@ -81,29 +81,18 @@ class DatasetBuilder:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
- # 3. Shuffle and split documents
+ # 3. Group documents by group_key and assign splits
doc_list = list(documents)
- rng = random.Random(seed)
- rng.shuffle(doc_list)
-
- n = len(doc_list)
- n_train = max(1, round(n * train_ratio))
- n_val = max(0, round(n * val_ratio))
- n_test = n - n_train - n_val
-
- splits = (
- ["train"] * n_train
- + ["val"] * n_val
- + ["test"] * n_test
- )
+ doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
- for doc, split in zip(doc_list, splits):
+ for doc in doc_list:
doc_id = str(doc.document_id)
+ split = doc_splits[doc_id]
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
@@ -174,6 +163,86 @@ class DatasetBuilder:
"total_annotations": total_annotations,
}
+ def _assign_splits_by_group(
+ self,
+ documents: list,
+ train_ratio: float,
+ val_ratio: float,
+ seed: int,
+ ) -> dict[str, str]:
+ """Assign splits based on group_key.
+
+ Logic:
+ - Documents with same group_key stay together in the same split
+ - Groups with only 1 document go directly to train
+ - Groups with 2+ documents participate in shuffle & split
+
+ Args:
+ documents: List of AdminDocument objects
+ train_ratio: Fraction for training set
+ val_ratio: Fraction for validation set
+ seed: Random seed for reproducibility
+
+ Returns:
+ Dict mapping document_id (str) -> split ("train"/"val"/"test")
+ """
+ # Group documents by group_key
+ # None/empty group_key treated as unique (each doc is its own group)
+ groups: dict[str | None, list] = {}
+ for doc in documents:
+ key = doc.group_key if doc.group_key else None
+ if key is None:
+ # Treat each ungrouped doc as its own unique group
+ # Use document_id as pseudo-key
+ key = f"__ungrouped_{doc.document_id}"
+ groups.setdefault(key, []).append(doc)
+
+ # Separate single-doc groups from multi-doc groups
+ single_doc_groups: list[tuple[str | None, list]] = []
+ multi_doc_groups: list[tuple[str | None, list]] = []
+
+ for key, docs in groups.items():
+ if len(docs) == 1:
+ single_doc_groups.append((key, docs))
+ else:
+ multi_doc_groups.append((key, docs))
+
+ # Initialize result mapping
+ doc_splits: dict[str, str] = {}
+
+ # Combine all groups for splitting
+ all_groups = single_doc_groups + multi_doc_groups
+
+ # Shuffle all groups and assign splits
+ if all_groups:
+ rng = random.Random(seed)
+ rng.shuffle(all_groups)
+
+ n_groups = len(all_groups)
+ n_train = max(1, round(n_groups * train_ratio))
+ # Ensure at least 1 in val if we have more than 1 group
+ n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
+
+ for i, (_key, docs) in enumerate(all_groups):
+ if i < n_train:
+ split = "train"
+ elif i < n_train + n_val:
+ split = "val"
+ else:
+ split = "test"
+
+ for doc in docs:
+ doc_splits[str(doc.document_id)] = split
+
+ logger.info(
+ "Split assignment: %d total groups shuffled (train=%d, val=%d)",
+ len(all_groups),
+ sum(1 for s in doc_splits.values() if s == "train"),
+ sum(1 for s in doc_splits.values() if s == "val"),
+ )
+
+ return doc_splits
+
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {
diff --git a/packages/inference/inference/web/services/inference.py b/packages/inference/inference/web/services/inference.py
index f087569..84d4028 100644
--- a/packages/inference/inference/web/services/inference.py
+++ b/packages/inference/inference/web/services/inference.py
@@ -11,7 +11,7 @@ import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable
import numpy as np
from PIL import Image
@@ -22,6 +22,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Type alias for model path resolver function
+ModelPathResolver = Callable[[], Path | None]
+
+
@dataclass
class ServiceResult:
"""Result from inference service."""
@@ -42,25 +46,52 @@ class InferenceService:
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
+ Supports dynamic model loading from database.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
+ model_path_resolver: ModelPathResolver | None = None,
) -> None:
"""
Initialize inference service.
Args:
- model_config: Model configuration
+ model_config: Model configuration (default model settings)
storage_config: Storage configuration
+ model_path_resolver: Optional function to resolve model path from database.
+ If provided, will be called to get active model path.
+ If returns None, falls back to model_config.model_path.
"""
self.model_config = model_config
self.storage_config = storage_config
+ self._model_path_resolver = model_path_resolver
self._pipeline = None
self._detector = None
self._is_initialized = False
+ self._current_model_path: Path | None = None
+
+ def _resolve_model_path(self) -> Path:
+ """Resolve the model path to use for inference.
+
+ Priority:
+ 1. Active model from database (via resolver)
+ 2. Default model from config
+ """
+ if self._model_path_resolver:
+ try:
+ db_model_path = self._model_path_resolver()
+ if db_model_path and Path(db_model_path).exists():
+ logger.info(f"Using active model from database: {db_model_path}")
+ return Path(db_model_path)
+ elif db_model_path:
+ logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
+ except Exception as e:
+ logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
+
+ return self.model_config.model_path
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
@@ -74,16 +105,20 @@ class InferenceService:
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
+ # Resolve model path (from DB or config)
+ model_path = self._resolve_model_path()
+ self._current_model_path = model_path
+
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
- str(self.model_config.model_path),
+ str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
- model_path=str(self.model_config.model_path),
+ model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
@@ -92,12 +127,36 @@ class InferenceService:
self._is_initialized = True
elapsed = time.time() - start_time
- logger.info(f"Inference service initialized in {elapsed:.2f}s")
+ logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
+ def reload_model(self) -> bool:
+ """Reload the model if active model has changed.
+
+ Returns:
+ True if model was reloaded, False if no change needed.
+ """
+ new_model_path = self._resolve_model_path()
+
+ if self._current_model_path == new_model_path:
+ logger.debug("Model unchanged, no reload needed")
+ return False
+
+ logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
+ self._is_initialized = False
+ self._pipeline = None
+ self._detector = None
+ self.initialize()
+ return True
+
+ @property
+ def current_model_path(self) -> Path | None:
+ """Get the currently loaded model path."""
+ return self._current_model_path
+
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""
diff --git a/packages/shared/shared/augmentation/__init__.py b/packages/shared/shared/augmentation/__init__.py
new file mode 100644
index 0000000..ec8ed5b
--- /dev/null
+++ b/packages/shared/shared/augmentation/__init__.py
@@ -0,0 +1,24 @@
+"""
+Document Image Augmentation Module.
+
+Provides augmentation transformations for training data enhancement,
+specifically designed for document images (invoices, forms, etc.).
+
+Key features:
+- Document-safe augmentations that preserve text readability
+- Support for both offline preprocessing and runtime augmentation
+- Bbox-aware geometric transforms
+- Configurable augmentation pipeline
+"""
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+from shared.augmentation.config import AugmentationConfig, AugmentationParams
+from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+__all__ = [
+ "AugmentationConfig",
+ "AugmentationParams",
+ "AugmentationResult",
+ "BaseAugmentation",
+ "DatasetAugmenter",
+]
diff --git a/packages/shared/shared/augmentation/base.py b/packages/shared/shared/augmentation/base.py
new file mode 100644
index 0000000..b342125
--- /dev/null
+++ b/packages/shared/shared/augmentation/base.py
@@ -0,0 +1,108 @@
+"""
+Base classes for augmentation transforms.
+
+Provides abstract base class and result dataclass for all augmentation
+implementations.
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import Any
+
+import numpy as np
+
+
+@dataclass
+class AugmentationResult:
+ """
+ Result of applying an augmentation.
+
+ Attributes:
+ image: The augmented image as numpy array (H, W, C).
+ bboxes: Updated bounding boxes if geometric transform was applied.
+ Format: (N, 5) array with [class_id, x_center, y_center, width, height].
+ transform_matrix: The transformation matrix if applicable (for bbox adjustment).
+ applied: Whether the augmentation was actually applied.
+ metadata: Additional metadata about the augmentation.
+ """
+
+ image: np.ndarray
+ bboxes: np.ndarray | None = None
+ transform_matrix: np.ndarray | None = None
+ applied: bool = True
+ metadata: dict[str, Any] | None = None
+
+
+class BaseAugmentation(ABC):
+ """
+ Abstract base class for all augmentations.
+
+ Subclasses must implement:
+ - _validate_params(): Validate augmentation parameters
+ - apply(): Apply the augmentation to an image
+
+ Class attributes:
+ name: Human-readable name of the augmentation.
+ affects_geometry: True if this augmentation modifies bbox coordinates.
+ """
+
+ name: str = "base"
+ affects_geometry: bool = False
+
+ def __init__(self, params: dict[str, Any]) -> None:
+ """
+ Initialize augmentation with parameters.
+
+ Args:
+ params: Dictionary of augmentation-specific parameters.
+ """
+ self.params = params
+ self._validate_params()
+
+ @abstractmethod
+ def _validate_params(self) -> None:
+ """
+ Validate augmentation parameters.
+
+ Raises:
+ ValueError: If parameters are invalid.
+ """
+ pass
+
+ @abstractmethod
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ """
+ Apply augmentation to image.
+
+ IMPORTANT: Implementations must NOT modify the input image or bboxes.
+ Always create copies before modifying.
+
+ Args:
+ image: Input image as numpy array (H, W, C) with dtype uint8.
+ bboxes: Optional bounding boxes in YOLO format (N, 5) array.
+ Each row: [class_id, x_center, y_center, width, height].
+ Coordinates are normalized to 0-1 range.
+ rng: Random number generator for reproducibility.
+ If None, a new generator should be created.
+
+ Returns:
+ AugmentationResult with augmented image and optionally updated bboxes.
+ """
+ pass
+
+ def get_preview_params(self) -> dict[str, Any]:
+ """
+ Get parameters optimized for preview display.
+
+ Override this method to provide parameters that produce
+ clearly visible effects for preview/demo purposes.
+
+ Returns:
+ Dictionary of preview parameters.
+ """
+ return dict(self.params)
diff --git a/packages/shared/shared/augmentation/config.py b/packages/shared/shared/augmentation/config.py
new file mode 100644
index 0000000..5c7cfbc
--- /dev/null
+++ b/packages/shared/shared/augmentation/config.py
@@ -0,0 +1,274 @@
+"""
+Augmentation configuration module.
+
+Provides dataclasses for configuring document image augmentations.
+All default values are document-safe (conservative) to preserve text readability.
+"""
+
+from dataclasses import dataclass, field
+from typing import Any
+
+
+@dataclass
+class AugmentationParams:
+ """
+ Parameters for a single augmentation type.
+
+ Attributes:
+ enabled: Whether this augmentation is enabled.
+ probability: Probability of applying this augmentation (0.0 to 1.0).
+ params: Type-specific parameters dictionary.
+ """
+
+ enabled: bool = False
+ probability: float = 0.5
+ params: dict[str, Any] = field(default_factory=dict)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary for serialization."""
+ return {
+ "enabled": self.enabled,
+ "probability": self.probability,
+ "params": dict(self.params),
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "AugmentationParams":
+ """Create from dictionary."""
+ return cls(
+ enabled=data.get("enabled", False),
+ probability=data.get("probability", 0.5),
+ params=dict(data.get("params", {})),
+ )
+
+
+def _default_perspective_warp() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.3,
+ params={"max_warp": 0.02}, # Very conservative - 2% max distortion
+ )
+
+
+def _default_wrinkle() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.3,
+ params={"intensity": 0.3, "num_wrinkles": (2, 5)},
+ )
+
+
+def _default_edge_damage() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.2,
+ params={"max_damage_ratio": 0.05}, # Max 5% of edge damaged
+ )
+
+
+def _default_stain() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.2,
+ params={
+ "num_stains": (1, 3),
+ "max_radius_ratio": 0.1,
+ "opacity": (0.1, 0.3),
+ },
+ )
+
+
+def _default_lighting_variation() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=True, # Safe default, commonly needed
+ probability=0.5,
+ params={
+ "brightness_range": (-0.1, 0.1),
+ "contrast_range": (0.9, 1.1),
+ },
+ )
+
+
+def _default_shadow() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.3,
+ params={"num_shadows": (1, 2), "opacity": (0.2, 0.4)},
+ )
+
+
+def _default_gaussian_blur() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.2,
+ params={"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
+ )
+
+
+def _default_motion_blur() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.2,
+ params={"kernel_size": (5, 9), "angle_range": (-45, 45)},
+ )
+
+
+def _default_gaussian_noise() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.3,
+ params={"mean": 0, "std": (5, 15)}, # Conservative noise levels
+ )
+
+
+def _default_salt_pepper() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.2,
+ params={"amount": (0.001, 0.005)}, # Very sparse
+ )
+
+
+def _default_paper_texture() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.3,
+ params={"texture_type": "random", "intensity": (0.05, 0.15)},
+ )
+
+
+def _default_scanner_artifacts() -> AugmentationParams:
+ return AugmentationParams(
+ enabled=False,
+ probability=0.2,
+ params={"line_probability": 0.3, "dust_probability": 0.4},
+ )
+
+
+@dataclass
+class AugmentationConfig:
+ """
+ Complete augmentation configuration.
+
+ All augmentation types have document-safe defaults that preserve
+ text readability. Only lighting_variation is enabled by default.
+
+ Attributes:
+ perspective_warp: Geometric perspective transform (affects bboxes).
+ wrinkle: Paper wrinkle/crease simulation.
+ edge_damage: Damaged/torn edge effects.
+ stain: Coffee stain/smudge effects.
+ lighting_variation: Brightness and contrast variation.
+ shadow: Shadow overlay effects.
+ gaussian_blur: Gaussian blur for focus issues.
+ motion_blur: Motion blur simulation.
+ gaussian_noise: Gaussian noise for sensor noise.
+ salt_pepper: Salt and pepper noise.
+ paper_texture: Paper texture overlay.
+ scanner_artifacts: Scanner line and dust artifacts.
+ preserve_bboxes: Whether to adjust bboxes for geometric transforms.
+ seed: Random seed for reproducibility.
+ """
+
+ # Geometric transforms (affects bboxes)
+ perspective_warp: AugmentationParams = field(
+ default_factory=_default_perspective_warp
+ )
+
+ # Degradation effects
+ wrinkle: AugmentationParams = field(default_factory=_default_wrinkle)
+ edge_damage: AugmentationParams = field(default_factory=_default_edge_damage)
+ stain: AugmentationParams = field(default_factory=_default_stain)
+
+ # Lighting effects
+ lighting_variation: AugmentationParams = field(
+ default_factory=_default_lighting_variation
+ )
+ shadow: AugmentationParams = field(default_factory=_default_shadow)
+
+ # Blur effects
+ gaussian_blur: AugmentationParams = field(default_factory=_default_gaussian_blur)
+ motion_blur: AugmentationParams = field(default_factory=_default_motion_blur)
+
+ # Noise effects
+ gaussian_noise: AugmentationParams = field(default_factory=_default_gaussian_noise)
+ salt_pepper: AugmentationParams = field(default_factory=_default_salt_pepper)
+
+ # Texture effects
+ paper_texture: AugmentationParams = field(default_factory=_default_paper_texture)
+ scanner_artifacts: AugmentationParams = field(
+ default_factory=_default_scanner_artifacts
+ )
+
+ # Global settings
+ preserve_bboxes: bool = True
+ seed: int | None = None
+
+ # List of all augmentation field names
+ _AUGMENTATION_FIELDS: tuple[str, ...] = (
+ "perspective_warp",
+ "wrinkle",
+ "edge_damage",
+ "stain",
+ "lighting_variation",
+ "shadow",
+ "gaussian_blur",
+ "motion_blur",
+ "gaussian_noise",
+ "salt_pepper",
+ "paper_texture",
+ "scanner_artifacts",
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert to dictionary for serialization."""
+ result: dict[str, Any] = {
+ "preserve_bboxes": self.preserve_bboxes,
+ "seed": self.seed,
+ }
+
+ for field_name in self._AUGMENTATION_FIELDS:
+ params: AugmentationParams = getattr(self, field_name)
+ result[field_name] = params.to_dict()
+
+ return result
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "AugmentationConfig":
+ """Create from dictionary."""
+ kwargs: dict[str, Any] = {
+ "preserve_bboxes": data.get("preserve_bboxes", True),
+ "seed": data.get("seed"),
+ }
+
+ for field_name in cls._AUGMENTATION_FIELDS:
+ if field_name in data:
+ field_data = data[field_name]
+ if isinstance(field_data, dict):
+ kwargs[field_name] = AugmentationParams.from_dict(field_data)
+
+ return cls(**kwargs)
+
+ def get_enabled_augmentations(self) -> list[str]:
+ """Get list of enabled augmentation names."""
+ enabled = []
+ for field_name in self._AUGMENTATION_FIELDS:
+ params: AugmentationParams = getattr(self, field_name)
+ if params.enabled:
+ enabled.append(field_name)
+ return enabled
+
+ def validate(self) -> None:
+ """
+ Validate configuration.
+
+ Raises:
+ ValueError: If any configuration value is invalid.
+ """
+ for field_name in self._AUGMENTATION_FIELDS:
+ params: AugmentationParams = getattr(self, field_name)
+ if not (0.0 <= params.probability <= 1.0):
+ raise ValueError(
+ f"{field_name}.probability must be between 0 and 1, "
+ f"got {params.probability}"
+ )
diff --git a/packages/shared/shared/augmentation/dataset_augmenter.py b/packages/shared/shared/augmentation/dataset_augmenter.py
new file mode 100644
index 0000000..514fe71
--- /dev/null
+++ b/packages/shared/shared/augmentation/dataset_augmenter.py
@@ -0,0 +1,206 @@
+"""
+Dataset Augmenter Module.
+
+Applies augmentation pipeline to YOLO datasets,
+creating new augmented images and label files.
+"""
+
+import logging
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+from PIL import Image
+
+from shared.augmentation.config import AugmentationConfig, AugmentationParams
+from shared.augmentation.pipeline import AugmentationPipeline
+
+logger = logging.getLogger(__name__)
+
+
+class DatasetAugmenter:
+ """
+ Augments YOLO datasets by creating new images and label files.
+
+ Reads images from dataset/images/train/ and labels from dataset/labels/train/,
+ applies augmentation pipeline, and saves augmented versions with "_augN" suffix.
+ """
+
+ def __init__(
+ self,
+ config: dict[str, Any],
+ seed: int | None = None,
+ ) -> None:
+ """
+ Initialize augmenter with configuration.
+
+ Args:
+ config: Dictionary mapping augmentation names to their settings.
+ Each augmentation should have 'enabled', 'probability', and 'params'.
+ seed: Random seed for reproducibility.
+ """
+ self._config_dict = config
+ self._seed = seed
+ self._config = self._build_config(config, seed)
+
+ def _build_config(
+ self,
+ config_dict: dict[str, Any],
+ seed: int | None,
+ ) -> AugmentationConfig:
+ """Build AugmentationConfig from dictionary."""
+ kwargs: dict[str, Any] = {"seed": seed, "preserve_bboxes": True}
+
+ for aug_name, aug_settings in config_dict.items():
+ if aug_name in AugmentationConfig._AUGMENTATION_FIELDS:
+ kwargs[aug_name] = AugmentationParams(
+ enabled=aug_settings.get("enabled", False),
+ probability=aug_settings.get("probability", 0.5),
+ params=aug_settings.get("params", {}),
+ )
+
+ return AugmentationConfig(**kwargs)
+
+ def augment_dataset(
+ self,
+ dataset_path: Path,
+ multiplier: int = 1,
+ split: str = "train",
+ ) -> dict[str, int]:
+ """
+ Augment a YOLO dataset.
+
+ Args:
+ dataset_path: Path to dataset root (containing images/ and labels/).
+ multiplier: Number of augmented copies per original image.
+ split: Which split to augment (default: "train").
+
+ Returns:
+ Summary dict with original_images, augmented_images, total_images.
+ """
+ images_dir = dataset_path / "images" / split
+ labels_dir = dataset_path / "labels" / split
+
+ if not images_dir.exists():
+ raise ValueError(f"Images directory not found: {images_dir}")
+
+ # Find all images
+ image_extensions = ("*.png", "*.jpg", "*.jpeg")
+ image_files: list[Path] = []
+ for ext in image_extensions:
+ image_files.extend(images_dir.glob(ext))
+
+ original_count = len(image_files)
+ augmented_count = 0
+
+ if multiplier <= 0:
+ return {
+ "original_images": original_count,
+ "augmented_images": 0,
+ "total_images": original_count,
+ }
+
+ # Process each image
+ for img_path in image_files:
+ # Load image
+ pil_image = Image.open(img_path).convert("RGB")
+ image = np.array(pil_image)
+
+ # Load corresponding label
+ label_path = labels_dir / f"{img_path.stem}.txt"
+ bboxes = self._load_bboxes(label_path) if label_path.exists() else None
+
+ # Create multiple augmented versions
+ for aug_idx in range(multiplier):
+ # Create pipeline with adjusted seed for each augmentation
+ aug_seed = None
+ if self._seed is not None:
+ aug_seed = self._seed + aug_idx + hash(img_path.stem) % 10000
+
+ pipeline = AugmentationPipeline(
+ self._build_config(self._config_dict, aug_seed)
+ )
+
+ # Apply augmentation
+ result = pipeline.apply(image, bboxes)
+
+ # Save augmented image
+ aug_name = f"{img_path.stem}_aug{aug_idx}{img_path.suffix}"
+ aug_img_path = images_dir / aug_name
+ aug_pil = Image.fromarray(result.image)
+ aug_pil.save(aug_img_path)
+
+ # Save augmented label
+ aug_label_path = labels_dir / f"{img_path.stem}_aug{aug_idx}.txt"
+ self._save_bboxes(aug_label_path, result.bboxes)
+
+ augmented_count += 1
+
+ logger.info(
+ "Dataset augmentation complete: %d original, %d augmented",
+ original_count,
+ augmented_count,
+ )
+
+ return {
+ "original_images": original_count,
+ "augmented_images": augmented_count,
+ "total_images": original_count + augmented_count,
+ }
+
+ def _load_bboxes(self, label_path: Path) -> np.ndarray | None:
+ """
+ Load bounding boxes from YOLO label file.
+
+ Args:
+ label_path: Path to label file.
+
+ Returns:
+ Array of shape (N, 5) with class_id, x_center, y_center, width, height.
+ Returns None if file is empty or doesn't exist.
+ """
+ if not label_path.exists():
+ return None
+
+ content = label_path.read_text().strip()
+ if not content:
+ return None
+
+ bboxes = []
+ for line in content.split("\n"):
+ parts = line.strip().split()
+ if len(parts) == 5:
+ class_id = int(parts[0])
+ x_center = float(parts[1])
+ y_center = float(parts[2])
+ width = float(parts[3])
+ height = float(parts[4])
+ bboxes.append([class_id, x_center, y_center, width, height])
+
+ if not bboxes:
+ return None
+
+ return np.array(bboxes, dtype=np.float32)
+
+ def _save_bboxes(self, label_path: Path, bboxes: np.ndarray | None) -> None:
+ """
+ Save bounding boxes to YOLO label file.
+
+ Args:
+ label_path: Path to save label file.
+ bboxes: Array of shape (N, 5) or None for empty labels.
+ """
+ if bboxes is None or len(bboxes) == 0:
+ label_path.write_text("")
+ return
+
+ lines = []
+ for bbox in bboxes:
+ class_id = int(bbox[0])
+ x_center = bbox[1]
+ y_center = bbox[2]
+ width = bbox[3]
+ height = bbox[4]
+ lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
+
+ label_path.write_text("\n".join(lines))
diff --git a/packages/shared/shared/augmentation/pipeline.py b/packages/shared/shared/augmentation/pipeline.py
new file mode 100644
index 0000000..ff5bfec
--- /dev/null
+++ b/packages/shared/shared/augmentation/pipeline.py
@@ -0,0 +1,184 @@
+"""
+Augmentation pipeline module.
+
+Orchestrates multiple augmentations with proper ordering and
+provides preview functionality.
+"""
+
+from typing import Any
+
+import numpy as np
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+from shared.augmentation.config import AugmentationConfig, AugmentationParams
+from shared.augmentation.transforms.blur import GaussianBlur, MotionBlur
+from shared.augmentation.transforms.degradation import EdgeDamage, Stain, Wrinkle
+from shared.augmentation.transforms.geometric import PerspectiveWarp
+from shared.augmentation.transforms.lighting import LightingVariation, Shadow
+from shared.augmentation.transforms.noise import GaussianNoise, SaltPepper
+from shared.augmentation.transforms.texture import PaperTexture, ScannerArtifacts
+
+# Registry of augmentation classes
+AUGMENTATION_REGISTRY: dict[str, type[BaseAugmentation]] = {
+ "perspective_warp": PerspectiveWarp,
+ "wrinkle": Wrinkle,
+ "edge_damage": EdgeDamage,
+ "stain": Stain,
+ "lighting_variation": LightingVariation,
+ "shadow": Shadow,
+ "gaussian_blur": GaussianBlur,
+ "motion_blur": MotionBlur,
+ "gaussian_noise": GaussianNoise,
+ "salt_pepper": SaltPepper,
+ "paper_texture": PaperTexture,
+ "scanner_artifacts": ScannerArtifacts,
+}
+
+
+class AugmentationPipeline:
+ """
+ Orchestrates multiple augmentations with proper ordering.
+
+ Augmentations are applied in the following order:
+ 1. Geometric (perspective_warp) - affects bboxes
+ 2. Degradation (wrinkle, edge_damage, stain) - visual artifacts
+ 3. Lighting (lighting_variation, shadow)
+ 4. Texture (paper_texture, scanner_artifacts)
+ 5. Blur (gaussian_blur, motion_blur)
+ 6. Noise (gaussian_noise, salt_pepper) - applied last
+ """
+
+ STAGE_ORDER = [
+ "geometric",
+ "degradation",
+ "lighting",
+ "texture",
+ "blur",
+ "noise",
+ ]
+
+ STAGE_MAPPING = {
+ "perspective_warp": "geometric",
+ "wrinkle": "degradation",
+ "edge_damage": "degradation",
+ "stain": "degradation",
+ "lighting_variation": "lighting",
+ "shadow": "lighting",
+ "paper_texture": "texture",
+ "scanner_artifacts": "texture",
+ "gaussian_blur": "blur",
+ "motion_blur": "blur",
+ "gaussian_noise": "noise",
+ "salt_pepper": "noise",
+ }
+
+ def __init__(self, config: AugmentationConfig) -> None:
+ """
+ Initialize pipeline with configuration.
+
+ Args:
+ config: Augmentation configuration.
+ """
+ self.config = config
+ self._rng = np.random.default_rng(config.seed)
+ self._augmentations = self._build_augmentations()
+
+ def _build_augmentations(
+ self,
+ ) -> list[tuple[str, BaseAugmentation, float]]:
+ """Build ordered list of (name, augmentation, probability) tuples."""
+ augmentations: list[tuple[str, BaseAugmentation, float]] = []
+
+ for aug_name, aug_class in AUGMENTATION_REGISTRY.items():
+ params: AugmentationParams = getattr(self.config, aug_name)
+ if params.enabled:
+ aug = aug_class(params.params)
+ augmentations.append((aug_name, aug, params.probability))
+
+ # Sort by stage order
+ def sort_key(item: tuple[str, BaseAugmentation, float]) -> int:
+ name, _, _ = item
+ stage = self.STAGE_MAPPING[name]
+ return self.STAGE_ORDER.index(stage)
+
+ return sorted(augmentations, key=sort_key)
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ ) -> AugmentationResult:
+ """
+ Apply augmentation pipeline to image.
+
+ Args:
+ image: Input image (H, W, C) as numpy array with dtype uint8.
+ bboxes: Optional bounding boxes in YOLO format (N, 5).
+
+ Returns:
+ AugmentationResult with augmented image and optionally adjusted bboxes.
+ """
+ current_image = image.copy()
+ current_bboxes = bboxes.copy() if bboxes is not None else None
+ applied_augmentations: list[str] = []
+
+ for name, aug, probability in self._augmentations:
+ if self._rng.random() < probability:
+ result = aug.apply(current_image, current_bboxes, self._rng)
+ current_image = result.image
+ if result.bboxes is not None and self.config.preserve_bboxes:
+ current_bboxes = result.bboxes
+ applied_augmentations.append(name)
+
+ return AugmentationResult(
+ image=current_image,
+ bboxes=current_bboxes,
+ metadata={"applied_augmentations": applied_augmentations},
+ )
+
+ def preview(
+ self,
+ image: np.ndarray,
+ augmentation_name: str,
+ ) -> np.ndarray:
+ """
+ Preview a single augmentation deterministically.
+
+ Args:
+ image: Input image.
+ augmentation_name: Name of augmentation to preview.
+
+ Returns:
+ Augmented image.
+
+ Raises:
+ ValueError: If augmentation_name is not recognized.
+ """
+ if augmentation_name not in AUGMENTATION_REGISTRY:
+ raise ValueError(f"Unknown augmentation: {augmentation_name}")
+
+ params: AugmentationParams = getattr(self.config, augmentation_name)
+ aug = AUGMENTATION_REGISTRY[augmentation_name](params.params)
+
+ # Use deterministic RNG for preview
+ preview_rng = np.random.default_rng(42)
+ result = aug.apply(image.copy(), rng=preview_rng)
+ return result.image
+
+
+def get_available_augmentations() -> list[dict[str, Any]]:
+ """
+ Get list of available augmentations with metadata.
+
+ Returns:
+ List of dictionaries with augmentation info.
+ """
+ augmentations = []
+ for name, aug_class in AUGMENTATION_REGISTRY.items():
+ augmentations.append({
+ "name": name,
+ "description": aug_class.__doc__ or "",
+ "affects_geometry": aug_class.affects_geometry,
+ "stage": AugmentationPipeline.STAGE_MAPPING[name],
+ })
+ return augmentations
diff --git a/packages/shared/shared/augmentation/presets.py b/packages/shared/shared/augmentation/presets.py
new file mode 100644
index 0000000..6bad56c
--- /dev/null
+++ b/packages/shared/shared/augmentation/presets.py
@@ -0,0 +1,212 @@
+"""
+Predefined augmentation presets for common document scenarios.
+
+Presets provide ready-to-use configurations optimized for different
+use cases, from conservative (preserves text readability) to aggressive
+(simulates poor document quality).
+"""
+
+from typing import Any
+
+from shared.augmentation.config import AugmentationConfig, AugmentationParams
+
+
+PRESETS: dict[str, dict[str, Any]] = {
+ "conservative": {
+ "description": "Safe augmentations that preserve text readability",
+ "config": {
+ "lighting_variation": {
+ "enabled": True,
+ "probability": 0.5,
+ "params": {
+ "brightness_range": (-0.1, 0.1),
+ "contrast_range": (0.9, 1.1),
+ },
+ },
+ "gaussian_noise": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"std": (3, 10)},
+ },
+ },
+ },
+ "moderate": {
+ "description": "Balanced augmentations for typical document degradation",
+ "config": {
+ "lighting_variation": {
+ "enabled": True,
+ "probability": 0.5,
+ "params": {
+ "brightness_range": (-0.15, 0.15),
+ "contrast_range": (0.85, 1.15),
+ },
+ },
+ "shadow": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"num_shadows": (1, 2), "opacity": (0.2, 0.35)},
+ },
+ "gaussian_noise": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"std": (5, 12)},
+ },
+ "gaussian_blur": {
+ "enabled": True,
+ "probability": 0.2,
+ "params": {"kernel_size": (3, 5), "sigma": (0.5, 1.0)},
+ },
+ "paper_texture": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"intensity": (0.05, 0.12)},
+ },
+ },
+ },
+ "aggressive": {
+ "description": "Heavy augmentations simulating poor scan quality",
+ "config": {
+ "perspective_warp": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"max_warp": 0.02},
+ },
+ "wrinkle": {
+ "enabled": True,
+ "probability": 0.4,
+ "params": {"intensity": 0.3, "num_wrinkles": (2, 4)},
+ },
+ "stain": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {
+ "num_stains": (1, 2),
+ "max_radius_ratio": 0.08,
+ "opacity": (0.1, 0.25),
+ },
+ },
+ "lighting_variation": {
+ "enabled": True,
+ "probability": 0.6,
+ "params": {
+ "brightness_range": (-0.2, 0.2),
+ "contrast_range": (0.8, 1.2),
+ },
+ },
+ "shadow": {
+ "enabled": True,
+ "probability": 0.4,
+ "params": {"num_shadows": (1, 2), "opacity": (0.25, 0.4)},
+ },
+ "gaussian_blur": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
+ },
+ "motion_blur": {
+ "enabled": True,
+ "probability": 0.2,
+ "params": {"kernel_size": (5, 7), "angle_range": (-30, 30)},
+ },
+ "gaussian_noise": {
+ "enabled": True,
+ "probability": 0.4,
+ "params": {"std": (8, 18)},
+ },
+ "paper_texture": {
+ "enabled": True,
+ "probability": 0.4,
+ "params": {"intensity": (0.08, 0.15)},
+ },
+ "scanner_artifacts": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"line_probability": 0.4, "dust_probability": 0.5},
+ },
+ "edge_damage": {
+ "enabled": True,
+ "probability": 0.2,
+ "params": {"max_damage_ratio": 0.04},
+ },
+ },
+ },
+ "scanned_document": {
+ "description": "Simulates typical scanned document artifacts",
+ "config": {
+ "scanner_artifacts": {
+ "enabled": True,
+ "probability": 0.5,
+ "params": {"line_probability": 0.4, "dust_probability": 0.5},
+ },
+ "paper_texture": {
+ "enabled": True,
+ "probability": 0.4,
+ "params": {"intensity": (0.05, 0.12)},
+ },
+ "lighting_variation": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {
+ "brightness_range": (-0.1, 0.1),
+ "contrast_range": (0.9, 1.1),
+ },
+ },
+ "gaussian_noise": {
+ "enabled": True,
+ "probability": 0.3,
+ "params": {"std": (5, 12)},
+ },
+ },
+ },
+}
+
+
+def get_preset_config(preset_name: str) -> dict[str, Any]:
+ """
+ Get the configuration dictionary for a preset.
+
+ Args:
+ preset_name: Name of the preset.
+
+ Returns:
+ Configuration dictionary.
+
+ Raises:
+ ValueError: If preset is not found.
+ """
+ if preset_name not in PRESETS:
+ raise ValueError(
+ f"Unknown preset: {preset_name}. "
+ f"Available presets: {list(PRESETS.keys())}"
+ )
+ return PRESETS[preset_name]["config"]
+
+
+def create_config_from_preset(preset_name: str) -> AugmentationConfig:
+ """
+ Create an AugmentationConfig from a preset.
+
+ Args:
+ preset_name: Name of the preset.
+
+ Returns:
+ AugmentationConfig instance.
+
+ Raises:
+ ValueError: If preset is not found.
+ """
+ config_dict = get_preset_config(preset_name)
+ return AugmentationConfig.from_dict(config_dict)
+
+
+def list_presets() -> list[dict[str, str]]:
+ """
+ List all available presets.
+
+ Returns:
+ List of dictionaries with name and description.
+ """
+ return [
+ {"name": name, "description": preset["description"]}
+ for name, preset in PRESETS.items()
+ ]
diff --git a/packages/shared/shared/augmentation/transforms/__init__.py b/packages/shared/shared/augmentation/transforms/__init__.py
new file mode 100644
index 0000000..1fe9e61
--- /dev/null
+++ b/packages/shared/shared/augmentation/transforms/__init__.py
@@ -0,0 +1,13 @@
+"""
+Augmentation transform implementations.
+
+Each module contains related augmentation classes:
+- geometric.py: Perspective warp and other geometric transforms
+- degradation.py: Wrinkle, edge damage, stain effects
+- lighting.py: Lighting variation and shadow effects
+- blur.py: Gaussian and motion blur
+- noise.py: Gaussian and salt-pepper noise
+- texture.py: Paper texture and scanner artifacts
+"""
+
+# Will be populated as transforms are implemented
diff --git a/packages/shared/shared/augmentation/transforms/blur.py b/packages/shared/shared/augmentation/transforms/blur.py
new file mode 100644
index 0000000..b5902e6
--- /dev/null
+++ b/packages/shared/shared/augmentation/transforms/blur.py
@@ -0,0 +1,144 @@
+"""
+Blur augmentation transforms.
+
+Provides blur effects for document image augmentation:
+- GaussianBlur: Simulates out-of-focus capture
+- MotionBlur: Simulates camera/document movement during capture
+"""
+
+import cv2
+import numpy as np
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+
+class GaussianBlur(BaseAugmentation):
+ """
+ Applies Gaussian blur to the image.
+
+ Simulates out-of-focus capture or low-quality optics.
+ Conservative defaults to preserve text readability.
+
+ Parameters:
+ kernel_size: Blur kernel size, int or (min, max) tuple (default: (3, 5)).
+ sigma: Blur sigma, float or (min, max) tuple (default: (0.5, 1.5)).
+ """
+
+ name = "gaussian_blur"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ kernel_size = self.params.get("kernel_size", (3, 5))
+ if isinstance(kernel_size, int):
+ if kernel_size < 1 or kernel_size % 2 == 0:
+ raise ValueError("kernel_size must be a positive odd integer")
+ elif isinstance(kernel_size, tuple):
+ if kernel_size[0] < 1 or kernel_size[1] < kernel_size[0]:
+ raise ValueError("kernel_size tuple must be (min, max) with min >= 1")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ kernel_size = self.params.get("kernel_size", (3, 5))
+ sigma = self.params.get("sigma", (0.5, 1.5))
+
+ if isinstance(kernel_size, tuple):
+ # Choose random odd kernel size
+ min_k, max_k = kernel_size
+ possible_sizes = [k for k in range(min_k, max_k + 1) if k % 2 == 1]
+ if not possible_sizes:
+ possible_sizes = [min_k if min_k % 2 == 1 else min_k + 1]
+ kernel_size = rng.choice(possible_sizes)
+
+ if isinstance(sigma, tuple):
+ sigma = rng.uniform(sigma[0], sigma[1])
+
+ # Ensure kernel size is odd
+ if kernel_size % 2 == 0:
+ kernel_size += 1
+
+ # Apply Gaussian blur
+ blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
+
+ return AugmentationResult(
+ image=blurred,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"kernel_size": kernel_size, "sigma": sigma},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"kernel_size": 5, "sigma": 1.5}
+
+
+class MotionBlur(BaseAugmentation):
+ """
+ Applies motion blur to the image.
+
+ Simulates camera shake or document movement during capture.
+
+ Parameters:
+ kernel_size: Blur kernel size, int or (min, max) tuple (default: (5, 9)).
+ angle_range: Motion angle range in degrees (default: (-45, 45)).
+ """
+
+ name = "motion_blur"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ kernel_size = self.params.get("kernel_size", (5, 9))
+ if isinstance(kernel_size, int):
+ if kernel_size < 3:
+ raise ValueError("kernel_size must be at least 3")
+ elif isinstance(kernel_size, tuple):
+ if kernel_size[0] < 3:
+ raise ValueError("kernel_size min must be at least 3")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ kernel_size = self.params.get("kernel_size", (5, 9))
+ angle_range = self.params.get("angle_range", (-45, 45))
+
+ if isinstance(kernel_size, tuple):
+ kernel_size = rng.integers(kernel_size[0], kernel_size[1] + 1)
+
+ angle = rng.uniform(angle_range[0], angle_range[1])
+
+ # Create motion blur kernel
+ kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
+
+ # Draw a line in the center of the kernel
+ center = kernel_size // 2
+ angle_rad = np.deg2rad(angle)
+
+ for i in range(kernel_size):
+ offset = i - center
+ x = int(center + offset * np.cos(angle_rad))
+ y = int(center + offset * np.sin(angle_rad))
+ if 0 <= x < kernel_size and 0 <= y < kernel_size:
+ kernel[y, x] = 1.0
+
+ # Normalize kernel
+ kernel = kernel / kernel.sum() if kernel.sum() > 0 else kernel
+
+ # Apply motion blur
+ blurred = cv2.filter2D(image, -1, kernel)
+
+ return AugmentationResult(
+ image=blurred,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"kernel_size": kernel_size, "angle": angle},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"kernel_size": 7, "angle_range": (-30, 30)}
diff --git a/packages/shared/shared/augmentation/transforms/degradation.py b/packages/shared/shared/augmentation/transforms/degradation.py
new file mode 100644
index 0000000..b5c4560
--- /dev/null
+++ b/packages/shared/shared/augmentation/transforms/degradation.py
@@ -0,0 +1,259 @@
+"""
+Degradation augmentation transforms.
+
+Provides degradation effects for document image augmentation:
+- Wrinkle: Paper wrinkle/crease simulation
+- EdgeDamage: Damaged/torn edge effects
+- Stain: Coffee stain/smudge effects
+"""
+
+import cv2
+import numpy as np
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+
+class Wrinkle(BaseAugmentation):
+ """
+ Simulates paper wrinkles/creases using displacement mapping.
+
+ Document-friendly: Uses subtle displacement to preserve text readability.
+
+ Parameters:
+ intensity: Wrinkle intensity (0-1) (default: 0.3).
+ num_wrinkles: Number of wrinkles, int or (min, max) tuple (default: (2, 5)).
+ """
+
+ name = "wrinkle"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ intensity = self.params.get("intensity", 0.3)
+ if not (0 < intensity <= 1):
+ raise ValueError("intensity must be between 0 and 1")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ h, w = image.shape[:2]
+ intensity = self.params.get("intensity", 0.3)
+ num_wrinkles = self.params.get("num_wrinkles", (2, 5))
+
+ if isinstance(num_wrinkles, tuple):
+ num_wrinkles = rng.integers(num_wrinkles[0], num_wrinkles[1] + 1)
+
+ # Create displacement maps
+ displacement_x = np.zeros((h, w), dtype=np.float32)
+ displacement_y = np.zeros((h, w), dtype=np.float32)
+
+ for _ in range(num_wrinkles):
+ # Random wrinkle parameters
+ angle = rng.uniform(0, np.pi)
+ x0 = rng.uniform(0, w)
+ y0 = rng.uniform(0, h)
+ length = rng.uniform(0.3, 0.8) * min(h, w)
+ width = rng.uniform(0.02, 0.05) * min(h, w)
+
+ # Create coordinate grids
+ xx, yy = np.meshgrid(np.arange(w), np.arange(h))
+
+ # Distance from wrinkle line
+ dx = (xx - x0) * np.cos(angle) + (yy - y0) * np.sin(angle)
+ dy = -(xx - x0) * np.sin(angle) + (yy - y0) * np.cos(angle)
+
+ # Gaussian falloff perpendicular to wrinkle
+ mask = np.exp(-dy**2 / (2 * width**2))
+ mask *= (np.abs(dx) < length / 2).astype(np.float32)
+
+ # Displacement perpendicular to wrinkle
+ disp_amount = intensity * rng.uniform(2, 8)
+ displacement_x += mask * disp_amount * np.sin(angle)
+ displacement_y += mask * disp_amount * np.cos(angle)
+
+ # Create remap coordinates
+ map_x = (np.arange(w)[np.newaxis, :] + displacement_x).astype(np.float32)
+ map_y = (np.arange(h)[:, np.newaxis] + displacement_y).astype(np.float32)
+
+ # Apply displacement
+ augmented = cv2.remap(
+ image, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT
+ )
+
+ # Add subtle shading along wrinkles
+ max_disp = np.max(np.abs(displacement_y)) + 1e-6
+ shading = 1 - 0.1 * intensity * np.abs(displacement_y) / max_disp
+ shading = shading[:, :, np.newaxis]
+ augmented = (augmented.astype(np.float32) * shading).astype(np.uint8)
+
+ return AugmentationResult(
+ image=augmented,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"num_wrinkles": num_wrinkles, "intensity": intensity},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"intensity": 0.5, "num_wrinkles": 3}
+
+
+class EdgeDamage(BaseAugmentation):
+ """
+ Adds damaged/torn edge effects to the image.
+
+ Simulates worn or torn document edges.
+
+ Parameters:
+ max_damage_ratio: Maximum proportion of edge to damage (default: 0.05).
+ edges: Which edges to potentially damage (default: all).
+ """
+
+ name = "edge_damage"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
+ if not (0 < max_damage_ratio <= 0.2):
+ raise ValueError("max_damage_ratio must be between 0 and 0.2")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ h, w = image.shape[:2]
+ max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
+ edges = self.params.get("edges", ["top", "bottom", "left", "right"])
+
+ output = image.copy()
+
+ # Select random edge to damage
+ edge = rng.choice(edges)
+ damage_size = int(max_damage_ratio * min(h, w))
+
+ if edge == "top":
+ # Create irregular top edge
+ for x in range(w):
+ depth = rng.integers(0, damage_size + 1)
+ if depth > 0:
+ # Random color (white or darker)
+ color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
+ output[:depth, x] = color
+
+ elif edge == "bottom":
+ for x in range(w):
+ depth = rng.integers(0, damage_size + 1)
+ if depth > 0:
+ color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
+ output[h - depth:, x] = color
+
+ elif edge == "left":
+ for y in range(h):
+ depth = rng.integers(0, damage_size + 1)
+ if depth > 0:
+ color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
+ output[y, :depth] = color
+
+ else: # right
+ for y in range(h):
+ depth = rng.integers(0, damage_size + 1)
+ if depth > 0:
+ color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
+ output[y, w - depth:] = color
+
+ return AugmentationResult(
+ image=output,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"edge": edge, "damage_size": damage_size},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"max_damage_ratio": 0.08}
+
+
+class Stain(BaseAugmentation):
+ """
+ Adds coffee stain/smudge effects to the image.
+
+ Simulates accidental stains on documents.
+
+ Parameters:
+ num_stains: Number of stains, int or (min, max) tuple (default: (1, 3)).
+ max_radius_ratio: Maximum stain radius as ratio of image size (default: 0.1).
+ opacity: Stain opacity, float or (min, max) tuple (default: (0.1, 0.3)).
+ """
+
+ name = "stain"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ opacity = self.params.get("opacity", (0.1, 0.3))
+ if isinstance(opacity, (int, float)):
+ if not (0 < opacity <= 1):
+ raise ValueError("opacity must be between 0 and 1")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ h, w = image.shape[:2]
+ num_stains = self.params.get("num_stains", (1, 3))
+ max_radius_ratio = self.params.get("max_radius_ratio", 0.1)
+ opacity = self.params.get("opacity", (0.1, 0.3))
+
+ if isinstance(num_stains, tuple):
+ num_stains = rng.integers(num_stains[0], num_stains[1] + 1)
+ if isinstance(opacity, tuple):
+ opacity = rng.uniform(opacity[0], opacity[1])
+
+ output = image.astype(np.float32)
+ max_radius = int(max_radius_ratio * min(h, w))
+
+ for _ in range(num_stains):
+ # Random stain position and size
+ cx = rng.integers(max_radius, w - max_radius)
+ cy = rng.integers(max_radius, h - max_radius)
+ radius = rng.integers(max_radius // 3, max_radius)
+
+ # Create stain mask with irregular edges
+ yy, xx = np.ogrid[:h, :w]
+ dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
+
+ # Add noise to make edges irregular
+ noise = rng.uniform(0.8, 1.2, (h, w))
+ mask = (dist < radius * noise).astype(np.float32)
+
+ # Blur for soft edges
+ mask = cv2.GaussianBlur(mask, (21, 21), 0)
+
+ # Random stain color (brownish/yellowish)
+ stain_color = np.array([
+ rng.integers(180, 220), # R
+ rng.integers(160, 200), # G
+ rng.integers(120, 160), # B
+ ], dtype=np.float32)
+
+ # Apply stain
+ mask_3d = mask[:, :, np.newaxis]
+ output = output * (1 - mask_3d * opacity) + stain_color * mask_3d * opacity
+
+ output = np.clip(output, 0, 255).astype(np.uint8)
+
+ return AugmentationResult(
+ image=output,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"num_stains": num_stains, "opacity": opacity},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"num_stains": 2, "max_radius_ratio": 0.1, "opacity": 0.25}
diff --git a/packages/shared/shared/augmentation/transforms/geometric.py b/packages/shared/shared/augmentation/transforms/geometric.py
new file mode 100644
index 0000000..04dc0be
--- /dev/null
+++ b/packages/shared/shared/augmentation/transforms/geometric.py
@@ -0,0 +1,145 @@
+"""
+Geometric augmentation transforms.
+
+Provides geometric transforms for document image augmentation:
+- PerspectiveWarp: Subtle perspective distortion
+"""
+
+import cv2
+import numpy as np
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+
+class PerspectiveWarp(BaseAugmentation):
+ """
+ Applies subtle perspective transformation to the image.
+
+ Simulates viewing document at slight angle. Very conservative
+ by default to preserve text readability.
+
+ IMPORTANT: This transform affects bounding box coordinates.
+
+ Parameters:
+ max_warp: Maximum warp as proportion of image size (default: 0.02).
+ """
+
+ name = "perspective_warp"
+ affects_geometry = True
+
+ def _validate_params(self) -> None:
+ max_warp = self.params.get("max_warp", 0.02)
+ if not (0 < max_warp <= 0.1):
+ raise ValueError("max_warp must be between 0 and 0.1")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ h, w = image.shape[:2]
+ max_warp = self.params.get("max_warp", 0.02)
+
+ # Original corners
+ src_pts = np.float32([
+ [0, 0],
+ [w, 0],
+ [w, h],
+ [0, h],
+ ])
+
+ # Add random perturbations to corners
+ max_offset = max_warp * min(h, w)
+ dst_pts = src_pts.copy()
+ for i in range(4):
+ dst_pts[i, 0] += rng.uniform(-max_offset, max_offset)
+ dst_pts[i, 1] += rng.uniform(-max_offset, max_offset)
+
+ # Compute perspective transform matrix
+ transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
+
+ # Apply perspective transform
+ warped = cv2.warpPerspective(
+ image, transform_matrix, (w, h),
+ borderMode=cv2.BORDER_REPLICATE
+ )
+
+ # Transform bounding boxes if present
+ transformed_bboxes = None
+ if bboxes is not None:
+ transformed_bboxes = self._transform_bboxes(
+ bboxes, transform_matrix, w, h
+ )
+
+ return AugmentationResult(
+ image=warped,
+ bboxes=transformed_bboxes,
+ transform_matrix=transform_matrix,
+ metadata={"max_warp": max_warp},
+ )
+
+ def _transform_bboxes(
+ self,
+ bboxes: np.ndarray,
+ transform_matrix: np.ndarray,
+ w: int,
+ h: int,
+ ) -> np.ndarray:
+ """Transform bounding boxes using perspective matrix."""
+ if len(bboxes) == 0:
+ return bboxes.copy()
+
+ transformed = []
+ for bbox in bboxes:
+ class_id, x_center, y_center, width, height = bbox
+
+ # Convert normalized coords to pixel coords
+ x_center_px = x_center * w
+ y_center_px = y_center * h
+ width_px = width * w
+ height_px = height * h
+
+ # Get corner points
+ x1 = x_center_px - width_px / 2
+ y1 = y_center_px - height_px / 2
+ x2 = x_center_px + width_px / 2
+ y2 = y_center_px + height_px / 2
+
+ # Transform all 4 corners
+ corners = np.float32([
+ [x1, y1],
+ [x2, y1],
+ [x2, y2],
+ [x1, y2],
+ ]).reshape(-1, 1, 2)
+
+ transformed_corners = cv2.perspectiveTransform(corners, transform_matrix)
+ transformed_corners = transformed_corners.reshape(-1, 2)
+
+ # Get bounding box of transformed corners
+ new_x1 = np.min(transformed_corners[:, 0])
+ new_y1 = np.min(transformed_corners[:, 1])
+ new_x2 = np.max(transformed_corners[:, 0])
+ new_y2 = np.max(transformed_corners[:, 1])
+
+ # Convert back to normalized center format
+ new_width = (new_x2 - new_x1) / w
+ new_height = (new_y2 - new_y1) / h
+ new_x_center = ((new_x1 + new_x2) / 2) / w
+ new_y_center = ((new_y1 + new_y2) / 2) / h
+
+ # Clamp to valid range
+ new_x_center = np.clip(new_x_center, 0, 1)
+ new_y_center = np.clip(new_y_center, 0, 1)
+ new_width = np.clip(new_width, 0, 1)
+ new_height = np.clip(new_height, 0, 1)
+
+ transformed.append([class_id, new_x_center, new_y_center, new_width, new_height])
+
+ return np.array(transformed, dtype=np.float32)
+
+ def get_preview_params(self) -> dict:
+ return {"max_warp": 0.03}
diff --git a/packages/shared/shared/augmentation/transforms/lighting.py b/packages/shared/shared/augmentation/transforms/lighting.py
new file mode 100644
index 0000000..93341c9
--- /dev/null
+++ b/packages/shared/shared/augmentation/transforms/lighting.py
@@ -0,0 +1,167 @@
+"""
+Lighting augmentation transforms.
+
+Provides lighting effects for document image augmentation:
+- LightingVariation: Adjusts brightness and contrast
+- Shadow: Adds shadow overlay effects
+"""
+
+import cv2
+import numpy as np
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+
+class LightingVariation(BaseAugmentation):
+ """
+ Adjusts image brightness and contrast.
+
+ Simulates different lighting conditions during document capture.
+ Safe for documents with conservative default parameters.
+
+ Parameters:
+ brightness_range: (min, max) brightness adjustment (default: (-0.1, 0.1)).
+ contrast_range: (min, max) contrast multiplier (default: (0.9, 1.1)).
+ """
+
+ name = "lighting_variation"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ brightness = self.params.get("brightness_range", (-0.1, 0.1))
+ contrast = self.params.get("contrast_range", (0.9, 1.1))
+
+ if not isinstance(brightness, tuple) or len(brightness) != 2:
+ raise ValueError("brightness_range must be a (min, max) tuple")
+ if not isinstance(contrast, tuple) or len(contrast) != 2:
+ raise ValueError("contrast_range must be a (min, max) tuple")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ brightness_range = self.params.get("brightness_range", (-0.1, 0.1))
+ contrast_range = self.params.get("contrast_range", (0.9, 1.1))
+
+ # Random brightness and contrast
+ brightness = rng.uniform(brightness_range[0], brightness_range[1])
+ contrast = rng.uniform(contrast_range[0], contrast_range[1])
+
+ # Apply adjustments
+ adjusted = image.astype(np.float32)
+
+ # Contrast adjustment (multiply around mean)
+ mean = adjusted.mean()
+ adjusted = (adjusted - mean) * contrast + mean
+
+ # Brightness adjustment (add offset)
+ adjusted = adjusted + brightness * 255
+
+ # Clip and convert back
+ adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
+
+ return AugmentationResult(
+ image=adjusted,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"brightness": brightness, "contrast": contrast},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"brightness_range": (-0.15, 0.15), "contrast_range": (0.85, 1.15)}
+
+
+class Shadow(BaseAugmentation):
+ """
+ Adds shadow overlay effects to the image.
+
+ Simulates shadows from objects or hands during document capture.
+
+ Parameters:
+ num_shadows: Number of shadow regions, int or (min, max) tuple (default: (1, 2)).
+ opacity: Shadow darkness, float or (min, max) tuple (default: (0.2, 0.4)).
+ """
+
+ name = "shadow"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ opacity = self.params.get("opacity", (0.2, 0.4))
+ if isinstance(opacity, (int, float)):
+ if not (0 <= opacity <= 1):
+ raise ValueError("opacity must be between 0 and 1")
+ elif isinstance(opacity, tuple):
+ if not (0 <= opacity[0] <= opacity[1] <= 1):
+ raise ValueError("opacity tuple must be in range [0, 1]")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ num_shadows = self.params.get("num_shadows", (1, 2))
+ opacity = self.params.get("opacity", (0.2, 0.4))
+
+ if isinstance(num_shadows, tuple):
+ num_shadows = rng.integers(num_shadows[0], num_shadows[1] + 1)
+ if isinstance(opacity, tuple):
+ opacity = rng.uniform(opacity[0], opacity[1])
+
+ h, w = image.shape[:2]
+ output = image.astype(np.float32)
+
+ for _ in range(num_shadows):
+ # Generate random shadow polygon
+ num_vertices = rng.integers(3, 6)
+ vertices = []
+
+ # Start from a random edge
+ edge = rng.integers(0, 4)
+ if edge == 0: # Top
+ start = (rng.integers(0, w), 0)
+ elif edge == 1: # Right
+ start = (w, rng.integers(0, h))
+ elif edge == 2: # Bottom
+ start = (rng.integers(0, w), h)
+ else: # Left
+ start = (0, rng.integers(0, h))
+
+ vertices.append(start)
+
+ # Add random vertices
+ for _ in range(num_vertices - 1):
+ x = rng.integers(0, w)
+ y = rng.integers(0, h)
+ vertices.append((x, y))
+
+ # Create shadow mask
+ mask = np.zeros((h, w), dtype=np.float32)
+ pts = np.array(vertices, dtype=np.int32).reshape((-1, 1, 2))
+ cv2.fillPoly(mask, [pts], 1.0)
+
+ # Blur the mask for soft edges
+ blur_size = max(31, min(h, w) // 10)
+ if blur_size % 2 == 0:
+ blur_size += 1
+ mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
+
+ # Apply shadow
+ shadow_factor = 1 - opacity * mask[:, :, np.newaxis]
+ output = output * shadow_factor
+
+ output = np.clip(output, 0, 255).astype(np.uint8)
+
+ return AugmentationResult(
+ image=output,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"num_shadows": num_shadows, "opacity": opacity},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"num_shadows": 1, "opacity": 0.3}
diff --git a/packages/shared/shared/augmentation/transforms/noise.py b/packages/shared/shared/augmentation/transforms/noise.py
new file mode 100644
index 0000000..c0b6eb2
--- /dev/null
+++ b/packages/shared/shared/augmentation/transforms/noise.py
@@ -0,0 +1,142 @@
+"""
+Noise augmentation transforms.
+
+Provides noise effects for document image augmentation:
+- GaussianNoise: Adds Gaussian noise to simulate sensor noise
+- SaltPepper: Adds salt and pepper noise for impulse noise effects
+"""
+
+from typing import Any
+
+import numpy as np
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+
+class GaussianNoise(BaseAugmentation):
+ """
+ Adds Gaussian noise to the image.
+
+ Simulates sensor noise from cameras or scanners.
+ Document-safe with conservative default parameters.
+
+ Parameters:
+ mean: Mean of the Gaussian noise (default: 0).
+ std: Standard deviation, can be int or (min, max) tuple (default: (5, 15)).
+ """
+
+ name = "gaussian_noise"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ std = self.params.get("std", (5, 15))
+ if isinstance(std, (int, float)):
+ if std < 0:
+ raise ValueError("std must be non-negative")
+ elif isinstance(std, tuple):
+ if len(std) != 2 or std[0] < 0 or std[1] < std[0]:
+ raise ValueError("std tuple must be (min, max) with min <= max >= 0")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ mean = self.params.get("mean", 0)
+ std = self.params.get("std", (5, 15))
+
+ if isinstance(std, tuple):
+ std = rng.uniform(std[0], std[1])
+
+ # Generate noise
+ noise = rng.normal(mean, std, image.shape).astype(np.float32)
+
+ # Apply noise
+ noisy = image.astype(np.float32) + noise
+ noisy = np.clip(noisy, 0, 255).astype(np.uint8)
+
+ return AugmentationResult(
+ image=noisy,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"applied_std": std},
+ )
+
+ def get_preview_params(self) -> dict[str, Any]:
+ return {"mean": 0, "std": 15}
+
+
+class SaltPepper(BaseAugmentation):
+ """
+ Adds salt and pepper (impulse) noise to the image.
+
+ Simulates defects from damaged sensors or transmission errors.
+ Very sparse by default to preserve document readability.
+
+ Parameters:
+ amount: Proportion of pixels to affect, can be float or (min, max) tuple.
+ Default: (0.001, 0.005) for very sparse noise.
+ salt_vs_pepper: Ratio of salt to pepper (default: 0.5 for equal amounts).
+ """
+
+ name = "salt_pepper"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ amount = self.params.get("amount", (0.001, 0.005))
+ if isinstance(amount, (int, float)):
+ if not (0 <= amount <= 1):
+ raise ValueError("amount must be between 0 and 1")
+ elif isinstance(amount, tuple):
+ if len(amount) != 2 or not (0 <= amount[0] <= amount[1] <= 1):
+ raise ValueError("amount tuple must be (min, max) in range [0, 1]")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ amount = self.params.get("amount", (0.001, 0.005))
+ salt_vs_pepper = self.params.get("salt_vs_pepper", 0.5)
+
+ if isinstance(amount, tuple):
+ amount = rng.uniform(amount[0], amount[1])
+
+ # Copy image
+ output = image.copy()
+ h, w = image.shape[:2]
+ total_pixels = h * w
+
+ # Calculate number of salt and pepper pixels
+ num_salt = int(total_pixels * amount * salt_vs_pepper)
+ num_pepper = int(total_pixels * amount * (1 - salt_vs_pepper))
+
+ # Add salt (white pixels)
+ if num_salt > 0:
+ salt_coords = (
+ rng.integers(0, h, num_salt),
+ rng.integers(0, w, num_salt),
+ )
+ output[salt_coords] = 255
+
+ # Add pepper (black pixels)
+ if num_pepper > 0:
+ pepper_coords = (
+ rng.integers(0, h, num_pepper),
+ rng.integers(0, w, num_pepper),
+ )
+ output[pepper_coords] = 0
+
+ return AugmentationResult(
+ image=output,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"applied_amount": amount},
+ )
+
+ def get_preview_params(self) -> dict[str, Any]:
+ return {"amount": 0.01, "salt_vs_pepper": 0.5}
diff --git a/packages/shared/shared/augmentation/transforms/texture.py b/packages/shared/shared/augmentation/transforms/texture.py
new file mode 100644
index 0000000..41287d2
--- /dev/null
+++ b/packages/shared/shared/augmentation/transforms/texture.py
@@ -0,0 +1,159 @@
+"""
+Texture augmentation transforms.
+
+Provides texture effects for document image augmentation:
+- PaperTexture: Adds paper grain/texture
+- ScannerArtifacts: Adds scanner line and dust artifacts
+"""
+
+import cv2
+import numpy as np
+
+from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+
+class PaperTexture(BaseAugmentation):
+ """
+ Adds paper texture/grain to the image.
+
+ Simulates different paper types and ages.
+
+ Parameters:
+ texture_type: Type of texture ("random", "fine", "coarse") (default: "random").
+ intensity: Texture intensity, float or (min, max) tuple (default: (0.05, 0.15)).
+ """
+
+ name = "paper_texture"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ intensity = self.params.get("intensity", (0.05, 0.15))
+ if isinstance(intensity, (int, float)):
+ if not (0 < intensity <= 1):
+ raise ValueError("intensity must be between 0 and 1")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ h, w = image.shape[:2]
+ texture_type = self.params.get("texture_type", "random")
+ intensity = self.params.get("intensity", (0.05, 0.15))
+
+ if texture_type == "random":
+ texture_type = rng.choice(["fine", "coarse"])
+
+ if isinstance(intensity, tuple):
+ intensity = rng.uniform(intensity[0], intensity[1])
+
+ # Generate base noise
+ if texture_type == "fine":
+ # Fine grain texture
+ noise = rng.uniform(-1, 1, (h, w)).astype(np.float32)
+ noise = cv2.GaussianBlur(noise, (3, 3), 0)
+ else:
+ # Coarse texture
+ # Generate at lower resolution and upscale
+ small_h, small_w = h // 4, w // 4
+ noise = rng.uniform(-1, 1, (small_h, small_w)).astype(np.float32)
+ noise = cv2.resize(noise, (w, h), interpolation=cv2.INTER_LINEAR)
+ noise = cv2.GaussianBlur(noise, (5, 5), 0)
+
+ # Apply texture
+ output = image.astype(np.float32)
+ noise_3d = noise[:, :, np.newaxis] * intensity * 255
+ output = output + noise_3d
+
+ output = np.clip(output, 0, 255).astype(np.uint8)
+
+ return AugmentationResult(
+ image=output,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={"texture_type": texture_type, "intensity": intensity},
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"texture_type": "coarse", "intensity": 0.15}
+
+
+class ScannerArtifacts(BaseAugmentation):
+ """
+ Adds scanner artifacts to the image.
+
+ Simulates scanner imperfections like lines and dust spots.
+
+ Parameters:
+ line_probability: Probability of adding scan lines (default: 0.3).
+ dust_probability: Probability of adding dust spots (default: 0.4).
+ """
+
+ name = "scanner_artifacts"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ line_prob = self.params.get("line_probability", 0.3)
+ dust_prob = self.params.get("dust_probability", 0.4)
+ if not (0 <= line_prob <= 1):
+ raise ValueError("line_probability must be between 0 and 1")
+ if not (0 <= dust_prob <= 1):
+ raise ValueError("dust_probability must be between 0 and 1")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ rng = rng or np.random.default_rng()
+
+ h, w = image.shape[:2]
+ line_probability = self.params.get("line_probability", 0.3)
+ dust_probability = self.params.get("dust_probability", 0.4)
+
+ output = image.copy()
+
+ # Add scan lines
+ if rng.random() < line_probability:
+ num_lines = rng.integers(1, 4)
+ for _ in range(num_lines):
+ y = rng.integers(0, h)
+ thickness = rng.integers(1, 3)
+ # Light or dark line
+ color = rng.integers(200, 240) if rng.random() > 0.5 else rng.integers(50, 100)
+
+ # Make line partially transparent
+ alpha = rng.uniform(0.3, 0.6)
+ for dy in range(thickness):
+ if y + dy < h:
+ output[y + dy, :] = (
+ output[y + dy, :].astype(np.float32) * (1 - alpha) +
+ color * alpha
+ ).astype(np.uint8)
+
+ # Add dust spots
+ if rng.random() < dust_probability:
+ num_dust = rng.integers(5, 20)
+ for _ in range(num_dust):
+ x = rng.integers(0, w)
+ y = rng.integers(0, h)
+ radius = rng.integers(1, 3)
+
+ # Dark dust spot
+ color = rng.integers(50, 120)
+ cv2.circle(output, (x, y), radius, int(color), -1)
+
+ return AugmentationResult(
+ image=output,
+ bboxes=bboxes.copy() if bboxes is not None else None,
+ metadata={
+ "line_probability": line_probability,
+ "dust_probability": dust_probability,
+ },
+ )
+
+ def get_preview_params(self) -> dict:
+ return {"line_probability": 0.8, "dust_probability": 0.8}
diff --git a/packages/shared/shared/training/__init__.py b/packages/shared/shared/training/__init__.py
new file mode 100644
index 0000000..76d3c08
--- /dev/null
+++ b/packages/shared/shared/training/__init__.py
@@ -0,0 +1,5 @@
+"""Shared training utilities."""
+
+from .yolo_trainer import YOLOTrainer, TrainingConfig, TrainingResult
+
+__all__ = ["YOLOTrainer", "TrainingConfig", "TrainingResult"]
diff --git a/packages/shared/shared/training/yolo_trainer.py b/packages/shared/shared/training/yolo_trainer.py
new file mode 100644
index 0000000..59435cd
--- /dev/null
+++ b/packages/shared/shared/training/yolo_trainer.py
@@ -0,0 +1,239 @@
+"""
+Shared YOLO Training Module
+
+Unified training logic for both CLI and Web API.
+"""
+
+import logging
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Callable
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class TrainingConfig:
+ """Training configuration."""
+
+ # Model settings
+ model_path: str = "yolo11n.pt" # Base model or path to trained model
+ data_yaml: str = "" # Path to data.yaml
+
+ # Training hyperparameters
+ epochs: int = 100
+ batch_size: int = 16
+ image_size: int = 640
+ learning_rate: float = 0.01
+ device: str = "0"
+
+ # Output settings
+ project: str = "runs/train"
+ name: str = "invoice_fields"
+
+ # Performance settings
+ workers: int = 4
+ cache: bool = False
+
+ # Resume settings
+ resume: bool = False
+ resume_from: str | None = None # Path to checkpoint
+
+ # Document-specific augmentation (optimized for invoices)
+ augmentation: dict[str, Any] = field(default_factory=lambda: {
+ "degrees": 5.0,
+ "translate": 0.05,
+ "scale": 0.2,
+ "shear": 0.0,
+ "perspective": 0.0,
+ "flipud": 0.0,
+ "fliplr": 0.0,
+ "mosaic": 0.0,
+ "mixup": 0.0,
+ "hsv_h": 0.0,
+ "hsv_s": 0.1,
+ "hsv_v": 0.2,
+ })
+
+
+@dataclass
+class TrainingResult:
+ """Training result."""
+
+ success: bool
+ model_path: str | None = None
+ metrics: dict[str, float] = field(default_factory=dict)
+ error: str | None = None
+ save_dir: str | None = None
+
+
+class YOLOTrainer:
+ """Unified YOLO trainer for CLI and Web API."""
+
+ def __init__(
+ self,
+ config: TrainingConfig,
+ log_callback: Callable[[str, str], None] | None = None,
+ ):
+ """
+ Initialize trainer.
+
+ Args:
+ config: Training configuration
+ log_callback: Optional callback for logging (level, message)
+ """
+ self.config = config
+ self._log_callback = log_callback
+
+ def _log(self, level: str, message: str) -> None:
+ """Log a message."""
+ if self._log_callback:
+ self._log_callback(level, message)
+ if level == "INFO":
+ logger.info(message)
+ elif level == "ERROR":
+ logger.error(message)
+ elif level == "WARNING":
+ logger.warning(message)
+
+ def validate_config(self) -> tuple[bool, str | None]:
+ """
+ Validate training configuration.
+
+ Returns:
+ Tuple of (is_valid, error_message)
+ """
+ # Check model path
+ model_path = Path(self.config.model_path)
+ if not model_path.suffix == ".pt":
+ # Could be a model name like "yolo11n.pt" which is downloaded
+ if not model_path.name.startswith("yolo"):
+ return False, f"Invalid model: {self.config.model_path}"
+ elif not model_path.exists():
+ return False, f"Model file not found: {self.config.model_path}"
+
+ # Check data.yaml
+ if not self.config.data_yaml:
+ return False, "data_yaml is required"
+ data_yaml = Path(self.config.data_yaml)
+ if not data_yaml.exists():
+ return False, f"data.yaml not found: {self.config.data_yaml}"
+
+ return True, None
+
+ def train(self) -> TrainingResult:
+ """
+ Run YOLO training.
+
+ Returns:
+ TrainingResult with model path and metrics
+ """
+ try:
+ from ultralytics import YOLO
+ except ImportError:
+ return TrainingResult(
+ success=False,
+ error="Ultralytics (YOLO) not installed. Install with: pip install ultralytics",
+ )
+
+ # Validate config
+ is_valid, error = self.validate_config()
+ if not is_valid:
+ return TrainingResult(success=False, error=error)
+
+ self._log("INFO", f"Starting YOLO training")
+ self._log("INFO", f" Model: {self.config.model_path}")
+ self._log("INFO", f" Data: {self.config.data_yaml}")
+ self._log("INFO", f" Epochs: {self.config.epochs}")
+ self._log("INFO", f" Batch size: {self.config.batch_size}")
+ self._log("INFO", f" Image size: {self.config.image_size}")
+
+ try:
+ # Load model
+ if self.config.resume and self.config.resume_from:
+ resume_path = Path(self.config.resume_from)
+ if resume_path.exists():
+ self._log("INFO", f"Resuming from: {resume_path}")
+ model = YOLO(str(resume_path))
+ else:
+ model = YOLO(self.config.model_path)
+ else:
+ model = YOLO(self.config.model_path)
+
+ # Build training arguments
+ train_args = {
+ "data": str(Path(self.config.data_yaml).absolute()),
+ "epochs": self.config.epochs,
+ "batch": self.config.batch_size,
+ "imgsz": self.config.image_size,
+ "lr0": self.config.learning_rate,
+ "device": self.config.device,
+ "project": self.config.project,
+ "name": self.config.name,
+ "exist_ok": True,
+ "pretrained": True,
+ "verbose": True,
+ "workers": self.config.workers,
+ "cache": self.config.cache,
+ "resume": self.config.resume and self.config.resume_from is not None,
+ }
+
+ # Add augmentation settings
+ train_args.update(self.config.augmentation)
+
+ # Train
+ results = model.train(**train_args)
+
+ # Get best model path
+ best_model = Path(results.save_dir) / "weights" / "best.pt"
+
+ # Extract metrics
+ metrics = {}
+ if hasattr(results, "results_dict"):
+ metrics = {
+ "mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
+ "mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
+ "precision": results.results_dict.get("metrics/precision(B)", 0),
+ "recall": results.results_dict.get("metrics/recall(B)", 0),
+ }
+
+ self._log("INFO", f"Training completed successfully")
+ self._log("INFO", f" Best model: {best_model}")
+ self._log("INFO", f" mAP@0.5: {metrics.get('mAP50', 'N/A')}")
+
+ return TrainingResult(
+ success=True,
+ model_path=str(best_model) if best_model.exists() else None,
+ metrics=metrics,
+ save_dir=str(results.save_dir),
+ )
+
+ except Exception as e:
+ self._log("ERROR", f"Training failed: {e}")
+ return TrainingResult(success=False, error=str(e))
+
+ def validate(self, split: str = "val") -> dict[str, float]:
+ """
+ Run validation on trained model.
+
+ Args:
+ split: Dataset split to validate on ("val" or "test")
+
+ Returns:
+ Validation metrics
+ """
+ try:
+ from ultralytics import YOLO
+
+ model = YOLO(self.config.model_path)
+ metrics = model.val(data=self.config.data_yaml, split=split)
+
+ return {
+ "mAP50": metrics.box.map50,
+ "mAP50-95": metrics.box.map,
+ "precision": metrics.box.mp,
+ "recall": metrics.box.mr,
+ }
+ except Exception as e:
+ self._log("ERROR", f"Validation failed: {e}")
+ return {}
diff --git a/packages/training/training/cli/train.py b/packages/training/training/cli/train.py
index ca64863..fc96997 100644
--- a/packages/training/training/cli/train.py
+++ b/packages/training/training/cli/train.py
@@ -199,67 +199,63 @@ def main():
db.close()
return
- # Start training
+ # Start training using shared trainer
print("\n" + "=" * 60)
print("Starting YOLO Training")
print("=" * 60)
- from ultralytics import YOLO
+ from shared.training import YOLOTrainer, TrainingConfig
- # Load model
+ # Determine resume checkpoint
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
- if args.resume and last_checkpoint.exists():
- print(f"Resuming from: {last_checkpoint}")
- model = YOLO(str(last_checkpoint))
- else:
- model = YOLO(args.model)
+ resume_from = str(last_checkpoint) if args.resume and last_checkpoint.exists() else None
- # Training arguments
+ # Create training config
data_yaml = dataset_dir / 'dataset.yaml'
- train_args = {
- 'data': str(data_yaml.absolute()),
- 'epochs': args.epochs,
- 'batch': args.batch,
- 'imgsz': args.imgsz,
- 'project': args.project,
- 'name': args.name,
- 'device': args.device,
- 'exist_ok': True,
- 'pretrained': True,
- 'verbose': True,
- 'workers': args.workers,
- 'cache': args.cache,
- 'resume': args.resume and last_checkpoint.exists(),
- # Document-specific augmentation settings
- 'degrees': 5.0,
- 'translate': 0.05,
- 'scale': 0.2,
- 'shear': 0.0,
- 'perspective': 0.0,
- 'flipud': 0.0,
- 'fliplr': 0.0,
- 'mosaic': 0.0,
- 'mixup': 0.0,
- 'hsv_h': 0.0,
- 'hsv_s': 0.1,
- 'hsv_v': 0.2,
- }
+ config = TrainingConfig(
+ model_path=args.model,
+ data_yaml=str(data_yaml),
+ epochs=args.epochs,
+ batch_size=args.batch,
+ image_size=args.imgsz,
+ device=args.device,
+ project=args.project,
+ name=args.name,
+ workers=args.workers,
+ cache=args.cache,
+ resume=args.resume,
+ resume_from=resume_from,
+ )
- # Train
- results = model.train(**train_args)
+ # Run training
+ trainer = YOLOTrainer(config=config)
+ result = trainer.train()
+
+ if not result.success:
+ print(f"\nError: Training failed - {result.error}")
+ db.close()
+ sys.exit(1)
# Print results
print("\n" + "=" * 60)
print("Training Complete")
print("=" * 60)
- print(f"Best model: {args.project}/{args.name}/weights/best.pt")
- print(f"Last model: {args.project}/{args.name}/weights/last.pt")
+ print(f"Best model: {result.model_path}")
+ print(f"Save directory: {result.save_dir}")
+ if result.metrics:
+ print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
+ print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
# Validate on test set
print("\nRunning validation on test set...")
- metrics = model.val(split='test')
- print(f"mAP50: {metrics.box.map50:.4f}")
- print(f"mAP50-95: {metrics.box.map:.4f}")
+ if result.model_path:
+ config.model_path = result.model_path
+ config.data_yaml = str(data_yaml)
+ test_trainer = YOLOTrainer(config=config)
+ test_metrics = test_trainer.validate(split='test')
+ if test_metrics:
+ print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
+ print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
# Close database
db.close()
diff --git a/runs_backup/train/invoice_fields/args.yaml b/runs_backup/train/invoice_fields/args.yaml
new file mode 100644
index 0000000..6aa033a
--- /dev/null
+++ b/runs_backup/train/invoice_fields/args.yaml
@@ -0,0 +1,106 @@
+task: detect
+mode: train
+model: runs/train/invoice_fields/weights/last.pt
+data: /home/kai/invoice-data/dataset/dataset.yaml
+epochs: 100
+time: null
+patience: 100
+batch: 8
+imgsz: 1280
+save: true
+save_period: -1
+cache: false
+device: '0'
+workers: 8
+project: runs/train
+name: invoice_fields
+exist_ok: true
+pretrained: true
+optimizer: auto
+verbose: true
+seed: 0
+deterministic: true
+single_cls: false
+rect: false
+cos_lr: false
+close_mosaic: 10
+resume: runs/train/invoice_fields/weights/last.pt
+amp: true
+fraction: 1.0
+profile: false
+freeze: null
+multi_scale: false
+compile: false
+overlap_mask: true
+mask_ratio: 4
+dropout: 0.0
+val: true
+split: val
+save_json: false
+conf: null
+iou: 0.7
+max_det: 300
+half: false
+dnn: false
+plots: true
+source: null
+vid_stride: 1
+stream_buffer: false
+visualize: false
+augment: false
+agnostic_nms: false
+classes: null
+retina_masks: false
+embed: null
+show: false
+save_frames: false
+save_txt: false
+save_conf: false
+save_crop: false
+show_labels: true
+show_conf: true
+show_boxes: true
+line_width: null
+format: torchscript
+keras: false
+optimize: false
+int8: false
+dynamic: false
+simplify: true
+opset: null
+workspace: null
+nms: false
+lr0: 0.01
+lrf: 0.01
+momentum: 0.937
+weight_decay: 0.0005
+warmup_epochs: 3.0
+warmup_momentum: 0.8
+warmup_bias_lr: 0.0
+box: 7.5
+cls: 0.5
+dfl: 1.5
+pose: 12.0
+kobj: 1.0
+nbs: 64
+hsv_h: 0.0
+hsv_s: 0.1
+hsv_v: 0.2
+degrees: 5.0
+translate: 0.05
+scale: 0.2
+shear: 0.0
+perspective: 0.0
+flipud: 0.0
+fliplr: 0.0
+bgr: 0.0
+mosaic: 0.0
+mixup: 0.0
+cutmix: 0.0
+copy_paste: 0.0
+copy_paste_mode: flip
+auto_augment: randaugment
+erasing: 0.4
+cfg: null
+tracker: botsort.yaml
+save_dir: /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields
diff --git a/runs_backup/train/invoice_fields/results.csv b/runs_backup/train/invoice_fields/results.csv
new file mode 100644
index 0000000..2a35a32
--- /dev/null
+++ b/runs_backup/train/invoice_fields/results.csv
@@ -0,0 +1,101 @@
+epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
+1,507.213,0.97459,2.44056,1.09407,0.62221,0.6847,0.6581,0.53505,0.66652,1.13762,0.89651,0.00333127,0.00333127,0.00333127
+2,960.144,0.65449,1.06648,0.90865,0.62998,0.73947,0.70066,0.58144,0.63394,1.03248,0.89183,0.00659863,0.00659863,0.00659863
+3,1415.64,0.64921,0.97725,0.90901,0.70271,0.7713,0.76974,0.60362,0.7453,0.92005,0.9044,0.00979998,0.00979998,0.00979998
+4,1871.6,0.61159,0.90785,0.90192,0.7264,0.77507,0.78958,0.66909,0.57561,0.83662,0.87737,0.009703,0.009703,0.009703
+5,2317.68,0.55503,0.81107,0.88628,0.75742,0.81685,0.8287,0.70772,0.57193,0.74866,0.87531,0.009604,0.009604,0.009604
+6,2756.85,0.52443,0.75067,0.8767,0.77362,0.82743,0.84758,0.73413,0.53604,0.70643,0.86947,0.009505,0.009505,0.009505
+7,3211.38,0.49889,0.70526,0.8696,0.79041,0.83537,0.86367,0.75099,0.52784,0.67021,0.86686,0.009406,0.009406,0.009406
+8,3657.41,0.47889,0.66715,0.86261,0.80402,0.85067,0.87834,0.7573,0.5365,0.64019,0.86754,0.009307,0.009307,0.009307
+9,4100.61,0.46417,0.63485,0.85761,0.81776,0.84993,0.88612,0.77638,0.50423,0.61379,0.85539,0.009208,0.009208,0.009208
+10,4544.47,0.45154,0.61098,0.85334,0.82044,0.86313,0.89168,0.7685,0.54578,0.59853,0.8668,0.009109,0.009109,0.009109
+11,4983.87,0.44,0.58833,0.84961,0.82995,0.87408,0.90194,0.7972,0.48628,0.57037,0.84686,0.00901,0.00901,0.00901
+12,5419.2,0.42826,0.57051,0.84622,0.83627,0.87342,0.90482,0.79702,0.49356,0.55869,0.84836,0.008911,0.008911,0.008911
+13,5854.78,0.42052,0.55192,0.84327,0.83405,0.88243,0.90883,0.79807,0.50137,0.54503,0.85114,0.008812,0.008812,0.008812
+14,6288.44,0.41497,0.53873,0.84105,0.83687,0.8769,0.90767,0.79881,0.49873,0.54655,0.85078,0.008713,0.008713,0.008713
+15,6724.06,0.40759,0.52501,0.83904,0.84211,0.88472,0.91329,0.81111,0.47871,0.52982,0.84461,0.008614,0.008614,0.008614
+16,7164.64,0.39956,0.51303,0.83591,0.85078,0.88609,0.91766,0.80948,0.49199,0.52052,0.84634,0.008515,0.008515,0.008515
+17,7626.93,0.39651,0.50255,0.83453,0.84894,0.89151,0.92036,0.80185,0.5227,0.52229,0.8514,0.008416,0.008416,0.008416
+18,490.978,0.3953,0.49497,0.83435,0.86113,0.87974,0.92114,0.80521,0.51284,0.51491,0.84926,0.008317,0.008317,0.008317
+19,976.368,0.39234,0.49242,0.8328,0.86521,0.88431,0.92347,0.80898,0.5037,0.50744,0.84824,0.008218,0.008218,0.008218
+20,1457.68,0.3864,0.48591,0.8308,0.86361,0.88895,0.92337,0.80562,0.51844,0.50692,0.8497,0.008119,0.008119,0.008119
+21,1938.95,0.3826,0.47627,0.83086,0.85282,0.89488,0.92276,0.80107,0.53348,0.50706,0.85418,0.00802,0.00802,0.00802
+22,2418,0.37996,0.46715,0.82911,0.86698,0.89705,0.92757,0.81493,0.5038,0.49105,0.84342,0.007921,0.007921,0.007921
+23,2900.33,0.37555,0.46089,0.82769,0.86871,0.89566,0.92949,0.82251,0.48233,0.48473,0.84078,0.007822,0.007822,0.007822
+24,3381.33,0.3717,0.45531,0.82676,0.87203,0.89509,0.93002,0.80968,0.52892,0.49231,0.85017,0.007723,0.007723,0.007723
+25,3867.95,0.36913,0.44591,0.82544,0.87636,0.89074,0.93018,0.81083,0.5286,0.4882,0.84889,0.007624,0.007624,0.007624
+26,4354.94,0.3662,0.44016,0.82483,0.86957,0.89807,0.92963,0.80554,0.5486,0.49546,0.8527,0.007525,0.007525,0.007525
+27,4839.34,0.36364,0.43794,0.82368,0.87515,0.89602,0.93102,0.80927,0.54175,0.49001,0.85159,0.007426,0.007426,0.007426
+28,5324.36,0.36043,0.42951,0.8234,0.87074,0.90178,0.93175,0.8107,0.53742,0.4862,0.85096,0.007327,0.007327,0.007327
+29,5810.68,0.35852,0.42834,0.82242,0.87574,0.8992,0.93108,0.80885,0.53965,0.48712,0.85055,0.007228,0.007228,0.007228
+30,6294.55,0.35572,0.42015,0.82155,0.87635,0.90084,0.93172,0.81312,0.5252,0.48089,0.84711,0.007129,0.007129,0.007129
+31,6778.25,0.3539,0.41622,0.82025,0.8777,0.90106,0.93275,0.81359,0.52726,0.47864,0.84712,0.00703,0.00703,0.00703
+32,7261.88,0.35095,0.4105,0.82048,0.87684,0.90364,0.9337,0.81598,0.52269,0.47547,0.84625,0.006931,0.006931,0.006931
+33,7745.73,0.34835,0.40539,0.81904,0.87629,0.90454,0.933,0.81501,0.52381,0.47649,0.84713,0.006832,0.006832,0.006832
+34,8234.78,0.34699,0.402,0.8182,0.8755,0.90506,0.93336,0.81532,0.52435,0.4756,0.84749,0.006733,0.006733,0.006733
+35,8716.45,0.34487,0.39711,0.81704,0.87551,0.90432,0.93325,0.81529,0.52489,0.47558,0.84763,0.006634,0.006634,0.006634
+36,9200.09,0.34343,0.39638,0.81749,0.87551,0.90495,0.93343,0.81477,0.52659,0.4765,0.84798,0.006535,0.006535,0.006535
+37,9683.76,0.34013,0.3907,0.81647,0.87738,0.90416,0.93371,0.8152,0.52755,0.47568,0.84807,0.006436,0.006436,0.006436
+38,10164.4,0.33826,0.38626,0.81468,0.87923,0.90342,0.93399,0.81571,0.52756,0.47527,0.84794,0.006337,0.006337,0.006337
+39,10648.6,0.3366,0.3812,0.81517,0.8786,0.90333,0.93368,0.81448,0.53035,0.47601,0.84857,0.006238,0.006238,0.006238
+40,11130.6,0.3353,0.37879,0.8151,0.87974,0.90405,0.93411,0.81544,0.52739,0.47333,0.84768,0.006139,0.006139,0.006139
+41,11612.6,0.33395,0.37397,0.8143,0.88034,0.90315,0.93432,0.81514,0.529,0.47304,0.84765,0.00604,0.00604,0.00604
+42,12097.5,0.33104,0.37164,0.81448,0.87942,0.90449,0.93429,0.81484,0.53077,0.47386,0.84799,0.005941,0.005941,0.005941
+43,12579.8,0.33016,0.3681,0.81305,0.87964,0.90412,0.93457,0.81473,0.53129,0.47382,0.84789,0.005842,0.005842,0.005842
+44,13065.3,0.32845,0.36431,0.81191,0.88092,0.90312,0.9348,0.81538,0.53025,0.47337,0.84751,0.005743,0.005743,0.005743
+45,13550.4,0.32642,0.36127,0.81295,0.8805,0.905,0.93502,0.81544,0.53124,0.47333,0.8477,0.005644,0.005644,0.005644
+46,14034.3,0.32483,0.35761,0.81158,0.87898,0.90636,0.935,0.81556,0.53135,0.47317,0.84772,0.005545,0.005545,0.005545
+47,14517.3,0.32236,0.35337,0.81014,0.88018,0.90502,0.93493,0.81547,0.53228,0.473,0.8478,0.005446,0.005446,0.005446
+48,14998.6,0.3211,0.35051,0.81064,0.87941,0.9055,0.93481,0.81473,0.5353,0.47335,0.84839,0.005347,0.005347,0.005347
+49,15479.2,0.32043,0.34797,0.8097,0.87884,0.90584,0.93482,0.81429,0.53741,0.47359,0.84867,0.005248,0.005248,0.005248
+50,15962,0.3182,0.34589,0.80867,0.87776,0.90777,0.93476,0.81395,0.53841,0.47358,0.84884,0.005149,0.005149,0.005149
+51,16445.4,0.31722,0.34332,0.80879,0.87932,0.90605,0.93488,0.81439,0.53827,0.47301,0.84874,0.00505,0.00505,0.00505
+52,16925.3,0.31437,0.33963,0.80846,0.87925,0.90688,0.93521,0.81468,0.53772,0.47254,0.84861,0.004951,0.004951,0.004951
+53,17410.4,0.31435,0.33632,0.80792,0.88015,0.9069,0.93538,0.8148,0.53789,0.47193,0.84865,0.004852,0.004852,0.004852
+54,17893.2,0.31341,0.33352,0.80833,0.88132,0.90605,0.93552,0.81465,0.53934,0.47218,0.84889,0.004753,0.004753,0.004753
+55,18375.2,0.31091,0.33239,0.80688,0.88151,0.90617,0.93547,0.81453,0.54019,0.47227,0.84885,0.004654,0.004654,0.004654
+56,18857.4,0.30906,0.32753,0.80616,0.88348,0.90504,0.93565,0.81495,0.53974,0.47179,0.84864,0.004555,0.004555,0.004555
+57,19340.6,0.30722,0.32334,0.80582,0.88502,0.90326,0.93558,0.81484,0.53977,0.47174,0.84859,0.004456,0.004456,0.004456
+58,19824.8,0.30575,0.3195,0.80592,0.88348,0.90487,0.93566,0.81511,0.53921,0.47158,0.84832,0.004357,0.004357,0.004357
+59,20305.7,0.30426,0.31846,0.80494,0.88419,0.90477,0.93575,0.81534,0.5387,0.47138,0.84819,0.004258,0.004258,0.004258
+60,20785.1,0.30295,0.3154,0.80494,0.88302,0.90624,0.93568,0.81572,0.53769,0.47106,0.84788,0.004159,0.004159,0.004159
+61,21263.8,0.3013,0.3131,0.80436,0.88438,0.90545,0.93572,0.81606,0.53622,0.47079,0.84762,0.00406,0.00406,0.00406
+62,21746,0.30019,0.31077,0.80391,0.88296,0.90732,0.93578,0.8165,0.53455,0.47011,0.84724,0.003961,0.003961,0.003961
+63,22225,0.29841,0.30656,0.80379,0.88244,0.90779,0.93591,0.81693,0.53417,0.47003,0.84715,0.003862,0.003862,0.003862
+64,22704.7,0.29696,0.30489,0.80284,0.88493,0.90596,0.9359,0.81716,0.53362,0.47013,0.84707,0.003763,0.003763,0.003763
+65,23182.9,0.2952,0.30022,0.80288,0.88366,0.90731,0.93594,0.81737,0.533,0.47024,0.84695,0.003664,0.003664,0.003664
+66,23663.6,0.29337,0.29898,0.80273,0.88514,0.90611,0.93609,0.81805,0.53189,0.47015,0.84664,0.003565,0.003565,0.003565
+67,24149.7,0.29248,0.29492,0.80242,0.88664,0.90536,0.93607,0.81783,0.53208,0.4704,0.84665,0.003466,0.003466,0.003466
+68,24629,0.28987,0.29155,0.8009,0.89014,0.90207,0.93611,0.81811,0.53203,0.47046,0.84656,0.003367,0.003367,0.003367
+69,25109.9,0.28942,0.29004,0.8011,0.88939,0.90353,0.93619,0.81872,0.53127,0.4704,0.84631,0.003268,0.003268,0.003268
+70,25590.5,0.28752,0.28571,0.80059,0.88926,0.90393,0.93627,0.81909,0.53074,0.47023,0.84624,0.003169,0.003169,0.003169
+71,26072.8,0.28546,0.28301,0.7999,0.88844,0.90476,0.93631,0.81967,0.52999,0.47005,0.84612,0.00307,0.00307,0.00307
+72,26552.9,0.2842,0.28027,0.79942,0.88801,0.90505,0.93622,0.81978,0.52939,0.46994,0.84607,0.002971,0.002971,0.002971
+73,27035.5,0.28297,0.27907,0.79956,0.88781,0.90499,0.93615,0.82032,0.528,0.4694,0.84578,0.002872,0.002872,0.002872
+74,27518.8,0.2812,0.27446,0.79886,0.88848,0.90493,0.93611,0.82061,0.52675,0.46906,0.84549,0.002773,0.002773,0.002773
+75,28007.4,0.27866,0.27202,0.79684,0.88889,0.90467,0.9361,0.82099,0.5257,0.4692,0.84529,0.002674,0.002674,0.002674
+76,28499.1,0.27708,0.26798,0.7978,0.88807,0.9054,0.93615,0.82138,0.52574,0.46928,0.84523,0.002575,0.002575,0.002575
+77,28993.2,0.27398,0.2644,0.79612,0.88825,0.9055,0.93616,0.82161,0.52491,0.46925,0.84496,0.002476,0.002476,0.002476
+78,29480.5,0.27359,0.26209,0.79678,0.88876,0.90547,0.93617,0.82172,0.52467,0.4691,0.84498,0.002377,0.002377,0.002377
+79,29970.7,0.27153,0.25905,0.79585,0.88942,0.90548,0.93613,0.82211,0.52407,0.46892,0.84474,0.002278,0.002278,0.002278
+80,30453.8,0.2696,0.25647,0.79513,0.88897,0.90665,0.93617,0.82246,0.52298,0.46881,0.84445,0.002179,0.002179,0.002179
+81,30936.1,0.26895,0.25375,0.79466,0.88799,0.90791,0.93617,0.8226,0.52253,0.46904,0.84445,0.00208,0.00208,0.00208
+82,31418.6,0.26733,0.25025,0.79474,0.88945,0.90695,0.93608,0.82293,0.52172,0.4694,0.84434,0.001981,0.001981,0.001981
+83,31911.4,0.26537,0.24754,0.79479,0.89112,0.90496,0.93604,0.82338,0.52094,0.46932,0.84421,0.001882,0.001882,0.001882
+84,32402,0.26344,0.24514,0.79369,0.8928,0.90319,0.93598,0.82353,0.52015,0.46957,0.84407,0.001783,0.001783,0.001783
+85,32903.1,0.26045,0.24052,0.79226,0.89211,0.90347,0.93615,0.82427,0.51861,0.46958,0.84372,0.001684,0.001684,0.001684
+86,33413.6,0.25867,0.23781,0.79209,0.89286,0.90279,0.9362,0.82493,0.51664,0.47018,0.84338,0.001585,0.001585,0.001585
+87,33923.6,0.257,0.23463,0.792,0.89299,0.90297,0.93614,0.8254,0.5147,0.46974,0.84305,0.001486,0.001486,0.001486
+88,34436.4,0.25569,0.23153,0.79149,0.89278,0.90277,0.93609,0.82622,0.51242,0.4697,0.84266,0.001387,0.001387,0.001387
+89,34949.6,0.25343,0.22868,0.791,0.89137,0.90434,0.93599,0.82675,0.51036,0.46947,0.84227,0.001288,0.001288,0.001288
+90,35449.6,0.25194,0.22502,0.79051,0.89128,0.90489,0.93591,0.82729,0.5092,0.46975,0.84219,0.001189,0.001189,0.001189
+91,35960.5,0.2502,0.22225,0.78959,0.8898,0.90646,0.93586,0.82781,0.50761,0.46999,0.84186,0.00109,0.00109,0.00109
+92,36452.1,0.24777,0.21844,0.78906,0.89054,0.9057,0.93593,0.82831,0.50603,0.47043,0.84154,0.000991,0.000991,0.000991
+93,36942.9,0.24554,0.21503,0.78861,0.88979,0.90679,0.93584,0.82858,0.50495,0.4703,0.84125,0.000892,0.000892,0.000892
+94,37430.3,0.2434,0.21193,0.78799,0.88928,0.90756,0.93566,0.8288,0.5041,0.47075,0.84109,0.000793,0.000793,0.000793
+95,37918.7,0.2413,0.20892,0.78736,0.8899,0.90683,0.93567,0.82882,0.50339,0.47152,0.84105,0.000694,0.000694,0.000694
+96,38404.3,0.2405,0.20619,0.78595,0.88999,0.90713,0.9355,0.82912,0.50244,0.47239,0.84104,0.000595,0.000595,0.000595
+97,38893.8,0.23808,0.2031,0.78683,0.88982,0.90634,0.93531,0.82938,0.50187,0.47281,0.84116,0.000496,0.000496,0.000496
+98,39382.8,0.23581,0.20034,0.78643,0.89144,0.9045,0.93517,0.82959,0.50119,0.47383,0.84128,0.000397,0.000397,0.000397
+99,39871.7,0.23432,0.19778,0.78568,0.89187,0.90415,0.93488,0.82953,0.50058,0.47452,0.84126,0.000298,0.000298,0.000298
+100,40359.7,0.233,0.19485,0.78528,0.89228,0.90349,0.93471,0.82961,0.50029,0.47497,0.84139,0.000199,0.000199,0.000199
diff --git a/runs_backup/train/invoice_yolo11n_full/args.yaml b/runs_backup/train/invoice_yolo11n_full/args.yaml
new file mode 100644
index 0000000..19a11c5
--- /dev/null
+++ b/runs_backup/train/invoice_yolo11n_full/args.yaml
@@ -0,0 +1,106 @@
+task: detect
+mode: train
+model: yolo11n.pt
+data: /home/kai/invoice-data/dataset/dataset.yaml
+epochs: 100
+time: null
+patience: 100
+batch: 16
+imgsz: 1280
+save: true
+save_period: -1
+cache: false
+device: '0'
+workers: 8
+project: runs/train
+name: invoice_yolo11n_full
+exist_ok: true
+pretrained: true
+optimizer: auto
+verbose: true
+seed: 0
+deterministic: true
+single_cls: false
+rect: false
+cos_lr: false
+close_mosaic: 10
+resume: false
+amp: true
+fraction: 1.0
+profile: false
+freeze: null
+multi_scale: false
+compile: false
+overlap_mask: true
+mask_ratio: 4
+dropout: 0.0
+val: true
+split: val
+save_json: false
+conf: null
+iou: 0.7
+max_det: 300
+half: false
+dnn: false
+plots: true
+source: null
+vid_stride: 1
+stream_buffer: false
+visualize: false
+augment: false
+agnostic_nms: false
+classes: null
+retina_masks: false
+embed: null
+show: false
+save_frames: false
+save_txt: false
+save_conf: false
+save_crop: false
+show_labels: true
+show_conf: true
+show_boxes: true
+line_width: null
+format: torchscript
+keras: false
+optimize: false
+int8: false
+dynamic: false
+simplify: true
+opset: null
+workspace: null
+nms: false
+lr0: 0.01
+lrf: 0.01
+momentum: 0.937
+weight_decay: 0.0005
+warmup_epochs: 3.0
+warmup_momentum: 0.8
+warmup_bias_lr: 0.1
+box: 7.5
+cls: 0.5
+dfl: 1.5
+pose: 12.0
+kobj: 1.0
+nbs: 64
+hsv_h: 0.0
+hsv_s: 0.1
+hsv_v: 0.2
+degrees: 5.0
+translate: 0.05
+scale: 0.2
+shear: 0.0
+perspective: 0.0
+flipud: 0.0
+fliplr: 0.0
+bgr: 0.0
+mosaic: 0.0
+mixup: 0.0
+cutmix: 0.0
+copy_paste: 0.0
+copy_paste_mode: flip
+auto_augment: randaugment
+erasing: 0.4
+cfg: null
+tracker: botsort.yaml
+save_dir: /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_yolo11n_full
diff --git a/runs_backup/train/invoice_yolo11n_full/results.csv b/runs_backup/train/invoice_yolo11n_full/results.csv
new file mode 100644
index 0000000..181dab6
--- /dev/null
+++ b/runs_backup/train/invoice_yolo11n_full/results.csv
@@ -0,0 +1,101 @@
+epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
+1,217.641,0.79856,2.56507,1.01986,0.8921,0.84545,0.90033,0.81508,0.49347,0.92369,0.83815,0.00332991,0.00332991,0.00332991
+2,410.275,0.506,1.04596,0.85661,0.9164,0.85418,0.92852,0.79726,0.63851,0.72277,0.87404,0.00659728,0.00659728,0.00659728
+3,598.713,0.49618,0.68647,0.85624,0.93775,0.80014,0.91481,0.80383,0.5798,0.76691,0.86835,0.00979865,0.00979865,0.00979865
+4,782.868,0.44059,0.53532,0.84299,0.94668,0.90421,0.96101,0.8832,0.43961,0.49649,0.84298,0.009703,0.009703,0.009703
+5,967.898,0.37596,0.44308,0.82667,0.88316,0.81492,0.91376,0.82272,0.50616,0.70202,0.85825,0.009604,0.009604,0.009604
+6,1152.03,0.33999,0.39482,0.81661,0.81567,0.73691,0.82644,0.75085,0.43451,0.92038,0.86158,0.009505,0.009505,0.009505
+7,1335.35,0.31114,0.35971,0.80992,0.95256,0.89807,0.96383,0.84241,0.60248,0.48455,0.88156,0.009406,0.009406,0.009406
+8,1518.68,0.29176,0.33987,0.80516,0.97058,0.91185,0.97221,0.85356,0.58408,0.43771,0.86239,0.009307,0.009307,0.009307
+9,1702.03,0.27683,0.3214,0.80166,0.96403,0.91663,0.9736,0.85118,0.6359,0.43055,0.88091,0.009208,0.009208,0.009208
+10,1891.76,0.26487,0.30796,0.79943,0.96201,0.92669,0.97715,0.894,0.46381,0.37437,0.84314,0.009109,0.009109,0.009109
+11,2081.79,0.25744,0.29846,0.79614,0.96562,0.92382,0.97554,0.79415,0.81302,0.46321,0.93576,0.00901,0.00901,0.00901
+12,2273.34,0.24726,0.28842,0.79445,0.96248,0.92901,0.97544,0.7642,0.93193,0.48807,0.98769,0.008911,0.008911,0.008911
+13,2461.71,0.24266,0.27619,0.79413,0.9672,0.93016,0.97834,0.69208,1.24698,0.59927,1.19042,0.008812,0.008812,0.008812
+14,2649.81,0.23391,0.26941,0.79165,0.96579,0.93247,0.98028,0.72182,1.0815,0.50846,1.08505,0.008713,0.008713,0.008713
+15,2837.95,0.22893,0.2651,0.79082,0.9639,0.93414,0.9807,0.8522,0.64306,0.38931,0.88909,0.008614,0.008614,0.008614
+16,3021.27,0.22269,0.25369,0.78809,0.97667,0.92754,0.98283,0.79198,0.89233,0.43512,0.98623,0.008515,0.008515,0.008515
+17,3209.3,0.21937,0.24886,0.78797,0.96559,0.93193,0.98178,0.67198,1.35518,0.59949,1.30325,0.008416,0.008416,0.008416
+18,3400.76,0.21415,0.24489,0.78789,0.95973,0.94489,0.98156,0.63227,1.45967,0.63486,1.42858,0.008317,0.008317,0.008317
+19,3590.96,0.21227,0.23986,0.78736,0.96136,0.94369,0.98263,0.83035,0.76379,0.39831,0.92541,0.008218,0.008218,0.008218
+20,3779.15,0.20834,0.23506,0.78475,0.96214,0.93563,0.97908,0.5024,1.92081,0.81282,1.99717,0.008119,0.008119,0.008119
+21,3976.3,0.20592,0.23055,0.78534,0.9636,0.94141,0.98186,0.71087,1.20783,0.53596,1.17998,0.00802,0.00802,0.00802
+22,4165.69,0.20195,0.22554,0.78431,0.96621,0.94394,0.98458,0.86353,0.6245,0.35194,0.86978,0.007921,0.007921,0.007921
+23,4357.69,0.19847,0.22066,0.78362,0.9745,0.93877,0.98501,0.84365,0.71155,0.37717,0.91395,0.007822,0.007822,0.007822
+24,4548.46,0.19715,0.21991,0.78423,0.96456,0.94907,0.98541,0.77136,0.9901,0.45056,1.02401,0.007723,0.007723,0.007723
+25,4738.2,0.19207,0.21284,0.7821,0.97136,0.94417,0.98568,0.8139,0.83053,0.41526,0.94475,0.007624,0.007624,0.007624
+26,4926.95,0.19124,0.21138,0.7823,0.9712,0.94333,0.98466,0.78106,0.94702,0.44977,1.0068,0.007525,0.007525,0.007525
+27,5115.34,0.18944,0.21166,0.78245,0.97207,0.9347,0.98325,0.57941,1.64865,0.72474,1.70082,0.007426,0.007426,0.007426
+28,5303.92,0.18817,0.20777,0.7814,0.96837,0.9519,0.98672,0.77596,0.96734,0.43218,1.01592,0.007327,0.007327,0.007327
+29,5493.53,0.18565,0.20231,0.78154,0.9719,0.94481,0.98552,0.67875,1.27094,0.57309,1.27411,0.007228,0.007228,0.007228
+30,5682.31,0.18424,0.19916,0.77989,0.96714,0.95269,0.98588,0.73712,1.08764,0.52054,1.09884,0.007129,0.007129,0.007129
+31,5870.11,0.1812,0.19544,0.78013,0.9698,0.95028,0.98687,0.72258,1.15995,0.50918,1.14823,0.00703,0.00703,0.00703
+32,6060.08,0.1801,0.19571,0.7799,0.9699,0.95342,0.98761,0.83372,0.70788,0.36986,0.89674,0.006931,0.006931,0.006931
+33,6244.87,0.17816,0.19373,0.7789,0.96278,0.95546,0.98684,0.85265,0.66817,0.36063,0.88589,0.006832,0.006832,0.006832
+34,6427.13,0.176,0.1909,0.77864,0.96962,0.95053,0.98707,0.85351,0.70717,0.36925,0.90913,0.006733,0.006733,0.006733
+35,6619.57,0.17377,0.18513,0.77794,0.97418,0.94725,0.98662,0.83346,0.77385,0.3987,0.92964,0.006634,0.006634,0.006634
+36,6807.72,0.17359,0.18454,0.77885,0.97363,0.95,0.98703,0.85072,0.72917,0.36745,0.91556,0.006535,0.006535,0.006535
+37,7001.26,0.17126,0.18179,0.77796,0.96337,0.95646,0.98744,0.79259,0.86734,0.4072,0.96666,0.006436,0.006436,0.006436
+38,7186.91,0.16989,0.17967,0.77791,0.97277,0.94891,0.98737,0.8268,0.78213,0.38811,0.9293,0.006337,0.006337,0.006337
+39,7372.72,0.16823,0.17959,0.77698,0.96961,0.95206,0.98714,0.86035,0.69764,0.35696,0.90589,0.006238,0.006238,0.006238
+40,7558.51,0.16639,0.17648,0.77676,0.96471,0.9592,0.98756,0.85427,0.70458,0.35248,0.90042,0.006139,0.006139,0.006139
+41,7747.05,0.16641,0.17472,0.77686,0.96565,0.95698,0.98751,0.76422,1.01303,0.45638,1.04876,0.00604,0.00604,0.00604
+42,7930.66,0.16528,0.17295,0.77783,0.98029,0.94412,0.98709,0.68134,1.27076,0.55229,1.27714,0.005941,0.005941,0.005941
+43,8113.31,0.16304,0.17093,0.77627,0.96909,0.95623,0.98692,0.77479,0.97729,0.45248,1.03101,0.005842,0.005842,0.005842
+44,8298.3,0.16163,0.16817,0.77509,0.96809,0.9575,0.98709,0.80448,0.85637,0.40945,0.96298,0.005743,0.005743,0.005743
+45,8485.25,0.16053,0.16768,0.77535,0.97311,0.95211,0.98726,0.8047,0.85835,0.40888,0.96605,0.005644,0.005644,0.005644
+46,8669.88,0.15959,0.16634,0.77576,0.97431,0.95186,0.98739,0.797,0.87218,0.41446,0.97186,0.005545,0.005545,0.005545
+47,8853.93,0.15778,0.16234,0.77599,0.97532,0.95052,0.98702,0.78582,0.92511,0.43102,1.0039,0.005446,0.005446,0.005446
+48,9037.52,0.15602,0.16175,0.77439,0.97529,0.94998,0.9874,0.84071,0.77064,0.38361,0.93995,0.005347,0.005347,0.005347
+49,9223.9,0.15478,0.1604,0.77364,0.97345,0.95143,0.98729,0.84662,0.73185,0.37143,0.92248,0.005248,0.005248,0.005248
+50,9411.45,0.15431,0.1584,0.77449,0.98033,0.94592,0.98733,0.86995,0.63137,0.34173,0.8816,0.005149,0.005149,0.005149
+51,9595.98,0.15305,0.15648,0.77414,0.97318,0.95431,0.98753,0.86938,0.63298,0.34305,0.88158,0.00505,0.00505,0.00505
+52,9779.61,0.15291,0.15561,0.77441,0.97824,0.94936,0.98777,0.87333,0.60174,0.33831,0.87135,0.004951,0.004951,0.004951
+53,9963.17,0.15193,0.15454,0.77361,0.97077,0.95579,0.98737,0.87038,0.62864,0.34404,0.87819,0.004852,0.004852,0.004852
+54,10151.9,0.14978,0.15218,0.77431,0.97892,0.94812,0.9874,0.86664,0.6559,0.35033,0.88752,0.004753,0.004753,0.004753
+55,10335.2,0.14867,0.14954,0.77318,0.98227,0.94556,0.98734,0.86535,0.66191,0.35233,0.89052,0.004654,0.004654,0.004654
+56,10517.7,0.14781,0.1504,0.77387,0.97187,0.95472,0.98731,0.85393,0.70291,0.36551,0.90638,0.004555,0.004555,0.004555
+57,10704.4,0.14704,0.14654,0.77286,0.96973,0.95539,0.98731,0.84386,0.75517,0.38619,0.92774,0.004456,0.004456,0.004456
+58,10888.6,0.14478,0.14588,0.77324,0.9792,0.94676,0.9872,0.84023,0.76095,0.38846,0.93011,0.004357,0.004357,0.004357
+59,11071.2,0.14408,0.14418,0.7724,0.9709,0.95553,0.98729,0.8499,0.71784,0.37089,0.91332,0.004258,0.004258,0.004258
+60,11256.4,0.1427,0.14165,0.77106,0.96919,0.95682,0.98729,0.85156,0.70256,0.36509,0.90774,0.004159,0.004159,0.004159
+61,11444.8,0.14194,0.14087,0.77269,0.96601,0.96121,0.98731,0.85107,0.70839,0.36753,0.90948,0.00406,0.00406,0.00406
+62,11630.9,0.14062,0.13882,0.77215,0.96628,0.96081,0.98736,0.84858,0.73074,0.3762,0.92033,0.003961,0.003961,0.003961
+63,11816.5,0.13938,0.13865,0.77152,0.96711,0.95961,0.98744,0.85214,0.70862,0.36754,0.91079,0.003862,0.003862,0.003862
+64,12005.3,0.13858,0.13687,0.77045,0.96702,0.9595,0.98748,0.85574,0.69672,0.36084,0.90588,0.003763,0.003763,0.003763
+65,12191.8,0.13775,0.13411,0.77132,0.96785,0.95943,0.9874,0.85729,0.68875,0.35766,0.90221,0.003664,0.003664,0.003664
+66,12379.6,0.13556,0.13271,0.77167,0.96725,0.96005,0.98735,0.85898,0.68174,0.3561,0.89887,0.003565,0.003565,0.003565
+67,12565.4,0.13463,0.13108,0.77009,0.97381,0.95338,0.98732,0.86031,0.67263,0.35399,0.89484,0.003466,0.003466,0.003466
+68,12752.5,0.13515,0.1311,0.77095,0.96906,0.95916,0.98725,0.86029,0.66717,0.35274,0.89292,0.003367,0.003367,0.003367
+69,12940.8,0.13415,0.12957,0.76963,0.97126,0.95685,0.9873,0.86049,0.6644,0.35306,0.8918,0.003268,0.003268,0.003268
+70,13133.2,0.13179,0.12737,0.76937,0.97287,0.95478,0.98727,0.86047,0.6632,0.35246,0.89193,0.003169,0.003169,0.003169
+71,13319.4,0.13185,0.1274,0.77079,0.97267,0.95587,0.98722,0.86193,0.65949,0.35213,0.89086,0.00307,0.00307,0.00307
+72,13504.8,0.12947,0.12446,0.76998,0.97199,0.95686,0.98725,0.86401,0.64895,0.34877,0.88741,0.002971,0.002971,0.002971
+73,13695.1,0.12876,0.12321,0.76883,0.9723,0.9569,0.98725,0.86643,0.64091,0.3473,0.88447,0.002872,0.002872,0.002872
+74,13882,0.12828,0.12194,0.76915,0.97256,0.95686,0.98732,0.86847,0.6322,0.34702,0.88109,0.002773,0.002773,0.002773
+75,14075,0.12664,0.11944,0.76878,0.97277,0.95678,0.98726,0.86861,0.63123,0.3482,0.88086,0.002674,0.002674,0.002674
+76,14259.9,0.12587,0.11965,0.7692,0.9727,0.95673,0.98717,0.86916,0.62721,0.34811,0.87932,0.002575,0.002575,0.002575
+77,14451.2,0.12433,0.1174,0.76838,0.97267,0.95663,0.9872,0.87057,0.62032,0.34709,0.8769,0.002476,0.002476,0.002476
+78,14636.3,0.12352,0.11507,0.76971,0.97087,0.95829,0.98721,0.87189,0.61271,0.34667,0.87445,0.002377,0.002377,0.002377
+79,14821.1,0.1231,0.11454,0.76897,0.97195,0.95752,0.98722,0.87292,0.60714,0.34596,0.87271,0.002278,0.002278,0.002278
+80,15007.3,0.12117,0.11285,0.76864,0.97146,0.95789,0.98726,0.8735,0.6031,0.34515,0.87163,0.002179,0.002179,0.002179
+81,15199.5,0.12029,0.11158,0.76708,0.97018,0.95938,0.9872,0.87378,0.60116,0.34538,0.87113,0.00208,0.00208,0.00208
+82,15390.8,0.11877,0.10949,0.76719,0.97021,0.95964,0.98721,0.87422,0.59897,0.3457,0.87057,0.001981,0.001981,0.001981
+83,15577.6,0.11812,0.10818,0.76749,0.97013,0.95951,0.98722,0.87429,0.59878,0.34524,0.87054,0.001882,0.001882,0.001882
+84,15761.5,0.11687,0.10634,0.76703,0.97155,0.9583,0.98713,0.87407,0.59964,0.34532,0.8709,0.001783,0.001783,0.001783
+85,15946.2,0.11551,0.10455,0.7672,0.9717,0.95797,0.9871,0.87367,0.60049,0.34569,0.87136,0.001684,0.001684,0.001684
+86,16130.6,0.11474,0.10479,0.76737,0.97183,0.95808,0.98712,0.87406,0.5981,0.34504,0.87076,0.001585,0.001585,0.001585
+87,16324.1,0.11337,0.10221,0.76695,0.97137,0.95881,0.98708,0.87382,0.59851,0.34519,0.87106,0.001486,0.001486,0.001486
+88,16517.1,0.11185,0.10043,0.76513,0.97121,0.95899,0.98707,0.87379,0.59906,0.34583,0.87135,0.001387,0.001387,0.001387
+89,16708.5,0.11103,0.09846,0.76565,0.97113,0.95904,0.98709,0.87369,0.59838,0.34599,0.87138,0.001288,0.001288,0.001288
+90,16896.9,0.11054,0.0982,0.76703,0.97095,0.95892,0.98712,0.87377,0.59757,0.34552,0.87126,0.001189,0.001189,0.001189
+91,17091.5,0.10967,0.09616,0.76665,0.97037,0.9595,0.98704,0.87361,0.59635,0.34561,0.87111,0.00109,0.00109,0.00109
+92,17282.9,0.10834,0.09481,0.76509,0.9726,0.95743,0.98704,0.87372,0.5956,0.34572,0.87108,0.000991,0.000991,0.000991
+93,17471.1,0.10692,0.09247,0.76461,0.97255,0.95738,0.9871,0.87368,0.59467,0.34689,0.8707,0.000892,0.000892,0.000892
+94,17654.7,0.10578,0.09076,0.76573,0.97167,0.95786,0.9872,0.87367,0.59367,0.34732,0.87049,0.000793,0.000793,0.000793
+95,17858.1,0.10457,0.08903,0.7648,0.97097,0.95816,0.98718,0.87394,0.59295,0.34757,0.87044,0.000694,0.000694,0.000694
+96,18048,0.10283,0.08802,0.76437,0.97358,0.95577,0.98712,0.8737,0.59392,0.34877,0.87087,0.000595,0.000595,0.000595
+97,18233,0.10269,0.08685,0.76468,0.97469,0.95492,0.98712,0.8741,0.59227,0.34903,0.87042,0.000496,0.000496,0.000496
+98,18418.2,0.10143,0.0852,0.7644,0.97473,0.95512,0.98709,0.87397,0.59171,0.35007,0.8704,0.000397,0.000397,0.000397
+99,18605.1,0.10052,0.08363,0.76442,0.97443,0.95526,0.98712,0.87396,0.5922,0.35121,0.87087,0.000298,0.000298,0.000298
+100,18790,0.09925,0.08228,0.76465,0.97498,0.95493,0.98711,0.8737,0.59312,0.35293,0.87138,0.000199,0.000199,0.000199
diff --git a/tests/shared/augmentation/__init__.py b/tests/shared/augmentation/__init__.py
new file mode 100644
index 0000000..56fc805
--- /dev/null
+++ b/tests/shared/augmentation/__init__.py
@@ -0,0 +1 @@
+# Tests for augmentation module
diff --git a/tests/shared/augmentation/test_base.py b/tests/shared/augmentation/test_base.py
new file mode 100644
index 0000000..3f238d1
--- /dev/null
+++ b/tests/shared/augmentation/test_base.py
@@ -0,0 +1,347 @@
+"""
+Tests for augmentation base module.
+
+TDD Phase 1: RED - Write tests first, then implement to pass.
+"""
+
+from typing import Any
+from unittest.mock import MagicMock
+
+import numpy as np
+import pytest
+
+
+class TestAugmentationResult:
+ """Tests for AugmentationResult dataclass."""
+
+ def test_minimal_result(self) -> None:
+ """Test creating result with only required field."""
+ from shared.augmentation.base import AugmentationResult
+
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+ result = AugmentationResult(image=image)
+
+ assert result.image is image
+ assert result.bboxes is None
+ assert result.transform_matrix is None
+ assert result.applied is True
+ assert result.metadata is None
+
+ def test_full_result(self) -> None:
+ """Test creating result with all fields."""
+ from shared.augmentation.base import AugmentationResult
+
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+ bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
+ transform = np.eye(3)
+ metadata = {"applied_transform": "wrinkle"}
+
+ result = AugmentationResult(
+ image=image,
+ bboxes=bboxes,
+ transform_matrix=transform,
+ applied=True,
+ metadata=metadata,
+ )
+
+ assert result.image is image
+ np.testing.assert_array_equal(result.bboxes, bboxes)
+ np.testing.assert_array_equal(result.transform_matrix, transform)
+ assert result.applied is True
+ assert result.metadata == {"applied_transform": "wrinkle"}
+
+ def test_not_applied(self) -> None:
+ """Test result when augmentation was not applied."""
+ from shared.augmentation.base import AugmentationResult
+
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+ result = AugmentationResult(image=image, applied=False)
+
+ assert result.applied is False
+
+
+class TestBaseAugmentation:
+ """Tests for BaseAugmentation abstract class."""
+
+ def test_cannot_instantiate_directly(self) -> None:
+ """Test that BaseAugmentation cannot be instantiated."""
+ from shared.augmentation.base import BaseAugmentation
+
+ with pytest.raises(TypeError):
+ BaseAugmentation({}) # type: ignore
+
+ def test_subclass_must_implement_apply(self) -> None:
+ """Test that subclass must implement apply method."""
+ from shared.augmentation.base import BaseAugmentation
+
+ class IncompleteAugmentation(BaseAugmentation):
+ name = "incomplete"
+
+ def _validate_params(self) -> None:
+ pass
+
+ # Missing apply method
+
+ with pytest.raises(TypeError):
+ IncompleteAugmentation({}) # type: ignore
+
+ def test_subclass_must_implement_validate_params(self) -> None:
+ """Test that subclass must implement _validate_params."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class IncompleteAugmentation(BaseAugmentation):
+ name = "incomplete"
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ return AugmentationResult(image=image)
+
+ # Missing _validate_params method
+
+ with pytest.raises(TypeError):
+ IncompleteAugmentation({}) # type: ignore
+
+ def test_valid_subclass(self) -> None:
+ """Test creating a valid subclass."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class DummyAugmentation(BaseAugmentation):
+ name = "dummy"
+ affects_geometry = False
+
+ def _validate_params(self) -> None:
+ pass
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ return AugmentationResult(image=image, bboxes=bboxes)
+
+ aug = DummyAugmentation({"param1": "value1"})
+
+ assert aug.name == "dummy"
+ assert aug.affects_geometry is False
+ assert aug.params == {"param1": "value1"}
+
+ def test_apply_returns_augmentation_result(self) -> None:
+ """Test that apply returns AugmentationResult."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class DummyAugmentation(BaseAugmentation):
+ name = "dummy"
+
+ def _validate_params(self) -> None:
+ pass
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ # Simple pass-through
+ return AugmentationResult(image=image, bboxes=bboxes)
+
+ aug = DummyAugmentation({})
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+ bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
+
+ result = aug.apply(image, bboxes)
+
+ assert isinstance(result, AugmentationResult)
+ assert result.image is image
+ np.testing.assert_array_equal(result.bboxes, bboxes)
+
+ def test_affects_geometry_default(self) -> None:
+ """Test that affects_geometry defaults to False."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class DummyAugmentation(BaseAugmentation):
+ name = "dummy"
+ # Not setting affects_geometry
+
+ def _validate_params(self) -> None:
+ pass
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ return AugmentationResult(image=image)
+
+ aug = DummyAugmentation({})
+
+ assert aug.affects_geometry is False
+
+ def test_validate_params_called_on_init(self) -> None:
+ """Test that _validate_params is called during initialization."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ validation_called = {"called": False}
+
+ class ValidatingAugmentation(BaseAugmentation):
+ name = "validating"
+
+ def _validate_params(self) -> None:
+ validation_called["called"] = True
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ return AugmentationResult(image=image)
+
+ ValidatingAugmentation({})
+
+ assert validation_called["called"] is True
+
+ def test_validate_params_raises_on_invalid(self) -> None:
+ """Test that _validate_params can raise ValueError."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class StrictAugmentation(BaseAugmentation):
+ name = "strict"
+
+ def _validate_params(self) -> None:
+ if "required_param" not in self.params:
+ raise ValueError("required_param is required")
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ return AugmentationResult(image=image)
+
+ with pytest.raises(ValueError, match="required_param"):
+ StrictAugmentation({})
+
+ # Should work with required param
+ aug = StrictAugmentation({"required_param": "value"})
+ assert aug.params["required_param"] == "value"
+
+ def test_rng_usage(self) -> None:
+ """Test that random generator can be passed and used."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class RandomAugmentation(BaseAugmentation):
+ name = "random"
+
+ def _validate_params(self) -> None:
+ pass
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ if rng is None:
+ rng = np.random.default_rng()
+ # Use rng to generate a random value
+ random_value = rng.random()
+ return AugmentationResult(
+ image=image,
+ metadata={"random_value": random_value},
+ )
+
+ aug = RandomAugmentation({})
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+
+ # With same seed, should get same result
+ rng1 = np.random.default_rng(42)
+ rng2 = np.random.default_rng(42)
+
+ result1 = aug.apply(image, rng=rng1)
+ result2 = aug.apply(image, rng=rng2)
+
+ assert result1.metadata is not None
+ assert result2.metadata is not None
+ assert result1.metadata["random_value"] == result2.metadata["random_value"]
+
+
+class TestAugmentationResultImmutability:
+ """Tests for ensuring result doesn't mutate input."""
+
+ def test_image_not_modified(self) -> None:
+ """Test that original image is not modified."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class ModifyingAugmentation(BaseAugmentation):
+ name = "modifying"
+
+ def _validate_params(self) -> None:
+ pass
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ # Should copy before modifying
+ modified = image.copy()
+ modified[:] = 255
+ return AugmentationResult(image=modified)
+
+ aug = ModifyingAugmentation({})
+ original = np.zeros((100, 100, 3), dtype=np.uint8)
+ original_copy = original.copy()
+
+ result = aug.apply(original)
+
+ # Original should be unchanged
+ np.testing.assert_array_equal(original, original_copy)
+ # Result should be modified
+ assert np.all(result.image == 255)
+
+ def test_bboxes_not_modified(self) -> None:
+ """Test that original bboxes are not modified."""
+ from shared.augmentation.base import AugmentationResult, BaseAugmentation
+
+ class BboxModifyingAugmentation(BaseAugmentation):
+ name = "bbox_modifying"
+ affects_geometry = True
+
+ def _validate_params(self) -> None:
+ pass
+
+ def apply(
+ self,
+ image: np.ndarray,
+ bboxes: np.ndarray | None = None,
+ rng: np.random.Generator | None = None,
+ ) -> AugmentationResult:
+ if bboxes is not None:
+ # Should copy before modifying
+ modified_bboxes = bboxes.copy()
+ modified_bboxes[:, 1:] *= 0.5 # Scale down
+ return AugmentationResult(image=image, bboxes=modified_bboxes)
+ return AugmentationResult(image=image)
+
+ aug = BboxModifyingAugmentation({})
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+ original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
+ original_bboxes_copy = original_bboxes.copy()
+
+ result = aug.apply(image, original_bboxes)
+
+ # Original should be unchanged
+ np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
+ # Result should be modified
+ assert result.bboxes is not None
+ np.testing.assert_array_almost_equal(
+ result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
+ )
diff --git a/tests/shared/augmentation/test_config.py b/tests/shared/augmentation/test_config.py
new file mode 100644
index 0000000..4b6047c
--- /dev/null
+++ b/tests/shared/augmentation/test_config.py
@@ -0,0 +1,283 @@
+"""
+Tests for augmentation configuration module.
+
+TDD Phase 1: RED - Write tests first, then implement to pass.
+"""
+
+from typing import Any
+
+import pytest
+
+
+class TestAugmentationParams:
+ """Tests for AugmentationParams dataclass."""
+
+ def test_default_values(self) -> None:
+ """Test default parameter values."""
+ from shared.augmentation.config import AugmentationParams
+
+ params = AugmentationParams()
+
+ assert params.enabled is False
+ assert params.probability == 0.5
+ assert params.params == {}
+
+ def test_custom_values(self) -> None:
+ """Test creating params with custom values."""
+ from shared.augmentation.config import AugmentationParams
+
+ params = AugmentationParams(
+ enabled=True,
+ probability=0.8,
+ params={"intensity": 0.5, "num_wrinkles": (2, 5)},
+ )
+
+ assert params.enabled is True
+ assert params.probability == 0.8
+ assert params.params["intensity"] == 0.5
+ assert params.params["num_wrinkles"] == (2, 5)
+
+ def test_immutability_params_dict(self) -> None:
+ """Test that params dict is independent between instances."""
+ from shared.augmentation.config import AugmentationParams
+
+ params1 = AugmentationParams()
+ params2 = AugmentationParams()
+
+ # Modifying one should not affect the other
+ params1.params["test"] = "value"
+
+ assert "test" not in params2.params
+
+ def test_to_dict(self) -> None:
+ """Test conversion to dictionary."""
+ from shared.augmentation.config import AugmentationParams
+
+ params = AugmentationParams(
+ enabled=True,
+ probability=0.7,
+ params={"key": "value"},
+ )
+
+ result = params.to_dict()
+
+ assert result == {
+ "enabled": True,
+ "probability": 0.7,
+ "params": {"key": "value"},
+ }
+
+ def test_from_dict(self) -> None:
+ """Test creation from dictionary."""
+ from shared.augmentation.config import AugmentationParams
+
+ data = {
+ "enabled": True,
+ "probability": 0.6,
+ "params": {"intensity": 0.3},
+ }
+
+ params = AugmentationParams.from_dict(data)
+
+ assert params.enabled is True
+ assert params.probability == 0.6
+ assert params.params == {"intensity": 0.3}
+
+ def test_from_dict_with_defaults(self) -> None:
+ """Test creation from partial dictionary uses defaults."""
+ from shared.augmentation.config import AugmentationParams
+
+ data: dict[str, Any] = {"enabled": True}
+
+ params = AugmentationParams.from_dict(data)
+
+ assert params.enabled is True
+ assert params.probability == 0.5 # default
+ assert params.params == {} # default
+
+
+class TestAugmentationConfig:
+ """Tests for AugmentationConfig dataclass."""
+
+ def test_default_values(self) -> None:
+ """Test that all augmentation types have defaults."""
+ from shared.augmentation.config import AugmentationConfig
+
+ config = AugmentationConfig()
+
+ # All augmentation types should exist
+ augmentation_types = [
+ "perspective_warp",
+ "wrinkle",
+ "edge_damage",
+ "stain",
+ "lighting_variation",
+ "shadow",
+ "gaussian_blur",
+ "motion_blur",
+ "gaussian_noise",
+ "salt_pepper",
+ "paper_texture",
+ "scanner_artifacts",
+ ]
+
+ for aug_type in augmentation_types:
+ assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
+ params = getattr(config, aug_type)
+ assert hasattr(params, "enabled")
+ assert hasattr(params, "probability")
+ assert hasattr(params, "params")
+
+ def test_global_settings_defaults(self) -> None:
+ """Test global settings default values."""
+ from shared.augmentation.config import AugmentationConfig
+
+ config = AugmentationConfig()
+
+ assert config.preserve_bboxes is True
+ assert config.seed is None
+
+ def test_custom_seed(self) -> None:
+ """Test setting custom seed for reproducibility."""
+ from shared.augmentation.config import AugmentationConfig
+
+ config = AugmentationConfig(seed=42)
+
+ assert config.seed == 42
+
+ def test_to_dict(self) -> None:
+ """Test conversion to dictionary."""
+ from shared.augmentation.config import AugmentationConfig
+
+ config = AugmentationConfig(seed=123, preserve_bboxes=False)
+
+ result = config.to_dict()
+
+ assert isinstance(result, dict)
+ assert result["seed"] == 123
+ assert result["preserve_bboxes"] is False
+ assert "perspective_warp" in result
+ assert "wrinkle" in result
+
+ def test_from_dict(self) -> None:
+ """Test creation from dictionary."""
+ from shared.augmentation.config import AugmentationConfig
+
+ data = {
+ "seed": 456,
+ "preserve_bboxes": False,
+ "wrinkle": {
+ "enabled": True,
+ "probability": 0.8,
+ "params": {"intensity": 0.5},
+ },
+ }
+
+ config = AugmentationConfig.from_dict(data)
+
+ assert config.seed == 456
+ assert config.preserve_bboxes is False
+ assert config.wrinkle.enabled is True
+ assert config.wrinkle.probability == 0.8
+ assert config.wrinkle.params["intensity"] == 0.5
+
+ def test_from_dict_with_partial_data(self) -> None:
+ """Test creation from partial dictionary uses defaults."""
+ from shared.augmentation.config import AugmentationConfig
+
+ data: dict[str, Any] = {
+ "wrinkle": {"enabled": True},
+ }
+
+ config = AugmentationConfig.from_dict(data)
+
+ # Explicitly set value
+ assert config.wrinkle.enabled is True
+ # Default values
+ assert config.preserve_bboxes is True
+ assert config.seed is None
+ assert config.gaussian_blur.enabled is False
+
+ def test_get_enabled_augmentations(self) -> None:
+ """Test getting list of enabled augmentations."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+
+ config = AugmentationConfig(
+ wrinkle=AugmentationParams(enabled=True),
+ stain=AugmentationParams(enabled=True),
+ gaussian_blur=AugmentationParams(enabled=False),
+ )
+
+ enabled = config.get_enabled_augmentations()
+
+ assert "wrinkle" in enabled
+ assert "stain" in enabled
+ assert "gaussian_blur" not in enabled
+
+ def test_document_safe_defaults(self) -> None:
+ """Test that default params are document-safe (conservative)."""
+ from shared.augmentation.config import AugmentationConfig
+
+ config = AugmentationConfig()
+
+ # Perspective warp should be very conservative
+ assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
+
+ # Noise should be subtle
+ noise_std = config.gaussian_noise.params.get("std", (5, 15))
+ if isinstance(noise_std, tuple):
+ assert noise_std[1] <= 20 # Max std should be low
+
+ def test_immutability_between_instances(self) -> None:
+ """Test that config instances are independent."""
+ from shared.augmentation.config import AugmentationConfig
+
+ config1 = AugmentationConfig()
+ config2 = AugmentationConfig()
+
+ # Modifying one should not affect the other
+ config1.wrinkle.params["test"] = "value"
+
+ assert "test" not in config2.wrinkle.params
+
+
+class TestAugmentationConfigValidation:
+ """Tests for configuration validation."""
+
+ def test_probability_range_validation(self) -> None:
+ """Test that probability values are validated."""
+ from shared.augmentation.config import AugmentationParams
+
+ # Valid range
+ params = AugmentationParams(probability=0.5)
+ assert params.probability == 0.5
+
+ # Edge cases
+ params_zero = AugmentationParams(probability=0.0)
+ assert params_zero.probability == 0.0
+
+ params_one = AugmentationParams(probability=1.0)
+ assert params_one.probability == 1.0
+
+ def test_config_validate_method(self) -> None:
+ """Test the validate method catches invalid configurations."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+
+ # Invalid probability
+ config = AugmentationConfig(
+ wrinkle=AugmentationParams(probability=1.5), # Invalid
+ )
+
+ with pytest.raises(ValueError, match="probability"):
+ config.validate()
+
+ def test_config_validate_negative_probability(self) -> None:
+ """Test validation catches negative probability."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+
+ config = AugmentationConfig(
+ wrinkle=AugmentationParams(probability=-0.1),
+ )
+
+ with pytest.raises(ValueError, match="probability"):
+ config.validate()
diff --git a/tests/shared/augmentation/test_pipeline.py b/tests/shared/augmentation/test_pipeline.py
new file mode 100644
index 0000000..f317b00
--- /dev/null
+++ b/tests/shared/augmentation/test_pipeline.py
@@ -0,0 +1,338 @@
+"""
+Tests for augmentation pipeline module.
+
+TDD Phase 2: RED - Write tests first, then implement to pass.
+"""
+
+from typing import Any
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+import pytest
+
+
+class TestAugmentationPipeline:
+ """Tests for AugmentationPipeline class."""
+
+ def test_create_with_config(self) -> None:
+ """Test creating pipeline with config."""
+ from shared.augmentation.config import AugmentationConfig
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig()
+ pipeline = AugmentationPipeline(config)
+
+ assert pipeline.config is config
+
+ def test_create_with_seed(self) -> None:
+ """Test creating pipeline with seed for reproducibility."""
+ from shared.augmentation.config import AugmentationConfig
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig(seed=42)
+ pipeline = AugmentationPipeline(config)
+
+ assert pipeline.config.seed == 42
+
+ def test_apply_returns_augmentation_result(self) -> None:
+ """Test that apply returns AugmentationResult."""
+ from shared.augmentation.base import AugmentationResult
+ from shared.augmentation.config import AugmentationConfig
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig()
+ pipeline = AugmentationPipeline(config)
+
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+ result = pipeline.apply(image)
+
+ assert isinstance(result, AugmentationResult)
+ assert result.image is not None
+ assert result.image.shape == image.shape
+
+ def test_apply_with_bboxes(self) -> None:
+ """Test apply with bounding boxes."""
+ from shared.augmentation.config import AugmentationConfig
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig()
+ pipeline = AugmentationPipeline(config)
+
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+ bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
+
+ result = pipeline.apply(image, bboxes)
+
+ # Bboxes should be preserved when preserve_bboxes=True
+ assert result.bboxes is not None
+
+ def test_apply_no_augmentations_enabled(self) -> None:
+ """Test apply when no augmentations are enabled."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ # Disable all augmentations
+ config = AugmentationConfig(
+ lighting_variation=AugmentationParams(enabled=False),
+ )
+ pipeline = AugmentationPipeline(config)
+
+ image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
+ result = pipeline.apply(image)
+
+ # Image should be unchanged (or a copy)
+ np.testing.assert_array_equal(result.image, image)
+
+ def test_apply_does_not_mutate_input(self) -> None:
+ """Test that apply does not mutate input image."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig(
+ lighting_variation=AugmentationParams(enabled=True, probability=1.0),
+ )
+ pipeline = AugmentationPipeline(config)
+
+ image = np.full((100, 100, 3), 128, dtype=np.uint8)
+ original_copy = image.copy()
+
+ pipeline.apply(image)
+
+ np.testing.assert_array_equal(image, original_copy)
+
+ def test_reproducibility_with_seed(self) -> None:
+ """Test that same seed produces same results."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config1 = AugmentationConfig(
+ seed=42,
+ gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
+ )
+ config2 = AugmentationConfig(
+ seed=42,
+ gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
+ )
+
+ pipeline1 = AugmentationPipeline(config1)
+ pipeline2 = AugmentationPipeline(config2)
+
+ image = np.full((100, 100, 3), 128, dtype=np.uint8)
+
+ result1 = pipeline1.apply(image.copy())
+ result2 = pipeline2.apply(image.copy())
+
+ np.testing.assert_array_equal(result1.image, result2.image)
+
+ def test_metadata_contains_applied_augmentations(self) -> None:
+ """Test that metadata lists applied augmentations."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig(
+ seed=42,
+ gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
+ lighting_variation=AugmentationParams(enabled=True, probability=1.0),
+ )
+ pipeline = AugmentationPipeline(config)
+
+ image = np.full((100, 100, 3), 128, dtype=np.uint8)
+ result = pipeline.apply(image)
+
+ assert result.metadata is not None
+ assert "applied_augmentations" in result.metadata
+ # Both should be applied with probability=1.0
+ assert "gaussian_noise" in result.metadata["applied_augmentations"]
+ assert "lighting_variation" in result.metadata["applied_augmentations"]
+
+
+class TestAugmentationPipelineStageOrder:
+ """Tests for pipeline stage ordering."""
+
+ def test_stage_order_defined(self) -> None:
+ """Test that stage order is defined."""
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ assert hasattr(AugmentationPipeline, "STAGE_ORDER")
+ expected_stages = [
+ "geometric",
+ "degradation",
+ "lighting",
+ "texture",
+ "blur",
+ "noise",
+ ]
+ assert AugmentationPipeline.STAGE_ORDER == expected_stages
+
+ def test_stage_mapping_defined(self) -> None:
+ """Test that all augmentation types are mapped to stages."""
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
+
+ expected_mappings = {
+ "perspective_warp": "geometric",
+ "wrinkle": "degradation",
+ "edge_damage": "degradation",
+ "stain": "degradation",
+ "lighting_variation": "lighting",
+ "shadow": "lighting",
+ "paper_texture": "texture",
+ "scanner_artifacts": "texture",
+ "gaussian_blur": "blur",
+ "motion_blur": "blur",
+ "gaussian_noise": "noise",
+ "salt_pepper": "noise",
+ }
+
+ for aug_name, stage in expected_mappings.items():
+ assert aug_name in AugmentationPipeline.STAGE_MAPPING
+ assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
+
+ def test_geometric_before_degradation(self) -> None:
+ """Test that geometric transforms are applied before degradation."""
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ stages = AugmentationPipeline.STAGE_ORDER
+ geometric_idx = stages.index("geometric")
+ degradation_idx = stages.index("degradation")
+
+ assert geometric_idx < degradation_idx
+
+ def test_noise_applied_last(self) -> None:
+ """Test that noise is applied last."""
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ stages = AugmentationPipeline.STAGE_ORDER
+ assert stages[-1] == "noise"
+
+
+class TestAugmentationRegistry:
+ """Tests for augmentation registry."""
+
+ def test_registry_exists(self) -> None:
+ """Test that augmentation registry exists."""
+ from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
+
+ assert isinstance(AUGMENTATION_REGISTRY, dict)
+
+ def test_registry_contains_all_types(self) -> None:
+ """Test that registry contains all augmentation types."""
+ from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
+
+ expected_types = [
+ "perspective_warp",
+ "wrinkle",
+ "edge_damage",
+ "stain",
+ "lighting_variation",
+ "shadow",
+ "gaussian_blur",
+ "motion_blur",
+ "gaussian_noise",
+ "salt_pepper",
+ "paper_texture",
+ "scanner_artifacts",
+ ]
+
+ for aug_type in expected_types:
+ assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
+
+
+class TestPipelinePreview:
+ """Tests for pipeline preview functionality."""
+
+ def test_preview_single_augmentation(self) -> None:
+ """Test previewing a single augmentation."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig(
+ gaussian_noise=AugmentationParams(
+ enabled=True, probability=1.0, params={"std": (10, 10)}
+ ),
+ )
+ pipeline = AugmentationPipeline(config)
+
+ image = np.full((100, 100, 3), 128, dtype=np.uint8)
+ preview = pipeline.preview(image, "gaussian_noise")
+
+ assert preview.shape == image.shape
+ assert preview.dtype == np.uint8
+ # Preview should modify the image
+ assert not np.array_equal(preview, image)
+
+ def test_preview_unknown_augmentation_raises(self) -> None:
+ """Test that previewing unknown augmentation raises error."""
+ from shared.augmentation.config import AugmentationConfig
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig()
+ pipeline = AugmentationPipeline(config)
+
+ image = np.zeros((100, 100, 3), dtype=np.uint8)
+
+ with pytest.raises(ValueError, match="Unknown augmentation"):
+ pipeline.preview(image, "non_existent_augmentation")
+
+ def test_preview_is_deterministic(self) -> None:
+ """Test that preview produces deterministic results."""
+ from shared.augmentation.config import AugmentationConfig, AugmentationParams
+ from shared.augmentation.pipeline import AugmentationPipeline
+
+ config = AugmentationConfig(
+ gaussian_noise=AugmentationParams(enabled=True),
+ )
+ pipeline = AugmentationPipeline(config)
+
+ image = np.full((100, 100, 3), 128, dtype=np.uint8)
+
+ preview1 = pipeline.preview(image, "gaussian_noise")
+ preview2 = pipeline.preview(image, "gaussian_noise")
+
+ np.testing.assert_array_equal(preview1, preview2)
+
+
+class TestPipelineGetAvailableAugmentations:
+ """Tests for getting available augmentations."""
+
+ def test_get_available_augmentations(self) -> None:
+ """Test getting list of available augmentations."""
+ from shared.augmentation.pipeline import get_available_augmentations
+
+ augmentations = get_available_augmentations()
+
+ assert isinstance(augmentations, list)
+ assert len(augmentations) == 12
+
+ # Each item should have name, description, affects_geometry
+ for aug in augmentations:
+ assert "name" in aug
+ assert "description" in aug
+ assert "affects_geometry" in aug
+ assert "stage" in aug
+
+ def test_get_available_augmentations_includes_all_types(self) -> None:
+ """Test that all augmentation types are included."""
+ from shared.augmentation.pipeline import get_available_augmentations
+
+ augmentations = get_available_augmentations()
+ names = [aug["name"] for aug in augmentations]
+
+ expected = [
+ "perspective_warp",
+ "wrinkle",
+ "edge_damage",
+ "stain",
+ "lighting_variation",
+ "shadow",
+ "gaussian_blur",
+ "motion_blur",
+ "gaussian_noise",
+ "salt_pepper",
+ "paper_texture",
+ "scanner_artifacts",
+ ]
+
+ for name in expected:
+ assert name in names
diff --git a/tests/shared/augmentation/test_presets.py b/tests/shared/augmentation/test_presets.py
new file mode 100644
index 0000000..00e4051
--- /dev/null
+++ b/tests/shared/augmentation/test_presets.py
@@ -0,0 +1,102 @@
+"""
+Tests for augmentation presets module.
+
+TDD Phase 4: RED - Write tests first, then implement to pass.
+"""
+
+import pytest
+
+
+class TestPresets:
+ """Tests for augmentation presets."""
+
+ def test_presets_dict_exists(self) -> None:
+ """Test that PRESETS dictionary exists."""
+ from shared.augmentation.presets import PRESETS
+
+ assert isinstance(PRESETS, dict)
+ assert len(PRESETS) > 0
+
+ def test_expected_presets_exist(self) -> None:
+ """Test that expected presets are defined."""
+ from shared.augmentation.presets import PRESETS
+
+ expected_presets = ["conservative", "moderate", "aggressive", "scanned_document"]
+
+ for preset_name in expected_presets:
+ assert preset_name in PRESETS, f"Missing preset: {preset_name}"
+
+ def test_preset_structure(self) -> None:
+ """Test that each preset has required structure."""
+ from shared.augmentation.presets import PRESETS
+
+ for name, preset in PRESETS.items():
+ assert "description" in preset, f"Preset {name} missing description"
+ assert "config" in preset, f"Preset {name} missing config"
+ assert isinstance(preset["description"], str)
+ assert isinstance(preset["config"], dict)
+
+ def test_get_preset_config(self) -> None:
+ """Test getting config from preset."""
+ from shared.augmentation.presets import get_preset_config
+
+ config = get_preset_config("conservative")
+
+ assert config is not None
+ # Should have at least some augmentations defined
+ assert len(config) > 0
+
+ def test_get_preset_config_unknown_raises(self) -> None:
+ """Test that getting unknown preset raises error."""
+ from shared.augmentation.presets import get_preset_config
+
+ with pytest.raises(ValueError, match="Unknown preset"):
+ get_preset_config("nonexistent_preset")
+
+ def test_create_config_from_preset(self) -> None:
+ """Test creating AugmentationConfig from preset."""
+ from shared.augmentation.config import AugmentationConfig
+ from shared.augmentation.presets import create_config_from_preset
+
+ config = create_config_from_preset("moderate")
+
+ assert isinstance(config, AugmentationConfig)
+
+ def test_conservative_preset_is_safe(self) -> None:
+ """Test that conservative preset only enables safe augmentations."""
+ from shared.augmentation.presets import create_config_from_preset
+
+ config = create_config_from_preset("conservative")
+
+ # Should NOT enable geometric transforms
+ assert config.perspective_warp.enabled is False
+
+ # Should NOT enable heavy degradation
+ assert config.wrinkle.enabled is False
+ assert config.edge_damage.enabled is False
+ assert config.stain.enabled is False
+
+ def test_aggressive_preset_enables_more(self) -> None:
+ """Test that aggressive preset enables more augmentations."""
+ from shared.augmentation.presets import create_config_from_preset
+
+ config = create_config_from_preset("aggressive")
+
+ enabled = config.get_enabled_augmentations()
+
+ # Should enable multiple augmentation types
+ assert len(enabled) >= 6
+
+ def test_list_presets(self) -> None:
+ """Test listing available presets."""
+ from shared.augmentation.presets import list_presets
+
+ presets = list_presets()
+
+ assert isinstance(presets, list)
+ assert len(presets) >= 4
+
+ # Each item should have name and description
+ for preset in presets:
+ assert "name" in preset
+ assert "description" in preset
diff --git a/tests/shared/augmentation/transforms/__init__.py b/tests/shared/augmentation/transforms/__init__.py
new file mode 100644
index 0000000..f66e61f
--- /dev/null
+++ b/tests/shared/augmentation/transforms/__init__.py
@@ -0,0 +1 @@
+# Tests for augmentation transforms
diff --git a/tests/shared/test_dataset_augmenter.py b/tests/shared/test_dataset_augmenter.py
new file mode 100644
index 0000000..d9704e9
--- /dev/null
+++ b/tests/shared/test_dataset_augmenter.py
@@ -0,0 +1,293 @@
+"""
+Tests for DatasetAugmenter.
+
+TDD Phase 1: RED - Write tests first, then implement to pass.
+"""
+
+import tempfile
+from pathlib import Path
+
+import numpy as np
+import pytest
+from PIL import Image
+
+
+class TestDatasetAugmenter:
+ """Tests for DatasetAugmenter class."""
+
+ @pytest.fixture
+ def sample_dataset(self, tmp_path: Path) -> Path:
+ """Create a sample YOLO dataset structure."""
+ dataset_dir = tmp_path / "dataset"
+
+ # Create directory structure
+ for split in ["train", "val", "test"]:
+ (dataset_dir / "images" / split).mkdir(parents=True)
+ (dataset_dir / "labels" / split).mkdir(parents=True)
+
+ # Create sample images and labels
+ for i in range(3):
+ # Create 100x100 white image
+ img = Image.new("RGB", (100, 100), color="white")
+ img_path = dataset_dir / "images" / "train" / f"doc_{i}.png"
+ img.save(img_path)
+
+ # Create label with 2 bboxes
+ # Format: class_id x_center y_center width height
+ label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n"
+ label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt"
+ label_path.write_text(label_content)
+
+ # Create data.yaml
+ data_yaml = dataset_dir / "data.yaml"
+ data_yaml.write_text(
+ "path: .\n"
+ "train: images/train\n"
+ "val: images/val\n"
+ "test: images/test\n"
+ "nc: 10\n"
+ "names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n"
+ )
+
+ return dataset_dir
+
+ @pytest.fixture
+ def augmentation_config(self) -> dict:
+ """Create a sample augmentation config."""
+ return {
+ "gaussian_noise": {
+ "enabled": True,
+ "probability": 1.0,
+ "params": {"std": 10},
+ },
+ "gaussian_blur": {
+ "enabled": True,
+ "probability": 1.0,
+ "params": {"kernel_size": 3},
+ },
+ }
+
+ def test_augmenter_creates_additional_images(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that augmenter creates new augmented images."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ augmenter = DatasetAugmenter(augmentation_config)
+
+ # Count original images
+ original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
+ assert original_count == 3
+
+ # Apply augmentation with multiplier=2
+ result = augmenter.augment_dataset(sample_dataset, multiplier=2)
+
+ # Should now have original + 2x augmented = 3 + 6 = 9 images
+ new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
+ assert new_count == 9
+ assert result["augmented_images"] == 6
+
+ def test_augmenter_creates_matching_labels(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that augmenter creates label files for each augmented image."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ augmenter = DatasetAugmenter(augmentation_config)
+ augmenter.augment_dataset(sample_dataset, multiplier=2)
+
+ # Check that each image has a matching label file
+ images = list((sample_dataset / "images" / "train").glob("*.png"))
+ labels = list((sample_dataset / "labels" / "train").glob("*.txt"))
+
+ assert len(images) == len(labels)
+
+ # Check that augmented images have corresponding labels
+ for img_path in images:
+ label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt"
+ assert label_path.exists(), f"Missing label for {img_path.name}"
+
+ def test_augmented_labels_have_valid_format(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that augmented label files have valid YOLO format."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ augmenter = DatasetAugmenter(augmentation_config)
+ augmenter.augment_dataset(sample_dataset, multiplier=1)
+
+ # Check all label files
+ for label_path in (sample_dataset / "labels" / "train").glob("*.txt"):
+ content = label_path.read_text().strip()
+ if not content:
+ continue # Empty labels are valid (background images)
+
+ for line in content.split("\n"):
+ parts = line.split()
+ assert len(parts) == 5, f"Invalid label format in {label_path.name}"
+
+ class_id = int(parts[0])
+ x_center = float(parts[1])
+ y_center = float(parts[2])
+ width = float(parts[3])
+ height = float(parts[4])
+
+ # Check values are in valid range
+ assert 0 <= class_id < 100, f"Invalid class_id: {class_id}"
+ assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}"
+ assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}"
+ assert 0 <= width <= 1, f"Invalid width: {width}"
+ assert 0 <= height <= 1, f"Invalid height: {height}"
+
+ def test_augmented_images_are_different(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that augmented images are actually different from originals."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ # Load original image
+ original_path = sample_dataset / "images" / "train" / "doc_0.png"
+ original_img = np.array(Image.open(original_path))
+
+ augmenter = DatasetAugmenter(augmentation_config)
+ augmenter.augment_dataset(sample_dataset, multiplier=1)
+
+ # Find augmented version
+ aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png"
+ assert aug_path.exists()
+
+ aug_img = np.array(Image.open(aug_path))
+
+ # Images should be different (due to noise/blur)
+ assert not np.array_equal(original_img, aug_img)
+
+ def test_augmented_images_same_size(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that augmented images have same size as originals."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ # Get original size
+ original_path = sample_dataset / "images" / "train" / "doc_0.png"
+ original_img = Image.open(original_path)
+ original_size = original_img.size
+
+ augmenter = DatasetAugmenter(augmentation_config)
+ augmenter.augment_dataset(sample_dataset, multiplier=1)
+
+ # Check all augmented images have same size
+ for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"):
+ img = Image.open(img_path)
+ assert img.size == original_size, f"{img_path.name} has wrong size"
+
+ def test_perspective_warp_updates_bboxes(self, sample_dataset: Path):
+ """Test that perspective_warp augmentation updates bbox coordinates."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ config = {
+ "perspective_warp": {
+ "enabled": True,
+ "probability": 1.0,
+ "params": {"max_warp": 0.05}, # Use larger warp for visible difference
+ },
+ }
+
+ # Read original label
+ original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text()
+ original_bboxes = [line.split() for line in original_label.strip().split("\n")]
+
+ augmenter = DatasetAugmenter(config)
+ augmenter.augment_dataset(sample_dataset, multiplier=1)
+
+ # Read augmented label
+ aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text()
+ aug_bboxes = [line.split() for line in aug_label.strip().split("\n")]
+
+ # Same number of bboxes
+ assert len(original_bboxes) == len(aug_bboxes)
+
+ # At least one bbox should have different coordinates
+ # (perspective warp changes geometry)
+ differences_found = False
+ for orig, aug in zip(original_bboxes, aug_bboxes):
+ # Class ID should be same
+ assert orig[0] == aug[0]
+ # Coordinates might differ
+ if orig[1:] != aug[1:]:
+ differences_found = True
+
+ assert differences_found, "Perspective warp should change bbox coordinates"
+
+ def test_augmenter_only_processes_train_split(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that augmenter only processes train split by default."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ # Add a val image
+ val_img = Image.new("RGB", (100, 100), color="white")
+ val_img.save(sample_dataset / "images" / "val" / "val_doc.png")
+ (sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n")
+
+ augmenter = DatasetAugmenter(augmentation_config)
+ augmenter.augment_dataset(sample_dataset, multiplier=2)
+
+ # Val should still have only 1 image
+ val_count = len(list((sample_dataset / "images" / "val").glob("*.png")))
+ assert val_count == 1
+
+ def test_augmenter_with_multiplier_zero_does_nothing(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that multiplier=0 creates no augmented images."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
+
+ augmenter = DatasetAugmenter(augmentation_config)
+ result = augmenter.augment_dataset(sample_dataset, multiplier=0)
+
+ new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
+ assert new_count == original_count
+ assert result["augmented_images"] == 0
+
+ def test_augmenter_with_seed_is_reproducible(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that same seed produces same augmentation results."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ # Create two separate datasets
+ import shutil
+ dataset1 = sample_dataset
+ dataset2 = sample_dataset.parent / "dataset2"
+ shutil.copytree(dataset1, dataset2)
+
+ # Augment both with same seed
+ augmenter1 = DatasetAugmenter(augmentation_config, seed=42)
+ augmenter1.augment_dataset(dataset1, multiplier=1)
+
+ augmenter2 = DatasetAugmenter(augmentation_config, seed=42)
+ augmenter2.augment_dataset(dataset2, multiplier=1)
+
+ # Compare augmented images
+ aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png"))
+ aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png"))
+
+ assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation"
+
+ def test_augmenter_returns_summary(
+ self, sample_dataset: Path, augmentation_config: dict
+ ):
+ """Test that augmenter returns a summary of what was done."""
+ from shared.augmentation.dataset_augmenter import DatasetAugmenter
+
+ augmenter = DatasetAugmenter(augmentation_config)
+ result = augmenter.augment_dataset(sample_dataset, multiplier=2)
+
+ assert "original_images" in result
+ assert "augmented_images" in result
+ assert "total_images" in result
+ assert result["original_images"] == 3
+ assert result["augmented_images"] == 6
+ assert result["total_images"] == 9
diff --git a/tests/web/test_augmentation_routes.py b/tests/web/test_augmentation_routes.py
new file mode 100644
index 0000000..698d876
--- /dev/null
+++ b/tests/web/test_augmentation_routes.py
@@ -0,0 +1,261 @@
+"""
+Tests for augmentation API routes.
+
+TDD Phase 5: RED - Write tests first, then implement to pass.
+"""
+
+import pytest
+from fastapi.testclient import TestClient
+
+
+class TestAugmentationTypesEndpoint:
+ """Tests for GET /admin/augmentation/types endpoint."""
+
+ def test_list_augmentation_types(
+ self, admin_client: TestClient, admin_token: str
+ ) -> None:
+ """Test listing available augmentation types."""
+ response = admin_client.get(
+ "/api/v1/admin/augmentation/types",
+ headers={"X-Admin-Token": admin_token},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "augmentation_types" in data
+ assert len(data["augmentation_types"]) == 12
+
+ # Check structure
+ aug_type = data["augmentation_types"][0]
+ assert "name" in aug_type
+ assert "description" in aug_type
+ assert "affects_geometry" in aug_type
+ assert "stage" in aug_type
+
+ def test_list_augmentation_types_unauthorized(
+ self, admin_client: TestClient
+ ) -> None:
+ """Test that unauthorized request is rejected."""
+ response = admin_client.get("/api/v1/admin/augmentation/types")
+
+ assert response.status_code == 401
+
+
+class TestAugmentationPresetsEndpoint:
+ """Tests for GET /admin/augmentation/presets endpoint."""
+
+ def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None:
+ """Test listing available presets."""
+ response = admin_client.get(
+ "/api/v1/admin/augmentation/presets",
+ headers={"X-Admin-Token": admin_token},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "presets" in data
+ assert len(data["presets"]) >= 4
+
+ # Check expected presets exist
+ preset_names = [p["name"] for p in data["presets"]]
+ assert "conservative" in preset_names
+ assert "moderate" in preset_names
+ assert "aggressive" in preset_names
+ assert "scanned_document" in preset_names
+
+
+class TestAugmentationPreviewEndpoint:
+ """Tests for POST /admin/augmentation/preview/{document_id} endpoint."""
+
+ def test_preview_augmentation(
+ self,
+ admin_client: TestClient,
+ admin_token: str,
+ sample_document_id: str,
+ ) -> None:
+ """Test previewing augmentation on a document."""
+ response = admin_client.post(
+ f"/api/v1/admin/augmentation/preview/{sample_document_id}",
+ headers={"X-Admin-Token": admin_token},
+ json={
+ "augmentation_type": "gaussian_noise",
+ "params": {"std": 15},
+ },
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "preview_url" in data
+ assert "original_url" in data
+ assert "applied_params" in data
+
+ def test_preview_invalid_augmentation_type(
+ self,
+ admin_client: TestClient,
+ admin_token: str,
+ sample_document_id: str,
+ ) -> None:
+ """Test that invalid augmentation type returns error."""
+ response = admin_client.post(
+ f"/api/v1/admin/augmentation/preview/{sample_document_id}",
+ headers={"X-Admin-Token": admin_token},
+ json={
+ "augmentation_type": "nonexistent",
+ "params": {},
+ },
+ )
+
+ assert response.status_code == 400
+
+ def test_preview_nonexistent_document(
+ self,
+ admin_client: TestClient,
+ admin_token: str,
+ ) -> None:
+ """Test that nonexistent document returns 404."""
+ response = admin_client.post(
+ "/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000",
+ headers={"X-Admin-Token": admin_token},
+ json={
+ "augmentation_type": "gaussian_noise",
+ "params": {},
+ },
+ )
+
+ assert response.status_code == 404
+
+
+class TestAugmentationPreviewConfigEndpoint:
+ """Tests for POST /admin/augmentation/preview-config/{document_id} endpoint."""
+
+ def test_preview_config(
+ self,
+ admin_client: TestClient,
+ admin_token: str,
+ sample_document_id: str,
+ ) -> None:
+ """Test previewing full config on a document."""
+ response = admin_client.post(
+ f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
+ headers={"X-Admin-Token": admin_token},
+ json={
+ "gaussian_noise": {"enabled": True, "probability": 1.0},
+ "lighting_variation": {"enabled": True, "probability": 1.0},
+ "preserve_bboxes": True,
+ "seed": 42,
+ },
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "preview_url" in data
+ assert "original_url" in data
+
+
+class TestAugmentationBatchEndpoint:
+ """Tests for POST /admin/augmentation/batch endpoint."""
+
+ def test_create_augmented_dataset(
+ self,
+ admin_client: TestClient,
+ admin_token: str,
+ sample_dataset_id: str,
+ ) -> None:
+ """Test creating augmented dataset."""
+ response = admin_client.post(
+ "/api/v1/admin/augmentation/batch",
+ headers={"X-Admin-Token": admin_token},
+ json={
+ "dataset_id": sample_dataset_id,
+ "config": {
+ "gaussian_noise": {"enabled": True, "probability": 0.5},
+ "preserve_bboxes": True,
+ },
+ "output_name": "test_augmented_dataset",
+ "multiplier": 2,
+ },
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "task_id" in data
+ assert "status" in data
+ assert "estimated_images" in data
+
+ def test_create_augmented_dataset_invalid_multiplier(
+ self,
+ admin_client: TestClient,
+ admin_token: str,
+ sample_dataset_id: str,
+ ) -> None:
+ """Test that invalid multiplier is rejected."""
+ response = admin_client.post(
+ "/api/v1/admin/augmentation/batch",
+ headers={"X-Admin-Token": admin_token},
+ json={
+ "dataset_id": sample_dataset_id,
+ "config": {},
+ "output_name": "test",
+ "multiplier": 100, # Too high
+ },
+ )
+
+ assert response.status_code == 422 # Validation error
+
+
+class TestAugmentedDatasetsListEndpoint:
+ """Tests for GET /admin/augmentation/datasets endpoint."""
+
+ def test_list_augmented_datasets(
+ self, admin_client: TestClient, admin_token: str
+ ) -> None:
+ """Test listing augmented datasets."""
+ response = admin_client.get(
+ "/api/v1/admin/augmentation/datasets",
+ headers={"X-Admin-Token": admin_token},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "total" in data
+ assert "limit" in data
+ assert "offset" in data
+ assert "datasets" in data
+ assert isinstance(data["datasets"], list)
+
+ def test_list_augmented_datasets_pagination(
+ self, admin_client: TestClient, admin_token: str
+ ) -> None:
+ """Test pagination parameters."""
+ response = admin_client.get(
+ "/api/v1/admin/augmentation/datasets",
+ headers={"X-Admin-Token": admin_token},
+ params={"limit": 5, "offset": 0},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert data["limit"] == 5
+ assert data["offset"] == 0
+
+
+# Fixtures for tests
+@pytest.fixture
+def sample_document_id() -> str:
+ """Provide a sample document ID for testing."""
+ # This would need to be created in test setup
+ return "test-document-id"
+
+
+@pytest.fixture
+def sample_dataset_id() -> str:
+ """Provide a sample dataset ID for testing."""
+ # This would need to be created in test setup
+ return "test-dataset-id"
diff --git a/tests/web/test_dataset_builder.py b/tests/web/test_dataset_builder.py
index ae79912..1c052d4 100644
--- a/tests/web/test_dataset_builder.py
+++ b/tests/web/test_dataset_builder.py
@@ -329,3 +329,414 @@ class TestDatasetBuilder:
results.append([(d["document_id"], d["split"]) for d in docs])
assert results[0] == results[1]
+
+
+class TestAssignSplitsByGroup:
+ """Tests for _assign_splits_by_group method with group_key logic."""
+
+ def _make_mock_doc(self, doc_id, group_key=None):
+ """Create a mock AdminDocument with document_id and group_key."""
+ doc = MagicMock(spec=AdminDocument)
+ doc.document_id = doc_id
+ doc.group_key = group_key
+ doc.page_count = 1
+ return doc
+
+ def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
+ """Documents with unique group_key are distributed across splits."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ # 3 documents, each with unique group_key
+ docs = [
+ self._make_mock_doc(uuid4(), group_key="group-A"),
+ self._make_mock_doc(uuid4(), group_key="group-B"),
+ self._make_mock_doc(uuid4(), group_key="group-C"),
+ ]
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
+
+ # With 3 groups: 70% train = 2, 20% val = 1 (at least 1)
+ train_count = sum(1 for s in result.values() if s == "train")
+ val_count = sum(1 for s in result.values() if s == "val")
+ assert train_count >= 1
+ assert val_count >= 1 # Ensure val is not empty
+
+ def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
+ """Documents with null/empty group_key are each treated as independent single-doc groups."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ docs = [
+ self._make_mock_doc(uuid4(), group_key=None),
+ self._make_mock_doc(uuid4(), group_key=""),
+ self._make_mock_doc(uuid4(), group_key=None),
+ ]
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
+
+ # Each null/empty group_key doc is independent, distributed across splits
+ # With 3 docs: ensure at least 1 in train and 1 in val
+ train_count = sum(1 for s in result.values() if s == "train")
+ val_count = sum(1 for s in result.values() if s == "val")
+ assert train_count >= 1
+ assert val_count >= 1
+
+ def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
+ """Documents with same group_key should be assigned to the same split."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ # 6 documents in 2 groups
+ docs = [
+ self._make_mock_doc(uuid4(), group_key="supplier-A"),
+ self._make_mock_doc(uuid4(), group_key="supplier-A"),
+ self._make_mock_doc(uuid4(), group_key="supplier-A"),
+ self._make_mock_doc(uuid4(), group_key="supplier-B"),
+ self._make_mock_doc(uuid4(), group_key="supplier-B"),
+ self._make_mock_doc(uuid4(), group_key="supplier-B"),
+ ]
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
+
+ # All docs in supplier-A should have same split
+ splits_a = [result[str(d.document_id)] for d in docs[:3]]
+ assert len(set(splits_a)) == 1, "All docs in supplier-A should be in same split"
+
+ # All docs in supplier-B should have same split
+ splits_b = [result[str(d.document_id)] for d in docs[3:]]
+ assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
+
+ def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
+ """Multi-doc groups should be split according to train/val/test ratios."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ # 10 groups with 2 docs each
+ docs = []
+ for i in range(10):
+ group_key = f"group-{i}"
+ docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
+ docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
+
+ # Count groups per split
+ group_splits = {}
+ for doc in docs:
+ split = result[str(doc.document_id)]
+ if doc.group_key not in group_splits:
+ group_splits[doc.group_key] = split
+ else:
+ # Verify same group has same split
+ assert group_splits[doc.group_key] == split
+
+ split_counts = {"train": 0, "val": 0, "test": 0}
+ for split in group_splits.values():
+ split_counts[split] += 1
+
+ # With 10 groups, 70/20/10 -> ~7 train, ~2 val, ~1 test
+ assert split_counts["train"] >= 6
+ assert split_counts["train"] <= 8
+ assert split_counts["val"] >= 1
+ assert split_counts["val"] <= 3
+
+ def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
+ """Mix of single-doc and multi-doc groups should be handled correctly."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ docs = [
+ # Single-doc groups
+ self._make_mock_doc(uuid4(), group_key="single-1"),
+ self._make_mock_doc(uuid4(), group_key="single-2"),
+ self._make_mock_doc(uuid4(), group_key=None),
+ # Multi-doc groups
+ self._make_mock_doc(uuid4(), group_key="multi-A"),
+ self._make_mock_doc(uuid4(), group_key="multi-A"),
+ self._make_mock_doc(uuid4(), group_key="multi-B"),
+ self._make_mock_doc(uuid4(), group_key="multi-B"),
+ ]
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
+
+ # All groups are shuffled and distributed
+ # Ensure at least 1 in train and 1 in val
+ train_count = sum(1 for s in result.values() if s == "train")
+ val_count = sum(1 for s in result.values() if s == "val")
+ assert train_count >= 1
+ assert val_count >= 1
+
+ # Multi-doc groups stay together
+ assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
+ assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
+
+ def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
+ """Same seed should produce same split assignments."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ docs = [
+ self._make_mock_doc(uuid4(), group_key="group-A"),
+ self._make_mock_doc(uuid4(), group_key="group-A"),
+ self._make_mock_doc(uuid4(), group_key="group-B"),
+ self._make_mock_doc(uuid4(), group_key="group-B"),
+ self._make_mock_doc(uuid4(), group_key="group-C"),
+ self._make_mock_doc(uuid4(), group_key="group-C"),
+ ]
+
+ result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
+ result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
+
+ assert result1 == result2
+
+ def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
+ """Different seeds should potentially produce different split assignments."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ # Many groups to increase chance of different results
+ docs = []
+ for i in range(20):
+ group_key = f"group-{i}"
+ docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
+ docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
+
+ result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=1)
+ result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=999)
+
+ # Results should be different (very likely with 20 groups)
+ assert result1 != result2
+
+ def test_all_docs_assigned(self, tmp_path, mock_admin_db):
+ """Every document should be assigned a split."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ docs = [
+ self._make_mock_doc(uuid4(), group_key="group-A"),
+ self._make_mock_doc(uuid4(), group_key="group-A"),
+ self._make_mock_doc(uuid4(), group_key=None),
+ self._make_mock_doc(uuid4(), group_key="single"),
+ ]
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
+
+ assert len(result) == len(docs)
+ for doc in docs:
+ assert str(doc.document_id) in result
+ assert result[str(doc.document_id)] in ["train", "val", "test"]
+
+ def test_empty_documents_list(self, tmp_path, mock_admin_db):
+ """Empty document list should return empty result."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
+
+ assert result == {}
+
+ def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
+ """When all groups have multiple docs, splits should follow ratios."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ # 5 groups with 3 docs each
+ docs = []
+ for i in range(5):
+ group_key = f"group-{i}"
+ for _ in range(3):
+ docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
+
+ # Group splits
+ group_splits = {}
+ for doc in docs:
+ if doc.group_key not in group_splits:
+ group_splits[doc.group_key] = result[str(doc.document_id)]
+
+ split_counts = {"train": 0, "val": 0, "test": 0}
+ for split in group_splits.values():
+ split_counts[split] += 1
+
+ # With 5 groups, 60/20/20 -> 3 train, 1 val, 1 test
+ assert split_counts["train"] >= 2
+ assert split_counts["train"] <= 4
+
+ def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
+ """When all groups have single doc, they are distributed across splits."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+
+ docs = [
+ self._make_mock_doc(uuid4(), group_key="unique-1"),
+ self._make_mock_doc(uuid4(), group_key="unique-2"),
+ self._make_mock_doc(uuid4(), group_key="unique-3"),
+ self._make_mock_doc(uuid4(), group_key=None),
+ self._make_mock_doc(uuid4(), group_key=""),
+ ]
+
+ result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
+
+ # With 5 groups: 60% train = 3, 20% val = 1 (at least 1)
+ train_count = sum(1 for s in result.values() if s == "train")
+ val_count = sum(1 for s in result.values() if s == "val")
+ assert train_count >= 2
+ assert val_count >= 1 # Ensure val is not empty
+
+
+class TestBuildDatasetWithGroupKey:
+ """Integration tests for build_dataset with group_key logic."""
+
+ @pytest.fixture
+ def grouped_documents(self, tmp_path):
+ """Create documents with various group_key configurations."""
+ doc_ids = []
+ docs = []
+
+ # Create 3 groups: 2 multi-doc groups + 2 single-doc groups
+ group_configs = [
+ ("supplier-A", 3), # Multi-doc group: 3 docs
+ ("supplier-B", 2), # Multi-doc group: 2 docs
+ ("unique-1", 1), # Single-doc group
+ (None, 1), # Null group_key
+ ]
+
+ for group_key, count in group_configs:
+ for _ in range(count):
+ doc_id = uuid4()
+ doc_ids.append(doc_id)
+
+ # Create image files
+ doc_dir = tmp_path / "admin_images" / str(doc_id)
+ doc_dir.mkdir(parents=True)
+ for page in range(1, 3):
+ (doc_dir / f"page_{page}.png").write_bytes(b"fake-png")
+
+ # Create mock document
+ doc = MagicMock(spec=AdminDocument)
+ doc.document_id = doc_id
+ doc.filename = f"{doc_id}.pdf"
+ doc.page_count = 2
+ doc.group_key = group_key
+ doc.file_path = str(doc_dir)
+ docs.append(doc)
+
+ return tmp_path, docs
+
+ @pytest.fixture
+ def grouped_annotations(self, grouped_documents):
+ """Create annotations for grouped documents."""
+ tmp_path, docs = grouped_documents
+ annotations = {}
+ for doc in docs:
+ doc_anns = []
+ for page in range(1, 3):
+ ann = MagicMock(spec=AdminAnnotation)
+ ann.document_id = doc.document_id
+ ann.page_number = page
+ ann.class_id = 0
+ ann.class_name = "invoice_number"
+ ann.x_center = 0.5
+ ann.y_center = 0.3
+ ann.width = 0.2
+ ann.height = 0.05
+ doc_anns.append(ann)
+ annotations[str(doc.document_id)] = doc_anns
+ return annotations
+
+ def test_build_respects_group_key_splits(
+ self, grouped_documents, grouped_annotations, mock_admin_db
+ ):
+ """build_dataset should use group_key for split assignment."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ tmp_path, docs = grouped_documents
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+ mock_admin_db.get_documents_by_ids.return_value = docs
+ mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
+ grouped_annotations.get(str(doc_id), [])
+ )
+
+ dataset = mock_admin_db.create_dataset.return_value
+ builder.build_dataset(
+ dataset_id=str(dataset.dataset_id),
+ document_ids=[str(d.document_id) for d in docs],
+ train_ratio=0.5,
+ val_ratio=0.5,
+ seed=42,
+ admin_images_dir=tmp_path / "admin_images",
+ )
+
+ # Get the document splits from add_dataset_documents call
+ call_args = mock_admin_db.add_dataset_documents.call_args
+ docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
+
+ # Build mapping of doc_id -> split
+ doc_split_map = {d["document_id"]: d["split"] for d in docs_added}
+
+ # Verify all docs are assigned a valid split
+ for doc_id in doc_split_map:
+ assert doc_split_map[doc_id] in ("train", "val", "test")
+
+ # Verify multi-doc groups stay together
+ supplier_a_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-A"]
+ supplier_a_splits = [doc_split_map[doc_id] for doc_id in supplier_a_ids]
+ assert len(set(supplier_a_splits)) == 1, "supplier-A docs should be in same split"
+
+ supplier_b_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-B"]
+ supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
+ assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
+
+ def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
+ """All docs with same group_key should go to same split."""
+ from inference.web.services.dataset_builder import DatasetBuilder
+
+ # Create 5 docs all with same group_key
+ docs = []
+ for i in range(5):
+ doc_id = uuid4()
+ doc_dir = tmp_path / "admin_images" / str(doc_id)
+ doc_dir.mkdir(parents=True)
+ (doc_dir / "page_1.png").write_bytes(b"fake-png")
+
+ doc = MagicMock(spec=AdminDocument)
+ doc.document_id = doc_id
+ doc.filename = f"{doc_id}.pdf"
+ doc.page_count = 1
+ doc.group_key = "same-group"
+ docs.append(doc)
+
+ builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
+ mock_admin_db.get_documents_by_ids.return_value = docs
+ mock_admin_db.get_annotations_for_document.return_value = []
+
+ dataset = mock_admin_db.create_dataset.return_value
+ builder.build_dataset(
+ dataset_id=str(dataset.dataset_id),
+ document_ids=[str(d.document_id) for d in docs],
+ train_ratio=0.6,
+ val_ratio=0.2,
+ seed=42,
+ admin_images_dir=tmp_path / "admin_images",
+ )
+
+ call_args = mock_admin_db.add_dataset_documents.call_args
+ docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
+
+ splits = [d["split"] for d in docs_added]
+ # All should be in the same split (one group)
+ assert len(set(splits)) == 1, "All docs with same group_key should be in same split"
diff --git a/tests/web/test_dataset_routes.py b/tests/web/test_dataset_routes.py
index d2add37..2f1e5a5 100644
--- a/tests/web/test_dataset_routes.py
+++ b/tests/web/test_dataset_routes.py
@@ -25,6 +25,9 @@ TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
TEST_TOKEN = "test-admin-token-12345"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
+# Generate 10 unique UUIDs for minimum document count tests
+TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
+
def _make_dataset(**overrides) -> MagicMock:
defaults = dict(
@@ -83,14 +86,14 @@ class TestCreateDatasetRoute:
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
- "total_documents": 2,
- "total_images": 4,
- "total_annotations": 10,
+ "total_documents": 10,
+ "total_images": 20,
+ "total_annotations": 50,
}
request = DatasetCreateRequest(
name="test-dataset",
- document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
+ document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
)
with patch(
@@ -104,6 +107,73 @@ class TestCreateDatasetRoute:
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
+ def test_create_dataset_fails_with_less_than_10_documents(self):
+ """Test that creating dataset fails if fewer than 10 documents provided."""
+ fn = _find_endpoint("create_dataset")
+
+ mock_db = MagicMock()
+
+ # Only 2 documents - should fail
+ request = DatasetCreateRequest(
+ name="test-dataset",
+ document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
+ )
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
+
+ assert exc_info.value.status_code == 400
+ assert "Minimum 10 documents required" in exc_info.value.detail
+ assert "got 2" in exc_info.value.detail
+ # Ensure DB was never called since validation failed first
+ mock_db.create_dataset.assert_not_called()
+
+ def test_create_dataset_fails_with_9_documents(self):
+ """Test boundary condition: 9 documents should fail."""
+ fn = _find_endpoint("create_dataset")
+
+ mock_db = MagicMock()
+
+ # 9 documents - just under the limit
+ request = DatasetCreateRequest(
+ name="test-dataset",
+ document_ids=TEST_DOC_UUIDS[:9],
+ )
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
+
+ assert exc_info.value.status_code == 400
+ assert "Minimum 10 documents required" in exc_info.value.detail
+
+ def test_create_dataset_succeeds_with_exactly_10_documents(self):
+ """Test boundary condition: exactly 10 documents should succeed."""
+ fn = _find_endpoint("create_dataset")
+
+ mock_db = MagicMock()
+ mock_db.create_dataset.return_value = _make_dataset(status="building")
+
+ mock_builder = MagicMock()
+
+ # Exactly 10 documents - should pass
+ request = DatasetCreateRequest(
+ name="test-dataset",
+ document_ids=TEST_DOC_UUIDS[:10],
+ )
+
+ with patch(
+ "inference.web.services.dataset_builder.DatasetBuilder",
+ return_value=mock_builder,
+ ):
+ result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
+
+ mock_db.create_dataset.assert_called_once()
+ assert result.dataset_id == TEST_DATASET_UUID
+
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
@@ -198,3 +268,53 @@ class TestTrainFromDatasetRoute:
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
+
+ def test_incremental_training_with_base_model(self):
+ """Test training with base_model_version_id for incremental training."""
+ fn = _find_endpoint("train_from_dataset")
+
+ mock_model_version = MagicMock()
+ mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
+ mock_model_version.version = "1.0.0"
+
+ mock_db = MagicMock()
+ mock_db.get_dataset.return_value = _make_dataset(status="ready")
+ mock_db.get_model_version.return_value = mock_model_version
+ mock_db.create_training_task.return_value = TEST_TASK_UUID
+
+ base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
+ config = TrainingConfig(base_model_version_id=base_model_uuid)
+ request = DatasetTrainRequest(name="incremental-train", config=config)
+
+ result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
+
+ # Verify model version was looked up
+ mock_db.get_model_version.assert_called_once_with(base_model_uuid)
+
+ # Verify task was created with finetune type
+ call_kwargs = mock_db.create_training_task.call_args[1]
+ assert call_kwargs["task_type"] == "finetune"
+ assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
+ assert call_kwargs["config"]["base_model_version"] == "1.0.0"
+
+ assert result.task_id == TEST_TASK_UUID
+ assert "Incremental training" in result.message
+
+ def test_incremental_training_with_invalid_base_model_fails(self):
+ """Test that training fails if base_model_version_id doesn't exist."""
+ fn = _find_endpoint("train_from_dataset")
+
+ mock_db = MagicMock()
+ mock_db.get_dataset.return_value = _make_dataset(status="ready")
+ mock_db.get_model_version.return_value = None
+
+ base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
+ config = TrainingConfig(base_model_version_id=base_model_uuid)
+ request = DatasetTrainRequest(name="incremental-train", config=config)
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
+ assert exc_info.value.status_code == 404
+ assert "Base model version not found" in exc_info.value.detail
diff --git a/tests/web/test_model_versions.py b/tests/web/test_model_versions.py
new file mode 100644
index 0000000..e28249a
--- /dev/null
+++ b/tests/web/test_model_versions.py
@@ -0,0 +1,399 @@
+"""
+Tests for Model Version API routes.
+"""
+
+import asyncio
+from datetime import datetime, timezone
+from unittest.mock import MagicMock
+from uuid import UUID
+
+import pytest
+
+from inference.data.admin_models import ModelVersion
+from inference.web.api.v1.admin.training import create_training_router
+from inference.web.schemas.admin import (
+ ModelVersionCreateRequest,
+ ModelVersionUpdateRequest,
+)
+
+
+TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
+TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
+TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
+TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
+TEST_TOKEN = "test-admin-token-12345"
+
+
+def _make_model_version(**overrides) -> MagicMock:
+ """Create a mock ModelVersion."""
+ defaults = dict(
+ version_id=UUID(TEST_VERSION_UUID),
+ version="1.0.0",
+ name="test-model-v1",
+ description="Test model version",
+ model_path="/models/test-model-v1.pt",
+ status="inactive",
+ is_active=False,
+ task_id=UUID(TEST_TASK_UUID),
+ dataset_id=UUID(TEST_DATASET_UUID),
+ metrics_mAP=0.935,
+ metrics_precision=0.92,
+ metrics_recall=0.88,
+ document_count=100,
+ training_config={"epochs": 100, "batch_size": 16},
+ file_size=52428800,
+ trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
+ activated_at=None,
+ created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
+ )
+ defaults.update(overrides)
+ model = MagicMock(spec=ModelVersion)
+ for k, v in defaults.items():
+ setattr(model, k, v)
+ return model
+
+
+def _find_endpoint(name: str):
+ """Find endpoint function by name."""
+ router = create_training_router()
+ for route in router.routes:
+ if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
+ return route.endpoint
+ raise AssertionError(f"Endpoint {name} not found")
+
+
+class TestModelVersionRouterRegistration:
+ """Tests that model version endpoints are registered."""
+
+ def test_router_has_model_endpoints(self):
+ router = create_training_router()
+ paths = [route.path for route in router.routes]
+ assert any("models" in p for p in paths)
+
+ def test_has_create_model_version_endpoint(self):
+ endpoint = _find_endpoint("create_model_version")
+ assert endpoint is not None
+
+ def test_has_list_model_versions_endpoint(self):
+ endpoint = _find_endpoint("list_model_versions")
+ assert endpoint is not None
+
+ def test_has_get_active_model_endpoint(self):
+ endpoint = _find_endpoint("get_active_model")
+ assert endpoint is not None
+
+ def test_has_activate_model_version_endpoint(self):
+ endpoint = _find_endpoint("activate_model_version")
+ assert endpoint is not None
+
+
+class TestCreateModelVersionRoute:
+ """Tests for POST /admin/training/models."""
+
+ def test_create_model_version(self):
+ fn = _find_endpoint("create_model_version")
+
+ mock_db = MagicMock()
+ mock_db.create_model_version.return_value = _make_model_version()
+
+ request = ModelVersionCreateRequest(
+ version="1.0.0",
+ name="test-model-v1",
+ model_path="/models/test-model-v1.pt",
+ description="Test model",
+ metrics_mAP=0.935,
+ document_count=100,
+ )
+
+ result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
+
+ mock_db.create_model_version.assert_called_once()
+ assert result.version_id == TEST_VERSION_UUID
+ assert result.status == "inactive"
+ assert result.message == "Model version created successfully"
+
+ def test_create_model_version_with_task_and_dataset(self):
+ fn = _find_endpoint("create_model_version")
+
+ mock_db = MagicMock()
+ mock_db.create_model_version.return_value = _make_model_version()
+
+ request = ModelVersionCreateRequest(
+ version="1.0.0",
+ name="test-model-v1",
+ model_path="/models/test-model-v1.pt",
+ task_id=TEST_TASK_UUID,
+ dataset_id=TEST_DATASET_UUID,
+ )
+
+ result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
+
+ call_kwargs = mock_db.create_model_version.call_args[1]
+ assert call_kwargs["task_id"] == TEST_TASK_UUID
+ assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
+
+
+class TestListModelVersionsRoute:
+ """Tests for GET /admin/training/models."""
+
+ def test_list_model_versions(self):
+ fn = _find_endpoint("list_model_versions")
+
+ mock_db = MagicMock()
+ mock_db.get_model_versions.return_value = (
+ [_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
+ 2,
+ )
+
+ result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
+
+ assert result.total == 2
+ assert len(result.models) == 2
+ assert result.models[0].version == "1.0.0"
+
+ def test_list_model_versions_with_status_filter(self):
+ fn = _find_endpoint("list_model_versions")
+
+ mock_db = MagicMock()
+ mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
+
+ result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
+
+ mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
+ assert result.total == 1
+ assert result.models[0].status == "active"
+
+
+class TestGetActiveModelRoute:
+ """Tests for GET /admin/training/models/active."""
+
+ def test_get_active_model_when_exists(self):
+ fn = _find_endpoint("get_active_model")
+
+ mock_db = MagicMock()
+ mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
+
+ result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
+
+ assert result.has_active_model is True
+ assert result.model is not None
+ assert result.model.is_active is True
+
+ def test_get_active_model_when_none(self):
+ fn = _find_endpoint("get_active_model")
+
+ mock_db = MagicMock()
+ mock_db.get_active_model_version.return_value = None
+
+ result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
+
+ assert result.has_active_model is False
+ assert result.model is None
+
+
+class TestGetModelVersionRoute:
+ """Tests for GET /admin/training/models/{version_id}."""
+
+ def test_get_model_version(self):
+ fn = _find_endpoint("get_model_version")
+
+ mock_db = MagicMock()
+ mock_db.get_model_version.return_value = _make_model_version()
+
+ result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+
+ assert result.version_id == TEST_VERSION_UUID
+ assert result.version == "1.0.0"
+ assert result.name == "test-model-v1"
+ assert result.metrics_mAP == 0.935
+
+ def test_get_model_version_not_found(self):
+ fn = _find_endpoint("get_model_version")
+
+ mock_db = MagicMock()
+ mock_db.get_model_version.return_value = None
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+ assert exc_info.value.status_code == 404
+
+
+class TestUpdateModelVersionRoute:
+ """Tests for PATCH /admin/training/models/{version_id}."""
+
+ def test_update_model_version(self):
+ fn = _find_endpoint("update_model_version")
+
+ mock_db = MagicMock()
+ mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
+
+ request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
+
+ result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
+
+ mock_db.update_model_version.assert_called_once_with(
+ version_id=TEST_VERSION_UUID,
+ name="updated-name",
+ description="Updated description",
+ status=None,
+ )
+ assert result.message == "Model version updated successfully"
+
+ def test_update_model_version_not_found(self):
+ fn = _find_endpoint("update_model_version")
+
+ mock_db = MagicMock()
+ mock_db.update_model_version.return_value = None
+
+ request = ModelVersionUpdateRequest(name="updated-name")
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
+ assert exc_info.value.status_code == 404
+
+
+class TestActivateModelVersionRoute:
+ """Tests for POST /admin/training/models/{version_id}/activate."""
+
+ def test_activate_model_version(self):
+ fn = _find_endpoint("activate_model_version")
+
+ mock_db = MagicMock()
+ mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
+
+ result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+
+ mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
+ assert result.status == "active"
+ assert result.message == "Model version activated for inference"
+
+ def test_activate_model_version_not_found(self):
+ fn = _find_endpoint("activate_model_version")
+
+ mock_db = MagicMock()
+ mock_db.activate_model_version.return_value = None
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+ assert exc_info.value.status_code == 404
+
+
+class TestDeactivateModelVersionRoute:
+ """Tests for POST /admin/training/models/{version_id}/deactivate."""
+
+ def test_deactivate_model_version(self):
+ fn = _find_endpoint("deactivate_model_version")
+
+ mock_db = MagicMock()
+ mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
+
+ result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+
+ assert result.status == "inactive"
+ assert result.message == "Model version deactivated"
+
+ def test_deactivate_model_version_not_found(self):
+ fn = _find_endpoint("deactivate_model_version")
+
+ mock_db = MagicMock()
+ mock_db.deactivate_model_version.return_value = None
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+ assert exc_info.value.status_code == 404
+
+
+class TestArchiveModelVersionRoute:
+ """Tests for POST /admin/training/models/{version_id}/archive."""
+
+ def test_archive_model_version(self):
+ fn = _find_endpoint("archive_model_version")
+
+ mock_db = MagicMock()
+ mock_db.archive_model_version.return_value = _make_model_version(status="archived")
+
+ result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+
+ assert result.status == "archived"
+ assert result.message == "Model version archived"
+
+ def test_archive_active_model_fails(self):
+ fn = _find_endpoint("archive_model_version")
+
+ mock_db = MagicMock()
+ mock_db.archive_model_version.return_value = None
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+ assert exc_info.value.status_code == 400
+
+
+class TestDeleteModelVersionRoute:
+ """Tests for DELETE /admin/training/models/{version_id}."""
+
+ def test_delete_model_version(self):
+ fn = _find_endpoint("delete_model_version")
+
+ mock_db = MagicMock()
+ mock_db.delete_model_version.return_value = True
+
+ result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+
+ mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
+ assert result["message"] == "Model version deleted"
+
+ def test_delete_active_model_fails(self):
+ fn = _find_endpoint("delete_model_version")
+
+ mock_db = MagicMock()
+ mock_db.delete_model_version.return_value = False
+
+ from fastapi import HTTPException
+
+ with pytest.raises(HTTPException) as exc_info:
+ asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
+ assert exc_info.value.status_code == 400
+
+
+class TestModelVersionSchemas:
+ """Tests for model version Pydantic schemas."""
+
+ def test_create_request_validation(self):
+ request = ModelVersionCreateRequest(
+ version="1.0.0",
+ name="test-model",
+ model_path="/models/test.pt",
+ )
+ assert request.version == "1.0.0"
+ assert request.name == "test-model"
+ assert request.document_count == 0
+
+ def test_create_request_with_metrics(self):
+ request = ModelVersionCreateRequest(
+ version="2.0.0",
+ name="test-model-v2",
+ model_path="/models/v2.pt",
+ metrics_mAP=0.95,
+ metrics_precision=0.92,
+ metrics_recall=0.88,
+ document_count=500,
+ )
+ assert request.metrics_mAP == 0.95
+ assert request.document_count == 500
+
+ def test_update_request_partial(self):
+ request = ModelVersionUpdateRequest(name="new-name")
+ assert request.name == "new-name"
+ assert request.description is None
+ assert request.status is None
|