This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -4,6 +4,7 @@ import { DashboardOverview } from './components/DashboardOverview'
import { Dashboard } from './components/Dashboard'
import { DocumentDetail } from './components/DocumentDetail'
import { Training } from './components/Training'
import { DatasetDetail } from './components/DatasetDetail'
import { Models } from './components/Models'
import { Login } from './components/Login'
import { InferenceDemo } from './components/InferenceDemo'
@@ -55,7 +56,14 @@ const App: React.FC = () => {
case 'demo':
return <InferenceDemo />
case 'training':
return <Training />
return <Training onNavigate={handleNavigate} />
case 'dataset-detail':
return (
<DatasetDetail
datasetId={selectedDocId || ''}
onBack={() => setCurrentView('training')}
/>
)
case 'models':
return <Models />
default:

View File

@@ -0,0 +1,118 @@
/**
* Tests for augmentation API endpoints.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { augmentationApi } from './augmentation'
import apiClient from '../client'
// Mock the API client
vi.mock('../client', () => ({
default: {
get: vi.fn(),
post: vi.fn(),
},
}))
describe('augmentationApi', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('getTypes', () => {
it('should fetch augmentation types', async () => {
const mockResponse = {
data: {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
],
},
}
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.getTypes()
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types')
expect(result.augmentation_types).toHaveLength(1)
expect(result.augmentation_types[0].name).toBe('gaussian_noise')
})
})
describe('getPresets', () => {
it('should fetch augmentation presets', async () => {
const mockResponse = {
data: {
presets: [
{ name: 'conservative', description: 'Safe augmentations' },
{ name: 'moderate', description: 'Balanced augmentations' },
],
},
}
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.getPresets()
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets')
expect(result.presets).toHaveLength(2)
})
})
describe('preview', () => {
it('should preview single augmentation', async () => {
const mockResponse = {
data: {
preview_url: '',
original_url: '',
applied_params: { std: 15 },
},
}
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.preview('doc-123', {
augmentation_type: 'gaussian_noise',
params: { std: 15 },
})
expect(apiClient.post).toHaveBeenCalledWith(
'/api/v1/admin/augmentation/preview/doc-123',
{
augmentation_type: 'gaussian_noise',
params: { std: 15 },
},
{ params: { page: 1 } }
)
expect(result.preview_url).toBe('')
})
it('should support custom page number', async () => {
const mockResponse = {
data: {
preview_url: '',
original_url: '',
applied_params: {},
},
}
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
await augmentationApi.preview(
'doc-123',
{ augmentation_type: 'gaussian_noise', params: {} },
2
)
expect(apiClient.post).toHaveBeenCalledWith(
'/api/v1/admin/augmentation/preview/doc-123',
expect.anything(),
{ params: { page: 2 } }
)
})
})
})

View File

@@ -0,0 +1,144 @@
/**
* Augmentation API endpoints.
*
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
*/
import apiClient from '../client'
// Types
export interface AugmentationTypeInfo {
name: string
description: string
affects_geometry: boolean
stage: string
default_params: Record<string, unknown>
}
export interface AugmentationTypesResponse {
augmentation_types: AugmentationTypeInfo[]
}
export interface PresetInfo {
name: string
description: string
config?: Record<string, unknown>
}
export interface PresetsResponse {
presets: PresetInfo[]
}
export interface PreviewRequest {
augmentation_type: string
params: Record<string, unknown>
}
export interface PreviewResponse {
preview_url: string
original_url: string
applied_params: Record<string, unknown>
}
export interface AugmentationParams {
enabled: boolean
probability: number
params: Record<string, unknown>
}
export interface AugmentationConfig {
perspective_warp?: AugmentationParams
wrinkle?: AugmentationParams
edge_damage?: AugmentationParams
stain?: AugmentationParams
lighting_variation?: AugmentationParams
shadow?: AugmentationParams
gaussian_blur?: AugmentationParams
motion_blur?: AugmentationParams
gaussian_noise?: AugmentationParams
salt_pepper?: AugmentationParams
paper_texture?: AugmentationParams
scanner_artifacts?: AugmentationParams
preserve_bboxes?: boolean
seed?: number | null
}
export interface BatchRequest {
dataset_id: string
config: AugmentationConfig
output_name: string
multiplier: number
}
export interface BatchResponse {
task_id: string
status: string
message: string
estimated_images: number
}
// API functions
export const augmentationApi = {
/**
* Fetch available augmentation types.
*/
async getTypes(): Promise<AugmentationTypesResponse> {
const response = await apiClient.get<AugmentationTypesResponse>(
'/api/v1/admin/augmentation/types'
)
return response.data
},
/**
* Fetch augmentation presets.
*/
async getPresets(): Promise<PresetsResponse> {
const response = await apiClient.get<PresetsResponse>(
'/api/v1/admin/augmentation/presets'
)
return response.data
},
/**
* Preview a single augmentation on a document page.
*/
async preview(
documentId: string,
request: PreviewRequest,
page: number = 1
): Promise<PreviewResponse> {
const response = await apiClient.post<PreviewResponse>(
`/api/v1/admin/augmentation/preview/${documentId}`,
request,
{ params: { page } }
)
return response.data
},
/**
* Preview full augmentation config on a document page.
*/
async previewConfig(
documentId: string,
config: AugmentationConfig,
page: number = 1
): Promise<PreviewResponse> {
const response = await apiClient.post<PreviewResponse>(
`/api/v1/admin/augmentation/preview-config/${documentId}`,
config,
{ params: { page } }
)
return response.data
},
/**
* Create an augmented dataset.
*/
async createBatch(request: BatchRequest): Promise<BatchResponse> {
const response = await apiClient.post<BatchResponse>(
'/api/v1/admin/augmentation/batch',
request
)
return response.data
},
}

View File

@@ -0,0 +1,52 @@
import apiClient from '../client'
import type {
DatasetCreateRequest,
DatasetDetailResponse,
DatasetListResponse,
DatasetResponse,
DatasetTrainRequest,
TrainingTaskResponse,
} from '../types'
export const datasetsApi = {
list: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<DatasetListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/datasets', {
params,
})
return data
},
create: async (req: DatasetCreateRequest): Promise<DatasetResponse> => {
const { data } = await apiClient.post('/api/v1/admin/training/datasets', req)
return data
},
getDetail: async (datasetId: string): Promise<DatasetDetailResponse> => {
const { data } = await apiClient.get(
`/api/v1/admin/training/datasets/${datasetId}`
)
return data
},
remove: async (datasetId: string): Promise<{ message: string }> => {
const { data } = await apiClient.delete(
`/api/v1/admin/training/datasets/${datasetId}`
)
return data
},
trainFromDataset: async (
datasetId: string,
req: DatasetTrainRequest
): Promise<TrainingTaskResponse> => {
const { data } = await apiClient.post(
`/api/v1/admin/training/datasets/${datasetId}/train`,
req
)
return data
},
}

View File

@@ -21,14 +21,20 @@ export const documentsApi = {
return data
},
upload: async (file: File): Promise<UploadDocumentResponse> => {
upload: async (file: File, groupKey?: string): Promise<UploadDocumentResponse> => {
const formData = new FormData()
formData.append('file', file)
const params: Record<string, string> = {}
if (groupKey) {
params.group_key = groupKey
}
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
params,
})
return data
},
@@ -77,4 +83,16 @@ export const documentsApi = {
)
return data
},
updateGroupKey: async (
documentId: string,
groupKey: string | null
): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/group-key`,
null,
{ params: { group_key: groupKey } }
)
return data
},
}

View File

@@ -2,3 +2,6 @@ export { documentsApi } from './documents'
export { annotationsApi } from './annotations'
export { trainingApi } from './training'
export { inferenceApi } from './inference'
export { datasetsApi } from './datasets'
export { augmentationApi } from './augmentation'
export { modelsApi } from './models'

View File

@@ -0,0 +1,55 @@
import apiClient from '../client'
import type {
ModelVersionListResponse,
ModelVersionDetailResponse,
ModelVersionResponse,
ActiveModelResponse,
} from '../types'
export const modelsApi = {
list: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<ModelVersionListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/models', {
params,
})
return data
},
getDetail: async (versionId: string): Promise<ModelVersionDetailResponse> => {
const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`)
return data
},
getActive: async (): Promise<ActiveModelResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/models/active')
return data
},
activate: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`)
return data
},
deactivate: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`)
return data
},
archive: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`)
return data
},
delete: async (versionId: string): Promise<{ message: string }> => {
const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`)
return data
},
reload: async (): Promise<{ message: string; reloaded: boolean }> => {
const { data } = await apiClient.post('/api/v1/admin/training/models/reload')
return data
},
}

View File

@@ -8,6 +8,7 @@ export interface DocumentItem {
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
auto_label_error: string | null
upload_source: string
group_key: string | null
created_at: string
updated_at: string
annotation_count?: number
@@ -59,6 +60,7 @@ export interface DocumentDetailResponse {
auto_label_error: string | null
upload_source: string
batch_id: string | null
group_key: string | null
csv_field_values: Record<string, string> | null
can_annotate: boolean
annotation_lock_until: string | null
@@ -113,7 +115,11 @@ export interface ErrorResponse {
export interface UploadDocumentResponse {
document_id: string
filename: string
file_size: number
page_count: number
status: string
group_key: string | null
auto_label_started: boolean
message: string
}
@@ -171,3 +177,165 @@ export interface InferenceResult {
export interface InferenceResponse {
result: InferenceResult
}
// Dataset types
export interface DatasetCreateRequest {
name: string
description?: string
document_ids: string[]
train_ratio?: number
val_ratio?: number
seed?: number
}
export interface DatasetResponse {
dataset_id: string
name: string
status: string
message: string
}
export interface DatasetDocumentItem {
document_id: string
split: string
page_count: number
annotation_count: number
}
export interface DatasetListItem {
dataset_id: string
name: string
description: string | null
status: string
training_status: string | null
active_training_task_id: string | null
total_documents: number
total_images: number
total_annotations: number
created_at: string
}
export interface DatasetListResponse {
total: number
limit: number
offset: number
datasets: DatasetListItem[]
}
export interface DatasetDetailResponse {
dataset_id: string
name: string
description: string | null
status: string
train_ratio: number
val_ratio: number
seed: number
total_documents: number
total_images: number
total_annotations: number
dataset_path: string | null
error_message: string | null
documents: DatasetDocumentItem[]
created_at: string
updated_at: string
}
export interface AugmentationParams {
enabled: boolean
probability: number
params: Record<string, unknown>
}
export interface AugmentationTrainingConfig {
gaussian_noise?: AugmentationParams
perspective_warp?: AugmentationParams
wrinkle?: AugmentationParams
edge_damage?: AugmentationParams
stain?: AugmentationParams
lighting_variation?: AugmentationParams
shadow?: AugmentationParams
gaussian_blur?: AugmentationParams
motion_blur?: AugmentationParams
salt_pepper?: AugmentationParams
paper_texture?: AugmentationParams
scanner_artifacts?: AugmentationParams
preserve_bboxes?: boolean
seed?: number | null
}
export interface DatasetTrainRequest {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs?: number
batch_size?: number
image_size?: number
learning_rate?: number
device?: string
augmentation?: AugmentationTrainingConfig
augmentation_multiplier?: number
}
}
export interface TrainingTaskResponse {
task_id: string
status: string
message: string
}
// Model Version types
export interface ModelVersionItem {
version_id: string
version: string
name: string
status: string
is_active: boolean
metrics_mAP: number | null
document_count: number
trained_at: string | null
activated_at: string | null
created_at: string
}
export interface ModelVersionDetailResponse {
version_id: string
version: string
name: string
description: string | null
model_path: string
status: string
is_active: boolean
task_id: string | null
dataset_id: string | null
metrics_mAP: number | null
metrics_precision: number | null
metrics_recall: number | null
document_count: number
training_config: Record<string, unknown> | null
file_size: number | null
trained_at: string | null
activated_at: string | null
created_at: string
updated_at: string
}
export interface ModelVersionListResponse {
total: number
limit: number
offset: number
models: ModelVersionItem[]
}
export interface ModelVersionResponse {
version_id: string
status: string
message: string
}
export interface ActiveModelResponse {
has_active_model: boolean
model: ModelVersionItem | null
}

View File

@@ -0,0 +1,251 @@
/**
* Tests for AugmentationConfig component.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { AugmentationConfig } from './AugmentationConfig'
import { augmentationApi } from '../api/endpoints/augmentation'
import type { ReactNode } from 'react'
// Mock the API
vi.mock('../api/endpoints/augmentation', () => ({
augmentationApi: {
getTypes: vi.fn(),
getPresets: vi.fn(),
preview: vi.fn(),
previewConfig: vi.fn(),
createBatch: vi.fn(),
},
}))
// Default mock data
const mockTypes = {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise to simulate sensor noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
{
name: 'perspective_warp',
description: 'Applies perspective transformation',
affects_geometry: true,
stage: 'geometric',
default_params: { max_warp: 0.02 },
},
{
name: 'gaussian_blur',
description: 'Applies Gaussian blur',
affects_geometry: false,
stage: 'blur',
default_params: { kernel_size: 5 },
},
],
}
const mockPresets = {
presets: [
{ name: 'conservative', description: 'Safe augmentations for high-quality documents' },
{ name: 'moderate', description: 'Balanced augmentation settings' },
{ name: 'aggressive', description: 'Strong augmentations for data diversity' },
],
}
// Test wrapper with QueryClient
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return ({ children }: { children: ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
}
describe('AugmentationConfig', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes)
vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets)
})
describe('rendering', () => {
it('should render enable checkbox', async () => {
render(
<AugmentationConfig
enabled={false}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument()
})
it('should be collapsed when disabled', () => {
render(
<AugmentationConfig
enabled={false}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
// Config options should not be visible
expect(screen.queryByText(/preset/i)).not.toBeInTheDocument()
})
it('should expand when enabled', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText(/preset/i)).toBeInTheDocument()
})
})
})
describe('preset selection', () => {
it('should display available presets', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText('conservative')).toBeInTheDocument()
expect(screen.getByText('moderate')).toBeInTheDocument()
expect(screen.getByText('aggressive')).toBeInTheDocument()
})
})
it('should call onConfigChange when preset is selected', async () => {
const user = userEvent.setup()
const onConfigChange = vi.fn()
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={onConfigChange}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText('moderate')).toBeInTheDocument()
})
await user.click(screen.getByText('moderate'))
expect(onConfigChange).toHaveBeenCalled()
})
})
describe('enable toggle', () => {
it('should call onEnabledChange when checkbox is toggled', async () => {
const user = userEvent.setup()
const onEnabledChange = vi.fn()
render(
<AugmentationConfig
enabled={false}
onEnabledChange={onEnabledChange}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i }))
expect(onEnabledChange).toHaveBeenCalledWith(true)
})
})
describe('augmentation types', () => {
it('should display augmentation types when in custom mode', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
showCustomOptions={true}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument()
expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument()
})
})
it('should indicate which augmentations affect geometry', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
showCustomOptions={true}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
// perspective_warp affects geometry
const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div')
expect(perspectiveItem).toHaveTextContent(/affects bbox/i)
})
})
})
describe('loading state', () => {
it('should show loading indicator while fetching types', () => {
vi.mocked(augmentationApi.getTypes).mockImplementation(
() => new Promise(() => {})
)
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,136 @@
/**
* AugmentationConfig component for configuring image augmentation during training.
*
* Provides preset selection and optional custom augmentation type configuration.
*/
import React from 'react'
import { Loader2, AlertTriangle } from 'lucide-react'
import { useAugmentation } from '../hooks/useAugmentation'
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
interface AugmentationConfigProps {
enabled: boolean
onEnabledChange: (enabled: boolean) => void
config: Partial<AugmentationConfigType>
onConfigChange: (config: Partial<AugmentationConfigType>) => void
showCustomOptions?: boolean
}
export const AugmentationConfig: React.FC<AugmentationConfigProps> = ({
enabled,
onEnabledChange,
config,
onConfigChange,
showCustomOptions = false,
}) => {
const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation()
const isLoading = isLoadingTypes || isLoadingPresets
const handlePresetSelect = (presetName: string) => {
const preset = presets.find((p) => p.name === presetName)
if (preset && preset.config) {
onConfigChange(preset.config as Partial<AugmentationConfigType>)
} else {
// Apply a basic config based on preset name
const presetConfigs: Record<string, Partial<AugmentationConfigType>> = {
conservative: {
gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } },
gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } },
},
moderate: {
gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } },
gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } },
lighting_variation: { enabled: true, probability: 0.3, params: {} },
perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } },
},
aggressive: {
gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } },
gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } },
motion_blur: { enabled: true, probability: 0.3, params: {} },
lighting_variation: { enabled: true, probability: 0.5, params: {} },
shadow: { enabled: true, probability: 0.3, params: {} },
perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } },
wrinkle: { enabled: true, probability: 0.2, params: {} },
stain: { enabled: true, probability: 0.2, params: {} },
},
}
onConfigChange(presetConfigs[presetName] || {})
}
}
return (
<div className="border border-warm-divider rounded-lg p-4 bg-warm-bg-secondary">
{/* Enable checkbox */}
<label className="flex items-center gap-2 cursor-pointer">
<input
type="checkbox"
checked={enabled}
onChange={(e) => onEnabledChange(e.target.checked)}
className="w-4 h-4 rounded border-warm-divider text-warm-state-info focus:ring-warm-state-info"
aria-label="Enable augmentation"
/>
<span className="text-sm font-medium text-warm-text-secondary">Enable Augmentation</span>
<span className="text-xs text-warm-text-muted">(Simulate real-world document conditions)</span>
</label>
{/* Expanded content when enabled */}
{enabled && (
<div className="mt-4 space-y-4">
{isLoading ? (
<div className="flex items-center justify-center py-4" data-testid="augmentation-loading">
<Loader2 className="w-5 h-5 animate-spin text-warm-state-info" />
<span className="ml-2 text-sm text-warm-text-muted">Loading augmentation options...</span>
</div>
) : (
<>
{/* Preset selection */}
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Preset</label>
<div className="flex flex-wrap gap-2">
{presets.map((preset) => (
<button
key={preset.name}
onClick={() => handlePresetSelect(preset.name)}
className="px-3 py-1.5 text-sm rounded-md border border-warm-divider hover:bg-warm-bg-tertiary transition-colors"
title={preset.description}
>
{preset.name}
</button>
))}
</div>
</div>
{/* Custom options (if enabled) */}
{showCustomOptions && (
<div className="border-t border-warm-divider pt-4">
<h4 className="text-sm font-medium text-warm-text-secondary mb-3">Augmentation Types</h4>
<div className="grid gap-2">
{augmentationTypes.map((type) => (
<div
key={type.name}
className="flex items-center justify-between p-2 bg-warm-bg-primary rounded border border-warm-divider"
>
<div className="flex items-center gap-2">
<span className="text-sm text-warm-text-primary">{type.name}</span>
{type.affects_geometry && (
<span className="flex items-center gap-1 text-xs text-warm-state-warning">
<AlertTriangle size={12} />
affects bbox
</span>
)}
</div>
<span className="text-xs text-warm-text-muted">{type.stage}</span>
</div>
))}
</div>
</div>
)}
</>
)}
</div>
)}
</div>
)
}

View File

@@ -144,6 +144,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Annotations
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Group
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
Auto-label
</th>
@@ -153,13 +156,13 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<tbody>
{isLoading ? (
<tr>
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
Loading documents...
</td>
</tr>
) : documents.length === 0 ? (
<tr>
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
No documents found. Upload your first document to get started.
</td>
</tr>
@@ -213,6 +216,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<td className="py-4 px-4 text-sm text-warm-text-secondary">
{doc.annotation_count || 0} annotations
</td>
<td className="py-4 px-4 text-sm text-warm-text-muted">
{doc.group_key || '-'}
</td>
<td className="py-4 px-4">
{doc.auto_label_status === 'running' && progress && (
<div className="w-full">

View File

@@ -0,0 +1,122 @@
import React from 'react'
import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react'
import { Button } from './Button'
import { useDatasetDetail } from '../hooks/useDatasets'
interface DatasetDetailProps {
datasetId: string
onBack: () => void
}
const SPLIT_STYLES: Record<string, string> = {
train: 'bg-warm-state-info/10 text-warm-state-info',
val: 'bg-warm-state-warning/10 text-warm-state-warning',
test: 'bg-warm-state-success/10 text-warm-state-success',
}
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
if (isLoading) {
return (
<div className="flex items-center justify-center py-20 text-warm-text-muted">
<Loader2 size={24} className="animate-spin mr-2" />Loading dataset...
</div>
)
}
if (error || !dataset) {
return (
<div className="p-8 max-w-7xl mx-auto">
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
<ArrowLeft size={16} />Back
</button>
<p className="text-warm-state-error">Failed to load dataset.</p>
</div>
)
}
const statusIcon = dataset.status === 'ready'
? <Check size={14} className="text-warm-state-success" />
: dataset.status === 'failed'
? <AlertCircle size={14} className="text-warm-state-error" />
: <Loader2 size={14} className="animate-spin text-warm-state-info" />
return (
<div className="p-8 max-w-7xl mx-auto">
{/* Header */}
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
<ArrowLeft size={16} />Back to Datasets
</button>
<div className="flex items-center justify-between mb-6">
<div>
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
{dataset.name} {statusIcon}
</h2>
{dataset.description && (
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
)}
</div>
{dataset.status === 'ready' && (
<Button><Play size={14} className="mr-1" />Start Training</Button>
)}
</div>
{dataset.error_message && (
<div className="bg-warm-state-error/10 border border-warm-state-error/20 rounded-lg p-4 mb-6 text-sm text-warm-state-error">
{dataset.error_message}
</div>
)}
{/* Stats */}
<div className="grid grid-cols-4 gap-4 mb-8">
{[
['Documents', dataset.total_documents],
['Images', dataset.total_images],
['Annotations', dataset.total_annotations],
['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
].map(([label, value]) => (
<div key={String(label)} className="bg-warm-card border border-warm-border rounded-lg p-4">
<p className="text-xs text-warm-text-muted uppercase font-semibold mb-1">{label}</p>
<p className="text-2xl font-bold text-warm-text-primary font-mono">{value}</p>
</div>
))}
</div>
{/* Document list */}
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Documents</h3>
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
<table className="w-full text-left">
<thead className="bg-white border-b border-warm-border">
<tr>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Split</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
</tr>
</thead>
<tbody>
{dataset.documents.map(doc => (
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
<td className="py-3 px-4">
<span className={`inline-flex px-2.5 py-1 rounded-full text-xs font-medium ${SPLIT_STYLES[doc.split] ?? 'bg-warm-border text-warm-text-muted'}`}>
{doc.split}
</span>
</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count}</td>
</tr>
))}
</tbody>
</table>
</div>
<p className="text-xs text-warm-text-muted mt-4">
Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
{dataset.dataset_path && <> | Path: <code className="text-xs">{dataset.dataset_path}</code></>}
</p>
</div>
)
}

View File

@@ -1,8 +1,9 @@
import React, { useState, useRef, useEffect } from 'react'
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
import { Button } from './Button'
import { useDocumentDetail } from '../hooks/useDocumentDetail'
import { useAnnotations } from '../hooks/useAnnotations'
import { useDocuments } from '../hooks/useDocuments'
import { documentsApi } from '../api/endpoints/documents'
import type { AnnotationItem } from '../api/types'
@@ -26,7 +27,7 @@ const FIELD_CLASSES: Record<number, string> = {
}
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
const { document, annotations, isLoading } = useDocumentDetail(docId)
const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
const {
createAnnotation,
updateAnnotation,
@@ -34,10 +35,13 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
isCreating,
isDeleting,
} = useAnnotations(docId)
const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
const [selectedId, setSelectedId] = useState<string | null>(null)
const [zoom, setZoom] = useState(100)
const [isDrawing, setIsDrawing] = useState(false)
const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
const [selectedClassId, setSelectedClassId] = useState<number>(0)
@@ -426,6 +430,65 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
{new Date(document.created_at).toLocaleDateString()}
</span>
</div>
<div className="flex justify-between items-center text-xs">
<span className="text-warm-text-muted">Group</span>
{isEditingGroupKey ? (
<div className="flex items-center gap-1">
<input
type="text"
value={editGroupKeyValue}
onChange={(e) => setEditGroupKeyValue(e.target.value)}
className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
placeholder="group key"
autoFocus
/>
<button
onClick={() => {
updateGroupKey(
{ documentId: docId, groupKey: editGroupKeyValue.trim() || null },
{
onSuccess: () => {
setIsEditingGroupKey(false)
refetch()
},
onError: () => {
alert('Failed to update group key. Please try again.')
},
}
)
}}
disabled={isUpdatingGroupKey}
className="p-0.5 text-warm-state-success hover:bg-warm-hover rounded"
>
<Check size={14} />
</button>
<button
onClick={() => {
setIsEditingGroupKey(false)
setEditGroupKeyValue(document.group_key || '')
}}
className="p-0.5 text-warm-state-error hover:bg-warm-hover rounded"
>
<X size={14} />
</button>
</div>
) : (
<div className="flex items-center gap-1">
<span className="text-warm-text-secondary font-medium">
{document.group_key || '-'}
</span>
<button
onClick={() => {
setEditGroupKeyValue(document.group_key || '')
setIsEditingGroupKey(true)
}}
className="p-0.5 text-warm-text-muted hover:text-warm-text-secondary hover:bg-warm-hover rounded"
>
<Edit2 size={12} />
</button>
</div>
)}
</div>
</div>
</div>
</div>

View File

@@ -1,73 +1,108 @@
import React from 'react';
import React, { useState } from 'react';
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
import { Loader2, Power, CheckCircle } from 'lucide-react';
import { Button } from './Button';
import { useModels, useModelDetail } from '../hooks';
import type { ModelVersionItem } from '../api/types';
const CHART_DATA = [
{ name: 'Model A', value: 75 },
{ name: 'Model B', value: 82 },
{ name: 'Model C', value: 95 },
{ name: 'Model D', value: 68 },
];
const METRICS_DATA = [
{ name: 'Precision', value: 88 },
{ name: 'Recall', value: 76 },
{ name: 'F1 Score', value: 91 },
{ name: 'Accuracy', value: 82 },
];
const JOBS = [
{ id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
{ id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
{ id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
];
const formatDate = (dateString: string | null): string => {
if (!dateString) return 'N/A';
return new Date(dateString).toLocaleString();
};
export const Models: React.FC = () => {
const [selectedModel, setSelectedModel] = useState<ModelVersionItem | null>(null);
const { models, isLoading, activateModel, isActivating } = useModels();
const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
// Build chart data from selected model's metrics
const metricsData = modelDetail ? [
{ name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
{ name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
{ name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
] : [
{ name: 'Precision', value: 0 },
{ name: 'Recall', value: 0 },
{ name: 'mAP', value: 0 },
];
// Build comparison chart from all models (with placeholder if empty)
const chartData = models.length > 0
? models.slice(0, 4).map(m => ({
name: m.version,
value: (m.metrics_mAP ?? 0) * 100,
}))
: [
{ name: 'Model A', value: 0 },
{ name: 'Model B', value: 0 },
{ name: 'Model C', value: 0 },
{ name: 'Model D', value: 0 },
];
return (
<div className="p-8 max-w-7xl mx-auto flex gap-8">
{/* Left: Job History */}
<div className="flex-1">
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Recent Training Jobs</h3>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Model Versions</h3>
<div className="space-y-4">
{JOBS.map(job => (
<div key={job.id} className="bg-warm-card border border-warm-border rounded-lg p-5 shadow-sm hover:border-warm-divider transition-colors">
<div className="flex justify-between items-start mb-2">
<div>
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">{job.name}</h4>
<p className="text-sm text-warm-text-muted">Started {job.date}</p>
</div>
<span className={`px-3 py-1 rounded-full text-xs font-medium ${job.status === 'Running' ? 'bg-warm-selected text-warm-text-secondary' : 'bg-warm-selected text-warm-state-success'}`}>
{job.status}
</span>
</div>
{job.status === 'Running' ? (
<div className="mt-4">
<div className="h-2 w-full bg-warm-selected rounded-full overflow-hidden">
<div className="h-full bg-warm-text-secondary w-[65%] rounded-full"></div>
</div>
{isLoading ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="animate-spin text-warm-text-muted" size={32} />
</div>
) : models.length === 0 ? (
<div className="text-center py-12 text-warm-text-muted">
No model versions found. Complete a training task to create a model version.
</div>
) : (
<div className="space-y-4">
{models.map(model => (
<div
key={model.version_id}
onClick={() => setSelectedModel(model)}
className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
selectedModel?.version_id === model.version_id
? 'border-warm-text-secondary'
: 'border-warm-border hover:border-warm-divider'
}`}
>
<div className="flex justify-between items-start mb-2">
<div>
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">
{model.name}
{model.is_active && <CheckCircle size={16} className="inline ml-2 text-warm-state-info" />}
</h4>
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
</div>
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
model.is_active
? 'bg-warm-state-info/10 text-warm-state-info'
: 'bg-warm-selected text-warm-state-success'
}`}>
{model.is_active ? 'Active' : model.status}
</span>
</div>
) : (
<div className="mt-4 flex gap-8">
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Success</span>
<span className="text-lg font-mono text-warm-text-secondary">{job.success}</span>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Documents</span>
<span className="text-lg font-mono text-warm-text-secondary">{model.document_count}</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Performance</span>
<span className="text-lg font-mono text-warm-text-secondary">{job.metrics}%</span>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">mAP</span>
<span className="text-lg font-mono text-warm-text-secondary">
{model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Completed</span>
<span className="text-lg font-mono text-warm-text-secondary">100%</span>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Version</span>
<span className="text-lg font-mono text-warm-text-secondary">{model.version}</span>
</div>
</div>
)}
</div>
))}
</div>
</div>
))}
</div>
)}
</div>
{/* Right: Model Detail */}
@@ -75,27 +110,34 @@ export const Models: React.FC = () => {
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
<div className="flex justify-between items-center mb-6">
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
<span className="text-sm font-medium text-warm-state-success">Completed</span>
<span className={`text-sm font-medium ${
selectedModel?.is_active ? 'text-warm-state-info' : 'text-warm-state-success'
}`}>
{selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
</span>
</div>
<div className="mb-8">
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
<p className="font-medium text-warm-text-primary">Invoices Q4 v2.1</p>
<p className="font-medium text-warm-text-primary">
{selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
</p>
</div>
<div className="space-y-8">
{/* Chart 1 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Bar Rate Metrics</h4>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Model Comparison (mAP)</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={CHART_DATA}>
<BarChart data={chartData}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" hide />
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
<YAxis hide domain={[0, 100]} />
<Tooltip
cursor={{fill: '#F1F0ED'}}
<Tooltip
cursor={{fill: '#F1F0ED'}}
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
formatter={(value: number) => [`${value.toFixed(1)}%`, 'mAP']}
/>
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
@@ -105,14 +147,17 @@ export const Models: React.FC = () => {
{/* Chart 2 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Entity Extraction Accuracy</h4>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Performance Metrics</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={METRICS_DATA}>
<BarChart data={metricsData}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
<YAxis hide domain={[0, 100]} />
<Tooltip cursor={{fill: '#F1F0ED'}} />
<Tooltip
cursor={{fill: '#F1F0ED'}}
formatter={(value: number) => [`${value.toFixed(1)}%`, 'Score']}
/>
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
</ResponsiveContainer>
@@ -121,14 +166,43 @@ export const Models: React.FC = () => {
</div>
<div className="mt-8 space-y-3">
<Button className="w-full">Download Model</Button>
{selectedModel && !selectedModel.is_active ? (
<Button
className="w-full"
onClick={() => activateModel(selectedModel.version_id)}
disabled={isActivating}
>
{isActivating ? (
<>
<Loader2 size={16} className="mr-2 animate-spin" />
Activating...
</>
) : (
<>
<Power size={16} className="mr-2" />
Activate for Inference
</>
)}
</Button>
) : (
<Button className="w-full" disabled={!selectedModel}>
{selectedModel?.is_active ? (
<>
<CheckCircle size={16} className="mr-2" />
Currently Active
</>
) : (
'Select a Model'
)}
</Button>
)}
<div className="flex gap-3">
<Button variant="secondary" className="flex-1">View Logs</Button>
<Button variant="secondary" className="flex-1">Use as Base</Button>
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>View Logs</Button>
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>Use as Base</Button>
</div>
</div>
</div>
</div>
</div>
);
};
};

View File

@@ -1,113 +1,482 @@
import React, { useState } from 'react';
import { Check, AlertCircle } from 'lucide-react';
import { Button } from './Button';
import { DocumentStatus } from '../types';
import React, { useState, useMemo } from 'react'
import { useQuery } from '@tanstack/react-query'
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
import { Button } from './Button'
import { AugmentationConfig } from './AugmentationConfig'
import { useDatasets } from '../hooks/useDatasets'
import { useTrainingDocuments } from '../hooks/useTraining'
import { trainingApi } from '../api/endpoints'
import type { DatasetListItem } from '../api/types'
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
export const Training: React.FC = () => {
const [split, setSplit] = useState(80);
type Tab = 'datasets' | 'create'
const docs = [
{ id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
{ id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
{ id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
{ id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
];
interface TrainingProps {
onNavigate?: (view: string, id?: string) => void
}
const STATUS_STYLES: Record<string, string> = {
ready: 'bg-warm-state-success/10 text-warm-state-success',
building: 'bg-warm-state-info/10 text-warm-state-info',
training: 'bg-warm-state-info/10 text-warm-state-info',
failed: 'bg-warm-state-error/10 text-warm-state-error',
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
running: 'bg-warm-state-info/10 text-warm-state-info',
}
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
// If there's an active training task, show training status
const displayStatus = trainingStatus === 'running'
? 'training'
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
? 'pending'
: status
return (
<div className="p-8 max-w-7xl mx-auto h-[calc(100vh-56px)] flex gap-8">
{/* Document Selection List */}
<div className="flex-1 flex flex-col">
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Document Selection</h2>
<div className="flex-1 bg-warm-card border border-warm-border rounded-lg overflow-hidden flex flex-col shadow-sm">
<div className="overflow-auto flex-1">
<table className="w-full text-left">
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
<tr>
<th className="py-3 pl-6 pr-4 w-12"><input type="checkbox" className="rounded border-warm-divider"/></th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document name</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Date</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
</tr>
</thead>
<tbody>
{docs.map(doc => (
<tr key={doc.id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 pl-6 pr-4"><input type="checkbox" defaultChecked className="rounded border-warm-divider accent-warm-state-info"/></td>
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{doc.name}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.date}</td>
<td className="py-3 px-4">
{doc.status === DocumentStatus.VERIFIED ? (
<div className="flex items-center text-warm-state-success text-sm font-medium">
<div className="w-5 h-5 rounded-full bg-warm-state-success flex items-center justify-center text-white mr-2">
<Check size={12} strokeWidth={3}/>
</div>
Verified
</div>
) : (
<div className="flex items-center text-warm-text-muted text-sm">
<div className="w-5 h-5 rounded-full bg-[#BDBBB5] flex items-center justify-center text-white mr-2">
<span className="font-bold text-[10px]">!</span>
</div>
Partial
</div>
)}
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
</div>
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
{displayStatus}
</span>
)
}
{/* Configuration Panel */}
<div className="w-96">
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
<h3 className="text-lg font-semibold text-warm-text-primary mb-6">Training Configuration</h3>
<div className="space-y-6">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Model Name</label>
<input
type="text"
placeholder="e.g. Invoices Q4"
// --- Train Dialog ---
interface TrainDialogProps {
dataset: DatasetListItem
onClose: () => void
onSubmit: (config: {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs: number
batch_size: number
augmentation?: AugmentationConfigType
augmentation_multiplier?: number
}
}) => void
isPending: boolean
}
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
const [name, setName] = useState(`train-${dataset.name}`)
const [epochs, setEpochs] = useState(100)
const [batchSize, setBatchSize] = useState(16)
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
// Fetch available trained models
const { data: modelsData } = useQuery({
queryKey: ['training', 'models', 'completed'],
queryFn: () => trainingApi.getModels({ status: 'completed' }),
})
const completedModels = modelsData?.models ?? []
const handleSubmit = () => {
onSubmit({
name,
config: {
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
epochs,
batch_size: batchSize,
augmentation: augmentationEnabled
? (augmentationConfig as AugmentationConfigType)
: undefined,
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
},
})
}
return (
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
<p className="text-sm text-warm-text-muted mb-4">
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
</p>
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
<input type="text" value={name} onChange={e => setName(e.target.value)}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
{/* Base Model Selection */}
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
<select
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
onChange={e => {
if (e.target.value === 'pretrained') {
setBaseModelType('pretrained')
setBaseModelVersionId(null)
} else {
setBaseModelType('existing')
setBaseModelVersionId(e.target.value)
}
}}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
>
<option value="pretrained">yolo11n.pt (Pretrained)</option>
{completedModels.map(m => (
<option key={m.task_id} value={m.task_id}>
{m.name} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
</option>
))}
</select>
<p className="text-xs text-warm-text-muted mt-1">
{baseModelType === 'pretrained'
? 'Start from pretrained YOLO model'
: 'Continue training from an existing model (incremental training)'}
</p>
</div>
<div className="flex gap-4">
<div className="flex-1">
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
<input
id="train-epochs"
type="number"
min={1}
max={1000}
value={epochs}
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Base Model</label>
<select className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none">
<option>LayoutLMv3 (Standard)</option>
<option>Donut (Beta)</option>
</select>
</div>
<div>
<div className="flex justify-between mb-2">
<label className="block text-sm font-medium text-warm-text-secondary">Train/Test Split</label>
<span className="text-xs font-mono text-warm-text-muted">{split}% / {100-split}%</span>
</div>
<input
type="range"
min="50"
max="95"
value={split}
onChange={(e) => setSplit(parseInt(e.target.value))}
className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
<div className="flex-1">
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
<input
id="train-batch-size"
type="number"
min={1}
max={128}
value={batchSize}
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
</div>
{/* Augmentation Configuration */}
<AugmentationConfig
enabled={augmentationEnabled}
onEnabledChange={setAugmentationEnabled}
config={augmentationConfig}
onConfigChange={setAugmentationConfig}
/>
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
{augmentationEnabled && (
<div>
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
Augmentation Multiplier
</label>
<input
id="aug-multiplier"
type="number"
min={1}
max={10}
value={augmentationMultiplier}
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
<p className="text-xs text-warm-text-muted mt-1">
Number of augmented copies per original image (1-10)
</p>
</div>
)}
</div>
<div className="flex justify-end gap-3 mt-6">
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
</Button>
</div>
</div>
</div>
)
}
// --- Dataset List ---
const DatasetList: React.FC<{
onNavigate?: (view: string, id?: string) => void
onSwitchTab: (tab: Tab) => void
}> = ({ onNavigate, onSwitchTab }) => {
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
const handleTrain = (config: {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs: number
batch_size: number
augmentation?: AugmentationConfigType
augmentation_multiplier?: number
}
}) => {
if (!trainTarget) return
// Pass config to the training API
const trainRequest = {
name: config.name,
config: config.config,
}
trainFromDataset(
{ datasetId: trainTarget.dataset_id, req: trainRequest },
{ onSuccess: () => setTrainTarget(null) },
)
}
if (isLoading) {
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
}
if (datasets.length === 0) {
return (
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
<Database size={48} className="mb-4 opacity-40" />
<p className="text-lg mb-2">No datasets yet</p>
<p className="text-sm mb-4">Create a dataset to start training</p>
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
</div>
)
}
return (
<>
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
<table className="w-full text-left">
<thead className="bg-white border-b border-warm-border">
<tr>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
</tr>
</thead>
<tbody>
{datasets.map(ds => (
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
<td className="py-3 px-4">
<div className="flex gap-1">
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
<Eye size={14} />
</button>
{ds.status === 'ready' && (
<button title="Train" onClick={() => setTrainTarget(ds)}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
<Play size={14} />
</button>
)}
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
disabled={isDeleting}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors">
<Trash2 size={14} />
</button>
</div>
</td>
</tr>
))}
</tbody>
</table>
</div>
{trainTarget && (
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
)}
</>
)
}
// --- Create Dataset ---
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
const { createDatasetAsync, isCreating } = useDatasets()
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
const [name, setName] = useState('')
const [description, setDescription] = useState('')
const [trainRatio, setTrainRatio] = useState(0.7)
const [valRatio, setValRatio] = useState(0.2)
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
const toggleDoc = (id: string) => {
setSelectedIds(prev => {
const next = new Set(prev)
if (next.has(id)) { next.delete(id) } else { next.add(id) }
return next
})
}
const toggleAll = () => {
if (selectedIds.size === documents.length) {
setSelectedIds(new Set())
} else {
setSelectedIds(new Set(documents.map((d) => d.document_id)))
}
}
const handleCreate = async () => {
await createDatasetAsync({
name,
description: description || undefined,
document_ids: [...selectedIds],
train_ratio: trainRatio,
val_ratio: valRatio,
})
onSwitchTab('datasets')
}
return (
<div className="flex gap-8">
{/* Document selection */}
<div className="flex-1 flex flex-col">
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
{isLoadingDocs ? (
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
) : (
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
<div className="overflow-auto max-h-[calc(100vh-240px)]">
<table className="w-full text-left">
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
<tr>
<th className="py-3 pl-6 pr-4 w-12">
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
</tr>
</thead>
<tbody>
{documents.map((doc) => (
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
onClick={() => toggleDoc(doc.document_id)}>
<td className="py-3 pl-6 pr-4">
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
</td>
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
)}
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
</div>
{/* Config panel */}
<div className="w-80">
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
<div className="flex gap-2 text-sm">
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Train</span>
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Val</span>
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Test</span>
<input type="number" value={testRatio} readOnly
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
</div>
</div>
</div>
<div className="pt-4 border-t border-warm-border">
<Button className="w-full h-12">Start Training</Button>
{selectedIds.size > 0 && selectedIds.size < 10 && (
<p className="text-xs text-warm-state-warning mb-2">
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
</p>
)}
<Button className="w-full h-11" onClick={handleCreate}
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
</Button>
</div>
</div>
</div>
</div>
</div>
);
};
)
}
// --- Main Training Component ---
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
const [activeTab, setActiveTab] = useState<Tab>('datasets')
return (
<div className="p-8 max-w-7xl mx-auto">
<div className="flex items-center justify-between mb-6">
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
</div>
{/* Tabs */}
<div className="flex gap-1 mb-6 border-b border-warm-border">
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
<button key={key} onClick={() => setActiveTab(key)}
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
activeTab === key
? 'border-warm-state-info text-warm-state-info'
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
}`}>
{label}
</button>
))}
</div>
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
</div>
)
}

View File

@@ -11,6 +11,7 @@ interface UploadModalProps {
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
const [isDragging, setIsDragging] = useState(false)
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
const [groupKey, setGroupKey] = useState('')
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
const [errorMessage, setErrorMessage] = useState('')
const fileInputRef = useRef<HTMLInputElement>(null)
@@ -61,10 +62,13 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
// Upload files one by one
for (const file of selectedFiles) {
await new Promise<void>((resolve, reject) => {
uploadDocument(file, {
onSuccess: () => resolve(),
onError: (error: Error) => reject(error),
})
uploadDocument(
{ file, groupKey: groupKey || undefined },
{
onSuccess: () => resolve(),
onError: (error: Error) => reject(error),
}
)
})
}
@@ -72,6 +76,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
setTimeout(() => {
onClose()
setSelectedFiles([])
setGroupKey('')
setUploadStatus('idle')
}, 1500)
} catch (error) {
@@ -85,6 +90,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
return // Prevent closing during upload
}
setSelectedFiles([])
setGroupKey('')
setUploadStatus('idle')
setErrorMessage('')
onClose()
@@ -173,6 +179,26 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
</div>
)}
{/* Group Key Input */}
{selectedFiles.length > 0 && (
<div className="mb-6">
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
Group Key (optional)
</label>
<input
type="text"
value={groupKey}
onChange={(e) => setGroupKey(e.target.value)}
placeholder="e.g., 2024-Q1, supplier-abc, project-name"
className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
disabled={uploadStatus === 'uploading'}
/>
<p className="text-xs text-warm-text-muted mt-1">
Use group keys to organize documents into logical groups
</p>
</div>
)}
{/* Status Messages */}
{uploadStatus === 'success' && (
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">

View File

@@ -2,3 +2,6 @@ export { useDocuments } from './useDocuments'
export { useDocumentDetail } from './useDocumentDetail'
export { useAnnotations } from './useAnnotations'
export { useTraining, useTrainingDocuments } from './useTraining'
export { useDatasets, useDatasetDetail } from './useDatasets'
export { useAugmentation } from './useAugmentation'
export { useModels, useModelDetail, useActiveModel } from './useModels'

View File

@@ -0,0 +1,226 @@
/**
* Tests for useAugmentation hook.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { renderHook, waitFor } from '@testing-library/react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { augmentationApi } from '../api/endpoints/augmentation'
import { useAugmentation } from './useAugmentation'
import type { ReactNode } from 'react'
// Mock the API
vi.mock('../api/endpoints/augmentation', () => ({
augmentationApi: {
getTypes: vi.fn(),
getPresets: vi.fn(),
preview: vi.fn(),
previewConfig: vi.fn(),
createBatch: vi.fn(),
},
}))
// Test wrapper with QueryClient
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return ({ children }: { children: ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
}
describe('useAugmentation', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('getTypes', () => {
it('should fetch augmentation types', async () => {
const mockTypes = {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
{
name: 'perspective_warp',
description: 'Applies perspective warp',
affects_geometry: true,
stage: 'geometric',
default_params: { max_warp: 0.02 },
},
],
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce(mockTypes)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.augmentationTypes).toHaveLength(2)
expect(result.current.augmentationTypes[0].name).toBe('gaussian_noise')
})
it('should handle error when fetching types', async () => {
vi.mocked(augmentationApi.getTypes).mockRejectedValueOnce(new Error('Network error'))
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.typesError).toBeTruthy()
})
})
describe('getPresets', () => {
it('should fetch augmentation presets', async () => {
const mockPresets = {
presets: [
{ name: 'conservative', description: 'Safe augmentations' },
{ name: 'moderate', description: 'Balanced augmentations' },
{ name: 'aggressive', description: 'Strong augmentations' },
],
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce(mockPresets)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingPresets).toBe(false)
})
expect(result.current.presets).toHaveLength(3)
expect(result.current.presets[0].name).toBe('conservative')
})
})
describe('preview', () => {
it('should preview single augmentation', async () => {
const mockPreview = {
preview_url: '',
original_url: '',
applied_params: { std: 15 },
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.preview).mockResolvedValueOnce(mockPreview)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
// Call preview mutation
result.current.preview({
documentId: 'doc-123',
augmentationType: 'gaussian_noise',
params: { std: 15 },
page: 1,
})
await waitFor(() => {
expect(augmentationApi.preview).toHaveBeenCalledWith(
'doc-123',
{ augmentation_type: 'gaussian_noise', params: { std: 15 } },
1
)
})
})
it('should track preview loading state', async () => {
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.preview).mockImplementation(
() => new Promise((resolve) => setTimeout(resolve, 100))
)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.isPreviewing).toBe(false)
result.current.preview({
documentId: 'doc-123',
augmentationType: 'gaussian_noise',
params: {},
page: 1,
})
// State update happens asynchronously
await waitFor(() => {
expect(result.current.isPreviewing).toBe(true)
})
})
})
describe('createBatch', () => {
it('should create augmented dataset', async () => {
const mockResponse = {
task_id: 'task-123',
status: 'pending',
message: 'Augmentation task queued',
estimated_images: 100,
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.createBatch).mockResolvedValueOnce(mockResponse)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
result.current.createBatch({
dataset_id: 'dataset-123',
config: {
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
},
output_name: 'augmented-dataset',
multiplier: 2,
})
await waitFor(() => {
expect(augmentationApi.createBatch).toHaveBeenCalledWith({
dataset_id: 'dataset-123',
config: {
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
},
output_name: 'augmented-dataset',
multiplier: 2,
})
})
})
})
})

View File

@@ -0,0 +1,121 @@
/**
* Hook for managing augmentation operations.
*
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
*/
import { useQuery, useMutation } from '@tanstack/react-query'
import {
augmentationApi,
type AugmentationTypesResponse,
type PresetsResponse,
type PreviewResponse,
type BatchRequest,
type BatchResponse,
type AugmentationConfig,
} from '../api/endpoints/augmentation'
interface PreviewParams {
documentId: string
augmentationType: string
params: Record<string, unknown>
page?: number
}
interface PreviewConfigParams {
documentId: string
config: AugmentationConfig
page?: number
}
export const useAugmentation = () => {
// Fetch augmentation types
const {
data: typesData,
isLoading: isLoadingTypes,
error: typesError,
} = useQuery<AugmentationTypesResponse>({
queryKey: ['augmentation', 'types'],
queryFn: () => augmentationApi.getTypes(),
staleTime: 5 * 60 * 1000, // Cache for 5 minutes
})
// Fetch presets
const {
data: presetsData,
isLoading: isLoadingPresets,
error: presetsError,
} = useQuery<PresetsResponse>({
queryKey: ['augmentation', 'presets'],
queryFn: () => augmentationApi.getPresets(),
staleTime: 5 * 60 * 1000,
})
// Preview single augmentation mutation
const previewMutation = useMutation<PreviewResponse, Error, PreviewParams>({
mutationFn: ({ documentId, augmentationType, params, page = 1 }) =>
augmentationApi.preview(
documentId,
{ augmentation_type: augmentationType, params },
page
),
onError: (error) => {
console.error('Preview augmentation failed:', error)
},
})
// Preview full config mutation
const previewConfigMutation = useMutation<PreviewResponse, Error, PreviewConfigParams>({
mutationFn: ({ documentId, config, page = 1 }) =>
augmentationApi.previewConfig(documentId, config, page),
onError: (error) => {
console.error('Preview config failed:', error)
},
})
// Create augmented dataset mutation
const createBatchMutation = useMutation<BatchResponse, Error, BatchRequest>({
mutationFn: (request) => augmentationApi.createBatch(request),
onError: (error) => {
console.error('Create augmented dataset failed:', error)
},
})
return {
// Types data
augmentationTypes: typesData?.augmentation_types || [],
isLoadingTypes,
typesError,
// Presets data
presets: presetsData?.presets || [],
isLoadingPresets,
presetsError,
// Preview single augmentation
preview: previewMutation.mutate,
previewAsync: previewMutation.mutateAsync,
isPreviewing: previewMutation.isPending,
previewData: previewMutation.data,
previewError: previewMutation.error,
// Preview full config
previewConfig: previewConfigMutation.mutate,
previewConfigAsync: previewConfigMutation.mutateAsync,
isPreviewingConfig: previewConfigMutation.isPending,
previewConfigData: previewConfigMutation.data,
previewConfigError: previewConfigMutation.error,
// Create batch
createBatch: createBatchMutation.mutate,
createBatchAsync: createBatchMutation.mutateAsync,
isCreatingBatch: createBatchMutation.isPending,
batchData: createBatchMutation.data,
batchError: createBatchMutation.error,
// Reset functions for clearing stale mutation state
resetPreview: previewMutation.reset,
resetPreviewConfig: previewConfigMutation.reset,
resetBatch: createBatchMutation.reset,
}
}

View File

@@ -0,0 +1,84 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { datasetsApi } from '../api/endpoints'
import type {
DatasetCreateRequest,
DatasetDetailResponse,
DatasetListResponse,
DatasetTrainRequest,
} from '../api/types'
export const useDatasets = (params?: {
status?: string
limit?: number
offset?: number
}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<DatasetListResponse>({
queryKey: ['datasets', params],
queryFn: () => datasetsApi.list(params),
staleTime: 30000,
// Poll every 5 seconds when there's an active training task
refetchInterval: (query) => {
const datasets = query.state.data?.datasets ?? []
const hasActiveTraining = datasets.some(
d => d.training_status === 'running' || d.training_status === 'pending' || d.training_status === 'scheduled'
)
return hasActiveTraining ? 5000 : false
},
})
const createMutation = useMutation({
mutationFn: (req: DatasetCreateRequest) => datasetsApi.create(req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
},
})
const deleteMutation = useMutation({
mutationFn: (datasetId: string) => datasetsApi.remove(datasetId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
},
})
const trainMutation = useMutation({
mutationFn: ({ datasetId, req }: { datasetId: string; req: DatasetTrainRequest }) =>
datasetsApi.trainFromDataset(datasetId, req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
},
})
return {
datasets: data?.datasets ?? [],
total: data?.total ?? 0,
isLoading,
error,
refetch,
createDataset: createMutation.mutate,
createDatasetAsync: createMutation.mutateAsync,
isCreating: createMutation.isPending,
deleteDataset: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
trainFromDataset: trainMutation.mutate,
trainFromDatasetAsync: trainMutation.mutateAsync,
isTraining: trainMutation.isPending,
}
}
export const useDatasetDetail = (datasetId: string | null) => {
const { data, isLoading, error } = useQuery<DatasetDetailResponse>({
queryKey: ['datasets', datasetId],
queryFn: () => datasetsApi.getDetail(datasetId!),
enabled: !!datasetId,
staleTime: 30000,
})
return {
dataset: data ?? null,
isLoading,
error,
}
}

View File

@@ -18,7 +18,16 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
})
const uploadMutation = useMutation({
mutationFn: (file: File) => documentsApi.upload(file),
mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) =>
documentsApi.upload(file, groupKey),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const updateGroupKeyMutation = useMutation({
mutationFn: ({ documentId, groupKey }: { documentId: string; groupKey: string | null }) =>
documentsApi.updateGroupKey(documentId, groupKey),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
@@ -74,5 +83,8 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
isUpdatingStatus: updateStatusMutation.isPending,
triggerAutoLabel: triggerAutoLabelMutation.mutate,
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
updateGroupKey: updateGroupKeyMutation.mutate,
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
}
}

View File

@@ -0,0 +1,98 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { modelsApi } from '../api/endpoints'
import type {
ModelVersionListResponse,
ModelVersionDetailResponse,
ActiveModelResponse,
} from '../api/types'
export const useModels = (params?: {
status?: string
limit?: number
offset?: number
}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<ModelVersionListResponse>({
queryKey: ['models', params],
queryFn: () => modelsApi.list(params),
staleTime: 30000,
})
const activateMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.activate(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
},
})
const deactivateMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.deactivate(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
},
})
const archiveMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.archive(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
},
})
const deleteMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.delete(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
},
})
return {
models: data?.models ?? [],
total: data?.total ?? 0,
isLoading,
error,
refetch,
activateModel: activateMutation.mutate,
activateModelAsync: activateMutation.mutateAsync,
isActivating: activateMutation.isPending,
deactivateModel: deactivateMutation.mutate,
isDeactivating: deactivateMutation.isPending,
archiveModel: archiveMutation.mutate,
isArchiving: archiveMutation.isPending,
deleteModel: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
}
}
export const useModelDetail = (versionId: string | null) => {
const { data, isLoading, error } = useQuery<ModelVersionDetailResponse>({
queryKey: ['models', versionId],
queryFn: () => modelsApi.getDetail(versionId!),
enabled: !!versionId,
staleTime: 30000,
})
return {
model: data ?? null,
isLoading,
error,
}
}
export const useActiveModel = () => {
const { data, isLoading, error } = useQuery<ActiveModelResponse>({
queryKey: ['models', 'active'],
queryFn: () => modelsApi.getActive(),
staleTime: 30000,
})
return {
hasActiveModel: data?.has_active_model ?? false,
activeModel: data?.model ?? null,
isLoading,
error,
}
}

View File

@@ -0,0 +1,8 @@
-- Add group_key column to admin_documents
-- Allows users to organize documents into logical groups
-- Add the column (nullable, VARCHAR 255)
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
-- Add index for filtering/grouping queries
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);

View File

@@ -0,0 +1,49 @@
-- Model versions table for tracking trained model deployments.
-- Each training run can produce a model version for inference.
CREATE TABLE IF NOT EXISTS model_versions (
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
version VARCHAR(50) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
model_path VARCHAR(512) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
is_active BOOLEAN NOT NULL DEFAULT FALSE,
-- Training association
task_id UUID REFERENCES training_tasks(task_id) ON DELETE SET NULL,
dataset_id UUID REFERENCES training_datasets(dataset_id) ON DELETE SET NULL,
-- Training metrics
metrics_mAP DOUBLE PRECISION,
metrics_precision DOUBLE PRECISION,
metrics_recall DOUBLE PRECISION,
document_count INTEGER NOT NULL DEFAULT 0,
-- Training configuration snapshot
training_config JSONB,
-- File info
file_size BIGINT,
-- Timestamps
trained_at TIMESTAMP WITH TIME ZONE,
activated_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_model_versions_version ON model_versions(version);
CREATE INDEX IF NOT EXISTS idx_model_versions_status ON model_versions(status);
CREATE INDEX IF NOT EXISTS idx_model_versions_is_active ON model_versions(is_active);
CREATE INDEX IF NOT EXISTS idx_model_versions_task_id ON model_versions(task_id);
CREATE INDEX IF NOT EXISTS idx_model_versions_dataset_id ON model_versions(dataset_id);
CREATE INDEX IF NOT EXISTS idx_model_versions_created ON model_versions(created_at);
-- Ensure only one active model at a time
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
ON model_versions(is_active) WHERE is_active = TRUE;
-- Comment
COMMENT ON TABLE model_versions IS 'Trained model versions for inference deployment';

View File

@@ -0,0 +1,46 @@
-- Add missing columns to training_tasks table
-- Add name column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
-- Add description column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
-- Add admin_token column (for multi-tenant support)
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
-- Add task_type column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
-- Add recurring schedule columns
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
-- Add result metrics columns (for display without parsing JSONB)
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
-- Rename metrics to config if exists
DO $$
BEGIN
IF EXISTS (SELECT 1 FROM information_schema.columns
WHERE table_name = 'training_tasks' AND column_name = 'metrics'
AND NOT EXISTS (SELECT 1 FROM information_schema.columns
WHERE table_name = 'training_tasks' AND column_name = 'config')) THEN
ALTER TABLE training_tasks RENAME COLUMN metrics TO config;
END IF;
END $$;
-- Add updated_at column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
-- Create index on name
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
-- Create index on metrics_mAP
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);

View File

@@ -0,0 +1,14 @@
-- Fix foreign key constraints on model_versions table to allow CASCADE delete
-- Drop existing constraints
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
-- Add constraints with ON DELETE SET NULL
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_dataset_id_fkey
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_task_id_fkey
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;

View File

@@ -25,6 +25,7 @@ from inference.data.admin_models import (
AnnotationHistory,
TrainingDataset,
DatasetDocument,
ModelVersion,
)
logger = logging.getLogger(__name__)
@@ -110,6 +111,7 @@ class AdminDB:
page_count: int = 1,
upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None,
group_key: str | None = None,
admin_token: str | None = None, # Deprecated, kept for compatibility
) -> str:
"""Create a new document record."""
@@ -122,6 +124,7 @@ class AdminDB:
page_count=page_count,
upload_source=upload_source,
csv_field_values=csv_field_values,
group_key=group_key,
)
session.add(document)
session.flush()
@@ -253,6 +256,17 @@ class AdminDB:
document.updated_at = datetime.utcnow()
session.add(document)
def update_document_group_key(self, document_id: str, group_key: str | None) -> bool:
"""Update document group key."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.group_key = group_key
document.updated_at = datetime.utcnow()
session.add(document)
return True
return False
def delete_document(self, document_id: str) -> bool:
"""Delete a document and its annotations."""
with get_session_context() as session:
@@ -1215,6 +1229,39 @@ class AdminDB:
session.expunge(d)
return list(datasets), total
def get_active_training_tasks_for_datasets(
self, dataset_ids: list[str]
) -> dict[str, dict[str, str]]:
"""Get active (pending/scheduled/running) training tasks for datasets.
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
"""
if not dataset_ids:
return {}
# Validate UUIDs before query
valid_uuids = []
for d in dataset_ids:
try:
valid_uuids.append(UUID(d))
except ValueError:
logger.warning("Invalid UUID in get_active_training_tasks_for_datasets: %s", d)
continue
if not valid_uuids:
return {}
with get_session_context() as session:
statement = select(TrainingTask).where(
TrainingTask.dataset_id.in_(valid_uuids),
TrainingTask.status.in_(["pending", "scheduled", "running"]),
)
results = session.exec(statement).all()
return {
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
for t in results
}
def update_dataset_status(
self,
dataset_id: str | UUID,
@@ -1314,3 +1361,182 @@ class AdminDB:
session.delete(dataset)
session.commit()
return True
# ==========================================================================
# Model Version Operations
# ==========================================================================
def create_model_version(
self,
version: str,
name: str,
model_path: str,
description: str | None = None,
task_id: str | UUID | None = None,
dataset_id: str | UUID | None = None,
metrics_mAP: float | None = None,
metrics_precision: float | None = None,
metrics_recall: float | None = None,
document_count: int = 0,
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
model = ModelVersion(
version=version,
name=name,
model_path=model_path,
description=description,
task_id=UUID(str(task_id)) if task_id else None,
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
)
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def get_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Get a model version by ID."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if model:
session.expunge(model)
return model
def get_model_versions(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[ModelVersion], int]:
"""List model versions with optional status filter."""
with get_session_context() as session:
query = select(ModelVersion)
count_query = select(func.count()).select_from(ModelVersion)
if status:
query = query.where(ModelVersion.status == status)
count_query = count_query.where(ModelVersion.status == status)
total = session.exec(count_query).one()
models = session.exec(
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
).all()
for m in models:
session.expunge(m)
return list(models), total
def get_active_model_version(self) -> ModelVersion | None:
"""Get the currently active model version for inference."""
with get_session_context() as session:
result = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).first()
if result:
session.expunge(result)
return result
def activate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Activate a model version for inference (deactivates all others)."""
with get_session_context() as session:
# Deactivate all versions
all_versions = session.exec(
select(ModelVersion).where(ModelVersion.is_active == True)
).all()
for v in all_versions:
v.is_active = False
v.status = "inactive"
v.updated_at = datetime.utcnow()
session.add(v)
# Activate the specified version
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = True
model.status = "active"
model.activated_at = datetime.utcnow()
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def deactivate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Deactivate a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
model.is_active = False
model.status = "inactive"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def update_model_version(
self,
version_id: str | UUID,
name: str | None = None,
description: str | None = None,
status: str | None = None,
) -> ModelVersion | None:
"""Update model version metadata."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
if name is not None:
model.name = name
if description is not None:
model.description = description
if status is not None:
model.status = status
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def archive_model_version(self, version_id: str | UUID) -> ModelVersion | None:
"""Archive a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return None
# Cannot archive active model
if model.is_active:
return None
model.status = "archived"
model.updated_at = datetime.utcnow()
session.add(model)
session.commit()
session.refresh(model)
session.expunge(model)
return model
def delete_model_version(self, version_id: str | UUID) -> bool:
"""Delete a model version."""
with get_session_context() as session:
model = session.get(ModelVersion, UUID(str(version_id)))
if not model:
return False
# Cannot delete active model
if model.is_active:
return False
session.delete(model)
session.commit()
return True

View File

@@ -70,6 +70,8 @@ class AdminDocument(SQLModel, table=True):
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
group_key: str | None = Field(default=None, max_length=255, index=True)
# User-defined grouping key for document organization
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
@@ -275,6 +277,56 @@ class TrainingDocumentLink(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Model Version Management
# =============================================================================
class ModelVersion(SQLModel, table=True):
"""Model version for inference deployment."""
__tablename__ = "model_versions"
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
version: str = Field(max_length=50, index=True)
# Semantic version e.g., "1.0.0", "2.1.0"
name: str = Field(max_length=255)
description: str | None = Field(default=None)
model_path: str = Field(max_length=512)
# Path to the model weights file
status: str = Field(default="inactive", max_length=20, index=True)
# Status: active, inactive, archived
is_active: bool = Field(default=False, index=True)
# Only one version can be active at a time for inference
# Training association
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training metrics
metrics_mAP: float | None = Field(default=None)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
document_count: int = Field(default=0)
# Number of documents used in training
# Training configuration snapshot
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of epochs, batch_size, etc.
# File info
file_size: int | None = Field(default=None)
# Model file size in bytes
# Timestamps
trained_at: datetime | None = Field(default=None)
# When training completed
activated_at: datetime | None = Field(default=None)
# When this version was last activated
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================

View File

@@ -49,6 +49,111 @@ def get_engine():
return _engine
def run_migrations() -> None:
"""Run database migrations for new columns."""
engine = get_engine()
migrations = [
# Migration 004: Training datasets tables and dataset_id on training_tasks
(
"training_datasets_tables",
"""
CREATE TABLE IF NOT EXISTS training_datasets (
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(255) NOT NULL,
description TEXT,
status VARCHAR(20) NOT NULL DEFAULT 'building',
train_ratio FLOAT NOT NULL DEFAULT 0.8,
val_ratio FLOAT NOT NULL DEFAULT 0.1,
seed INTEGER NOT NULL DEFAULT 42,
total_documents INTEGER NOT NULL DEFAULT 0,
total_images INTEGER NOT NULL DEFAULT 0,
total_annotations INTEGER NOT NULL DEFAULT 0,
dataset_path VARCHAR(512),
error_message TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
""",
),
(
"dataset_documents_table",
"""
CREATE TABLE IF NOT EXISTS dataset_documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
split VARCHAR(10) NOT NULL,
page_count INTEGER NOT NULL DEFAULT 0,
annotation_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
UNIQUE(dataset_id, document_id)
);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
""",
),
(
"training_tasks_dataset_id",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
""",
),
# Migration 005: Add group_key to admin_documents
(
"admin_documents_group_key",
"""
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
""",
),
# Migration 006: Model versions table
(
"model_versions_table",
"""
CREATE TABLE IF NOT EXISTS model_versions (
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
version VARCHAR(50) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
model_path VARCHAR(512) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
is_active BOOLEAN NOT NULL DEFAULT FALSE,
task_id UUID REFERENCES training_tasks(task_id),
dataset_id UUID REFERENCES training_datasets(dataset_id),
metrics_mAP FLOAT,
metrics_precision FLOAT,
metrics_recall FLOAT,
document_count INTEGER NOT NULL DEFAULT 0,
training_config JSONB,
file_size BIGINT,
trained_at TIMESTAMP WITH TIME ZONE,
activated_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
""",
),
]
with engine.connect() as conn:
for name, sql in migrations:
try:
conn.execute(text(sql))
conn.commit()
logger.info(f"Migration '{name}' applied successfully")
except Exception as e:
# Log but don't fail - column may already exist
logger.debug(f"Migration '{name}' skipped or failed: {e}")
def create_db_and_tables() -> None:
"""Create all database tables."""
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
@@ -64,6 +169,9 @@ def create_db_and_tables() -> None:
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
# Run migrations for new columns
run_migrations()
def get_session() -> Session:
"""Get a new database session."""

View File

@@ -5,6 +5,7 @@ Document management, annotations, and training endpoints.
"""
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.api.v1.admin.auth import create_auth_router
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.api.v1.admin.locks import create_locks_router
@@ -12,6 +13,7 @@ from inference.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_augmentation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",

View File

@@ -0,0 +1,15 @@
"""Augmentation API module."""
from fastapi import APIRouter
from .routes import register_augmentation_routes
def create_augmentation_router() -> APIRouter:
"""Create and configure the augmentation router."""
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
register_augmentation_routes(router)
return router
__all__ = ["create_augmentation_router"]

View File

@@ -0,0 +1,162 @@
"""Augmentation API routes."""
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminDBDep, AdminTokenDep
from inference.web.schemas.admin.augmentation import (
AugmentationBatchRequest,
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewRequest,
AugmentationPreviewResponse,
AugmentationTypeInfo,
AugmentationTypesResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
PresetInfo,
PresetsResponse,
)
def register_augmentation_routes(router: APIRouter) -> None:
"""Register augmentation endpoints on the router."""
@router.get(
"/types",
response_model=AugmentationTypesResponse,
summary="List available augmentation types",
)
async def list_augmentation_types(
admin_token: AdminTokenDep,
) -> AugmentationTypesResponse:
"""
List all available augmentation types with descriptions and parameters.
"""
from shared.augmentation.pipeline import (
AUGMENTATION_REGISTRY,
AugmentationPipeline,
)
types = []
for name, aug_class in AUGMENTATION_REGISTRY.items():
# Create instance with empty params to get preview params
aug = aug_class({})
types.append(
AugmentationTypeInfo(
name=name,
description=(aug_class.__doc__ or "").strip(),
affects_geometry=aug_class.affects_geometry,
stage=AugmentationPipeline.STAGE_MAPPING[name],
default_params=aug.get_preview_params(),
)
)
return AugmentationTypesResponse(augmentation_types=types)
@router.get(
"/presets",
response_model=PresetsResponse,
summary="Get augmentation presets",
)
async def get_presets(
admin_token: AdminTokenDep,
) -> PresetsResponse:
"""Get predefined augmentation presets for common use cases."""
from shared.augmentation.presets import list_presets
presets = [PresetInfo(**p) for p in list_presets()]
return PresetsResponse(presets=presets)
@router.post(
"/preview/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview augmentation on document image",
)
async def preview_augmentation(
document_id: str,
request: AugmentationPreviewRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Returns URLs to original and augmented preview images.
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.preview_single(
document_id=document_id,
page=page,
augmentation_type=request.augmentation_type,
params=request.params,
)
@router.post(
"/preview-config/{document_id}",
response_model=AugmentationPreviewResponse,
summary="Preview full augmentation config on document",
)
async def preview_config(
document_id: str,
config: AugmentationConfigSchema,
admin_token: AdminTokenDep,
db: AdminDBDep,
page: int = Query(default=1, ge=1, description="Page number"),
) -> AugmentationPreviewResponse:
"""Preview complete augmentation pipeline on a document page."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.preview_config(
document_id=document_id,
page=page,
config=config,
)
@router.post(
"/batch",
response_model=AugmentationBatchResponse,
summary="Create augmented dataset (offline preprocessing)",
)
async def create_augmented_dataset(
request: AugmentationBatchRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
This runs as a background task. The augmented images are stored
alongside the original dataset for training.
"""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.create_augmented_dataset(
source_dataset_id=request.dataset_id,
config=request.config,
output_name=request.output_name,
multiplier=request.multiplier,
)
@router.get(
"/datasets",
response_model=AugmentedDatasetListResponse,
summary="List augmented datasets",
)
async def list_augmented_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
offset: int = Query(default=0, ge=0, description="Offset"),
) -> AugmentedDatasetListResponse:
"""List all augmented datasets."""
from inference.web.services.augmentation_service import AugmentationService
service = AugmentationService(db=db)
return await service.list_augmented_datasets(limit=limit, offset=offset)

View File

@@ -91,8 +91,19 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
group_key: Annotated[
str | None,
Query(description="Optional group key for document organization", max_length=255),
] = None,
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
@@ -131,6 +142,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
group_key=group_key,
)
# Save file to admin uploads
@@ -177,6 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
group_key=group_key,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@@ -277,6 +290,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
annotation_count=len(annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
@@ -421,6 +435,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
group_key=document.group_key if hasattr(document, 'group_key') else None,
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
@@ -548,4 +563,50 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
return response
@router.patch(
"/{document_id}/group-key",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document group key",
description="Update the group key for a document.",
)
async def update_document_group_key(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
group_key: Annotated[
str | None,
Query(description="New group key (null to clear)"),
] = None,
) -> dict:
"""Update document group key."""
_validate_uuid(document_id, "document_id")
# Validate group_key length
if group_key and len(group_key) > 255:
raise HTTPException(
status_code=400,
detail="Group key must be 255 characters or less",
)
# Verify document exists
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Update group key
db.update_document_group_key(document_id, group_key)
return {
"status": "updated",
"document_id": document_id,
"group_key": group_key,
"message": "Document group key updated",
}
return router

View File

@@ -11,6 +11,7 @@ from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
from .models import register_model_routes
def create_training_router() -> APIRouter:
@@ -21,6 +22,7 @@ def create_training_router() -> APIRouter:
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
register_model_routes(router)
return router

View File

@@ -41,6 +41,13 @@ def register_dataset_routes(router: APIRouter) -> None:
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
# Validate minimum document count for proper train/val/test split
if len(request.document_ids) < 10:
raise HTTPException(
status_code=400,
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
)
dataset = db.create_dataset(
name=request.name,
description=request.description,
@@ -83,6 +90,15 @@ def register_dataset_routes(router: APIRouter) -> None:
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
# Get active training tasks for each dataset (graceful degradation on error)
dataset_ids = [str(d.dataset_id) for d in datasets]
try:
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
except Exception:
logger.exception("Failed to fetch active training tasks")
active_tasks = {}
return DatasetListResponse(
total=total,
limit=limit,
@@ -93,6 +109,8 @@ def register_dataset_routes(router: APIRouter) -> None:
name=d.name,
description=d.description,
status=d.status,
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
@@ -175,6 +193,7 @@ def register_dataset_routes(router: APIRouter) -> None:
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
description="Create a training task. Set base_model_version_id in config for incremental training.",
)
async def train_from_dataset(
dataset_id: str,
@@ -182,7 +201,11 @@ def register_dataset_routes(router: APIRouter) -> None:
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset."""
"""Create a training task from a dataset.
For incremental training, set config.base_model_version_id to a model version UUID.
The training will use that model as the starting point instead of a pretrained model.
"""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
@@ -194,16 +217,42 @@ def register_dataset_routes(router: APIRouter) -> None:
)
config_dict = request.config.model_dump()
# Resolve base_model_version_id to actual model path for incremental training
base_model_version_id = config_dict.get("base_model_version_id")
if base_model_version_id:
_validate_uuid(base_model_version_id, "base_model_version_id")
base_model = db.get_model_version(base_model_version_id)
if not base_model:
raise HTTPException(
status_code=404,
detail=f"Base model version not found: {base_model_version_id}",
)
# Store the resolved model path for the training worker
config_dict["base_model_path"] = base_model.model_path
config_dict["base_model_version"] = base_model.version
logger.info(
"Incremental training: using model %s (%s) as base",
base_model.version,
base_model.model_path,
)
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type="train",
task_type="finetune" if base_model_version_id else "train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
message = (
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
if base_model_version_id
else "Training task created from dataset"
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
message="Training task created from dataset",
message=message,
)

View File

@@ -145,15 +145,15 @@ def register_document_routes(router: APIRouter) -> None:
)
@router.get(
"/models",
"/completed-tasks",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get trained models",
description="Get list of trained models with metrics and download links.",
summary="Get completed training tasks",
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
)
async def get_training_models(
async def get_completed_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[

View File

@@ -0,0 +1,333 @@
"""Model Version Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query, Request
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
ModelVersionItem,
ModelVersionListResponse,
ModelVersionDetailResponse,
ModelVersionResponse,
ActiveModelResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_model_routes(router: APIRouter) -> None:
"""Register model version endpoints on the router."""
@router.post(
"/models",
response_model=ModelVersionResponse,
summary="Create model version",
description="Register a new model version for deployment.",
)
async def create_model_version(
request: ModelVersionCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Create a new model version."""
if request.task_id:
_validate_uuid(request.task_id, "task_id")
if request.dataset_id:
_validate_uuid(request.dataset_id, "dataset_id")
model = db.create_model_version(
version=request.version,
name=request.name,
model_path=request.model_path,
description=request.description,
task_id=request.task_id,
dataset_id=request.dataset_id,
metrics_mAP=request.metrics_mAP,
metrics_precision=request.metrics_precision,
metrics_recall=request.metrics_recall,
document_count=request.document_count,
training_config=request.training_config,
file_size=request.file_size,
trained_at=request.trained_at,
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version created successfully",
)
@router.get(
"/models",
response_model=ModelVersionListResponse,
summary="List model versions",
)
async def list_model_versions(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> ModelVersionListResponse:
"""List model versions with optional status filter."""
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
return ModelVersionListResponse(
total=total,
limit=limit,
offset=offset,
models=[
ModelVersionItem(
version_id=str(m.version_id),
version=m.version,
name=m.name,
status=m.status,
is_active=m.is_active,
metrics_mAP=m.metrics_mAP,
document_count=m.document_count,
trained_at=m.trained_at,
activated_at=m.activated_at,
created_at=m.created_at,
)
for m in models
],
)
@router.get(
"/models/active",
response_model=ActiveModelResponse,
summary="Get active model",
description="Get the currently active model for inference.",
)
async def get_active_model(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ActiveModelResponse:
"""Get the currently active model version."""
model = db.get_active_model_version()
if not model:
return ActiveModelResponse(has_active_model=False, model=None)
return ActiveModelResponse(
has_active_model=True,
model=ModelVersionItem(
version_id=str(model.version_id),
version=model.version,
name=model.name,
status=model.status,
is_active=model.is_active,
metrics_mAP=model.metrics_mAP,
document_count=model.document_count,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
),
)
@router.get(
"/models/{version_id}",
response_model=ModelVersionDetailResponse,
summary="Get model version detail",
)
async def get_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionDetailResponse:
"""Get detailed model version information."""
_validate_uuid(version_id, "version_id")
model = db.get_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionDetailResponse(
version_id=str(model.version_id),
version=model.version,
name=model.name,
description=model.description,
model_path=model.model_path,
status=model.status,
is_active=model.is_active,
task_id=str(model.task_id) if model.task_id else None,
dataset_id=str(model.dataset_id) if model.dataset_id else None,
metrics_mAP=model.metrics_mAP,
metrics_precision=model.metrics_precision,
metrics_recall=model.metrics_recall,
document_count=model.document_count,
training_config=model.training_config,
file_size=model.file_size,
trained_at=model.trained_at,
activated_at=model.activated_at,
created_at=model.created_at,
updated_at=model.updated_at,
)
@router.patch(
"/models/{version_id}",
response_model=ModelVersionResponse,
summary="Update model version",
)
async def update_model_version(
version_id: str,
request: ModelVersionUpdateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Update model version metadata."""
_validate_uuid(version_id, "version_id")
model = db.update_model_version(
version_id=version_id,
name=request.name,
description=request.description,
status=request.status,
)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version updated successfully",
)
@router.post(
"/models/{version_id}/activate",
response_model=ModelVersionResponse,
summary="Activate model version",
description="Activate a model version for inference (deactivates all others).",
)
async def activate_model_version(
version_id: str,
request: Request,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Activate a model version for inference."""
_validate_uuid(version_id, "version_id")
model = db.activate_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
# Trigger model reload in inference service
inference_service = getattr(request.app.state, "inference_service", None)
model_reloaded = False
if inference_service:
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info(f"Inference model reloaded to version {model.version}")
except Exception as e:
logger.warning(f"Failed to reload inference model: {e}")
message = "Model version activated for inference"
if model_reloaded:
message += " (model reloaded)"
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message=message,
)
@router.post(
"/models/{version_id}/deactivate",
response_model=ModelVersionResponse,
summary="Deactivate model version",
)
async def deactivate_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Deactivate a model version."""
_validate_uuid(version_id, "version_id")
model = db.deactivate_model_version(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version deactivated",
)
@router.post(
"/models/{version_id}/archive",
response_model=ModelVersionResponse,
summary="Archive model version",
)
async def archive_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ModelVersionResponse:
"""Archive a model version."""
_validate_uuid(version_id, "version_id")
model = db.archive_model_version(version_id)
if not model:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot archive active model",
)
return ModelVersionResponse(
version_id=str(model.version_id),
status=model.status,
message="Model version archived",
)
@router.delete(
"/models/{version_id}",
summary="Delete model version",
)
async def delete_model_version(
version_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a model version."""
_validate_uuid(version_id, "version_id")
success = db.delete_model_version(version_id)
if not success:
raise HTTPException(
status_code=400,
detail="Model version not found or cannot delete active model",
)
return {"message": "Model version deleted"}
@router.post(
"/models/reload",
summary="Reload inference model",
description="Reload the inference model from the currently active model version.",
)
async def reload_inference_model(
request: Request,
admin_token: AdminTokenDep,
) -> dict:
"""Reload the inference model from active version."""
inference_service = getattr(request.app.state, "inference_service", None)
if not inference_service:
raise HTTPException(
status_code=500,
detail="Inference service not available",
)
try:
model_reloaded = inference_service.reload_model()
if model_reloaded:
logger.info("Inference model manually reloaded")
return {"message": "Model reloaded successfully", "reloaded": True}
else:
return {"message": "Model already up to date", "reloaded": False}
except Exception as e:
logger.error(f"Failed to reload model: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to reload model: {e}",
)

View File

@@ -37,6 +37,7 @@ from inference.web.core.rate_limiter import RateLimiter
# Admin API imports
from inference.web.api.v1.admin import (
create_annotation_router,
create_augmentation_router,
create_auth_router,
create_documents_router,
create_locks_router,
@@ -69,10 +70,23 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
"""
config = config or default_config
# Create inference service
# Create model path resolver that reads from database
def get_active_model_path():
"""Resolve active model path from database."""
try:
db = AdminDB()
active_model = db.get_active_model_version()
if active_model and active_model.model_path:
return active_model.model_path
except Exception as e:
logger.warning(f"Failed to get active model from database: {e}")
return None
# Create inference service with database model resolver
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
model_path_resolver=get_active_model_path,
)
# Create async processing components
@@ -185,6 +199,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
# Store inference service for access by routes (e.g., model reload)
# This will be set after app creation
app = FastAPI(
title="Invoice Field Extraction API",
description="""
@@ -255,9 +272,15 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
augmentation_router = create_augmentation_router()
app.include_router(augmentation_router, prefix="/api/v1/admin")
# Include batch upload routes
app.include_router(batch_upload_router)
# Store inference service in app state for access by routes
app.state.inference_service = inference_service
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:

View File

@@ -110,6 +110,7 @@ class TrainingScheduler:
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
base_model_path = config.get("base_model_path") # For incremental training
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
@@ -117,12 +118,31 @@ class TrainingScheduler:
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
# Get augmentation config if present
augmentation_config = config.get("augmentation")
augmentation_multiplier = config.get("augmentation_multiplier", 2)
# Determine which model to use as base
if base_model_path:
# Incremental training: use existing trained model
if not Path(base_model_path).exists():
raise ValueError(f"Base model not found: {base_model_path}")
effective_model = base_model_path
self._db.add_training_log(
task_id, "INFO",
f"Incremental training from: {base_model_path}",
)
else:
# Train from pretrained model
effective_model = model_name
# Use dataset if available, otherwise export from scratch
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
dataset_path = Path(dataset.dataset_path)
self._db.add_training_log(
task_id, "INFO",
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
@@ -132,15 +152,28 @@ class TrainingScheduler:
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
dataset_path = Path(data_yaml).parent
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Apply augmentation if config is provided
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
aug_result = self._apply_augmentation(
task_id, dataset_path, augmentation_config, augmentation_multiplier
)
if aug_result:
self._db.add_training_log(
task_id, "INFO",
f"Augmentation complete: {aug_result['augmented_images']} new images "
f"(total: {aug_result['total_images']})",
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=model_name,
model_name=effective_model, # Use base model or pretrained model
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
@@ -159,11 +192,94 @@ class TrainingScheduler:
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
# Auto-create model version for the completed training
self._create_model_version_from_training(
task_id=task_id,
config=config,
dataset_id=dataset_id,
result=result,
)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
raise
def _create_model_version_from_training(
self,
task_id: str,
config: dict[str, Any],
dataset_id: str | None,
result: dict[str, Any],
) -> None:
"""Create a model version entry from completed training."""
try:
model_path = result.get("model_path")
if not model_path:
logger.warning(f"No model path in training result for task {task_id}")
return
# Get task info for name
task = self._db.get_training_task(task_id)
task_name = task.name if task else f"Task {task_id[:8]}"
# Generate version number based on existing versions
existing_versions = self._db.get_model_versions(limit=1, offset=0)
version_count = existing_versions[1] if existing_versions else 0
version = f"v{version_count + 1}.0"
# Extract metrics from result
metrics = result.get("metrics", {})
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
metrics_precision = metrics.get("precision")
metrics_recall = metrics.get("recall")
# Get file size if possible
file_size = None
model_file = Path(model_path)
if model_file.exists():
file_size = model_file.stat().st_size
# Get document count from dataset if available
document_count = 0
if dataset_id:
dataset = self._db.get_dataset(dataset_id)
if dataset:
document_count = dataset.total_documents
# Create model version
model_version = self._db.create_model_version(
version=version,
name=task_name,
model_path=str(model_path),
description=f"Auto-created from training task {task_id[:8]}",
task_id=task_id,
dataset_id=dataset_id,
metrics_mAP=metrics_mAP,
metrics_precision=metrics_precision,
metrics_recall=metrics_recall,
document_count=document_count,
training_config=config,
file_size=file_size,
trained_at=datetime.utcnow(),
)
logger.info(
f"Created model version {version} (ID: {model_version.version_id}) "
f"from training task {task_id}"
)
self._db.add_training_log(
task_id, "INFO",
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
)
except Exception as e:
logger.error(f"Failed to create model version for task {task_id}: {e}")
self._db.add_training_log(
task_id, "WARNING",
f"Failed to auto-create model version: {e}",
)
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
@@ -256,62 +372,82 @@ names: {list(FIELD_CLASSES.values())}
device: str,
project_name: str,
) -> dict[str, Any]:
"""Run YOLO training."""
"""Run YOLO training using shared trainer."""
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
# Create log callback that writes to DB
def log_callback(level: str, message: str) -> None:
self._db.add_training_log(task_id, level, message)
# Create shared training config
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
config = SharedTrainingConfig(
model_path=model_name,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project="runs/train",
name=f"{project_name}/task_{task_id[:8]}",
workers=0,
)
# Run training using shared trainer
trainer = YOLOTrainer(config=config, log_callback=log_callback)
result = trainer.train()
if not result.success:
raise ValueError(result.error or "Training failed")
return {
"model_path": result.model_path,
"metrics": result.metrics,
}
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
"""Check if any augmentations are enabled in the config."""
augmentation_fields = [
"perspective_warp", "wrinkle", "edge_damage", "stain",
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
]
for field in augmentation_fields:
if field in aug_config:
field_config = aug_config[field]
if isinstance(field_config, dict) and field_config.get("enabled", False):
return True
return False
def _apply_augmentation(
self,
task_id: str,
dataset_path: Path,
aug_config: dict[str, Any],
multiplier: int,
) -> dict[str, int] | None:
"""Apply augmentation to dataset before training."""
try:
from ultralytics import YOLO
# Log training start
self._db.add_training_log(
task_id, "INFO",
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
)
# Load model
model = YOLO(model_name)
# Train
results = model.train(
data=data_yaml,
epochs=epochs,
batch=batch_size,
imgsz=image_size,
lr0=learning_rate,
device=device,
project=f"runs/train/{project_name}",
name=f"task_{task_id[:8]}",
exist_ok=True,
verbose=True,
)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
from shared.augmentation import DatasetAugmenter
self._db.add_training_log(
task_id, "INFO",
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
f"Applying augmentation with multiplier={multiplier}",
)
return {
"model_path": str(best_model) if best_model.exists() else None,
"metrics": metrics,
}
augmenter = DatasetAugmenter(aug_config)
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
return result
except ImportError:
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
raise ValueError("Ultralytics (YOLO) not installed")
except Exception as e:
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
raise
logger.error(f"Augmentation failed for task {task_id}: {e}")
self._db.add_training_log(
task_id, "WARNING",
f"Augmentation failed: {e}. Continuing with original dataset.",
)
return None
# Global scheduler instance

View File

@@ -10,6 +10,7 @@ from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse

View File

@@ -0,0 +1,187 @@
"""Admin Augmentation Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class AugmentationParamsSchema(BaseModel):
"""Single augmentation parameters."""
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
probability: float = Field(
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
)
params: dict[str, Any] = Field(
default_factory=dict, description="Type-specific parameters"
)
class AugmentationConfigSchema(BaseModel):
"""Complete augmentation configuration."""
# Geometric transforms
perspective_warp: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Degradation effects
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
edge_damage: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Lighting effects
lighting_variation: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
# Blur effects
gaussian_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
motion_blur: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Noise effects
gaussian_noise: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
salt_pepper: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Texture effects
paper_texture: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
scanner_artifacts: AugmentationParamsSchema = Field(
default_factory=AugmentationParamsSchema
)
# Global settings
preserve_bboxes: bool = Field(
default=True, description="Whether to adjust bboxes for geometric transforms"
)
seed: int | None = Field(default=None, description="Random seed for reproducibility")
class AugmentationTypeInfo(BaseModel):
"""Information about an augmentation type."""
name: str = Field(..., description="Augmentation name")
description: str = Field(..., description="Augmentation description")
affects_geometry: bool = Field(
..., description="Whether this augmentation affects bbox coordinates"
)
stage: str = Field(..., description="Processing stage")
default_params: dict[str, Any] = Field(
default_factory=dict, description="Default parameters"
)
class AugmentationTypesResponse(BaseModel):
"""Response for listing augmentation types."""
augmentation_types: list[AugmentationTypeInfo] = Field(
..., description="Available augmentation types"
)
class PresetInfo(BaseModel):
"""Information about a preset."""
name: str = Field(..., description="Preset name")
description: str = Field(..., description="Preset description")
class PresetsResponse(BaseModel):
"""Response for listing presets."""
presets: list[PresetInfo] = Field(..., description="Available presets")
class AugmentationPreviewRequest(BaseModel):
"""Request to preview augmentation on an image."""
augmentation_type: str = Field(..., description="Type of augmentation to preview")
params: dict[str, Any] = Field(
default_factory=dict, description="Override parameters"
)
class AugmentationPreviewResponse(BaseModel):
"""Response with preview image data."""
preview_url: str = Field(..., description="URL to preview image")
original_url: str = Field(..., description="URL to original image")
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
class AugmentationBatchRequest(BaseModel):
"""Request to augment a dataset offline."""
dataset_id: str = Field(..., description="Source dataset UUID")
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
output_name: str = Field(
..., min_length=1, max_length=255, description="Output dataset name"
)
multiplier: int = Field(
default=2, ge=1, le=10, description="Augmented copies per image"
)
class AugmentationBatchResponse(BaseModel):
"""Response for batch augmentation."""
task_id: str = Field(..., description="Background task UUID")
status: str = Field(..., description="Task status")
message: str = Field(..., description="Status message")
estimated_images: int = Field(..., description="Estimated total images")
class AugmentedDatasetItem(BaseModel):
"""Single augmented dataset in list."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
created_at: datetime = Field(..., description="Creation timestamp")
class AugmentedDatasetListResponse(BaseModel):
"""Response for listing augmented datasets."""
total: int = Field(..., ge=0, description="Total datasets")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
datasets: list[AugmentedDatasetItem] = Field(
default_factory=list, description="Dataset list"
)
class AugmentedDatasetDetailResponse(BaseModel):
"""Detailed augmented dataset response."""
dataset_id: str = Field(..., description="Dataset UUID")
source_dataset_id: str = Field(..., description="Source dataset UUID")
name: str = Field(..., description="Dataset name")
status: str = Field(..., description="Dataset status")
config: AugmentationConfigSchema | None = Field(
None, description="Augmentation config used"
)
multiplier: int = Field(..., description="Augmentation multiplier")
total_original_images: int = Field(..., description="Original image count")
total_augmented_images: int = Field(..., description="Augmented image count")
dataset_path: str | None = Field(None, description="Dataset path on disk")
error_message: str | None = Field(None, description="Error message if failed")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")

View File

@@ -63,6 +63,8 @@ class DatasetListItem(BaseModel):
name: str
description: str | None
status: str
training_status: str | None = None
active_training_task_id: str | None = None
total_documents: int
total_images: int
total_annotations: int

View File

@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
group_key: str | None = Field(None, description="User-defined group key")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
@@ -42,6 +43,7 @@ class DocumentItem(BaseModel):
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@@ -73,6 +75,7 @@ class DocumentDetailResponse(BaseModel):
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)

View File

@@ -0,0 +1,95 @@
"""Admin Model Version Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class ModelVersionCreateRequest(BaseModel):
"""Request to create a model version."""
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
name: str = Field(..., min_length=1, max_length=255, description="Model name")
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
description: str | None = Field(None, description="Optional description")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
document_count: int = Field(0, ge=0, description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
class ModelVersionUpdateRequest(BaseModel):
"""Request to update a model version."""
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
description: str | None = Field(None, description="Description")
status: str | None = Field(None, description="Status (inactive, archived)")
class ModelVersionItem(BaseModel):
"""Model version in list view."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
document_count: int = Field(..., description="Documents used in training")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
class ModelVersionListResponse(BaseModel):
"""Paginated model version list."""
total: int = Field(..., ge=0, description="Total model versions")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
class ModelVersionDetailResponse(BaseModel):
"""Detailed model version info."""
version_id: str = Field(..., description="Version UUID")
version: str = Field(..., description="Semantic version")
name: str = Field(..., description="Model name")
description: str | None = Field(None, description="Description")
model_path: str = Field(..., description="Path to model file")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
metrics_precision: float | None = Field(None, description="Precision")
metrics_recall: float | None = Field(None, description="Recall")
document_count: int = Field(..., description="Documents used in training")
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
file_size: int | None = Field(None, description="Model file size in bytes")
trained_at: datetime | None = Field(None, description="Training completion time")
activated_at: datetime | None = Field(None, description="Last activation time")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class ModelVersionResponse(BaseModel):
"""Response for model version operation."""
version_id: str = Field(..., description="Version UUID")
status: str = Field(..., description="Model status")
message: str = Field(..., description="Status message")
class ActiveModelResponse(BaseModel):
"""Response for active model query."""
has_active_model: bool = Field(..., description="Whether an active model exists")
model: ModelVersionItem | None = Field(None, description="Active model if exists")

View File

@@ -5,13 +5,18 @@ from typing import Any
from pydantic import BaseModel, Field
from .augmentation import AugmentationConfigSchema
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name")
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
base_model_version_id: str | None = Field(
default=None,
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
)
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
@@ -21,6 +26,18 @@ class TrainingConfig(BaseModel):
default="invoice_fields", description="Training project name"
)
# Data augmentation settings
augmentation: AugmentationConfigSchema | None = Field(
default=None,
description="Augmentation configuration. If provided, augments dataset before training.",
)
augmentation_multiplier: int = Field(
default=2,
ge=1,
le=10,
description="Number of augmented copies per original image",
)
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""

View File

@@ -0,0 +1,317 @@
"""Augmentation service for handling augmentation operations."""
import base64
import io
import re
import uuid
from pathlib import Path
from typing import Any
import numpy as np
from fastapi import HTTPException
from PIL import Image
from inference.data.admin_db import AdminDB
from inference.web.schemas.admin.augmentation import (
AugmentationBatchResponse,
AugmentationConfigSchema,
AugmentationPreviewResponse,
AugmentedDatasetItem,
AugmentedDatasetListResponse,
)
# Constants
PREVIEW_MAX_SIZE = 800
PREVIEW_SEED = 42
UUID_PATTERN = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
re.IGNORECASE,
)
class AugmentationService:
"""Service for augmentation operations."""
def __init__(self, db: AdminDB) -> None:
"""Initialize service with database connection."""
self.db = db
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
"""
Validate UUID format to prevent path traversal.
Args:
value: Value to validate.
field_name: Field name for error message.
Raises:
HTTPException: If value is not a valid UUID.
"""
if not UUID_PATTERN.match(value):
raise HTTPException(
status_code=400,
detail=f"Invalid {field_name} format: {value}",
)
async def preview_single(
self,
document_id: str,
page: int,
augmentation_type: str,
params: dict[str, Any],
) -> AugmentationPreviewResponse:
"""
Preview a single augmentation on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
augmentation_type: Name of augmentation to apply.
params: Override parameters.
Returns:
Preview response with image URLs.
Raises:
HTTPException: If document not found or augmentation invalid.
"""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
# Validate augmentation type
if augmentation_type not in AUGMENTATION_REGISTRY:
raise HTTPException(
status_code=400,
detail=f"Unknown augmentation type: {augmentation_type}. "
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
)
# Get document and load image
image = await self._load_document_page(document_id, page)
# Create config with only this augmentation enabled
config_kwargs = {
augmentation_type: AugmentationParams(
enabled=True,
probability=1.0, # Always apply for preview
params=params,
),
"seed": PREVIEW_SEED, # Deterministic preview
}
config = AugmentationConfig(**config_kwargs)
pipeline = AugmentationPipeline(config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=params,
)
async def preview_config(
self,
document_id: str,
page: int,
config: AugmentationConfigSchema,
) -> AugmentationPreviewResponse:
"""
Preview full augmentation config on a document page.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
config: Full augmentation configuration.
Returns:
Preview response with image URLs.
"""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
# Load image
image = await self._load_document_page(document_id, page)
# Convert Pydantic model to internal config
config_dict = config.model_dump()
internal_config = AugmentationConfig.from_dict(config_dict)
pipeline = AugmentationPipeline(internal_config)
# Apply augmentation
result = pipeline.apply(image)
# Convert to base64 URLs
original_url = self._image_to_data_url(image)
preview_url = self._image_to_data_url(result.image)
return AugmentationPreviewResponse(
preview_url=preview_url,
original_url=original_url,
applied_params=config_dict,
)
async def create_augmented_dataset(
self,
source_dataset_id: str,
config: AugmentationConfigSchema,
output_name: str,
multiplier: int,
) -> AugmentationBatchResponse:
"""
Create a new augmented dataset from an existing dataset.
Args:
source_dataset_id: Source dataset UUID.
config: Augmentation configuration.
output_name: Name for the new dataset.
multiplier: Number of augmented copies per image.
Returns:
Batch response with task ID.
Raises:
HTTPException: If source dataset not found.
"""
# Validate source dataset exists
try:
source_dataset = self.db.get_dataset(source_dataset_id)
if source_dataset is None:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
)
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Source dataset not found: {source_dataset_id}",
) from e
# Create task ID for background processing
task_id = str(uuid.uuid4())
# Estimate total images
estimated_images = (
source_dataset.total_images * multiplier
if hasattr(source_dataset, "total_images")
else 0
)
# TODO: Queue background task for actual augmentation
# For now, return pending status
return AugmentationBatchResponse(
task_id=task_id,
status="pending",
message=f"Augmentation task queued for dataset '{output_name}'",
estimated_images=estimated_images,
)
async def list_augmented_datasets(
self,
limit: int = 20,
offset: int = 0,
) -> AugmentedDatasetListResponse:
"""
List augmented datasets.
Args:
limit: Maximum number of datasets to return.
offset: Number of datasets to skip.
Returns:
List response with datasets.
"""
# TODO: Implement actual database query for augmented datasets
# For now, return empty list
return AugmentedDatasetListResponse(
total=0,
limit=limit,
offset=offset,
datasets=[],
)
async def _load_document_page(
self,
document_id: str,
page: int,
) -> np.ndarray:
"""
Load a document page as numpy array.
Args:
document_id: Document UUID.
page: Page number (1-indexed).
Returns:
Image as numpy array (H, W, C) with dtype uint8.
Raises:
HTTPException: If document or page not found.
"""
# Validate document_id format to prevent path traversal
self._validate_uuid(document_id, "document_id")
# Get document from database
try:
document = self.db.get_document(document_id)
if document is None:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Document not found: {document_id}",
) from e
# Get image path for page
if hasattr(document, "images_dir"):
images_dir = Path(document.images_dir)
else:
# Fallback to constructed path
from inference.web.core.config import get_settings
settings = get_settings()
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
# Find image for page
page_idx = page - 1 # Convert to 0-indexed
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
if page_idx >= len(image_files):
raise HTTPException(
status_code=404,
detail=f"Page {page} not found for document {document_id}",
)
# Load image
image_path = image_files[page_idx]
pil_image = Image.open(image_path).convert("RGB")
return np.array(pil_image)
def _image_to_data_url(self, image: np.ndarray) -> str:
"""Convert numpy image to base64 data URL."""
pil_image = Image.fromarray(image)
# Resize for preview if too large
max_size = PREVIEW_MAX_SIZE
if max(pil_image.size) > max_size:
ratio = max_size / max(pil_image.size)
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# Convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{base64_data}"

View File

@@ -81,29 +81,18 @@ class DatasetBuilder:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Shuffle and split documents
# 3. Group documents by group_key and assign splits
doc_list = list(documents)
rng = random.Random(seed)
rng.shuffle(doc_list)
n = len(doc_list)
n_train = max(1, round(n * train_ratio))
n_val = max(0, round(n * val_ratio))
n_test = n - n_train - n_val
splits = (
["train"] * n_train
+ ["val"] * n_val
+ ["test"] * n_test
)
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc, split in zip(doc_list, splits):
for doc in doc_list:
doc_id = str(doc.document_id)
split = doc_splits[doc_id]
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
@@ -174,6 +163,86 @@ class DatasetBuilder:
"total_annotations": total_annotations,
}
def _assign_splits_by_group(
self,
documents: list,
train_ratio: float,
val_ratio: float,
seed: int,
) -> dict[str, str]:
"""Assign splits based on group_key.
Logic:
- Documents with same group_key stay together in the same split
- Groups with only 1 document go directly to train
- Groups with 2+ documents participate in shuffle & split
Args:
documents: List of AdminDocument objects
train_ratio: Fraction for training set
val_ratio: Fraction for validation set
seed: Random seed for reproducibility
Returns:
Dict mapping document_id (str) -> split ("train"/"val"/"test")
"""
# Group documents by group_key
# None/empty group_key treated as unique (each doc is its own group)
groups: dict[str | None, list] = {}
for doc in documents:
key = doc.group_key if doc.group_key else None
if key is None:
# Treat each ungrouped doc as its own unique group
# Use document_id as pseudo-key
key = f"__ungrouped_{doc.document_id}"
groups.setdefault(key, []).append(doc)
# Separate single-doc groups from multi-doc groups
single_doc_groups: list[tuple[str | None, list]] = []
multi_doc_groups: list[tuple[str | None, list]] = []
for key, docs in groups.items():
if len(docs) == 1:
single_doc_groups.append((key, docs))
else:
multi_doc_groups.append((key, docs))
# Initialize result mapping
doc_splits: dict[str, str] = {}
# Combine all groups for splitting
all_groups = single_doc_groups + multi_doc_groups
# Shuffle all groups and assign splits
if all_groups:
rng = random.Random(seed)
rng.shuffle(all_groups)
n_groups = len(all_groups)
n_train = max(1, round(n_groups * train_ratio))
# Ensure at least 1 in val if we have more than 1 group
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
for i, (_key, docs) in enumerate(all_groups):
if i < n_train:
split = "train"
elif i < n_train + n_val:
split = "val"
else:
split = "test"
for doc in docs:
doc_splits[str(doc.document_id)] = split
logger.info(
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
len(all_groups),
sum(1 for s in doc_splits.values() if s == "train"),
sum(1 for s in doc_splits.values() if s == "val"),
)
return doc_splits
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {

View File

@@ -11,7 +11,7 @@ import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
import numpy as np
from PIL import Image
@@ -22,6 +22,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Type alias for model path resolver function
ModelPathResolver = Callable[[], Path | None]
@dataclass
class ServiceResult:
"""Result from inference service."""
@@ -42,25 +46,52 @@ class InferenceService:
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration
model_config: Model configuration (default model settings)
storage_config: Storage configuration
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
self._pipeline = None
self._detector = None
self._is_initialized = False
self._current_model_path: Path | None = None
def _resolve_model_path(self) -> Path:
"""Resolve the model path to use for inference.
Priority:
1. Active model from database (via resolver)
2. Default model from config
"""
if self._model_path_resolver:
try:
db_model_path = self._model_path_resolver()
if db_model_path and Path(db_model_path).exists():
logger.info(f"Using active model from database: {db_model_path}")
return Path(db_model_path)
elif db_model_path:
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
except Exception as e:
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
return self.model_config.model_path
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
@@ -74,16 +105,20 @@ class InferenceService:
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
# Resolve model path (from DB or config)
model_path = self._resolve_model_path()
self._current_model_path = model_path
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(self.model_config.model_path),
str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(self.model_config.model_path),
model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
@@ -92,12 +127,36 @@ class InferenceService:
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s")
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
def reload_model(self) -> bool:
"""Reload the model if active model has changed.
Returns:
True if model was reloaded, False if no change needed.
"""
new_model_path = self._resolve_model_path()
if self._current_model_path == new_model_path:
logger.debug("Model unchanged, no reload needed")
return False
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
self._is_initialized = False
self._pipeline = None
self._detector = None
self.initialize()
return True
@property
def current_model_path(self) -> Path | None:
"""Get the currently loaded model path."""
return self._current_model_path
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""

View File

@@ -0,0 +1,24 @@
"""
Document Image Augmentation Module.
Provides augmentation transformations for training data enhancement,
specifically designed for document images (invoices, forms, etc.).
Key features:
- Document-safe augmentations that preserve text readability
- Support for both offline preprocessing and runtime augmentation
- Bbox-aware geometric transforms
- Configurable augmentation pipeline
"""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.dataset_augmenter import DatasetAugmenter
__all__ = [
"AugmentationConfig",
"AugmentationParams",
"AugmentationResult",
"BaseAugmentation",
"DatasetAugmenter",
]

View File

@@ -0,0 +1,108 @@
"""
Base classes for augmentation transforms.
Provides abstract base class and result dataclass for all augmentation
implementations.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
import numpy as np
@dataclass
class AugmentationResult:
"""
Result of applying an augmentation.
Attributes:
image: The augmented image as numpy array (H, W, C).
bboxes: Updated bounding boxes if geometric transform was applied.
Format: (N, 5) array with [class_id, x_center, y_center, width, height].
transform_matrix: The transformation matrix if applicable (for bbox adjustment).
applied: Whether the augmentation was actually applied.
metadata: Additional metadata about the augmentation.
"""
image: np.ndarray
bboxes: np.ndarray | None = None
transform_matrix: np.ndarray | None = None
applied: bool = True
metadata: dict[str, Any] | None = None
class BaseAugmentation(ABC):
"""
Abstract base class for all augmentations.
Subclasses must implement:
- _validate_params(): Validate augmentation parameters
- apply(): Apply the augmentation to an image
Class attributes:
name: Human-readable name of the augmentation.
affects_geometry: True if this augmentation modifies bbox coordinates.
"""
name: str = "base"
affects_geometry: bool = False
def __init__(self, params: dict[str, Any]) -> None:
"""
Initialize augmentation with parameters.
Args:
params: Dictionary of augmentation-specific parameters.
"""
self.params = params
self._validate_params()
@abstractmethod
def _validate_params(self) -> None:
"""
Validate augmentation parameters.
Raises:
ValueError: If parameters are invalid.
"""
pass
@abstractmethod
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
"""
Apply augmentation to image.
IMPORTANT: Implementations must NOT modify the input image or bboxes.
Always create copies before modifying.
Args:
image: Input image as numpy array (H, W, C) with dtype uint8.
bboxes: Optional bounding boxes in YOLO format (N, 5) array.
Each row: [class_id, x_center, y_center, width, height].
Coordinates are normalized to 0-1 range.
rng: Random number generator for reproducibility.
If None, a new generator should be created.
Returns:
AugmentationResult with augmented image and optionally updated bboxes.
"""
pass
def get_preview_params(self) -> dict[str, Any]:
"""
Get parameters optimized for preview display.
Override this method to provide parameters that produce
clearly visible effects for preview/demo purposes.
Returns:
Dictionary of preview parameters.
"""
return dict(self.params)

View File

@@ -0,0 +1,274 @@
"""
Augmentation configuration module.
Provides dataclasses for configuring document image augmentations.
All default values are document-safe (conservative) to preserve text readability.
"""
from dataclasses import dataclass, field
from typing import Any
@dataclass
class AugmentationParams:
"""
Parameters for a single augmentation type.
Attributes:
enabled: Whether this augmentation is enabled.
probability: Probability of applying this augmentation (0.0 to 1.0).
params: Type-specific parameters dictionary.
"""
enabled: bool = False
probability: float = 0.5
params: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"enabled": self.enabled,
"probability": self.probability,
"params": dict(self.params),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AugmentationParams":
"""Create from dictionary."""
return cls(
enabled=data.get("enabled", False),
probability=data.get("probability", 0.5),
params=dict(data.get("params", {})),
)
def _default_perspective_warp() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"max_warp": 0.02}, # Very conservative - 2% max distortion
)
def _default_wrinkle() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"intensity": 0.3, "num_wrinkles": (2, 5)},
)
def _default_edge_damage() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"max_damage_ratio": 0.05}, # Max 5% of edge damaged
)
def _default_stain() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={
"num_stains": (1, 3),
"max_radius_ratio": 0.1,
"opacity": (0.1, 0.3),
},
)
def _default_lighting_variation() -> AugmentationParams:
return AugmentationParams(
enabled=True, # Safe default, commonly needed
probability=0.5,
params={
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
)
def _default_shadow() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"num_shadows": (1, 2), "opacity": (0.2, 0.4)},
)
def _default_gaussian_blur() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
)
def _default_motion_blur() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"kernel_size": (5, 9), "angle_range": (-45, 45)},
)
def _default_gaussian_noise() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"mean": 0, "std": (5, 15)}, # Conservative noise levels
)
def _default_salt_pepper() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"amount": (0.001, 0.005)}, # Very sparse
)
def _default_paper_texture() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"texture_type": "random", "intensity": (0.05, 0.15)},
)
def _default_scanner_artifacts() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"line_probability": 0.3, "dust_probability": 0.4},
)
@dataclass
class AugmentationConfig:
"""
Complete augmentation configuration.
All augmentation types have document-safe defaults that preserve
text readability. Only lighting_variation is enabled by default.
Attributes:
perspective_warp: Geometric perspective transform (affects bboxes).
wrinkle: Paper wrinkle/crease simulation.
edge_damage: Damaged/torn edge effects.
stain: Coffee stain/smudge effects.
lighting_variation: Brightness and contrast variation.
shadow: Shadow overlay effects.
gaussian_blur: Gaussian blur for focus issues.
motion_blur: Motion blur simulation.
gaussian_noise: Gaussian noise for sensor noise.
salt_pepper: Salt and pepper noise.
paper_texture: Paper texture overlay.
scanner_artifacts: Scanner line and dust artifacts.
preserve_bboxes: Whether to adjust bboxes for geometric transforms.
seed: Random seed for reproducibility.
"""
# Geometric transforms (affects bboxes)
perspective_warp: AugmentationParams = field(
default_factory=_default_perspective_warp
)
# Degradation effects
wrinkle: AugmentationParams = field(default_factory=_default_wrinkle)
edge_damage: AugmentationParams = field(default_factory=_default_edge_damage)
stain: AugmentationParams = field(default_factory=_default_stain)
# Lighting effects
lighting_variation: AugmentationParams = field(
default_factory=_default_lighting_variation
)
shadow: AugmentationParams = field(default_factory=_default_shadow)
# Blur effects
gaussian_blur: AugmentationParams = field(default_factory=_default_gaussian_blur)
motion_blur: AugmentationParams = field(default_factory=_default_motion_blur)
# Noise effects
gaussian_noise: AugmentationParams = field(default_factory=_default_gaussian_noise)
salt_pepper: AugmentationParams = field(default_factory=_default_salt_pepper)
# Texture effects
paper_texture: AugmentationParams = field(default_factory=_default_paper_texture)
scanner_artifacts: AugmentationParams = field(
default_factory=_default_scanner_artifacts
)
# Global settings
preserve_bboxes: bool = True
seed: int | None = None
# List of all augmentation field names
_AUGMENTATION_FIELDS: tuple[str, ...] = (
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
result: dict[str, Any] = {
"preserve_bboxes": self.preserve_bboxes,
"seed": self.seed,
}
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
result[field_name] = params.to_dict()
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AugmentationConfig":
"""Create from dictionary."""
kwargs: dict[str, Any] = {
"preserve_bboxes": data.get("preserve_bboxes", True),
"seed": data.get("seed"),
}
for field_name in cls._AUGMENTATION_FIELDS:
if field_name in data:
field_data = data[field_name]
if isinstance(field_data, dict):
kwargs[field_name] = AugmentationParams.from_dict(field_data)
return cls(**kwargs)
def get_enabled_augmentations(self) -> list[str]:
"""Get list of enabled augmentation names."""
enabled = []
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
if params.enabled:
enabled.append(field_name)
return enabled
def validate(self) -> None:
"""
Validate configuration.
Raises:
ValueError: If any configuration value is invalid.
"""
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
if not (0.0 <= params.probability <= 1.0):
raise ValueError(
f"{field_name}.probability must be between 0 and 1, "
f"got {params.probability}"
)

View File

@@ -0,0 +1,206 @@
"""
Dataset Augmenter Module.
Applies augmentation pipeline to YOLO datasets,
creating new augmented images and label files.
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
logger = logging.getLogger(__name__)
class DatasetAugmenter:
"""
Augments YOLO datasets by creating new images and label files.
Reads images from dataset/images/train/ and labels from dataset/labels/train/,
applies augmentation pipeline, and saves augmented versions with "_augN" suffix.
"""
def __init__(
self,
config: dict[str, Any],
seed: int | None = None,
) -> None:
"""
Initialize augmenter with configuration.
Args:
config: Dictionary mapping augmentation names to their settings.
Each augmentation should have 'enabled', 'probability', and 'params'.
seed: Random seed for reproducibility.
"""
self._config_dict = config
self._seed = seed
self._config = self._build_config(config, seed)
def _build_config(
self,
config_dict: dict[str, Any],
seed: int | None,
) -> AugmentationConfig:
"""Build AugmentationConfig from dictionary."""
kwargs: dict[str, Any] = {"seed": seed, "preserve_bboxes": True}
for aug_name, aug_settings in config_dict.items():
if aug_name in AugmentationConfig._AUGMENTATION_FIELDS:
kwargs[aug_name] = AugmentationParams(
enabled=aug_settings.get("enabled", False),
probability=aug_settings.get("probability", 0.5),
params=aug_settings.get("params", {}),
)
return AugmentationConfig(**kwargs)
def augment_dataset(
self,
dataset_path: Path,
multiplier: int = 1,
split: str = "train",
) -> dict[str, int]:
"""
Augment a YOLO dataset.
Args:
dataset_path: Path to dataset root (containing images/ and labels/).
multiplier: Number of augmented copies per original image.
split: Which split to augment (default: "train").
Returns:
Summary dict with original_images, augmented_images, total_images.
"""
images_dir = dataset_path / "images" / split
labels_dir = dataset_path / "labels" / split
if not images_dir.exists():
raise ValueError(f"Images directory not found: {images_dir}")
# Find all images
image_extensions = ("*.png", "*.jpg", "*.jpeg")
image_files: list[Path] = []
for ext in image_extensions:
image_files.extend(images_dir.glob(ext))
original_count = len(image_files)
augmented_count = 0
if multiplier <= 0:
return {
"original_images": original_count,
"augmented_images": 0,
"total_images": original_count,
}
# Process each image
for img_path in image_files:
# Load image
pil_image = Image.open(img_path).convert("RGB")
image = np.array(pil_image)
# Load corresponding label
label_path = labels_dir / f"{img_path.stem}.txt"
bboxes = self._load_bboxes(label_path) if label_path.exists() else None
# Create multiple augmented versions
for aug_idx in range(multiplier):
# Create pipeline with adjusted seed for each augmentation
aug_seed = None
if self._seed is not None:
aug_seed = self._seed + aug_idx + hash(img_path.stem) % 10000
pipeline = AugmentationPipeline(
self._build_config(self._config_dict, aug_seed)
)
# Apply augmentation
result = pipeline.apply(image, bboxes)
# Save augmented image
aug_name = f"{img_path.stem}_aug{aug_idx}{img_path.suffix}"
aug_img_path = images_dir / aug_name
aug_pil = Image.fromarray(result.image)
aug_pil.save(aug_img_path)
# Save augmented label
aug_label_path = labels_dir / f"{img_path.stem}_aug{aug_idx}.txt"
self._save_bboxes(aug_label_path, result.bboxes)
augmented_count += 1
logger.info(
"Dataset augmentation complete: %d original, %d augmented",
original_count,
augmented_count,
)
return {
"original_images": original_count,
"augmented_images": augmented_count,
"total_images": original_count + augmented_count,
}
def _load_bboxes(self, label_path: Path) -> np.ndarray | None:
"""
Load bounding boxes from YOLO label file.
Args:
label_path: Path to label file.
Returns:
Array of shape (N, 5) with class_id, x_center, y_center, width, height.
Returns None if file is empty or doesn't exist.
"""
if not label_path.exists():
return None
content = label_path.read_text().strip()
if not content:
return None
bboxes = []
for line in content.split("\n"):
parts = line.strip().split()
if len(parts) == 5:
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
bboxes.append([class_id, x_center, y_center, width, height])
if not bboxes:
return None
return np.array(bboxes, dtype=np.float32)
def _save_bboxes(self, label_path: Path, bboxes: np.ndarray | None) -> None:
"""
Save bounding boxes to YOLO label file.
Args:
label_path: Path to save label file.
bboxes: Array of shape (N, 5) or None for empty labels.
"""
if bboxes is None or len(bboxes) == 0:
label_path.write_text("")
return
lines = []
for bbox in bboxes:
class_id = int(bbox[0])
x_center = bbox[1]
y_center = bbox[2]
width = bbox[3]
height = bbox[4]
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
label_path.write_text("\n".join(lines))

View File

@@ -0,0 +1,184 @@
"""
Augmentation pipeline module.
Orchestrates multiple augmentations with proper ordering and
provides preview functionality.
"""
from typing import Any
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.transforms.blur import GaussianBlur, MotionBlur
from shared.augmentation.transforms.degradation import EdgeDamage, Stain, Wrinkle
from shared.augmentation.transforms.geometric import PerspectiveWarp
from shared.augmentation.transforms.lighting import LightingVariation, Shadow
from shared.augmentation.transforms.noise import GaussianNoise, SaltPepper
from shared.augmentation.transforms.texture import PaperTexture, ScannerArtifacts
# Registry of augmentation classes
AUGMENTATION_REGISTRY: dict[str, type[BaseAugmentation]] = {
"perspective_warp": PerspectiveWarp,
"wrinkle": Wrinkle,
"edge_damage": EdgeDamage,
"stain": Stain,
"lighting_variation": LightingVariation,
"shadow": Shadow,
"gaussian_blur": GaussianBlur,
"motion_blur": MotionBlur,
"gaussian_noise": GaussianNoise,
"salt_pepper": SaltPepper,
"paper_texture": PaperTexture,
"scanner_artifacts": ScannerArtifacts,
}
class AugmentationPipeline:
"""
Orchestrates multiple augmentations with proper ordering.
Augmentations are applied in the following order:
1. Geometric (perspective_warp) - affects bboxes
2. Degradation (wrinkle, edge_damage, stain) - visual artifacts
3. Lighting (lighting_variation, shadow)
4. Texture (paper_texture, scanner_artifacts)
5. Blur (gaussian_blur, motion_blur)
6. Noise (gaussian_noise, salt_pepper) - applied last
"""
STAGE_ORDER = [
"geometric",
"degradation",
"lighting",
"texture",
"blur",
"noise",
]
STAGE_MAPPING = {
"perspective_warp": "geometric",
"wrinkle": "degradation",
"edge_damage": "degradation",
"stain": "degradation",
"lighting_variation": "lighting",
"shadow": "lighting",
"paper_texture": "texture",
"scanner_artifacts": "texture",
"gaussian_blur": "blur",
"motion_blur": "blur",
"gaussian_noise": "noise",
"salt_pepper": "noise",
}
def __init__(self, config: AugmentationConfig) -> None:
"""
Initialize pipeline with configuration.
Args:
config: Augmentation configuration.
"""
self.config = config
self._rng = np.random.default_rng(config.seed)
self._augmentations = self._build_augmentations()
def _build_augmentations(
self,
) -> list[tuple[str, BaseAugmentation, float]]:
"""Build ordered list of (name, augmentation, probability) tuples."""
augmentations: list[tuple[str, BaseAugmentation, float]] = []
for aug_name, aug_class in AUGMENTATION_REGISTRY.items():
params: AugmentationParams = getattr(self.config, aug_name)
if params.enabled:
aug = aug_class(params.params)
augmentations.append((aug_name, aug, params.probability))
# Sort by stage order
def sort_key(item: tuple[str, BaseAugmentation, float]) -> int:
name, _, _ = item
stage = self.STAGE_MAPPING[name]
return self.STAGE_ORDER.index(stage)
return sorted(augmentations, key=sort_key)
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
) -> AugmentationResult:
"""
Apply augmentation pipeline to image.
Args:
image: Input image (H, W, C) as numpy array with dtype uint8.
bboxes: Optional bounding boxes in YOLO format (N, 5).
Returns:
AugmentationResult with augmented image and optionally adjusted bboxes.
"""
current_image = image.copy()
current_bboxes = bboxes.copy() if bboxes is not None else None
applied_augmentations: list[str] = []
for name, aug, probability in self._augmentations:
if self._rng.random() < probability:
result = aug.apply(current_image, current_bboxes, self._rng)
current_image = result.image
if result.bboxes is not None and self.config.preserve_bboxes:
current_bboxes = result.bboxes
applied_augmentations.append(name)
return AugmentationResult(
image=current_image,
bboxes=current_bboxes,
metadata={"applied_augmentations": applied_augmentations},
)
def preview(
self,
image: np.ndarray,
augmentation_name: str,
) -> np.ndarray:
"""
Preview a single augmentation deterministically.
Args:
image: Input image.
augmentation_name: Name of augmentation to preview.
Returns:
Augmented image.
Raises:
ValueError: If augmentation_name is not recognized.
"""
if augmentation_name not in AUGMENTATION_REGISTRY:
raise ValueError(f"Unknown augmentation: {augmentation_name}")
params: AugmentationParams = getattr(self.config, augmentation_name)
aug = AUGMENTATION_REGISTRY[augmentation_name](params.params)
# Use deterministic RNG for preview
preview_rng = np.random.default_rng(42)
result = aug.apply(image.copy(), rng=preview_rng)
return result.image
def get_available_augmentations() -> list[dict[str, Any]]:
"""
Get list of available augmentations with metadata.
Returns:
List of dictionaries with augmentation info.
"""
augmentations = []
for name, aug_class in AUGMENTATION_REGISTRY.items():
augmentations.append({
"name": name,
"description": aug_class.__doc__ or "",
"affects_geometry": aug_class.affects_geometry,
"stage": AugmentationPipeline.STAGE_MAPPING[name],
})
return augmentations

View File

@@ -0,0 +1,212 @@
"""
Predefined augmentation presets for common document scenarios.
Presets provide ready-to-use configurations optimized for different
use cases, from conservative (preserves text readability) to aggressive
(simulates poor document quality).
"""
from typing import Any
from shared.augmentation.config import AugmentationConfig, AugmentationParams
PRESETS: dict[str, dict[str, Any]] = {
"conservative": {
"description": "Safe augmentations that preserve text readability",
"config": {
"lighting_variation": {
"enabled": True,
"probability": 0.5,
"params": {
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (3, 10)},
},
},
},
"moderate": {
"description": "Balanced augmentations for typical document degradation",
"config": {
"lighting_variation": {
"enabled": True,
"probability": 0.5,
"params": {
"brightness_range": (-0.15, 0.15),
"contrast_range": (0.85, 1.15),
},
},
"shadow": {
"enabled": True,
"probability": 0.3,
"params": {"num_shadows": (1, 2), "opacity": (0.2, 0.35)},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (5, 12)},
},
"gaussian_blur": {
"enabled": True,
"probability": 0.2,
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.0)},
},
"paper_texture": {
"enabled": True,
"probability": 0.3,
"params": {"intensity": (0.05, 0.12)},
},
},
},
"aggressive": {
"description": "Heavy augmentations simulating poor scan quality",
"config": {
"perspective_warp": {
"enabled": True,
"probability": 0.3,
"params": {"max_warp": 0.02},
},
"wrinkle": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": 0.3, "num_wrinkles": (2, 4)},
},
"stain": {
"enabled": True,
"probability": 0.3,
"params": {
"num_stains": (1, 2),
"max_radius_ratio": 0.08,
"opacity": (0.1, 0.25),
},
},
"lighting_variation": {
"enabled": True,
"probability": 0.6,
"params": {
"brightness_range": (-0.2, 0.2),
"contrast_range": (0.8, 1.2),
},
},
"shadow": {
"enabled": True,
"probability": 0.4,
"params": {"num_shadows": (1, 2), "opacity": (0.25, 0.4)},
},
"gaussian_blur": {
"enabled": True,
"probability": 0.3,
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
},
"motion_blur": {
"enabled": True,
"probability": 0.2,
"params": {"kernel_size": (5, 7), "angle_range": (-30, 30)},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.4,
"params": {"std": (8, 18)},
},
"paper_texture": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": (0.08, 0.15)},
},
"scanner_artifacts": {
"enabled": True,
"probability": 0.3,
"params": {"line_probability": 0.4, "dust_probability": 0.5},
},
"edge_damage": {
"enabled": True,
"probability": 0.2,
"params": {"max_damage_ratio": 0.04},
},
},
},
"scanned_document": {
"description": "Simulates typical scanned document artifacts",
"config": {
"scanner_artifacts": {
"enabled": True,
"probability": 0.5,
"params": {"line_probability": 0.4, "dust_probability": 0.5},
},
"paper_texture": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": (0.05, 0.12)},
},
"lighting_variation": {
"enabled": True,
"probability": 0.3,
"params": {
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (5, 12)},
},
},
},
}
def get_preset_config(preset_name: str) -> dict[str, Any]:
"""
Get the configuration dictionary for a preset.
Args:
preset_name: Name of the preset.
Returns:
Configuration dictionary.
Raises:
ValueError: If preset is not found.
"""
if preset_name not in PRESETS:
raise ValueError(
f"Unknown preset: {preset_name}. "
f"Available presets: {list(PRESETS.keys())}"
)
return PRESETS[preset_name]["config"]
def create_config_from_preset(preset_name: str) -> AugmentationConfig:
"""
Create an AugmentationConfig from a preset.
Args:
preset_name: Name of the preset.
Returns:
AugmentationConfig instance.
Raises:
ValueError: If preset is not found.
"""
config_dict = get_preset_config(preset_name)
return AugmentationConfig.from_dict(config_dict)
def list_presets() -> list[dict[str, str]]:
"""
List all available presets.
Returns:
List of dictionaries with name and description.
"""
return [
{"name": name, "description": preset["description"]}
for name, preset in PRESETS.items()
]

View File

@@ -0,0 +1,13 @@
"""
Augmentation transform implementations.
Each module contains related augmentation classes:
- geometric.py: Perspective warp and other geometric transforms
- degradation.py: Wrinkle, edge damage, stain effects
- lighting.py: Lighting variation and shadow effects
- blur.py: Gaussian and motion blur
- noise.py: Gaussian and salt-pepper noise
- texture.py: Paper texture and scanner artifacts
"""
# Will be populated as transforms are implemented

View File

@@ -0,0 +1,144 @@
"""
Blur augmentation transforms.
Provides blur effects for document image augmentation:
- GaussianBlur: Simulates out-of-focus capture
- MotionBlur: Simulates camera/document movement during capture
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class GaussianBlur(BaseAugmentation):
"""
Applies Gaussian blur to the image.
Simulates out-of-focus capture or low-quality optics.
Conservative defaults to preserve text readability.
Parameters:
kernel_size: Blur kernel size, int or (min, max) tuple (default: (3, 5)).
sigma: Blur sigma, float or (min, max) tuple (default: (0.5, 1.5)).
"""
name = "gaussian_blur"
affects_geometry = False
def _validate_params(self) -> None:
kernel_size = self.params.get("kernel_size", (3, 5))
if isinstance(kernel_size, int):
if kernel_size < 1 or kernel_size % 2 == 0:
raise ValueError("kernel_size must be a positive odd integer")
elif isinstance(kernel_size, tuple):
if kernel_size[0] < 1 or kernel_size[1] < kernel_size[0]:
raise ValueError("kernel_size tuple must be (min, max) with min >= 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
kernel_size = self.params.get("kernel_size", (3, 5))
sigma = self.params.get("sigma", (0.5, 1.5))
if isinstance(kernel_size, tuple):
# Choose random odd kernel size
min_k, max_k = kernel_size
possible_sizes = [k for k in range(min_k, max_k + 1) if k % 2 == 1]
if not possible_sizes:
possible_sizes = [min_k if min_k % 2 == 1 else min_k + 1]
kernel_size = rng.choice(possible_sizes)
if isinstance(sigma, tuple):
sigma = rng.uniform(sigma[0], sigma[1])
# Ensure kernel size is odd
if kernel_size % 2 == 0:
kernel_size += 1
# Apply Gaussian blur
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
return AugmentationResult(
image=blurred,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"kernel_size": kernel_size, "sigma": sigma},
)
def get_preview_params(self) -> dict:
return {"kernel_size": 5, "sigma": 1.5}
class MotionBlur(BaseAugmentation):
"""
Applies motion blur to the image.
Simulates camera shake or document movement during capture.
Parameters:
kernel_size: Blur kernel size, int or (min, max) tuple (default: (5, 9)).
angle_range: Motion angle range in degrees (default: (-45, 45)).
"""
name = "motion_blur"
affects_geometry = False
def _validate_params(self) -> None:
kernel_size = self.params.get("kernel_size", (5, 9))
if isinstance(kernel_size, int):
if kernel_size < 3:
raise ValueError("kernel_size must be at least 3")
elif isinstance(kernel_size, tuple):
if kernel_size[0] < 3:
raise ValueError("kernel_size min must be at least 3")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
kernel_size = self.params.get("kernel_size", (5, 9))
angle_range = self.params.get("angle_range", (-45, 45))
if isinstance(kernel_size, tuple):
kernel_size = rng.integers(kernel_size[0], kernel_size[1] + 1)
angle = rng.uniform(angle_range[0], angle_range[1])
# Create motion blur kernel
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
# Draw a line in the center of the kernel
center = kernel_size // 2
angle_rad = np.deg2rad(angle)
for i in range(kernel_size):
offset = i - center
x = int(center + offset * np.cos(angle_rad))
y = int(center + offset * np.sin(angle_rad))
if 0 <= x < kernel_size and 0 <= y < kernel_size:
kernel[y, x] = 1.0
# Normalize kernel
kernel = kernel / kernel.sum() if kernel.sum() > 0 else kernel
# Apply motion blur
blurred = cv2.filter2D(image, -1, kernel)
return AugmentationResult(
image=blurred,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"kernel_size": kernel_size, "angle": angle},
)
def get_preview_params(self) -> dict:
return {"kernel_size": 7, "angle_range": (-30, 30)}

View File

@@ -0,0 +1,259 @@
"""
Degradation augmentation transforms.
Provides degradation effects for document image augmentation:
- Wrinkle: Paper wrinkle/crease simulation
- EdgeDamage: Damaged/torn edge effects
- Stain: Coffee stain/smudge effects
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class Wrinkle(BaseAugmentation):
"""
Simulates paper wrinkles/creases using displacement mapping.
Document-friendly: Uses subtle displacement to preserve text readability.
Parameters:
intensity: Wrinkle intensity (0-1) (default: 0.3).
num_wrinkles: Number of wrinkles, int or (min, max) tuple (default: (2, 5)).
"""
name = "wrinkle"
affects_geometry = False
def _validate_params(self) -> None:
intensity = self.params.get("intensity", 0.3)
if not (0 < intensity <= 1):
raise ValueError("intensity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
intensity = self.params.get("intensity", 0.3)
num_wrinkles = self.params.get("num_wrinkles", (2, 5))
if isinstance(num_wrinkles, tuple):
num_wrinkles = rng.integers(num_wrinkles[0], num_wrinkles[1] + 1)
# Create displacement maps
displacement_x = np.zeros((h, w), dtype=np.float32)
displacement_y = np.zeros((h, w), dtype=np.float32)
for _ in range(num_wrinkles):
# Random wrinkle parameters
angle = rng.uniform(0, np.pi)
x0 = rng.uniform(0, w)
y0 = rng.uniform(0, h)
length = rng.uniform(0.3, 0.8) * min(h, w)
width = rng.uniform(0.02, 0.05) * min(h, w)
# Create coordinate grids
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
# Distance from wrinkle line
dx = (xx - x0) * np.cos(angle) + (yy - y0) * np.sin(angle)
dy = -(xx - x0) * np.sin(angle) + (yy - y0) * np.cos(angle)
# Gaussian falloff perpendicular to wrinkle
mask = np.exp(-dy**2 / (2 * width**2))
mask *= (np.abs(dx) < length / 2).astype(np.float32)
# Displacement perpendicular to wrinkle
disp_amount = intensity * rng.uniform(2, 8)
displacement_x += mask * disp_amount * np.sin(angle)
displacement_y += mask * disp_amount * np.cos(angle)
# Create remap coordinates
map_x = (np.arange(w)[np.newaxis, :] + displacement_x).astype(np.float32)
map_y = (np.arange(h)[:, np.newaxis] + displacement_y).astype(np.float32)
# Apply displacement
augmented = cv2.remap(
image, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT
)
# Add subtle shading along wrinkles
max_disp = np.max(np.abs(displacement_y)) + 1e-6
shading = 1 - 0.1 * intensity * np.abs(displacement_y) / max_disp
shading = shading[:, :, np.newaxis]
augmented = (augmented.astype(np.float32) * shading).astype(np.uint8)
return AugmentationResult(
image=augmented,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_wrinkles": num_wrinkles, "intensity": intensity},
)
def get_preview_params(self) -> dict:
return {"intensity": 0.5, "num_wrinkles": 3}
class EdgeDamage(BaseAugmentation):
"""
Adds damaged/torn edge effects to the image.
Simulates worn or torn document edges.
Parameters:
max_damage_ratio: Maximum proportion of edge to damage (default: 0.05).
edges: Which edges to potentially damage (default: all).
"""
name = "edge_damage"
affects_geometry = False
def _validate_params(self) -> None:
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
if not (0 < max_damage_ratio <= 0.2):
raise ValueError("max_damage_ratio must be between 0 and 0.2")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
edges = self.params.get("edges", ["top", "bottom", "left", "right"])
output = image.copy()
# Select random edge to damage
edge = rng.choice(edges)
damage_size = int(max_damage_ratio * min(h, w))
if edge == "top":
# Create irregular top edge
for x in range(w):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
# Random color (white or darker)
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[:depth, x] = color
elif edge == "bottom":
for x in range(w):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[h - depth:, x] = color
elif edge == "left":
for y in range(h):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[y, :depth] = color
else: # right
for y in range(h):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[y, w - depth:] = color
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"edge": edge, "damage_size": damage_size},
)
def get_preview_params(self) -> dict:
return {"max_damage_ratio": 0.08}
class Stain(BaseAugmentation):
"""
Adds coffee stain/smudge effects to the image.
Simulates accidental stains on documents.
Parameters:
num_stains: Number of stains, int or (min, max) tuple (default: (1, 3)).
max_radius_ratio: Maximum stain radius as ratio of image size (default: 0.1).
opacity: Stain opacity, float or (min, max) tuple (default: (0.1, 0.3)).
"""
name = "stain"
affects_geometry = False
def _validate_params(self) -> None:
opacity = self.params.get("opacity", (0.1, 0.3))
if isinstance(opacity, (int, float)):
if not (0 < opacity <= 1):
raise ValueError("opacity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
num_stains = self.params.get("num_stains", (1, 3))
max_radius_ratio = self.params.get("max_radius_ratio", 0.1)
opacity = self.params.get("opacity", (0.1, 0.3))
if isinstance(num_stains, tuple):
num_stains = rng.integers(num_stains[0], num_stains[1] + 1)
if isinstance(opacity, tuple):
opacity = rng.uniform(opacity[0], opacity[1])
output = image.astype(np.float32)
max_radius = int(max_radius_ratio * min(h, w))
for _ in range(num_stains):
# Random stain position and size
cx = rng.integers(max_radius, w - max_radius)
cy = rng.integers(max_radius, h - max_radius)
radius = rng.integers(max_radius // 3, max_radius)
# Create stain mask with irregular edges
yy, xx = np.ogrid[:h, :w]
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
# Add noise to make edges irregular
noise = rng.uniform(0.8, 1.2, (h, w))
mask = (dist < radius * noise).astype(np.float32)
# Blur for soft edges
mask = cv2.GaussianBlur(mask, (21, 21), 0)
# Random stain color (brownish/yellowish)
stain_color = np.array([
rng.integers(180, 220), # R
rng.integers(160, 200), # G
rng.integers(120, 160), # B
], dtype=np.float32)
# Apply stain
mask_3d = mask[:, :, np.newaxis]
output = output * (1 - mask_3d * opacity) + stain_color * mask_3d * opacity
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_stains": num_stains, "opacity": opacity},
)
def get_preview_params(self) -> dict:
return {"num_stains": 2, "max_radius_ratio": 0.1, "opacity": 0.25}

View File

@@ -0,0 +1,145 @@
"""
Geometric augmentation transforms.
Provides geometric transforms for document image augmentation:
- PerspectiveWarp: Subtle perspective distortion
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class PerspectiveWarp(BaseAugmentation):
"""
Applies subtle perspective transformation to the image.
Simulates viewing document at slight angle. Very conservative
by default to preserve text readability.
IMPORTANT: This transform affects bounding box coordinates.
Parameters:
max_warp: Maximum warp as proportion of image size (default: 0.02).
"""
name = "perspective_warp"
affects_geometry = True
def _validate_params(self) -> None:
max_warp = self.params.get("max_warp", 0.02)
if not (0 < max_warp <= 0.1):
raise ValueError("max_warp must be between 0 and 0.1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
max_warp = self.params.get("max_warp", 0.02)
# Original corners
src_pts = np.float32([
[0, 0],
[w, 0],
[w, h],
[0, h],
])
# Add random perturbations to corners
max_offset = max_warp * min(h, w)
dst_pts = src_pts.copy()
for i in range(4):
dst_pts[i, 0] += rng.uniform(-max_offset, max_offset)
dst_pts[i, 1] += rng.uniform(-max_offset, max_offset)
# Compute perspective transform matrix
transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
# Apply perspective transform
warped = cv2.warpPerspective(
image, transform_matrix, (w, h),
borderMode=cv2.BORDER_REPLICATE
)
# Transform bounding boxes if present
transformed_bboxes = None
if bboxes is not None:
transformed_bboxes = self._transform_bboxes(
bboxes, transform_matrix, w, h
)
return AugmentationResult(
image=warped,
bboxes=transformed_bboxes,
transform_matrix=transform_matrix,
metadata={"max_warp": max_warp},
)
def _transform_bboxes(
self,
bboxes: np.ndarray,
transform_matrix: np.ndarray,
w: int,
h: int,
) -> np.ndarray:
"""Transform bounding boxes using perspective matrix."""
if len(bboxes) == 0:
return bboxes.copy()
transformed = []
for bbox in bboxes:
class_id, x_center, y_center, width, height = bbox
# Convert normalized coords to pixel coords
x_center_px = x_center * w
y_center_px = y_center * h
width_px = width * w
height_px = height * h
# Get corner points
x1 = x_center_px - width_px / 2
y1 = y_center_px - height_px / 2
x2 = x_center_px + width_px / 2
y2 = y_center_px + height_px / 2
# Transform all 4 corners
corners = np.float32([
[x1, y1],
[x2, y1],
[x2, y2],
[x1, y2],
]).reshape(-1, 1, 2)
transformed_corners = cv2.perspectiveTransform(corners, transform_matrix)
transformed_corners = transformed_corners.reshape(-1, 2)
# Get bounding box of transformed corners
new_x1 = np.min(transformed_corners[:, 0])
new_y1 = np.min(transformed_corners[:, 1])
new_x2 = np.max(transformed_corners[:, 0])
new_y2 = np.max(transformed_corners[:, 1])
# Convert back to normalized center format
new_width = (new_x2 - new_x1) / w
new_height = (new_y2 - new_y1) / h
new_x_center = ((new_x1 + new_x2) / 2) / w
new_y_center = ((new_y1 + new_y2) / 2) / h
# Clamp to valid range
new_x_center = np.clip(new_x_center, 0, 1)
new_y_center = np.clip(new_y_center, 0, 1)
new_width = np.clip(new_width, 0, 1)
new_height = np.clip(new_height, 0, 1)
transformed.append([class_id, new_x_center, new_y_center, new_width, new_height])
return np.array(transformed, dtype=np.float32)
def get_preview_params(self) -> dict:
return {"max_warp": 0.03}

View File

@@ -0,0 +1,167 @@
"""
Lighting augmentation transforms.
Provides lighting effects for document image augmentation:
- LightingVariation: Adjusts brightness and contrast
- Shadow: Adds shadow overlay effects
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class LightingVariation(BaseAugmentation):
"""
Adjusts image brightness and contrast.
Simulates different lighting conditions during document capture.
Safe for documents with conservative default parameters.
Parameters:
brightness_range: (min, max) brightness adjustment (default: (-0.1, 0.1)).
contrast_range: (min, max) contrast multiplier (default: (0.9, 1.1)).
"""
name = "lighting_variation"
affects_geometry = False
def _validate_params(self) -> None:
brightness = self.params.get("brightness_range", (-0.1, 0.1))
contrast = self.params.get("contrast_range", (0.9, 1.1))
if not isinstance(brightness, tuple) or len(brightness) != 2:
raise ValueError("brightness_range must be a (min, max) tuple")
if not isinstance(contrast, tuple) or len(contrast) != 2:
raise ValueError("contrast_range must be a (min, max) tuple")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
brightness_range = self.params.get("brightness_range", (-0.1, 0.1))
contrast_range = self.params.get("contrast_range", (0.9, 1.1))
# Random brightness and contrast
brightness = rng.uniform(brightness_range[0], brightness_range[1])
contrast = rng.uniform(contrast_range[0], contrast_range[1])
# Apply adjustments
adjusted = image.astype(np.float32)
# Contrast adjustment (multiply around mean)
mean = adjusted.mean()
adjusted = (adjusted - mean) * contrast + mean
# Brightness adjustment (add offset)
adjusted = adjusted + brightness * 255
# Clip and convert back
adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
return AugmentationResult(
image=adjusted,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"brightness": brightness, "contrast": contrast},
)
def get_preview_params(self) -> dict:
return {"brightness_range": (-0.15, 0.15), "contrast_range": (0.85, 1.15)}
class Shadow(BaseAugmentation):
"""
Adds shadow overlay effects to the image.
Simulates shadows from objects or hands during document capture.
Parameters:
num_shadows: Number of shadow regions, int or (min, max) tuple (default: (1, 2)).
opacity: Shadow darkness, float or (min, max) tuple (default: (0.2, 0.4)).
"""
name = "shadow"
affects_geometry = False
def _validate_params(self) -> None:
opacity = self.params.get("opacity", (0.2, 0.4))
if isinstance(opacity, (int, float)):
if not (0 <= opacity <= 1):
raise ValueError("opacity must be between 0 and 1")
elif isinstance(opacity, tuple):
if not (0 <= opacity[0] <= opacity[1] <= 1):
raise ValueError("opacity tuple must be in range [0, 1]")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
num_shadows = self.params.get("num_shadows", (1, 2))
opacity = self.params.get("opacity", (0.2, 0.4))
if isinstance(num_shadows, tuple):
num_shadows = rng.integers(num_shadows[0], num_shadows[1] + 1)
if isinstance(opacity, tuple):
opacity = rng.uniform(opacity[0], opacity[1])
h, w = image.shape[:2]
output = image.astype(np.float32)
for _ in range(num_shadows):
# Generate random shadow polygon
num_vertices = rng.integers(3, 6)
vertices = []
# Start from a random edge
edge = rng.integers(0, 4)
if edge == 0: # Top
start = (rng.integers(0, w), 0)
elif edge == 1: # Right
start = (w, rng.integers(0, h))
elif edge == 2: # Bottom
start = (rng.integers(0, w), h)
else: # Left
start = (0, rng.integers(0, h))
vertices.append(start)
# Add random vertices
for _ in range(num_vertices - 1):
x = rng.integers(0, w)
y = rng.integers(0, h)
vertices.append((x, y))
# Create shadow mask
mask = np.zeros((h, w), dtype=np.float32)
pts = np.array(vertices, dtype=np.int32).reshape((-1, 1, 2))
cv2.fillPoly(mask, [pts], 1.0)
# Blur the mask for soft edges
blur_size = max(31, min(h, w) // 10)
if blur_size % 2 == 0:
blur_size += 1
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
# Apply shadow
shadow_factor = 1 - opacity * mask[:, :, np.newaxis]
output = output * shadow_factor
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_shadows": num_shadows, "opacity": opacity},
)
def get_preview_params(self) -> dict:
return {"num_shadows": 1, "opacity": 0.3}

View File

@@ -0,0 +1,142 @@
"""
Noise augmentation transforms.
Provides noise effects for document image augmentation:
- GaussianNoise: Adds Gaussian noise to simulate sensor noise
- SaltPepper: Adds salt and pepper noise for impulse noise effects
"""
from typing import Any
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class GaussianNoise(BaseAugmentation):
"""
Adds Gaussian noise to the image.
Simulates sensor noise from cameras or scanners.
Document-safe with conservative default parameters.
Parameters:
mean: Mean of the Gaussian noise (default: 0).
std: Standard deviation, can be int or (min, max) tuple (default: (5, 15)).
"""
name = "gaussian_noise"
affects_geometry = False
def _validate_params(self) -> None:
std = self.params.get("std", (5, 15))
if isinstance(std, (int, float)):
if std < 0:
raise ValueError("std must be non-negative")
elif isinstance(std, tuple):
if len(std) != 2 or std[0] < 0 or std[1] < std[0]:
raise ValueError("std tuple must be (min, max) with min <= max >= 0")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
mean = self.params.get("mean", 0)
std = self.params.get("std", (5, 15))
if isinstance(std, tuple):
std = rng.uniform(std[0], std[1])
# Generate noise
noise = rng.normal(mean, std, image.shape).astype(np.float32)
# Apply noise
noisy = image.astype(np.float32) + noise
noisy = np.clip(noisy, 0, 255).astype(np.uint8)
return AugmentationResult(
image=noisy,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"applied_std": std},
)
def get_preview_params(self) -> dict[str, Any]:
return {"mean": 0, "std": 15}
class SaltPepper(BaseAugmentation):
"""
Adds salt and pepper (impulse) noise to the image.
Simulates defects from damaged sensors or transmission errors.
Very sparse by default to preserve document readability.
Parameters:
amount: Proportion of pixels to affect, can be float or (min, max) tuple.
Default: (0.001, 0.005) for very sparse noise.
salt_vs_pepper: Ratio of salt to pepper (default: 0.5 for equal amounts).
"""
name = "salt_pepper"
affects_geometry = False
def _validate_params(self) -> None:
amount = self.params.get("amount", (0.001, 0.005))
if isinstance(amount, (int, float)):
if not (0 <= amount <= 1):
raise ValueError("amount must be between 0 and 1")
elif isinstance(amount, tuple):
if len(amount) != 2 or not (0 <= amount[0] <= amount[1] <= 1):
raise ValueError("amount tuple must be (min, max) in range [0, 1]")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
amount = self.params.get("amount", (0.001, 0.005))
salt_vs_pepper = self.params.get("salt_vs_pepper", 0.5)
if isinstance(amount, tuple):
amount = rng.uniform(amount[0], amount[1])
# Copy image
output = image.copy()
h, w = image.shape[:2]
total_pixels = h * w
# Calculate number of salt and pepper pixels
num_salt = int(total_pixels * amount * salt_vs_pepper)
num_pepper = int(total_pixels * amount * (1 - salt_vs_pepper))
# Add salt (white pixels)
if num_salt > 0:
salt_coords = (
rng.integers(0, h, num_salt),
rng.integers(0, w, num_salt),
)
output[salt_coords] = 255
# Add pepper (black pixels)
if num_pepper > 0:
pepper_coords = (
rng.integers(0, h, num_pepper),
rng.integers(0, w, num_pepper),
)
output[pepper_coords] = 0
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"applied_amount": amount},
)
def get_preview_params(self) -> dict[str, Any]:
return {"amount": 0.01, "salt_vs_pepper": 0.5}

View File

@@ -0,0 +1,159 @@
"""
Texture augmentation transforms.
Provides texture effects for document image augmentation:
- PaperTexture: Adds paper grain/texture
- ScannerArtifacts: Adds scanner line and dust artifacts
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class PaperTexture(BaseAugmentation):
"""
Adds paper texture/grain to the image.
Simulates different paper types and ages.
Parameters:
texture_type: Type of texture ("random", "fine", "coarse") (default: "random").
intensity: Texture intensity, float or (min, max) tuple (default: (0.05, 0.15)).
"""
name = "paper_texture"
affects_geometry = False
def _validate_params(self) -> None:
intensity = self.params.get("intensity", (0.05, 0.15))
if isinstance(intensity, (int, float)):
if not (0 < intensity <= 1):
raise ValueError("intensity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
texture_type = self.params.get("texture_type", "random")
intensity = self.params.get("intensity", (0.05, 0.15))
if texture_type == "random":
texture_type = rng.choice(["fine", "coarse"])
if isinstance(intensity, tuple):
intensity = rng.uniform(intensity[0], intensity[1])
# Generate base noise
if texture_type == "fine":
# Fine grain texture
noise = rng.uniform(-1, 1, (h, w)).astype(np.float32)
noise = cv2.GaussianBlur(noise, (3, 3), 0)
else:
# Coarse texture
# Generate at lower resolution and upscale
small_h, small_w = h // 4, w // 4
noise = rng.uniform(-1, 1, (small_h, small_w)).astype(np.float32)
noise = cv2.resize(noise, (w, h), interpolation=cv2.INTER_LINEAR)
noise = cv2.GaussianBlur(noise, (5, 5), 0)
# Apply texture
output = image.astype(np.float32)
noise_3d = noise[:, :, np.newaxis] * intensity * 255
output = output + noise_3d
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"texture_type": texture_type, "intensity": intensity},
)
def get_preview_params(self) -> dict:
return {"texture_type": "coarse", "intensity": 0.15}
class ScannerArtifacts(BaseAugmentation):
"""
Adds scanner artifacts to the image.
Simulates scanner imperfections like lines and dust spots.
Parameters:
line_probability: Probability of adding scan lines (default: 0.3).
dust_probability: Probability of adding dust spots (default: 0.4).
"""
name = "scanner_artifacts"
affects_geometry = False
def _validate_params(self) -> None:
line_prob = self.params.get("line_probability", 0.3)
dust_prob = self.params.get("dust_probability", 0.4)
if not (0 <= line_prob <= 1):
raise ValueError("line_probability must be between 0 and 1")
if not (0 <= dust_prob <= 1):
raise ValueError("dust_probability must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
line_probability = self.params.get("line_probability", 0.3)
dust_probability = self.params.get("dust_probability", 0.4)
output = image.copy()
# Add scan lines
if rng.random() < line_probability:
num_lines = rng.integers(1, 4)
for _ in range(num_lines):
y = rng.integers(0, h)
thickness = rng.integers(1, 3)
# Light or dark line
color = rng.integers(200, 240) if rng.random() > 0.5 else rng.integers(50, 100)
# Make line partially transparent
alpha = rng.uniform(0.3, 0.6)
for dy in range(thickness):
if y + dy < h:
output[y + dy, :] = (
output[y + dy, :].astype(np.float32) * (1 - alpha) +
color * alpha
).astype(np.uint8)
# Add dust spots
if rng.random() < dust_probability:
num_dust = rng.integers(5, 20)
for _ in range(num_dust):
x = rng.integers(0, w)
y = rng.integers(0, h)
radius = rng.integers(1, 3)
# Dark dust spot
color = rng.integers(50, 120)
cv2.circle(output, (x, y), radius, int(color), -1)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={
"line_probability": line_probability,
"dust_probability": dust_probability,
},
)
def get_preview_params(self) -> dict:
return {"line_probability": 0.8, "dust_probability": 0.8}

View File

@@ -0,0 +1,5 @@
"""Shared training utilities."""
from .yolo_trainer import YOLOTrainer, TrainingConfig, TrainingResult
__all__ = ["YOLOTrainer", "TrainingConfig", "TrainingResult"]

View File

@@ -0,0 +1,239 @@
"""
Shared YOLO Training Module
Unified training logic for both CLI and Web API.
"""
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
logger = logging.getLogger(__name__)
@dataclass
class TrainingConfig:
"""Training configuration."""
# Model settings
model_path: str = "yolo11n.pt" # Base model or path to trained model
data_yaml: str = "" # Path to data.yaml
# Training hyperparameters
epochs: int = 100
batch_size: int = 16
image_size: int = 640
learning_rate: float = 0.01
device: str = "0"
# Output settings
project: str = "runs/train"
name: str = "invoice_fields"
# Performance settings
workers: int = 4
cache: bool = False
# Resume settings
resume: bool = False
resume_from: str | None = None # Path to checkpoint
# Document-specific augmentation (optimized for invoices)
augmentation: dict[str, Any] = field(default_factory=lambda: {
"degrees": 5.0,
"translate": 0.05,
"scale": 0.2,
"shear": 0.0,
"perspective": 0.0,
"flipud": 0.0,
"fliplr": 0.0,
"mosaic": 0.0,
"mixup": 0.0,
"hsv_h": 0.0,
"hsv_s": 0.1,
"hsv_v": 0.2,
})
@dataclass
class TrainingResult:
"""Training result."""
success: bool
model_path: str | None = None
metrics: dict[str, float] = field(default_factory=dict)
error: str | None = None
save_dir: str | None = None
class YOLOTrainer:
"""Unified YOLO trainer for CLI and Web API."""
def __init__(
self,
config: TrainingConfig,
log_callback: Callable[[str, str], None] | None = None,
):
"""
Initialize trainer.
Args:
config: Training configuration
log_callback: Optional callback for logging (level, message)
"""
self.config = config
self._log_callback = log_callback
def _log(self, level: str, message: str) -> None:
"""Log a message."""
if self._log_callback:
self._log_callback(level, message)
if level == "INFO":
logger.info(message)
elif level == "ERROR":
logger.error(message)
elif level == "WARNING":
logger.warning(message)
def validate_config(self) -> tuple[bool, str | None]:
"""
Validate training configuration.
Returns:
Tuple of (is_valid, error_message)
"""
# Check model path
model_path = Path(self.config.model_path)
if not model_path.suffix == ".pt":
# Could be a model name like "yolo11n.pt" which is downloaded
if not model_path.name.startswith("yolo"):
return False, f"Invalid model: {self.config.model_path}"
elif not model_path.exists():
return False, f"Model file not found: {self.config.model_path}"
# Check data.yaml
if not self.config.data_yaml:
return False, "data_yaml is required"
data_yaml = Path(self.config.data_yaml)
if not data_yaml.exists():
return False, f"data.yaml not found: {self.config.data_yaml}"
return True, None
def train(self) -> TrainingResult:
"""
Run YOLO training.
Returns:
TrainingResult with model path and metrics
"""
try:
from ultralytics import YOLO
except ImportError:
return TrainingResult(
success=False,
error="Ultralytics (YOLO) not installed. Install with: pip install ultralytics",
)
# Validate config
is_valid, error = self.validate_config()
if not is_valid:
return TrainingResult(success=False, error=error)
self._log("INFO", f"Starting YOLO training")
self._log("INFO", f" Model: {self.config.model_path}")
self._log("INFO", f" Data: {self.config.data_yaml}")
self._log("INFO", f" Epochs: {self.config.epochs}")
self._log("INFO", f" Batch size: {self.config.batch_size}")
self._log("INFO", f" Image size: {self.config.image_size}")
try:
# Load model
if self.config.resume and self.config.resume_from:
resume_path = Path(self.config.resume_from)
if resume_path.exists():
self._log("INFO", f"Resuming from: {resume_path}")
model = YOLO(str(resume_path))
else:
model = YOLO(self.config.model_path)
else:
model = YOLO(self.config.model_path)
# Build training arguments
train_args = {
"data": str(Path(self.config.data_yaml).absolute()),
"epochs": self.config.epochs,
"batch": self.config.batch_size,
"imgsz": self.config.image_size,
"lr0": self.config.learning_rate,
"device": self.config.device,
"project": self.config.project,
"name": self.config.name,
"exist_ok": True,
"pretrained": True,
"verbose": True,
"workers": self.config.workers,
"cache": self.config.cache,
"resume": self.config.resume and self.config.resume_from is not None,
}
# Add augmentation settings
train_args.update(self.config.augmentation)
# Train
results = model.train(**train_args)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
self._log("INFO", f"Training completed successfully")
self._log("INFO", f" Best model: {best_model}")
self._log("INFO", f" mAP@0.5: {metrics.get('mAP50', 'N/A')}")
return TrainingResult(
success=True,
model_path=str(best_model) if best_model.exists() else None,
metrics=metrics,
save_dir=str(results.save_dir),
)
except Exception as e:
self._log("ERROR", f"Training failed: {e}")
return TrainingResult(success=False, error=str(e))
def validate(self, split: str = "val") -> dict[str, float]:
"""
Run validation on trained model.
Args:
split: Dataset split to validate on ("val" or "test")
Returns:
Validation metrics
"""
try:
from ultralytics import YOLO
model = YOLO(self.config.model_path)
metrics = model.val(data=self.config.data_yaml, split=split)
return {
"mAP50": metrics.box.map50,
"mAP50-95": metrics.box.map,
"precision": metrics.box.mp,
"recall": metrics.box.mr,
}
except Exception as e:
self._log("ERROR", f"Validation failed: {e}")
return {}

View File

@@ -199,67 +199,63 @@ def main():
db.close()
return
# Start training
# Start training using shared trainer
print("\n" + "=" * 60)
print("Starting YOLO Training")
print("=" * 60)
from ultralytics import YOLO
from shared.training import YOLOTrainer, TrainingConfig
# Load model
# Determine resume checkpoint
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
if args.resume and last_checkpoint.exists():
print(f"Resuming from: {last_checkpoint}")
model = YOLO(str(last_checkpoint))
else:
model = YOLO(args.model)
resume_from = str(last_checkpoint) if args.resume and last_checkpoint.exists() else None
# Training arguments
# Create training config
data_yaml = dataset_dir / 'dataset.yaml'
train_args = {
'data': str(data_yaml.absolute()),
'epochs': args.epochs,
'batch': args.batch,
'imgsz': args.imgsz,
'project': args.project,
'name': args.name,
'device': args.device,
'exist_ok': True,
'pretrained': True,
'verbose': True,
'workers': args.workers,
'cache': args.cache,
'resume': args.resume and last_checkpoint.exists(),
# Document-specific augmentation settings
'degrees': 5.0,
'translate': 0.05,
'scale': 0.2,
'shear': 0.0,
'perspective': 0.0,
'flipud': 0.0,
'fliplr': 0.0,
'mosaic': 0.0,
'mixup': 0.0,
'hsv_h': 0.0,
'hsv_s': 0.1,
'hsv_v': 0.2,
}
config = TrainingConfig(
model_path=args.model,
data_yaml=str(data_yaml),
epochs=args.epochs,
batch_size=args.batch,
image_size=args.imgsz,
device=args.device,
project=args.project,
name=args.name,
workers=args.workers,
cache=args.cache,
resume=args.resume,
resume_from=resume_from,
)
# Train
results = model.train(**train_args)
# Run training
trainer = YOLOTrainer(config=config)
result = trainer.train()
if not result.success:
print(f"\nError: Training failed - {result.error}")
db.close()
sys.exit(1)
# Print results
print("\n" + "=" * 60)
print("Training Complete")
print("=" * 60)
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
print(f"Best model: {result.model_path}")
print(f"Save directory: {result.save_dir}")
if result.metrics:
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
# Validate on test set
print("\nRunning validation on test set...")
metrics = model.val(split='test')
print(f"mAP50: {metrics.box.map50:.4f}")
print(f"mAP50-95: {metrics.box.map:.4f}")
if result.model_path:
config.model_path = result.model_path
config.data_yaml = str(data_yaml)
test_trainer = YOLOTrainer(config=config)
test_metrics = test_trainer.validate(split='test')
if test_metrics:
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
# Close database
db.close()

View File

@@ -0,0 +1,106 @@
task: detect
mode: train
model: runs/train/invoice_fields/weights/last.pt
data: /home/kai/invoice-data/dataset/dataset.yaml
epochs: 100
time: null
patience: 100
batch: 8
imgsz: 1280
save: true
save_period: -1
cache: false
device: '0'
workers: 8
project: runs/train
name: invoice_fields
exist_ok: true
pretrained: true
optimizer: auto
verbose: true
seed: 0
deterministic: true
single_cls: false
rect: false
cos_lr: false
close_mosaic: 10
resume: runs/train/invoice_fields/weights/last.pt
amp: true
fraction: 1.0
profile: false
freeze: null
multi_scale: false
compile: false
overlap_mask: true
mask_ratio: 4
dropout: 0.0
val: true
split: val
save_json: false
conf: null
iou: 0.7
max_det: 300
half: false
dnn: false
plots: true
source: null
vid_stride: 1
stream_buffer: false
visualize: false
augment: false
agnostic_nms: false
classes: null
retina_masks: false
embed: null
show: false
save_frames: false
save_txt: false
save_conf: false
save_crop: false
show_labels: true
show_conf: true
show_boxes: true
line_width: null
format: torchscript
keras: false
optimize: false
int8: false
dynamic: false
simplify: true
opset: null
workspace: null
nms: false
lr0: 0.01
lrf: 0.01
momentum: 0.937
weight_decay: 0.0005
warmup_epochs: 3.0
warmup_momentum: 0.8
warmup_bias_lr: 0.0
box: 7.5
cls: 0.5
dfl: 1.5
pose: 12.0
kobj: 1.0
nbs: 64
hsv_h: 0.0
hsv_s: 0.1
hsv_v: 0.2
degrees: 5.0
translate: 0.05
scale: 0.2
shear: 0.0
perspective: 0.0
flipud: 0.0
fliplr: 0.0
bgr: 0.0
mosaic: 0.0
mixup: 0.0
cutmix: 0.0
copy_paste: 0.0
copy_paste_mode: flip
auto_augment: randaugment
erasing: 0.4
cfg: null
tracker: botsort.yaml
save_dir: /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields

View File

@@ -0,0 +1,101 @@
epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
1,507.213,0.97459,2.44056,1.09407,0.62221,0.6847,0.6581,0.53505,0.66652,1.13762,0.89651,0.00333127,0.00333127,0.00333127
2,960.144,0.65449,1.06648,0.90865,0.62998,0.73947,0.70066,0.58144,0.63394,1.03248,0.89183,0.00659863,0.00659863,0.00659863
3,1415.64,0.64921,0.97725,0.90901,0.70271,0.7713,0.76974,0.60362,0.7453,0.92005,0.9044,0.00979998,0.00979998,0.00979998
4,1871.6,0.61159,0.90785,0.90192,0.7264,0.77507,0.78958,0.66909,0.57561,0.83662,0.87737,0.009703,0.009703,0.009703
5,2317.68,0.55503,0.81107,0.88628,0.75742,0.81685,0.8287,0.70772,0.57193,0.74866,0.87531,0.009604,0.009604,0.009604
6,2756.85,0.52443,0.75067,0.8767,0.77362,0.82743,0.84758,0.73413,0.53604,0.70643,0.86947,0.009505,0.009505,0.009505
7,3211.38,0.49889,0.70526,0.8696,0.79041,0.83537,0.86367,0.75099,0.52784,0.67021,0.86686,0.009406,0.009406,0.009406
8,3657.41,0.47889,0.66715,0.86261,0.80402,0.85067,0.87834,0.7573,0.5365,0.64019,0.86754,0.009307,0.009307,0.009307
9,4100.61,0.46417,0.63485,0.85761,0.81776,0.84993,0.88612,0.77638,0.50423,0.61379,0.85539,0.009208,0.009208,0.009208
10,4544.47,0.45154,0.61098,0.85334,0.82044,0.86313,0.89168,0.7685,0.54578,0.59853,0.8668,0.009109,0.009109,0.009109
11,4983.87,0.44,0.58833,0.84961,0.82995,0.87408,0.90194,0.7972,0.48628,0.57037,0.84686,0.00901,0.00901,0.00901
12,5419.2,0.42826,0.57051,0.84622,0.83627,0.87342,0.90482,0.79702,0.49356,0.55869,0.84836,0.008911,0.008911,0.008911
13,5854.78,0.42052,0.55192,0.84327,0.83405,0.88243,0.90883,0.79807,0.50137,0.54503,0.85114,0.008812,0.008812,0.008812
14,6288.44,0.41497,0.53873,0.84105,0.83687,0.8769,0.90767,0.79881,0.49873,0.54655,0.85078,0.008713,0.008713,0.008713
15,6724.06,0.40759,0.52501,0.83904,0.84211,0.88472,0.91329,0.81111,0.47871,0.52982,0.84461,0.008614,0.008614,0.008614
16,7164.64,0.39956,0.51303,0.83591,0.85078,0.88609,0.91766,0.80948,0.49199,0.52052,0.84634,0.008515,0.008515,0.008515
17,7626.93,0.39651,0.50255,0.83453,0.84894,0.89151,0.92036,0.80185,0.5227,0.52229,0.8514,0.008416,0.008416,0.008416
18,490.978,0.3953,0.49497,0.83435,0.86113,0.87974,0.92114,0.80521,0.51284,0.51491,0.84926,0.008317,0.008317,0.008317
19,976.368,0.39234,0.49242,0.8328,0.86521,0.88431,0.92347,0.80898,0.5037,0.50744,0.84824,0.008218,0.008218,0.008218
20,1457.68,0.3864,0.48591,0.8308,0.86361,0.88895,0.92337,0.80562,0.51844,0.50692,0.8497,0.008119,0.008119,0.008119
21,1938.95,0.3826,0.47627,0.83086,0.85282,0.89488,0.92276,0.80107,0.53348,0.50706,0.85418,0.00802,0.00802,0.00802
22,2418,0.37996,0.46715,0.82911,0.86698,0.89705,0.92757,0.81493,0.5038,0.49105,0.84342,0.007921,0.007921,0.007921
23,2900.33,0.37555,0.46089,0.82769,0.86871,0.89566,0.92949,0.82251,0.48233,0.48473,0.84078,0.007822,0.007822,0.007822
24,3381.33,0.3717,0.45531,0.82676,0.87203,0.89509,0.93002,0.80968,0.52892,0.49231,0.85017,0.007723,0.007723,0.007723
25,3867.95,0.36913,0.44591,0.82544,0.87636,0.89074,0.93018,0.81083,0.5286,0.4882,0.84889,0.007624,0.007624,0.007624
26,4354.94,0.3662,0.44016,0.82483,0.86957,0.89807,0.92963,0.80554,0.5486,0.49546,0.8527,0.007525,0.007525,0.007525
27,4839.34,0.36364,0.43794,0.82368,0.87515,0.89602,0.93102,0.80927,0.54175,0.49001,0.85159,0.007426,0.007426,0.007426
28,5324.36,0.36043,0.42951,0.8234,0.87074,0.90178,0.93175,0.8107,0.53742,0.4862,0.85096,0.007327,0.007327,0.007327
29,5810.68,0.35852,0.42834,0.82242,0.87574,0.8992,0.93108,0.80885,0.53965,0.48712,0.85055,0.007228,0.007228,0.007228
30,6294.55,0.35572,0.42015,0.82155,0.87635,0.90084,0.93172,0.81312,0.5252,0.48089,0.84711,0.007129,0.007129,0.007129
31,6778.25,0.3539,0.41622,0.82025,0.8777,0.90106,0.93275,0.81359,0.52726,0.47864,0.84712,0.00703,0.00703,0.00703
32,7261.88,0.35095,0.4105,0.82048,0.87684,0.90364,0.9337,0.81598,0.52269,0.47547,0.84625,0.006931,0.006931,0.006931
33,7745.73,0.34835,0.40539,0.81904,0.87629,0.90454,0.933,0.81501,0.52381,0.47649,0.84713,0.006832,0.006832,0.006832
34,8234.78,0.34699,0.402,0.8182,0.8755,0.90506,0.93336,0.81532,0.52435,0.4756,0.84749,0.006733,0.006733,0.006733
35,8716.45,0.34487,0.39711,0.81704,0.87551,0.90432,0.93325,0.81529,0.52489,0.47558,0.84763,0.006634,0.006634,0.006634
36,9200.09,0.34343,0.39638,0.81749,0.87551,0.90495,0.93343,0.81477,0.52659,0.4765,0.84798,0.006535,0.006535,0.006535
37,9683.76,0.34013,0.3907,0.81647,0.87738,0.90416,0.93371,0.8152,0.52755,0.47568,0.84807,0.006436,0.006436,0.006436
38,10164.4,0.33826,0.38626,0.81468,0.87923,0.90342,0.93399,0.81571,0.52756,0.47527,0.84794,0.006337,0.006337,0.006337
39,10648.6,0.3366,0.3812,0.81517,0.8786,0.90333,0.93368,0.81448,0.53035,0.47601,0.84857,0.006238,0.006238,0.006238
40,11130.6,0.3353,0.37879,0.8151,0.87974,0.90405,0.93411,0.81544,0.52739,0.47333,0.84768,0.006139,0.006139,0.006139
41,11612.6,0.33395,0.37397,0.8143,0.88034,0.90315,0.93432,0.81514,0.529,0.47304,0.84765,0.00604,0.00604,0.00604
42,12097.5,0.33104,0.37164,0.81448,0.87942,0.90449,0.93429,0.81484,0.53077,0.47386,0.84799,0.005941,0.005941,0.005941
43,12579.8,0.33016,0.3681,0.81305,0.87964,0.90412,0.93457,0.81473,0.53129,0.47382,0.84789,0.005842,0.005842,0.005842
44,13065.3,0.32845,0.36431,0.81191,0.88092,0.90312,0.9348,0.81538,0.53025,0.47337,0.84751,0.005743,0.005743,0.005743
45,13550.4,0.32642,0.36127,0.81295,0.8805,0.905,0.93502,0.81544,0.53124,0.47333,0.8477,0.005644,0.005644,0.005644
46,14034.3,0.32483,0.35761,0.81158,0.87898,0.90636,0.935,0.81556,0.53135,0.47317,0.84772,0.005545,0.005545,0.005545
47,14517.3,0.32236,0.35337,0.81014,0.88018,0.90502,0.93493,0.81547,0.53228,0.473,0.8478,0.005446,0.005446,0.005446
48,14998.6,0.3211,0.35051,0.81064,0.87941,0.9055,0.93481,0.81473,0.5353,0.47335,0.84839,0.005347,0.005347,0.005347
49,15479.2,0.32043,0.34797,0.8097,0.87884,0.90584,0.93482,0.81429,0.53741,0.47359,0.84867,0.005248,0.005248,0.005248
50,15962,0.3182,0.34589,0.80867,0.87776,0.90777,0.93476,0.81395,0.53841,0.47358,0.84884,0.005149,0.005149,0.005149
51,16445.4,0.31722,0.34332,0.80879,0.87932,0.90605,0.93488,0.81439,0.53827,0.47301,0.84874,0.00505,0.00505,0.00505
52,16925.3,0.31437,0.33963,0.80846,0.87925,0.90688,0.93521,0.81468,0.53772,0.47254,0.84861,0.004951,0.004951,0.004951
53,17410.4,0.31435,0.33632,0.80792,0.88015,0.9069,0.93538,0.8148,0.53789,0.47193,0.84865,0.004852,0.004852,0.004852
54,17893.2,0.31341,0.33352,0.80833,0.88132,0.90605,0.93552,0.81465,0.53934,0.47218,0.84889,0.004753,0.004753,0.004753
55,18375.2,0.31091,0.33239,0.80688,0.88151,0.90617,0.93547,0.81453,0.54019,0.47227,0.84885,0.004654,0.004654,0.004654
56,18857.4,0.30906,0.32753,0.80616,0.88348,0.90504,0.93565,0.81495,0.53974,0.47179,0.84864,0.004555,0.004555,0.004555
57,19340.6,0.30722,0.32334,0.80582,0.88502,0.90326,0.93558,0.81484,0.53977,0.47174,0.84859,0.004456,0.004456,0.004456
58,19824.8,0.30575,0.3195,0.80592,0.88348,0.90487,0.93566,0.81511,0.53921,0.47158,0.84832,0.004357,0.004357,0.004357
59,20305.7,0.30426,0.31846,0.80494,0.88419,0.90477,0.93575,0.81534,0.5387,0.47138,0.84819,0.004258,0.004258,0.004258
60,20785.1,0.30295,0.3154,0.80494,0.88302,0.90624,0.93568,0.81572,0.53769,0.47106,0.84788,0.004159,0.004159,0.004159
61,21263.8,0.3013,0.3131,0.80436,0.88438,0.90545,0.93572,0.81606,0.53622,0.47079,0.84762,0.00406,0.00406,0.00406
62,21746,0.30019,0.31077,0.80391,0.88296,0.90732,0.93578,0.8165,0.53455,0.47011,0.84724,0.003961,0.003961,0.003961
63,22225,0.29841,0.30656,0.80379,0.88244,0.90779,0.93591,0.81693,0.53417,0.47003,0.84715,0.003862,0.003862,0.003862
64,22704.7,0.29696,0.30489,0.80284,0.88493,0.90596,0.9359,0.81716,0.53362,0.47013,0.84707,0.003763,0.003763,0.003763
65,23182.9,0.2952,0.30022,0.80288,0.88366,0.90731,0.93594,0.81737,0.533,0.47024,0.84695,0.003664,0.003664,0.003664
66,23663.6,0.29337,0.29898,0.80273,0.88514,0.90611,0.93609,0.81805,0.53189,0.47015,0.84664,0.003565,0.003565,0.003565
67,24149.7,0.29248,0.29492,0.80242,0.88664,0.90536,0.93607,0.81783,0.53208,0.4704,0.84665,0.003466,0.003466,0.003466
68,24629,0.28987,0.29155,0.8009,0.89014,0.90207,0.93611,0.81811,0.53203,0.47046,0.84656,0.003367,0.003367,0.003367
69,25109.9,0.28942,0.29004,0.8011,0.88939,0.90353,0.93619,0.81872,0.53127,0.4704,0.84631,0.003268,0.003268,0.003268
70,25590.5,0.28752,0.28571,0.80059,0.88926,0.90393,0.93627,0.81909,0.53074,0.47023,0.84624,0.003169,0.003169,0.003169
71,26072.8,0.28546,0.28301,0.7999,0.88844,0.90476,0.93631,0.81967,0.52999,0.47005,0.84612,0.00307,0.00307,0.00307
72,26552.9,0.2842,0.28027,0.79942,0.88801,0.90505,0.93622,0.81978,0.52939,0.46994,0.84607,0.002971,0.002971,0.002971
73,27035.5,0.28297,0.27907,0.79956,0.88781,0.90499,0.93615,0.82032,0.528,0.4694,0.84578,0.002872,0.002872,0.002872
74,27518.8,0.2812,0.27446,0.79886,0.88848,0.90493,0.93611,0.82061,0.52675,0.46906,0.84549,0.002773,0.002773,0.002773
75,28007.4,0.27866,0.27202,0.79684,0.88889,0.90467,0.9361,0.82099,0.5257,0.4692,0.84529,0.002674,0.002674,0.002674
76,28499.1,0.27708,0.26798,0.7978,0.88807,0.9054,0.93615,0.82138,0.52574,0.46928,0.84523,0.002575,0.002575,0.002575
77,28993.2,0.27398,0.2644,0.79612,0.88825,0.9055,0.93616,0.82161,0.52491,0.46925,0.84496,0.002476,0.002476,0.002476
78,29480.5,0.27359,0.26209,0.79678,0.88876,0.90547,0.93617,0.82172,0.52467,0.4691,0.84498,0.002377,0.002377,0.002377
79,29970.7,0.27153,0.25905,0.79585,0.88942,0.90548,0.93613,0.82211,0.52407,0.46892,0.84474,0.002278,0.002278,0.002278
80,30453.8,0.2696,0.25647,0.79513,0.88897,0.90665,0.93617,0.82246,0.52298,0.46881,0.84445,0.002179,0.002179,0.002179
81,30936.1,0.26895,0.25375,0.79466,0.88799,0.90791,0.93617,0.8226,0.52253,0.46904,0.84445,0.00208,0.00208,0.00208
82,31418.6,0.26733,0.25025,0.79474,0.88945,0.90695,0.93608,0.82293,0.52172,0.4694,0.84434,0.001981,0.001981,0.001981
83,31911.4,0.26537,0.24754,0.79479,0.89112,0.90496,0.93604,0.82338,0.52094,0.46932,0.84421,0.001882,0.001882,0.001882
84,32402,0.26344,0.24514,0.79369,0.8928,0.90319,0.93598,0.82353,0.52015,0.46957,0.84407,0.001783,0.001783,0.001783
85,32903.1,0.26045,0.24052,0.79226,0.89211,0.90347,0.93615,0.82427,0.51861,0.46958,0.84372,0.001684,0.001684,0.001684
86,33413.6,0.25867,0.23781,0.79209,0.89286,0.90279,0.9362,0.82493,0.51664,0.47018,0.84338,0.001585,0.001585,0.001585
87,33923.6,0.257,0.23463,0.792,0.89299,0.90297,0.93614,0.8254,0.5147,0.46974,0.84305,0.001486,0.001486,0.001486
88,34436.4,0.25569,0.23153,0.79149,0.89278,0.90277,0.93609,0.82622,0.51242,0.4697,0.84266,0.001387,0.001387,0.001387
89,34949.6,0.25343,0.22868,0.791,0.89137,0.90434,0.93599,0.82675,0.51036,0.46947,0.84227,0.001288,0.001288,0.001288
90,35449.6,0.25194,0.22502,0.79051,0.89128,0.90489,0.93591,0.82729,0.5092,0.46975,0.84219,0.001189,0.001189,0.001189
91,35960.5,0.2502,0.22225,0.78959,0.8898,0.90646,0.93586,0.82781,0.50761,0.46999,0.84186,0.00109,0.00109,0.00109
92,36452.1,0.24777,0.21844,0.78906,0.89054,0.9057,0.93593,0.82831,0.50603,0.47043,0.84154,0.000991,0.000991,0.000991
93,36942.9,0.24554,0.21503,0.78861,0.88979,0.90679,0.93584,0.82858,0.50495,0.4703,0.84125,0.000892,0.000892,0.000892
94,37430.3,0.2434,0.21193,0.78799,0.88928,0.90756,0.93566,0.8288,0.5041,0.47075,0.84109,0.000793,0.000793,0.000793
95,37918.7,0.2413,0.20892,0.78736,0.8899,0.90683,0.93567,0.82882,0.50339,0.47152,0.84105,0.000694,0.000694,0.000694
96,38404.3,0.2405,0.20619,0.78595,0.88999,0.90713,0.9355,0.82912,0.50244,0.47239,0.84104,0.000595,0.000595,0.000595
97,38893.8,0.23808,0.2031,0.78683,0.88982,0.90634,0.93531,0.82938,0.50187,0.47281,0.84116,0.000496,0.000496,0.000496
98,39382.8,0.23581,0.20034,0.78643,0.89144,0.9045,0.93517,0.82959,0.50119,0.47383,0.84128,0.000397,0.000397,0.000397
99,39871.7,0.23432,0.19778,0.78568,0.89187,0.90415,0.93488,0.82953,0.50058,0.47452,0.84126,0.000298,0.000298,0.000298
100,40359.7,0.233,0.19485,0.78528,0.89228,0.90349,0.93471,0.82961,0.50029,0.47497,0.84139,0.000199,0.000199,0.000199
1 epoch time train/box_loss train/cls_loss train/dfl_loss metrics/precision(B) metrics/recall(B) metrics/mAP50(B) metrics/mAP50-95(B) val/box_loss val/cls_loss val/dfl_loss lr/pg0 lr/pg1 lr/pg2
2 1 507.213 0.97459 2.44056 1.09407 0.62221 0.6847 0.6581 0.53505 0.66652 1.13762 0.89651 0.00333127 0.00333127 0.00333127
3 2 960.144 0.65449 1.06648 0.90865 0.62998 0.73947 0.70066 0.58144 0.63394 1.03248 0.89183 0.00659863 0.00659863 0.00659863
4 3 1415.64 0.64921 0.97725 0.90901 0.70271 0.7713 0.76974 0.60362 0.7453 0.92005 0.9044 0.00979998 0.00979998 0.00979998
5 4 1871.6 0.61159 0.90785 0.90192 0.7264 0.77507 0.78958 0.66909 0.57561 0.83662 0.87737 0.009703 0.009703 0.009703
6 5 2317.68 0.55503 0.81107 0.88628 0.75742 0.81685 0.8287 0.70772 0.57193 0.74866 0.87531 0.009604 0.009604 0.009604
7 6 2756.85 0.52443 0.75067 0.8767 0.77362 0.82743 0.84758 0.73413 0.53604 0.70643 0.86947 0.009505 0.009505 0.009505
8 7 3211.38 0.49889 0.70526 0.8696 0.79041 0.83537 0.86367 0.75099 0.52784 0.67021 0.86686 0.009406 0.009406 0.009406
9 8 3657.41 0.47889 0.66715 0.86261 0.80402 0.85067 0.87834 0.7573 0.5365 0.64019 0.86754 0.009307 0.009307 0.009307
10 9 4100.61 0.46417 0.63485 0.85761 0.81776 0.84993 0.88612 0.77638 0.50423 0.61379 0.85539 0.009208 0.009208 0.009208
11 10 4544.47 0.45154 0.61098 0.85334 0.82044 0.86313 0.89168 0.7685 0.54578 0.59853 0.8668 0.009109 0.009109 0.009109
12 11 4983.87 0.44 0.58833 0.84961 0.82995 0.87408 0.90194 0.7972 0.48628 0.57037 0.84686 0.00901 0.00901 0.00901
13 12 5419.2 0.42826 0.57051 0.84622 0.83627 0.87342 0.90482 0.79702 0.49356 0.55869 0.84836 0.008911 0.008911 0.008911
14 13 5854.78 0.42052 0.55192 0.84327 0.83405 0.88243 0.90883 0.79807 0.50137 0.54503 0.85114 0.008812 0.008812 0.008812
15 14 6288.44 0.41497 0.53873 0.84105 0.83687 0.8769 0.90767 0.79881 0.49873 0.54655 0.85078 0.008713 0.008713 0.008713
16 15 6724.06 0.40759 0.52501 0.83904 0.84211 0.88472 0.91329 0.81111 0.47871 0.52982 0.84461 0.008614 0.008614 0.008614
17 16 7164.64 0.39956 0.51303 0.83591 0.85078 0.88609 0.91766 0.80948 0.49199 0.52052 0.84634 0.008515 0.008515 0.008515
18 17 7626.93 0.39651 0.50255 0.83453 0.84894 0.89151 0.92036 0.80185 0.5227 0.52229 0.8514 0.008416 0.008416 0.008416
19 18 490.978 0.3953 0.49497 0.83435 0.86113 0.87974 0.92114 0.80521 0.51284 0.51491 0.84926 0.008317 0.008317 0.008317
20 19 976.368 0.39234 0.49242 0.8328 0.86521 0.88431 0.92347 0.80898 0.5037 0.50744 0.84824 0.008218 0.008218 0.008218
21 20 1457.68 0.3864 0.48591 0.8308 0.86361 0.88895 0.92337 0.80562 0.51844 0.50692 0.8497 0.008119 0.008119 0.008119
22 21 1938.95 0.3826 0.47627 0.83086 0.85282 0.89488 0.92276 0.80107 0.53348 0.50706 0.85418 0.00802 0.00802 0.00802
23 22 2418 0.37996 0.46715 0.82911 0.86698 0.89705 0.92757 0.81493 0.5038 0.49105 0.84342 0.007921 0.007921 0.007921
24 23 2900.33 0.37555 0.46089 0.82769 0.86871 0.89566 0.92949 0.82251 0.48233 0.48473 0.84078 0.007822 0.007822 0.007822
25 24 3381.33 0.3717 0.45531 0.82676 0.87203 0.89509 0.93002 0.80968 0.52892 0.49231 0.85017 0.007723 0.007723 0.007723
26 25 3867.95 0.36913 0.44591 0.82544 0.87636 0.89074 0.93018 0.81083 0.5286 0.4882 0.84889 0.007624 0.007624 0.007624
27 26 4354.94 0.3662 0.44016 0.82483 0.86957 0.89807 0.92963 0.80554 0.5486 0.49546 0.8527 0.007525 0.007525 0.007525
28 27 4839.34 0.36364 0.43794 0.82368 0.87515 0.89602 0.93102 0.80927 0.54175 0.49001 0.85159 0.007426 0.007426 0.007426
29 28 5324.36 0.36043 0.42951 0.8234 0.87074 0.90178 0.93175 0.8107 0.53742 0.4862 0.85096 0.007327 0.007327 0.007327
30 29 5810.68 0.35852 0.42834 0.82242 0.87574 0.8992 0.93108 0.80885 0.53965 0.48712 0.85055 0.007228 0.007228 0.007228
31 30 6294.55 0.35572 0.42015 0.82155 0.87635 0.90084 0.93172 0.81312 0.5252 0.48089 0.84711 0.007129 0.007129 0.007129
32 31 6778.25 0.3539 0.41622 0.82025 0.8777 0.90106 0.93275 0.81359 0.52726 0.47864 0.84712 0.00703 0.00703 0.00703
33 32 7261.88 0.35095 0.4105 0.82048 0.87684 0.90364 0.9337 0.81598 0.52269 0.47547 0.84625 0.006931 0.006931 0.006931
34 33 7745.73 0.34835 0.40539 0.81904 0.87629 0.90454 0.933 0.81501 0.52381 0.47649 0.84713 0.006832 0.006832 0.006832
35 34 8234.78 0.34699 0.402 0.8182 0.8755 0.90506 0.93336 0.81532 0.52435 0.4756 0.84749 0.006733 0.006733 0.006733
36 35 8716.45 0.34487 0.39711 0.81704 0.87551 0.90432 0.93325 0.81529 0.52489 0.47558 0.84763 0.006634 0.006634 0.006634
37 36 9200.09 0.34343 0.39638 0.81749 0.87551 0.90495 0.93343 0.81477 0.52659 0.4765 0.84798 0.006535 0.006535 0.006535
38 37 9683.76 0.34013 0.3907 0.81647 0.87738 0.90416 0.93371 0.8152 0.52755 0.47568 0.84807 0.006436 0.006436 0.006436
39 38 10164.4 0.33826 0.38626 0.81468 0.87923 0.90342 0.93399 0.81571 0.52756 0.47527 0.84794 0.006337 0.006337 0.006337
40 39 10648.6 0.3366 0.3812 0.81517 0.8786 0.90333 0.93368 0.81448 0.53035 0.47601 0.84857 0.006238 0.006238 0.006238
41 40 11130.6 0.3353 0.37879 0.8151 0.87974 0.90405 0.93411 0.81544 0.52739 0.47333 0.84768 0.006139 0.006139 0.006139
42 41 11612.6 0.33395 0.37397 0.8143 0.88034 0.90315 0.93432 0.81514 0.529 0.47304 0.84765 0.00604 0.00604 0.00604
43 42 12097.5 0.33104 0.37164 0.81448 0.87942 0.90449 0.93429 0.81484 0.53077 0.47386 0.84799 0.005941 0.005941 0.005941
44 43 12579.8 0.33016 0.3681 0.81305 0.87964 0.90412 0.93457 0.81473 0.53129 0.47382 0.84789 0.005842 0.005842 0.005842
45 44 13065.3 0.32845 0.36431 0.81191 0.88092 0.90312 0.9348 0.81538 0.53025 0.47337 0.84751 0.005743 0.005743 0.005743
46 45 13550.4 0.32642 0.36127 0.81295 0.8805 0.905 0.93502 0.81544 0.53124 0.47333 0.8477 0.005644 0.005644 0.005644
47 46 14034.3 0.32483 0.35761 0.81158 0.87898 0.90636 0.935 0.81556 0.53135 0.47317 0.84772 0.005545 0.005545 0.005545
48 47 14517.3 0.32236 0.35337 0.81014 0.88018 0.90502 0.93493 0.81547 0.53228 0.473 0.8478 0.005446 0.005446 0.005446
49 48 14998.6 0.3211 0.35051 0.81064 0.87941 0.9055 0.93481 0.81473 0.5353 0.47335 0.84839 0.005347 0.005347 0.005347
50 49 15479.2 0.32043 0.34797 0.8097 0.87884 0.90584 0.93482 0.81429 0.53741 0.47359 0.84867 0.005248 0.005248 0.005248
51 50 15962 0.3182 0.34589 0.80867 0.87776 0.90777 0.93476 0.81395 0.53841 0.47358 0.84884 0.005149 0.005149 0.005149
52 51 16445.4 0.31722 0.34332 0.80879 0.87932 0.90605 0.93488 0.81439 0.53827 0.47301 0.84874 0.00505 0.00505 0.00505
53 52 16925.3 0.31437 0.33963 0.80846 0.87925 0.90688 0.93521 0.81468 0.53772 0.47254 0.84861 0.004951 0.004951 0.004951
54 53 17410.4 0.31435 0.33632 0.80792 0.88015 0.9069 0.93538 0.8148 0.53789 0.47193 0.84865 0.004852 0.004852 0.004852
55 54 17893.2 0.31341 0.33352 0.80833 0.88132 0.90605 0.93552 0.81465 0.53934 0.47218 0.84889 0.004753 0.004753 0.004753
56 55 18375.2 0.31091 0.33239 0.80688 0.88151 0.90617 0.93547 0.81453 0.54019 0.47227 0.84885 0.004654 0.004654 0.004654
57 56 18857.4 0.30906 0.32753 0.80616 0.88348 0.90504 0.93565 0.81495 0.53974 0.47179 0.84864 0.004555 0.004555 0.004555
58 57 19340.6 0.30722 0.32334 0.80582 0.88502 0.90326 0.93558 0.81484 0.53977 0.47174 0.84859 0.004456 0.004456 0.004456
59 58 19824.8 0.30575 0.3195 0.80592 0.88348 0.90487 0.93566 0.81511 0.53921 0.47158 0.84832 0.004357 0.004357 0.004357
60 59 20305.7 0.30426 0.31846 0.80494 0.88419 0.90477 0.93575 0.81534 0.5387 0.47138 0.84819 0.004258 0.004258 0.004258
61 60 20785.1 0.30295 0.3154 0.80494 0.88302 0.90624 0.93568 0.81572 0.53769 0.47106 0.84788 0.004159 0.004159 0.004159
62 61 21263.8 0.3013 0.3131 0.80436 0.88438 0.90545 0.93572 0.81606 0.53622 0.47079 0.84762 0.00406 0.00406 0.00406
63 62 21746 0.30019 0.31077 0.80391 0.88296 0.90732 0.93578 0.8165 0.53455 0.47011 0.84724 0.003961 0.003961 0.003961
64 63 22225 0.29841 0.30656 0.80379 0.88244 0.90779 0.93591 0.81693 0.53417 0.47003 0.84715 0.003862 0.003862 0.003862
65 64 22704.7 0.29696 0.30489 0.80284 0.88493 0.90596 0.9359 0.81716 0.53362 0.47013 0.84707 0.003763 0.003763 0.003763
66 65 23182.9 0.2952 0.30022 0.80288 0.88366 0.90731 0.93594 0.81737 0.533 0.47024 0.84695 0.003664 0.003664 0.003664
67 66 23663.6 0.29337 0.29898 0.80273 0.88514 0.90611 0.93609 0.81805 0.53189 0.47015 0.84664 0.003565 0.003565 0.003565
68 67 24149.7 0.29248 0.29492 0.80242 0.88664 0.90536 0.93607 0.81783 0.53208 0.4704 0.84665 0.003466 0.003466 0.003466
69 68 24629 0.28987 0.29155 0.8009 0.89014 0.90207 0.93611 0.81811 0.53203 0.47046 0.84656 0.003367 0.003367 0.003367
70 69 25109.9 0.28942 0.29004 0.8011 0.88939 0.90353 0.93619 0.81872 0.53127 0.4704 0.84631 0.003268 0.003268 0.003268
71 70 25590.5 0.28752 0.28571 0.80059 0.88926 0.90393 0.93627 0.81909 0.53074 0.47023 0.84624 0.003169 0.003169 0.003169
72 71 26072.8 0.28546 0.28301 0.7999 0.88844 0.90476 0.93631 0.81967 0.52999 0.47005 0.84612 0.00307 0.00307 0.00307
73 72 26552.9 0.2842 0.28027 0.79942 0.88801 0.90505 0.93622 0.81978 0.52939 0.46994 0.84607 0.002971 0.002971 0.002971
74 73 27035.5 0.28297 0.27907 0.79956 0.88781 0.90499 0.93615 0.82032 0.528 0.4694 0.84578 0.002872 0.002872 0.002872
75 74 27518.8 0.2812 0.27446 0.79886 0.88848 0.90493 0.93611 0.82061 0.52675 0.46906 0.84549 0.002773 0.002773 0.002773
76 75 28007.4 0.27866 0.27202 0.79684 0.88889 0.90467 0.9361 0.82099 0.5257 0.4692 0.84529 0.002674 0.002674 0.002674
77 76 28499.1 0.27708 0.26798 0.7978 0.88807 0.9054 0.93615 0.82138 0.52574 0.46928 0.84523 0.002575 0.002575 0.002575
78 77 28993.2 0.27398 0.2644 0.79612 0.88825 0.9055 0.93616 0.82161 0.52491 0.46925 0.84496 0.002476 0.002476 0.002476
79 78 29480.5 0.27359 0.26209 0.79678 0.88876 0.90547 0.93617 0.82172 0.52467 0.4691 0.84498 0.002377 0.002377 0.002377
80 79 29970.7 0.27153 0.25905 0.79585 0.88942 0.90548 0.93613 0.82211 0.52407 0.46892 0.84474 0.002278 0.002278 0.002278
81 80 30453.8 0.2696 0.25647 0.79513 0.88897 0.90665 0.93617 0.82246 0.52298 0.46881 0.84445 0.002179 0.002179 0.002179
82 81 30936.1 0.26895 0.25375 0.79466 0.88799 0.90791 0.93617 0.8226 0.52253 0.46904 0.84445 0.00208 0.00208 0.00208
83 82 31418.6 0.26733 0.25025 0.79474 0.88945 0.90695 0.93608 0.82293 0.52172 0.4694 0.84434 0.001981 0.001981 0.001981
84 83 31911.4 0.26537 0.24754 0.79479 0.89112 0.90496 0.93604 0.82338 0.52094 0.46932 0.84421 0.001882 0.001882 0.001882
85 84 32402 0.26344 0.24514 0.79369 0.8928 0.90319 0.93598 0.82353 0.52015 0.46957 0.84407 0.001783 0.001783 0.001783
86 85 32903.1 0.26045 0.24052 0.79226 0.89211 0.90347 0.93615 0.82427 0.51861 0.46958 0.84372 0.001684 0.001684 0.001684
87 86 33413.6 0.25867 0.23781 0.79209 0.89286 0.90279 0.9362 0.82493 0.51664 0.47018 0.84338 0.001585 0.001585 0.001585
88 87 33923.6 0.257 0.23463 0.792 0.89299 0.90297 0.93614 0.8254 0.5147 0.46974 0.84305 0.001486 0.001486 0.001486
89 88 34436.4 0.25569 0.23153 0.79149 0.89278 0.90277 0.93609 0.82622 0.51242 0.4697 0.84266 0.001387 0.001387 0.001387
90 89 34949.6 0.25343 0.22868 0.791 0.89137 0.90434 0.93599 0.82675 0.51036 0.46947 0.84227 0.001288 0.001288 0.001288
91 90 35449.6 0.25194 0.22502 0.79051 0.89128 0.90489 0.93591 0.82729 0.5092 0.46975 0.84219 0.001189 0.001189 0.001189
92 91 35960.5 0.2502 0.22225 0.78959 0.8898 0.90646 0.93586 0.82781 0.50761 0.46999 0.84186 0.00109 0.00109 0.00109
93 92 36452.1 0.24777 0.21844 0.78906 0.89054 0.9057 0.93593 0.82831 0.50603 0.47043 0.84154 0.000991 0.000991 0.000991
94 93 36942.9 0.24554 0.21503 0.78861 0.88979 0.90679 0.93584 0.82858 0.50495 0.4703 0.84125 0.000892 0.000892 0.000892
95 94 37430.3 0.2434 0.21193 0.78799 0.88928 0.90756 0.93566 0.8288 0.5041 0.47075 0.84109 0.000793 0.000793 0.000793
96 95 37918.7 0.2413 0.20892 0.78736 0.8899 0.90683 0.93567 0.82882 0.50339 0.47152 0.84105 0.000694 0.000694 0.000694
97 96 38404.3 0.2405 0.20619 0.78595 0.88999 0.90713 0.9355 0.82912 0.50244 0.47239 0.84104 0.000595 0.000595 0.000595
98 97 38893.8 0.23808 0.2031 0.78683 0.88982 0.90634 0.93531 0.82938 0.50187 0.47281 0.84116 0.000496 0.000496 0.000496
99 98 39382.8 0.23581 0.20034 0.78643 0.89144 0.9045 0.93517 0.82959 0.50119 0.47383 0.84128 0.000397 0.000397 0.000397
100 99 39871.7 0.23432 0.19778 0.78568 0.89187 0.90415 0.93488 0.82953 0.50058 0.47452 0.84126 0.000298 0.000298 0.000298
101 100 40359.7 0.233 0.19485 0.78528 0.89228 0.90349 0.93471 0.82961 0.50029 0.47497 0.84139 0.000199 0.000199 0.000199

View File

@@ -0,0 +1,106 @@
task: detect
mode: train
model: yolo11n.pt
data: /home/kai/invoice-data/dataset/dataset.yaml
epochs: 100
time: null
patience: 100
batch: 16
imgsz: 1280
save: true
save_period: -1
cache: false
device: '0'
workers: 8
project: runs/train
name: invoice_yolo11n_full
exist_ok: true
pretrained: true
optimizer: auto
verbose: true
seed: 0
deterministic: true
single_cls: false
rect: false
cos_lr: false
close_mosaic: 10
resume: false
amp: true
fraction: 1.0
profile: false
freeze: null
multi_scale: false
compile: false
overlap_mask: true
mask_ratio: 4
dropout: 0.0
val: true
split: val
save_json: false
conf: null
iou: 0.7
max_det: 300
half: false
dnn: false
plots: true
source: null
vid_stride: 1
stream_buffer: false
visualize: false
augment: false
agnostic_nms: false
classes: null
retina_masks: false
embed: null
show: false
save_frames: false
save_txt: false
save_conf: false
save_crop: false
show_labels: true
show_conf: true
show_boxes: true
line_width: null
format: torchscript
keras: false
optimize: false
int8: false
dynamic: false
simplify: true
opset: null
workspace: null
nms: false
lr0: 0.01
lrf: 0.01
momentum: 0.937
weight_decay: 0.0005
warmup_epochs: 3.0
warmup_momentum: 0.8
warmup_bias_lr: 0.1
box: 7.5
cls: 0.5
dfl: 1.5
pose: 12.0
kobj: 1.0
nbs: 64
hsv_h: 0.0
hsv_s: 0.1
hsv_v: 0.2
degrees: 5.0
translate: 0.05
scale: 0.2
shear: 0.0
perspective: 0.0
flipud: 0.0
fliplr: 0.0
bgr: 0.0
mosaic: 0.0
mixup: 0.0
cutmix: 0.0
copy_paste: 0.0
copy_paste_mode: flip
auto_augment: randaugment
erasing: 0.4
cfg: null
tracker: botsort.yaml
save_dir: /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_yolo11n_full

View File

@@ -0,0 +1,101 @@
epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
1,217.641,0.79856,2.56507,1.01986,0.8921,0.84545,0.90033,0.81508,0.49347,0.92369,0.83815,0.00332991,0.00332991,0.00332991
2,410.275,0.506,1.04596,0.85661,0.9164,0.85418,0.92852,0.79726,0.63851,0.72277,0.87404,0.00659728,0.00659728,0.00659728
3,598.713,0.49618,0.68647,0.85624,0.93775,0.80014,0.91481,0.80383,0.5798,0.76691,0.86835,0.00979865,0.00979865,0.00979865
4,782.868,0.44059,0.53532,0.84299,0.94668,0.90421,0.96101,0.8832,0.43961,0.49649,0.84298,0.009703,0.009703,0.009703
5,967.898,0.37596,0.44308,0.82667,0.88316,0.81492,0.91376,0.82272,0.50616,0.70202,0.85825,0.009604,0.009604,0.009604
6,1152.03,0.33999,0.39482,0.81661,0.81567,0.73691,0.82644,0.75085,0.43451,0.92038,0.86158,0.009505,0.009505,0.009505
7,1335.35,0.31114,0.35971,0.80992,0.95256,0.89807,0.96383,0.84241,0.60248,0.48455,0.88156,0.009406,0.009406,0.009406
8,1518.68,0.29176,0.33987,0.80516,0.97058,0.91185,0.97221,0.85356,0.58408,0.43771,0.86239,0.009307,0.009307,0.009307
9,1702.03,0.27683,0.3214,0.80166,0.96403,0.91663,0.9736,0.85118,0.6359,0.43055,0.88091,0.009208,0.009208,0.009208
10,1891.76,0.26487,0.30796,0.79943,0.96201,0.92669,0.97715,0.894,0.46381,0.37437,0.84314,0.009109,0.009109,0.009109
11,2081.79,0.25744,0.29846,0.79614,0.96562,0.92382,0.97554,0.79415,0.81302,0.46321,0.93576,0.00901,0.00901,0.00901
12,2273.34,0.24726,0.28842,0.79445,0.96248,0.92901,0.97544,0.7642,0.93193,0.48807,0.98769,0.008911,0.008911,0.008911
13,2461.71,0.24266,0.27619,0.79413,0.9672,0.93016,0.97834,0.69208,1.24698,0.59927,1.19042,0.008812,0.008812,0.008812
14,2649.81,0.23391,0.26941,0.79165,0.96579,0.93247,0.98028,0.72182,1.0815,0.50846,1.08505,0.008713,0.008713,0.008713
15,2837.95,0.22893,0.2651,0.79082,0.9639,0.93414,0.9807,0.8522,0.64306,0.38931,0.88909,0.008614,0.008614,0.008614
16,3021.27,0.22269,0.25369,0.78809,0.97667,0.92754,0.98283,0.79198,0.89233,0.43512,0.98623,0.008515,0.008515,0.008515
17,3209.3,0.21937,0.24886,0.78797,0.96559,0.93193,0.98178,0.67198,1.35518,0.59949,1.30325,0.008416,0.008416,0.008416
18,3400.76,0.21415,0.24489,0.78789,0.95973,0.94489,0.98156,0.63227,1.45967,0.63486,1.42858,0.008317,0.008317,0.008317
19,3590.96,0.21227,0.23986,0.78736,0.96136,0.94369,0.98263,0.83035,0.76379,0.39831,0.92541,0.008218,0.008218,0.008218
20,3779.15,0.20834,0.23506,0.78475,0.96214,0.93563,0.97908,0.5024,1.92081,0.81282,1.99717,0.008119,0.008119,0.008119
21,3976.3,0.20592,0.23055,0.78534,0.9636,0.94141,0.98186,0.71087,1.20783,0.53596,1.17998,0.00802,0.00802,0.00802
22,4165.69,0.20195,0.22554,0.78431,0.96621,0.94394,0.98458,0.86353,0.6245,0.35194,0.86978,0.007921,0.007921,0.007921
23,4357.69,0.19847,0.22066,0.78362,0.9745,0.93877,0.98501,0.84365,0.71155,0.37717,0.91395,0.007822,0.007822,0.007822
24,4548.46,0.19715,0.21991,0.78423,0.96456,0.94907,0.98541,0.77136,0.9901,0.45056,1.02401,0.007723,0.007723,0.007723
25,4738.2,0.19207,0.21284,0.7821,0.97136,0.94417,0.98568,0.8139,0.83053,0.41526,0.94475,0.007624,0.007624,0.007624
26,4926.95,0.19124,0.21138,0.7823,0.9712,0.94333,0.98466,0.78106,0.94702,0.44977,1.0068,0.007525,0.007525,0.007525
27,5115.34,0.18944,0.21166,0.78245,0.97207,0.9347,0.98325,0.57941,1.64865,0.72474,1.70082,0.007426,0.007426,0.007426
28,5303.92,0.18817,0.20777,0.7814,0.96837,0.9519,0.98672,0.77596,0.96734,0.43218,1.01592,0.007327,0.007327,0.007327
29,5493.53,0.18565,0.20231,0.78154,0.9719,0.94481,0.98552,0.67875,1.27094,0.57309,1.27411,0.007228,0.007228,0.007228
30,5682.31,0.18424,0.19916,0.77989,0.96714,0.95269,0.98588,0.73712,1.08764,0.52054,1.09884,0.007129,0.007129,0.007129
31,5870.11,0.1812,0.19544,0.78013,0.9698,0.95028,0.98687,0.72258,1.15995,0.50918,1.14823,0.00703,0.00703,0.00703
32,6060.08,0.1801,0.19571,0.7799,0.9699,0.95342,0.98761,0.83372,0.70788,0.36986,0.89674,0.006931,0.006931,0.006931
33,6244.87,0.17816,0.19373,0.7789,0.96278,0.95546,0.98684,0.85265,0.66817,0.36063,0.88589,0.006832,0.006832,0.006832
34,6427.13,0.176,0.1909,0.77864,0.96962,0.95053,0.98707,0.85351,0.70717,0.36925,0.90913,0.006733,0.006733,0.006733
35,6619.57,0.17377,0.18513,0.77794,0.97418,0.94725,0.98662,0.83346,0.77385,0.3987,0.92964,0.006634,0.006634,0.006634
36,6807.72,0.17359,0.18454,0.77885,0.97363,0.95,0.98703,0.85072,0.72917,0.36745,0.91556,0.006535,0.006535,0.006535
37,7001.26,0.17126,0.18179,0.77796,0.96337,0.95646,0.98744,0.79259,0.86734,0.4072,0.96666,0.006436,0.006436,0.006436
38,7186.91,0.16989,0.17967,0.77791,0.97277,0.94891,0.98737,0.8268,0.78213,0.38811,0.9293,0.006337,0.006337,0.006337
39,7372.72,0.16823,0.17959,0.77698,0.96961,0.95206,0.98714,0.86035,0.69764,0.35696,0.90589,0.006238,0.006238,0.006238
40,7558.51,0.16639,0.17648,0.77676,0.96471,0.9592,0.98756,0.85427,0.70458,0.35248,0.90042,0.006139,0.006139,0.006139
41,7747.05,0.16641,0.17472,0.77686,0.96565,0.95698,0.98751,0.76422,1.01303,0.45638,1.04876,0.00604,0.00604,0.00604
42,7930.66,0.16528,0.17295,0.77783,0.98029,0.94412,0.98709,0.68134,1.27076,0.55229,1.27714,0.005941,0.005941,0.005941
43,8113.31,0.16304,0.17093,0.77627,0.96909,0.95623,0.98692,0.77479,0.97729,0.45248,1.03101,0.005842,0.005842,0.005842
44,8298.3,0.16163,0.16817,0.77509,0.96809,0.9575,0.98709,0.80448,0.85637,0.40945,0.96298,0.005743,0.005743,0.005743
45,8485.25,0.16053,0.16768,0.77535,0.97311,0.95211,0.98726,0.8047,0.85835,0.40888,0.96605,0.005644,0.005644,0.005644
46,8669.88,0.15959,0.16634,0.77576,0.97431,0.95186,0.98739,0.797,0.87218,0.41446,0.97186,0.005545,0.005545,0.005545
47,8853.93,0.15778,0.16234,0.77599,0.97532,0.95052,0.98702,0.78582,0.92511,0.43102,1.0039,0.005446,0.005446,0.005446
48,9037.52,0.15602,0.16175,0.77439,0.97529,0.94998,0.9874,0.84071,0.77064,0.38361,0.93995,0.005347,0.005347,0.005347
49,9223.9,0.15478,0.1604,0.77364,0.97345,0.95143,0.98729,0.84662,0.73185,0.37143,0.92248,0.005248,0.005248,0.005248
50,9411.45,0.15431,0.1584,0.77449,0.98033,0.94592,0.98733,0.86995,0.63137,0.34173,0.8816,0.005149,0.005149,0.005149
51,9595.98,0.15305,0.15648,0.77414,0.97318,0.95431,0.98753,0.86938,0.63298,0.34305,0.88158,0.00505,0.00505,0.00505
52,9779.61,0.15291,0.15561,0.77441,0.97824,0.94936,0.98777,0.87333,0.60174,0.33831,0.87135,0.004951,0.004951,0.004951
53,9963.17,0.15193,0.15454,0.77361,0.97077,0.95579,0.98737,0.87038,0.62864,0.34404,0.87819,0.004852,0.004852,0.004852
54,10151.9,0.14978,0.15218,0.77431,0.97892,0.94812,0.9874,0.86664,0.6559,0.35033,0.88752,0.004753,0.004753,0.004753
55,10335.2,0.14867,0.14954,0.77318,0.98227,0.94556,0.98734,0.86535,0.66191,0.35233,0.89052,0.004654,0.004654,0.004654
56,10517.7,0.14781,0.1504,0.77387,0.97187,0.95472,0.98731,0.85393,0.70291,0.36551,0.90638,0.004555,0.004555,0.004555
57,10704.4,0.14704,0.14654,0.77286,0.96973,0.95539,0.98731,0.84386,0.75517,0.38619,0.92774,0.004456,0.004456,0.004456
58,10888.6,0.14478,0.14588,0.77324,0.9792,0.94676,0.9872,0.84023,0.76095,0.38846,0.93011,0.004357,0.004357,0.004357
59,11071.2,0.14408,0.14418,0.7724,0.9709,0.95553,0.98729,0.8499,0.71784,0.37089,0.91332,0.004258,0.004258,0.004258
60,11256.4,0.1427,0.14165,0.77106,0.96919,0.95682,0.98729,0.85156,0.70256,0.36509,0.90774,0.004159,0.004159,0.004159
61,11444.8,0.14194,0.14087,0.77269,0.96601,0.96121,0.98731,0.85107,0.70839,0.36753,0.90948,0.00406,0.00406,0.00406
62,11630.9,0.14062,0.13882,0.77215,0.96628,0.96081,0.98736,0.84858,0.73074,0.3762,0.92033,0.003961,0.003961,0.003961
63,11816.5,0.13938,0.13865,0.77152,0.96711,0.95961,0.98744,0.85214,0.70862,0.36754,0.91079,0.003862,0.003862,0.003862
64,12005.3,0.13858,0.13687,0.77045,0.96702,0.9595,0.98748,0.85574,0.69672,0.36084,0.90588,0.003763,0.003763,0.003763
65,12191.8,0.13775,0.13411,0.77132,0.96785,0.95943,0.9874,0.85729,0.68875,0.35766,0.90221,0.003664,0.003664,0.003664
66,12379.6,0.13556,0.13271,0.77167,0.96725,0.96005,0.98735,0.85898,0.68174,0.3561,0.89887,0.003565,0.003565,0.003565
67,12565.4,0.13463,0.13108,0.77009,0.97381,0.95338,0.98732,0.86031,0.67263,0.35399,0.89484,0.003466,0.003466,0.003466
68,12752.5,0.13515,0.1311,0.77095,0.96906,0.95916,0.98725,0.86029,0.66717,0.35274,0.89292,0.003367,0.003367,0.003367
69,12940.8,0.13415,0.12957,0.76963,0.97126,0.95685,0.9873,0.86049,0.6644,0.35306,0.8918,0.003268,0.003268,0.003268
70,13133.2,0.13179,0.12737,0.76937,0.97287,0.95478,0.98727,0.86047,0.6632,0.35246,0.89193,0.003169,0.003169,0.003169
71,13319.4,0.13185,0.1274,0.77079,0.97267,0.95587,0.98722,0.86193,0.65949,0.35213,0.89086,0.00307,0.00307,0.00307
72,13504.8,0.12947,0.12446,0.76998,0.97199,0.95686,0.98725,0.86401,0.64895,0.34877,0.88741,0.002971,0.002971,0.002971
73,13695.1,0.12876,0.12321,0.76883,0.9723,0.9569,0.98725,0.86643,0.64091,0.3473,0.88447,0.002872,0.002872,0.002872
74,13882,0.12828,0.12194,0.76915,0.97256,0.95686,0.98732,0.86847,0.6322,0.34702,0.88109,0.002773,0.002773,0.002773
75,14075,0.12664,0.11944,0.76878,0.97277,0.95678,0.98726,0.86861,0.63123,0.3482,0.88086,0.002674,0.002674,0.002674
76,14259.9,0.12587,0.11965,0.7692,0.9727,0.95673,0.98717,0.86916,0.62721,0.34811,0.87932,0.002575,0.002575,0.002575
77,14451.2,0.12433,0.1174,0.76838,0.97267,0.95663,0.9872,0.87057,0.62032,0.34709,0.8769,0.002476,0.002476,0.002476
78,14636.3,0.12352,0.11507,0.76971,0.97087,0.95829,0.98721,0.87189,0.61271,0.34667,0.87445,0.002377,0.002377,0.002377
79,14821.1,0.1231,0.11454,0.76897,0.97195,0.95752,0.98722,0.87292,0.60714,0.34596,0.87271,0.002278,0.002278,0.002278
80,15007.3,0.12117,0.11285,0.76864,0.97146,0.95789,0.98726,0.8735,0.6031,0.34515,0.87163,0.002179,0.002179,0.002179
81,15199.5,0.12029,0.11158,0.76708,0.97018,0.95938,0.9872,0.87378,0.60116,0.34538,0.87113,0.00208,0.00208,0.00208
82,15390.8,0.11877,0.10949,0.76719,0.97021,0.95964,0.98721,0.87422,0.59897,0.3457,0.87057,0.001981,0.001981,0.001981
83,15577.6,0.11812,0.10818,0.76749,0.97013,0.95951,0.98722,0.87429,0.59878,0.34524,0.87054,0.001882,0.001882,0.001882
84,15761.5,0.11687,0.10634,0.76703,0.97155,0.9583,0.98713,0.87407,0.59964,0.34532,0.8709,0.001783,0.001783,0.001783
85,15946.2,0.11551,0.10455,0.7672,0.9717,0.95797,0.9871,0.87367,0.60049,0.34569,0.87136,0.001684,0.001684,0.001684
86,16130.6,0.11474,0.10479,0.76737,0.97183,0.95808,0.98712,0.87406,0.5981,0.34504,0.87076,0.001585,0.001585,0.001585
87,16324.1,0.11337,0.10221,0.76695,0.97137,0.95881,0.98708,0.87382,0.59851,0.34519,0.87106,0.001486,0.001486,0.001486
88,16517.1,0.11185,0.10043,0.76513,0.97121,0.95899,0.98707,0.87379,0.59906,0.34583,0.87135,0.001387,0.001387,0.001387
89,16708.5,0.11103,0.09846,0.76565,0.97113,0.95904,0.98709,0.87369,0.59838,0.34599,0.87138,0.001288,0.001288,0.001288
90,16896.9,0.11054,0.0982,0.76703,0.97095,0.95892,0.98712,0.87377,0.59757,0.34552,0.87126,0.001189,0.001189,0.001189
91,17091.5,0.10967,0.09616,0.76665,0.97037,0.9595,0.98704,0.87361,0.59635,0.34561,0.87111,0.00109,0.00109,0.00109
92,17282.9,0.10834,0.09481,0.76509,0.9726,0.95743,0.98704,0.87372,0.5956,0.34572,0.87108,0.000991,0.000991,0.000991
93,17471.1,0.10692,0.09247,0.76461,0.97255,0.95738,0.9871,0.87368,0.59467,0.34689,0.8707,0.000892,0.000892,0.000892
94,17654.7,0.10578,0.09076,0.76573,0.97167,0.95786,0.9872,0.87367,0.59367,0.34732,0.87049,0.000793,0.000793,0.000793
95,17858.1,0.10457,0.08903,0.7648,0.97097,0.95816,0.98718,0.87394,0.59295,0.34757,0.87044,0.000694,0.000694,0.000694
96,18048,0.10283,0.08802,0.76437,0.97358,0.95577,0.98712,0.8737,0.59392,0.34877,0.87087,0.000595,0.000595,0.000595
97,18233,0.10269,0.08685,0.76468,0.97469,0.95492,0.98712,0.8741,0.59227,0.34903,0.87042,0.000496,0.000496,0.000496
98,18418.2,0.10143,0.0852,0.7644,0.97473,0.95512,0.98709,0.87397,0.59171,0.35007,0.8704,0.000397,0.000397,0.000397
99,18605.1,0.10052,0.08363,0.76442,0.97443,0.95526,0.98712,0.87396,0.5922,0.35121,0.87087,0.000298,0.000298,0.000298
100,18790,0.09925,0.08228,0.76465,0.97498,0.95493,0.98711,0.8737,0.59312,0.35293,0.87138,0.000199,0.000199,0.000199
1 epoch time train/box_loss train/cls_loss train/dfl_loss metrics/precision(B) metrics/recall(B) metrics/mAP50(B) metrics/mAP50-95(B) val/box_loss val/cls_loss val/dfl_loss lr/pg0 lr/pg1 lr/pg2
2 1 217.641 0.79856 2.56507 1.01986 0.8921 0.84545 0.90033 0.81508 0.49347 0.92369 0.83815 0.00332991 0.00332991 0.00332991
3 2 410.275 0.506 1.04596 0.85661 0.9164 0.85418 0.92852 0.79726 0.63851 0.72277 0.87404 0.00659728 0.00659728 0.00659728
4 3 598.713 0.49618 0.68647 0.85624 0.93775 0.80014 0.91481 0.80383 0.5798 0.76691 0.86835 0.00979865 0.00979865 0.00979865
5 4 782.868 0.44059 0.53532 0.84299 0.94668 0.90421 0.96101 0.8832 0.43961 0.49649 0.84298 0.009703 0.009703 0.009703
6 5 967.898 0.37596 0.44308 0.82667 0.88316 0.81492 0.91376 0.82272 0.50616 0.70202 0.85825 0.009604 0.009604 0.009604
7 6 1152.03 0.33999 0.39482 0.81661 0.81567 0.73691 0.82644 0.75085 0.43451 0.92038 0.86158 0.009505 0.009505 0.009505
8 7 1335.35 0.31114 0.35971 0.80992 0.95256 0.89807 0.96383 0.84241 0.60248 0.48455 0.88156 0.009406 0.009406 0.009406
9 8 1518.68 0.29176 0.33987 0.80516 0.97058 0.91185 0.97221 0.85356 0.58408 0.43771 0.86239 0.009307 0.009307 0.009307
10 9 1702.03 0.27683 0.3214 0.80166 0.96403 0.91663 0.9736 0.85118 0.6359 0.43055 0.88091 0.009208 0.009208 0.009208
11 10 1891.76 0.26487 0.30796 0.79943 0.96201 0.92669 0.97715 0.894 0.46381 0.37437 0.84314 0.009109 0.009109 0.009109
12 11 2081.79 0.25744 0.29846 0.79614 0.96562 0.92382 0.97554 0.79415 0.81302 0.46321 0.93576 0.00901 0.00901 0.00901
13 12 2273.34 0.24726 0.28842 0.79445 0.96248 0.92901 0.97544 0.7642 0.93193 0.48807 0.98769 0.008911 0.008911 0.008911
14 13 2461.71 0.24266 0.27619 0.79413 0.9672 0.93016 0.97834 0.69208 1.24698 0.59927 1.19042 0.008812 0.008812 0.008812
15 14 2649.81 0.23391 0.26941 0.79165 0.96579 0.93247 0.98028 0.72182 1.0815 0.50846 1.08505 0.008713 0.008713 0.008713
16 15 2837.95 0.22893 0.2651 0.79082 0.9639 0.93414 0.9807 0.8522 0.64306 0.38931 0.88909 0.008614 0.008614 0.008614
17 16 3021.27 0.22269 0.25369 0.78809 0.97667 0.92754 0.98283 0.79198 0.89233 0.43512 0.98623 0.008515 0.008515 0.008515
18 17 3209.3 0.21937 0.24886 0.78797 0.96559 0.93193 0.98178 0.67198 1.35518 0.59949 1.30325 0.008416 0.008416 0.008416
19 18 3400.76 0.21415 0.24489 0.78789 0.95973 0.94489 0.98156 0.63227 1.45967 0.63486 1.42858 0.008317 0.008317 0.008317
20 19 3590.96 0.21227 0.23986 0.78736 0.96136 0.94369 0.98263 0.83035 0.76379 0.39831 0.92541 0.008218 0.008218 0.008218
21 20 3779.15 0.20834 0.23506 0.78475 0.96214 0.93563 0.97908 0.5024 1.92081 0.81282 1.99717 0.008119 0.008119 0.008119
22 21 3976.3 0.20592 0.23055 0.78534 0.9636 0.94141 0.98186 0.71087 1.20783 0.53596 1.17998 0.00802 0.00802 0.00802
23 22 4165.69 0.20195 0.22554 0.78431 0.96621 0.94394 0.98458 0.86353 0.6245 0.35194 0.86978 0.007921 0.007921 0.007921
24 23 4357.69 0.19847 0.22066 0.78362 0.9745 0.93877 0.98501 0.84365 0.71155 0.37717 0.91395 0.007822 0.007822 0.007822
25 24 4548.46 0.19715 0.21991 0.78423 0.96456 0.94907 0.98541 0.77136 0.9901 0.45056 1.02401 0.007723 0.007723 0.007723
26 25 4738.2 0.19207 0.21284 0.7821 0.97136 0.94417 0.98568 0.8139 0.83053 0.41526 0.94475 0.007624 0.007624 0.007624
27 26 4926.95 0.19124 0.21138 0.7823 0.9712 0.94333 0.98466 0.78106 0.94702 0.44977 1.0068 0.007525 0.007525 0.007525
28 27 5115.34 0.18944 0.21166 0.78245 0.97207 0.9347 0.98325 0.57941 1.64865 0.72474 1.70082 0.007426 0.007426 0.007426
29 28 5303.92 0.18817 0.20777 0.7814 0.96837 0.9519 0.98672 0.77596 0.96734 0.43218 1.01592 0.007327 0.007327 0.007327
30 29 5493.53 0.18565 0.20231 0.78154 0.9719 0.94481 0.98552 0.67875 1.27094 0.57309 1.27411 0.007228 0.007228 0.007228
31 30 5682.31 0.18424 0.19916 0.77989 0.96714 0.95269 0.98588 0.73712 1.08764 0.52054 1.09884 0.007129 0.007129 0.007129
32 31 5870.11 0.1812 0.19544 0.78013 0.9698 0.95028 0.98687 0.72258 1.15995 0.50918 1.14823 0.00703 0.00703 0.00703
33 32 6060.08 0.1801 0.19571 0.7799 0.9699 0.95342 0.98761 0.83372 0.70788 0.36986 0.89674 0.006931 0.006931 0.006931
34 33 6244.87 0.17816 0.19373 0.7789 0.96278 0.95546 0.98684 0.85265 0.66817 0.36063 0.88589 0.006832 0.006832 0.006832
35 34 6427.13 0.176 0.1909 0.77864 0.96962 0.95053 0.98707 0.85351 0.70717 0.36925 0.90913 0.006733 0.006733 0.006733
36 35 6619.57 0.17377 0.18513 0.77794 0.97418 0.94725 0.98662 0.83346 0.77385 0.3987 0.92964 0.006634 0.006634 0.006634
37 36 6807.72 0.17359 0.18454 0.77885 0.97363 0.95 0.98703 0.85072 0.72917 0.36745 0.91556 0.006535 0.006535 0.006535
38 37 7001.26 0.17126 0.18179 0.77796 0.96337 0.95646 0.98744 0.79259 0.86734 0.4072 0.96666 0.006436 0.006436 0.006436
39 38 7186.91 0.16989 0.17967 0.77791 0.97277 0.94891 0.98737 0.8268 0.78213 0.38811 0.9293 0.006337 0.006337 0.006337
40 39 7372.72 0.16823 0.17959 0.77698 0.96961 0.95206 0.98714 0.86035 0.69764 0.35696 0.90589 0.006238 0.006238 0.006238
41 40 7558.51 0.16639 0.17648 0.77676 0.96471 0.9592 0.98756 0.85427 0.70458 0.35248 0.90042 0.006139 0.006139 0.006139
42 41 7747.05 0.16641 0.17472 0.77686 0.96565 0.95698 0.98751 0.76422 1.01303 0.45638 1.04876 0.00604 0.00604 0.00604
43 42 7930.66 0.16528 0.17295 0.77783 0.98029 0.94412 0.98709 0.68134 1.27076 0.55229 1.27714 0.005941 0.005941 0.005941
44 43 8113.31 0.16304 0.17093 0.77627 0.96909 0.95623 0.98692 0.77479 0.97729 0.45248 1.03101 0.005842 0.005842 0.005842
45 44 8298.3 0.16163 0.16817 0.77509 0.96809 0.9575 0.98709 0.80448 0.85637 0.40945 0.96298 0.005743 0.005743 0.005743
46 45 8485.25 0.16053 0.16768 0.77535 0.97311 0.95211 0.98726 0.8047 0.85835 0.40888 0.96605 0.005644 0.005644 0.005644
47 46 8669.88 0.15959 0.16634 0.77576 0.97431 0.95186 0.98739 0.797 0.87218 0.41446 0.97186 0.005545 0.005545 0.005545
48 47 8853.93 0.15778 0.16234 0.77599 0.97532 0.95052 0.98702 0.78582 0.92511 0.43102 1.0039 0.005446 0.005446 0.005446
49 48 9037.52 0.15602 0.16175 0.77439 0.97529 0.94998 0.9874 0.84071 0.77064 0.38361 0.93995 0.005347 0.005347 0.005347
50 49 9223.9 0.15478 0.1604 0.77364 0.97345 0.95143 0.98729 0.84662 0.73185 0.37143 0.92248 0.005248 0.005248 0.005248
51 50 9411.45 0.15431 0.1584 0.77449 0.98033 0.94592 0.98733 0.86995 0.63137 0.34173 0.8816 0.005149 0.005149 0.005149
52 51 9595.98 0.15305 0.15648 0.77414 0.97318 0.95431 0.98753 0.86938 0.63298 0.34305 0.88158 0.00505 0.00505 0.00505
53 52 9779.61 0.15291 0.15561 0.77441 0.97824 0.94936 0.98777 0.87333 0.60174 0.33831 0.87135 0.004951 0.004951 0.004951
54 53 9963.17 0.15193 0.15454 0.77361 0.97077 0.95579 0.98737 0.87038 0.62864 0.34404 0.87819 0.004852 0.004852 0.004852
55 54 10151.9 0.14978 0.15218 0.77431 0.97892 0.94812 0.9874 0.86664 0.6559 0.35033 0.88752 0.004753 0.004753 0.004753
56 55 10335.2 0.14867 0.14954 0.77318 0.98227 0.94556 0.98734 0.86535 0.66191 0.35233 0.89052 0.004654 0.004654 0.004654
57 56 10517.7 0.14781 0.1504 0.77387 0.97187 0.95472 0.98731 0.85393 0.70291 0.36551 0.90638 0.004555 0.004555 0.004555
58 57 10704.4 0.14704 0.14654 0.77286 0.96973 0.95539 0.98731 0.84386 0.75517 0.38619 0.92774 0.004456 0.004456 0.004456
59 58 10888.6 0.14478 0.14588 0.77324 0.9792 0.94676 0.9872 0.84023 0.76095 0.38846 0.93011 0.004357 0.004357 0.004357
60 59 11071.2 0.14408 0.14418 0.7724 0.9709 0.95553 0.98729 0.8499 0.71784 0.37089 0.91332 0.004258 0.004258 0.004258
61 60 11256.4 0.1427 0.14165 0.77106 0.96919 0.95682 0.98729 0.85156 0.70256 0.36509 0.90774 0.004159 0.004159 0.004159
62 61 11444.8 0.14194 0.14087 0.77269 0.96601 0.96121 0.98731 0.85107 0.70839 0.36753 0.90948 0.00406 0.00406 0.00406
63 62 11630.9 0.14062 0.13882 0.77215 0.96628 0.96081 0.98736 0.84858 0.73074 0.3762 0.92033 0.003961 0.003961 0.003961
64 63 11816.5 0.13938 0.13865 0.77152 0.96711 0.95961 0.98744 0.85214 0.70862 0.36754 0.91079 0.003862 0.003862 0.003862
65 64 12005.3 0.13858 0.13687 0.77045 0.96702 0.9595 0.98748 0.85574 0.69672 0.36084 0.90588 0.003763 0.003763 0.003763
66 65 12191.8 0.13775 0.13411 0.77132 0.96785 0.95943 0.9874 0.85729 0.68875 0.35766 0.90221 0.003664 0.003664 0.003664
67 66 12379.6 0.13556 0.13271 0.77167 0.96725 0.96005 0.98735 0.85898 0.68174 0.3561 0.89887 0.003565 0.003565 0.003565
68 67 12565.4 0.13463 0.13108 0.77009 0.97381 0.95338 0.98732 0.86031 0.67263 0.35399 0.89484 0.003466 0.003466 0.003466
69 68 12752.5 0.13515 0.1311 0.77095 0.96906 0.95916 0.98725 0.86029 0.66717 0.35274 0.89292 0.003367 0.003367 0.003367
70 69 12940.8 0.13415 0.12957 0.76963 0.97126 0.95685 0.9873 0.86049 0.6644 0.35306 0.8918 0.003268 0.003268 0.003268
71 70 13133.2 0.13179 0.12737 0.76937 0.97287 0.95478 0.98727 0.86047 0.6632 0.35246 0.89193 0.003169 0.003169 0.003169
72 71 13319.4 0.13185 0.1274 0.77079 0.97267 0.95587 0.98722 0.86193 0.65949 0.35213 0.89086 0.00307 0.00307 0.00307
73 72 13504.8 0.12947 0.12446 0.76998 0.97199 0.95686 0.98725 0.86401 0.64895 0.34877 0.88741 0.002971 0.002971 0.002971
74 73 13695.1 0.12876 0.12321 0.76883 0.9723 0.9569 0.98725 0.86643 0.64091 0.3473 0.88447 0.002872 0.002872 0.002872
75 74 13882 0.12828 0.12194 0.76915 0.97256 0.95686 0.98732 0.86847 0.6322 0.34702 0.88109 0.002773 0.002773 0.002773
76 75 14075 0.12664 0.11944 0.76878 0.97277 0.95678 0.98726 0.86861 0.63123 0.3482 0.88086 0.002674 0.002674 0.002674
77 76 14259.9 0.12587 0.11965 0.7692 0.9727 0.95673 0.98717 0.86916 0.62721 0.34811 0.87932 0.002575 0.002575 0.002575
78 77 14451.2 0.12433 0.1174 0.76838 0.97267 0.95663 0.9872 0.87057 0.62032 0.34709 0.8769 0.002476 0.002476 0.002476
79 78 14636.3 0.12352 0.11507 0.76971 0.97087 0.95829 0.98721 0.87189 0.61271 0.34667 0.87445 0.002377 0.002377 0.002377
80 79 14821.1 0.1231 0.11454 0.76897 0.97195 0.95752 0.98722 0.87292 0.60714 0.34596 0.87271 0.002278 0.002278 0.002278
81 80 15007.3 0.12117 0.11285 0.76864 0.97146 0.95789 0.98726 0.8735 0.6031 0.34515 0.87163 0.002179 0.002179 0.002179
82 81 15199.5 0.12029 0.11158 0.76708 0.97018 0.95938 0.9872 0.87378 0.60116 0.34538 0.87113 0.00208 0.00208 0.00208
83 82 15390.8 0.11877 0.10949 0.76719 0.97021 0.95964 0.98721 0.87422 0.59897 0.3457 0.87057 0.001981 0.001981 0.001981
84 83 15577.6 0.11812 0.10818 0.76749 0.97013 0.95951 0.98722 0.87429 0.59878 0.34524 0.87054 0.001882 0.001882 0.001882
85 84 15761.5 0.11687 0.10634 0.76703 0.97155 0.9583 0.98713 0.87407 0.59964 0.34532 0.8709 0.001783 0.001783 0.001783
86 85 15946.2 0.11551 0.10455 0.7672 0.9717 0.95797 0.9871 0.87367 0.60049 0.34569 0.87136 0.001684 0.001684 0.001684
87 86 16130.6 0.11474 0.10479 0.76737 0.97183 0.95808 0.98712 0.87406 0.5981 0.34504 0.87076 0.001585 0.001585 0.001585
88 87 16324.1 0.11337 0.10221 0.76695 0.97137 0.95881 0.98708 0.87382 0.59851 0.34519 0.87106 0.001486 0.001486 0.001486
89 88 16517.1 0.11185 0.10043 0.76513 0.97121 0.95899 0.98707 0.87379 0.59906 0.34583 0.87135 0.001387 0.001387 0.001387
90 89 16708.5 0.11103 0.09846 0.76565 0.97113 0.95904 0.98709 0.87369 0.59838 0.34599 0.87138 0.001288 0.001288 0.001288
91 90 16896.9 0.11054 0.0982 0.76703 0.97095 0.95892 0.98712 0.87377 0.59757 0.34552 0.87126 0.001189 0.001189 0.001189
92 91 17091.5 0.10967 0.09616 0.76665 0.97037 0.9595 0.98704 0.87361 0.59635 0.34561 0.87111 0.00109 0.00109 0.00109
93 92 17282.9 0.10834 0.09481 0.76509 0.9726 0.95743 0.98704 0.87372 0.5956 0.34572 0.87108 0.000991 0.000991 0.000991
94 93 17471.1 0.10692 0.09247 0.76461 0.97255 0.95738 0.9871 0.87368 0.59467 0.34689 0.8707 0.000892 0.000892 0.000892
95 94 17654.7 0.10578 0.09076 0.76573 0.97167 0.95786 0.9872 0.87367 0.59367 0.34732 0.87049 0.000793 0.000793 0.000793
96 95 17858.1 0.10457 0.08903 0.7648 0.97097 0.95816 0.98718 0.87394 0.59295 0.34757 0.87044 0.000694 0.000694 0.000694
97 96 18048 0.10283 0.08802 0.76437 0.97358 0.95577 0.98712 0.8737 0.59392 0.34877 0.87087 0.000595 0.000595 0.000595
98 97 18233 0.10269 0.08685 0.76468 0.97469 0.95492 0.98712 0.8741 0.59227 0.34903 0.87042 0.000496 0.000496 0.000496
99 98 18418.2 0.10143 0.0852 0.7644 0.97473 0.95512 0.98709 0.87397 0.59171 0.35007 0.8704 0.000397 0.000397 0.000397
100 99 18605.1 0.10052 0.08363 0.76442 0.97443 0.95526 0.98712 0.87396 0.5922 0.35121 0.87087 0.000298 0.000298 0.000298
101 100 18790 0.09925 0.08228 0.76465 0.97498 0.95493 0.98711 0.8737 0.59312 0.35293 0.87138 0.000199 0.000199 0.000199

View File

@@ -0,0 +1 @@
# Tests for augmentation module

View File

@@ -0,0 +1,347 @@
"""
Tests for augmentation base module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from typing import Any
from unittest.mock import MagicMock
import numpy as np
import pytest
class TestAugmentationResult:
"""Tests for AugmentationResult dataclass."""
def test_minimal_result(self) -> None:
"""Test creating result with only required field."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = AugmentationResult(image=image)
assert result.image is image
assert result.bboxes is None
assert result.transform_matrix is None
assert result.applied is True
assert result.metadata is None
def test_full_result(self) -> None:
"""Test creating result with all fields."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
transform = np.eye(3)
metadata = {"applied_transform": "wrinkle"}
result = AugmentationResult(
image=image,
bboxes=bboxes,
transform_matrix=transform,
applied=True,
metadata=metadata,
)
assert result.image is image
np.testing.assert_array_equal(result.bboxes, bboxes)
np.testing.assert_array_equal(result.transform_matrix, transform)
assert result.applied is True
assert result.metadata == {"applied_transform": "wrinkle"}
def test_not_applied(self) -> None:
"""Test result when augmentation was not applied."""
from shared.augmentation.base import AugmentationResult
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = AugmentationResult(image=image, applied=False)
assert result.applied is False
class TestBaseAugmentation:
"""Tests for BaseAugmentation abstract class."""
def test_cannot_instantiate_directly(self) -> None:
"""Test that BaseAugmentation cannot be instantiated."""
from shared.augmentation.base import BaseAugmentation
with pytest.raises(TypeError):
BaseAugmentation({}) # type: ignore
def test_subclass_must_implement_apply(self) -> None:
"""Test that subclass must implement apply method."""
from shared.augmentation.base import BaseAugmentation
class IncompleteAugmentation(BaseAugmentation):
name = "incomplete"
def _validate_params(self) -> None:
pass
# Missing apply method
with pytest.raises(TypeError):
IncompleteAugmentation({}) # type: ignore
def test_subclass_must_implement_validate_params(self) -> None:
"""Test that subclass must implement _validate_params."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class IncompleteAugmentation(BaseAugmentation):
name = "incomplete"
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
# Missing _validate_params method
with pytest.raises(TypeError):
IncompleteAugmentation({}) # type: ignore
def test_valid_subclass(self) -> None:
"""Test creating a valid subclass."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
affects_geometry = False
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image, bboxes=bboxes)
aug = DummyAugmentation({"param1": "value1"})
assert aug.name == "dummy"
assert aug.affects_geometry is False
assert aug.params == {"param1": "value1"}
def test_apply_returns_augmentation_result(self) -> None:
"""Test that apply returns AugmentationResult."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
# Simple pass-through
return AugmentationResult(image=image, bboxes=bboxes)
aug = DummyAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
result = aug.apply(image, bboxes)
assert isinstance(result, AugmentationResult)
assert result.image is image
np.testing.assert_array_equal(result.bboxes, bboxes)
def test_affects_geometry_default(self) -> None:
"""Test that affects_geometry defaults to False."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class DummyAugmentation(BaseAugmentation):
name = "dummy"
# Not setting affects_geometry
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
aug = DummyAugmentation({})
assert aug.affects_geometry is False
def test_validate_params_called_on_init(self) -> None:
"""Test that _validate_params is called during initialization."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
validation_called = {"called": False}
class ValidatingAugmentation(BaseAugmentation):
name = "validating"
def _validate_params(self) -> None:
validation_called["called"] = True
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
ValidatingAugmentation({})
assert validation_called["called"] is True
def test_validate_params_raises_on_invalid(self) -> None:
"""Test that _validate_params can raise ValueError."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class StrictAugmentation(BaseAugmentation):
name = "strict"
def _validate_params(self) -> None:
if "required_param" not in self.params:
raise ValueError("required_param is required")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
return AugmentationResult(image=image)
with pytest.raises(ValueError, match="required_param"):
StrictAugmentation({})
# Should work with required param
aug = StrictAugmentation({"required_param": "value"})
assert aug.params["required_param"] == "value"
def test_rng_usage(self) -> None:
"""Test that random generator can be passed and used."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class RandomAugmentation(BaseAugmentation):
name = "random"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
if rng is None:
rng = np.random.default_rng()
# Use rng to generate a random value
random_value = rng.random()
return AugmentationResult(
image=image,
metadata={"random_value": random_value},
)
aug = RandomAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
# With same seed, should get same result
rng1 = np.random.default_rng(42)
rng2 = np.random.default_rng(42)
result1 = aug.apply(image, rng=rng1)
result2 = aug.apply(image, rng=rng2)
assert result1.metadata is not None
assert result2.metadata is not None
assert result1.metadata["random_value"] == result2.metadata["random_value"]
class TestAugmentationResultImmutability:
"""Tests for ensuring result doesn't mutate input."""
def test_image_not_modified(self) -> None:
"""Test that original image is not modified."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class ModifyingAugmentation(BaseAugmentation):
name = "modifying"
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
# Should copy before modifying
modified = image.copy()
modified[:] = 255
return AugmentationResult(image=modified)
aug = ModifyingAugmentation({})
original = np.zeros((100, 100, 3), dtype=np.uint8)
original_copy = original.copy()
result = aug.apply(original)
# Original should be unchanged
np.testing.assert_array_equal(original, original_copy)
# Result should be modified
assert np.all(result.image == 255)
def test_bboxes_not_modified(self) -> None:
"""Test that original bboxes are not modified."""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class BboxModifyingAugmentation(BaseAugmentation):
name = "bbox_modifying"
affects_geometry = True
def _validate_params(self) -> None:
pass
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
if bboxes is not None:
# Should copy before modifying
modified_bboxes = bboxes.copy()
modified_bboxes[:, 1:] *= 0.5 # Scale down
return AugmentationResult(image=image, bboxes=modified_bboxes)
return AugmentationResult(image=image)
aug = BboxModifyingAugmentation({})
image = np.zeros((100, 100, 3), dtype=np.uint8)
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
original_bboxes_copy = original_bboxes.copy()
result = aug.apply(image, original_bboxes)
# Original should be unchanged
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
# Result should be modified
assert result.bboxes is not None
np.testing.assert_array_almost_equal(
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
)

View File

@@ -0,0 +1,283 @@
"""
Tests for augmentation configuration module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from typing import Any
import pytest
class TestAugmentationParams:
"""Tests for AugmentationParams dataclass."""
def test_default_values(self) -> None:
"""Test default parameter values."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams()
assert params.enabled is False
assert params.probability == 0.5
assert params.params == {}
def test_custom_values(self) -> None:
"""Test creating params with custom values."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams(
enabled=True,
probability=0.8,
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
)
assert params.enabled is True
assert params.probability == 0.8
assert params.params["intensity"] == 0.5
assert params.params["num_wrinkles"] == (2, 5)
def test_immutability_params_dict(self) -> None:
"""Test that params dict is independent between instances."""
from shared.augmentation.config import AugmentationParams
params1 = AugmentationParams()
params2 = AugmentationParams()
# Modifying one should not affect the other
params1.params["test"] = "value"
assert "test" not in params2.params
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
from shared.augmentation.config import AugmentationParams
params = AugmentationParams(
enabled=True,
probability=0.7,
params={"key": "value"},
)
result = params.to_dict()
assert result == {
"enabled": True,
"probability": 0.7,
"params": {"key": "value"},
}
def test_from_dict(self) -> None:
"""Test creation from dictionary."""
from shared.augmentation.config import AugmentationParams
data = {
"enabled": True,
"probability": 0.6,
"params": {"intensity": 0.3},
}
params = AugmentationParams.from_dict(data)
assert params.enabled is True
assert params.probability == 0.6
assert params.params == {"intensity": 0.3}
def test_from_dict_with_defaults(self) -> None:
"""Test creation from partial dictionary uses defaults."""
from shared.augmentation.config import AugmentationParams
data: dict[str, Any] = {"enabled": True}
params = AugmentationParams.from_dict(data)
assert params.enabled is True
assert params.probability == 0.5 # default
assert params.params == {} # default
class TestAugmentationConfig:
"""Tests for AugmentationConfig dataclass."""
def test_default_values(self) -> None:
"""Test that all augmentation types have defaults."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
# All augmentation types should exist
augmentation_types = [
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
]
for aug_type in augmentation_types:
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
params = getattr(config, aug_type)
assert hasattr(params, "enabled")
assert hasattr(params, "probability")
assert hasattr(params, "params")
def test_global_settings_defaults(self) -> None:
"""Test global settings default values."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
assert config.preserve_bboxes is True
assert config.seed is None
def test_custom_seed(self) -> None:
"""Test setting custom seed for reproducibility."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig(seed=42)
assert config.seed == 42
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig(seed=123, preserve_bboxes=False)
result = config.to_dict()
assert isinstance(result, dict)
assert result["seed"] == 123
assert result["preserve_bboxes"] is False
assert "perspective_warp" in result
assert "wrinkle" in result
def test_from_dict(self) -> None:
"""Test creation from dictionary."""
from shared.augmentation.config import AugmentationConfig
data = {
"seed": 456,
"preserve_bboxes": False,
"wrinkle": {
"enabled": True,
"probability": 0.8,
"params": {"intensity": 0.5},
},
}
config = AugmentationConfig.from_dict(data)
assert config.seed == 456
assert config.preserve_bboxes is False
assert config.wrinkle.enabled is True
assert config.wrinkle.probability == 0.8
assert config.wrinkle.params["intensity"] == 0.5
def test_from_dict_with_partial_data(self) -> None:
"""Test creation from partial dictionary uses defaults."""
from shared.augmentation.config import AugmentationConfig
data: dict[str, Any] = {
"wrinkle": {"enabled": True},
}
config = AugmentationConfig.from_dict(data)
# Explicitly set value
assert config.wrinkle.enabled is True
# Default values
assert config.preserve_bboxes is True
assert config.seed is None
assert config.gaussian_blur.enabled is False
def test_get_enabled_augmentations(self) -> None:
"""Test getting list of enabled augmentations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
config = AugmentationConfig(
wrinkle=AugmentationParams(enabled=True),
stain=AugmentationParams(enabled=True),
gaussian_blur=AugmentationParams(enabled=False),
)
enabled = config.get_enabled_augmentations()
assert "wrinkle" in enabled
assert "stain" in enabled
assert "gaussian_blur" not in enabled
def test_document_safe_defaults(self) -> None:
"""Test that default params are document-safe (conservative)."""
from shared.augmentation.config import AugmentationConfig
config = AugmentationConfig()
# Perspective warp should be very conservative
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
# Noise should be subtle
noise_std = config.gaussian_noise.params.get("std", (5, 15))
if isinstance(noise_std, tuple):
assert noise_std[1] <= 20 # Max std should be low
def test_immutability_between_instances(self) -> None:
"""Test that config instances are independent."""
from shared.augmentation.config import AugmentationConfig
config1 = AugmentationConfig()
config2 = AugmentationConfig()
# Modifying one should not affect the other
config1.wrinkle.params["test"] = "value"
assert "test" not in config2.wrinkle.params
class TestAugmentationConfigValidation:
"""Tests for configuration validation."""
def test_probability_range_validation(self) -> None:
"""Test that probability values are validated."""
from shared.augmentation.config import AugmentationParams
# Valid range
params = AugmentationParams(probability=0.5)
assert params.probability == 0.5
# Edge cases
params_zero = AugmentationParams(probability=0.0)
assert params_zero.probability == 0.0
params_one = AugmentationParams(probability=1.0)
assert params_one.probability == 1.0
def test_config_validate_method(self) -> None:
"""Test the validate method catches invalid configurations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
# Invalid probability
config = AugmentationConfig(
wrinkle=AugmentationParams(probability=1.5), # Invalid
)
with pytest.raises(ValueError, match="probability"):
config.validate()
def test_config_validate_negative_probability(self) -> None:
"""Test validation catches negative probability."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
config = AugmentationConfig(
wrinkle=AugmentationParams(probability=-0.1),
)
with pytest.raises(ValueError, match="probability"):
config.validate()

View File

@@ -0,0 +1,338 @@
"""
Tests for augmentation pipeline module.
TDD Phase 2: RED - Write tests first, then implement to pass.
"""
from typing import Any
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
class TestAugmentationPipeline:
"""Tests for AugmentationPipeline class."""
def test_create_with_config(self) -> None:
"""Test creating pipeline with config."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
assert pipeline.config is config
def test_create_with_seed(self) -> None:
"""Test creating pipeline with seed for reproducibility."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(seed=42)
pipeline = AugmentationPipeline(config)
assert pipeline.config.seed == 42
def test_apply_returns_augmentation_result(self) -> None:
"""Test that apply returns AugmentationResult."""
from shared.augmentation.base import AugmentationResult
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
image = np.zeros((100, 100, 3), dtype=np.uint8)
result = pipeline.apply(image)
assert isinstance(result, AugmentationResult)
assert result.image is not None
assert result.image.shape == image.shape
def test_apply_with_bboxes(self) -> None:
"""Test apply with bounding boxes."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
image = np.zeros((100, 100, 3), dtype=np.uint8)
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
result = pipeline.apply(image, bboxes)
# Bboxes should be preserved when preserve_bboxes=True
assert result.bboxes is not None
def test_apply_no_augmentations_enabled(self) -> None:
"""Test apply when no augmentations are enabled."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
# Disable all augmentations
config = AugmentationConfig(
lighting_variation=AugmentationParams(enabled=False),
)
pipeline = AugmentationPipeline(config)
image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
result = pipeline.apply(image)
# Image should be unchanged (or a copy)
np.testing.assert_array_equal(result.image, image)
def test_apply_does_not_mutate_input(self) -> None:
"""Test that apply does not mutate input image."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
original_copy = image.copy()
pipeline.apply(image)
np.testing.assert_array_equal(image, original_copy)
def test_reproducibility_with_seed(self) -> None:
"""Test that same seed produces same results."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config1 = AugmentationConfig(
seed=42,
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
)
config2 = AugmentationConfig(
seed=42,
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
)
pipeline1 = AugmentationPipeline(config1)
pipeline2 = AugmentationPipeline(config2)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
result1 = pipeline1.apply(image.copy())
result2 = pipeline2.apply(image.copy())
np.testing.assert_array_equal(result1.image, result2.image)
def test_metadata_contains_applied_augmentations(self) -> None:
"""Test that metadata lists applied augmentations."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
seed=42,
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
result = pipeline.apply(image)
assert result.metadata is not None
assert "applied_augmentations" in result.metadata
# Both should be applied with probability=1.0
assert "gaussian_noise" in result.metadata["applied_augmentations"]
assert "lighting_variation" in result.metadata["applied_augmentations"]
class TestAugmentationPipelineStageOrder:
"""Tests for pipeline stage ordering."""
def test_stage_order_defined(self) -> None:
"""Test that stage order is defined."""
from shared.augmentation.pipeline import AugmentationPipeline
assert hasattr(AugmentationPipeline, "STAGE_ORDER")
expected_stages = [
"geometric",
"degradation",
"lighting",
"texture",
"blur",
"noise",
]
assert AugmentationPipeline.STAGE_ORDER == expected_stages
def test_stage_mapping_defined(self) -> None:
"""Test that all augmentation types are mapped to stages."""
from shared.augmentation.pipeline import AugmentationPipeline
assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
expected_mappings = {
"perspective_warp": "geometric",
"wrinkle": "degradation",
"edge_damage": "degradation",
"stain": "degradation",
"lighting_variation": "lighting",
"shadow": "lighting",
"paper_texture": "texture",
"scanner_artifacts": "texture",
"gaussian_blur": "blur",
"motion_blur": "blur",
"gaussian_noise": "noise",
"salt_pepper": "noise",
}
for aug_name, stage in expected_mappings.items():
assert aug_name in AugmentationPipeline.STAGE_MAPPING
assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
def test_geometric_before_degradation(self) -> None:
"""Test that geometric transforms are applied before degradation."""
from shared.augmentation.pipeline import AugmentationPipeline
stages = AugmentationPipeline.STAGE_ORDER
geometric_idx = stages.index("geometric")
degradation_idx = stages.index("degradation")
assert geometric_idx < degradation_idx
def test_noise_applied_last(self) -> None:
"""Test that noise is applied last."""
from shared.augmentation.pipeline import AugmentationPipeline
stages = AugmentationPipeline.STAGE_ORDER
assert stages[-1] == "noise"
class TestAugmentationRegistry:
"""Tests for augmentation registry."""
def test_registry_exists(self) -> None:
"""Test that augmentation registry exists."""
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
assert isinstance(AUGMENTATION_REGISTRY, dict)
def test_registry_contains_all_types(self) -> None:
"""Test that registry contains all augmentation types."""
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
expected_types = [
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
]
for aug_type in expected_types:
assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
class TestPipelinePreview:
"""Tests for pipeline preview functionality."""
def test_preview_single_augmentation(self) -> None:
"""Test previewing a single augmentation."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
gaussian_noise=AugmentationParams(
enabled=True, probability=1.0, params={"std": (10, 10)}
),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
preview = pipeline.preview(image, "gaussian_noise")
assert preview.shape == image.shape
assert preview.dtype == np.uint8
# Preview should modify the image
assert not np.array_equal(preview, image)
def test_preview_unknown_augmentation_raises(self) -> None:
"""Test that previewing unknown augmentation raises error."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)
image = np.zeros((100, 100, 3), dtype=np.uint8)
with pytest.raises(ValueError, match="Unknown augmentation"):
pipeline.preview(image, "non_existent_augmentation")
def test_preview_is_deterministic(self) -> None:
"""Test that preview produces deterministic results."""
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
config = AugmentationConfig(
gaussian_noise=AugmentationParams(enabled=True),
)
pipeline = AugmentationPipeline(config)
image = np.full((100, 100, 3), 128, dtype=np.uint8)
preview1 = pipeline.preview(image, "gaussian_noise")
preview2 = pipeline.preview(image, "gaussian_noise")
np.testing.assert_array_equal(preview1, preview2)
class TestPipelineGetAvailableAugmentations:
"""Tests for getting available augmentations."""
def test_get_available_augmentations(self) -> None:
"""Test getting list of available augmentations."""
from shared.augmentation.pipeline import get_available_augmentations
augmentations = get_available_augmentations()
assert isinstance(augmentations, list)
assert len(augmentations) == 12
# Each item should have name, description, affects_geometry
for aug in augmentations:
assert "name" in aug
assert "description" in aug
assert "affects_geometry" in aug
assert "stage" in aug
def test_get_available_augmentations_includes_all_types(self) -> None:
"""Test that all augmentation types are included."""
from shared.augmentation.pipeline import get_available_augmentations
augmentations = get_available_augmentations()
names = [aug["name"] for aug in augmentations]
expected = [
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
]
for name in expected:
assert name in names

View File

@@ -0,0 +1,102 @@
"""
Tests for augmentation presets module.
TDD Phase 4: RED - Write tests first, then implement to pass.
"""
import pytest
class TestPresets:
"""Tests for augmentation presets."""
def test_presets_dict_exists(self) -> None:
"""Test that PRESETS dictionary exists."""
from shared.augmentation.presets import PRESETS
assert isinstance(PRESETS, dict)
assert len(PRESETS) > 0
def test_expected_presets_exist(self) -> None:
"""Test that expected presets are defined."""
from shared.augmentation.presets import PRESETS
expected_presets = ["conservative", "moderate", "aggressive", "scanned_document"]
for preset_name in expected_presets:
assert preset_name in PRESETS, f"Missing preset: {preset_name}"
def test_preset_structure(self) -> None:
"""Test that each preset has required structure."""
from shared.augmentation.presets import PRESETS
for name, preset in PRESETS.items():
assert "description" in preset, f"Preset {name} missing description"
assert "config" in preset, f"Preset {name} missing config"
assert isinstance(preset["description"], str)
assert isinstance(preset["config"], dict)
def test_get_preset_config(self) -> None:
"""Test getting config from preset."""
from shared.augmentation.presets import get_preset_config
config = get_preset_config("conservative")
assert config is not None
# Should have at least some augmentations defined
assert len(config) > 0
def test_get_preset_config_unknown_raises(self) -> None:
"""Test that getting unknown preset raises error."""
from shared.augmentation.presets import get_preset_config
with pytest.raises(ValueError, match="Unknown preset"):
get_preset_config("nonexistent_preset")
def test_create_config_from_preset(self) -> None:
"""Test creating AugmentationConfig from preset."""
from shared.augmentation.config import AugmentationConfig
from shared.augmentation.presets import create_config_from_preset
config = create_config_from_preset("moderate")
assert isinstance(config, AugmentationConfig)
def test_conservative_preset_is_safe(self) -> None:
"""Test that conservative preset only enables safe augmentations."""
from shared.augmentation.presets import create_config_from_preset
config = create_config_from_preset("conservative")
# Should NOT enable geometric transforms
assert config.perspective_warp.enabled is False
# Should NOT enable heavy degradation
assert config.wrinkle.enabled is False
assert config.edge_damage.enabled is False
assert config.stain.enabled is False
def test_aggressive_preset_enables_more(self) -> None:
"""Test that aggressive preset enables more augmentations."""
from shared.augmentation.presets import create_config_from_preset
config = create_config_from_preset("aggressive")
enabled = config.get_enabled_augmentations()
# Should enable multiple augmentation types
assert len(enabled) >= 6
def test_list_presets(self) -> None:
"""Test listing available presets."""
from shared.augmentation.presets import list_presets
presets = list_presets()
assert isinstance(presets, list)
assert len(presets) >= 4
# Each item should have name and description
for preset in presets:
assert "name" in preset
assert "description" in preset

View File

@@ -0,0 +1 @@
# Tests for augmentation transforms

View File

@@ -0,0 +1,293 @@
"""
Tests for DatasetAugmenter.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import tempfile
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
class TestDatasetAugmenter:
"""Tests for DatasetAugmenter class."""
@pytest.fixture
def sample_dataset(self, tmp_path: Path) -> Path:
"""Create a sample YOLO dataset structure."""
dataset_dir = tmp_path / "dataset"
# Create directory structure
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True)
(dataset_dir / "labels" / split).mkdir(parents=True)
# Create sample images and labels
for i in range(3):
# Create 100x100 white image
img = Image.new("RGB", (100, 100), color="white")
img_path = dataset_dir / "images" / "train" / f"doc_{i}.png"
img.save(img_path)
# Create label with 2 bboxes
# Format: class_id x_center y_center width height
label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n"
label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt"
label_path.write_text(label_content)
# Create data.yaml
data_yaml = dataset_dir / "data.yaml"
data_yaml.write_text(
"path: .\n"
"train: images/train\n"
"val: images/val\n"
"test: images/test\n"
"nc: 10\n"
"names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n"
)
return dataset_dir
@pytest.fixture
def augmentation_config(self) -> dict:
"""Create a sample augmentation config."""
return {
"gaussian_noise": {
"enabled": True,
"probability": 1.0,
"params": {"std": 10},
},
"gaussian_blur": {
"enabled": True,
"probability": 1.0,
"params": {"kernel_size": 3},
},
}
def test_augmenter_creates_additional_images(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter creates new augmented images."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
# Count original images
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert original_count == 3
# Apply augmentation with multiplier=2
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
# Should now have original + 2x augmented = 3 + 6 = 9 images
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert new_count == 9
assert result["augmented_images"] == 6
def test_augmenter_creates_matching_labels(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter creates label files for each augmented image."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=2)
# Check that each image has a matching label file
images = list((sample_dataset / "images" / "train").glob("*.png"))
labels = list((sample_dataset / "labels" / "train").glob("*.txt"))
assert len(images) == len(labels)
# Check that augmented images have corresponding labels
for img_path in images:
label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt"
assert label_path.exists(), f"Missing label for {img_path.name}"
def test_augmented_labels_have_valid_format(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented label files have valid YOLO format."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Check all label files
for label_path in (sample_dataset / "labels" / "train").glob("*.txt"):
content = label_path.read_text().strip()
if not content:
continue # Empty labels are valid (background images)
for line in content.split("\n"):
parts = line.split()
assert len(parts) == 5, f"Invalid label format in {label_path.name}"
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
# Check values are in valid range
assert 0 <= class_id < 100, f"Invalid class_id: {class_id}"
assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}"
assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}"
assert 0 <= width <= 1, f"Invalid width: {width}"
assert 0 <= height <= 1, f"Invalid height: {height}"
def test_augmented_images_are_different(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented images are actually different from originals."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Load original image
original_path = sample_dataset / "images" / "train" / "doc_0.png"
original_img = np.array(Image.open(original_path))
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Find augmented version
aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png"
assert aug_path.exists()
aug_img = np.array(Image.open(aug_path))
# Images should be different (due to noise/blur)
assert not np.array_equal(original_img, aug_img)
def test_augmented_images_same_size(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmented images have same size as originals."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Get original size
original_path = sample_dataset / "images" / "train" / "doc_0.png"
original_img = Image.open(original_path)
original_size = original_img.size
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Check all augmented images have same size
for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"):
img = Image.open(img_path)
assert img.size == original_size, f"{img_path.name} has wrong size"
def test_perspective_warp_updates_bboxes(self, sample_dataset: Path):
"""Test that perspective_warp augmentation updates bbox coordinates."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
config = {
"perspective_warp": {
"enabled": True,
"probability": 1.0,
"params": {"max_warp": 0.05}, # Use larger warp for visible difference
},
}
# Read original label
original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text()
original_bboxes = [line.split() for line in original_label.strip().split("\n")]
augmenter = DatasetAugmenter(config)
augmenter.augment_dataset(sample_dataset, multiplier=1)
# Read augmented label
aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text()
aug_bboxes = [line.split() for line in aug_label.strip().split("\n")]
# Same number of bboxes
assert len(original_bboxes) == len(aug_bboxes)
# At least one bbox should have different coordinates
# (perspective warp changes geometry)
differences_found = False
for orig, aug in zip(original_bboxes, aug_bboxes):
# Class ID should be same
assert orig[0] == aug[0]
# Coordinates might differ
if orig[1:] != aug[1:]:
differences_found = True
assert differences_found, "Perspective warp should change bbox coordinates"
def test_augmenter_only_processes_train_split(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter only processes train split by default."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Add a val image
val_img = Image.new("RGB", (100, 100), color="white")
val_img.save(sample_dataset / "images" / "val" / "val_doc.png")
(sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n")
augmenter = DatasetAugmenter(augmentation_config)
augmenter.augment_dataset(sample_dataset, multiplier=2)
# Val should still have only 1 image
val_count = len(list((sample_dataset / "images" / "val").glob("*.png")))
assert val_count == 1
def test_augmenter_with_multiplier_zero_does_nothing(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that multiplier=0 creates no augmented images."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
augmenter = DatasetAugmenter(augmentation_config)
result = augmenter.augment_dataset(sample_dataset, multiplier=0)
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
assert new_count == original_count
assert result["augmented_images"] == 0
def test_augmenter_with_seed_is_reproducible(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that same seed produces same augmentation results."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
# Create two separate datasets
import shutil
dataset1 = sample_dataset
dataset2 = sample_dataset.parent / "dataset2"
shutil.copytree(dataset1, dataset2)
# Augment both with same seed
augmenter1 = DatasetAugmenter(augmentation_config, seed=42)
augmenter1.augment_dataset(dataset1, multiplier=1)
augmenter2 = DatasetAugmenter(augmentation_config, seed=42)
augmenter2.augment_dataset(dataset2, multiplier=1)
# Compare augmented images
aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png"))
aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png"))
assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation"
def test_augmenter_returns_summary(
self, sample_dataset: Path, augmentation_config: dict
):
"""Test that augmenter returns a summary of what was done."""
from shared.augmentation.dataset_augmenter import DatasetAugmenter
augmenter = DatasetAugmenter(augmentation_config)
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
assert "original_images" in result
assert "augmented_images" in result
assert "total_images" in result
assert result["original_images"] == 3
assert result["augmented_images"] == 6
assert result["total_images"] == 9

View File

@@ -0,0 +1,261 @@
"""
Tests for augmentation API routes.
TDD Phase 5: RED - Write tests first, then implement to pass.
"""
import pytest
from fastapi.testclient import TestClient
class TestAugmentationTypesEndpoint:
"""Tests for GET /admin/augmentation/types endpoint."""
def test_list_augmentation_types(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test listing available augmentation types."""
response = admin_client.get(
"/api/v1/admin/augmentation/types",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "augmentation_types" in data
assert len(data["augmentation_types"]) == 12
# Check structure
aug_type = data["augmentation_types"][0]
assert "name" in aug_type
assert "description" in aug_type
assert "affects_geometry" in aug_type
assert "stage" in aug_type
def test_list_augmentation_types_unauthorized(
self, admin_client: TestClient
) -> None:
"""Test that unauthorized request is rejected."""
response = admin_client.get("/api/v1/admin/augmentation/types")
assert response.status_code == 401
class TestAugmentationPresetsEndpoint:
"""Tests for GET /admin/augmentation/presets endpoint."""
def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None:
"""Test listing available presets."""
response = admin_client.get(
"/api/v1/admin/augmentation/presets",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "presets" in data
assert len(data["presets"]) >= 4
# Check expected presets exist
preset_names = [p["name"] for p in data["presets"]]
assert "conservative" in preset_names
assert "moderate" in preset_names
assert "aggressive" in preset_names
assert "scanned_document" in preset_names
class TestAugmentationPreviewEndpoint:
"""Tests for POST /admin/augmentation/preview/{document_id} endpoint."""
def test_preview_augmentation(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
) -> None:
"""Test previewing augmentation on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {"std": 15},
},
)
assert response.status_code == 200
data = response.json()
assert "preview_url" in data
assert "original_url" in data
assert "applied_params" in data
def test_preview_invalid_augmentation_type(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
) -> None:
"""Test that invalid augmentation type returns error."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "nonexistent",
"params": {},
},
)
assert response.status_code == 400
def test_preview_nonexistent_document(
self,
admin_client: TestClient,
admin_token: str,
) -> None:
"""Test that nonexistent document returns 404."""
response = admin_client.post(
"/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000",
headers={"X-Admin-Token": admin_token},
json={
"augmentation_type": "gaussian_noise",
"params": {},
},
)
assert response.status_code == 404
class TestAugmentationPreviewConfigEndpoint:
"""Tests for POST /admin/augmentation/preview-config/{document_id} endpoint."""
def test_preview_config(
self,
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
) -> None:
"""Test previewing full config on a document."""
response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token},
json={
"gaussian_noise": {"enabled": True, "probability": 1.0},
"lighting_variation": {"enabled": True, "probability": 1.0},
"preserve_bboxes": True,
"seed": 42,
},
)
assert response.status_code == 200
data = response.json()
assert "preview_url" in data
assert "original_url" in data
class TestAugmentationBatchEndpoint:
"""Tests for POST /admin/augmentation/batch endpoint."""
def test_create_augmented_dataset(
self,
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
) -> None:
"""Test creating augmented dataset."""
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
json={
"dataset_id": sample_dataset_id,
"config": {
"gaussian_noise": {"enabled": True, "probability": 0.5},
"preserve_bboxes": True,
},
"output_name": "test_augmented_dataset",
"multiplier": 2,
},
)
assert response.status_code == 200
data = response.json()
assert "task_id" in data
assert "status" in data
assert "estimated_images" in data
def test_create_augmented_dataset_invalid_multiplier(
self,
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
) -> None:
"""Test that invalid multiplier is rejected."""
response = admin_client.post(
"/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token},
json={
"dataset_id": sample_dataset_id,
"config": {},
"output_name": "test",
"multiplier": 100, # Too high
},
)
assert response.status_code == 422 # Validation error
class TestAugmentedDatasetsListEndpoint:
"""Tests for GET /admin/augmentation/datasets endpoint."""
def test_list_augmented_datasets(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test listing augmented datasets."""
response = admin_client.get(
"/api/v1/admin/augmentation/datasets",
headers={"X-Admin-Token": admin_token},
)
assert response.status_code == 200
data = response.json()
assert "total" in data
assert "limit" in data
assert "offset" in data
assert "datasets" in data
assert isinstance(data["datasets"], list)
def test_list_augmented_datasets_pagination(
self, admin_client: TestClient, admin_token: str
) -> None:
"""Test pagination parameters."""
response = admin_client.get(
"/api/v1/admin/augmentation/datasets",
headers={"X-Admin-Token": admin_token},
params={"limit": 5, "offset": 0},
)
assert response.status_code == 200
data = response.json()
assert data["limit"] == 5
assert data["offset"] == 0
# Fixtures for tests
@pytest.fixture
def sample_document_id() -> str:
"""Provide a sample document ID for testing."""
# This would need to be created in test setup
return "test-document-id"
@pytest.fixture
def sample_dataset_id() -> str:
"""Provide a sample dataset ID for testing."""
# This would need to be created in test setup
return "test-dataset-id"

View File

@@ -329,3 +329,414 @@ class TestDatasetBuilder:
results.append([(d["document_id"], d["split"]) for d in docs])
assert results[0] == results[1]
class TestAssignSplitsByGroup:
"""Tests for _assign_splits_by_group method with group_key logic."""
def _make_mock_doc(self, doc_id, group_key=None):
"""Create a mock AdminDocument with document_id and group_key."""
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.group_key = group_key
doc.page_count = 1
return doc
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
"""Documents with unique group_key are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 3 documents, each with unique group_key
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-B"),
self._make_mock_doc(uuid4(), group_key="group-C"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
# With 3 groups: 70% train = 2, 20% val = 1 (at least 1)
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 1
assert val_count >= 1 # Ensure val is not empty
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key=None),
self._make_mock_doc(uuid4(), group_key=""),
self._make_mock_doc(uuid4(), group_key=None),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
# Each null/empty group_key doc is independent, distributed across splits
# With 3 docs: ensure at least 1 in train and 1 in val
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 1
assert val_count >= 1
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
"""Documents with same group_key should be assigned to the same split."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 6 documents in 2 groups
docs = [
self._make_mock_doc(uuid4(), group_key="supplier-A"),
self._make_mock_doc(uuid4(), group_key="supplier-A"),
self._make_mock_doc(uuid4(), group_key="supplier-A"),
self._make_mock_doc(uuid4(), group_key="supplier-B"),
self._make_mock_doc(uuid4(), group_key="supplier-B"),
self._make_mock_doc(uuid4(), group_key="supplier-B"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
# All docs in supplier-A should have same split
splits_a = [result[str(d.document_id)] for d in docs[:3]]
assert len(set(splits_a)) == 1, "All docs in supplier-A should be in same split"
# All docs in supplier-B should have same split
splits_b = [result[str(d.document_id)] for d in docs[3:]]
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
"""Multi-doc groups should be split according to train/val/test ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 10 groups with 2 docs each
docs = []
for i in range(10):
group_key = f"group-{i}"
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
# Count groups per split
group_splits = {}
for doc in docs:
split = result[str(doc.document_id)]
if doc.group_key not in group_splits:
group_splits[doc.group_key] = split
else:
# Verify same group has same split
assert group_splits[doc.group_key] == split
split_counts = {"train": 0, "val": 0, "test": 0}
for split in group_splits.values():
split_counts[split] += 1
# With 10 groups, 70/20/10 -> ~7 train, ~2 val, ~1 test
assert split_counts["train"] >= 6
assert split_counts["train"] <= 8
assert split_counts["val"] >= 1
assert split_counts["val"] <= 3
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
"""Mix of single-doc and multi-doc groups should be handled correctly."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
# Single-doc groups
self._make_mock_doc(uuid4(), group_key="single-1"),
self._make_mock_doc(uuid4(), group_key="single-2"),
self._make_mock_doc(uuid4(), group_key=None),
# Multi-doc groups
self._make_mock_doc(uuid4(), group_key="multi-A"),
self._make_mock_doc(uuid4(), group_key="multi-A"),
self._make_mock_doc(uuid4(), group_key="multi-B"),
self._make_mock_doc(uuid4(), group_key="multi-B"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
# All groups are shuffled and distributed
# Ensure at least 1 in train and 1 in val
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 1
assert val_count >= 1
# Multi-doc groups stay together
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
"""Same seed should produce same split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-B"),
self._make_mock_doc(uuid4(), group_key="group-B"),
self._make_mock_doc(uuid4(), group_key="group-C"),
self._make_mock_doc(uuid4(), group_key="group-C"),
]
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
assert result1 == result2
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
"""Different seeds should potentially produce different split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# Many groups to increase chance of different results
docs = []
for i in range(20):
group_key = f"group-{i}"
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=1)
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=999)
# Results should be different (very likely with 20 groups)
assert result1 != result2
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
"""Every document should be assigned a split."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key="group-A"),
self._make_mock_doc(uuid4(), group_key=None),
self._make_mock_doc(uuid4(), group_key="single"),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
assert len(result) == len(docs)
for doc in docs:
assert str(doc.document_id) in result
assert result[str(doc.document_id)] in ["train", "val", "test"]
def test_empty_documents_list(self, tmp_path, mock_admin_db):
"""Empty document list should return empty result."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
assert result == {}
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
"""When all groups have multiple docs, splits should follow ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
# 5 groups with 3 docs each
docs = []
for i in range(5):
group_key = f"group-{i}"
for _ in range(3):
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
# Group splits
group_splits = {}
for doc in docs:
if doc.group_key not in group_splits:
group_splits[doc.group_key] = result[str(doc.document_id)]
split_counts = {"train": 0, "val": 0, "test": 0}
for split in group_splits.values():
split_counts[split] += 1
# With 5 groups, 60/20/20 -> 3 train, 1 val, 1 test
assert split_counts["train"] >= 2
assert split_counts["train"] <= 4
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
"""When all groups have single doc, they are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
docs = [
self._make_mock_doc(uuid4(), group_key="unique-1"),
self._make_mock_doc(uuid4(), group_key="unique-2"),
self._make_mock_doc(uuid4(), group_key="unique-3"),
self._make_mock_doc(uuid4(), group_key=None),
self._make_mock_doc(uuid4(), group_key=""),
]
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
# With 5 groups: 60% train = 3, 20% val = 1 (at least 1)
train_count = sum(1 for s in result.values() if s == "train")
val_count = sum(1 for s in result.values() if s == "val")
assert train_count >= 2
assert val_count >= 1 # Ensure val is not empty
class TestBuildDatasetWithGroupKey:
"""Integration tests for build_dataset with group_key logic."""
@pytest.fixture
def grouped_documents(self, tmp_path):
"""Create documents with various group_key configurations."""
doc_ids = []
docs = []
# Create 3 groups: 2 multi-doc groups + 2 single-doc groups
group_configs = [
("supplier-A", 3), # Multi-doc group: 3 docs
("supplier-B", 2), # Multi-doc group: 2 docs
("unique-1", 1), # Single-doc group
(None, 1), # Null group_key
]
for group_key, count in group_configs:
for _ in range(count):
doc_id = uuid4()
doc_ids.append(doc_id)
# Create image files
doc_dir = tmp_path / "admin_images" / str(doc_id)
doc_dir.mkdir(parents=True)
for page in range(1, 3):
(doc_dir / f"page_{page}.png").write_bytes(b"fake-png")
# Create mock document
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.filename = f"{doc_id}.pdf"
doc.page_count = 2
doc.group_key = group_key
doc.file_path = str(doc_dir)
docs.append(doc)
return tmp_path, docs
@pytest.fixture
def grouped_annotations(self, grouped_documents):
"""Create annotations for grouped documents."""
tmp_path, docs = grouped_documents
annotations = {}
for doc in docs:
doc_anns = []
for page in range(1, 3):
ann = MagicMock(spec=AdminAnnotation)
ann.document_id = doc.document_id
ann.page_number = page
ann.class_id = 0
ann.class_name = "invoice_number"
ann.x_center = 0.5
ann.y_center = 0.3
ann.width = 0.2
ann.height = 0.05
doc_anns.append(ann)
annotations[str(doc.document_id)] = doc_anns
return annotations
def test_build_respects_group_key_splits(
self, grouped_documents, grouped_annotations, mock_admin_db
):
"""build_dataset should use group_key for split assignment."""
from inference.web.services.dataset_builder import DatasetBuilder
tmp_path, docs = grouped_documents
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = docs
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
grouped_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in docs],
train_ratio=0.5,
val_ratio=0.5,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
# Get the document splits from add_dataset_documents call
call_args = mock_admin_db.add_dataset_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
# Build mapping of doc_id -> split
doc_split_map = {d["document_id"]: d["split"] for d in docs_added}
# Verify all docs are assigned a valid split
for doc_id in doc_split_map:
assert doc_split_map[doc_id] in ("train", "val", "test")
# Verify multi-doc groups stay together
supplier_a_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-A"]
supplier_a_splits = [doc_split_map[doc_id] for doc_id in supplier_a_ids]
assert len(set(supplier_a_splits)) == 1, "supplier-A docs should be in same split"
supplier_b_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-B"]
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
"""All docs with same group_key should go to same split."""
from inference.web.services.dataset_builder import DatasetBuilder
# Create 5 docs all with same group_key
docs = []
for i in range(5):
doc_id = uuid4()
doc_dir = tmp_path / "admin_images" / str(doc_id)
doc_dir.mkdir(parents=True)
(doc_dir / "page_1.png").write_bytes(b"fake-png")
doc = MagicMock(spec=AdminDocument)
doc.document_id = doc_id
doc.filename = f"{doc_id}.pdf"
doc.page_count = 1
doc.group_key = "same-group"
docs.append(doc)
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = docs
mock_admin_db.get_annotations_for_document.return_value = []
dataset = mock_admin_db.create_dataset.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in docs],
train_ratio=0.6,
val_ratio=0.2,
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
call_args = mock_admin_db.add_dataset_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
splits = [d["split"] for d in docs_added]
# All should be in the same split (one group)
assert len(set(splits)) == 1, "All docs with same group_key should be in same split"

View File

@@ -25,6 +25,9 @@ TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
TEST_TOKEN = "test-admin-token-12345"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
# Generate 10 unique UUIDs for minimum document count tests
TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
def _make_dataset(**overrides) -> MagicMock:
defaults = dict(
@@ -83,14 +86,14 @@ class TestCreateDatasetRoute:
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
"total_documents": 2,
"total_images": 4,
"total_annotations": 10,
"total_documents": 10,
"total_images": 20,
"total_annotations": 50,
}
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
)
with patch(
@@ -104,6 +107,73 @@ class TestCreateDatasetRoute:
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
def test_create_dataset_fails_with_less_than_10_documents(self):
"""Test that creating dataset fails if fewer than 10 documents provided."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# Only 2 documents - should fail
request = DatasetCreateRequest(
name="test-dataset",
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
assert "got 2" in exc_info.value.detail
# Ensure DB was never called since validation failed first
mock_db.create_dataset.assert_not_called()
def test_create_dataset_fails_with_9_documents(self):
"""Test boundary condition: 9 documents should fail."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# 9 documents - just under the limit
request = DatasetCreateRequest(
name="test-dataset",
document_ids=TEST_DOC_UUIDS[:9],
)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
def test_create_dataset_succeeds_with_exactly_10_documents(self):
"""Test boundary condition: exactly 10 documents should succeed."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
# Exactly 10 documents - should pass
request = DatasetCreateRequest(
name="test-dataset",
document_ids=TEST_DOC_UUIDS[:10],
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
):
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
@@ -198,3 +268,53 @@ class TestTrainFromDatasetRoute:
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
def test_incremental_training_with_base_model(self):
"""Test training with base_model_version_id for incremental training."""
fn = _find_endpoint("train_from_dataset")
mock_model_version = MagicMock()
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
mock_model_version.version = "1.0.0"
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = mock_model_version
mock_db.create_training_task.return_value = TEST_TASK_UUID
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
# Verify model version was looked up
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
# Verify task was created with finetune type
call_kwargs = mock_db.create_training_task.call_args[1]
assert call_kwargs["task_type"] == "finetune"
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
assert result.task_id == TEST_TASK_UUID
assert "Incremental training" in result.message
def test_incremental_training_with_invalid_base_model_fails(self):
"""Test that training fails if base_model_version_id doesn't exist."""
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = None
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
assert "Base model version not found" in exc_info.value.detail

View File

@@ -0,0 +1,399 @@
"""
Tests for Model Version API routes.
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from inference.data.admin_models import ModelVersion
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
)
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
TEST_TOKEN = "test-admin-token-12345"
def _make_model_version(**overrides) -> MagicMock:
"""Create a mock ModelVersion."""
defaults = dict(
version_id=UUID(TEST_VERSION_UUID),
version="1.0.0",
name="test-model-v1",
description="Test model version",
model_path="/models/test-model-v1.pt",
status="inactive",
is_active=False,
task_id=UUID(TEST_TASK_UUID),
dataset_id=UUID(TEST_DATASET_UUID),
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=100,
training_config={"epochs": 100, "batch_size": 16},
file_size=52428800,
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
activated_at=None,
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
)
defaults.update(overrides)
model = MagicMock(spec=ModelVersion)
for k, v in defaults.items():
setattr(model, k, v)
return model
def _find_endpoint(name: str):
"""Find endpoint function by name."""
router = create_training_router()
for route in router.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
return route.endpoint
raise AssertionError(f"Endpoint {name} not found")
class TestModelVersionRouterRegistration:
"""Tests that model version endpoints are registered."""
def test_router_has_model_endpoints(self):
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("models" in p for p in paths)
def test_has_create_model_version_endpoint(self):
endpoint = _find_endpoint("create_model_version")
assert endpoint is not None
def test_has_list_model_versions_endpoint(self):
endpoint = _find_endpoint("list_model_versions")
assert endpoint is not None
def test_has_get_active_model_endpoint(self):
endpoint = _find_endpoint("get_active_model")
assert endpoint is not None
def test_has_activate_model_version_endpoint(self):
endpoint = _find_endpoint("activate_model_version")
assert endpoint is not None
class TestCreateModelVersionRoute:
"""Tests for POST /admin/training/models."""
def test_create_model_version(self):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
description="Test model",
metrics_mAP=0.935,
document_count=100,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.create_model_version.assert_called_once()
assert result.version_id == TEST_VERSION_UUID
assert result.status == "inactive"
assert result.message == "Model version created successfully"
def test_create_model_version_with_task_and_dataset(self):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model-v1",
model_path="/models/test-model-v1.pt",
task_id=TEST_TASK_UUID,
dataset_id=TEST_DATASET_UUID,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
call_kwargs = mock_db.create_model_version.call_args[1]
assert call_kwargs["task_id"] == TEST_TASK_UUID
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
class TestListModelVersionsRoute:
"""Tests for GET /admin/training/models."""
def test_list_model_versions(self):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = (
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
2,
)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
assert result.total == 2
assert len(result.models) == 2
assert result.models[0].version == "1.0.0"
def test_list_model_versions_with_status_filter(self):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
assert result.total == 1
assert result.models[0].status == "active"
class TestGetActiveModelRoute:
"""Tests for GET /admin/training/models/active."""
def test_get_active_model_when_exists(self):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
assert result.has_active_model is True
assert result.model is not None
assert result.model.is_active is True
def test_get_active_model_when_none(self):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = None
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
assert result.has_active_model is False
assert result.model is None
class TestGetModelVersionRoute:
"""Tests for GET /admin/training/models/{version_id}."""
def test_get_model_version(self):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = _make_model_version()
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.version_id == TEST_VERSION_UUID
assert result.version == "1.0.0"
assert result.name == "test-model-v1"
assert result.metrics_mAP == 0.935
def test_get_model_version_not_found(self):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestUpdateModelVersionRoute:
"""Tests for PATCH /admin/training/models/{version_id}."""
def test_update_model_version(self):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.update_model_version.assert_called_once_with(
version_id=TEST_VERSION_UUID,
name="updated-name",
description="Updated description",
status=None,
)
assert result.message == "Model version updated successfully"
def test_update_model_version_not_found(self):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = None
request = ModelVersionUpdateRequest(name="updated-name")
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestActivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/activate."""
def test_activate_model_version(self):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
assert result.message == "Model version activated for inference"
def test_activate_model_version_not_found(self):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestDeactivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
def test_deactivate_model_version(self):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.status == "inactive"
assert result.message == "Model version deactivated"
def test_deactivate_model_version_not_found(self):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404
class TestArchiveModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/archive."""
def test_archive_model_version(self):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert result.status == "archived"
assert result.message == "Model version archived"
def test_archive_active_model_fails(self):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
class TestDeleteModelVersionRoute:
"""Tests for DELETE /admin/training/models/{version_id}."""
def test_delete_model_version(self):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = True
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result["message"] == "Model version deleted"
def test_delete_active_model_fails(self):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = False
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 400
class TestModelVersionSchemas:
"""Tests for model version Pydantic schemas."""
def test_create_request_validation(self):
request = ModelVersionCreateRequest(
version="1.0.0",
name="test-model",
model_path="/models/test.pt",
)
assert request.version == "1.0.0"
assert request.name == "test-model"
assert request.document_count == 0
def test_create_request_with_metrics(self):
request = ModelVersionCreateRequest(
version="2.0.0",
name="test-model-v2",
model_path="/models/v2.pt",
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
assert request.metrics_mAP == 0.95
assert request.document_count == 500
def test_update_request_partial(self):
request = ModelVersionUpdateRequest(name="new-name")
assert request.name == "new-name"
assert request.description is None
assert request.status is None