WIP
This commit is contained in:
@@ -4,6 +4,7 @@ import { DashboardOverview } from './components/DashboardOverview'
|
|||||||
import { Dashboard } from './components/Dashboard'
|
import { Dashboard } from './components/Dashboard'
|
||||||
import { DocumentDetail } from './components/DocumentDetail'
|
import { DocumentDetail } from './components/DocumentDetail'
|
||||||
import { Training } from './components/Training'
|
import { Training } from './components/Training'
|
||||||
|
import { DatasetDetail } from './components/DatasetDetail'
|
||||||
import { Models } from './components/Models'
|
import { Models } from './components/Models'
|
||||||
import { Login } from './components/Login'
|
import { Login } from './components/Login'
|
||||||
import { InferenceDemo } from './components/InferenceDemo'
|
import { InferenceDemo } from './components/InferenceDemo'
|
||||||
@@ -55,7 +56,14 @@ const App: React.FC = () => {
|
|||||||
case 'demo':
|
case 'demo':
|
||||||
return <InferenceDemo />
|
return <InferenceDemo />
|
||||||
case 'training':
|
case 'training':
|
||||||
return <Training />
|
return <Training onNavigate={handleNavigate} />
|
||||||
|
case 'dataset-detail':
|
||||||
|
return (
|
||||||
|
<DatasetDetail
|
||||||
|
datasetId={selectedDocId || ''}
|
||||||
|
onBack={() => setCurrentView('training')}
|
||||||
|
/>
|
||||||
|
)
|
||||||
case 'models':
|
case 'models':
|
||||||
return <Models />
|
return <Models />
|
||||||
default:
|
default:
|
||||||
|
|||||||
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
/**
|
||||||
|
* Tests for augmentation API endpoints.
|
||||||
|
*
|
||||||
|
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||||
|
import { augmentationApi } from './augmentation'
|
||||||
|
import apiClient from '../client'
|
||||||
|
|
||||||
|
// Mock the API client
|
||||||
|
vi.mock('../client', () => ({
|
||||||
|
default: {
|
||||||
|
get: vi.fn(),
|
||||||
|
post: vi.fn(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('augmentationApi', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getTypes', () => {
|
||||||
|
it('should fetch augmentation types', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
data: {
|
||||||
|
augmentation_types: [
|
||||||
|
{
|
||||||
|
name: 'gaussian_noise',
|
||||||
|
description: 'Adds Gaussian noise',
|
||||||
|
affects_geometry: false,
|
||||||
|
stage: 'noise',
|
||||||
|
default_params: { mean: 0, std: 15 },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||||
|
|
||||||
|
const result = await augmentationApi.getTypes()
|
||||||
|
|
||||||
|
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types')
|
||||||
|
expect(result.augmentation_types).toHaveLength(1)
|
||||||
|
expect(result.augmentation_types[0].name).toBe('gaussian_noise')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getPresets', () => {
|
||||||
|
it('should fetch augmentation presets', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
data: {
|
||||||
|
presets: [
|
||||||
|
{ name: 'conservative', description: 'Safe augmentations' },
|
||||||
|
{ name: 'moderate', description: 'Balanced augmentations' },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||||
|
|
||||||
|
const result = await augmentationApi.getPresets()
|
||||||
|
|
||||||
|
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets')
|
||||||
|
expect(result.presets).toHaveLength(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('preview', () => {
|
||||||
|
it('should preview single augmentation', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
data: {
|
||||||
|
preview_url: 'data:image/png;base64,xxx',
|
||||||
|
original_url: 'data:image/png;base64,yyy',
|
||||||
|
applied_params: { std: 15 },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||||
|
|
||||||
|
const result = await augmentationApi.preview('doc-123', {
|
||||||
|
augmentation_type: 'gaussian_noise',
|
||||||
|
params: { std: 15 },
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(apiClient.post).toHaveBeenCalledWith(
|
||||||
|
'/api/v1/admin/augmentation/preview/doc-123',
|
||||||
|
{
|
||||||
|
augmentation_type: 'gaussian_noise',
|
||||||
|
params: { std: 15 },
|
||||||
|
},
|
||||||
|
{ params: { page: 1 } }
|
||||||
|
)
|
||||||
|
expect(result.preview_url).toBe('data:image/png;base64,xxx')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support custom page number', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
data: {
|
||||||
|
preview_url: 'data:image/png;base64,xxx',
|
||||||
|
original_url: 'data:image/png;base64,yyy',
|
||||||
|
applied_params: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||||
|
|
||||||
|
await augmentationApi.preview(
|
||||||
|
'doc-123',
|
||||||
|
{ augmentation_type: 'gaussian_noise', params: {} },
|
||||||
|
2
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(apiClient.post).toHaveBeenCalledWith(
|
||||||
|
'/api/v1/admin/augmentation/preview/doc-123',
|
||||||
|
expect.anything(),
|
||||||
|
{ params: { page: 2 } }
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
144
frontend/src/api/endpoints/augmentation.ts
Normal file
144
frontend/src/api/endpoints/augmentation.ts
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
/**
|
||||||
|
* Augmentation API endpoints.
|
||||||
|
*
|
||||||
|
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import apiClient from '../client'
|
||||||
|
|
||||||
|
// Types
|
||||||
|
export interface AugmentationTypeInfo {
|
||||||
|
name: string
|
||||||
|
description: string
|
||||||
|
affects_geometry: boolean
|
||||||
|
stage: string
|
||||||
|
default_params: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AugmentationTypesResponse {
|
||||||
|
augmentation_types: AugmentationTypeInfo[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PresetInfo {
|
||||||
|
name: string
|
||||||
|
description: string
|
||||||
|
config?: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PresetsResponse {
|
||||||
|
presets: PresetInfo[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PreviewRequest {
|
||||||
|
augmentation_type: string
|
||||||
|
params: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PreviewResponse {
|
||||||
|
preview_url: string
|
||||||
|
original_url: string
|
||||||
|
applied_params: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AugmentationParams {
|
||||||
|
enabled: boolean
|
||||||
|
probability: number
|
||||||
|
params: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AugmentationConfig {
|
||||||
|
perspective_warp?: AugmentationParams
|
||||||
|
wrinkle?: AugmentationParams
|
||||||
|
edge_damage?: AugmentationParams
|
||||||
|
stain?: AugmentationParams
|
||||||
|
lighting_variation?: AugmentationParams
|
||||||
|
shadow?: AugmentationParams
|
||||||
|
gaussian_blur?: AugmentationParams
|
||||||
|
motion_blur?: AugmentationParams
|
||||||
|
gaussian_noise?: AugmentationParams
|
||||||
|
salt_pepper?: AugmentationParams
|
||||||
|
paper_texture?: AugmentationParams
|
||||||
|
scanner_artifacts?: AugmentationParams
|
||||||
|
preserve_bboxes?: boolean
|
||||||
|
seed?: number | null
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BatchRequest {
|
||||||
|
dataset_id: string
|
||||||
|
config: AugmentationConfig
|
||||||
|
output_name: string
|
||||||
|
multiplier: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BatchResponse {
|
||||||
|
task_id: string
|
||||||
|
status: string
|
||||||
|
message: string
|
||||||
|
estimated_images: number
|
||||||
|
}
|
||||||
|
|
||||||
|
// API functions
|
||||||
|
export const augmentationApi = {
|
||||||
|
/**
|
||||||
|
* Fetch available augmentation types.
|
||||||
|
*/
|
||||||
|
async getTypes(): Promise<AugmentationTypesResponse> {
|
||||||
|
const response = await apiClient.get<AugmentationTypesResponse>(
|
||||||
|
'/api/v1/admin/augmentation/types'
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch augmentation presets.
|
||||||
|
*/
|
||||||
|
async getPresets(): Promise<PresetsResponse> {
|
||||||
|
const response = await apiClient.get<PresetsResponse>(
|
||||||
|
'/api/v1/admin/augmentation/presets'
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Preview a single augmentation on a document page.
|
||||||
|
*/
|
||||||
|
async preview(
|
||||||
|
documentId: string,
|
||||||
|
request: PreviewRequest,
|
||||||
|
page: number = 1
|
||||||
|
): Promise<PreviewResponse> {
|
||||||
|
const response = await apiClient.post<PreviewResponse>(
|
||||||
|
`/api/v1/admin/augmentation/preview/${documentId}`,
|
||||||
|
request,
|
||||||
|
{ params: { page } }
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Preview full augmentation config on a document page.
|
||||||
|
*/
|
||||||
|
async previewConfig(
|
||||||
|
documentId: string,
|
||||||
|
config: AugmentationConfig,
|
||||||
|
page: number = 1
|
||||||
|
): Promise<PreviewResponse> {
|
||||||
|
const response = await apiClient.post<PreviewResponse>(
|
||||||
|
`/api/v1/admin/augmentation/preview-config/${documentId}`,
|
||||||
|
config,
|
||||||
|
{ params: { page } }
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an augmented dataset.
|
||||||
|
*/
|
||||||
|
async createBatch(request: BatchRequest): Promise<BatchResponse> {
|
||||||
|
const response = await apiClient.post<BatchResponse>(
|
||||||
|
'/api/v1/admin/augmentation/batch',
|
||||||
|
request
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
}
|
||||||
52
frontend/src/api/endpoints/datasets.ts
Normal file
52
frontend/src/api/endpoints/datasets.ts
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import apiClient from '../client'
|
||||||
|
import type {
|
||||||
|
DatasetCreateRequest,
|
||||||
|
DatasetDetailResponse,
|
||||||
|
DatasetListResponse,
|
||||||
|
DatasetResponse,
|
||||||
|
DatasetTrainRequest,
|
||||||
|
TrainingTaskResponse,
|
||||||
|
} from '../types'
|
||||||
|
|
||||||
|
export const datasetsApi = {
|
||||||
|
list: async (params?: {
|
||||||
|
status?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}): Promise<DatasetListResponse> => {
|
||||||
|
const { data } = await apiClient.get('/api/v1/admin/training/datasets', {
|
||||||
|
params,
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
create: async (req: DatasetCreateRequest): Promise<DatasetResponse> => {
|
||||||
|
const { data } = await apiClient.post('/api/v1/admin/training/datasets', req)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
getDetail: async (datasetId: string): Promise<DatasetDetailResponse> => {
|
||||||
|
const { data } = await apiClient.get(
|
||||||
|
`/api/v1/admin/training/datasets/${datasetId}`
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
remove: async (datasetId: string): Promise<{ message: string }> => {
|
||||||
|
const { data } = await apiClient.delete(
|
||||||
|
`/api/v1/admin/training/datasets/${datasetId}`
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
trainFromDataset: async (
|
||||||
|
datasetId: string,
|
||||||
|
req: DatasetTrainRequest
|
||||||
|
): Promise<TrainingTaskResponse> => {
|
||||||
|
const { data } = await apiClient.post(
|
||||||
|
`/api/v1/admin/training/datasets/${datasetId}/train`,
|
||||||
|
req
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -21,14 +21,20 @@ export const documentsApi = {
|
|||||||
return data
|
return data
|
||||||
},
|
},
|
||||||
|
|
||||||
upload: async (file: File): Promise<UploadDocumentResponse> => {
|
upload: async (file: File, groupKey?: string): Promise<UploadDocumentResponse> => {
|
||||||
const formData = new FormData()
|
const formData = new FormData()
|
||||||
formData.append('file', file)
|
formData.append('file', file)
|
||||||
|
|
||||||
|
const params: Record<string, string> = {}
|
||||||
|
if (groupKey) {
|
||||||
|
params.group_key = groupKey
|
||||||
|
}
|
||||||
|
|
||||||
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'multipart/form-data',
|
'Content-Type': 'multipart/form-data',
|
||||||
},
|
},
|
||||||
|
params,
|
||||||
})
|
})
|
||||||
return data
|
return data
|
||||||
},
|
},
|
||||||
@@ -77,4 +83,16 @@ export const documentsApi = {
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
},
|
},
|
||||||
|
|
||||||
|
updateGroupKey: async (
|
||||||
|
documentId: string,
|
||||||
|
groupKey: string | null
|
||||||
|
): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => {
|
||||||
|
const { data } = await apiClient.patch(
|
||||||
|
`/api/v1/admin/documents/${documentId}/group-key`,
|
||||||
|
null,
|
||||||
|
{ params: { group_key: groupKey } }
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,3 +2,6 @@ export { documentsApi } from './documents'
|
|||||||
export { annotationsApi } from './annotations'
|
export { annotationsApi } from './annotations'
|
||||||
export { trainingApi } from './training'
|
export { trainingApi } from './training'
|
||||||
export { inferenceApi } from './inference'
|
export { inferenceApi } from './inference'
|
||||||
|
export { datasetsApi } from './datasets'
|
||||||
|
export { augmentationApi } from './augmentation'
|
||||||
|
export { modelsApi } from './models'
|
||||||
|
|||||||
55
frontend/src/api/endpoints/models.ts
Normal file
55
frontend/src/api/endpoints/models.ts
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import apiClient from '../client'
|
||||||
|
import type {
|
||||||
|
ModelVersionListResponse,
|
||||||
|
ModelVersionDetailResponse,
|
||||||
|
ModelVersionResponse,
|
||||||
|
ActiveModelResponse,
|
||||||
|
} from '../types'
|
||||||
|
|
||||||
|
export const modelsApi = {
|
||||||
|
list: async (params?: {
|
||||||
|
status?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}): Promise<ModelVersionListResponse> => {
|
||||||
|
const { data } = await apiClient.get('/api/v1/admin/training/models', {
|
||||||
|
params,
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
getDetail: async (versionId: string): Promise<ModelVersionDetailResponse> => {
|
||||||
|
const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
getActive: async (): Promise<ActiveModelResponse> => {
|
||||||
|
const { data } = await apiClient.get('/api/v1/admin/training/models/active')
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
activate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||||
|
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
deactivate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||||
|
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
archive: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||||
|
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
delete: async (versionId: string): Promise<{ message: string }> => {
|
||||||
|
const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
reload: async (): Promise<{ message: string; reloaded: boolean }> => {
|
||||||
|
const { data } = await apiClient.post('/api/v1/admin/training/models/reload')
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ export interface DocumentItem {
|
|||||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||||
auto_label_error: string | null
|
auto_label_error: string | null
|
||||||
upload_source: string
|
upload_source: string
|
||||||
|
group_key: string | null
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
annotation_count?: number
|
annotation_count?: number
|
||||||
@@ -59,6 +60,7 @@ export interface DocumentDetailResponse {
|
|||||||
auto_label_error: string | null
|
auto_label_error: string | null
|
||||||
upload_source: string
|
upload_source: string
|
||||||
batch_id: string | null
|
batch_id: string | null
|
||||||
|
group_key: string | null
|
||||||
csv_field_values: Record<string, string> | null
|
csv_field_values: Record<string, string> | null
|
||||||
can_annotate: boolean
|
can_annotate: boolean
|
||||||
annotation_lock_until: string | null
|
annotation_lock_until: string | null
|
||||||
@@ -113,7 +115,11 @@ export interface ErrorResponse {
|
|||||||
export interface UploadDocumentResponse {
|
export interface UploadDocumentResponse {
|
||||||
document_id: string
|
document_id: string
|
||||||
filename: string
|
filename: string
|
||||||
|
file_size: number
|
||||||
|
page_count: number
|
||||||
status: string
|
status: string
|
||||||
|
group_key: string | null
|
||||||
|
auto_label_started: boolean
|
||||||
message: string
|
message: string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,3 +177,165 @@ export interface InferenceResult {
|
|||||||
export interface InferenceResponse {
|
export interface InferenceResponse {
|
||||||
result: InferenceResult
|
result: InferenceResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dataset types
|
||||||
|
|
||||||
|
export interface DatasetCreateRequest {
|
||||||
|
name: string
|
||||||
|
description?: string
|
||||||
|
document_ids: string[]
|
||||||
|
train_ratio?: number
|
||||||
|
val_ratio?: number
|
||||||
|
seed?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DatasetResponse {
|
||||||
|
dataset_id: string
|
||||||
|
name: string
|
||||||
|
status: string
|
||||||
|
message: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DatasetDocumentItem {
|
||||||
|
document_id: string
|
||||||
|
split: string
|
||||||
|
page_count: number
|
||||||
|
annotation_count: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DatasetListItem {
|
||||||
|
dataset_id: string
|
||||||
|
name: string
|
||||||
|
description: string | null
|
||||||
|
status: string
|
||||||
|
training_status: string | null
|
||||||
|
active_training_task_id: string | null
|
||||||
|
total_documents: number
|
||||||
|
total_images: number
|
||||||
|
total_annotations: number
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DatasetListResponse {
|
||||||
|
total: number
|
||||||
|
limit: number
|
||||||
|
offset: number
|
||||||
|
datasets: DatasetListItem[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DatasetDetailResponse {
|
||||||
|
dataset_id: string
|
||||||
|
name: string
|
||||||
|
description: string | null
|
||||||
|
status: string
|
||||||
|
train_ratio: number
|
||||||
|
val_ratio: number
|
||||||
|
seed: number
|
||||||
|
total_documents: number
|
||||||
|
total_images: number
|
||||||
|
total_annotations: number
|
||||||
|
dataset_path: string | null
|
||||||
|
error_message: string | null
|
||||||
|
documents: DatasetDocumentItem[]
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AugmentationParams {
|
||||||
|
enabled: boolean
|
||||||
|
probability: number
|
||||||
|
params: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AugmentationTrainingConfig {
|
||||||
|
gaussian_noise?: AugmentationParams
|
||||||
|
perspective_warp?: AugmentationParams
|
||||||
|
wrinkle?: AugmentationParams
|
||||||
|
edge_damage?: AugmentationParams
|
||||||
|
stain?: AugmentationParams
|
||||||
|
lighting_variation?: AugmentationParams
|
||||||
|
shadow?: AugmentationParams
|
||||||
|
gaussian_blur?: AugmentationParams
|
||||||
|
motion_blur?: AugmentationParams
|
||||||
|
salt_pepper?: AugmentationParams
|
||||||
|
paper_texture?: AugmentationParams
|
||||||
|
scanner_artifacts?: AugmentationParams
|
||||||
|
preserve_bboxes?: boolean
|
||||||
|
seed?: number | null
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DatasetTrainRequest {
|
||||||
|
name: string
|
||||||
|
config: {
|
||||||
|
model_name?: string
|
||||||
|
base_model_version_id?: string | null
|
||||||
|
epochs?: number
|
||||||
|
batch_size?: number
|
||||||
|
image_size?: number
|
||||||
|
learning_rate?: number
|
||||||
|
device?: string
|
||||||
|
augmentation?: AugmentationTrainingConfig
|
||||||
|
augmentation_multiplier?: number
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TrainingTaskResponse {
|
||||||
|
task_id: string
|
||||||
|
status: string
|
||||||
|
message: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model Version types
|
||||||
|
|
||||||
|
export interface ModelVersionItem {
|
||||||
|
version_id: string
|
||||||
|
version: string
|
||||||
|
name: string
|
||||||
|
status: string
|
||||||
|
is_active: boolean
|
||||||
|
metrics_mAP: number | null
|
||||||
|
document_count: number
|
||||||
|
trained_at: string | null
|
||||||
|
activated_at: string | null
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelVersionDetailResponse {
|
||||||
|
version_id: string
|
||||||
|
version: string
|
||||||
|
name: string
|
||||||
|
description: string | null
|
||||||
|
model_path: string
|
||||||
|
status: string
|
||||||
|
is_active: boolean
|
||||||
|
task_id: string | null
|
||||||
|
dataset_id: string | null
|
||||||
|
metrics_mAP: number | null
|
||||||
|
metrics_precision: number | null
|
||||||
|
metrics_recall: number | null
|
||||||
|
document_count: number
|
||||||
|
training_config: Record<string, unknown> | null
|
||||||
|
file_size: number | null
|
||||||
|
trained_at: string | null
|
||||||
|
activated_at: string | null
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelVersionListResponse {
|
||||||
|
total: number
|
||||||
|
limit: number
|
||||||
|
offset: number
|
||||||
|
models: ModelVersionItem[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelVersionResponse {
|
||||||
|
version_id: string
|
||||||
|
status: string
|
||||||
|
message: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ActiveModelResponse {
|
||||||
|
has_active_model: boolean
|
||||||
|
model: ModelVersionItem | null
|
||||||
|
}
|
||||||
|
|||||||
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
/**
|
||||||
|
* Tests for AugmentationConfig component.
|
||||||
|
*
|
||||||
|
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||||
|
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||||
|
import { AugmentationConfig } from './AugmentationConfig'
|
||||||
|
import { augmentationApi } from '../api/endpoints/augmentation'
|
||||||
|
import type { ReactNode } from 'react'
|
||||||
|
|
||||||
|
// Mock the API
|
||||||
|
vi.mock('../api/endpoints/augmentation', () => ({
|
||||||
|
augmentationApi: {
|
||||||
|
getTypes: vi.fn(),
|
||||||
|
getPresets: vi.fn(),
|
||||||
|
preview: vi.fn(),
|
||||||
|
previewConfig: vi.fn(),
|
||||||
|
createBatch: vi.fn(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Default mock data
|
||||||
|
const mockTypes = {
|
||||||
|
augmentation_types: [
|
||||||
|
{
|
||||||
|
name: 'gaussian_noise',
|
||||||
|
description: 'Adds Gaussian noise to simulate sensor noise',
|
||||||
|
affects_geometry: false,
|
||||||
|
stage: 'noise',
|
||||||
|
default_params: { mean: 0, std: 15 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'perspective_warp',
|
||||||
|
description: 'Applies perspective transformation',
|
||||||
|
affects_geometry: true,
|
||||||
|
stage: 'geometric',
|
||||||
|
default_params: { max_warp: 0.02 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'gaussian_blur',
|
||||||
|
description: 'Applies Gaussian blur',
|
||||||
|
affects_geometry: false,
|
||||||
|
stage: 'blur',
|
||||||
|
default_params: { kernel_size: 5 },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockPresets = {
|
||||||
|
presets: [
|
||||||
|
{ name: 'conservative', description: 'Safe augmentations for high-quality documents' },
|
||||||
|
{ name: 'moderate', description: 'Balanced augmentation settings' },
|
||||||
|
{ name: 'aggressive', description: 'Strong augmentations for data diversity' },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test wrapper with QueryClient
|
||||||
|
const createWrapper = () => {
|
||||||
|
const queryClient = new QueryClient({
|
||||||
|
defaultOptions: {
|
||||||
|
queries: {
|
||||||
|
retry: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return ({ children }: { children: ReactNode }) => (
|
||||||
|
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('AugmentationConfig', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes)
|
||||||
|
vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('rendering', () => {
|
||||||
|
it('should render enable checkbox', async () => {
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={false}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should be collapsed when disabled', () => {
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={false}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config options should not be visible
|
||||||
|
expect(screen.queryByText(/preset/i)).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should expand when enabled', async () => {
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={true}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/preset/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('preset selection', () => {
|
||||||
|
it('should display available presets', async () => {
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={true}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('conservative')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('aggressive')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should call onConfigChange when preset is selected', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onConfigChange = vi.fn()
|
||||||
|
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={true}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={onConfigChange}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
await user.click(screen.getByText('moderate'))
|
||||||
|
|
||||||
|
expect(onConfigChange).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('enable toggle', () => {
|
||||||
|
it('should call onEnabledChange when checkbox is toggled', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onEnabledChange = vi.fn()
|
||||||
|
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={false}
|
||||||
|
onEnabledChange={onEnabledChange}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i }))
|
||||||
|
|
||||||
|
expect(onEnabledChange).toHaveBeenCalledWith(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('augmentation types', () => {
|
||||||
|
it('should display augmentation types when in custom mode', async () => {
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={true}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
showCustomOptions={true}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument()
|
||||||
|
expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should indicate which augmentations affect geometry', async () => {
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={true}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
showCustomOptions={true}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
// perspective_warp affects geometry
|
||||||
|
const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div')
|
||||||
|
expect(perspectiveItem).toHaveTextContent(/affects bbox/i)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('loading state', () => {
|
||||||
|
it('should show loading indicator while fetching types', () => {
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockImplementation(
|
||||||
|
() => new Promise(() => {})
|
||||||
|
)
|
||||||
|
|
||||||
|
render(
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={true}
|
||||||
|
onEnabledChange={vi.fn()}
|
||||||
|
config={{}}
|
||||||
|
onConfigChange={vi.fn()}
|
||||||
|
/>,
|
||||||
|
{ wrapper: createWrapper() }
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
136
frontend/src/components/AugmentationConfig.tsx
Normal file
136
frontend/src/components/AugmentationConfig.tsx
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
/**
|
||||||
|
* AugmentationConfig component for configuring image augmentation during training.
|
||||||
|
*
|
||||||
|
* Provides preset selection and optional custom augmentation type configuration.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import React from 'react'
|
||||||
|
import { Loader2, AlertTriangle } from 'lucide-react'
|
||||||
|
import { useAugmentation } from '../hooks/useAugmentation'
|
||||||
|
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||||
|
|
||||||
|
interface AugmentationConfigProps {
|
||||||
|
enabled: boolean
|
||||||
|
onEnabledChange: (enabled: boolean) => void
|
||||||
|
config: Partial<AugmentationConfigType>
|
||||||
|
onConfigChange: (config: Partial<AugmentationConfigType>) => void
|
||||||
|
showCustomOptions?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export const AugmentationConfig: React.FC<AugmentationConfigProps> = ({
|
||||||
|
enabled,
|
||||||
|
onEnabledChange,
|
||||||
|
config,
|
||||||
|
onConfigChange,
|
||||||
|
showCustomOptions = false,
|
||||||
|
}) => {
|
||||||
|
const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation()
|
||||||
|
|
||||||
|
const isLoading = isLoadingTypes || isLoadingPresets
|
||||||
|
|
||||||
|
const handlePresetSelect = (presetName: string) => {
|
||||||
|
const preset = presets.find((p) => p.name === presetName)
|
||||||
|
if (preset && preset.config) {
|
||||||
|
onConfigChange(preset.config as Partial<AugmentationConfigType>)
|
||||||
|
} else {
|
||||||
|
// Apply a basic config based on preset name
|
||||||
|
const presetConfigs: Record<string, Partial<AugmentationConfigType>> = {
|
||||||
|
conservative: {
|
||||||
|
gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } },
|
||||||
|
gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } },
|
||||||
|
},
|
||||||
|
moderate: {
|
||||||
|
gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } },
|
||||||
|
gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } },
|
||||||
|
lighting_variation: { enabled: true, probability: 0.3, params: {} },
|
||||||
|
perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } },
|
||||||
|
},
|
||||||
|
aggressive: {
|
||||||
|
gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } },
|
||||||
|
gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } },
|
||||||
|
motion_blur: { enabled: true, probability: 0.3, params: {} },
|
||||||
|
lighting_variation: { enabled: true, probability: 0.5, params: {} },
|
||||||
|
shadow: { enabled: true, probability: 0.3, params: {} },
|
||||||
|
perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } },
|
||||||
|
wrinkle: { enabled: true, probability: 0.2, params: {} },
|
||||||
|
stain: { enabled: true, probability: 0.2, params: {} },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
onConfigChange(presetConfigs[presetName] || {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="border border-warm-divider rounded-lg p-4 bg-warm-bg-secondary">
|
||||||
|
{/* Enable checkbox */}
|
||||||
|
<label className="flex items-center gap-2 cursor-pointer">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={enabled}
|
||||||
|
onChange={(e) => onEnabledChange(e.target.checked)}
|
||||||
|
className="w-4 h-4 rounded border-warm-divider text-warm-state-info focus:ring-warm-state-info"
|
||||||
|
aria-label="Enable augmentation"
|
||||||
|
/>
|
||||||
|
<span className="text-sm font-medium text-warm-text-secondary">Enable Augmentation</span>
|
||||||
|
<span className="text-xs text-warm-text-muted">(Simulate real-world document conditions)</span>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
{/* Expanded content when enabled */}
|
||||||
|
{enabled && (
|
||||||
|
<div className="mt-4 space-y-4">
|
||||||
|
{isLoading ? (
|
||||||
|
<div className="flex items-center justify-center py-4" data-testid="augmentation-loading">
|
||||||
|
<Loader2 className="w-5 h-5 animate-spin text-warm-state-info" />
|
||||||
|
<span className="ml-2 text-sm text-warm-text-muted">Loading augmentation options...</span>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
{/* Preset selection */}
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Preset</label>
|
||||||
|
<div className="flex flex-wrap gap-2">
|
||||||
|
{presets.map((preset) => (
|
||||||
|
<button
|
||||||
|
key={preset.name}
|
||||||
|
onClick={() => handlePresetSelect(preset.name)}
|
||||||
|
className="px-3 py-1.5 text-sm rounded-md border border-warm-divider hover:bg-warm-bg-tertiary transition-colors"
|
||||||
|
title={preset.description}
|
||||||
|
>
|
||||||
|
{preset.name}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Custom options (if enabled) */}
|
||||||
|
{showCustomOptions && (
|
||||||
|
<div className="border-t border-warm-divider pt-4">
|
||||||
|
<h4 className="text-sm font-medium text-warm-text-secondary mb-3">Augmentation Types</h4>
|
||||||
|
<div className="grid gap-2">
|
||||||
|
{augmentationTypes.map((type) => (
|
||||||
|
<div
|
||||||
|
key={type.name}
|
||||||
|
className="flex items-center justify-between p-2 bg-warm-bg-primary rounded border border-warm-divider"
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-sm text-warm-text-primary">{type.name}</span>
|
||||||
|
{type.affects_geometry && (
|
||||||
|
<span className="flex items-center gap-1 text-xs text-warm-state-warning">
|
||||||
|
<AlertTriangle size={12} />
|
||||||
|
affects bbox
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<span className="text-xs text-warm-text-muted">{type.stage}</span>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -144,6 +144,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
Annotations
|
Annotations
|
||||||
</th>
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
|
Group
|
||||||
|
</th>
|
||||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
|
||||||
Auto-label
|
Auto-label
|
||||||
</th>
|
</th>
|
||||||
@@ -153,13 +156,13 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
<tbody>
|
<tbody>
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
<tr>
|
<tr>
|
||||||
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
|
||||||
Loading documents...
|
Loading documents...
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
) : documents.length === 0 ? (
|
) : documents.length === 0 ? (
|
||||||
<tr>
|
<tr>
|
||||||
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
|
||||||
No documents found. Upload your first document to get started.
|
No documents found. Upload your first document to get started.
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
@@ -213,6 +216,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
||||||
{doc.annotation_count || 0} annotations
|
{doc.annotation_count || 0} annotations
|
||||||
</td>
|
</td>
|
||||||
|
<td className="py-4 px-4 text-sm text-warm-text-muted">
|
||||||
|
{doc.group_key || '-'}
|
||||||
|
</td>
|
||||||
<td className="py-4 px-4">
|
<td className="py-4 px-4">
|
||||||
{doc.auto_label_status === 'running' && progress && (
|
{doc.auto_label_status === 'running' && progress && (
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
|
|||||||
122
frontend/src/components/DatasetDetail.tsx
Normal file
122
frontend/src/components/DatasetDetail.tsx
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react'
|
||||||
|
import { Button } from './Button'
|
||||||
|
import { useDatasetDetail } from '../hooks/useDatasets'
|
||||||
|
|
||||||
|
interface DatasetDetailProps {
|
||||||
|
datasetId: string
|
||||||
|
onBack: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const SPLIT_STYLES: Record<string, string> = {
|
||||||
|
train: 'bg-warm-state-info/10 text-warm-state-info',
|
||||||
|
val: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||||
|
test: 'bg-warm-state-success/10 text-warm-state-success',
|
||||||
|
}
|
||||||
|
|
||||||
|
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
|
||||||
|
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-center py-20 text-warm-text-muted">
|
||||||
|
<Loader2 size={24} className="animate-spin mr-2" />Loading dataset...
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error || !dataset) {
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto">
|
||||||
|
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||||
|
<ArrowLeft size={16} />Back
|
||||||
|
</button>
|
||||||
|
<p className="text-warm-state-error">Failed to load dataset.</p>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusIcon = dataset.status === 'ready'
|
||||||
|
? <Check size={14} className="text-warm-state-success" />
|
||||||
|
: dataset.status === 'failed'
|
||||||
|
? <AlertCircle size={14} className="text-warm-state-error" />
|
||||||
|
: <Loader2 size={14} className="animate-spin text-warm-state-info" />
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto">
|
||||||
|
{/* Header */}
|
||||||
|
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||||
|
<ArrowLeft size={16} />Back to Datasets
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div className="flex items-center justify-between mb-6">
|
||||||
|
<div>
|
||||||
|
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
|
||||||
|
{dataset.name} {statusIcon}
|
||||||
|
</h2>
|
||||||
|
{dataset.description && (
|
||||||
|
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{dataset.status === 'ready' && (
|
||||||
|
<Button><Play size={14} className="mr-1" />Start Training</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{dataset.error_message && (
|
||||||
|
<div className="bg-warm-state-error/10 border border-warm-state-error/20 rounded-lg p-4 mb-6 text-sm text-warm-state-error">
|
||||||
|
{dataset.error_message}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Stats */}
|
||||||
|
<div className="grid grid-cols-4 gap-4 mb-8">
|
||||||
|
{[
|
||||||
|
['Documents', dataset.total_documents],
|
||||||
|
['Images', dataset.total_images],
|
||||||
|
['Annotations', dataset.total_annotations],
|
||||||
|
['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
|
||||||
|
].map(([label, value]) => (
|
||||||
|
<div key={String(label)} className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||||
|
<p className="text-xs text-warm-text-muted uppercase font-semibold mb-1">{label}</p>
|
||||||
|
<p className="text-2xl font-bold text-warm-text-primary font-mono">{value}</p>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Document list */}
|
||||||
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Documents</h3>
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||||
|
<table className="w-full text-left">
|
||||||
|
<thead className="bg-white border-b border-warm-border">
|
||||||
|
<tr>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Split</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{dataset.documents.map(doc => (
|
||||||
|
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||||
|
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||||
|
<td className="py-3 px-4">
|
||||||
|
<span className={`inline-flex px-2.5 py-1 rounded-full text-xs font-medium ${SPLIT_STYLES[doc.split] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||||
|
{doc.split}
|
||||||
|
</span>
|
||||||
|
</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p className="text-xs text-warm-text-muted mt-4">
|
||||||
|
Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
|
||||||
|
{dataset.dataset_path && <> | Path: <code className="text-xs">{dataset.dataset_path}</code></>}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
import React, { useState, useRef, useEffect } from 'react'
|
import React, { useState, useRef, useEffect } from 'react'
|
||||||
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
|
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
|
||||||
import { Button } from './Button'
|
import { Button } from './Button'
|
||||||
import { useDocumentDetail } from '../hooks/useDocumentDetail'
|
import { useDocumentDetail } from '../hooks/useDocumentDetail'
|
||||||
import { useAnnotations } from '../hooks/useAnnotations'
|
import { useAnnotations } from '../hooks/useAnnotations'
|
||||||
|
import { useDocuments } from '../hooks/useDocuments'
|
||||||
import { documentsApi } from '../api/endpoints/documents'
|
import { documentsApi } from '../api/endpoints/documents'
|
||||||
import type { AnnotationItem } from '../api/types'
|
import type { AnnotationItem } from '../api/types'
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ const FIELD_CLASSES: Record<number, string> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
|
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
|
||||||
const { document, annotations, isLoading } = useDocumentDetail(docId)
|
const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
|
||||||
const {
|
const {
|
||||||
createAnnotation,
|
createAnnotation,
|
||||||
updateAnnotation,
|
updateAnnotation,
|
||||||
@@ -34,10 +35,13 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
|
|||||||
isCreating,
|
isCreating,
|
||||||
isDeleting,
|
isDeleting,
|
||||||
} = useAnnotations(docId)
|
} = useAnnotations(docId)
|
||||||
|
const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
|
||||||
|
|
||||||
const [selectedId, setSelectedId] = useState<string | null>(null)
|
const [selectedId, setSelectedId] = useState<string | null>(null)
|
||||||
const [zoom, setZoom] = useState(100)
|
const [zoom, setZoom] = useState(100)
|
||||||
const [isDrawing, setIsDrawing] = useState(false)
|
const [isDrawing, setIsDrawing] = useState(false)
|
||||||
|
const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
|
||||||
|
const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
|
||||||
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
|
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
|
||||||
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
|
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
|
||||||
const [selectedClassId, setSelectedClassId] = useState<number>(0)
|
const [selectedClassId, setSelectedClassId] = useState<number>(0)
|
||||||
@@ -426,6 +430,65 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
|
|||||||
{new Date(document.created_at).toLocaleDateString()}
|
{new Date(document.created_at).toLocaleDateString()}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
<div className="flex justify-between items-center text-xs">
|
||||||
|
<span className="text-warm-text-muted">Group</span>
|
||||||
|
{isEditingGroupKey ? (
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={editGroupKeyValue}
|
||||||
|
onChange={(e) => setEditGroupKeyValue(e.target.value)}
|
||||||
|
className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
|
placeholder="group key"
|
||||||
|
autoFocus
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
updateGroupKey(
|
||||||
|
{ documentId: docId, groupKey: editGroupKeyValue.trim() || null },
|
||||||
|
{
|
||||||
|
onSuccess: () => {
|
||||||
|
setIsEditingGroupKey(false)
|
||||||
|
refetch()
|
||||||
|
},
|
||||||
|
onError: () => {
|
||||||
|
alert('Failed to update group key. Please try again.')
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}}
|
||||||
|
disabled={isUpdatingGroupKey}
|
||||||
|
className="p-0.5 text-warm-state-success hover:bg-warm-hover rounded"
|
||||||
|
>
|
||||||
|
<Check size={14} />
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
setIsEditingGroupKey(false)
|
||||||
|
setEditGroupKeyValue(document.group_key || '')
|
||||||
|
}}
|
||||||
|
className="p-0.5 text-warm-state-error hover:bg-warm-hover rounded"
|
||||||
|
>
|
||||||
|
<X size={14} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<span className="text-warm-text-secondary font-medium">
|
||||||
|
{document.group_key || '-'}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
setEditGroupKeyValue(document.group_key || '')
|
||||||
|
setIsEditingGroupKey(true)
|
||||||
|
}}
|
||||||
|
className="p-0.5 text-warm-text-muted hover:text-warm-text-secondary hover:bg-warm-hover rounded"
|
||||||
|
>
|
||||||
|
<Edit2 size={12} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,73 +1,108 @@
|
|||||||
import React from 'react';
|
import React, { useState } from 'react';
|
||||||
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
|
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
|
||||||
|
import { Loader2, Power, CheckCircle } from 'lucide-react';
|
||||||
import { Button } from './Button';
|
import { Button } from './Button';
|
||||||
|
import { useModels, useModelDetail } from '../hooks';
|
||||||
|
import type { ModelVersionItem } from '../api/types';
|
||||||
|
|
||||||
const CHART_DATA = [
|
const formatDate = (dateString: string | null): string => {
|
||||||
{ name: 'Model A', value: 75 },
|
if (!dateString) return 'N/A';
|
||||||
{ name: 'Model B', value: 82 },
|
return new Date(dateString).toLocaleString();
|
||||||
{ name: 'Model C', value: 95 },
|
};
|
||||||
{ name: 'Model D', value: 68 },
|
|
||||||
];
|
|
||||||
|
|
||||||
const METRICS_DATA = [
|
|
||||||
{ name: 'Precision', value: 88 },
|
|
||||||
{ name: 'Recall', value: 76 },
|
|
||||||
{ name: 'F1 Score', value: 91 },
|
|
||||||
{ name: 'Accuracy', value: 82 },
|
|
||||||
];
|
|
||||||
|
|
||||||
const JOBS = [
|
|
||||||
{ id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
|
|
||||||
{ id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
|
|
||||||
{ id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
|
|
||||||
];
|
|
||||||
|
|
||||||
export const Models: React.FC = () => {
|
export const Models: React.FC = () => {
|
||||||
|
const [selectedModel, setSelectedModel] = useState<ModelVersionItem | null>(null);
|
||||||
|
const { models, isLoading, activateModel, isActivating } = useModels();
|
||||||
|
const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
|
||||||
|
|
||||||
|
// Build chart data from selected model's metrics
|
||||||
|
const metricsData = modelDetail ? [
|
||||||
|
{ name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
|
||||||
|
{ name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
|
||||||
|
{ name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
|
||||||
|
] : [
|
||||||
|
{ name: 'Precision', value: 0 },
|
||||||
|
{ name: 'Recall', value: 0 },
|
||||||
|
{ name: 'mAP', value: 0 },
|
||||||
|
];
|
||||||
|
|
||||||
|
// Build comparison chart from all models (with placeholder if empty)
|
||||||
|
const chartData = models.length > 0
|
||||||
|
? models.slice(0, 4).map(m => ({
|
||||||
|
name: m.version,
|
||||||
|
value: (m.metrics_mAP ?? 0) * 100,
|
||||||
|
}))
|
||||||
|
: [
|
||||||
|
{ name: 'Model A', value: 0 },
|
||||||
|
{ name: 'Model B', value: 0 },
|
||||||
|
{ name: 'Model C', value: 0 },
|
||||||
|
{ name: 'Model D', value: 0 },
|
||||||
|
];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="p-8 max-w-7xl mx-auto flex gap-8">
|
<div className="p-8 max-w-7xl mx-auto flex gap-8">
|
||||||
{/* Left: Job History */}
|
{/* Left: Job History */}
|
||||||
<div className="flex-1">
|
<div className="flex-1">
|
||||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
|
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
|
||||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Recent Training Jobs</h3>
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Model Versions</h3>
|
||||||
|
|
||||||
<div className="space-y-4">
|
{isLoading ? (
|
||||||
{JOBS.map(job => (
|
<div className="flex items-center justify-center py-12">
|
||||||
<div key={job.id} className="bg-warm-card border border-warm-border rounded-lg p-5 shadow-sm hover:border-warm-divider transition-colors">
|
<Loader2 className="animate-spin text-warm-text-muted" size={32} />
|
||||||
<div className="flex justify-between items-start mb-2">
|
</div>
|
||||||
<div>
|
) : models.length === 0 ? (
|
||||||
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">{job.name}</h4>
|
<div className="text-center py-12 text-warm-text-muted">
|
||||||
<p className="text-sm text-warm-text-muted">Started {job.date}</p>
|
No model versions found. Complete a training task to create a model version.
|
||||||
</div>
|
</div>
|
||||||
<span className={`px-3 py-1 rounded-full text-xs font-medium ${job.status === 'Running' ? 'bg-warm-selected text-warm-text-secondary' : 'bg-warm-selected text-warm-state-success'}`}>
|
) : (
|
||||||
{job.status}
|
<div className="space-y-4">
|
||||||
</span>
|
{models.map(model => (
|
||||||
</div>
|
<div
|
||||||
|
key={model.version_id}
|
||||||
{job.status === 'Running' ? (
|
onClick={() => setSelectedModel(model)}
|
||||||
<div className="mt-4">
|
className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
|
||||||
<div className="h-2 w-full bg-warm-selected rounded-full overflow-hidden">
|
selectedModel?.version_id === model.version_id
|
||||||
<div className="h-full bg-warm-text-secondary w-[65%] rounded-full"></div>
|
? 'border-warm-text-secondary'
|
||||||
</div>
|
: 'border-warm-border hover:border-warm-divider'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<div className="flex justify-between items-start mb-2">
|
||||||
|
<div>
|
||||||
|
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">
|
||||||
|
{model.name}
|
||||||
|
{model.is_active && <CheckCircle size={16} className="inline ml-2 text-warm-state-info" />}
|
||||||
|
</h4>
|
||||||
|
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
|
||||||
|
</div>
|
||||||
|
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
|
||||||
|
model.is_active
|
||||||
|
? 'bg-warm-state-info/10 text-warm-state-info'
|
||||||
|
: 'bg-warm-selected text-warm-state-success'
|
||||||
|
}`}>
|
||||||
|
{model.is_active ? 'Active' : model.status}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
|
||||||
<div className="mt-4 flex gap-8">
|
<div className="mt-4 flex gap-8">
|
||||||
<div>
|
<div>
|
||||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Success</span>
|
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Documents</span>
|
||||||
<span className="text-lg font-mono text-warm-text-secondary">{job.success}</span>
|
<span className="text-lg font-mono text-warm-text-secondary">{model.document_count}</span>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Performance</span>
|
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">mAP</span>
|
||||||
<span className="text-lg font-mono text-warm-text-secondary">{job.metrics}%</span>
|
<span className="text-lg font-mono text-warm-text-secondary">
|
||||||
|
{model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Completed</span>
|
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Version</span>
|
||||||
<span className="text-lg font-mono text-warm-text-secondary">100%</span>
|
<span className="text-lg font-mono text-warm-text-secondary">{model.version}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
</div>
|
||||||
</div>
|
))}
|
||||||
))}
|
</div>
|
||||||
</div>
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Right: Model Detail */}
|
{/* Right: Model Detail */}
|
||||||
@@ -75,27 +110,34 @@ export const Models: React.FC = () => {
|
|||||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
|
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
|
||||||
<div className="flex justify-between items-center mb-6">
|
<div className="flex justify-between items-center mb-6">
|
||||||
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
|
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
|
||||||
<span className="text-sm font-medium text-warm-state-success">Completed</span>
|
<span className={`text-sm font-medium ${
|
||||||
|
selectedModel?.is_active ? 'text-warm-state-info' : 'text-warm-state-success'
|
||||||
|
}`}>
|
||||||
|
{selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="mb-8">
|
<div className="mb-8">
|
||||||
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
|
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
|
||||||
<p className="font-medium text-warm-text-primary">Invoices Q4 v2.1</p>
|
<p className="font-medium text-warm-text-primary">
|
||||||
|
{selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="space-y-8">
|
<div className="space-y-8">
|
||||||
{/* Chart 1 */}
|
{/* Chart 1 */}
|
||||||
<div>
|
<div>
|
||||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Bar Rate Metrics</h4>
|
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Model Comparison (mAP)</h4>
|
||||||
<div className="h-40">
|
<div className="h-40">
|
||||||
<ResponsiveContainer width="100%" height="100%">
|
<ResponsiveContainer width="100%" height="100%">
|
||||||
<BarChart data={CHART_DATA}>
|
<BarChart data={chartData}>
|
||||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||||
<XAxis dataKey="name" hide />
|
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||||
<YAxis hide domain={[0, 100]} />
|
<YAxis hide domain={[0, 100]} />
|
||||||
<Tooltip
|
<Tooltip
|
||||||
cursor={{fill: '#F1F0ED'}}
|
cursor={{fill: '#F1F0ED'}}
|
||||||
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
|
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
|
||||||
|
formatter={(value: number) => [`${value.toFixed(1)}%`, 'mAP']}
|
||||||
/>
|
/>
|
||||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||||
</BarChart>
|
</BarChart>
|
||||||
@@ -105,14 +147,17 @@ export const Models: React.FC = () => {
|
|||||||
|
|
||||||
{/* Chart 2 */}
|
{/* Chart 2 */}
|
||||||
<div>
|
<div>
|
||||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Entity Extraction Accuracy</h4>
|
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Performance Metrics</h4>
|
||||||
<div className="h-40">
|
<div className="h-40">
|
||||||
<ResponsiveContainer width="100%" height="100%">
|
<ResponsiveContainer width="100%" height="100%">
|
||||||
<BarChart data={METRICS_DATA}>
|
<BarChart data={metricsData}>
|
||||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||||
<YAxis hide domain={[0, 100]} />
|
<YAxis hide domain={[0, 100]} />
|
||||||
<Tooltip cursor={{fill: '#F1F0ED'}} />
|
<Tooltip
|
||||||
|
cursor={{fill: '#F1F0ED'}}
|
||||||
|
formatter={(value: number) => [`${value.toFixed(1)}%`, 'Score']}
|
||||||
|
/>
|
||||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||||
</BarChart>
|
</BarChart>
|
||||||
</ResponsiveContainer>
|
</ResponsiveContainer>
|
||||||
@@ -121,10 +166,39 @@ export const Models: React.FC = () => {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="mt-8 space-y-3">
|
<div className="mt-8 space-y-3">
|
||||||
<Button className="w-full">Download Model</Button>
|
{selectedModel && !selectedModel.is_active ? (
|
||||||
|
<Button
|
||||||
|
className="w-full"
|
||||||
|
onClick={() => activateModel(selectedModel.version_id)}
|
||||||
|
disabled={isActivating}
|
||||||
|
>
|
||||||
|
{isActivating ? (
|
||||||
|
<>
|
||||||
|
<Loader2 size={16} className="mr-2 animate-spin" />
|
||||||
|
Activating...
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Power size={16} className="mr-2" />
|
||||||
|
Activate for Inference
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
) : (
|
||||||
|
<Button className="w-full" disabled={!selectedModel}>
|
||||||
|
{selectedModel?.is_active ? (
|
||||||
|
<>
|
||||||
|
<CheckCircle size={16} className="mr-2" />
|
||||||
|
Currently Active
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
'Select a Model'
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
<div className="flex gap-3">
|
<div className="flex gap-3">
|
||||||
<Button variant="secondary" className="flex-1">View Logs</Button>
|
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>View Logs</Button>
|
||||||
<Button variant="secondary" className="flex-1">Use as Base</Button>
|
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>Use as Base</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,113 +1,482 @@
|
|||||||
import React, { useState } from 'react';
|
import React, { useState, useMemo } from 'react'
|
||||||
import { Check, AlertCircle } from 'lucide-react';
|
import { useQuery } from '@tanstack/react-query'
|
||||||
import { Button } from './Button';
|
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
|
||||||
import { DocumentStatus } from '../types';
|
import { Button } from './Button'
|
||||||
|
import { AugmentationConfig } from './AugmentationConfig'
|
||||||
|
import { useDatasets } from '../hooks/useDatasets'
|
||||||
|
import { useTrainingDocuments } from '../hooks/useTraining'
|
||||||
|
import { trainingApi } from '../api/endpoints'
|
||||||
|
import type { DatasetListItem } from '../api/types'
|
||||||
|
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||||
|
|
||||||
export const Training: React.FC = () => {
|
type Tab = 'datasets' | 'create'
|
||||||
const [split, setSplit] = useState(80);
|
|
||||||
|
|
||||||
const docs = [
|
interface TrainingProps {
|
||||||
{ id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
|
onNavigate?: (view: string, id?: string) => void
|
||||||
{ id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
}
|
||||||
{ id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
|
||||||
{ id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
const STATUS_STYLES: Record<string, string> = {
|
||||||
{ id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
ready: 'bg-warm-state-success/10 text-warm-state-success',
|
||||||
{ id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
building: 'bg-warm-state-info/10 text-warm-state-info',
|
||||||
{ id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
training: 'bg-warm-state-info/10 text-warm-state-info',
|
||||||
];
|
failed: 'bg-warm-state-error/10 text-warm-state-error',
|
||||||
|
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||||
|
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||||
|
running: 'bg-warm-state-info/10 text-warm-state-info',
|
||||||
|
}
|
||||||
|
|
||||||
|
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
|
||||||
|
// If there's an active training task, show training status
|
||||||
|
const displayStatus = trainingStatus === 'running'
|
||||||
|
? 'training'
|
||||||
|
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
|
||||||
|
? 'pending'
|
||||||
|
: status
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="p-8 max-w-7xl mx-auto h-[calc(100vh-56px)] flex gap-8">
|
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||||
{/* Document Selection List */}
|
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
|
||||||
<div className="flex-1 flex flex-col">
|
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
|
||||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Document Selection</h2>
|
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
|
||||||
|
{displayStatus}
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
<div className="flex-1 bg-warm-card border border-warm-border rounded-lg overflow-hidden flex flex-col shadow-sm">
|
// --- Train Dialog ---
|
||||||
<div className="overflow-auto flex-1">
|
|
||||||
<table className="w-full text-left">
|
|
||||||
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
|
||||||
<tr>
|
|
||||||
<th className="py-3 pl-6 pr-4 w-12"><input type="checkbox" className="rounded border-warm-divider"/></th>
|
|
||||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document name</th>
|
|
||||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Date</th>
|
|
||||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{docs.map(doc => (
|
|
||||||
<tr key={doc.id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
|
||||||
<td className="py-3 pl-6 pr-4"><input type="checkbox" defaultChecked className="rounded border-warm-divider accent-warm-state-info"/></td>
|
|
||||||
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{doc.name}</td>
|
|
||||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.date}</td>
|
|
||||||
<td className="py-3 px-4">
|
|
||||||
{doc.status === DocumentStatus.VERIFIED ? (
|
|
||||||
<div className="flex items-center text-warm-state-success text-sm font-medium">
|
|
||||||
<div className="w-5 h-5 rounded-full bg-warm-state-success flex items-center justify-center text-white mr-2">
|
|
||||||
<Check size={12} strokeWidth={3}/>
|
|
||||||
</div>
|
|
||||||
Verified
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<div className="flex items-center text-warm-text-muted text-sm">
|
|
||||||
<div className="w-5 h-5 rounded-full bg-[#BDBBB5] flex items-center justify-center text-white mr-2">
|
|
||||||
<span className="font-bold text-[10px]">!</span>
|
|
||||||
</div>
|
|
||||||
Partial
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
))}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Configuration Panel */}
|
interface TrainDialogProps {
|
||||||
<div className="w-96">
|
dataset: DatasetListItem
|
||||||
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
onClose: () => void
|
||||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-6">Training Configuration</h3>
|
onSubmit: (config: {
|
||||||
|
name: string
|
||||||
|
config: {
|
||||||
|
model_name?: string
|
||||||
|
base_model_version_id?: string | null
|
||||||
|
epochs: number
|
||||||
|
batch_size: number
|
||||||
|
augmentation?: AugmentationConfigType
|
||||||
|
augmentation_multiplier?: number
|
||||||
|
}
|
||||||
|
}) => void
|
||||||
|
isPending: boolean
|
||||||
|
}
|
||||||
|
|
||||||
<div className="space-y-6">
|
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
|
||||||
<div>
|
const [name, setName] = useState(`train-${dataset.name}`)
|
||||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Model Name</label>
|
const [epochs, setEpochs] = useState(100)
|
||||||
|
const [batchSize, setBatchSize] = useState(16)
|
||||||
|
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
|
||||||
|
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
|
||||||
|
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
|
||||||
|
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
||||||
|
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
||||||
|
|
||||||
|
// Fetch available trained models
|
||||||
|
const { data: modelsData } = useQuery({
|
||||||
|
queryKey: ['training', 'models', 'completed'],
|
||||||
|
queryFn: () => trainingApi.getModels({ status: 'completed' }),
|
||||||
|
})
|
||||||
|
const completedModels = modelsData?.models ?? []
|
||||||
|
|
||||||
|
const handleSubmit = () => {
|
||||||
|
onSubmit({
|
||||||
|
name,
|
||||||
|
config: {
|
||||||
|
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
|
||||||
|
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
|
||||||
|
epochs,
|
||||||
|
batch_size: batchSize,
|
||||||
|
augmentation: augmentationEnabled
|
||||||
|
? (augmentationConfig as AugmentationConfigType)
|
||||||
|
: undefined,
|
||||||
|
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
|
||||||
|
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
|
||||||
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
|
||||||
|
<p className="text-sm text-warm-text-muted mb-4">
|
||||||
|
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
|
||||||
|
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
|
||||||
|
<input type="text" value={name} onChange={e => setName(e.target.value)}
|
||||||
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Base Model Selection */}
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
|
||||||
|
<select
|
||||||
|
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
|
||||||
|
onChange={e => {
|
||||||
|
if (e.target.value === 'pretrained') {
|
||||||
|
setBaseModelType('pretrained')
|
||||||
|
setBaseModelVersionId(null)
|
||||||
|
} else {
|
||||||
|
setBaseModelType('existing')
|
||||||
|
setBaseModelVersionId(e.target.value)
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
|
>
|
||||||
|
<option value="pretrained">yolo11n.pt (Pretrained)</option>
|
||||||
|
{completedModels.map(m => (
|
||||||
|
<option key={m.task_id} value={m.task_id}>
|
||||||
|
{m.name} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
<p className="text-xs text-warm-text-muted mt-1">
|
||||||
|
{baseModelType === 'pretrained'
|
||||||
|
? 'Start from pretrained YOLO model'
|
||||||
|
: 'Continue training from an existing model (incremental training)'}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex gap-4">
|
||||||
|
<div className="flex-1">
|
||||||
|
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
|
||||||
<input
|
<input
|
||||||
type="text"
|
id="train-epochs"
|
||||||
placeholder="e.g. Invoices Q4"
|
type="number"
|
||||||
|
min={1}
|
||||||
|
max={1000}
|
||||||
|
value={epochs}
|
||||||
|
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
|
||||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
<div>
|
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
|
||||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Base Model</label>
|
|
||||||
<select className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none">
|
|
||||||
<option>LayoutLMv3 (Standard)</option>
|
|
||||||
<option>Donut (Beta)</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<div className="flex justify-between mb-2">
|
|
||||||
<label className="block text-sm font-medium text-warm-text-secondary">Train/Test Split</label>
|
|
||||||
<span className="text-xs font-mono text-warm-text-muted">{split}% / {100-split}%</span>
|
|
||||||
</div>
|
|
||||||
<input
|
<input
|
||||||
type="range"
|
id="train-batch-size"
|
||||||
min="50"
|
type="number"
|
||||||
max="95"
|
min={1}
|
||||||
value={split}
|
max={128}
|
||||||
onChange={(e) => setSplit(parseInt(e.target.value))}
|
value={batchSize}
|
||||||
className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
|
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
|
||||||
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Augmentation Configuration */}
|
||||||
|
<AugmentationConfig
|
||||||
|
enabled={augmentationEnabled}
|
||||||
|
onEnabledChange={setAugmentationEnabled}
|
||||||
|
config={augmentationConfig}
|
||||||
|
onConfigChange={setAugmentationConfig}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
|
||||||
|
{augmentationEnabled && (
|
||||||
|
<div>
|
||||||
|
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
|
||||||
|
Augmentation Multiplier
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
id="aug-multiplier"
|
||||||
|
type="number"
|
||||||
|
min={1}
|
||||||
|
max={10}
|
||||||
|
value={augmentationMultiplier}
|
||||||
|
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
|
||||||
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-warm-text-muted mt-1">
|
||||||
|
Number of augmented copies per original image (1-10)
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex justify-end gap-3 mt-6">
|
||||||
|
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
|
||||||
|
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
|
||||||
|
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Dataset List ---
|
||||||
|
|
||||||
|
const DatasetList: React.FC<{
|
||||||
|
onNavigate?: (view: string, id?: string) => void
|
||||||
|
onSwitchTab: (tab: Tab) => void
|
||||||
|
}> = ({ onNavigate, onSwitchTab }) => {
|
||||||
|
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
|
||||||
|
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
|
||||||
|
|
||||||
|
const handleTrain = (config: {
|
||||||
|
name: string
|
||||||
|
config: {
|
||||||
|
model_name?: string
|
||||||
|
base_model_version_id?: string | null
|
||||||
|
epochs: number
|
||||||
|
batch_size: number
|
||||||
|
augmentation?: AugmentationConfigType
|
||||||
|
augmentation_multiplier?: number
|
||||||
|
}
|
||||||
|
}) => {
|
||||||
|
if (!trainTarget) return
|
||||||
|
// Pass config to the training API
|
||||||
|
const trainRequest = {
|
||||||
|
name: config.name,
|
||||||
|
config: config.config,
|
||||||
|
}
|
||||||
|
trainFromDataset(
|
||||||
|
{ datasetId: trainTarget.dataset_id, req: trainRequest },
|
||||||
|
{ onSuccess: () => setTrainTarget(null) },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
|
||||||
|
}
|
||||||
|
|
||||||
|
if (datasets.length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
|
||||||
|
<Database size={48} className="mb-4 opacity-40" />
|
||||||
|
<p className="text-lg mb-2">No datasets yet</p>
|
||||||
|
<p className="text-sm mb-4">Create a dataset to start training</p>
|
||||||
|
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||||
|
<table className="w-full text-left">
|
||||||
|
<thead className="bg-white border-b border-warm-border">
|
||||||
|
<tr>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{datasets.map(ds => (
|
||||||
|
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||||
|
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
|
||||||
|
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
|
||||||
|
<td className="py-3 px-4">
|
||||||
|
<div className="flex gap-1">
|
||||||
|
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
|
||||||
|
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
|
||||||
|
<Eye size={14} />
|
||||||
|
</button>
|
||||||
|
{ds.status === 'ready' && (
|
||||||
|
<button title="Train" onClick={() => setTrainTarget(ds)}
|
||||||
|
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
|
||||||
|
<Play size={14} />
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
|
||||||
|
disabled={isDeleting}
|
||||||
|
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors">
|
||||||
|
<Trash2 size={14} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{trainTarget && (
|
||||||
|
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create Dataset ---
|
||||||
|
|
||||||
|
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
|
||||||
|
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
|
||||||
|
const { createDatasetAsync, isCreating } = useDatasets()
|
||||||
|
|
||||||
|
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
|
||||||
|
const [name, setName] = useState('')
|
||||||
|
const [description, setDescription] = useState('')
|
||||||
|
const [trainRatio, setTrainRatio] = useState(0.7)
|
||||||
|
const [valRatio, setValRatio] = useState(0.2)
|
||||||
|
|
||||||
|
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
|
||||||
|
|
||||||
|
const toggleDoc = (id: string) => {
|
||||||
|
setSelectedIds(prev => {
|
||||||
|
const next = new Set(prev)
|
||||||
|
if (next.has(id)) { next.delete(id) } else { next.add(id) }
|
||||||
|
return next
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleAll = () => {
|
||||||
|
if (selectedIds.size === documents.length) {
|
||||||
|
setSelectedIds(new Set())
|
||||||
|
} else {
|
||||||
|
setSelectedIds(new Set(documents.map((d) => d.document_id)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCreate = async () => {
|
||||||
|
await createDatasetAsync({
|
||||||
|
name,
|
||||||
|
description: description || undefined,
|
||||||
|
document_ids: [...selectedIds],
|
||||||
|
train_ratio: trainRatio,
|
||||||
|
val_ratio: valRatio,
|
||||||
|
})
|
||||||
|
onSwitchTab('datasets')
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex gap-8">
|
||||||
|
{/* Document selection */}
|
||||||
|
<div className="flex-1 flex flex-col">
|
||||||
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
|
||||||
|
{isLoadingDocs ? (
|
||||||
|
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
|
||||||
|
) : (
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
|
||||||
|
<div className="overflow-auto max-h-[calc(100vh-240px)]">
|
||||||
|
<table className="w-full text-left">
|
||||||
|
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
||||||
|
<tr>
|
||||||
|
<th className="py-3 pl-6 pr-4 w-12">
|
||||||
|
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
|
||||||
|
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
|
||||||
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{documents.map((doc) => (
|
||||||
|
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
|
||||||
|
onClick={() => toggleDoc(doc.document_id)}>
|
||||||
|
<td className="py-3 pl-6 pr-4">
|
||||||
|
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
|
||||||
|
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
|
||||||
|
</td>
|
||||||
|
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Config panel */}
|
||||||
|
<div className="w-80">
|
||||||
|
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
||||||
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
|
||||||
|
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
|
||||||
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
|
||||||
|
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
|
||||||
|
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
|
||||||
|
<div className="flex gap-2 text-sm">
|
||||||
|
<div className="flex-1">
|
||||||
|
<span className="text-xs text-warm-text-muted">Train</span>
|
||||||
|
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
|
||||||
|
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
|
<span className="text-xs text-warm-text-muted">Val</span>
|
||||||
|
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
|
||||||
|
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||||
|
</div>
|
||||||
|
<div className="flex-1">
|
||||||
|
<span className="text-xs text-warm-text-muted">Test</span>
|
||||||
|
<input type="number" value={testRatio} readOnly
|
||||||
|
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div className="pt-4 border-t border-warm-border">
|
<div className="pt-4 border-t border-warm-border">
|
||||||
<Button className="w-full h-12">Start Training</Button>
|
{selectedIds.size > 0 && selectedIds.size < 10 && (
|
||||||
|
<p className="text-xs text-warm-state-warning mb-2">
|
||||||
|
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
<Button className="w-full h-11" onClick={handleCreate}
|
||||||
|
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
|
||||||
|
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
)
|
||||||
};
|
}
|
||||||
|
|
||||||
|
// --- Main Training Component ---
|
||||||
|
|
||||||
|
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
||||||
|
const [activeTab, setActiveTab] = useState<Tab>('datasets')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto">
|
||||||
|
<div className="flex items-center justify-between mb-6">
|
||||||
|
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Tabs */}
|
||||||
|
<div className="flex gap-1 mb-6 border-b border-warm-border">
|
||||||
|
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
|
||||||
|
<button key={key} onClick={() => setActiveTab(key)}
|
||||||
|
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
|
||||||
|
activeTab === key
|
||||||
|
? 'border-warm-state-info text-warm-state-info'
|
||||||
|
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
|
||||||
|
}`}>
|
||||||
|
{label}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
|
||||||
|
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ interface UploadModalProps {
|
|||||||
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
|
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
|
||||||
const [isDragging, setIsDragging] = useState(false)
|
const [isDragging, setIsDragging] = useState(false)
|
||||||
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
||||||
|
const [groupKey, setGroupKey] = useState('')
|
||||||
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
||||||
const [errorMessage, setErrorMessage] = useState('')
|
const [errorMessage, setErrorMessage] = useState('')
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||||
@@ -61,10 +62,13 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
// Upload files one by one
|
// Upload files one by one
|
||||||
for (const file of selectedFiles) {
|
for (const file of selectedFiles) {
|
||||||
await new Promise<void>((resolve, reject) => {
|
await new Promise<void>((resolve, reject) => {
|
||||||
uploadDocument(file, {
|
uploadDocument(
|
||||||
onSuccess: () => resolve(),
|
{ file, groupKey: groupKey || undefined },
|
||||||
onError: (error: Error) => reject(error),
|
{
|
||||||
})
|
onSuccess: () => resolve(),
|
||||||
|
onError: (error: Error) => reject(error),
|
||||||
|
}
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,6 +76,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
onClose()
|
onClose()
|
||||||
setSelectedFiles([])
|
setSelectedFiles([])
|
||||||
|
setGroupKey('')
|
||||||
setUploadStatus('idle')
|
setUploadStatus('idle')
|
||||||
}, 1500)
|
}, 1500)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -85,6 +90,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
return // Prevent closing during upload
|
return // Prevent closing during upload
|
||||||
}
|
}
|
||||||
setSelectedFiles([])
|
setSelectedFiles([])
|
||||||
|
setGroupKey('')
|
||||||
setUploadStatus('idle')
|
setUploadStatus('idle')
|
||||||
setErrorMessage('')
|
setErrorMessage('')
|
||||||
onClose()
|
onClose()
|
||||||
@@ -173,6 +179,26 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Group Key Input */}
|
||||||
|
{selectedFiles.length > 0 && (
|
||||||
|
<div className="mb-6">
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
|
||||||
|
Group Key (optional)
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={groupKey}
|
||||||
|
onChange={(e) => setGroupKey(e.target.value)}
|
||||||
|
placeholder="e.g., 2024-Q1, supplier-abc, project-name"
|
||||||
|
className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
|
||||||
|
disabled={uploadStatus === 'uploading'}
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-warm-text-muted mt-1">
|
||||||
|
Use group keys to organize documents into logical groups
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Status Messages */}
|
{/* Status Messages */}
|
||||||
{uploadStatus === 'success' && (
|
{uploadStatus === 'success' && (
|
||||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
|
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
|
||||||
|
|||||||
@@ -2,3 +2,6 @@ export { useDocuments } from './useDocuments'
|
|||||||
export { useDocumentDetail } from './useDocumentDetail'
|
export { useDocumentDetail } from './useDocumentDetail'
|
||||||
export { useAnnotations } from './useAnnotations'
|
export { useAnnotations } from './useAnnotations'
|
||||||
export { useTraining, useTrainingDocuments } from './useTraining'
|
export { useTraining, useTrainingDocuments } from './useTraining'
|
||||||
|
export { useDatasets, useDatasetDetail } from './useDatasets'
|
||||||
|
export { useAugmentation } from './useAugmentation'
|
||||||
|
export { useModels, useModelDetail, useActiveModel } from './useModels'
|
||||||
|
|||||||
226
frontend/src/hooks/useAugmentation.test.tsx
Normal file
226
frontend/src/hooks/useAugmentation.test.tsx
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
/**
|
||||||
|
* Tests for useAugmentation hook.
|
||||||
|
*
|
||||||
|
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||||
|
import { renderHook, waitFor } from '@testing-library/react'
|
||||||
|
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||||
|
import { augmentationApi } from '../api/endpoints/augmentation'
|
||||||
|
import { useAugmentation } from './useAugmentation'
|
||||||
|
import type { ReactNode } from 'react'
|
||||||
|
|
||||||
|
// Mock the API
|
||||||
|
vi.mock('../api/endpoints/augmentation', () => ({
|
||||||
|
augmentationApi: {
|
||||||
|
getTypes: vi.fn(),
|
||||||
|
getPresets: vi.fn(),
|
||||||
|
preview: vi.fn(),
|
||||||
|
previewConfig: vi.fn(),
|
||||||
|
createBatch: vi.fn(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Test wrapper with QueryClient
|
||||||
|
const createWrapper = () => {
|
||||||
|
const queryClient = new QueryClient({
|
||||||
|
defaultOptions: {
|
||||||
|
queries: {
|
||||||
|
retry: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return ({ children }: { children: ReactNode }) => (
|
||||||
|
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('useAugmentation', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getTypes', () => {
|
||||||
|
it('should fetch augmentation types', async () => {
|
||||||
|
const mockTypes = {
|
||||||
|
augmentation_types: [
|
||||||
|
{
|
||||||
|
name: 'gaussian_noise',
|
||||||
|
description: 'Adds Gaussian noise',
|
||||||
|
affects_geometry: false,
|
||||||
|
stage: 'noise',
|
||||||
|
default_params: { mean: 0, std: 15 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'perspective_warp',
|
||||||
|
description: 'Applies perspective warp',
|
||||||
|
affects_geometry: true,
|
||||||
|
stage: 'geometric',
|
||||||
|
default_params: { max_warp: 0.02 },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce(mockTypes)
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useAugmentation(), {
|
||||||
|
wrapper: createWrapper(),
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.isLoadingTypes).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current.augmentationTypes).toHaveLength(2)
|
||||||
|
expect(result.current.augmentationTypes[0].name).toBe('gaussian_noise')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle error when fetching types', async () => {
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockRejectedValueOnce(new Error('Network error'))
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useAugmentation(), {
|
||||||
|
wrapper: createWrapper(),
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.isLoadingTypes).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current.typesError).toBeTruthy()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getPresets', () => {
|
||||||
|
it('should fetch augmentation presets', async () => {
|
||||||
|
const mockPresets = {
|
||||||
|
presets: [
|
||||||
|
{ name: 'conservative', description: 'Safe augmentations' },
|
||||||
|
{ name: 'moderate', description: 'Balanced augmentations' },
|
||||||
|
{ name: 'aggressive', description: 'Strong augmentations' },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||||
|
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce(mockPresets)
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useAugmentation(), {
|
||||||
|
wrapper: createWrapper(),
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.isLoadingPresets).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current.presets).toHaveLength(3)
|
||||||
|
expect(result.current.presets[0].name).toBe('conservative')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('preview', () => {
|
||||||
|
it('should preview single augmentation', async () => {
|
||||||
|
const mockPreview = {
|
||||||
|
preview_url: 'data:image/png;base64,xxx',
|
||||||
|
original_url: 'data:image/png;base64,yyy',
|
||||||
|
applied_params: { std: 15 },
|
||||||
|
}
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||||
|
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||||
|
vi.mocked(augmentationApi.preview).mockResolvedValueOnce(mockPreview)
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useAugmentation(), {
|
||||||
|
wrapper: createWrapper(),
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.isLoadingTypes).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call preview mutation
|
||||||
|
result.current.preview({
|
||||||
|
documentId: 'doc-123',
|
||||||
|
augmentationType: 'gaussian_noise',
|
||||||
|
params: { std: 15 },
|
||||||
|
page: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(augmentationApi.preview).toHaveBeenCalledWith(
|
||||||
|
'doc-123',
|
||||||
|
{ augmentation_type: 'gaussian_noise', params: { std: 15 } },
|
||||||
|
1
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should track preview loading state', async () => {
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||||
|
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||||
|
vi.mocked(augmentationApi.preview).mockImplementation(
|
||||||
|
() => new Promise((resolve) => setTimeout(resolve, 100))
|
||||||
|
)
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useAugmentation(), {
|
||||||
|
wrapper: createWrapper(),
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.isLoadingTypes).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current.isPreviewing).toBe(false)
|
||||||
|
|
||||||
|
result.current.preview({
|
||||||
|
documentId: 'doc-123',
|
||||||
|
augmentationType: 'gaussian_noise',
|
||||||
|
params: {},
|
||||||
|
page: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
// State update happens asynchronously
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.isPreviewing).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('createBatch', () => {
|
||||||
|
it('should create augmented dataset', async () => {
|
||||||
|
const mockResponse = {
|
||||||
|
task_id: 'task-123',
|
||||||
|
status: 'pending',
|
||||||
|
message: 'Augmentation task queued',
|
||||||
|
estimated_images: 100,
|
||||||
|
}
|
||||||
|
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||||
|
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||||
|
vi.mocked(augmentationApi.createBatch).mockResolvedValueOnce(mockResponse)
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useAugmentation(), {
|
||||||
|
wrapper: createWrapper(),
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.isLoadingTypes).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
result.current.createBatch({
|
||||||
|
dataset_id: 'dataset-123',
|
||||||
|
config: {
|
||||||
|
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
|
||||||
|
},
|
||||||
|
output_name: 'augmented-dataset',
|
||||||
|
multiplier: 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(augmentationApi.createBatch).toHaveBeenCalledWith({
|
||||||
|
dataset_id: 'dataset-123',
|
||||||
|
config: {
|
||||||
|
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
|
||||||
|
},
|
||||||
|
output_name: 'augmented-dataset',
|
||||||
|
multiplier: 2,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
121
frontend/src/hooks/useAugmentation.ts
Normal file
121
frontend/src/hooks/useAugmentation.ts
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
/**
|
||||||
|
* Hook for managing augmentation operations.
|
||||||
|
*
|
||||||
|
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useQuery, useMutation } from '@tanstack/react-query'
|
||||||
|
import {
|
||||||
|
augmentationApi,
|
||||||
|
type AugmentationTypesResponse,
|
||||||
|
type PresetsResponse,
|
||||||
|
type PreviewResponse,
|
||||||
|
type BatchRequest,
|
||||||
|
type BatchResponse,
|
||||||
|
type AugmentationConfig,
|
||||||
|
} from '../api/endpoints/augmentation'
|
||||||
|
|
||||||
|
interface PreviewParams {
|
||||||
|
documentId: string
|
||||||
|
augmentationType: string
|
||||||
|
params: Record<string, unknown>
|
||||||
|
page?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface PreviewConfigParams {
|
||||||
|
documentId: string
|
||||||
|
config: AugmentationConfig
|
||||||
|
page?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useAugmentation = () => {
|
||||||
|
// Fetch augmentation types
|
||||||
|
const {
|
||||||
|
data: typesData,
|
||||||
|
isLoading: isLoadingTypes,
|
||||||
|
error: typesError,
|
||||||
|
} = useQuery<AugmentationTypesResponse>({
|
||||||
|
queryKey: ['augmentation', 'types'],
|
||||||
|
queryFn: () => augmentationApi.getTypes(),
|
||||||
|
staleTime: 5 * 60 * 1000, // Cache for 5 minutes
|
||||||
|
})
|
||||||
|
|
||||||
|
// Fetch presets
|
||||||
|
const {
|
||||||
|
data: presetsData,
|
||||||
|
isLoading: isLoadingPresets,
|
||||||
|
error: presetsError,
|
||||||
|
} = useQuery<PresetsResponse>({
|
||||||
|
queryKey: ['augmentation', 'presets'],
|
||||||
|
queryFn: () => augmentationApi.getPresets(),
|
||||||
|
staleTime: 5 * 60 * 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Preview single augmentation mutation
|
||||||
|
const previewMutation = useMutation<PreviewResponse, Error, PreviewParams>({
|
||||||
|
mutationFn: ({ documentId, augmentationType, params, page = 1 }) =>
|
||||||
|
augmentationApi.preview(
|
||||||
|
documentId,
|
||||||
|
{ augmentation_type: augmentationType, params },
|
||||||
|
page
|
||||||
|
),
|
||||||
|
onError: (error) => {
|
||||||
|
console.error('Preview augmentation failed:', error)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Preview full config mutation
|
||||||
|
const previewConfigMutation = useMutation<PreviewResponse, Error, PreviewConfigParams>({
|
||||||
|
mutationFn: ({ documentId, config, page = 1 }) =>
|
||||||
|
augmentationApi.previewConfig(documentId, config, page),
|
||||||
|
onError: (error) => {
|
||||||
|
console.error('Preview config failed:', error)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create augmented dataset mutation
|
||||||
|
const createBatchMutation = useMutation<BatchResponse, Error, BatchRequest>({
|
||||||
|
mutationFn: (request) => augmentationApi.createBatch(request),
|
||||||
|
onError: (error) => {
|
||||||
|
console.error('Create augmented dataset failed:', error)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
// Types data
|
||||||
|
augmentationTypes: typesData?.augmentation_types || [],
|
||||||
|
isLoadingTypes,
|
||||||
|
typesError,
|
||||||
|
|
||||||
|
// Presets data
|
||||||
|
presets: presetsData?.presets || [],
|
||||||
|
isLoadingPresets,
|
||||||
|
presetsError,
|
||||||
|
|
||||||
|
// Preview single augmentation
|
||||||
|
preview: previewMutation.mutate,
|
||||||
|
previewAsync: previewMutation.mutateAsync,
|
||||||
|
isPreviewing: previewMutation.isPending,
|
||||||
|
previewData: previewMutation.data,
|
||||||
|
previewError: previewMutation.error,
|
||||||
|
|
||||||
|
// Preview full config
|
||||||
|
previewConfig: previewConfigMutation.mutate,
|
||||||
|
previewConfigAsync: previewConfigMutation.mutateAsync,
|
||||||
|
isPreviewingConfig: previewConfigMutation.isPending,
|
||||||
|
previewConfigData: previewConfigMutation.data,
|
||||||
|
previewConfigError: previewConfigMutation.error,
|
||||||
|
|
||||||
|
// Create batch
|
||||||
|
createBatch: createBatchMutation.mutate,
|
||||||
|
createBatchAsync: createBatchMutation.mutateAsync,
|
||||||
|
isCreatingBatch: createBatchMutation.isPending,
|
||||||
|
batchData: createBatchMutation.data,
|
||||||
|
batchError: createBatchMutation.error,
|
||||||
|
|
||||||
|
// Reset functions for clearing stale mutation state
|
||||||
|
resetPreview: previewMutation.reset,
|
||||||
|
resetPreviewConfig: previewConfigMutation.reset,
|
||||||
|
resetBatch: createBatchMutation.reset,
|
||||||
|
}
|
||||||
|
}
|
||||||
84
frontend/src/hooks/useDatasets.ts
Normal file
84
frontend/src/hooks/useDatasets.ts
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { datasetsApi } from '../api/endpoints'
|
||||||
|
import type {
|
||||||
|
DatasetCreateRequest,
|
||||||
|
DatasetDetailResponse,
|
||||||
|
DatasetListResponse,
|
||||||
|
DatasetTrainRequest,
|
||||||
|
} from '../api/types'
|
||||||
|
|
||||||
|
export const useDatasets = (params?: {
|
||||||
|
status?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}) => {
|
||||||
|
const queryClient = useQueryClient()
|
||||||
|
|
||||||
|
const { data, isLoading, error, refetch } = useQuery<DatasetListResponse>({
|
||||||
|
queryKey: ['datasets', params],
|
||||||
|
queryFn: () => datasetsApi.list(params),
|
||||||
|
staleTime: 30000,
|
||||||
|
// Poll every 5 seconds when there's an active training task
|
||||||
|
refetchInterval: (query) => {
|
||||||
|
const datasets = query.state.data?.datasets ?? []
|
||||||
|
const hasActiveTraining = datasets.some(
|
||||||
|
d => d.training_status === 'running' || d.training_status === 'pending' || d.training_status === 'scheduled'
|
||||||
|
)
|
||||||
|
return hasActiveTraining ? 5000 : false
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const createMutation = useMutation({
|
||||||
|
mutationFn: (req: DatasetCreateRequest) => datasetsApi.create(req),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const deleteMutation = useMutation({
|
||||||
|
mutationFn: (datasetId: string) => datasetsApi.remove(datasetId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const trainMutation = useMutation({
|
||||||
|
mutationFn: ({ datasetId, req }: { datasetId: string; req: DatasetTrainRequest }) =>
|
||||||
|
datasetsApi.trainFromDataset(datasetId, req),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
datasets: data?.datasets ?? [],
|
||||||
|
total: data?.total ?? 0,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
refetch,
|
||||||
|
createDataset: createMutation.mutate,
|
||||||
|
createDatasetAsync: createMutation.mutateAsync,
|
||||||
|
isCreating: createMutation.isPending,
|
||||||
|
deleteDataset: deleteMutation.mutate,
|
||||||
|
isDeleting: deleteMutation.isPending,
|
||||||
|
trainFromDataset: trainMutation.mutate,
|
||||||
|
trainFromDatasetAsync: trainMutation.mutateAsync,
|
||||||
|
isTraining: trainMutation.isPending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useDatasetDetail = (datasetId: string | null) => {
|
||||||
|
const { data, isLoading, error } = useQuery<DatasetDetailResponse>({
|
||||||
|
queryKey: ['datasets', datasetId],
|
||||||
|
queryFn: () => datasetsApi.getDetail(datasetId!),
|
||||||
|
enabled: !!datasetId,
|
||||||
|
staleTime: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
dataset: data ?? null,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,7 +18,16 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const uploadMutation = useMutation({
|
const uploadMutation = useMutation({
|
||||||
mutationFn: (file: File) => documentsApi.upload(file),
|
mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) =>
|
||||||
|
documentsApi.upload(file, groupKey),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const updateGroupKeyMutation = useMutation({
|
||||||
|
mutationFn: ({ documentId, groupKey }: { documentId: string; groupKey: string | null }) =>
|
||||||
|
documentsApi.updateGroupKey(documentId, groupKey),
|
||||||
onSuccess: () => {
|
onSuccess: () => {
|
||||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
},
|
},
|
||||||
@@ -74,5 +83,8 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
|
|||||||
isUpdatingStatus: updateStatusMutation.isPending,
|
isUpdatingStatus: updateStatusMutation.isPending,
|
||||||
triggerAutoLabel: triggerAutoLabelMutation.mutate,
|
triggerAutoLabel: triggerAutoLabelMutation.mutate,
|
||||||
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
|
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
|
||||||
|
updateGroupKey: updateGroupKeyMutation.mutate,
|
||||||
|
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
|
||||||
|
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
98
frontend/src/hooks/useModels.ts
Normal file
98
frontend/src/hooks/useModels.ts
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { modelsApi } from '../api/endpoints'
|
||||||
|
import type {
|
||||||
|
ModelVersionListResponse,
|
||||||
|
ModelVersionDetailResponse,
|
||||||
|
ActiveModelResponse,
|
||||||
|
} from '../api/types'
|
||||||
|
|
||||||
|
export const useModels = (params?: {
|
||||||
|
status?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}) => {
|
||||||
|
const queryClient = useQueryClient()
|
||||||
|
|
||||||
|
const { data, isLoading, error, refetch } = useQuery<ModelVersionListResponse>({
|
||||||
|
queryKey: ['models', params],
|
||||||
|
queryFn: () => modelsApi.list(params),
|
||||||
|
staleTime: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
const activateMutation = useMutation({
|
||||||
|
mutationFn: (versionId: string) => modelsApi.activate(versionId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const deactivateMutation = useMutation({
|
||||||
|
mutationFn: (versionId: string) => modelsApi.deactivate(versionId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const archiveMutation = useMutation({
|
||||||
|
mutationFn: (versionId: string) => modelsApi.archive(versionId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const deleteMutation = useMutation({
|
||||||
|
mutationFn: (versionId: string) => modelsApi.delete(versionId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
models: data?.models ?? [],
|
||||||
|
total: data?.total ?? 0,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
refetch,
|
||||||
|
activateModel: activateMutation.mutate,
|
||||||
|
activateModelAsync: activateMutation.mutateAsync,
|
||||||
|
isActivating: activateMutation.isPending,
|
||||||
|
deactivateModel: deactivateMutation.mutate,
|
||||||
|
isDeactivating: deactivateMutation.isPending,
|
||||||
|
archiveModel: archiveMutation.mutate,
|
||||||
|
isArchiving: archiveMutation.isPending,
|
||||||
|
deleteModel: deleteMutation.mutate,
|
||||||
|
isDeleting: deleteMutation.isPending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useModelDetail = (versionId: string | null) => {
|
||||||
|
const { data, isLoading, error } = useQuery<ModelVersionDetailResponse>({
|
||||||
|
queryKey: ['models', versionId],
|
||||||
|
queryFn: () => modelsApi.getDetail(versionId!),
|
||||||
|
enabled: !!versionId,
|
||||||
|
staleTime: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
model: data ?? null,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useActiveModel = () => {
|
||||||
|
const { data, isLoading, error } = useQuery<ActiveModelResponse>({
|
||||||
|
queryKey: ['models', 'active'],
|
||||||
|
queryFn: () => modelsApi.getActive(),
|
||||||
|
staleTime: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
hasActiveModel: data?.has_active_model ?? false,
|
||||||
|
activeModel: data?.model ?? null,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
}
|
||||||
|
}
|
||||||
8
migrations/005_add_group_key.sql
Normal file
8
migrations/005_add_group_key.sql
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
-- Add group_key column to admin_documents
|
||||||
|
-- Allows users to organize documents into logical groups
|
||||||
|
|
||||||
|
-- Add the column (nullable, VARCHAR 255)
|
||||||
|
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
|
||||||
|
|
||||||
|
-- Add index for filtering/grouping queries
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
|
||||||
49
migrations/006_model_versions.sql
Normal file
49
migrations/006_model_versions.sql
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
-- Model versions table for tracking trained model deployments.
|
||||||
|
-- Each training run can produce a model version for inference.
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS model_versions (
|
||||||
|
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
version VARCHAR(50) NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
model_path VARCHAR(512) NOT NULL,
|
||||||
|
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
|
||||||
|
is_active BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
|
||||||
|
-- Training association
|
||||||
|
task_id UUID REFERENCES training_tasks(task_id) ON DELETE SET NULL,
|
||||||
|
dataset_id UUID REFERENCES training_datasets(dataset_id) ON DELETE SET NULL,
|
||||||
|
|
||||||
|
-- Training metrics
|
||||||
|
metrics_mAP DOUBLE PRECISION,
|
||||||
|
metrics_precision DOUBLE PRECISION,
|
||||||
|
metrics_recall DOUBLE PRECISION,
|
||||||
|
document_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
|
||||||
|
-- Training configuration snapshot
|
||||||
|
training_config JSONB,
|
||||||
|
|
||||||
|
-- File info
|
||||||
|
file_size BIGINT,
|
||||||
|
|
||||||
|
-- Timestamps
|
||||||
|
trained_at TIMESTAMP WITH TIME ZONE,
|
||||||
|
activated_at TIMESTAMP WITH TIME ZONE,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Indexes
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_versions_version ON model_versions(version);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_versions_status ON model_versions(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_versions_is_active ON model_versions(is_active);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_versions_task_id ON model_versions(task_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_versions_dataset_id ON model_versions(dataset_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_model_versions_created ON model_versions(created_at);
|
||||||
|
|
||||||
|
-- Ensure only one active model at a time
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
|
||||||
|
ON model_versions(is_active) WHERE is_active = TRUE;
|
||||||
|
|
||||||
|
-- Comment
|
||||||
|
COMMENT ON TABLE model_versions IS 'Trained model versions for inference deployment';
|
||||||
46
migrations/007_training_tasks_extra_columns.sql
Normal file
46
migrations/007_training_tasks_extra_columns.sql
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
-- Add missing columns to training_tasks table
|
||||||
|
|
||||||
|
-- Add name column
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
|
||||||
|
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
|
||||||
|
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
|
||||||
|
|
||||||
|
-- Add description column
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
|
||||||
|
|
||||||
|
-- Add admin_token column (for multi-tenant support)
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
|
||||||
|
|
||||||
|
-- Add task_type column
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
|
||||||
|
|
||||||
|
-- Add recurring schedule columns
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
|
||||||
|
|
||||||
|
-- Add result metrics columns (for display without parsing JSONB)
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
|
||||||
|
|
||||||
|
-- Rename metrics to config if exists
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_name = 'training_tasks' AND column_name = 'metrics'
|
||||||
|
AND NOT EXISTS (SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_name = 'training_tasks' AND column_name = 'config')) THEN
|
||||||
|
ALTER TABLE training_tasks RENAME COLUMN metrics TO config;
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
-- Add updated_at column
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
|
||||||
|
|
||||||
|
-- Create index on name
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
|
||||||
|
|
||||||
|
-- Create index on metrics_mAP
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
|
||||||
14
migrations/008_fix_model_versions_fk.sql
Normal file
14
migrations/008_fix_model_versions_fk.sql
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
-- Fix foreign key constraints on model_versions table to allow CASCADE delete
|
||||||
|
|
||||||
|
-- Drop existing constraints
|
||||||
|
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
|
||||||
|
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
|
||||||
|
|
||||||
|
-- Add constraints with ON DELETE SET NULL
|
||||||
|
ALTER TABLE model_versions
|
||||||
|
ADD CONSTRAINT model_versions_dataset_id_fkey
|
||||||
|
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
|
||||||
|
|
||||||
|
ALTER TABLE model_versions
|
||||||
|
ADD CONSTRAINT model_versions_task_id_fkey
|
||||||
|
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
|
||||||
@@ -25,6 +25,7 @@ from inference.data.admin_models import (
|
|||||||
AnnotationHistory,
|
AnnotationHistory,
|
||||||
TrainingDataset,
|
TrainingDataset,
|
||||||
DatasetDocument,
|
DatasetDocument,
|
||||||
|
ModelVersion,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -110,6 +111,7 @@ class AdminDB:
|
|||||||
page_count: int = 1,
|
page_count: int = 1,
|
||||||
upload_source: str = "ui",
|
upload_source: str = "ui",
|
||||||
csv_field_values: dict[str, Any] | None = None,
|
csv_field_values: dict[str, Any] | None = None,
|
||||||
|
group_key: str | None = None,
|
||||||
admin_token: str | None = None, # Deprecated, kept for compatibility
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new document record."""
|
"""Create a new document record."""
|
||||||
@@ -122,6 +124,7 @@ class AdminDB:
|
|||||||
page_count=page_count,
|
page_count=page_count,
|
||||||
upload_source=upload_source,
|
upload_source=upload_source,
|
||||||
csv_field_values=csv_field_values,
|
csv_field_values=csv_field_values,
|
||||||
|
group_key=group_key,
|
||||||
)
|
)
|
||||||
session.add(document)
|
session.add(document)
|
||||||
session.flush()
|
session.flush()
|
||||||
@@ -253,6 +256,17 @@ class AdminDB:
|
|||||||
document.updated_at = datetime.utcnow()
|
document.updated_at = datetime.utcnow()
|
||||||
session.add(document)
|
session.add(document)
|
||||||
|
|
||||||
|
def update_document_group_key(self, document_id: str, group_key: str | None) -> bool:
|
||||||
|
"""Update document group key."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
document = session.get(AdminDocument, UUID(document_id))
|
||||||
|
if document:
|
||||||
|
document.group_key = group_key
|
||||||
|
document.updated_at = datetime.utcnow()
|
||||||
|
session.add(document)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def delete_document(self, document_id: str) -> bool:
|
def delete_document(self, document_id: str) -> bool:
|
||||||
"""Delete a document and its annotations."""
|
"""Delete a document and its annotations."""
|
||||||
with get_session_context() as session:
|
with get_session_context() as session:
|
||||||
@@ -1215,6 +1229,39 @@ class AdminDB:
|
|||||||
session.expunge(d)
|
session.expunge(d)
|
||||||
return list(datasets), total
|
return list(datasets), total
|
||||||
|
|
||||||
|
def get_active_training_tasks_for_datasets(
|
||||||
|
self, dataset_ids: list[str]
|
||||||
|
) -> dict[str, dict[str, str]]:
|
||||||
|
"""Get active (pending/scheduled/running) training tasks for datasets.
|
||||||
|
|
||||||
|
Returns a dict mapping dataset_id to {"task_id": ..., "status": ...}
|
||||||
|
"""
|
||||||
|
if not dataset_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Validate UUIDs before query
|
||||||
|
valid_uuids = []
|
||||||
|
for d in dataset_ids:
|
||||||
|
try:
|
||||||
|
valid_uuids.append(UUID(d))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid UUID in get_active_training_tasks_for_datasets: %s", d)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not valid_uuids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
with get_session_context() as session:
|
||||||
|
statement = select(TrainingTask).where(
|
||||||
|
TrainingTask.dataset_id.in_(valid_uuids),
|
||||||
|
TrainingTask.status.in_(["pending", "scheduled", "running"]),
|
||||||
|
)
|
||||||
|
results = session.exec(statement).all()
|
||||||
|
return {
|
||||||
|
str(t.dataset_id): {"task_id": str(t.task_id), "status": t.status}
|
||||||
|
for t in results
|
||||||
|
}
|
||||||
|
|
||||||
def update_dataset_status(
|
def update_dataset_status(
|
||||||
self,
|
self,
|
||||||
dataset_id: str | UUID,
|
dataset_id: str | UUID,
|
||||||
@@ -1314,3 +1361,182 @@ class AdminDB:
|
|||||||
session.delete(dataset)
|
session.delete(dataset)
|
||||||
session.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# Model Version Operations
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
def create_model_version(
|
||||||
|
self,
|
||||||
|
version: str,
|
||||||
|
name: str,
|
||||||
|
model_path: str,
|
||||||
|
description: str | None = None,
|
||||||
|
task_id: str | UUID | None = None,
|
||||||
|
dataset_id: str | UUID | None = None,
|
||||||
|
metrics_mAP: float | None = None,
|
||||||
|
metrics_precision: float | None = None,
|
||||||
|
metrics_recall: float | None = None,
|
||||||
|
document_count: int = 0,
|
||||||
|
training_config: dict[str, Any] | None = None,
|
||||||
|
file_size: int | None = None,
|
||||||
|
trained_at: datetime | None = None,
|
||||||
|
) -> ModelVersion:
|
||||||
|
"""Create a new model version."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
model = ModelVersion(
|
||||||
|
version=version,
|
||||||
|
name=name,
|
||||||
|
model_path=model_path,
|
||||||
|
description=description,
|
||||||
|
task_id=UUID(str(task_id)) if task_id else None,
|
||||||
|
dataset_id=UUID(str(dataset_id)) if dataset_id else None,
|
||||||
|
metrics_mAP=metrics_mAP,
|
||||||
|
metrics_precision=metrics_precision,
|
||||||
|
metrics_recall=metrics_recall,
|
||||||
|
document_count=document_count,
|
||||||
|
training_config=training_config,
|
||||||
|
file_size=file_size,
|
||||||
|
trained_at=trained_at,
|
||||||
|
)
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(model)
|
||||||
|
session.expunge(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||||
|
"""Get a model version by ID."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||||
|
if model:
|
||||||
|
session.expunge(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_model_versions(
|
||||||
|
self,
|
||||||
|
status: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[ModelVersion], int]:
|
||||||
|
"""List model versions with optional status filter."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
query = select(ModelVersion)
|
||||||
|
count_query = select(func.count()).select_from(ModelVersion)
|
||||||
|
if status:
|
||||||
|
query = query.where(ModelVersion.status == status)
|
||||||
|
count_query = count_query.where(ModelVersion.status == status)
|
||||||
|
total = session.exec(count_query).one()
|
||||||
|
models = session.exec(
|
||||||
|
query.order_by(ModelVersion.created_at.desc()).offset(offset).limit(limit)
|
||||||
|
).all()
|
||||||
|
for m in models:
|
||||||
|
session.expunge(m)
|
||||||
|
return list(models), total
|
||||||
|
|
||||||
|
def get_active_model_version(self) -> ModelVersion | None:
|
||||||
|
"""Get the currently active model version for inference."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
result = session.exec(
|
||||||
|
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||||
|
).first()
|
||||||
|
if result:
|
||||||
|
session.expunge(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def activate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||||
|
"""Activate a model version for inference (deactivates all others)."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
# Deactivate all versions
|
||||||
|
all_versions = session.exec(
|
||||||
|
select(ModelVersion).where(ModelVersion.is_active == True)
|
||||||
|
).all()
|
||||||
|
for v in all_versions:
|
||||||
|
v.is_active = False
|
||||||
|
v.status = "inactive"
|
||||||
|
v.updated_at = datetime.utcnow()
|
||||||
|
session.add(v)
|
||||||
|
|
||||||
|
# Activate the specified version
|
||||||
|
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
model.is_active = True
|
||||||
|
model.status = "active"
|
||||||
|
model.activated_at = datetime.utcnow()
|
||||||
|
model.updated_at = datetime.utcnow()
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(model)
|
||||||
|
session.expunge(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def deactivate_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||||
|
"""Deactivate a model version."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
model.is_active = False
|
||||||
|
model.status = "inactive"
|
||||||
|
model.updated_at = datetime.utcnow()
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(model)
|
||||||
|
session.expunge(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def update_model_version(
|
||||||
|
self,
|
||||||
|
version_id: str | UUID,
|
||||||
|
name: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> ModelVersion | None:
|
||||||
|
"""Update model version metadata."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
if name is not None:
|
||||||
|
model.name = name
|
||||||
|
if description is not None:
|
||||||
|
model.description = description
|
||||||
|
if status is not None:
|
||||||
|
model.status = status
|
||||||
|
model.updated_at = datetime.utcnow()
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(model)
|
||||||
|
session.expunge(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def archive_model_version(self, version_id: str | UUID) -> ModelVersion | None:
|
||||||
|
"""Archive a model version."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
# Cannot archive active model
|
||||||
|
if model.is_active:
|
||||||
|
return None
|
||||||
|
model.status = "archived"
|
||||||
|
model.updated_at = datetime.utcnow()
|
||||||
|
session.add(model)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(model)
|
||||||
|
session.expunge(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def delete_model_version(self, version_id: str | UUID) -> bool:
|
||||||
|
"""Delete a model version."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
model = session.get(ModelVersion, UUID(str(version_id)))
|
||||||
|
if not model:
|
||||||
|
return False
|
||||||
|
# Cannot delete active model
|
||||||
|
if model.is_active:
|
||||||
|
return False
|
||||||
|
session.delete(model)
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ class AdminDocument(SQLModel, table=True):
|
|||||||
# Upload source: ui, api
|
# Upload source: ui, api
|
||||||
batch_id: UUID | None = Field(default=None, index=True)
|
batch_id: UUID | None = Field(default=None, index=True)
|
||||||
# Link to batch upload (if uploaded via ZIP)
|
# Link to batch upload (if uploaded via ZIP)
|
||||||
|
group_key: str | None = Field(default=None, max_length=255, index=True)
|
||||||
|
# User-defined grouping key for document organization
|
||||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
# Original CSV values for reference
|
# Original CSV values for reference
|
||||||
auto_label_queued_at: datetime | None = Field(default=None)
|
auto_label_queued_at: datetime | None = Field(default=None)
|
||||||
@@ -275,6 +277,56 @@ class TrainingDocumentLink(SQLModel, table=True):
|
|||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Model Version Management
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVersion(SQLModel, table=True):
|
||||||
|
"""Model version for inference deployment."""
|
||||||
|
|
||||||
|
__tablename__ = "model_versions"
|
||||||
|
|
||||||
|
version_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
version: str = Field(max_length=50, index=True)
|
||||||
|
# Semantic version e.g., "1.0.0", "2.1.0"
|
||||||
|
name: str = Field(max_length=255)
|
||||||
|
description: str | None = Field(default=None)
|
||||||
|
model_path: str = Field(max_length=512)
|
||||||
|
# Path to the model weights file
|
||||||
|
status: str = Field(default="inactive", max_length=20, index=True)
|
||||||
|
# Status: active, inactive, archived
|
||||||
|
is_active: bool = Field(default=False, index=True)
|
||||||
|
# Only one version can be active at a time for inference
|
||||||
|
|
||||||
|
# Training association
|
||||||
|
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
|
||||||
|
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||||
|
|
||||||
|
# Training metrics
|
||||||
|
metrics_mAP: float | None = Field(default=None)
|
||||||
|
metrics_precision: float | None = Field(default=None)
|
||||||
|
metrics_recall: float | None = Field(default=None)
|
||||||
|
document_count: int = Field(default=0)
|
||||||
|
# Number of documents used in training
|
||||||
|
|
||||||
|
# Training configuration snapshot
|
||||||
|
training_config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
# Snapshot of epochs, batch_size, etc.
|
||||||
|
|
||||||
|
# File info
|
||||||
|
file_size: int | None = Field(default=None)
|
||||||
|
# Model file size in bytes
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
trained_at: datetime | None = Field(default=None)
|
||||||
|
# When training completed
|
||||||
|
activated_at: datetime | None = Field(default=None)
|
||||||
|
# When this version was last activated
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Annotation History (v2)
|
# Annotation History (v2)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -49,6 +49,111 @@ def get_engine():
|
|||||||
return _engine
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations() -> None:
|
||||||
|
"""Run database migrations for new columns."""
|
||||||
|
engine = get_engine()
|
||||||
|
|
||||||
|
migrations = [
|
||||||
|
# Migration 004: Training datasets tables and dataset_id on training_tasks
|
||||||
|
(
|
||||||
|
"training_datasets_tables",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS training_datasets (
|
||||||
|
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
status VARCHAR(20) NOT NULL DEFAULT 'building',
|
||||||
|
train_ratio FLOAT NOT NULL DEFAULT 0.8,
|
||||||
|
val_ratio FLOAT NOT NULL DEFAULT 0.1,
|
||||||
|
seed INTEGER NOT NULL DEFAULT 42,
|
||||||
|
total_documents INTEGER NOT NULL DEFAULT 0,
|
||||||
|
total_images INTEGER NOT NULL DEFAULT 0,
|
||||||
|
total_annotations INTEGER NOT NULL DEFAULT 0,
|
||||||
|
dataset_path VARCHAR(512),
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"dataset_documents_table",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS dataset_documents (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
|
||||||
|
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
|
||||||
|
split VARCHAR(10) NOT NULL,
|
||||||
|
page_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
annotation_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
UNIQUE(dataset_id, document_id)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"training_tasks_dataset_id",
|
||||||
|
"""
|
||||||
|
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
# Migration 005: Add group_key to admin_documents
|
||||||
|
(
|
||||||
|
"admin_documents_group_key",
|
||||||
|
"""
|
||||||
|
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
# Migration 006: Model versions table
|
||||||
|
(
|
||||||
|
"model_versions_table",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS model_versions (
|
||||||
|
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
version VARCHAR(50) NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
model_path VARCHAR(512) NOT NULL,
|
||||||
|
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
|
||||||
|
is_active BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
task_id UUID REFERENCES training_tasks(task_id),
|
||||||
|
dataset_id UUID REFERENCES training_datasets(dataset_id),
|
||||||
|
metrics_mAP FLOAT,
|
||||||
|
metrics_precision FLOAT,
|
||||||
|
metrics_recall FLOAT,
|
||||||
|
document_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
training_config JSONB,
|
||||||
|
file_size BIGINT,
|
||||||
|
trained_at TIMESTAMP WITH TIME ZONE,
|
||||||
|
activated_at TIMESTAMP WITH TIME ZONE,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_model_versions_version ON model_versions(version);
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_model_versions_status ON model_versions(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_model_versions_is_active ON model_versions(is_active);
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_model_versions_task_id ON model_versions(task_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
for name, sql in migrations:
|
||||||
|
try:
|
||||||
|
conn.execute(text(sql))
|
||||||
|
conn.commit()
|
||||||
|
logger.info(f"Migration '{name}' applied successfully")
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - column may already exist
|
||||||
|
logger.debug(f"Migration '{name}' skipped or failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_db_and_tables() -> None:
|
def create_db_and_tables() -> None:
|
||||||
"""Create all database tables."""
|
"""Create all database tables."""
|
||||||
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||||
@@ -64,6 +169,9 @@ def create_db_and_tables() -> None:
|
|||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
logger.info("Database tables created/verified")
|
logger.info("Database tables created/verified")
|
||||||
|
|
||||||
|
# Run migrations for new columns
|
||||||
|
run_migrations()
|
||||||
|
|
||||||
|
|
||||||
def get_session() -> Session:
|
def get_session() -> Session:
|
||||||
"""Get a new database session."""
|
"""Get a new database session."""
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Document management, annotations, and training endpoints.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from inference.web.api.v1.admin.annotations import create_annotation_router
|
from inference.web.api.v1.admin.annotations import create_annotation_router
|
||||||
|
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||||
from inference.web.api.v1.admin.auth import create_auth_router
|
from inference.web.api.v1.admin.auth import create_auth_router
|
||||||
from inference.web.api.v1.admin.documents import create_documents_router
|
from inference.web.api.v1.admin.documents import create_documents_router
|
||||||
from inference.web.api.v1.admin.locks import create_locks_router
|
from inference.web.api.v1.admin.locks import create_locks_router
|
||||||
@@ -12,6 +13,7 @@ from inference.web.api.v1.admin.training import create_training_router
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_annotation_router",
|
"create_annotation_router",
|
||||||
|
"create_augmentation_router",
|
||||||
"create_auth_router",
|
"create_auth_router",
|
||||||
"create_documents_router",
|
"create_documents_router",
|
||||||
"create_locks_router",
|
"create_locks_router",
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
"""Augmentation API module."""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .routes import register_augmentation_routes
|
||||||
|
|
||||||
|
|
||||||
|
def create_augmentation_router() -> APIRouter:
|
||||||
|
"""Create and configure the augmentation router."""
|
||||||
|
router = APIRouter(prefix="/augmentation", tags=["augmentation"])
|
||||||
|
register_augmentation_routes(router)
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["create_augmentation_router"]
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
"""Augmentation API routes."""
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
from inference.web.core.auth import AdminDBDep, AdminTokenDep
|
||||||
|
from inference.web.schemas.admin.augmentation import (
|
||||||
|
AugmentationBatchRequest,
|
||||||
|
AugmentationBatchResponse,
|
||||||
|
AugmentationConfigSchema,
|
||||||
|
AugmentationPreviewRequest,
|
||||||
|
AugmentationPreviewResponse,
|
||||||
|
AugmentationTypeInfo,
|
||||||
|
AugmentationTypesResponse,
|
||||||
|
AugmentedDatasetItem,
|
||||||
|
AugmentedDatasetListResponse,
|
||||||
|
PresetInfo,
|
||||||
|
PresetsResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_augmentation_routes(router: APIRouter) -> None:
|
||||||
|
"""Register augmentation endpoints on the router."""
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/types",
|
||||||
|
response_model=AugmentationTypesResponse,
|
||||||
|
summary="List available augmentation types",
|
||||||
|
)
|
||||||
|
async def list_augmentation_types(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
) -> AugmentationTypesResponse:
|
||||||
|
"""
|
||||||
|
List all available augmentation types with descriptions and parameters.
|
||||||
|
"""
|
||||||
|
from shared.augmentation.pipeline import (
|
||||||
|
AUGMENTATION_REGISTRY,
|
||||||
|
AugmentationPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
types = []
|
||||||
|
for name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||||
|
# Create instance with empty params to get preview params
|
||||||
|
aug = aug_class({})
|
||||||
|
types.append(
|
||||||
|
AugmentationTypeInfo(
|
||||||
|
name=name,
|
||||||
|
description=(aug_class.__doc__ or "").strip(),
|
||||||
|
affects_geometry=aug_class.affects_geometry,
|
||||||
|
stage=AugmentationPipeline.STAGE_MAPPING[name],
|
||||||
|
default_params=aug.get_preview_params(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AugmentationTypesResponse(augmentation_types=types)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/presets",
|
||||||
|
response_model=PresetsResponse,
|
||||||
|
summary="Get augmentation presets",
|
||||||
|
)
|
||||||
|
async def get_presets(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
) -> PresetsResponse:
|
||||||
|
"""Get predefined augmentation presets for common use cases."""
|
||||||
|
from shared.augmentation.presets import list_presets
|
||||||
|
|
||||||
|
presets = [PresetInfo(**p) for p in list_presets()]
|
||||||
|
return PresetsResponse(presets=presets)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/preview/{document_id}",
|
||||||
|
response_model=AugmentationPreviewResponse,
|
||||||
|
summary="Preview augmentation on document image",
|
||||||
|
)
|
||||||
|
async def preview_augmentation(
|
||||||
|
document_id: str,
|
||||||
|
request: AugmentationPreviewRequest,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
page: int = Query(default=1, ge=1, description="Page number"),
|
||||||
|
) -> AugmentationPreviewResponse:
|
||||||
|
"""
|
||||||
|
Preview a single augmentation on a document page.
|
||||||
|
|
||||||
|
Returns URLs to original and augmented preview images.
|
||||||
|
"""
|
||||||
|
from inference.web.services.augmentation_service import AugmentationService
|
||||||
|
|
||||||
|
service = AugmentationService(db=db)
|
||||||
|
return await service.preview_single(
|
||||||
|
document_id=document_id,
|
||||||
|
page=page,
|
||||||
|
augmentation_type=request.augmentation_type,
|
||||||
|
params=request.params,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/preview-config/{document_id}",
|
||||||
|
response_model=AugmentationPreviewResponse,
|
||||||
|
summary="Preview full augmentation config on document",
|
||||||
|
)
|
||||||
|
async def preview_config(
|
||||||
|
document_id: str,
|
||||||
|
config: AugmentationConfigSchema,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
page: int = Query(default=1, ge=1, description="Page number"),
|
||||||
|
) -> AugmentationPreviewResponse:
|
||||||
|
"""Preview complete augmentation pipeline on a document page."""
|
||||||
|
from inference.web.services.augmentation_service import AugmentationService
|
||||||
|
|
||||||
|
service = AugmentationService(db=db)
|
||||||
|
return await service.preview_config(
|
||||||
|
document_id=document_id,
|
||||||
|
page=page,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/batch",
|
||||||
|
response_model=AugmentationBatchResponse,
|
||||||
|
summary="Create augmented dataset (offline preprocessing)",
|
||||||
|
)
|
||||||
|
async def create_augmented_dataset(
|
||||||
|
request: AugmentationBatchRequest,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> AugmentationBatchResponse:
|
||||||
|
"""
|
||||||
|
Create a new augmented dataset from an existing dataset.
|
||||||
|
|
||||||
|
This runs as a background task. The augmented images are stored
|
||||||
|
alongside the original dataset for training.
|
||||||
|
"""
|
||||||
|
from inference.web.services.augmentation_service import AugmentationService
|
||||||
|
|
||||||
|
service = AugmentationService(db=db)
|
||||||
|
return await service.create_augmented_dataset(
|
||||||
|
source_dataset_id=request.dataset_id,
|
||||||
|
config=request.config,
|
||||||
|
output_name=request.output_name,
|
||||||
|
multiplier=request.multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/datasets",
|
||||||
|
response_model=AugmentedDatasetListResponse,
|
||||||
|
summary="List augmented datasets",
|
||||||
|
)
|
||||||
|
async def list_augmented_datasets(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
limit: int = Query(default=20, ge=1, le=100, description="Page size"),
|
||||||
|
offset: int = Query(default=0, ge=0, description="Offset"),
|
||||||
|
) -> AugmentedDatasetListResponse:
|
||||||
|
"""List all augmented datasets."""
|
||||||
|
from inference.web.services.augmentation_service import AugmentationService
|
||||||
|
|
||||||
|
service = AugmentationService(db=db)
|
||||||
|
return await service.list_augmented_datasets(limit=limit, offset=offset)
|
||||||
@@ -91,8 +91,19 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
bool,
|
bool,
|
||||||
Query(description="Trigger auto-labeling after upload"),
|
Query(description="Trigger auto-labeling after upload"),
|
||||||
] = True,
|
] = True,
|
||||||
|
group_key: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Optional group key for document organization", max_length=255),
|
||||||
|
] = None,
|
||||||
) -> DocumentUploadResponse:
|
) -> DocumentUploadResponse:
|
||||||
"""Upload a document for labeling."""
|
"""Upload a document for labeling."""
|
||||||
|
# Validate group_key length
|
||||||
|
if group_key and len(group_key) > 255:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Group key must be 255 characters or less",
|
||||||
|
)
|
||||||
|
|
||||||
# Validate filename
|
# Validate filename
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise HTTPException(status_code=400, detail="Filename is required")
|
raise HTTPException(status_code=400, detail="Filename is required")
|
||||||
@@ -131,6 +142,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
content_type=file.content_type or "application/octet-stream",
|
content_type=file.content_type or "application/octet-stream",
|
||||||
file_path="", # Will update after saving
|
file_path="", # Will update after saving
|
||||||
page_count=page_count,
|
page_count=page_count,
|
||||||
|
group_key=group_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save file to admin uploads
|
# Save file to admin uploads
|
||||||
@@ -177,6 +189,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
file_size=len(content),
|
file_size=len(content),
|
||||||
page_count=page_count,
|
page_count=page_count,
|
||||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||||
|
group_key=group_key,
|
||||||
auto_label_started=auto_label_started,
|
auto_label_started=auto_label_started,
|
||||||
message="Document uploaded successfully",
|
message="Document uploaded successfully",
|
||||||
)
|
)
|
||||||
@@ -277,6 +290,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
annotation_count=len(annotations),
|
annotation_count=len(annotations),
|
||||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||||
|
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||||
can_annotate=can_annotate,
|
can_annotate=can_annotate,
|
||||||
created_at=doc.created_at,
|
created_at=doc.created_at,
|
||||||
updated_at=doc.updated_at,
|
updated_at=doc.updated_at,
|
||||||
@@ -421,6 +435,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
auto_label_error=document.auto_label_error,
|
auto_label_error=document.auto_label_error,
|
||||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||||
|
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
||||||
csv_field_values=csv_field_values,
|
csv_field_values=csv_field_values,
|
||||||
can_annotate=can_annotate,
|
can_annotate=can_annotate,
|
||||||
annotation_lock_until=annotation_lock_until,
|
annotation_lock_until=annotation_lock_until,
|
||||||
@@ -548,4 +563,50 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{document_id}/group-key",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Update document group key",
|
||||||
|
description="Update the group key for a document.",
|
||||||
|
)
|
||||||
|
async def update_document_group_key(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
group_key: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="New group key (null to clear)"),
|
||||||
|
] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Update document group key."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Validate group_key length
|
||||||
|
if group_key and len(group_key) > 255:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Group key must be 255 characters or less",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify document exists
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update group key
|
||||||
|
db.update_document_group_key(document_id, group_key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "updated",
|
||||||
|
"document_id": document_id,
|
||||||
|
"group_key": group_key,
|
||||||
|
"message": "Document group key updated",
|
||||||
|
}
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from .tasks import register_task_routes
|
|||||||
from .documents import register_document_routes
|
from .documents import register_document_routes
|
||||||
from .export import register_export_routes
|
from .export import register_export_routes
|
||||||
from .datasets import register_dataset_routes
|
from .datasets import register_dataset_routes
|
||||||
|
from .models import register_model_routes
|
||||||
|
|
||||||
|
|
||||||
def create_training_router() -> APIRouter:
|
def create_training_router() -> APIRouter:
|
||||||
@@ -21,6 +22,7 @@ def create_training_router() -> APIRouter:
|
|||||||
register_document_routes(router)
|
register_document_routes(router)
|
||||||
register_export_routes(router)
|
register_export_routes(router)
|
||||||
register_dataset_routes(router)
|
register_dataset_routes(router)
|
||||||
|
register_model_routes(router)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,13 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from inference.web.services.dataset_builder import DatasetBuilder
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
# Validate minimum document count for proper train/val/test split
|
||||||
|
if len(request.document_ids) < 10:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Minimum 10 documents required for training dataset (got {len(request.document_ids)})",
|
||||||
|
)
|
||||||
|
|
||||||
dataset = db.create_dataset(
|
dataset = db.create_dataset(
|
||||||
name=request.name,
|
name=request.name,
|
||||||
description=request.description,
|
description=request.description,
|
||||||
@@ -83,6 +90,15 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
) -> DatasetListResponse:
|
) -> DatasetListResponse:
|
||||||
"""List training datasets."""
|
"""List training datasets."""
|
||||||
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
|
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
|
||||||
|
|
||||||
|
# Get active training tasks for each dataset (graceful degradation on error)
|
||||||
|
dataset_ids = [str(d.dataset_id) for d in datasets]
|
||||||
|
try:
|
||||||
|
active_tasks = db.get_active_training_tasks_for_datasets(dataset_ids)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to fetch active training tasks")
|
||||||
|
active_tasks = {}
|
||||||
|
|
||||||
return DatasetListResponse(
|
return DatasetListResponse(
|
||||||
total=total,
|
total=total,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -93,6 +109,8 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
name=d.name,
|
name=d.name,
|
||||||
description=d.description,
|
description=d.description,
|
||||||
status=d.status,
|
status=d.status,
|
||||||
|
training_status=active_tasks.get(str(d.dataset_id), {}).get("status"),
|
||||||
|
active_training_task_id=active_tasks.get(str(d.dataset_id), {}).get("task_id"),
|
||||||
total_documents=d.total_documents,
|
total_documents=d.total_documents,
|
||||||
total_images=d.total_images,
|
total_images=d.total_images,
|
||||||
total_annotations=d.total_annotations,
|
total_annotations=d.total_annotations,
|
||||||
@@ -175,6 +193,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
"/datasets/{dataset_id}/train",
|
"/datasets/{dataset_id}/train",
|
||||||
response_model=TrainingTaskResponse,
|
response_model=TrainingTaskResponse,
|
||||||
summary="Start training from dataset",
|
summary="Start training from dataset",
|
||||||
|
description="Create a training task. Set base_model_version_id in config for incremental training.",
|
||||||
)
|
)
|
||||||
async def train_from_dataset(
|
async def train_from_dataset(
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
@@ -182,7 +201,11 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
admin_token: AdminTokenDep,
|
admin_token: AdminTokenDep,
|
||||||
db: AdminDBDep,
|
db: AdminDBDep,
|
||||||
) -> TrainingTaskResponse:
|
) -> TrainingTaskResponse:
|
||||||
"""Create a training task from a dataset."""
|
"""Create a training task from a dataset.
|
||||||
|
|
||||||
|
For incremental training, set config.base_model_version_id to a model version UUID.
|
||||||
|
The training will use that model as the starting point instead of a pretrained model.
|
||||||
|
"""
|
||||||
_validate_uuid(dataset_id, "dataset_id")
|
_validate_uuid(dataset_id, "dataset_id")
|
||||||
dataset = db.get_dataset(dataset_id)
|
dataset = db.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@@ -194,16 +217,42 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
config_dict = request.config.model_dump()
|
config_dict = request.config.model_dump()
|
||||||
|
|
||||||
|
# Resolve base_model_version_id to actual model path for incremental training
|
||||||
|
base_model_version_id = config_dict.get("base_model_version_id")
|
||||||
|
if base_model_version_id:
|
||||||
|
_validate_uuid(base_model_version_id, "base_model_version_id")
|
||||||
|
base_model = db.get_model_version(base_model_version_id)
|
||||||
|
if not base_model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Base model version not found: {base_model_version_id}",
|
||||||
|
)
|
||||||
|
# Store the resolved model path for the training worker
|
||||||
|
config_dict["base_model_path"] = base_model.model_path
|
||||||
|
config_dict["base_model_version"] = base_model.version
|
||||||
|
logger.info(
|
||||||
|
"Incremental training: using model %s (%s) as base",
|
||||||
|
base_model.version,
|
||||||
|
base_model.model_path,
|
||||||
|
)
|
||||||
|
|
||||||
task_id = db.create_training_task(
|
task_id = db.create_training_task(
|
||||||
admin_token=admin_token,
|
admin_token=admin_token,
|
||||||
name=request.name,
|
name=request.name,
|
||||||
task_type="train",
|
task_type="finetune" if base_model_version_id else "train",
|
||||||
config=config_dict,
|
config=config_dict,
|
||||||
dataset_id=str(dataset.dataset_id),
|
dataset_id=str(dataset.dataset_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
message = (
|
||||||
|
f"Incremental training task created (base: v{config_dict.get('base_model_version', 'N/A')})"
|
||||||
|
if base_model_version_id
|
||||||
|
else "Training task created from dataset"
|
||||||
|
)
|
||||||
|
|
||||||
return TrainingTaskResponse(
|
return TrainingTaskResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
status=TrainingStatus.PENDING,
|
status=TrainingStatus.PENDING,
|
||||||
message="Training task created from dataset",
|
message=message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,15 +145,15 @@ def register_document_routes(router: APIRouter) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/models",
|
"/completed-tasks",
|
||||||
response_model=TrainingModelsResponse,
|
response_model=TrainingModelsResponse,
|
||||||
responses={
|
responses={
|
||||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
},
|
},
|
||||||
summary="Get trained models",
|
summary="Get completed training tasks",
|
||||||
description="Get list of trained models with metrics and download links.",
|
description="Get list of completed training tasks with metrics and download links. For model versions, use /models endpoint.",
|
||||||
)
|
)
|
||||||
async def get_training_models(
|
async def get_completed_training_tasks(
|
||||||
admin_token: AdminTokenDep,
|
admin_token: AdminTokenDep,
|
||||||
db: AdminDBDep,
|
db: AdminDBDep,
|
||||||
status: Annotated[
|
status: Annotated[
|
||||||
|
|||||||
333
packages/inference/inference/web/api/v1/admin/training/models.py
Normal file
333
packages/inference/inference/web/api/v1/admin/training/models.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""Model Version Endpoints."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query, Request
|
||||||
|
|
||||||
|
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
|
from inference.web.schemas.admin import (
|
||||||
|
ModelVersionCreateRequest,
|
||||||
|
ModelVersionUpdateRequest,
|
||||||
|
ModelVersionItem,
|
||||||
|
ModelVersionListResponse,
|
||||||
|
ModelVersionDetailResponse,
|
||||||
|
ModelVersionResponse,
|
||||||
|
ActiveModelResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ._utils import _validate_uuid
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def register_model_routes(router: APIRouter) -> None:
|
||||||
|
"""Register model version endpoints on the router."""
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/models",
|
||||||
|
response_model=ModelVersionResponse,
|
||||||
|
summary="Create model version",
|
||||||
|
description="Register a new model version for deployment.",
|
||||||
|
)
|
||||||
|
async def create_model_version(
|
||||||
|
request: ModelVersionCreateRequest,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ModelVersionResponse:
|
||||||
|
"""Create a new model version."""
|
||||||
|
if request.task_id:
|
||||||
|
_validate_uuid(request.task_id, "task_id")
|
||||||
|
if request.dataset_id:
|
||||||
|
_validate_uuid(request.dataset_id, "dataset_id")
|
||||||
|
|
||||||
|
model = db.create_model_version(
|
||||||
|
version=request.version,
|
||||||
|
name=request.name,
|
||||||
|
model_path=request.model_path,
|
||||||
|
description=request.description,
|
||||||
|
task_id=request.task_id,
|
||||||
|
dataset_id=request.dataset_id,
|
||||||
|
metrics_mAP=request.metrics_mAP,
|
||||||
|
metrics_precision=request.metrics_precision,
|
||||||
|
metrics_recall=request.metrics_recall,
|
||||||
|
document_count=request.document_count,
|
||||||
|
training_config=request.training_config,
|
||||||
|
file_size=request.file_size,
|
||||||
|
trained_at=request.trained_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelVersionResponse(
|
||||||
|
version_id=str(model.version_id),
|
||||||
|
status=model.status,
|
||||||
|
message="Model version created successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models",
|
||||||
|
response_model=ModelVersionListResponse,
|
||||||
|
summary="List model versions",
|
||||||
|
)
|
||||||
|
async def list_model_versions(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||||
|
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||||
|
offset: Annotated[int, Query(ge=0)] = 0,
|
||||||
|
) -> ModelVersionListResponse:
|
||||||
|
"""List model versions with optional status filter."""
|
||||||
|
models, total = db.get_model_versions(status=status, limit=limit, offset=offset)
|
||||||
|
return ModelVersionListResponse(
|
||||||
|
total=total,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
models=[
|
||||||
|
ModelVersionItem(
|
||||||
|
version_id=str(m.version_id),
|
||||||
|
version=m.version,
|
||||||
|
name=m.name,
|
||||||
|
status=m.status,
|
||||||
|
is_active=m.is_active,
|
||||||
|
metrics_mAP=m.metrics_mAP,
|
||||||
|
document_count=m.document_count,
|
||||||
|
trained_at=m.trained_at,
|
||||||
|
activated_at=m.activated_at,
|
||||||
|
created_at=m.created_at,
|
||||||
|
)
|
||||||
|
for m in models
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models/active",
|
||||||
|
response_model=ActiveModelResponse,
|
||||||
|
summary="Get active model",
|
||||||
|
description="Get the currently active model for inference.",
|
||||||
|
)
|
||||||
|
async def get_active_model(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ActiveModelResponse:
|
||||||
|
"""Get the currently active model version."""
|
||||||
|
model = db.get_active_model_version()
|
||||||
|
if not model:
|
||||||
|
return ActiveModelResponse(has_active_model=False, model=None)
|
||||||
|
|
||||||
|
return ActiveModelResponse(
|
||||||
|
has_active_model=True,
|
||||||
|
model=ModelVersionItem(
|
||||||
|
version_id=str(model.version_id),
|
||||||
|
version=model.version,
|
||||||
|
name=model.name,
|
||||||
|
status=model.status,
|
||||||
|
is_active=model.is_active,
|
||||||
|
metrics_mAP=model.metrics_mAP,
|
||||||
|
document_count=model.document_count,
|
||||||
|
trained_at=model.trained_at,
|
||||||
|
activated_at=model.activated_at,
|
||||||
|
created_at=model.created_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models/{version_id}",
|
||||||
|
response_model=ModelVersionDetailResponse,
|
||||||
|
summary="Get model version detail",
|
||||||
|
)
|
||||||
|
async def get_model_version(
|
||||||
|
version_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ModelVersionDetailResponse:
|
||||||
|
"""Get detailed model version information."""
|
||||||
|
_validate_uuid(version_id, "version_id")
|
||||||
|
model = db.get_model_version(version_id)
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="Model version not found")
|
||||||
|
|
||||||
|
return ModelVersionDetailResponse(
|
||||||
|
version_id=str(model.version_id),
|
||||||
|
version=model.version,
|
||||||
|
name=model.name,
|
||||||
|
description=model.description,
|
||||||
|
model_path=model.model_path,
|
||||||
|
status=model.status,
|
||||||
|
is_active=model.is_active,
|
||||||
|
task_id=str(model.task_id) if model.task_id else None,
|
||||||
|
dataset_id=str(model.dataset_id) if model.dataset_id else None,
|
||||||
|
metrics_mAP=model.metrics_mAP,
|
||||||
|
metrics_precision=model.metrics_precision,
|
||||||
|
metrics_recall=model.metrics_recall,
|
||||||
|
document_count=model.document_count,
|
||||||
|
training_config=model.training_config,
|
||||||
|
file_size=model.file_size,
|
||||||
|
trained_at=model.trained_at,
|
||||||
|
activated_at=model.activated_at,
|
||||||
|
created_at=model.created_at,
|
||||||
|
updated_at=model.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/models/{version_id}",
|
||||||
|
response_model=ModelVersionResponse,
|
||||||
|
summary="Update model version",
|
||||||
|
)
|
||||||
|
async def update_model_version(
|
||||||
|
version_id: str,
|
||||||
|
request: ModelVersionUpdateRequest,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ModelVersionResponse:
|
||||||
|
"""Update model version metadata."""
|
||||||
|
_validate_uuid(version_id, "version_id")
|
||||||
|
model = db.update_model_version(
|
||||||
|
version_id=version_id,
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
status=request.status,
|
||||||
|
)
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="Model version not found")
|
||||||
|
|
||||||
|
return ModelVersionResponse(
|
||||||
|
version_id=str(model.version_id),
|
||||||
|
status=model.status,
|
||||||
|
message="Model version updated successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/models/{version_id}/activate",
|
||||||
|
response_model=ModelVersionResponse,
|
||||||
|
summary="Activate model version",
|
||||||
|
description="Activate a model version for inference (deactivates all others).",
|
||||||
|
)
|
||||||
|
async def activate_model_version(
|
||||||
|
version_id: str,
|
||||||
|
request: Request,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ModelVersionResponse:
|
||||||
|
"""Activate a model version for inference."""
|
||||||
|
_validate_uuid(version_id, "version_id")
|
||||||
|
model = db.activate_model_version(version_id)
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="Model version not found")
|
||||||
|
|
||||||
|
# Trigger model reload in inference service
|
||||||
|
inference_service = getattr(request.app.state, "inference_service", None)
|
||||||
|
model_reloaded = False
|
||||||
|
if inference_service:
|
||||||
|
try:
|
||||||
|
model_reloaded = inference_service.reload_model()
|
||||||
|
if model_reloaded:
|
||||||
|
logger.info(f"Inference model reloaded to version {model.version}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to reload inference model: {e}")
|
||||||
|
|
||||||
|
message = "Model version activated for inference"
|
||||||
|
if model_reloaded:
|
||||||
|
message += " (model reloaded)"
|
||||||
|
|
||||||
|
return ModelVersionResponse(
|
||||||
|
version_id=str(model.version_id),
|
||||||
|
status=model.status,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/models/{version_id}/deactivate",
|
||||||
|
response_model=ModelVersionResponse,
|
||||||
|
summary="Deactivate model version",
|
||||||
|
)
|
||||||
|
async def deactivate_model_version(
|
||||||
|
version_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ModelVersionResponse:
|
||||||
|
"""Deactivate a model version."""
|
||||||
|
_validate_uuid(version_id, "version_id")
|
||||||
|
model = db.deactivate_model_version(version_id)
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail="Model version not found")
|
||||||
|
|
||||||
|
return ModelVersionResponse(
|
||||||
|
version_id=str(model.version_id),
|
||||||
|
status=model.status,
|
||||||
|
message="Model version deactivated",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/models/{version_id}/archive",
|
||||||
|
response_model=ModelVersionResponse,
|
||||||
|
summary="Archive model version",
|
||||||
|
)
|
||||||
|
async def archive_model_version(
|
||||||
|
version_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ModelVersionResponse:
|
||||||
|
"""Archive a model version."""
|
||||||
|
_validate_uuid(version_id, "version_id")
|
||||||
|
model = db.archive_model_version(version_id)
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Model version not found or cannot archive active model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelVersionResponse(
|
||||||
|
version_id=str(model.version_id),
|
||||||
|
status=model.status,
|
||||||
|
message="Model version archived",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/models/{version_id}",
|
||||||
|
summary="Delete model version",
|
||||||
|
)
|
||||||
|
async def delete_model_version(
|
||||||
|
version_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> dict:
|
||||||
|
"""Delete a model version."""
|
||||||
|
_validate_uuid(version_id, "version_id")
|
||||||
|
success = db.delete_model_version(version_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Model version not found or cannot delete active model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Model version deleted"}
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/models/reload",
|
||||||
|
summary="Reload inference model",
|
||||||
|
description="Reload the inference model from the currently active model version.",
|
||||||
|
)
|
||||||
|
async def reload_inference_model(
|
||||||
|
request: Request,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
) -> dict:
|
||||||
|
"""Reload the inference model from active version."""
|
||||||
|
inference_service = getattr(request.app.state, "inference_service", None)
|
||||||
|
if not inference_service:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Inference service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_reloaded = inference_service.reload_model()
|
||||||
|
if model_reloaded:
|
||||||
|
logger.info("Inference model manually reloaded")
|
||||||
|
return {"message": "Model reloaded successfully", "reloaded": True}
|
||||||
|
else:
|
||||||
|
return {"message": "Model already up to date", "reloaded": False}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reload model: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to reload model: {e}",
|
||||||
|
)
|
||||||
@@ -37,6 +37,7 @@ from inference.web.core.rate_limiter import RateLimiter
|
|||||||
# Admin API imports
|
# Admin API imports
|
||||||
from inference.web.api.v1.admin import (
|
from inference.web.api.v1.admin import (
|
||||||
create_annotation_router,
|
create_annotation_router,
|
||||||
|
create_augmentation_router,
|
||||||
create_auth_router,
|
create_auth_router,
|
||||||
create_documents_router,
|
create_documents_router,
|
||||||
create_locks_router,
|
create_locks_router,
|
||||||
@@ -69,10 +70,23 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
|||||||
"""
|
"""
|
||||||
config = config or default_config
|
config = config or default_config
|
||||||
|
|
||||||
# Create inference service
|
# Create model path resolver that reads from database
|
||||||
|
def get_active_model_path():
|
||||||
|
"""Resolve active model path from database."""
|
||||||
|
try:
|
||||||
|
db = AdminDB()
|
||||||
|
active_model = db.get_active_model_version()
|
||||||
|
if active_model and active_model.model_path:
|
||||||
|
return active_model.model_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get active model from database: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create inference service with database model resolver
|
||||||
inference_service = InferenceService(
|
inference_service = InferenceService(
|
||||||
model_config=config.model,
|
model_config=config.model,
|
||||||
storage_config=config.storage,
|
storage_config=config.storage,
|
||||||
|
model_path_resolver=get_active_model_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create async processing components
|
# Create async processing components
|
||||||
@@ -185,6 +199,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
|||||||
logger.error(f"Error closing database: {e}")
|
logger.error(f"Error closing database: {e}")
|
||||||
|
|
||||||
# Create FastAPI app
|
# Create FastAPI app
|
||||||
|
# Store inference service for access by routes (e.g., model reload)
|
||||||
|
# This will be set after app creation
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Invoice Field Extraction API",
|
title="Invoice Field Extraction API",
|
||||||
description="""
|
description="""
|
||||||
@@ -255,9 +272,15 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
|||||||
training_router = create_training_router()
|
training_router = create_training_router()
|
||||||
app.include_router(training_router, prefix="/api/v1")
|
app.include_router(training_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
augmentation_router = create_augmentation_router()
|
||||||
|
app.include_router(augmentation_router, prefix="/api/v1/admin")
|
||||||
|
|
||||||
# Include batch upload routes
|
# Include batch upload routes
|
||||||
app.include_router(batch_upload_router)
|
app.include_router(batch_upload_router)
|
||||||
|
|
||||||
|
# Store inference service in app state for access by routes
|
||||||
|
app.state.inference_service = inference_service
|
||||||
|
|
||||||
# Root endpoint - serve HTML UI
|
# Root endpoint - serve HTML UI
|
||||||
@app.get("/", response_class=HTMLResponse)
|
@app.get("/", response_class=HTMLResponse)
|
||||||
async def root() -> str:
|
async def root() -> str:
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ class TrainingScheduler:
|
|||||||
try:
|
try:
|
||||||
# Get training configuration
|
# Get training configuration
|
||||||
model_name = config.get("model_name", "yolo11n.pt")
|
model_name = config.get("model_name", "yolo11n.pt")
|
||||||
|
base_model_path = config.get("base_model_path") # For incremental training
|
||||||
epochs = config.get("epochs", 100)
|
epochs = config.get("epochs", 100)
|
||||||
batch_size = config.get("batch_size", 16)
|
batch_size = config.get("batch_size", 16)
|
||||||
image_size = config.get("image_size", 640)
|
image_size = config.get("image_size", 640)
|
||||||
@@ -117,12 +118,31 @@ class TrainingScheduler:
|
|||||||
device = config.get("device", "0")
|
device = config.get("device", "0")
|
||||||
project_name = config.get("project_name", "invoice_fields")
|
project_name = config.get("project_name", "invoice_fields")
|
||||||
|
|
||||||
|
# Get augmentation config if present
|
||||||
|
augmentation_config = config.get("augmentation")
|
||||||
|
augmentation_multiplier = config.get("augmentation_multiplier", 2)
|
||||||
|
|
||||||
|
# Determine which model to use as base
|
||||||
|
if base_model_path:
|
||||||
|
# Incremental training: use existing trained model
|
||||||
|
if not Path(base_model_path).exists():
|
||||||
|
raise ValueError(f"Base model not found: {base_model_path}")
|
||||||
|
effective_model = base_model_path
|
||||||
|
self._db.add_training_log(
|
||||||
|
task_id, "INFO",
|
||||||
|
f"Incremental training from: {base_model_path}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Train from pretrained model
|
||||||
|
effective_model = model_name
|
||||||
|
|
||||||
# Use dataset if available, otherwise export from scratch
|
# Use dataset if available, otherwise export from scratch
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset = self._db.get_dataset(dataset_id)
|
dataset = self._db.get_dataset(dataset_id)
|
||||||
if not dataset or not dataset.dataset_path:
|
if not dataset or not dataset.dataset_path:
|
||||||
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
||||||
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
||||||
|
dataset_path = Path(dataset.dataset_path)
|
||||||
self._db.add_training_log(
|
self._db.add_training_log(
|
||||||
task_id, "INFO",
|
task_id, "INFO",
|
||||||
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
||||||
@@ -132,15 +152,28 @@ class TrainingScheduler:
|
|||||||
if not export_result:
|
if not export_result:
|
||||||
raise ValueError("Failed to export training data")
|
raise ValueError("Failed to export training data")
|
||||||
data_yaml = export_result["data_yaml"]
|
data_yaml = export_result["data_yaml"]
|
||||||
|
dataset_path = Path(data_yaml).parent
|
||||||
self._db.add_training_log(
|
self._db.add_training_log(
|
||||||
task_id, "INFO",
|
task_id, "INFO",
|
||||||
f"Exported {export_result['total_images']} images for training",
|
f"Exported {export_result['total_images']} images for training",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply augmentation if config is provided
|
||||||
|
if augmentation_config and self._has_enabled_augmentations(augmentation_config):
|
||||||
|
aug_result = self._apply_augmentation(
|
||||||
|
task_id, dataset_path, augmentation_config, augmentation_multiplier
|
||||||
|
)
|
||||||
|
if aug_result:
|
||||||
|
self._db.add_training_log(
|
||||||
|
task_id, "INFO",
|
||||||
|
f"Augmentation complete: {aug_result['augmented_images']} new images "
|
||||||
|
f"(total: {aug_result['total_images']})",
|
||||||
|
)
|
||||||
|
|
||||||
# Run YOLO training
|
# Run YOLO training
|
||||||
result = self._run_yolo_training(
|
result = self._run_yolo_training(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
model_name=model_name,
|
model_name=effective_model, # Use base model or pretrained model
|
||||||
data_yaml=data_yaml,
|
data_yaml=data_yaml,
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@@ -159,11 +192,94 @@ class TrainingScheduler:
|
|||||||
)
|
)
|
||||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||||
|
|
||||||
|
# Auto-create model version for the completed training
|
||||||
|
self._create_model_version_from_training(
|
||||||
|
task_id=task_id,
|
||||||
|
config=config,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training task {task_id} failed: {e}")
|
logger.error(f"Training task {task_id} failed: {e}")
|
||||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _create_model_version_from_training(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
config: dict[str, Any],
|
||||||
|
dataset_id: str | None,
|
||||||
|
result: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Create a model version entry from completed training."""
|
||||||
|
try:
|
||||||
|
model_path = result.get("model_path")
|
||||||
|
if not model_path:
|
||||||
|
logger.warning(f"No model path in training result for task {task_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get task info for name
|
||||||
|
task = self._db.get_training_task(task_id)
|
||||||
|
task_name = task.name if task else f"Task {task_id[:8]}"
|
||||||
|
|
||||||
|
# Generate version number based on existing versions
|
||||||
|
existing_versions = self._db.get_model_versions(limit=1, offset=0)
|
||||||
|
version_count = existing_versions[1] if existing_versions else 0
|
||||||
|
version = f"v{version_count + 1}.0"
|
||||||
|
|
||||||
|
# Extract metrics from result
|
||||||
|
metrics = result.get("metrics", {})
|
||||||
|
metrics_mAP = metrics.get("mAP50") or metrics.get("mAP")
|
||||||
|
metrics_precision = metrics.get("precision")
|
||||||
|
metrics_recall = metrics.get("recall")
|
||||||
|
|
||||||
|
# Get file size if possible
|
||||||
|
file_size = None
|
||||||
|
model_file = Path(model_path)
|
||||||
|
if model_file.exists():
|
||||||
|
file_size = model_file.stat().st_size
|
||||||
|
|
||||||
|
# Get document count from dataset if available
|
||||||
|
document_count = 0
|
||||||
|
if dataset_id:
|
||||||
|
dataset = self._db.get_dataset(dataset_id)
|
||||||
|
if dataset:
|
||||||
|
document_count = dataset.total_documents
|
||||||
|
|
||||||
|
# Create model version
|
||||||
|
model_version = self._db.create_model_version(
|
||||||
|
version=version,
|
||||||
|
name=task_name,
|
||||||
|
model_path=str(model_path),
|
||||||
|
description=f"Auto-created from training task {task_id[:8]}",
|
||||||
|
task_id=task_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
metrics_mAP=metrics_mAP,
|
||||||
|
metrics_precision=metrics_precision,
|
||||||
|
metrics_recall=metrics_recall,
|
||||||
|
document_count=document_count,
|
||||||
|
training_config=config,
|
||||||
|
file_size=file_size,
|
||||||
|
trained_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created model version {version} (ID: {model_version.version_id}) "
|
||||||
|
f"from training task {task_id}"
|
||||||
|
)
|
||||||
|
self._db.add_training_log(
|
||||||
|
task_id, "INFO",
|
||||||
|
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create model version for task {task_id}: {e}")
|
||||||
|
self._db.add_training_log(
|
||||||
|
task_id, "WARNING",
|
||||||
|
f"Failed to auto-create model version: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||||
"""Export training data for a task."""
|
"""Export training data for a task."""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -256,62 +372,82 @@ names: {list(FIELD_CLASSES.values())}
|
|||||||
device: str,
|
device: str,
|
||||||
project_name: str,
|
project_name: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run YOLO training."""
|
"""Run YOLO training using shared trainer."""
|
||||||
|
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
|
||||||
|
|
||||||
|
# Create log callback that writes to DB
|
||||||
|
def log_callback(level: str, message: str) -> None:
|
||||||
|
self._db.add_training_log(task_id, level, message)
|
||||||
|
|
||||||
|
# Create shared training config
|
||||||
|
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
||||||
|
config = SharedTrainingConfig(
|
||||||
|
model_path=model_name,
|
||||||
|
data_yaml=data_yaml,
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
image_size=image_size,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
device=device,
|
||||||
|
project="runs/train",
|
||||||
|
name=f"{project_name}/task_{task_id[:8]}",
|
||||||
|
workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run training using shared trainer
|
||||||
|
trainer = YOLOTrainer(config=config, log_callback=log_callback)
|
||||||
|
result = trainer.train()
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
raise ValueError(result.error or "Training failed")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model_path": result.model_path,
|
||||||
|
"metrics": result.metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _has_enabled_augmentations(self, aug_config: dict[str, Any]) -> bool:
|
||||||
|
"""Check if any augmentations are enabled in the config."""
|
||||||
|
augmentation_fields = [
|
||||||
|
"perspective_warp", "wrinkle", "edge_damage", "stain",
|
||||||
|
"lighting_variation", "shadow", "gaussian_blur", "motion_blur",
|
||||||
|
"gaussian_noise", "salt_pepper", "paper_texture", "scanner_artifacts",
|
||||||
|
]
|
||||||
|
for field in augmentation_fields:
|
||||||
|
if field in aug_config:
|
||||||
|
field_config = aug_config[field]
|
||||||
|
if isinstance(field_config, dict) and field_config.get("enabled", False):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _apply_augmentation(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
dataset_path: Path,
|
||||||
|
aug_config: dict[str, Any],
|
||||||
|
multiplier: int,
|
||||||
|
) -> dict[str, int] | None:
|
||||||
|
"""Apply augmentation to dataset before training."""
|
||||||
try:
|
try:
|
||||||
from ultralytics import YOLO
|
from shared.augmentation import DatasetAugmenter
|
||||||
|
|
||||||
# Log training start
|
|
||||||
self._db.add_training_log(
|
|
||||||
task_id, "INFO",
|
|
||||||
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
model = YOLO(model_name)
|
|
||||||
|
|
||||||
# Train
|
|
||||||
results = model.train(
|
|
||||||
data=data_yaml,
|
|
||||||
epochs=epochs,
|
|
||||||
batch=batch_size,
|
|
||||||
imgsz=image_size,
|
|
||||||
lr0=learning_rate,
|
|
||||||
device=device,
|
|
||||||
project=f"runs/train/{project_name}",
|
|
||||||
name=f"task_{task_id[:8]}",
|
|
||||||
exist_ok=True,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get best model path
|
|
||||||
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
|
||||||
|
|
||||||
# Extract metrics
|
|
||||||
metrics = {}
|
|
||||||
if hasattr(results, "results_dict"):
|
|
||||||
metrics = {
|
|
||||||
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
|
||||||
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
|
||||||
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
|
||||||
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
self._db.add_training_log(
|
self._db.add_training_log(
|
||||||
task_id, "INFO",
|
task_id, "INFO",
|
||||||
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
|
f"Applying augmentation with multiplier={multiplier}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
augmenter = DatasetAugmenter(aug_config)
|
||||||
"model_path": str(best_model) if best_model.exists() else None,
|
result = augmenter.augment_dataset(dataset_path, multiplier=multiplier)
|
||||||
"metrics": metrics,
|
|
||||||
}
|
return result
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
|
|
||||||
raise ValueError("Ultralytics (YOLO) not installed")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
|
logger.error(f"Augmentation failed for task {task_id}: {e}")
|
||||||
raise
|
self._db.add_training_log(
|
||||||
|
task_id, "WARNING",
|
||||||
|
f"Augmentation failed: {e}. Continuing with original dataset.",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Global scheduler instance
|
# Global scheduler instance
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from .documents import * # noqa: F401, F403
|
|||||||
from .annotations import * # noqa: F401, F403
|
from .annotations import * # noqa: F401, F403
|
||||||
from .training import * # noqa: F401, F403
|
from .training import * # noqa: F401, F403
|
||||||
from .datasets import * # noqa: F401, F403
|
from .datasets import * # noqa: F401, F403
|
||||||
|
from .models import * # noqa: F401, F403
|
||||||
|
|
||||||
# Resolve forward references for DocumentDetailResponse
|
# Resolve forward references for DocumentDetailResponse
|
||||||
from .documents import DocumentDetailResponse
|
from .documents import DocumentDetailResponse
|
||||||
|
|||||||
187
packages/inference/inference/web/schemas/admin/augmentation.py
Normal file
187
packages/inference/inference/web/schemas/admin/augmentation.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""Admin Augmentation Schemas."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationParamsSchema(BaseModel):
|
||||||
|
"""Single augmentation parameters."""
|
||||||
|
|
||||||
|
enabled: bool = Field(default=False, description="Whether this augmentation is enabled")
|
||||||
|
probability: float = Field(
|
||||||
|
default=0.5, ge=0, le=1, description="Probability of applying (0-1)"
|
||||||
|
)
|
||||||
|
params: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Type-specific parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationConfigSchema(BaseModel):
|
||||||
|
"""Complete augmentation configuration."""
|
||||||
|
|
||||||
|
# Geometric transforms
|
||||||
|
perspective_warp: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Degradation effects
|
||||||
|
wrinkle: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||||
|
edge_damage: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
stain: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||||
|
|
||||||
|
# Lighting effects
|
||||||
|
lighting_variation: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
shadow: AugmentationParamsSchema = Field(default_factory=AugmentationParamsSchema)
|
||||||
|
|
||||||
|
# Blur effects
|
||||||
|
gaussian_blur: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
motion_blur: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Noise effects
|
||||||
|
gaussian_noise: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
salt_pepper: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Texture effects
|
||||||
|
paper_texture: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
scanner_artifacts: AugmentationParamsSchema = Field(
|
||||||
|
default_factory=AugmentationParamsSchema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global settings
|
||||||
|
preserve_bboxes: bool = Field(
|
||||||
|
default=True, description="Whether to adjust bboxes for geometric transforms"
|
||||||
|
)
|
||||||
|
seed: int | None = Field(default=None, description="Random seed for reproducibility")
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationTypeInfo(BaseModel):
|
||||||
|
"""Information about an augmentation type."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Augmentation name")
|
||||||
|
description: str = Field(..., description="Augmentation description")
|
||||||
|
affects_geometry: bool = Field(
|
||||||
|
..., description="Whether this augmentation affects bbox coordinates"
|
||||||
|
)
|
||||||
|
stage: str = Field(..., description="Processing stage")
|
||||||
|
default_params: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Default parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationTypesResponse(BaseModel):
|
||||||
|
"""Response for listing augmentation types."""
|
||||||
|
|
||||||
|
augmentation_types: list[AugmentationTypeInfo] = Field(
|
||||||
|
..., description="Available augmentation types"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PresetInfo(BaseModel):
|
||||||
|
"""Information about a preset."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Preset name")
|
||||||
|
description: str = Field(..., description="Preset description")
|
||||||
|
|
||||||
|
|
||||||
|
class PresetsResponse(BaseModel):
|
||||||
|
"""Response for listing presets."""
|
||||||
|
|
||||||
|
presets: list[PresetInfo] = Field(..., description="Available presets")
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationPreviewRequest(BaseModel):
|
||||||
|
"""Request to preview augmentation on an image."""
|
||||||
|
|
||||||
|
augmentation_type: str = Field(..., description="Type of augmentation to preview")
|
||||||
|
params: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Override parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationPreviewResponse(BaseModel):
|
||||||
|
"""Response with preview image data."""
|
||||||
|
|
||||||
|
preview_url: str = Field(..., description="URL to preview image")
|
||||||
|
original_url: str = Field(..., description="URL to original image")
|
||||||
|
applied_params: dict[str, Any] = Field(..., description="Applied parameters")
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationBatchRequest(BaseModel):
|
||||||
|
"""Request to augment a dataset offline."""
|
||||||
|
|
||||||
|
dataset_id: str = Field(..., description="Source dataset UUID")
|
||||||
|
config: AugmentationConfigSchema = Field(..., description="Augmentation config")
|
||||||
|
output_name: str = Field(
|
||||||
|
..., min_length=1, max_length=255, description="Output dataset name"
|
||||||
|
)
|
||||||
|
multiplier: int = Field(
|
||||||
|
default=2, ge=1, le=10, description="Augmented copies per image"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationBatchResponse(BaseModel):
|
||||||
|
"""Response for batch augmentation."""
|
||||||
|
|
||||||
|
task_id: str = Field(..., description="Background task UUID")
|
||||||
|
status: str = Field(..., description="Task status")
|
||||||
|
message: str = Field(..., description="Status message")
|
||||||
|
estimated_images: int = Field(..., description="Estimated total images")
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentedDatasetItem(BaseModel):
|
||||||
|
"""Single augmented dataset in list."""
|
||||||
|
|
||||||
|
dataset_id: str = Field(..., description="Dataset UUID")
|
||||||
|
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||||
|
name: str = Field(..., description="Dataset name")
|
||||||
|
status: str = Field(..., description="Dataset status")
|
||||||
|
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||||
|
total_original_images: int = Field(..., description="Original image count")
|
||||||
|
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||||
|
created_at: datetime = Field(..., description="Creation timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentedDatasetListResponse(BaseModel):
|
||||||
|
"""Response for listing augmented datasets."""
|
||||||
|
|
||||||
|
total: int = Field(..., ge=0, description="Total datasets")
|
||||||
|
limit: int = Field(..., ge=1, description="Page size")
|
||||||
|
offset: int = Field(..., ge=0, description="Current offset")
|
||||||
|
datasets: list[AugmentedDatasetItem] = Field(
|
||||||
|
default_factory=list, description="Dataset list"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentedDatasetDetailResponse(BaseModel):
|
||||||
|
"""Detailed augmented dataset response."""
|
||||||
|
|
||||||
|
dataset_id: str = Field(..., description="Dataset UUID")
|
||||||
|
source_dataset_id: str = Field(..., description="Source dataset UUID")
|
||||||
|
name: str = Field(..., description="Dataset name")
|
||||||
|
status: str = Field(..., description="Dataset status")
|
||||||
|
config: AugmentationConfigSchema | None = Field(
|
||||||
|
None, description="Augmentation config used"
|
||||||
|
)
|
||||||
|
multiplier: int = Field(..., description="Augmentation multiplier")
|
||||||
|
total_original_images: int = Field(..., description="Original image count")
|
||||||
|
total_augmented_images: int = Field(..., description="Augmented image count")
|
||||||
|
dataset_path: str | None = Field(None, description="Dataset path on disk")
|
||||||
|
error_message: str | None = Field(None, description="Error message if failed")
|
||||||
|
created_at: datetime = Field(..., description="Creation timestamp")
|
||||||
|
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||||
@@ -63,6 +63,8 @@ class DatasetListItem(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: str | None
|
description: str | None
|
||||||
status: str
|
status: str
|
||||||
|
training_status: str | None = None
|
||||||
|
active_training_task_id: str | None = None
|
||||||
total_documents: int
|
total_documents: int
|
||||||
total_images: int
|
total_images: int
|
||||||
total_annotations: int
|
total_annotations: int
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
|
|||||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||||
status: DocumentStatus = Field(..., description="Document status")
|
status: DocumentStatus = Field(..., description="Document status")
|
||||||
|
group_key: str | None = Field(None, description="User-defined group key")
|
||||||
auto_label_started: bool = Field(
|
auto_label_started: bool = Field(
|
||||||
default=False, description="Whether auto-labeling was started"
|
default=False, description="Whether auto-labeling was started"
|
||||||
)
|
)
|
||||||
@@ -42,6 +43,7 @@ class DocumentItem(BaseModel):
|
|||||||
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
||||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||||
|
group_key: str | None = Field(None, description="User-defined group key")
|
||||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||||
created_at: datetime = Field(..., description="Creation timestamp")
|
created_at: datetime = Field(..., description="Creation timestamp")
|
||||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||||
@@ -73,6 +75,7 @@ class DocumentDetailResponse(BaseModel):
|
|||||||
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
||||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||||
|
group_key: str | None = Field(None, description="User-defined group key")
|
||||||
csv_field_values: dict[str, str] | None = Field(
|
csv_field_values: dict[str, str] | None = Field(
|
||||||
None, description="CSV field values if uploaded via batch"
|
None, description="CSV field values if uploaded via batch"
|
||||||
)
|
)
|
||||||
|
|||||||
95
packages/inference/inference/web/schemas/admin/models.py
Normal file
95
packages/inference/inference/web/schemas/admin/models.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""Admin Model Version Schemas."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVersionCreateRequest(BaseModel):
|
||||||
|
"""Request to create a model version."""
|
||||||
|
|
||||||
|
version: str = Field(..., min_length=1, max_length=50, description="Semantic version")
|
||||||
|
name: str = Field(..., min_length=1, max_length=255, description="Model name")
|
||||||
|
model_path: str = Field(..., min_length=1, max_length=512, description="Path to model file")
|
||||||
|
description: str | None = Field(None, description="Optional description")
|
||||||
|
task_id: str | None = Field(None, description="Training task UUID")
|
||||||
|
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||||
|
metrics_mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||||
|
metrics_precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||||
|
metrics_recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||||
|
document_count: int = Field(0, ge=0, description="Documents used in training")
|
||||||
|
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||||
|
file_size: int | None = Field(None, ge=0, description="Model file size in bytes")
|
||||||
|
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVersionUpdateRequest(BaseModel):
|
||||||
|
"""Request to update a model version."""
|
||||||
|
|
||||||
|
name: str | None = Field(None, min_length=1, max_length=255, description="Model name")
|
||||||
|
description: str | None = Field(None, description="Description")
|
||||||
|
status: str | None = Field(None, description="Status (inactive, archived)")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVersionItem(BaseModel):
|
||||||
|
"""Model version in list view."""
|
||||||
|
|
||||||
|
version_id: str = Field(..., description="Version UUID")
|
||||||
|
version: str = Field(..., description="Semantic version")
|
||||||
|
name: str = Field(..., description="Model name")
|
||||||
|
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||||
|
is_active: bool = Field(..., description="Is currently active for inference")
|
||||||
|
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||||
|
document_count: int = Field(..., description="Documents used in training")
|
||||||
|
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||||
|
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||||
|
created_at: datetime = Field(..., description="Creation timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVersionListResponse(BaseModel):
|
||||||
|
"""Paginated model version list."""
|
||||||
|
|
||||||
|
total: int = Field(..., ge=0, description="Total model versions")
|
||||||
|
limit: int = Field(..., ge=1, description="Page size")
|
||||||
|
offset: int = Field(..., ge=0, description="Current offset")
|
||||||
|
models: list[ModelVersionItem] = Field(default_factory=list, description="Model versions")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVersionDetailResponse(BaseModel):
|
||||||
|
"""Detailed model version info."""
|
||||||
|
|
||||||
|
version_id: str = Field(..., description="Version UUID")
|
||||||
|
version: str = Field(..., description="Semantic version")
|
||||||
|
name: str = Field(..., description="Model name")
|
||||||
|
description: str | None = Field(None, description="Description")
|
||||||
|
model_path: str = Field(..., description="Path to model file")
|
||||||
|
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||||
|
is_active: bool = Field(..., description="Is currently active for inference")
|
||||||
|
task_id: str | None = Field(None, description="Training task UUID")
|
||||||
|
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||||
|
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||||
|
metrics_precision: float | None = Field(None, description="Precision")
|
||||||
|
metrics_recall: float | None = Field(None, description="Recall")
|
||||||
|
document_count: int = Field(..., description="Documents used in training")
|
||||||
|
training_config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||||
|
file_size: int | None = Field(None, description="Model file size in bytes")
|
||||||
|
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||||
|
activated_at: datetime | None = Field(None, description="Last activation time")
|
||||||
|
created_at: datetime = Field(..., description="Creation timestamp")
|
||||||
|
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVersionResponse(BaseModel):
|
||||||
|
"""Response for model version operation."""
|
||||||
|
|
||||||
|
version_id: str = Field(..., description="Version UUID")
|
||||||
|
status: str = Field(..., description="Model status")
|
||||||
|
message: str = Field(..., description="Status message")
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveModelResponse(BaseModel):
|
||||||
|
"""Response for active model query."""
|
||||||
|
|
||||||
|
has_active_model: bool = Field(..., description="Whether an active model exists")
|
||||||
|
model: ModelVersionItem | None = Field(None, description="Active model if exists")
|
||||||
@@ -5,13 +5,18 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .augmentation import AugmentationConfigSchema
|
||||||
from .enums import TrainingStatus, TrainingType
|
from .enums import TrainingStatus, TrainingType
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
"""Training configuration."""
|
"""Training configuration."""
|
||||||
|
|
||||||
model_name: str = Field(default="yolo11n.pt", description="Base model name")
|
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
|
||||||
|
base_model_version_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
|
||||||
|
)
|
||||||
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
||||||
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
||||||
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
||||||
@@ -21,6 +26,18 @@ class TrainingConfig(BaseModel):
|
|||||||
default="invoice_fields", description="Training project name"
|
default="invoice_fields", description="Training project name"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Data augmentation settings
|
||||||
|
augmentation: AugmentationConfigSchema | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Augmentation configuration. If provided, augments dataset before training.",
|
||||||
|
)
|
||||||
|
augmentation_multiplier: int = Field(
|
||||||
|
default=2,
|
||||||
|
ge=1,
|
||||||
|
le=10,
|
||||||
|
description="Number of augmented copies per original image",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainingTaskCreate(BaseModel):
|
class TrainingTaskCreate(BaseModel):
|
||||||
"""Request to create a training task."""
|
"""Request to create a training task."""
|
||||||
|
|||||||
@@ -0,0 +1,317 @@
|
|||||||
|
"""Augmentation service for handling augmentation operations."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
from inference.web.schemas.admin.augmentation import (
|
||||||
|
AugmentationBatchResponse,
|
||||||
|
AugmentationConfigSchema,
|
||||||
|
AugmentationPreviewResponse,
|
||||||
|
AugmentedDatasetItem,
|
||||||
|
AugmentedDatasetListResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
PREVIEW_MAX_SIZE = 800
|
||||||
|
PREVIEW_SEED = 42
|
||||||
|
UUID_PATTERN = re.compile(
|
||||||
|
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationService:
|
||||||
|
"""Service for augmentation operations."""
|
||||||
|
|
||||||
|
def __init__(self, db: AdminDB) -> None:
|
||||||
|
"""Initialize service with database connection."""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def _validate_uuid(self, value: str, field_name: str = "ID") -> None:
|
||||||
|
"""
|
||||||
|
Validate UUID format to prevent path traversal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to validate.
|
||||||
|
field_name: Field name for error message.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If value is not a valid UUID.
|
||||||
|
"""
|
||||||
|
if not UUID_PATTERN.match(value):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid {field_name} format: {value}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def preview_single(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page: int,
|
||||||
|
augmentation_type: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> AugmentationPreviewResponse:
|
||||||
|
"""
|
||||||
|
Preview a single augmentation on a document page.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document UUID.
|
||||||
|
page: Page number (1-indexed).
|
||||||
|
augmentation_type: Name of augmentation to apply.
|
||||||
|
params: Override parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preview response with image URLs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If document not found or augmentation invalid.
|
||||||
|
"""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY, AugmentationPipeline
|
||||||
|
|
||||||
|
# Validate augmentation type
|
||||||
|
if augmentation_type not in AUGMENTATION_REGISTRY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unknown augmentation type: {augmentation_type}. "
|
||||||
|
f"Available: {list(AUGMENTATION_REGISTRY.keys())}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get document and load image
|
||||||
|
image = await self._load_document_page(document_id, page)
|
||||||
|
|
||||||
|
# Create config with only this augmentation enabled
|
||||||
|
config_kwargs = {
|
||||||
|
augmentation_type: AugmentationParams(
|
||||||
|
enabled=True,
|
||||||
|
probability=1.0, # Always apply for preview
|
||||||
|
params=params,
|
||||||
|
),
|
||||||
|
"seed": PREVIEW_SEED, # Deterministic preview
|
||||||
|
}
|
||||||
|
config = AugmentationConfig(**config_kwargs)
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
# Apply augmentation
|
||||||
|
result = pipeline.apply(image)
|
||||||
|
|
||||||
|
# Convert to base64 URLs
|
||||||
|
original_url = self._image_to_data_url(image)
|
||||||
|
preview_url = self._image_to_data_url(result.image)
|
||||||
|
|
||||||
|
return AugmentationPreviewResponse(
|
||||||
|
preview_url=preview_url,
|
||||||
|
original_url=original_url,
|
||||||
|
applied_params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def preview_config(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page: int,
|
||||||
|
config: AugmentationConfigSchema,
|
||||||
|
) -> AugmentationPreviewResponse:
|
||||||
|
"""
|
||||||
|
Preview full augmentation config on a document page.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document UUID.
|
||||||
|
page: Page number (1-indexed).
|
||||||
|
config: Full augmentation configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preview response with image URLs.
|
||||||
|
"""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
image = await self._load_document_page(document_id, page)
|
||||||
|
|
||||||
|
# Convert Pydantic model to internal config
|
||||||
|
config_dict = config.model_dump()
|
||||||
|
internal_config = AugmentationConfig.from_dict(config_dict)
|
||||||
|
pipeline = AugmentationPipeline(internal_config)
|
||||||
|
|
||||||
|
# Apply augmentation
|
||||||
|
result = pipeline.apply(image)
|
||||||
|
|
||||||
|
# Convert to base64 URLs
|
||||||
|
original_url = self._image_to_data_url(image)
|
||||||
|
preview_url = self._image_to_data_url(result.image)
|
||||||
|
|
||||||
|
return AugmentationPreviewResponse(
|
||||||
|
preview_url=preview_url,
|
||||||
|
original_url=original_url,
|
||||||
|
applied_params=config_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_augmented_dataset(
|
||||||
|
self,
|
||||||
|
source_dataset_id: str,
|
||||||
|
config: AugmentationConfigSchema,
|
||||||
|
output_name: str,
|
||||||
|
multiplier: int,
|
||||||
|
) -> AugmentationBatchResponse:
|
||||||
|
"""
|
||||||
|
Create a new augmented dataset from an existing dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_dataset_id: Source dataset UUID.
|
||||||
|
config: Augmentation configuration.
|
||||||
|
output_name: Name for the new dataset.
|
||||||
|
multiplier: Number of augmented copies per image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batch response with task ID.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If source dataset not found.
|
||||||
|
"""
|
||||||
|
# Validate source dataset exists
|
||||||
|
try:
|
||||||
|
source_dataset = self.db.get_dataset(source_dataset_id)
|
||||||
|
if source_dataset is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Source dataset not found: {source_dataset_id}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Source dataset not found: {source_dataset_id}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Create task ID for background processing
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Estimate total images
|
||||||
|
estimated_images = (
|
||||||
|
source_dataset.total_images * multiplier
|
||||||
|
if hasattr(source_dataset, "total_images")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Queue background task for actual augmentation
|
||||||
|
# For now, return pending status
|
||||||
|
|
||||||
|
return AugmentationBatchResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
status="pending",
|
||||||
|
message=f"Augmentation task queued for dataset '{output_name}'",
|
||||||
|
estimated_images=estimated_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_augmented_datasets(
|
||||||
|
self,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> AugmentedDatasetListResponse:
|
||||||
|
"""
|
||||||
|
List augmented datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of datasets to return.
|
||||||
|
offset: Number of datasets to skip.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List response with datasets.
|
||||||
|
"""
|
||||||
|
# TODO: Implement actual database query for augmented datasets
|
||||||
|
# For now, return empty list
|
||||||
|
|
||||||
|
return AugmentedDatasetListResponse(
|
||||||
|
total=0,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
datasets=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _load_document_page(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Load a document page as numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document UUID.
|
||||||
|
page: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image as numpy array (H, W, C) with dtype uint8.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If document or page not found.
|
||||||
|
"""
|
||||||
|
# Validate document_id format to prevent path traversal
|
||||||
|
self._validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Get document from database
|
||||||
|
try:
|
||||||
|
document = self.db.get_document(document_id)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Document not found: {document_id}",
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Document not found: {document_id}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Get image path for page
|
||||||
|
if hasattr(document, "images_dir"):
|
||||||
|
images_dir = Path(document.images_dir)
|
||||||
|
else:
|
||||||
|
# Fallback to constructed path
|
||||||
|
from inference.web.core.config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
images_dir = Path(settings.admin_storage_path) / "documents" / document_id / "images"
|
||||||
|
|
||||||
|
# Find image for page
|
||||||
|
page_idx = page - 1 # Convert to 0-indexed
|
||||||
|
image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
|
||||||
|
|
||||||
|
if page_idx >= len(image_files):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Page {page} not found for document {document_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
image_path = image_files[page_idx]
|
||||||
|
pil_image = Image.open(image_path).convert("RGB")
|
||||||
|
return np.array(pil_image)
|
||||||
|
|
||||||
|
def _image_to_data_url(self, image: np.ndarray) -> str:
|
||||||
|
"""Convert numpy image to base64 data URL."""
|
||||||
|
pil_image = Image.fromarray(image)
|
||||||
|
|
||||||
|
# Resize for preview if too large
|
||||||
|
max_size = PREVIEW_MAX_SIZE
|
||||||
|
if max(pil_image.size) > max_size:
|
||||||
|
ratio = max_size / max(pil_image.size)
|
||||||
|
new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))
|
||||||
|
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# Convert to base64
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
pil_image.save(buffer, format="PNG")
|
||||||
|
base64_data = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
return f"data:image/png;base64,{base64_data}"
|
||||||
@@ -81,29 +81,18 @@ class DatasetBuilder:
|
|||||||
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||||
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# 3. Shuffle and split documents
|
# 3. Group documents by group_key and assign splits
|
||||||
doc_list = list(documents)
|
doc_list = list(documents)
|
||||||
rng = random.Random(seed)
|
doc_splits = self._assign_splits_by_group(doc_list, train_ratio, val_ratio, seed)
|
||||||
rng.shuffle(doc_list)
|
|
||||||
|
|
||||||
n = len(doc_list)
|
|
||||||
n_train = max(1, round(n * train_ratio))
|
|
||||||
n_val = max(0, round(n * val_ratio))
|
|
||||||
n_test = n - n_train - n_val
|
|
||||||
|
|
||||||
splits = (
|
|
||||||
["train"] * n_train
|
|
||||||
+ ["val"] * n_val
|
|
||||||
+ ["test"] * n_test
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Process each document
|
# 4. Process each document
|
||||||
total_images = 0
|
total_images = 0
|
||||||
total_annotations = 0
|
total_annotations = 0
|
||||||
dataset_docs = []
|
dataset_docs = []
|
||||||
|
|
||||||
for doc, split in zip(doc_list, splits):
|
for doc in doc_list:
|
||||||
doc_id = str(doc.document_id)
|
doc_id = str(doc.document_id)
|
||||||
|
split = doc_splits[doc_id]
|
||||||
annotations = self._db.get_annotations_for_document(doc.document_id)
|
annotations = self._db.get_annotations_for_document(doc.document_id)
|
||||||
|
|
||||||
# Group annotations by page
|
# Group annotations by page
|
||||||
@@ -174,6 +163,86 @@ class DatasetBuilder:
|
|||||||
"total_annotations": total_annotations,
|
"total_annotations": total_annotations,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _assign_splits_by_group(
|
||||||
|
self,
|
||||||
|
documents: list,
|
||||||
|
train_ratio: float,
|
||||||
|
val_ratio: float,
|
||||||
|
seed: int,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Assign splits based on group_key.
|
||||||
|
|
||||||
|
Logic:
|
||||||
|
- Documents with same group_key stay together in the same split
|
||||||
|
- Groups with only 1 document go directly to train
|
||||||
|
- Groups with 2+ documents participate in shuffle & split
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of AdminDocument objects
|
||||||
|
train_ratio: Fraction for training set
|
||||||
|
val_ratio: Fraction for validation set
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping document_id (str) -> split ("train"/"val"/"test")
|
||||||
|
"""
|
||||||
|
# Group documents by group_key
|
||||||
|
# None/empty group_key treated as unique (each doc is its own group)
|
||||||
|
groups: dict[str | None, list] = {}
|
||||||
|
for doc in documents:
|
||||||
|
key = doc.group_key if doc.group_key else None
|
||||||
|
if key is None:
|
||||||
|
# Treat each ungrouped doc as its own unique group
|
||||||
|
# Use document_id as pseudo-key
|
||||||
|
key = f"__ungrouped_{doc.document_id}"
|
||||||
|
groups.setdefault(key, []).append(doc)
|
||||||
|
|
||||||
|
# Separate single-doc groups from multi-doc groups
|
||||||
|
single_doc_groups: list[tuple[str | None, list]] = []
|
||||||
|
multi_doc_groups: list[tuple[str | None, list]] = []
|
||||||
|
|
||||||
|
for key, docs in groups.items():
|
||||||
|
if len(docs) == 1:
|
||||||
|
single_doc_groups.append((key, docs))
|
||||||
|
else:
|
||||||
|
multi_doc_groups.append((key, docs))
|
||||||
|
|
||||||
|
# Initialize result mapping
|
||||||
|
doc_splits: dict[str, str] = {}
|
||||||
|
|
||||||
|
# Combine all groups for splitting
|
||||||
|
all_groups = single_doc_groups + multi_doc_groups
|
||||||
|
|
||||||
|
# Shuffle all groups and assign splits
|
||||||
|
if all_groups:
|
||||||
|
rng = random.Random(seed)
|
||||||
|
rng.shuffle(all_groups)
|
||||||
|
|
||||||
|
n_groups = len(all_groups)
|
||||||
|
n_train = max(1, round(n_groups * train_ratio))
|
||||||
|
# Ensure at least 1 in val if we have more than 1 group
|
||||||
|
n_val = max(1 if n_groups > 1 else 0, round(n_groups * val_ratio))
|
||||||
|
|
||||||
|
for i, (_key, docs) in enumerate(all_groups):
|
||||||
|
if i < n_train:
|
||||||
|
split = "train"
|
||||||
|
elif i < n_train + n_val:
|
||||||
|
split = "val"
|
||||||
|
else:
|
||||||
|
split = "test"
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
doc_splits[str(doc.document_id)] = split
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Split assignment: %d total groups shuffled (train=%d, val=%d)",
|
||||||
|
len(all_groups),
|
||||||
|
sum(1 for s in doc_splits.values() if s == "train"),
|
||||||
|
sum(1 for s in doc_splits.values() if s == "val"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return doc_splits
|
||||||
|
|
||||||
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
||||||
"""Generate YOLO data.yaml configuration file."""
|
"""Generate YOLO data.yaml configuration file."""
|
||||||
data = {
|
data = {
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -22,6 +22,10 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for model path resolver function
|
||||||
|
ModelPathResolver = Callable[[], Path | None]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ServiceResult:
|
class ServiceResult:
|
||||||
"""Result from inference service."""
|
"""Result from inference service."""
|
||||||
@@ -42,25 +46,52 @@ class InferenceService:
|
|||||||
Service for running invoice field extraction.
|
Service for running invoice field extraction.
|
||||||
|
|
||||||
Encapsulates YOLO detection and OCR extraction logic.
|
Encapsulates YOLO detection and OCR extraction logic.
|
||||||
|
Supports dynamic model loading from database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
storage_config: StorageConfig,
|
storage_config: StorageConfig,
|
||||||
|
model_path_resolver: ModelPathResolver | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize inference service.
|
Initialize inference service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_config: Model configuration
|
model_config: Model configuration (default model settings)
|
||||||
storage_config: Storage configuration
|
storage_config: Storage configuration
|
||||||
|
model_path_resolver: Optional function to resolve model path from database.
|
||||||
|
If provided, will be called to get active model path.
|
||||||
|
If returns None, falls back to model_config.model_path.
|
||||||
"""
|
"""
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.storage_config = storage_config
|
self.storage_config = storage_config
|
||||||
|
self._model_path_resolver = model_path_resolver
|
||||||
self._pipeline = None
|
self._pipeline = None
|
||||||
self._detector = None
|
self._detector = None
|
||||||
self._is_initialized = False
|
self._is_initialized = False
|
||||||
|
self._current_model_path: Path | None = None
|
||||||
|
|
||||||
|
def _resolve_model_path(self) -> Path:
|
||||||
|
"""Resolve the model path to use for inference.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. Active model from database (via resolver)
|
||||||
|
2. Default model from config
|
||||||
|
"""
|
||||||
|
if self._model_path_resolver:
|
||||||
|
try:
|
||||||
|
db_model_path = self._model_path_resolver()
|
||||||
|
if db_model_path and Path(db_model_path).exists():
|
||||||
|
logger.info(f"Using active model from database: {db_model_path}")
|
||||||
|
return Path(db_model_path)
|
||||||
|
elif db_model_path:
|
||||||
|
logger.warning(f"Active model path does not exist: {db_model_path}, falling back to default")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to resolve model path from database: {e}, falling back to default")
|
||||||
|
|
||||||
|
return self.model_config.model_path
|
||||||
|
|
||||||
def initialize(self) -> None:
|
def initialize(self) -> None:
|
||||||
"""Initialize the inference pipeline (lazy loading)."""
|
"""Initialize the inference pipeline (lazy loading)."""
|
||||||
@@ -74,16 +105,20 @@ class InferenceService:
|
|||||||
from inference.pipeline.pipeline import InferencePipeline
|
from inference.pipeline.pipeline import InferencePipeline
|
||||||
from inference.pipeline.yolo_detector import YOLODetector
|
from inference.pipeline.yolo_detector import YOLODetector
|
||||||
|
|
||||||
|
# Resolve model path (from DB or config)
|
||||||
|
model_path = self._resolve_model_path()
|
||||||
|
self._current_model_path = model_path
|
||||||
|
|
||||||
# Initialize YOLO detector for visualization
|
# Initialize YOLO detector for visualization
|
||||||
self._detector = YOLODetector(
|
self._detector = YOLODetector(
|
||||||
str(self.model_config.model_path),
|
str(model_path),
|
||||||
confidence_threshold=self.model_config.confidence_threshold,
|
confidence_threshold=self.model_config.confidence_threshold,
|
||||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize full pipeline
|
# Initialize full pipeline
|
||||||
self._pipeline = InferencePipeline(
|
self._pipeline = InferencePipeline(
|
||||||
model_path=str(self.model_config.model_path),
|
model_path=str(model_path),
|
||||||
confidence_threshold=self.model_config.confidence_threshold,
|
confidence_threshold=self.model_config.confidence_threshold,
|
||||||
use_gpu=self.model_config.use_gpu,
|
use_gpu=self.model_config.use_gpu,
|
||||||
dpi=self.model_config.dpi,
|
dpi=self.model_config.dpi,
|
||||||
@@ -92,12 +127,36 @@ class InferenceService:
|
|||||||
|
|
||||||
self._is_initialized = True
|
self._is_initialized = True
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
logger.info(f"Inference service initialized in {elapsed:.2f}s")
|
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize inference service: {e}")
|
logger.error(f"Failed to initialize inference service: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def reload_model(self) -> bool:
|
||||||
|
"""Reload the model if active model has changed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model was reloaded, False if no change needed.
|
||||||
|
"""
|
||||||
|
new_model_path = self._resolve_model_path()
|
||||||
|
|
||||||
|
if self._current_model_path == new_model_path:
|
||||||
|
logger.debug("Model unchanged, no reload needed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"Reloading model: {self._current_model_path} -> {new_model_path}")
|
||||||
|
self._is_initialized = False
|
||||||
|
self._pipeline = None
|
||||||
|
self._detector = None
|
||||||
|
self.initialize()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_model_path(self) -> Path | None:
|
||||||
|
"""Get the currently loaded model path."""
|
||||||
|
return self._current_model_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_initialized(self) -> bool:
|
def is_initialized(self) -> bool:
|
||||||
"""Check if service is initialized."""
|
"""Check if service is initialized."""
|
||||||
|
|||||||
24
packages/shared/shared/augmentation/__init__.py
Normal file
24
packages/shared/shared/augmentation/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Document Image Augmentation Module.
|
||||||
|
|
||||||
|
Provides augmentation transformations for training data enhancement,
|
||||||
|
specifically designed for document images (invoices, forms, etc.).
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Document-safe augmentations that preserve text readability
|
||||||
|
- Support for both offline preprocessing and runtime augmentation
|
||||||
|
- Bbox-aware geometric transforms
|
||||||
|
- Configurable augmentation pipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AugmentationConfig",
|
||||||
|
"AugmentationParams",
|
||||||
|
"AugmentationResult",
|
||||||
|
"BaseAugmentation",
|
||||||
|
"DatasetAugmenter",
|
||||||
|
]
|
||||||
108
packages/shared/shared/augmentation/base.py
Normal file
108
packages/shared/shared/augmentation/base.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Base classes for augmentation transforms.
|
||||||
|
|
||||||
|
Provides abstract base class and result dataclass for all augmentation
|
||||||
|
implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AugmentationResult:
|
||||||
|
"""
|
||||||
|
Result of applying an augmentation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
image: The augmented image as numpy array (H, W, C).
|
||||||
|
bboxes: Updated bounding boxes if geometric transform was applied.
|
||||||
|
Format: (N, 5) array with [class_id, x_center, y_center, width, height].
|
||||||
|
transform_matrix: The transformation matrix if applicable (for bbox adjustment).
|
||||||
|
applied: Whether the augmentation was actually applied.
|
||||||
|
metadata: Additional metadata about the augmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image: np.ndarray
|
||||||
|
bboxes: np.ndarray | None = None
|
||||||
|
transform_matrix: np.ndarray | None = None
|
||||||
|
applied: bool = True
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAugmentation(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all augmentations.
|
||||||
|
|
||||||
|
Subclasses must implement:
|
||||||
|
- _validate_params(): Validate augmentation parameters
|
||||||
|
- apply(): Apply the augmentation to an image
|
||||||
|
|
||||||
|
Class attributes:
|
||||||
|
name: Human-readable name of the augmentation.
|
||||||
|
affects_geometry: True if this augmentation modifies bbox coordinates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "base"
|
||||||
|
affects_geometry: bool = False
|
||||||
|
|
||||||
|
def __init__(self, params: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Initialize augmentation with parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Dictionary of augmentation-specific parameters.
|
||||||
|
"""
|
||||||
|
self.params = params
|
||||||
|
self._validate_params()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate augmentation parameters.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If parameters are invalid.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
"""
|
||||||
|
Apply augmentation to image.
|
||||||
|
|
||||||
|
IMPORTANT: Implementations must NOT modify the input image or bboxes.
|
||||||
|
Always create copies before modifying.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image as numpy array (H, W, C) with dtype uint8.
|
||||||
|
bboxes: Optional bounding boxes in YOLO format (N, 5) array.
|
||||||
|
Each row: [class_id, x_center, y_center, width, height].
|
||||||
|
Coordinates are normalized to 0-1 range.
|
||||||
|
rng: Random number generator for reproducibility.
|
||||||
|
If None, a new generator should be created.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AugmentationResult with augmented image and optionally updated bboxes.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get parameters optimized for preview display.
|
||||||
|
|
||||||
|
Override this method to provide parameters that produce
|
||||||
|
clearly visible effects for preview/demo purposes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of preview parameters.
|
||||||
|
"""
|
||||||
|
return dict(self.params)
|
||||||
274
packages/shared/shared/augmentation/config.py
Normal file
274
packages/shared/shared/augmentation/config.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""
|
||||||
|
Augmentation configuration module.
|
||||||
|
|
||||||
|
Provides dataclasses for configuring document image augmentations.
|
||||||
|
All default values are document-safe (conservative) to preserve text readability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AugmentationParams:
|
||||||
|
"""
|
||||||
|
Parameters for a single augmentation type.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: Whether this augmentation is enabled.
|
||||||
|
probability: Probability of applying this augmentation (0.0 to 1.0).
|
||||||
|
params: Type-specific parameters dictionary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
probability: float = 0.5
|
||||||
|
params: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"probability": self.probability,
|
||||||
|
"params": dict(self.params),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "AugmentationParams":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
enabled=data.get("enabled", False),
|
||||||
|
probability=data.get("probability", 0.5),
|
||||||
|
params=dict(data.get("params", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_perspective_warp() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.3,
|
||||||
|
params={"max_warp": 0.02}, # Very conservative - 2% max distortion
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_wrinkle() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.3,
|
||||||
|
params={"intensity": 0.3, "num_wrinkles": (2, 5)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_edge_damage() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.2,
|
||||||
|
params={"max_damage_ratio": 0.05}, # Max 5% of edge damaged
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_stain() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.2,
|
||||||
|
params={
|
||||||
|
"num_stains": (1, 3),
|
||||||
|
"max_radius_ratio": 0.1,
|
||||||
|
"opacity": (0.1, 0.3),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_lighting_variation() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=True, # Safe default, commonly needed
|
||||||
|
probability=0.5,
|
||||||
|
params={
|
||||||
|
"brightness_range": (-0.1, 0.1),
|
||||||
|
"contrast_range": (0.9, 1.1),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_shadow() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.3,
|
||||||
|
params={"num_shadows": (1, 2), "opacity": (0.2, 0.4)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_gaussian_blur() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.2,
|
||||||
|
params={"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_motion_blur() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.2,
|
||||||
|
params={"kernel_size": (5, 9), "angle_range": (-45, 45)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_gaussian_noise() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.3,
|
||||||
|
params={"mean": 0, "std": (5, 15)}, # Conservative noise levels
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_salt_pepper() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.2,
|
||||||
|
params={"amount": (0.001, 0.005)}, # Very sparse
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_paper_texture() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.3,
|
||||||
|
params={"texture_type": "random", "intensity": (0.05, 0.15)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_scanner_artifacts() -> AugmentationParams:
|
||||||
|
return AugmentationParams(
|
||||||
|
enabled=False,
|
||||||
|
probability=0.2,
|
||||||
|
params={"line_probability": 0.3, "dust_probability": 0.4},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AugmentationConfig:
|
||||||
|
"""
|
||||||
|
Complete augmentation configuration.
|
||||||
|
|
||||||
|
All augmentation types have document-safe defaults that preserve
|
||||||
|
text readability. Only lighting_variation is enabled by default.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
perspective_warp: Geometric perspective transform (affects bboxes).
|
||||||
|
wrinkle: Paper wrinkle/crease simulation.
|
||||||
|
edge_damage: Damaged/torn edge effects.
|
||||||
|
stain: Coffee stain/smudge effects.
|
||||||
|
lighting_variation: Brightness and contrast variation.
|
||||||
|
shadow: Shadow overlay effects.
|
||||||
|
gaussian_blur: Gaussian blur for focus issues.
|
||||||
|
motion_blur: Motion blur simulation.
|
||||||
|
gaussian_noise: Gaussian noise for sensor noise.
|
||||||
|
salt_pepper: Salt and pepper noise.
|
||||||
|
paper_texture: Paper texture overlay.
|
||||||
|
scanner_artifacts: Scanner line and dust artifacts.
|
||||||
|
preserve_bboxes: Whether to adjust bboxes for geometric transforms.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Geometric transforms (affects bboxes)
|
||||||
|
perspective_warp: AugmentationParams = field(
|
||||||
|
default_factory=_default_perspective_warp
|
||||||
|
)
|
||||||
|
|
||||||
|
# Degradation effects
|
||||||
|
wrinkle: AugmentationParams = field(default_factory=_default_wrinkle)
|
||||||
|
edge_damage: AugmentationParams = field(default_factory=_default_edge_damage)
|
||||||
|
stain: AugmentationParams = field(default_factory=_default_stain)
|
||||||
|
|
||||||
|
# Lighting effects
|
||||||
|
lighting_variation: AugmentationParams = field(
|
||||||
|
default_factory=_default_lighting_variation
|
||||||
|
)
|
||||||
|
shadow: AugmentationParams = field(default_factory=_default_shadow)
|
||||||
|
|
||||||
|
# Blur effects
|
||||||
|
gaussian_blur: AugmentationParams = field(default_factory=_default_gaussian_blur)
|
||||||
|
motion_blur: AugmentationParams = field(default_factory=_default_motion_blur)
|
||||||
|
|
||||||
|
# Noise effects
|
||||||
|
gaussian_noise: AugmentationParams = field(default_factory=_default_gaussian_noise)
|
||||||
|
salt_pepper: AugmentationParams = field(default_factory=_default_salt_pepper)
|
||||||
|
|
||||||
|
# Texture effects
|
||||||
|
paper_texture: AugmentationParams = field(default_factory=_default_paper_texture)
|
||||||
|
scanner_artifacts: AugmentationParams = field(
|
||||||
|
default_factory=_default_scanner_artifacts
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global settings
|
||||||
|
preserve_bboxes: bool = True
|
||||||
|
seed: int | None = None
|
||||||
|
|
||||||
|
# List of all augmentation field names
|
||||||
|
_AUGMENTATION_FIELDS: tuple[str, ...] = (
|
||||||
|
"perspective_warp",
|
||||||
|
"wrinkle",
|
||||||
|
"edge_damage",
|
||||||
|
"stain",
|
||||||
|
"lighting_variation",
|
||||||
|
"shadow",
|
||||||
|
"gaussian_blur",
|
||||||
|
"motion_blur",
|
||||||
|
"gaussian_noise",
|
||||||
|
"salt_pepper",
|
||||||
|
"paper_texture",
|
||||||
|
"scanner_artifacts",
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"preserve_bboxes": self.preserve_bboxes,
|
||||||
|
"seed": self.seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
for field_name in self._AUGMENTATION_FIELDS:
|
||||||
|
params: AugmentationParams = getattr(self, field_name)
|
||||||
|
result[field_name] = params.to_dict()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "AugmentationConfig":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"preserve_bboxes": data.get("preserve_bboxes", True),
|
||||||
|
"seed": data.get("seed"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for field_name in cls._AUGMENTATION_FIELDS:
|
||||||
|
if field_name in data:
|
||||||
|
field_data = data[field_name]
|
||||||
|
if isinstance(field_data, dict):
|
||||||
|
kwargs[field_name] = AugmentationParams.from_dict(field_data)
|
||||||
|
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
def get_enabled_augmentations(self) -> list[str]:
|
||||||
|
"""Get list of enabled augmentation names."""
|
||||||
|
enabled = []
|
||||||
|
for field_name in self._AUGMENTATION_FIELDS:
|
||||||
|
params: AugmentationParams = getattr(self, field_name)
|
||||||
|
if params.enabled:
|
||||||
|
enabled.append(field_name)
|
||||||
|
return enabled
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any configuration value is invalid.
|
||||||
|
"""
|
||||||
|
for field_name in self._AUGMENTATION_FIELDS:
|
||||||
|
params: AugmentationParams = getattr(self, field_name)
|
||||||
|
if not (0.0 <= params.probability <= 1.0):
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name}.probability must be between 0 and 1, "
|
||||||
|
f"got {params.probability}"
|
||||||
|
)
|
||||||
206
packages/shared/shared/augmentation/dataset_augmenter.py
Normal file
206
packages/shared/shared/augmentation/dataset_augmenter.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
Dataset Augmenter Module.
|
||||||
|
|
||||||
|
Applies augmentation pipeline to YOLO datasets,
|
||||||
|
creating new augmented images and label files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetAugmenter:
|
||||||
|
"""
|
||||||
|
Augments YOLO datasets by creating new images and label files.
|
||||||
|
|
||||||
|
Reads images from dataset/images/train/ and labels from dataset/labels/train/,
|
||||||
|
applies augmentation pipeline, and saves augmented versions with "_augN" suffix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: dict[str, Any],
|
||||||
|
seed: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize augmenter with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dictionary mapping augmentation names to their settings.
|
||||||
|
Each augmentation should have 'enabled', 'probability', and 'params'.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
"""
|
||||||
|
self._config_dict = config
|
||||||
|
self._seed = seed
|
||||||
|
self._config = self._build_config(config, seed)
|
||||||
|
|
||||||
|
def _build_config(
|
||||||
|
self,
|
||||||
|
config_dict: dict[str, Any],
|
||||||
|
seed: int | None,
|
||||||
|
) -> AugmentationConfig:
|
||||||
|
"""Build AugmentationConfig from dictionary."""
|
||||||
|
kwargs: dict[str, Any] = {"seed": seed, "preserve_bboxes": True}
|
||||||
|
|
||||||
|
for aug_name, aug_settings in config_dict.items():
|
||||||
|
if aug_name in AugmentationConfig._AUGMENTATION_FIELDS:
|
||||||
|
kwargs[aug_name] = AugmentationParams(
|
||||||
|
enabled=aug_settings.get("enabled", False),
|
||||||
|
probability=aug_settings.get("probability", 0.5),
|
||||||
|
params=aug_settings.get("params", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AugmentationConfig(**kwargs)
|
||||||
|
|
||||||
|
def augment_dataset(
|
||||||
|
self,
|
||||||
|
dataset_path: Path,
|
||||||
|
multiplier: int = 1,
|
||||||
|
split: str = "train",
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Augment a YOLO dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path: Path to dataset root (containing images/ and labels/).
|
||||||
|
multiplier: Number of augmented copies per original image.
|
||||||
|
split: Which split to augment (default: "train").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary dict with original_images, augmented_images, total_images.
|
||||||
|
"""
|
||||||
|
images_dir = dataset_path / "images" / split
|
||||||
|
labels_dir = dataset_path / "labels" / split
|
||||||
|
|
||||||
|
if not images_dir.exists():
|
||||||
|
raise ValueError(f"Images directory not found: {images_dir}")
|
||||||
|
|
||||||
|
# Find all images
|
||||||
|
image_extensions = ("*.png", "*.jpg", "*.jpeg")
|
||||||
|
image_files: list[Path] = []
|
||||||
|
for ext in image_extensions:
|
||||||
|
image_files.extend(images_dir.glob(ext))
|
||||||
|
|
||||||
|
original_count = len(image_files)
|
||||||
|
augmented_count = 0
|
||||||
|
|
||||||
|
if multiplier <= 0:
|
||||||
|
return {
|
||||||
|
"original_images": original_count,
|
||||||
|
"augmented_images": 0,
|
||||||
|
"total_images": original_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process each image
|
||||||
|
for img_path in image_files:
|
||||||
|
# Load image
|
||||||
|
pil_image = Image.open(img_path).convert("RGB")
|
||||||
|
image = np.array(pil_image)
|
||||||
|
|
||||||
|
# Load corresponding label
|
||||||
|
label_path = labels_dir / f"{img_path.stem}.txt"
|
||||||
|
bboxes = self._load_bboxes(label_path) if label_path.exists() else None
|
||||||
|
|
||||||
|
# Create multiple augmented versions
|
||||||
|
for aug_idx in range(multiplier):
|
||||||
|
# Create pipeline with adjusted seed for each augmentation
|
||||||
|
aug_seed = None
|
||||||
|
if self._seed is not None:
|
||||||
|
aug_seed = self._seed + aug_idx + hash(img_path.stem) % 10000
|
||||||
|
|
||||||
|
pipeline = AugmentationPipeline(
|
||||||
|
self._build_config(self._config_dict, aug_seed)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply augmentation
|
||||||
|
result = pipeline.apply(image, bboxes)
|
||||||
|
|
||||||
|
# Save augmented image
|
||||||
|
aug_name = f"{img_path.stem}_aug{aug_idx}{img_path.suffix}"
|
||||||
|
aug_img_path = images_dir / aug_name
|
||||||
|
aug_pil = Image.fromarray(result.image)
|
||||||
|
aug_pil.save(aug_img_path)
|
||||||
|
|
||||||
|
# Save augmented label
|
||||||
|
aug_label_path = labels_dir / f"{img_path.stem}_aug{aug_idx}.txt"
|
||||||
|
self._save_bboxes(aug_label_path, result.bboxes)
|
||||||
|
|
||||||
|
augmented_count += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Dataset augmentation complete: %d original, %d augmented",
|
||||||
|
original_count,
|
||||||
|
augmented_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"original_images": original_count,
|
||||||
|
"augmented_images": augmented_count,
|
||||||
|
"total_images": original_count + augmented_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_bboxes(self, label_path: Path) -> np.ndarray | None:
|
||||||
|
"""
|
||||||
|
Load bounding boxes from YOLO label file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label_path: Path to label file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of shape (N, 5) with class_id, x_center, y_center, width, height.
|
||||||
|
Returns None if file is empty or doesn't exist.
|
||||||
|
"""
|
||||||
|
if not label_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = label_path.read_text().strip()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bboxes = []
|
||||||
|
for line in content.split("\n"):
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) == 5:
|
||||||
|
class_id = int(parts[0])
|
||||||
|
x_center = float(parts[1])
|
||||||
|
y_center = float(parts[2])
|
||||||
|
width = float(parts[3])
|
||||||
|
height = float(parts[4])
|
||||||
|
bboxes.append([class_id, x_center, y_center, width, height])
|
||||||
|
|
||||||
|
if not bboxes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return np.array(bboxes, dtype=np.float32)
|
||||||
|
|
||||||
|
def _save_bboxes(self, label_path: Path, bboxes: np.ndarray | None) -> None:
|
||||||
|
"""
|
||||||
|
Save bounding boxes to YOLO label file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label_path: Path to save label file.
|
||||||
|
bboxes: Array of shape (N, 5) or None for empty labels.
|
||||||
|
"""
|
||||||
|
if bboxes is None or len(bboxes) == 0:
|
||||||
|
label_path.write_text("")
|
||||||
|
return
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for bbox in bboxes:
|
||||||
|
class_id = int(bbox[0])
|
||||||
|
x_center = bbox[1]
|
||||||
|
y_center = bbox[2]
|
||||||
|
width = bbox[3]
|
||||||
|
height = bbox[4]
|
||||||
|
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||||
|
|
||||||
|
label_path.write_text("\n".join(lines))
|
||||||
184
packages/shared/shared/augmentation/pipeline.py
Normal file
184
packages/shared/shared/augmentation/pipeline.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""
|
||||||
|
Augmentation pipeline module.
|
||||||
|
|
||||||
|
Orchestrates multiple augmentations with proper ordering and
|
||||||
|
provides preview functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.transforms.blur import GaussianBlur, MotionBlur
|
||||||
|
from shared.augmentation.transforms.degradation import EdgeDamage, Stain, Wrinkle
|
||||||
|
from shared.augmentation.transforms.geometric import PerspectiveWarp
|
||||||
|
from shared.augmentation.transforms.lighting import LightingVariation, Shadow
|
||||||
|
from shared.augmentation.transforms.noise import GaussianNoise, SaltPepper
|
||||||
|
from shared.augmentation.transforms.texture import PaperTexture, ScannerArtifacts
|
||||||
|
|
||||||
|
# Registry of augmentation classes
|
||||||
|
AUGMENTATION_REGISTRY: dict[str, type[BaseAugmentation]] = {
|
||||||
|
"perspective_warp": PerspectiveWarp,
|
||||||
|
"wrinkle": Wrinkle,
|
||||||
|
"edge_damage": EdgeDamage,
|
||||||
|
"stain": Stain,
|
||||||
|
"lighting_variation": LightingVariation,
|
||||||
|
"shadow": Shadow,
|
||||||
|
"gaussian_blur": GaussianBlur,
|
||||||
|
"motion_blur": MotionBlur,
|
||||||
|
"gaussian_noise": GaussianNoise,
|
||||||
|
"salt_pepper": SaltPepper,
|
||||||
|
"paper_texture": PaperTexture,
|
||||||
|
"scanner_artifacts": ScannerArtifacts,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationPipeline:
|
||||||
|
"""
|
||||||
|
Orchestrates multiple augmentations with proper ordering.
|
||||||
|
|
||||||
|
Augmentations are applied in the following order:
|
||||||
|
1. Geometric (perspective_warp) - affects bboxes
|
||||||
|
2. Degradation (wrinkle, edge_damage, stain) - visual artifacts
|
||||||
|
3. Lighting (lighting_variation, shadow)
|
||||||
|
4. Texture (paper_texture, scanner_artifacts)
|
||||||
|
5. Blur (gaussian_blur, motion_blur)
|
||||||
|
6. Noise (gaussian_noise, salt_pepper) - applied last
|
||||||
|
"""
|
||||||
|
|
||||||
|
STAGE_ORDER = [
|
||||||
|
"geometric",
|
||||||
|
"degradation",
|
||||||
|
"lighting",
|
||||||
|
"texture",
|
||||||
|
"blur",
|
||||||
|
"noise",
|
||||||
|
]
|
||||||
|
|
||||||
|
STAGE_MAPPING = {
|
||||||
|
"perspective_warp": "geometric",
|
||||||
|
"wrinkle": "degradation",
|
||||||
|
"edge_damage": "degradation",
|
||||||
|
"stain": "degradation",
|
||||||
|
"lighting_variation": "lighting",
|
||||||
|
"shadow": "lighting",
|
||||||
|
"paper_texture": "texture",
|
||||||
|
"scanner_artifacts": "texture",
|
||||||
|
"gaussian_blur": "blur",
|
||||||
|
"motion_blur": "blur",
|
||||||
|
"gaussian_noise": "noise",
|
||||||
|
"salt_pepper": "noise",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: AugmentationConfig) -> None:
|
||||||
|
"""
|
||||||
|
Initialize pipeline with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Augmentation configuration.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self._rng = np.random.default_rng(config.seed)
|
||||||
|
self._augmentations = self._build_augmentations()
|
||||||
|
|
||||||
|
def _build_augmentations(
|
||||||
|
self,
|
||||||
|
) -> list[tuple[str, BaseAugmentation, float]]:
|
||||||
|
"""Build ordered list of (name, augmentation, probability) tuples."""
|
||||||
|
augmentations: list[tuple[str, BaseAugmentation, float]] = []
|
||||||
|
|
||||||
|
for aug_name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||||
|
params: AugmentationParams = getattr(self.config, aug_name)
|
||||||
|
if params.enabled:
|
||||||
|
aug = aug_class(params.params)
|
||||||
|
augmentations.append((aug_name, aug, params.probability))
|
||||||
|
|
||||||
|
# Sort by stage order
|
||||||
|
def sort_key(item: tuple[str, BaseAugmentation, float]) -> int:
|
||||||
|
name, _, _ = item
|
||||||
|
stage = self.STAGE_MAPPING[name]
|
||||||
|
return self.STAGE_ORDER.index(stage)
|
||||||
|
|
||||||
|
return sorted(augmentations, key=sort_key)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
"""
|
||||||
|
Apply augmentation pipeline to image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image (H, W, C) as numpy array with dtype uint8.
|
||||||
|
bboxes: Optional bounding boxes in YOLO format (N, 5).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AugmentationResult with augmented image and optionally adjusted bboxes.
|
||||||
|
"""
|
||||||
|
current_image = image.copy()
|
||||||
|
current_bboxes = bboxes.copy() if bboxes is not None else None
|
||||||
|
applied_augmentations: list[str] = []
|
||||||
|
|
||||||
|
for name, aug, probability in self._augmentations:
|
||||||
|
if self._rng.random() < probability:
|
||||||
|
result = aug.apply(current_image, current_bboxes, self._rng)
|
||||||
|
current_image = result.image
|
||||||
|
if result.bboxes is not None and self.config.preserve_bboxes:
|
||||||
|
current_bboxes = result.bboxes
|
||||||
|
applied_augmentations.append(name)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=current_image,
|
||||||
|
bboxes=current_bboxes,
|
||||||
|
metadata={"applied_augmentations": applied_augmentations},
|
||||||
|
)
|
||||||
|
|
||||||
|
def preview(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
augmentation_name: str,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Preview a single augmentation deterministically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image.
|
||||||
|
augmentation_name: Name of augmentation to preview.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Augmented image.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If augmentation_name is not recognized.
|
||||||
|
"""
|
||||||
|
if augmentation_name not in AUGMENTATION_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown augmentation: {augmentation_name}")
|
||||||
|
|
||||||
|
params: AugmentationParams = getattr(self.config, augmentation_name)
|
||||||
|
aug = AUGMENTATION_REGISTRY[augmentation_name](params.params)
|
||||||
|
|
||||||
|
# Use deterministic RNG for preview
|
||||||
|
preview_rng = np.random.default_rng(42)
|
||||||
|
result = aug.apply(image.copy(), rng=preview_rng)
|
||||||
|
return result.image
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_augmentations() -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get list of available augmentations with metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries with augmentation info.
|
||||||
|
"""
|
||||||
|
augmentations = []
|
||||||
|
for name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||||
|
augmentations.append({
|
||||||
|
"name": name,
|
||||||
|
"description": aug_class.__doc__ or "",
|
||||||
|
"affects_geometry": aug_class.affects_geometry,
|
||||||
|
"stage": AugmentationPipeline.STAGE_MAPPING[name],
|
||||||
|
})
|
||||||
|
return augmentations
|
||||||
212
packages/shared/shared/augmentation/presets.py
Normal file
212
packages/shared/shared/augmentation/presets.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Predefined augmentation presets for common document scenarios.
|
||||||
|
|
||||||
|
Presets provide ready-to-use configurations optimized for different
|
||||||
|
use cases, from conservative (preserves text readability) to aggressive
|
||||||
|
(simulates poor document quality).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
|
||||||
|
|
||||||
|
PRESETS: dict[str, dict[str, Any]] = {
|
||||||
|
"conservative": {
|
||||||
|
"description": "Safe augmentations that preserve text readability",
|
||||||
|
"config": {
|
||||||
|
"lighting_variation": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.5,
|
||||||
|
"params": {
|
||||||
|
"brightness_range": (-0.1, 0.1),
|
||||||
|
"contrast_range": (0.9, 1.1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gaussian_noise": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"std": (3, 10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"moderate": {
|
||||||
|
"description": "Balanced augmentations for typical document degradation",
|
||||||
|
"config": {
|
||||||
|
"lighting_variation": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.5,
|
||||||
|
"params": {
|
||||||
|
"brightness_range": (-0.15, 0.15),
|
||||||
|
"contrast_range": (0.85, 1.15),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shadow": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"num_shadows": (1, 2), "opacity": (0.2, 0.35)},
|
||||||
|
},
|
||||||
|
"gaussian_noise": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"std": (5, 12)},
|
||||||
|
},
|
||||||
|
"gaussian_blur": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.2,
|
||||||
|
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.0)},
|
||||||
|
},
|
||||||
|
"paper_texture": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"intensity": (0.05, 0.12)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"aggressive": {
|
||||||
|
"description": "Heavy augmentations simulating poor scan quality",
|
||||||
|
"config": {
|
||||||
|
"perspective_warp": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"max_warp": 0.02},
|
||||||
|
},
|
||||||
|
"wrinkle": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.4,
|
||||||
|
"params": {"intensity": 0.3, "num_wrinkles": (2, 4)},
|
||||||
|
},
|
||||||
|
"stain": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {
|
||||||
|
"num_stains": (1, 2),
|
||||||
|
"max_radius_ratio": 0.08,
|
||||||
|
"opacity": (0.1, 0.25),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"lighting_variation": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.6,
|
||||||
|
"params": {
|
||||||
|
"brightness_range": (-0.2, 0.2),
|
||||||
|
"contrast_range": (0.8, 1.2),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"shadow": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.4,
|
||||||
|
"params": {"num_shadows": (1, 2), "opacity": (0.25, 0.4)},
|
||||||
|
},
|
||||||
|
"gaussian_blur": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
|
||||||
|
},
|
||||||
|
"motion_blur": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.2,
|
||||||
|
"params": {"kernel_size": (5, 7), "angle_range": (-30, 30)},
|
||||||
|
},
|
||||||
|
"gaussian_noise": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.4,
|
||||||
|
"params": {"std": (8, 18)},
|
||||||
|
},
|
||||||
|
"paper_texture": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.4,
|
||||||
|
"params": {"intensity": (0.08, 0.15)},
|
||||||
|
},
|
||||||
|
"scanner_artifacts": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"line_probability": 0.4, "dust_probability": 0.5},
|
||||||
|
},
|
||||||
|
"edge_damage": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.2,
|
||||||
|
"params": {"max_damage_ratio": 0.04},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"scanned_document": {
|
||||||
|
"description": "Simulates typical scanned document artifacts",
|
||||||
|
"config": {
|
||||||
|
"scanner_artifacts": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.5,
|
||||||
|
"params": {"line_probability": 0.4, "dust_probability": 0.5},
|
||||||
|
},
|
||||||
|
"paper_texture": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.4,
|
||||||
|
"params": {"intensity": (0.05, 0.12)},
|
||||||
|
},
|
||||||
|
"lighting_variation": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {
|
||||||
|
"brightness_range": (-0.1, 0.1),
|
||||||
|
"contrast_range": (0.9, 1.1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gaussian_noise": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.3,
|
||||||
|
"params": {"std": (5, 12)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_preset_config(preset_name: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the configuration dictionary for a preset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preset_name: Name of the preset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configuration dictionary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If preset is not found.
|
||||||
|
"""
|
||||||
|
if preset_name not in PRESETS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown preset: {preset_name}. "
|
||||||
|
f"Available presets: {list(PRESETS.keys())}"
|
||||||
|
)
|
||||||
|
return PRESETS[preset_name]["config"]
|
||||||
|
|
||||||
|
|
||||||
|
def create_config_from_preset(preset_name: str) -> AugmentationConfig:
|
||||||
|
"""
|
||||||
|
Create an AugmentationConfig from a preset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preset_name: Name of the preset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AugmentationConfig instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If preset is not found.
|
||||||
|
"""
|
||||||
|
config_dict = get_preset_config(preset_name)
|
||||||
|
return AugmentationConfig.from_dict(config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def list_presets() -> list[dict[str, str]]:
|
||||||
|
"""
|
||||||
|
List all available presets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries with name and description.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{"name": name, "description": preset["description"]}
|
||||||
|
for name, preset in PRESETS.items()
|
||||||
|
]
|
||||||
13
packages/shared/shared/augmentation/transforms/__init__.py
Normal file
13
packages/shared/shared/augmentation/transforms/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Augmentation transform implementations.
|
||||||
|
|
||||||
|
Each module contains related augmentation classes:
|
||||||
|
- geometric.py: Perspective warp and other geometric transforms
|
||||||
|
- degradation.py: Wrinkle, edge damage, stain effects
|
||||||
|
- lighting.py: Lighting variation and shadow effects
|
||||||
|
- blur.py: Gaussian and motion blur
|
||||||
|
- noise.py: Gaussian and salt-pepper noise
|
||||||
|
- texture.py: Paper texture and scanner artifacts
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Will be populated as transforms are implemented
|
||||||
144
packages/shared/shared/augmentation/transforms/blur.py
Normal file
144
packages/shared/shared/augmentation/transforms/blur.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
Blur augmentation transforms.
|
||||||
|
|
||||||
|
Provides blur effects for document image augmentation:
|
||||||
|
- GaussianBlur: Simulates out-of-focus capture
|
||||||
|
- MotionBlur: Simulates camera/document movement during capture
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianBlur(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Applies Gaussian blur to the image.
|
||||||
|
|
||||||
|
Simulates out-of-focus capture or low-quality optics.
|
||||||
|
Conservative defaults to preserve text readability.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
kernel_size: Blur kernel size, int or (min, max) tuple (default: (3, 5)).
|
||||||
|
sigma: Blur sigma, float or (min, max) tuple (default: (0.5, 1.5)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "gaussian_blur"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
kernel_size = self.params.get("kernel_size", (3, 5))
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
if kernel_size < 1 or kernel_size % 2 == 0:
|
||||||
|
raise ValueError("kernel_size must be a positive odd integer")
|
||||||
|
elif isinstance(kernel_size, tuple):
|
||||||
|
if kernel_size[0] < 1 or kernel_size[1] < kernel_size[0]:
|
||||||
|
raise ValueError("kernel_size tuple must be (min, max) with min >= 1")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
kernel_size = self.params.get("kernel_size", (3, 5))
|
||||||
|
sigma = self.params.get("sigma", (0.5, 1.5))
|
||||||
|
|
||||||
|
if isinstance(kernel_size, tuple):
|
||||||
|
# Choose random odd kernel size
|
||||||
|
min_k, max_k = kernel_size
|
||||||
|
possible_sizes = [k for k in range(min_k, max_k + 1) if k % 2 == 1]
|
||||||
|
if not possible_sizes:
|
||||||
|
possible_sizes = [min_k if min_k % 2 == 1 else min_k + 1]
|
||||||
|
kernel_size = rng.choice(possible_sizes)
|
||||||
|
|
||||||
|
if isinstance(sigma, tuple):
|
||||||
|
sigma = rng.uniform(sigma[0], sigma[1])
|
||||||
|
|
||||||
|
# Ensure kernel size is odd
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1
|
||||||
|
|
||||||
|
# Apply Gaussian blur
|
||||||
|
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=blurred,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"kernel_size": kernel_size, "sigma": sigma},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"kernel_size": 5, "sigma": 1.5}
|
||||||
|
|
||||||
|
|
||||||
|
class MotionBlur(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Applies motion blur to the image.
|
||||||
|
|
||||||
|
Simulates camera shake or document movement during capture.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
kernel_size: Blur kernel size, int or (min, max) tuple (default: (5, 9)).
|
||||||
|
angle_range: Motion angle range in degrees (default: (-45, 45)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "motion_blur"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
kernel_size = self.params.get("kernel_size", (5, 9))
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
if kernel_size < 3:
|
||||||
|
raise ValueError("kernel_size must be at least 3")
|
||||||
|
elif isinstance(kernel_size, tuple):
|
||||||
|
if kernel_size[0] < 3:
|
||||||
|
raise ValueError("kernel_size min must be at least 3")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
kernel_size = self.params.get("kernel_size", (5, 9))
|
||||||
|
angle_range = self.params.get("angle_range", (-45, 45))
|
||||||
|
|
||||||
|
if isinstance(kernel_size, tuple):
|
||||||
|
kernel_size = rng.integers(kernel_size[0], kernel_size[1] + 1)
|
||||||
|
|
||||||
|
angle = rng.uniform(angle_range[0], angle_range[1])
|
||||||
|
|
||||||
|
# Create motion blur kernel
|
||||||
|
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
||||||
|
|
||||||
|
# Draw a line in the center of the kernel
|
||||||
|
center = kernel_size // 2
|
||||||
|
angle_rad = np.deg2rad(angle)
|
||||||
|
|
||||||
|
for i in range(kernel_size):
|
||||||
|
offset = i - center
|
||||||
|
x = int(center + offset * np.cos(angle_rad))
|
||||||
|
y = int(center + offset * np.sin(angle_rad))
|
||||||
|
if 0 <= x < kernel_size and 0 <= y < kernel_size:
|
||||||
|
kernel[y, x] = 1.0
|
||||||
|
|
||||||
|
# Normalize kernel
|
||||||
|
kernel = kernel / kernel.sum() if kernel.sum() > 0 else kernel
|
||||||
|
|
||||||
|
# Apply motion blur
|
||||||
|
blurred = cv2.filter2D(image, -1, kernel)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=blurred,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"kernel_size": kernel_size, "angle": angle},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"kernel_size": 7, "angle_range": (-30, 30)}
|
||||||
259
packages/shared/shared/augmentation/transforms/degradation.py
Normal file
259
packages/shared/shared/augmentation/transforms/degradation.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
Degradation augmentation transforms.
|
||||||
|
|
||||||
|
Provides degradation effects for document image augmentation:
|
||||||
|
- Wrinkle: Paper wrinkle/crease simulation
|
||||||
|
- EdgeDamage: Damaged/torn edge effects
|
||||||
|
- Stain: Coffee stain/smudge effects
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
|
||||||
|
class Wrinkle(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Simulates paper wrinkles/creases using displacement mapping.
|
||||||
|
|
||||||
|
Document-friendly: Uses subtle displacement to preserve text readability.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
intensity: Wrinkle intensity (0-1) (default: 0.3).
|
||||||
|
num_wrinkles: Number of wrinkles, int or (min, max) tuple (default: (2, 5)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "wrinkle"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
intensity = self.params.get("intensity", 0.3)
|
||||||
|
if not (0 < intensity <= 1):
|
||||||
|
raise ValueError("intensity must be between 0 and 1")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
intensity = self.params.get("intensity", 0.3)
|
||||||
|
num_wrinkles = self.params.get("num_wrinkles", (2, 5))
|
||||||
|
|
||||||
|
if isinstance(num_wrinkles, tuple):
|
||||||
|
num_wrinkles = rng.integers(num_wrinkles[0], num_wrinkles[1] + 1)
|
||||||
|
|
||||||
|
# Create displacement maps
|
||||||
|
displacement_x = np.zeros((h, w), dtype=np.float32)
|
||||||
|
displacement_y = np.zeros((h, w), dtype=np.float32)
|
||||||
|
|
||||||
|
for _ in range(num_wrinkles):
|
||||||
|
# Random wrinkle parameters
|
||||||
|
angle = rng.uniform(0, np.pi)
|
||||||
|
x0 = rng.uniform(0, w)
|
||||||
|
y0 = rng.uniform(0, h)
|
||||||
|
length = rng.uniform(0.3, 0.8) * min(h, w)
|
||||||
|
width = rng.uniform(0.02, 0.05) * min(h, w)
|
||||||
|
|
||||||
|
# Create coordinate grids
|
||||||
|
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
|
||||||
|
|
||||||
|
# Distance from wrinkle line
|
||||||
|
dx = (xx - x0) * np.cos(angle) + (yy - y0) * np.sin(angle)
|
||||||
|
dy = -(xx - x0) * np.sin(angle) + (yy - y0) * np.cos(angle)
|
||||||
|
|
||||||
|
# Gaussian falloff perpendicular to wrinkle
|
||||||
|
mask = np.exp(-dy**2 / (2 * width**2))
|
||||||
|
mask *= (np.abs(dx) < length / 2).astype(np.float32)
|
||||||
|
|
||||||
|
# Displacement perpendicular to wrinkle
|
||||||
|
disp_amount = intensity * rng.uniform(2, 8)
|
||||||
|
displacement_x += mask * disp_amount * np.sin(angle)
|
||||||
|
displacement_y += mask * disp_amount * np.cos(angle)
|
||||||
|
|
||||||
|
# Create remap coordinates
|
||||||
|
map_x = (np.arange(w)[np.newaxis, :] + displacement_x).astype(np.float32)
|
||||||
|
map_y = (np.arange(h)[:, np.newaxis] + displacement_y).astype(np.float32)
|
||||||
|
|
||||||
|
# Apply displacement
|
||||||
|
augmented = cv2.remap(
|
||||||
|
image, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add subtle shading along wrinkles
|
||||||
|
max_disp = np.max(np.abs(displacement_y)) + 1e-6
|
||||||
|
shading = 1 - 0.1 * intensity * np.abs(displacement_y) / max_disp
|
||||||
|
shading = shading[:, :, np.newaxis]
|
||||||
|
augmented = (augmented.astype(np.float32) * shading).astype(np.uint8)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=augmented,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"num_wrinkles": num_wrinkles, "intensity": intensity},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"intensity": 0.5, "num_wrinkles": 3}
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeDamage(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adds damaged/torn edge effects to the image.
|
||||||
|
|
||||||
|
Simulates worn or torn document edges.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
max_damage_ratio: Maximum proportion of edge to damage (default: 0.05).
|
||||||
|
edges: Which edges to potentially damage (default: all).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "edge_damage"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
|
||||||
|
if not (0 < max_damage_ratio <= 0.2):
|
||||||
|
raise ValueError("max_damage_ratio must be between 0 and 0.2")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
|
||||||
|
edges = self.params.get("edges", ["top", "bottom", "left", "right"])
|
||||||
|
|
||||||
|
output = image.copy()
|
||||||
|
|
||||||
|
# Select random edge to damage
|
||||||
|
edge = rng.choice(edges)
|
||||||
|
damage_size = int(max_damage_ratio * min(h, w))
|
||||||
|
|
||||||
|
if edge == "top":
|
||||||
|
# Create irregular top edge
|
||||||
|
for x in range(w):
|
||||||
|
depth = rng.integers(0, damage_size + 1)
|
||||||
|
if depth > 0:
|
||||||
|
# Random color (white or darker)
|
||||||
|
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||||
|
output[:depth, x] = color
|
||||||
|
|
||||||
|
elif edge == "bottom":
|
||||||
|
for x in range(w):
|
||||||
|
depth = rng.integers(0, damage_size + 1)
|
||||||
|
if depth > 0:
|
||||||
|
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||||
|
output[h - depth:, x] = color
|
||||||
|
|
||||||
|
elif edge == "left":
|
||||||
|
for y in range(h):
|
||||||
|
depth = rng.integers(0, damage_size + 1)
|
||||||
|
if depth > 0:
|
||||||
|
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||||
|
output[y, :depth] = color
|
||||||
|
|
||||||
|
else: # right
|
||||||
|
for y in range(h):
|
||||||
|
depth = rng.integers(0, damage_size + 1)
|
||||||
|
if depth > 0:
|
||||||
|
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||||
|
output[y, w - depth:] = color
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=output,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"edge": edge, "damage_size": damage_size},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"max_damage_ratio": 0.08}
|
||||||
|
|
||||||
|
|
||||||
|
class Stain(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adds coffee stain/smudge effects to the image.
|
||||||
|
|
||||||
|
Simulates accidental stains on documents.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_stains: Number of stains, int or (min, max) tuple (default: (1, 3)).
|
||||||
|
max_radius_ratio: Maximum stain radius as ratio of image size (default: 0.1).
|
||||||
|
opacity: Stain opacity, float or (min, max) tuple (default: (0.1, 0.3)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "stain"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
opacity = self.params.get("opacity", (0.1, 0.3))
|
||||||
|
if isinstance(opacity, (int, float)):
|
||||||
|
if not (0 < opacity <= 1):
|
||||||
|
raise ValueError("opacity must be between 0 and 1")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
num_stains = self.params.get("num_stains", (1, 3))
|
||||||
|
max_radius_ratio = self.params.get("max_radius_ratio", 0.1)
|
||||||
|
opacity = self.params.get("opacity", (0.1, 0.3))
|
||||||
|
|
||||||
|
if isinstance(num_stains, tuple):
|
||||||
|
num_stains = rng.integers(num_stains[0], num_stains[1] + 1)
|
||||||
|
if isinstance(opacity, tuple):
|
||||||
|
opacity = rng.uniform(opacity[0], opacity[1])
|
||||||
|
|
||||||
|
output = image.astype(np.float32)
|
||||||
|
max_radius = int(max_radius_ratio * min(h, w))
|
||||||
|
|
||||||
|
for _ in range(num_stains):
|
||||||
|
# Random stain position and size
|
||||||
|
cx = rng.integers(max_radius, w - max_radius)
|
||||||
|
cy = rng.integers(max_radius, h - max_radius)
|
||||||
|
radius = rng.integers(max_radius // 3, max_radius)
|
||||||
|
|
||||||
|
# Create stain mask with irregular edges
|
||||||
|
yy, xx = np.ogrid[:h, :w]
|
||||||
|
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
|
||||||
|
|
||||||
|
# Add noise to make edges irregular
|
||||||
|
noise = rng.uniform(0.8, 1.2, (h, w))
|
||||||
|
mask = (dist < radius * noise).astype(np.float32)
|
||||||
|
|
||||||
|
# Blur for soft edges
|
||||||
|
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
||||||
|
|
||||||
|
# Random stain color (brownish/yellowish)
|
||||||
|
stain_color = np.array([
|
||||||
|
rng.integers(180, 220), # R
|
||||||
|
rng.integers(160, 200), # G
|
||||||
|
rng.integers(120, 160), # B
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# Apply stain
|
||||||
|
mask_3d = mask[:, :, np.newaxis]
|
||||||
|
output = output * (1 - mask_3d * opacity) + stain_color * mask_3d * opacity
|
||||||
|
|
||||||
|
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=output,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"num_stains": num_stains, "opacity": opacity},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"num_stains": 2, "max_radius_ratio": 0.1, "opacity": 0.25}
|
||||||
145
packages/shared/shared/augmentation/transforms/geometric.py
Normal file
145
packages/shared/shared/augmentation/transforms/geometric.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Geometric augmentation transforms.
|
||||||
|
|
||||||
|
Provides geometric transforms for document image augmentation:
|
||||||
|
- PerspectiveWarp: Subtle perspective distortion
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
|
||||||
|
class PerspectiveWarp(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Applies subtle perspective transformation to the image.
|
||||||
|
|
||||||
|
Simulates viewing document at slight angle. Very conservative
|
||||||
|
by default to preserve text readability.
|
||||||
|
|
||||||
|
IMPORTANT: This transform affects bounding box coordinates.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
max_warp: Maximum warp as proportion of image size (default: 0.02).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "perspective_warp"
|
||||||
|
affects_geometry = True
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
max_warp = self.params.get("max_warp", 0.02)
|
||||||
|
if not (0 < max_warp <= 0.1):
|
||||||
|
raise ValueError("max_warp must be between 0 and 0.1")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
max_warp = self.params.get("max_warp", 0.02)
|
||||||
|
|
||||||
|
# Original corners
|
||||||
|
src_pts = np.float32([
|
||||||
|
[0, 0],
|
||||||
|
[w, 0],
|
||||||
|
[w, h],
|
||||||
|
[0, h],
|
||||||
|
])
|
||||||
|
|
||||||
|
# Add random perturbations to corners
|
||||||
|
max_offset = max_warp * min(h, w)
|
||||||
|
dst_pts = src_pts.copy()
|
||||||
|
for i in range(4):
|
||||||
|
dst_pts[i, 0] += rng.uniform(-max_offset, max_offset)
|
||||||
|
dst_pts[i, 1] += rng.uniform(-max_offset, max_offset)
|
||||||
|
|
||||||
|
# Compute perspective transform matrix
|
||||||
|
transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
|
||||||
|
|
||||||
|
# Apply perspective transform
|
||||||
|
warped = cv2.warpPerspective(
|
||||||
|
image, transform_matrix, (w, h),
|
||||||
|
borderMode=cv2.BORDER_REPLICATE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform bounding boxes if present
|
||||||
|
transformed_bboxes = None
|
||||||
|
if bboxes is not None:
|
||||||
|
transformed_bboxes = self._transform_bboxes(
|
||||||
|
bboxes, transform_matrix, w, h
|
||||||
|
)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=warped,
|
||||||
|
bboxes=transformed_bboxes,
|
||||||
|
transform_matrix=transform_matrix,
|
||||||
|
metadata={"max_warp": max_warp},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_bboxes(
|
||||||
|
self,
|
||||||
|
bboxes: np.ndarray,
|
||||||
|
transform_matrix: np.ndarray,
|
||||||
|
w: int,
|
||||||
|
h: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Transform bounding boxes using perspective matrix."""
|
||||||
|
if len(bboxes) == 0:
|
||||||
|
return bboxes.copy()
|
||||||
|
|
||||||
|
transformed = []
|
||||||
|
for bbox in bboxes:
|
||||||
|
class_id, x_center, y_center, width, height = bbox
|
||||||
|
|
||||||
|
# Convert normalized coords to pixel coords
|
||||||
|
x_center_px = x_center * w
|
||||||
|
y_center_px = y_center * h
|
||||||
|
width_px = width * w
|
||||||
|
height_px = height * h
|
||||||
|
|
||||||
|
# Get corner points
|
||||||
|
x1 = x_center_px - width_px / 2
|
||||||
|
y1 = y_center_px - height_px / 2
|
||||||
|
x2 = x_center_px + width_px / 2
|
||||||
|
y2 = y_center_px + height_px / 2
|
||||||
|
|
||||||
|
# Transform all 4 corners
|
||||||
|
corners = np.float32([
|
||||||
|
[x1, y1],
|
||||||
|
[x2, y1],
|
||||||
|
[x2, y2],
|
||||||
|
[x1, y2],
|
||||||
|
]).reshape(-1, 1, 2)
|
||||||
|
|
||||||
|
transformed_corners = cv2.perspectiveTransform(corners, transform_matrix)
|
||||||
|
transformed_corners = transformed_corners.reshape(-1, 2)
|
||||||
|
|
||||||
|
# Get bounding box of transformed corners
|
||||||
|
new_x1 = np.min(transformed_corners[:, 0])
|
||||||
|
new_y1 = np.min(transformed_corners[:, 1])
|
||||||
|
new_x2 = np.max(transformed_corners[:, 0])
|
||||||
|
new_y2 = np.max(transformed_corners[:, 1])
|
||||||
|
|
||||||
|
# Convert back to normalized center format
|
||||||
|
new_width = (new_x2 - new_x1) / w
|
||||||
|
new_height = (new_y2 - new_y1) / h
|
||||||
|
new_x_center = ((new_x1 + new_x2) / 2) / w
|
||||||
|
new_y_center = ((new_y1 + new_y2) / 2) / h
|
||||||
|
|
||||||
|
# Clamp to valid range
|
||||||
|
new_x_center = np.clip(new_x_center, 0, 1)
|
||||||
|
new_y_center = np.clip(new_y_center, 0, 1)
|
||||||
|
new_width = np.clip(new_width, 0, 1)
|
||||||
|
new_height = np.clip(new_height, 0, 1)
|
||||||
|
|
||||||
|
transformed.append([class_id, new_x_center, new_y_center, new_width, new_height])
|
||||||
|
|
||||||
|
return np.array(transformed, dtype=np.float32)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"max_warp": 0.03}
|
||||||
167
packages/shared/shared/augmentation/transforms/lighting.py
Normal file
167
packages/shared/shared/augmentation/transforms/lighting.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
Lighting augmentation transforms.
|
||||||
|
|
||||||
|
Provides lighting effects for document image augmentation:
|
||||||
|
- LightingVariation: Adjusts brightness and contrast
|
||||||
|
- Shadow: Adds shadow overlay effects
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
|
||||||
|
class LightingVariation(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adjusts image brightness and contrast.
|
||||||
|
|
||||||
|
Simulates different lighting conditions during document capture.
|
||||||
|
Safe for documents with conservative default parameters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
brightness_range: (min, max) brightness adjustment (default: (-0.1, 0.1)).
|
||||||
|
contrast_range: (min, max) contrast multiplier (default: (0.9, 1.1)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "lighting_variation"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
brightness = self.params.get("brightness_range", (-0.1, 0.1))
|
||||||
|
contrast = self.params.get("contrast_range", (0.9, 1.1))
|
||||||
|
|
||||||
|
if not isinstance(brightness, tuple) or len(brightness) != 2:
|
||||||
|
raise ValueError("brightness_range must be a (min, max) tuple")
|
||||||
|
if not isinstance(contrast, tuple) or len(contrast) != 2:
|
||||||
|
raise ValueError("contrast_range must be a (min, max) tuple")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
brightness_range = self.params.get("brightness_range", (-0.1, 0.1))
|
||||||
|
contrast_range = self.params.get("contrast_range", (0.9, 1.1))
|
||||||
|
|
||||||
|
# Random brightness and contrast
|
||||||
|
brightness = rng.uniform(brightness_range[0], brightness_range[1])
|
||||||
|
contrast = rng.uniform(contrast_range[0], contrast_range[1])
|
||||||
|
|
||||||
|
# Apply adjustments
|
||||||
|
adjusted = image.astype(np.float32)
|
||||||
|
|
||||||
|
# Contrast adjustment (multiply around mean)
|
||||||
|
mean = adjusted.mean()
|
||||||
|
adjusted = (adjusted - mean) * contrast + mean
|
||||||
|
|
||||||
|
# Brightness adjustment (add offset)
|
||||||
|
adjusted = adjusted + brightness * 255
|
||||||
|
|
||||||
|
# Clip and convert back
|
||||||
|
adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=adjusted,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"brightness": brightness, "contrast": contrast},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"brightness_range": (-0.15, 0.15), "contrast_range": (0.85, 1.15)}
|
||||||
|
|
||||||
|
|
||||||
|
class Shadow(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adds shadow overlay effects to the image.
|
||||||
|
|
||||||
|
Simulates shadows from objects or hands during document capture.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_shadows: Number of shadow regions, int or (min, max) tuple (default: (1, 2)).
|
||||||
|
opacity: Shadow darkness, float or (min, max) tuple (default: (0.2, 0.4)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "shadow"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
opacity = self.params.get("opacity", (0.2, 0.4))
|
||||||
|
if isinstance(opacity, (int, float)):
|
||||||
|
if not (0 <= opacity <= 1):
|
||||||
|
raise ValueError("opacity must be between 0 and 1")
|
||||||
|
elif isinstance(opacity, tuple):
|
||||||
|
if not (0 <= opacity[0] <= opacity[1] <= 1):
|
||||||
|
raise ValueError("opacity tuple must be in range [0, 1]")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
num_shadows = self.params.get("num_shadows", (1, 2))
|
||||||
|
opacity = self.params.get("opacity", (0.2, 0.4))
|
||||||
|
|
||||||
|
if isinstance(num_shadows, tuple):
|
||||||
|
num_shadows = rng.integers(num_shadows[0], num_shadows[1] + 1)
|
||||||
|
if isinstance(opacity, tuple):
|
||||||
|
opacity = rng.uniform(opacity[0], opacity[1])
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
output = image.astype(np.float32)
|
||||||
|
|
||||||
|
for _ in range(num_shadows):
|
||||||
|
# Generate random shadow polygon
|
||||||
|
num_vertices = rng.integers(3, 6)
|
||||||
|
vertices = []
|
||||||
|
|
||||||
|
# Start from a random edge
|
||||||
|
edge = rng.integers(0, 4)
|
||||||
|
if edge == 0: # Top
|
||||||
|
start = (rng.integers(0, w), 0)
|
||||||
|
elif edge == 1: # Right
|
||||||
|
start = (w, rng.integers(0, h))
|
||||||
|
elif edge == 2: # Bottom
|
||||||
|
start = (rng.integers(0, w), h)
|
||||||
|
else: # Left
|
||||||
|
start = (0, rng.integers(0, h))
|
||||||
|
|
||||||
|
vertices.append(start)
|
||||||
|
|
||||||
|
# Add random vertices
|
||||||
|
for _ in range(num_vertices - 1):
|
||||||
|
x = rng.integers(0, w)
|
||||||
|
y = rng.integers(0, h)
|
||||||
|
vertices.append((x, y))
|
||||||
|
|
||||||
|
# Create shadow mask
|
||||||
|
mask = np.zeros((h, w), dtype=np.float32)
|
||||||
|
pts = np.array(vertices, dtype=np.int32).reshape((-1, 1, 2))
|
||||||
|
cv2.fillPoly(mask, [pts], 1.0)
|
||||||
|
|
||||||
|
# Blur the mask for soft edges
|
||||||
|
blur_size = max(31, min(h, w) // 10)
|
||||||
|
if blur_size % 2 == 0:
|
||||||
|
blur_size += 1
|
||||||
|
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
|
||||||
|
|
||||||
|
# Apply shadow
|
||||||
|
shadow_factor = 1 - opacity * mask[:, :, np.newaxis]
|
||||||
|
output = output * shadow_factor
|
||||||
|
|
||||||
|
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=output,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"num_shadows": num_shadows, "opacity": opacity},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"num_shadows": 1, "opacity": 0.3}
|
||||||
142
packages/shared/shared/augmentation/transforms/noise.py
Normal file
142
packages/shared/shared/augmentation/transforms/noise.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""
|
||||||
|
Noise augmentation transforms.
|
||||||
|
|
||||||
|
Provides noise effects for document image augmentation:
|
||||||
|
- GaussianNoise: Adds Gaussian noise to simulate sensor noise
|
||||||
|
- SaltPepper: Adds salt and pepper noise for impulse noise effects
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianNoise(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adds Gaussian noise to the image.
|
||||||
|
|
||||||
|
Simulates sensor noise from cameras or scanners.
|
||||||
|
Document-safe with conservative default parameters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
mean: Mean of the Gaussian noise (default: 0).
|
||||||
|
std: Standard deviation, can be int or (min, max) tuple (default: (5, 15)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "gaussian_noise"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
std = self.params.get("std", (5, 15))
|
||||||
|
if isinstance(std, (int, float)):
|
||||||
|
if std < 0:
|
||||||
|
raise ValueError("std must be non-negative")
|
||||||
|
elif isinstance(std, tuple):
|
||||||
|
if len(std) != 2 or std[0] < 0 or std[1] < std[0]:
|
||||||
|
raise ValueError("std tuple must be (min, max) with min <= max >= 0")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
mean = self.params.get("mean", 0)
|
||||||
|
std = self.params.get("std", (5, 15))
|
||||||
|
|
||||||
|
if isinstance(std, tuple):
|
||||||
|
std = rng.uniform(std[0], std[1])
|
||||||
|
|
||||||
|
# Generate noise
|
||||||
|
noise = rng.normal(mean, std, image.shape).astype(np.float32)
|
||||||
|
|
||||||
|
# Apply noise
|
||||||
|
noisy = image.astype(np.float32) + noise
|
||||||
|
noisy = np.clip(noisy, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=noisy,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"applied_std": std},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict[str, Any]:
|
||||||
|
return {"mean": 0, "std": 15}
|
||||||
|
|
||||||
|
|
||||||
|
class SaltPepper(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adds salt and pepper (impulse) noise to the image.
|
||||||
|
|
||||||
|
Simulates defects from damaged sensors or transmission errors.
|
||||||
|
Very sparse by default to preserve document readability.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
amount: Proportion of pixels to affect, can be float or (min, max) tuple.
|
||||||
|
Default: (0.001, 0.005) for very sparse noise.
|
||||||
|
salt_vs_pepper: Ratio of salt to pepper (default: 0.5 for equal amounts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "salt_pepper"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
amount = self.params.get("amount", (0.001, 0.005))
|
||||||
|
if isinstance(amount, (int, float)):
|
||||||
|
if not (0 <= amount <= 1):
|
||||||
|
raise ValueError("amount must be between 0 and 1")
|
||||||
|
elif isinstance(amount, tuple):
|
||||||
|
if len(amount) != 2 or not (0 <= amount[0] <= amount[1] <= 1):
|
||||||
|
raise ValueError("amount tuple must be (min, max) in range [0, 1]")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
amount = self.params.get("amount", (0.001, 0.005))
|
||||||
|
salt_vs_pepper = self.params.get("salt_vs_pepper", 0.5)
|
||||||
|
|
||||||
|
if isinstance(amount, tuple):
|
||||||
|
amount = rng.uniform(amount[0], amount[1])
|
||||||
|
|
||||||
|
# Copy image
|
||||||
|
output = image.copy()
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
total_pixels = h * w
|
||||||
|
|
||||||
|
# Calculate number of salt and pepper pixels
|
||||||
|
num_salt = int(total_pixels * amount * salt_vs_pepper)
|
||||||
|
num_pepper = int(total_pixels * amount * (1 - salt_vs_pepper))
|
||||||
|
|
||||||
|
# Add salt (white pixels)
|
||||||
|
if num_salt > 0:
|
||||||
|
salt_coords = (
|
||||||
|
rng.integers(0, h, num_salt),
|
||||||
|
rng.integers(0, w, num_salt),
|
||||||
|
)
|
||||||
|
output[salt_coords] = 255
|
||||||
|
|
||||||
|
# Add pepper (black pixels)
|
||||||
|
if num_pepper > 0:
|
||||||
|
pepper_coords = (
|
||||||
|
rng.integers(0, h, num_pepper),
|
||||||
|
rng.integers(0, w, num_pepper),
|
||||||
|
)
|
||||||
|
output[pepper_coords] = 0
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=output,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"applied_amount": amount},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict[str, Any]:
|
||||||
|
return {"amount": 0.01, "salt_vs_pepper": 0.5}
|
||||||
159
packages/shared/shared/augmentation/transforms/texture.py
Normal file
159
packages/shared/shared/augmentation/transforms/texture.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
Texture augmentation transforms.
|
||||||
|
|
||||||
|
Provides texture effects for document image augmentation:
|
||||||
|
- PaperTexture: Adds paper grain/texture
|
||||||
|
- ScannerArtifacts: Adds scanner line and dust artifacts
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
|
||||||
|
class PaperTexture(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adds paper texture/grain to the image.
|
||||||
|
|
||||||
|
Simulates different paper types and ages.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
texture_type: Type of texture ("random", "fine", "coarse") (default: "random").
|
||||||
|
intensity: Texture intensity, float or (min, max) tuple (default: (0.05, 0.15)).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "paper_texture"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
intensity = self.params.get("intensity", (0.05, 0.15))
|
||||||
|
if isinstance(intensity, (int, float)):
|
||||||
|
if not (0 < intensity <= 1):
|
||||||
|
raise ValueError("intensity must be between 0 and 1")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
texture_type = self.params.get("texture_type", "random")
|
||||||
|
intensity = self.params.get("intensity", (0.05, 0.15))
|
||||||
|
|
||||||
|
if texture_type == "random":
|
||||||
|
texture_type = rng.choice(["fine", "coarse"])
|
||||||
|
|
||||||
|
if isinstance(intensity, tuple):
|
||||||
|
intensity = rng.uniform(intensity[0], intensity[1])
|
||||||
|
|
||||||
|
# Generate base noise
|
||||||
|
if texture_type == "fine":
|
||||||
|
# Fine grain texture
|
||||||
|
noise = rng.uniform(-1, 1, (h, w)).astype(np.float32)
|
||||||
|
noise = cv2.GaussianBlur(noise, (3, 3), 0)
|
||||||
|
else:
|
||||||
|
# Coarse texture
|
||||||
|
# Generate at lower resolution and upscale
|
||||||
|
small_h, small_w = h // 4, w // 4
|
||||||
|
noise = rng.uniform(-1, 1, (small_h, small_w)).astype(np.float32)
|
||||||
|
noise = cv2.resize(noise, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
noise = cv2.GaussianBlur(noise, (5, 5), 0)
|
||||||
|
|
||||||
|
# Apply texture
|
||||||
|
output = image.astype(np.float32)
|
||||||
|
noise_3d = noise[:, :, np.newaxis] * intensity * 255
|
||||||
|
output = output + noise_3d
|
||||||
|
|
||||||
|
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=output,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={"texture_type": texture_type, "intensity": intensity},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"texture_type": "coarse", "intensity": 0.15}
|
||||||
|
|
||||||
|
|
||||||
|
class ScannerArtifacts(BaseAugmentation):
|
||||||
|
"""
|
||||||
|
Adds scanner artifacts to the image.
|
||||||
|
|
||||||
|
Simulates scanner imperfections like lines and dust spots.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
line_probability: Probability of adding scan lines (default: 0.3).
|
||||||
|
dust_probability: Probability of adding dust spots (default: 0.4).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "scanner_artifacts"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
line_prob = self.params.get("line_probability", 0.3)
|
||||||
|
dust_prob = self.params.get("dust_probability", 0.4)
|
||||||
|
if not (0 <= line_prob <= 1):
|
||||||
|
raise ValueError("line_probability must be between 0 and 1")
|
||||||
|
if not (0 <= dust_prob <= 1):
|
||||||
|
raise ValueError("dust_probability must be between 0 and 1")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
rng = rng or np.random.default_rng()
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
line_probability = self.params.get("line_probability", 0.3)
|
||||||
|
dust_probability = self.params.get("dust_probability", 0.4)
|
||||||
|
|
||||||
|
output = image.copy()
|
||||||
|
|
||||||
|
# Add scan lines
|
||||||
|
if rng.random() < line_probability:
|
||||||
|
num_lines = rng.integers(1, 4)
|
||||||
|
for _ in range(num_lines):
|
||||||
|
y = rng.integers(0, h)
|
||||||
|
thickness = rng.integers(1, 3)
|
||||||
|
# Light or dark line
|
||||||
|
color = rng.integers(200, 240) if rng.random() > 0.5 else rng.integers(50, 100)
|
||||||
|
|
||||||
|
# Make line partially transparent
|
||||||
|
alpha = rng.uniform(0.3, 0.6)
|
||||||
|
for dy in range(thickness):
|
||||||
|
if y + dy < h:
|
||||||
|
output[y + dy, :] = (
|
||||||
|
output[y + dy, :].astype(np.float32) * (1 - alpha) +
|
||||||
|
color * alpha
|
||||||
|
).astype(np.uint8)
|
||||||
|
|
||||||
|
# Add dust spots
|
||||||
|
if rng.random() < dust_probability:
|
||||||
|
num_dust = rng.integers(5, 20)
|
||||||
|
for _ in range(num_dust):
|
||||||
|
x = rng.integers(0, w)
|
||||||
|
y = rng.integers(0, h)
|
||||||
|
radius = rng.integers(1, 3)
|
||||||
|
|
||||||
|
# Dark dust spot
|
||||||
|
color = rng.integers(50, 120)
|
||||||
|
cv2.circle(output, (x, y), radius, int(color), -1)
|
||||||
|
|
||||||
|
return AugmentationResult(
|
||||||
|
image=output,
|
||||||
|
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||||
|
metadata={
|
||||||
|
"line_probability": line_probability,
|
||||||
|
"dust_probability": dust_probability,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preview_params(self) -> dict:
|
||||||
|
return {"line_probability": 0.8, "dust_probability": 0.8}
|
||||||
5
packages/shared/shared/training/__init__.py
Normal file
5
packages/shared/shared/training/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Shared training utilities."""
|
||||||
|
|
||||||
|
from .yolo_trainer import YOLOTrainer, TrainingConfig, TrainingResult
|
||||||
|
|
||||||
|
__all__ = ["YOLOTrainer", "TrainingConfig", "TrainingResult"]
|
||||||
239
packages/shared/shared/training/yolo_trainer.py
Normal file
239
packages/shared/shared/training/yolo_trainer.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
"""
|
||||||
|
Shared YOLO Training Module
|
||||||
|
|
||||||
|
Unified training logic for both CLI and Web API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
"""Training configuration."""
|
||||||
|
|
||||||
|
# Model settings
|
||||||
|
model_path: str = "yolo11n.pt" # Base model or path to trained model
|
||||||
|
data_yaml: str = "" # Path to data.yaml
|
||||||
|
|
||||||
|
# Training hyperparameters
|
||||||
|
epochs: int = 100
|
||||||
|
batch_size: int = 16
|
||||||
|
image_size: int = 640
|
||||||
|
learning_rate: float = 0.01
|
||||||
|
device: str = "0"
|
||||||
|
|
||||||
|
# Output settings
|
||||||
|
project: str = "runs/train"
|
||||||
|
name: str = "invoice_fields"
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
workers: int = 4
|
||||||
|
cache: bool = False
|
||||||
|
|
||||||
|
# Resume settings
|
||||||
|
resume: bool = False
|
||||||
|
resume_from: str | None = None # Path to checkpoint
|
||||||
|
|
||||||
|
# Document-specific augmentation (optimized for invoices)
|
||||||
|
augmentation: dict[str, Any] = field(default_factory=lambda: {
|
||||||
|
"degrees": 5.0,
|
||||||
|
"translate": 0.05,
|
||||||
|
"scale": 0.2,
|
||||||
|
"shear": 0.0,
|
||||||
|
"perspective": 0.0,
|
||||||
|
"flipud": 0.0,
|
||||||
|
"fliplr": 0.0,
|
||||||
|
"mosaic": 0.0,
|
||||||
|
"mixup": 0.0,
|
||||||
|
"hsv_h": 0.0,
|
||||||
|
"hsv_s": 0.1,
|
||||||
|
"hsv_v": 0.2,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingResult:
|
||||||
|
"""Training result."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
model_path: str | None = None
|
||||||
|
metrics: dict[str, float] = field(default_factory=dict)
|
||||||
|
error: str | None = None
|
||||||
|
save_dir: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOTrainer:
|
||||||
|
"""Unified YOLO trainer for CLI and Web API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TrainingConfig,
|
||||||
|
log_callback: Callable[[str, str], None] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Training configuration
|
||||||
|
log_callback: Optional callback for logging (level, message)
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self._log_callback = log_callback
|
||||||
|
|
||||||
|
def _log(self, level: str, message: str) -> None:
|
||||||
|
"""Log a message."""
|
||||||
|
if self._log_callback:
|
||||||
|
self._log_callback(level, message)
|
||||||
|
if level == "INFO":
|
||||||
|
logger.info(message)
|
||||||
|
elif level == "ERROR":
|
||||||
|
logger.error(message)
|
||||||
|
elif level == "WARNING":
|
||||||
|
logger.warning(message)
|
||||||
|
|
||||||
|
def validate_config(self) -> tuple[bool, str | None]:
|
||||||
|
"""
|
||||||
|
Validate training configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
"""
|
||||||
|
# Check model path
|
||||||
|
model_path = Path(self.config.model_path)
|
||||||
|
if not model_path.suffix == ".pt":
|
||||||
|
# Could be a model name like "yolo11n.pt" which is downloaded
|
||||||
|
if not model_path.name.startswith("yolo"):
|
||||||
|
return False, f"Invalid model: {self.config.model_path}"
|
||||||
|
elif not model_path.exists():
|
||||||
|
return False, f"Model file not found: {self.config.model_path}"
|
||||||
|
|
||||||
|
# Check data.yaml
|
||||||
|
if not self.config.data_yaml:
|
||||||
|
return False, "data_yaml is required"
|
||||||
|
data_yaml = Path(self.config.data_yaml)
|
||||||
|
if not data_yaml.exists():
|
||||||
|
return False, f"data.yaml not found: {self.config.data_yaml}"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def train(self) -> TrainingResult:
|
||||||
|
"""
|
||||||
|
Run YOLO training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrainingResult with model path and metrics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
except ImportError:
|
||||||
|
return TrainingResult(
|
||||||
|
success=False,
|
||||||
|
error="Ultralytics (YOLO) not installed. Install with: pip install ultralytics",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate config
|
||||||
|
is_valid, error = self.validate_config()
|
||||||
|
if not is_valid:
|
||||||
|
return TrainingResult(success=False, error=error)
|
||||||
|
|
||||||
|
self._log("INFO", f"Starting YOLO training")
|
||||||
|
self._log("INFO", f" Model: {self.config.model_path}")
|
||||||
|
self._log("INFO", f" Data: {self.config.data_yaml}")
|
||||||
|
self._log("INFO", f" Epochs: {self.config.epochs}")
|
||||||
|
self._log("INFO", f" Batch size: {self.config.batch_size}")
|
||||||
|
self._log("INFO", f" Image size: {self.config.image_size}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load model
|
||||||
|
if self.config.resume and self.config.resume_from:
|
||||||
|
resume_path = Path(self.config.resume_from)
|
||||||
|
if resume_path.exists():
|
||||||
|
self._log("INFO", f"Resuming from: {resume_path}")
|
||||||
|
model = YOLO(str(resume_path))
|
||||||
|
else:
|
||||||
|
model = YOLO(self.config.model_path)
|
||||||
|
else:
|
||||||
|
model = YOLO(self.config.model_path)
|
||||||
|
|
||||||
|
# Build training arguments
|
||||||
|
train_args = {
|
||||||
|
"data": str(Path(self.config.data_yaml).absolute()),
|
||||||
|
"epochs": self.config.epochs,
|
||||||
|
"batch": self.config.batch_size,
|
||||||
|
"imgsz": self.config.image_size,
|
||||||
|
"lr0": self.config.learning_rate,
|
||||||
|
"device": self.config.device,
|
||||||
|
"project": self.config.project,
|
||||||
|
"name": self.config.name,
|
||||||
|
"exist_ok": True,
|
||||||
|
"pretrained": True,
|
||||||
|
"verbose": True,
|
||||||
|
"workers": self.config.workers,
|
||||||
|
"cache": self.config.cache,
|
||||||
|
"resume": self.config.resume and self.config.resume_from is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add augmentation settings
|
||||||
|
train_args.update(self.config.augmentation)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
results = model.train(**train_args)
|
||||||
|
|
||||||
|
# Get best model path
|
||||||
|
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||||
|
|
||||||
|
# Extract metrics
|
||||||
|
metrics = {}
|
||||||
|
if hasattr(results, "results_dict"):
|
||||||
|
metrics = {
|
||||||
|
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||||
|
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||||
|
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||||
|
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
self._log("INFO", f"Training completed successfully")
|
||||||
|
self._log("INFO", f" Best model: {best_model}")
|
||||||
|
self._log("INFO", f" mAP@0.5: {metrics.get('mAP50', 'N/A')}")
|
||||||
|
|
||||||
|
return TrainingResult(
|
||||||
|
success=True,
|
||||||
|
model_path=str(best_model) if best_model.exists() else None,
|
||||||
|
metrics=metrics,
|
||||||
|
save_dir=str(results.save_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._log("ERROR", f"Training failed: {e}")
|
||||||
|
return TrainingResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
def validate(self, split: str = "val") -> dict[str, float]:
|
||||||
|
"""
|
||||||
|
Run validation on trained model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split: Dataset split to validate on ("val" or "test")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validation metrics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
model = YOLO(self.config.model_path)
|
||||||
|
metrics = model.val(data=self.config.data_yaml, split=split)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mAP50": metrics.box.map50,
|
||||||
|
"mAP50-95": metrics.box.map,
|
||||||
|
"precision": metrics.box.mp,
|
||||||
|
"recall": metrics.box.mr,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
self._log("ERROR", f"Validation failed: {e}")
|
||||||
|
return {}
|
||||||
@@ -199,67 +199,63 @@ def main():
|
|||||||
db.close()
|
db.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Start training
|
# Start training using shared trainer
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("Starting YOLO Training")
|
print("Starting YOLO Training")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from shared.training import YOLOTrainer, TrainingConfig
|
||||||
|
|
||||||
# Load model
|
# Determine resume checkpoint
|
||||||
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
|
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
|
||||||
if args.resume and last_checkpoint.exists():
|
resume_from = str(last_checkpoint) if args.resume and last_checkpoint.exists() else None
|
||||||
print(f"Resuming from: {last_checkpoint}")
|
|
||||||
model = YOLO(str(last_checkpoint))
|
|
||||||
else:
|
|
||||||
model = YOLO(args.model)
|
|
||||||
|
|
||||||
# Training arguments
|
# Create training config
|
||||||
data_yaml = dataset_dir / 'dataset.yaml'
|
data_yaml = dataset_dir / 'dataset.yaml'
|
||||||
train_args = {
|
config = TrainingConfig(
|
||||||
'data': str(data_yaml.absolute()),
|
model_path=args.model,
|
||||||
'epochs': args.epochs,
|
data_yaml=str(data_yaml),
|
||||||
'batch': args.batch,
|
epochs=args.epochs,
|
||||||
'imgsz': args.imgsz,
|
batch_size=args.batch,
|
||||||
'project': args.project,
|
image_size=args.imgsz,
|
||||||
'name': args.name,
|
device=args.device,
|
||||||
'device': args.device,
|
project=args.project,
|
||||||
'exist_ok': True,
|
name=args.name,
|
||||||
'pretrained': True,
|
workers=args.workers,
|
||||||
'verbose': True,
|
cache=args.cache,
|
||||||
'workers': args.workers,
|
resume=args.resume,
|
||||||
'cache': args.cache,
|
resume_from=resume_from,
|
||||||
'resume': args.resume and last_checkpoint.exists(),
|
)
|
||||||
# Document-specific augmentation settings
|
|
||||||
'degrees': 5.0,
|
|
||||||
'translate': 0.05,
|
|
||||||
'scale': 0.2,
|
|
||||||
'shear': 0.0,
|
|
||||||
'perspective': 0.0,
|
|
||||||
'flipud': 0.0,
|
|
||||||
'fliplr': 0.0,
|
|
||||||
'mosaic': 0.0,
|
|
||||||
'mixup': 0.0,
|
|
||||||
'hsv_h': 0.0,
|
|
||||||
'hsv_s': 0.1,
|
|
||||||
'hsv_v': 0.2,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Train
|
# Run training
|
||||||
results = model.train(**train_args)
|
trainer = YOLOTrainer(config=config)
|
||||||
|
result = trainer.train()
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
print(f"\nError: Training failed - {result.error}")
|
||||||
|
db.close()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("Training Complete")
|
print("Training Complete")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
|
print(f"Best model: {result.model_path}")
|
||||||
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
|
print(f"Save directory: {result.save_dir}")
|
||||||
|
if result.metrics:
|
||||||
|
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
|
||||||
|
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
|
||||||
|
|
||||||
# Validate on test set
|
# Validate on test set
|
||||||
print("\nRunning validation on test set...")
|
print("\nRunning validation on test set...")
|
||||||
metrics = model.val(split='test')
|
if result.model_path:
|
||||||
print(f"mAP50: {metrics.box.map50:.4f}")
|
config.model_path = result.model_path
|
||||||
print(f"mAP50-95: {metrics.box.map:.4f}")
|
config.data_yaml = str(data_yaml)
|
||||||
|
test_trainer = YOLOTrainer(config=config)
|
||||||
|
test_metrics = test_trainer.validate(split='test')
|
||||||
|
if test_metrics:
|
||||||
|
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
|
||||||
|
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
|
||||||
|
|
||||||
# Close database
|
# Close database
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
106
runs_backup/train/invoice_fields/args.yaml
Normal file
106
runs_backup/train/invoice_fields/args.yaml
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
task: detect
|
||||||
|
mode: train
|
||||||
|
model: runs/train/invoice_fields/weights/last.pt
|
||||||
|
data: /home/kai/invoice-data/dataset/dataset.yaml
|
||||||
|
epochs: 100
|
||||||
|
time: null
|
||||||
|
patience: 100
|
||||||
|
batch: 8
|
||||||
|
imgsz: 1280
|
||||||
|
save: true
|
||||||
|
save_period: -1
|
||||||
|
cache: false
|
||||||
|
device: '0'
|
||||||
|
workers: 8
|
||||||
|
project: runs/train
|
||||||
|
name: invoice_fields
|
||||||
|
exist_ok: true
|
||||||
|
pretrained: true
|
||||||
|
optimizer: auto
|
||||||
|
verbose: true
|
||||||
|
seed: 0
|
||||||
|
deterministic: true
|
||||||
|
single_cls: false
|
||||||
|
rect: false
|
||||||
|
cos_lr: false
|
||||||
|
close_mosaic: 10
|
||||||
|
resume: runs/train/invoice_fields/weights/last.pt
|
||||||
|
amp: true
|
||||||
|
fraction: 1.0
|
||||||
|
profile: false
|
||||||
|
freeze: null
|
||||||
|
multi_scale: false
|
||||||
|
compile: false
|
||||||
|
overlap_mask: true
|
||||||
|
mask_ratio: 4
|
||||||
|
dropout: 0.0
|
||||||
|
val: true
|
||||||
|
split: val
|
||||||
|
save_json: false
|
||||||
|
conf: null
|
||||||
|
iou: 0.7
|
||||||
|
max_det: 300
|
||||||
|
half: false
|
||||||
|
dnn: false
|
||||||
|
plots: true
|
||||||
|
source: null
|
||||||
|
vid_stride: 1
|
||||||
|
stream_buffer: false
|
||||||
|
visualize: false
|
||||||
|
augment: false
|
||||||
|
agnostic_nms: false
|
||||||
|
classes: null
|
||||||
|
retina_masks: false
|
||||||
|
embed: null
|
||||||
|
show: false
|
||||||
|
save_frames: false
|
||||||
|
save_txt: false
|
||||||
|
save_conf: false
|
||||||
|
save_crop: false
|
||||||
|
show_labels: true
|
||||||
|
show_conf: true
|
||||||
|
show_boxes: true
|
||||||
|
line_width: null
|
||||||
|
format: torchscript
|
||||||
|
keras: false
|
||||||
|
optimize: false
|
||||||
|
int8: false
|
||||||
|
dynamic: false
|
||||||
|
simplify: true
|
||||||
|
opset: null
|
||||||
|
workspace: null
|
||||||
|
nms: false
|
||||||
|
lr0: 0.01
|
||||||
|
lrf: 0.01
|
||||||
|
momentum: 0.937
|
||||||
|
weight_decay: 0.0005
|
||||||
|
warmup_epochs: 3.0
|
||||||
|
warmup_momentum: 0.8
|
||||||
|
warmup_bias_lr: 0.0
|
||||||
|
box: 7.5
|
||||||
|
cls: 0.5
|
||||||
|
dfl: 1.5
|
||||||
|
pose: 12.0
|
||||||
|
kobj: 1.0
|
||||||
|
nbs: 64
|
||||||
|
hsv_h: 0.0
|
||||||
|
hsv_s: 0.1
|
||||||
|
hsv_v: 0.2
|
||||||
|
degrees: 5.0
|
||||||
|
translate: 0.05
|
||||||
|
scale: 0.2
|
||||||
|
shear: 0.0
|
||||||
|
perspective: 0.0
|
||||||
|
flipud: 0.0
|
||||||
|
fliplr: 0.0
|
||||||
|
bgr: 0.0
|
||||||
|
mosaic: 0.0
|
||||||
|
mixup: 0.0
|
||||||
|
cutmix: 0.0
|
||||||
|
copy_paste: 0.0
|
||||||
|
copy_paste_mode: flip
|
||||||
|
auto_augment: randaugment
|
||||||
|
erasing: 0.4
|
||||||
|
cfg: null
|
||||||
|
tracker: botsort.yaml
|
||||||
|
save_dir: /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields
|
||||||
101
runs_backup/train/invoice_fields/results.csv
Normal file
101
runs_backup/train/invoice_fields/results.csv
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
|
||||||
|
1,507.213,0.97459,2.44056,1.09407,0.62221,0.6847,0.6581,0.53505,0.66652,1.13762,0.89651,0.00333127,0.00333127,0.00333127
|
||||||
|
2,960.144,0.65449,1.06648,0.90865,0.62998,0.73947,0.70066,0.58144,0.63394,1.03248,0.89183,0.00659863,0.00659863,0.00659863
|
||||||
|
3,1415.64,0.64921,0.97725,0.90901,0.70271,0.7713,0.76974,0.60362,0.7453,0.92005,0.9044,0.00979998,0.00979998,0.00979998
|
||||||
|
4,1871.6,0.61159,0.90785,0.90192,0.7264,0.77507,0.78958,0.66909,0.57561,0.83662,0.87737,0.009703,0.009703,0.009703
|
||||||
|
5,2317.68,0.55503,0.81107,0.88628,0.75742,0.81685,0.8287,0.70772,0.57193,0.74866,0.87531,0.009604,0.009604,0.009604
|
||||||
|
6,2756.85,0.52443,0.75067,0.8767,0.77362,0.82743,0.84758,0.73413,0.53604,0.70643,0.86947,0.009505,0.009505,0.009505
|
||||||
|
7,3211.38,0.49889,0.70526,0.8696,0.79041,0.83537,0.86367,0.75099,0.52784,0.67021,0.86686,0.009406,0.009406,0.009406
|
||||||
|
8,3657.41,0.47889,0.66715,0.86261,0.80402,0.85067,0.87834,0.7573,0.5365,0.64019,0.86754,0.009307,0.009307,0.009307
|
||||||
|
9,4100.61,0.46417,0.63485,0.85761,0.81776,0.84993,0.88612,0.77638,0.50423,0.61379,0.85539,0.009208,0.009208,0.009208
|
||||||
|
10,4544.47,0.45154,0.61098,0.85334,0.82044,0.86313,0.89168,0.7685,0.54578,0.59853,0.8668,0.009109,0.009109,0.009109
|
||||||
|
11,4983.87,0.44,0.58833,0.84961,0.82995,0.87408,0.90194,0.7972,0.48628,0.57037,0.84686,0.00901,0.00901,0.00901
|
||||||
|
12,5419.2,0.42826,0.57051,0.84622,0.83627,0.87342,0.90482,0.79702,0.49356,0.55869,0.84836,0.008911,0.008911,0.008911
|
||||||
|
13,5854.78,0.42052,0.55192,0.84327,0.83405,0.88243,0.90883,0.79807,0.50137,0.54503,0.85114,0.008812,0.008812,0.008812
|
||||||
|
14,6288.44,0.41497,0.53873,0.84105,0.83687,0.8769,0.90767,0.79881,0.49873,0.54655,0.85078,0.008713,0.008713,0.008713
|
||||||
|
15,6724.06,0.40759,0.52501,0.83904,0.84211,0.88472,0.91329,0.81111,0.47871,0.52982,0.84461,0.008614,0.008614,0.008614
|
||||||
|
16,7164.64,0.39956,0.51303,0.83591,0.85078,0.88609,0.91766,0.80948,0.49199,0.52052,0.84634,0.008515,0.008515,0.008515
|
||||||
|
17,7626.93,0.39651,0.50255,0.83453,0.84894,0.89151,0.92036,0.80185,0.5227,0.52229,0.8514,0.008416,0.008416,0.008416
|
||||||
|
18,490.978,0.3953,0.49497,0.83435,0.86113,0.87974,0.92114,0.80521,0.51284,0.51491,0.84926,0.008317,0.008317,0.008317
|
||||||
|
19,976.368,0.39234,0.49242,0.8328,0.86521,0.88431,0.92347,0.80898,0.5037,0.50744,0.84824,0.008218,0.008218,0.008218
|
||||||
|
20,1457.68,0.3864,0.48591,0.8308,0.86361,0.88895,0.92337,0.80562,0.51844,0.50692,0.8497,0.008119,0.008119,0.008119
|
||||||
|
21,1938.95,0.3826,0.47627,0.83086,0.85282,0.89488,0.92276,0.80107,0.53348,0.50706,0.85418,0.00802,0.00802,0.00802
|
||||||
|
22,2418,0.37996,0.46715,0.82911,0.86698,0.89705,0.92757,0.81493,0.5038,0.49105,0.84342,0.007921,0.007921,0.007921
|
||||||
|
23,2900.33,0.37555,0.46089,0.82769,0.86871,0.89566,0.92949,0.82251,0.48233,0.48473,0.84078,0.007822,0.007822,0.007822
|
||||||
|
24,3381.33,0.3717,0.45531,0.82676,0.87203,0.89509,0.93002,0.80968,0.52892,0.49231,0.85017,0.007723,0.007723,0.007723
|
||||||
|
25,3867.95,0.36913,0.44591,0.82544,0.87636,0.89074,0.93018,0.81083,0.5286,0.4882,0.84889,0.007624,0.007624,0.007624
|
||||||
|
26,4354.94,0.3662,0.44016,0.82483,0.86957,0.89807,0.92963,0.80554,0.5486,0.49546,0.8527,0.007525,0.007525,0.007525
|
||||||
|
27,4839.34,0.36364,0.43794,0.82368,0.87515,0.89602,0.93102,0.80927,0.54175,0.49001,0.85159,0.007426,0.007426,0.007426
|
||||||
|
28,5324.36,0.36043,0.42951,0.8234,0.87074,0.90178,0.93175,0.8107,0.53742,0.4862,0.85096,0.007327,0.007327,0.007327
|
||||||
|
29,5810.68,0.35852,0.42834,0.82242,0.87574,0.8992,0.93108,0.80885,0.53965,0.48712,0.85055,0.007228,0.007228,0.007228
|
||||||
|
30,6294.55,0.35572,0.42015,0.82155,0.87635,0.90084,0.93172,0.81312,0.5252,0.48089,0.84711,0.007129,0.007129,0.007129
|
||||||
|
31,6778.25,0.3539,0.41622,0.82025,0.8777,0.90106,0.93275,0.81359,0.52726,0.47864,0.84712,0.00703,0.00703,0.00703
|
||||||
|
32,7261.88,0.35095,0.4105,0.82048,0.87684,0.90364,0.9337,0.81598,0.52269,0.47547,0.84625,0.006931,0.006931,0.006931
|
||||||
|
33,7745.73,0.34835,0.40539,0.81904,0.87629,0.90454,0.933,0.81501,0.52381,0.47649,0.84713,0.006832,0.006832,0.006832
|
||||||
|
34,8234.78,0.34699,0.402,0.8182,0.8755,0.90506,0.93336,0.81532,0.52435,0.4756,0.84749,0.006733,0.006733,0.006733
|
||||||
|
35,8716.45,0.34487,0.39711,0.81704,0.87551,0.90432,0.93325,0.81529,0.52489,0.47558,0.84763,0.006634,0.006634,0.006634
|
||||||
|
36,9200.09,0.34343,0.39638,0.81749,0.87551,0.90495,0.93343,0.81477,0.52659,0.4765,0.84798,0.006535,0.006535,0.006535
|
||||||
|
37,9683.76,0.34013,0.3907,0.81647,0.87738,0.90416,0.93371,0.8152,0.52755,0.47568,0.84807,0.006436,0.006436,0.006436
|
||||||
|
38,10164.4,0.33826,0.38626,0.81468,0.87923,0.90342,0.93399,0.81571,0.52756,0.47527,0.84794,0.006337,0.006337,0.006337
|
||||||
|
39,10648.6,0.3366,0.3812,0.81517,0.8786,0.90333,0.93368,0.81448,0.53035,0.47601,0.84857,0.006238,0.006238,0.006238
|
||||||
|
40,11130.6,0.3353,0.37879,0.8151,0.87974,0.90405,0.93411,0.81544,0.52739,0.47333,0.84768,0.006139,0.006139,0.006139
|
||||||
|
41,11612.6,0.33395,0.37397,0.8143,0.88034,0.90315,0.93432,0.81514,0.529,0.47304,0.84765,0.00604,0.00604,0.00604
|
||||||
|
42,12097.5,0.33104,0.37164,0.81448,0.87942,0.90449,0.93429,0.81484,0.53077,0.47386,0.84799,0.005941,0.005941,0.005941
|
||||||
|
43,12579.8,0.33016,0.3681,0.81305,0.87964,0.90412,0.93457,0.81473,0.53129,0.47382,0.84789,0.005842,0.005842,0.005842
|
||||||
|
44,13065.3,0.32845,0.36431,0.81191,0.88092,0.90312,0.9348,0.81538,0.53025,0.47337,0.84751,0.005743,0.005743,0.005743
|
||||||
|
45,13550.4,0.32642,0.36127,0.81295,0.8805,0.905,0.93502,0.81544,0.53124,0.47333,0.8477,0.005644,0.005644,0.005644
|
||||||
|
46,14034.3,0.32483,0.35761,0.81158,0.87898,0.90636,0.935,0.81556,0.53135,0.47317,0.84772,0.005545,0.005545,0.005545
|
||||||
|
47,14517.3,0.32236,0.35337,0.81014,0.88018,0.90502,0.93493,0.81547,0.53228,0.473,0.8478,0.005446,0.005446,0.005446
|
||||||
|
48,14998.6,0.3211,0.35051,0.81064,0.87941,0.9055,0.93481,0.81473,0.5353,0.47335,0.84839,0.005347,0.005347,0.005347
|
||||||
|
49,15479.2,0.32043,0.34797,0.8097,0.87884,0.90584,0.93482,0.81429,0.53741,0.47359,0.84867,0.005248,0.005248,0.005248
|
||||||
|
50,15962,0.3182,0.34589,0.80867,0.87776,0.90777,0.93476,0.81395,0.53841,0.47358,0.84884,0.005149,0.005149,0.005149
|
||||||
|
51,16445.4,0.31722,0.34332,0.80879,0.87932,0.90605,0.93488,0.81439,0.53827,0.47301,0.84874,0.00505,0.00505,0.00505
|
||||||
|
52,16925.3,0.31437,0.33963,0.80846,0.87925,0.90688,0.93521,0.81468,0.53772,0.47254,0.84861,0.004951,0.004951,0.004951
|
||||||
|
53,17410.4,0.31435,0.33632,0.80792,0.88015,0.9069,0.93538,0.8148,0.53789,0.47193,0.84865,0.004852,0.004852,0.004852
|
||||||
|
54,17893.2,0.31341,0.33352,0.80833,0.88132,0.90605,0.93552,0.81465,0.53934,0.47218,0.84889,0.004753,0.004753,0.004753
|
||||||
|
55,18375.2,0.31091,0.33239,0.80688,0.88151,0.90617,0.93547,0.81453,0.54019,0.47227,0.84885,0.004654,0.004654,0.004654
|
||||||
|
56,18857.4,0.30906,0.32753,0.80616,0.88348,0.90504,0.93565,0.81495,0.53974,0.47179,0.84864,0.004555,0.004555,0.004555
|
||||||
|
57,19340.6,0.30722,0.32334,0.80582,0.88502,0.90326,0.93558,0.81484,0.53977,0.47174,0.84859,0.004456,0.004456,0.004456
|
||||||
|
58,19824.8,0.30575,0.3195,0.80592,0.88348,0.90487,0.93566,0.81511,0.53921,0.47158,0.84832,0.004357,0.004357,0.004357
|
||||||
|
59,20305.7,0.30426,0.31846,0.80494,0.88419,0.90477,0.93575,0.81534,0.5387,0.47138,0.84819,0.004258,0.004258,0.004258
|
||||||
|
60,20785.1,0.30295,0.3154,0.80494,0.88302,0.90624,0.93568,0.81572,0.53769,0.47106,0.84788,0.004159,0.004159,0.004159
|
||||||
|
61,21263.8,0.3013,0.3131,0.80436,0.88438,0.90545,0.93572,0.81606,0.53622,0.47079,0.84762,0.00406,0.00406,0.00406
|
||||||
|
62,21746,0.30019,0.31077,0.80391,0.88296,0.90732,0.93578,0.8165,0.53455,0.47011,0.84724,0.003961,0.003961,0.003961
|
||||||
|
63,22225,0.29841,0.30656,0.80379,0.88244,0.90779,0.93591,0.81693,0.53417,0.47003,0.84715,0.003862,0.003862,0.003862
|
||||||
|
64,22704.7,0.29696,0.30489,0.80284,0.88493,0.90596,0.9359,0.81716,0.53362,0.47013,0.84707,0.003763,0.003763,0.003763
|
||||||
|
65,23182.9,0.2952,0.30022,0.80288,0.88366,0.90731,0.93594,0.81737,0.533,0.47024,0.84695,0.003664,0.003664,0.003664
|
||||||
|
66,23663.6,0.29337,0.29898,0.80273,0.88514,0.90611,0.93609,0.81805,0.53189,0.47015,0.84664,0.003565,0.003565,0.003565
|
||||||
|
67,24149.7,0.29248,0.29492,0.80242,0.88664,0.90536,0.93607,0.81783,0.53208,0.4704,0.84665,0.003466,0.003466,0.003466
|
||||||
|
68,24629,0.28987,0.29155,0.8009,0.89014,0.90207,0.93611,0.81811,0.53203,0.47046,0.84656,0.003367,0.003367,0.003367
|
||||||
|
69,25109.9,0.28942,0.29004,0.8011,0.88939,0.90353,0.93619,0.81872,0.53127,0.4704,0.84631,0.003268,0.003268,0.003268
|
||||||
|
70,25590.5,0.28752,0.28571,0.80059,0.88926,0.90393,0.93627,0.81909,0.53074,0.47023,0.84624,0.003169,0.003169,0.003169
|
||||||
|
71,26072.8,0.28546,0.28301,0.7999,0.88844,0.90476,0.93631,0.81967,0.52999,0.47005,0.84612,0.00307,0.00307,0.00307
|
||||||
|
72,26552.9,0.2842,0.28027,0.79942,0.88801,0.90505,0.93622,0.81978,0.52939,0.46994,0.84607,0.002971,0.002971,0.002971
|
||||||
|
73,27035.5,0.28297,0.27907,0.79956,0.88781,0.90499,0.93615,0.82032,0.528,0.4694,0.84578,0.002872,0.002872,0.002872
|
||||||
|
74,27518.8,0.2812,0.27446,0.79886,0.88848,0.90493,0.93611,0.82061,0.52675,0.46906,0.84549,0.002773,0.002773,0.002773
|
||||||
|
75,28007.4,0.27866,0.27202,0.79684,0.88889,0.90467,0.9361,0.82099,0.5257,0.4692,0.84529,0.002674,0.002674,0.002674
|
||||||
|
76,28499.1,0.27708,0.26798,0.7978,0.88807,0.9054,0.93615,0.82138,0.52574,0.46928,0.84523,0.002575,0.002575,0.002575
|
||||||
|
77,28993.2,0.27398,0.2644,0.79612,0.88825,0.9055,0.93616,0.82161,0.52491,0.46925,0.84496,0.002476,0.002476,0.002476
|
||||||
|
78,29480.5,0.27359,0.26209,0.79678,0.88876,0.90547,0.93617,0.82172,0.52467,0.4691,0.84498,0.002377,0.002377,0.002377
|
||||||
|
79,29970.7,0.27153,0.25905,0.79585,0.88942,0.90548,0.93613,0.82211,0.52407,0.46892,0.84474,0.002278,0.002278,0.002278
|
||||||
|
80,30453.8,0.2696,0.25647,0.79513,0.88897,0.90665,0.93617,0.82246,0.52298,0.46881,0.84445,0.002179,0.002179,0.002179
|
||||||
|
81,30936.1,0.26895,0.25375,0.79466,0.88799,0.90791,0.93617,0.8226,0.52253,0.46904,0.84445,0.00208,0.00208,0.00208
|
||||||
|
82,31418.6,0.26733,0.25025,0.79474,0.88945,0.90695,0.93608,0.82293,0.52172,0.4694,0.84434,0.001981,0.001981,0.001981
|
||||||
|
83,31911.4,0.26537,0.24754,0.79479,0.89112,0.90496,0.93604,0.82338,0.52094,0.46932,0.84421,0.001882,0.001882,0.001882
|
||||||
|
84,32402,0.26344,0.24514,0.79369,0.8928,0.90319,0.93598,0.82353,0.52015,0.46957,0.84407,0.001783,0.001783,0.001783
|
||||||
|
85,32903.1,0.26045,0.24052,0.79226,0.89211,0.90347,0.93615,0.82427,0.51861,0.46958,0.84372,0.001684,0.001684,0.001684
|
||||||
|
86,33413.6,0.25867,0.23781,0.79209,0.89286,0.90279,0.9362,0.82493,0.51664,0.47018,0.84338,0.001585,0.001585,0.001585
|
||||||
|
87,33923.6,0.257,0.23463,0.792,0.89299,0.90297,0.93614,0.8254,0.5147,0.46974,0.84305,0.001486,0.001486,0.001486
|
||||||
|
88,34436.4,0.25569,0.23153,0.79149,0.89278,0.90277,0.93609,0.82622,0.51242,0.4697,0.84266,0.001387,0.001387,0.001387
|
||||||
|
89,34949.6,0.25343,0.22868,0.791,0.89137,0.90434,0.93599,0.82675,0.51036,0.46947,0.84227,0.001288,0.001288,0.001288
|
||||||
|
90,35449.6,0.25194,0.22502,0.79051,0.89128,0.90489,0.93591,0.82729,0.5092,0.46975,0.84219,0.001189,0.001189,0.001189
|
||||||
|
91,35960.5,0.2502,0.22225,0.78959,0.8898,0.90646,0.93586,0.82781,0.50761,0.46999,0.84186,0.00109,0.00109,0.00109
|
||||||
|
92,36452.1,0.24777,0.21844,0.78906,0.89054,0.9057,0.93593,0.82831,0.50603,0.47043,0.84154,0.000991,0.000991,0.000991
|
||||||
|
93,36942.9,0.24554,0.21503,0.78861,0.88979,0.90679,0.93584,0.82858,0.50495,0.4703,0.84125,0.000892,0.000892,0.000892
|
||||||
|
94,37430.3,0.2434,0.21193,0.78799,0.88928,0.90756,0.93566,0.8288,0.5041,0.47075,0.84109,0.000793,0.000793,0.000793
|
||||||
|
95,37918.7,0.2413,0.20892,0.78736,0.8899,0.90683,0.93567,0.82882,0.50339,0.47152,0.84105,0.000694,0.000694,0.000694
|
||||||
|
96,38404.3,0.2405,0.20619,0.78595,0.88999,0.90713,0.9355,0.82912,0.50244,0.47239,0.84104,0.000595,0.000595,0.000595
|
||||||
|
97,38893.8,0.23808,0.2031,0.78683,0.88982,0.90634,0.93531,0.82938,0.50187,0.47281,0.84116,0.000496,0.000496,0.000496
|
||||||
|
98,39382.8,0.23581,0.20034,0.78643,0.89144,0.9045,0.93517,0.82959,0.50119,0.47383,0.84128,0.000397,0.000397,0.000397
|
||||||
|
99,39871.7,0.23432,0.19778,0.78568,0.89187,0.90415,0.93488,0.82953,0.50058,0.47452,0.84126,0.000298,0.000298,0.000298
|
||||||
|
100,40359.7,0.233,0.19485,0.78528,0.89228,0.90349,0.93471,0.82961,0.50029,0.47497,0.84139,0.000199,0.000199,0.000199
|
||||||
|
106
runs_backup/train/invoice_yolo11n_full/args.yaml
Normal file
106
runs_backup/train/invoice_yolo11n_full/args.yaml
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
task: detect
|
||||||
|
mode: train
|
||||||
|
model: yolo11n.pt
|
||||||
|
data: /home/kai/invoice-data/dataset/dataset.yaml
|
||||||
|
epochs: 100
|
||||||
|
time: null
|
||||||
|
patience: 100
|
||||||
|
batch: 16
|
||||||
|
imgsz: 1280
|
||||||
|
save: true
|
||||||
|
save_period: -1
|
||||||
|
cache: false
|
||||||
|
device: '0'
|
||||||
|
workers: 8
|
||||||
|
project: runs/train
|
||||||
|
name: invoice_yolo11n_full
|
||||||
|
exist_ok: true
|
||||||
|
pretrained: true
|
||||||
|
optimizer: auto
|
||||||
|
verbose: true
|
||||||
|
seed: 0
|
||||||
|
deterministic: true
|
||||||
|
single_cls: false
|
||||||
|
rect: false
|
||||||
|
cos_lr: false
|
||||||
|
close_mosaic: 10
|
||||||
|
resume: false
|
||||||
|
amp: true
|
||||||
|
fraction: 1.0
|
||||||
|
profile: false
|
||||||
|
freeze: null
|
||||||
|
multi_scale: false
|
||||||
|
compile: false
|
||||||
|
overlap_mask: true
|
||||||
|
mask_ratio: 4
|
||||||
|
dropout: 0.0
|
||||||
|
val: true
|
||||||
|
split: val
|
||||||
|
save_json: false
|
||||||
|
conf: null
|
||||||
|
iou: 0.7
|
||||||
|
max_det: 300
|
||||||
|
half: false
|
||||||
|
dnn: false
|
||||||
|
plots: true
|
||||||
|
source: null
|
||||||
|
vid_stride: 1
|
||||||
|
stream_buffer: false
|
||||||
|
visualize: false
|
||||||
|
augment: false
|
||||||
|
agnostic_nms: false
|
||||||
|
classes: null
|
||||||
|
retina_masks: false
|
||||||
|
embed: null
|
||||||
|
show: false
|
||||||
|
save_frames: false
|
||||||
|
save_txt: false
|
||||||
|
save_conf: false
|
||||||
|
save_crop: false
|
||||||
|
show_labels: true
|
||||||
|
show_conf: true
|
||||||
|
show_boxes: true
|
||||||
|
line_width: null
|
||||||
|
format: torchscript
|
||||||
|
keras: false
|
||||||
|
optimize: false
|
||||||
|
int8: false
|
||||||
|
dynamic: false
|
||||||
|
simplify: true
|
||||||
|
opset: null
|
||||||
|
workspace: null
|
||||||
|
nms: false
|
||||||
|
lr0: 0.01
|
||||||
|
lrf: 0.01
|
||||||
|
momentum: 0.937
|
||||||
|
weight_decay: 0.0005
|
||||||
|
warmup_epochs: 3.0
|
||||||
|
warmup_momentum: 0.8
|
||||||
|
warmup_bias_lr: 0.1
|
||||||
|
box: 7.5
|
||||||
|
cls: 0.5
|
||||||
|
dfl: 1.5
|
||||||
|
pose: 12.0
|
||||||
|
kobj: 1.0
|
||||||
|
nbs: 64
|
||||||
|
hsv_h: 0.0
|
||||||
|
hsv_s: 0.1
|
||||||
|
hsv_v: 0.2
|
||||||
|
degrees: 5.0
|
||||||
|
translate: 0.05
|
||||||
|
scale: 0.2
|
||||||
|
shear: 0.0
|
||||||
|
perspective: 0.0
|
||||||
|
flipud: 0.0
|
||||||
|
fliplr: 0.0
|
||||||
|
bgr: 0.0
|
||||||
|
mosaic: 0.0
|
||||||
|
mixup: 0.0
|
||||||
|
cutmix: 0.0
|
||||||
|
copy_paste: 0.0
|
||||||
|
copy_paste_mode: flip
|
||||||
|
auto_augment: randaugment
|
||||||
|
erasing: 0.4
|
||||||
|
cfg: null
|
||||||
|
tracker: botsort.yaml
|
||||||
|
save_dir: /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_yolo11n_full
|
||||||
101
runs_backup/train/invoice_yolo11n_full/results.csv
Normal file
101
runs_backup/train/invoice_yolo11n_full/results.csv
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
|
||||||
|
1,217.641,0.79856,2.56507,1.01986,0.8921,0.84545,0.90033,0.81508,0.49347,0.92369,0.83815,0.00332991,0.00332991,0.00332991
|
||||||
|
2,410.275,0.506,1.04596,0.85661,0.9164,0.85418,0.92852,0.79726,0.63851,0.72277,0.87404,0.00659728,0.00659728,0.00659728
|
||||||
|
3,598.713,0.49618,0.68647,0.85624,0.93775,0.80014,0.91481,0.80383,0.5798,0.76691,0.86835,0.00979865,0.00979865,0.00979865
|
||||||
|
4,782.868,0.44059,0.53532,0.84299,0.94668,0.90421,0.96101,0.8832,0.43961,0.49649,0.84298,0.009703,0.009703,0.009703
|
||||||
|
5,967.898,0.37596,0.44308,0.82667,0.88316,0.81492,0.91376,0.82272,0.50616,0.70202,0.85825,0.009604,0.009604,0.009604
|
||||||
|
6,1152.03,0.33999,0.39482,0.81661,0.81567,0.73691,0.82644,0.75085,0.43451,0.92038,0.86158,0.009505,0.009505,0.009505
|
||||||
|
7,1335.35,0.31114,0.35971,0.80992,0.95256,0.89807,0.96383,0.84241,0.60248,0.48455,0.88156,0.009406,0.009406,0.009406
|
||||||
|
8,1518.68,0.29176,0.33987,0.80516,0.97058,0.91185,0.97221,0.85356,0.58408,0.43771,0.86239,0.009307,0.009307,0.009307
|
||||||
|
9,1702.03,0.27683,0.3214,0.80166,0.96403,0.91663,0.9736,0.85118,0.6359,0.43055,0.88091,0.009208,0.009208,0.009208
|
||||||
|
10,1891.76,0.26487,0.30796,0.79943,0.96201,0.92669,0.97715,0.894,0.46381,0.37437,0.84314,0.009109,0.009109,0.009109
|
||||||
|
11,2081.79,0.25744,0.29846,0.79614,0.96562,0.92382,0.97554,0.79415,0.81302,0.46321,0.93576,0.00901,0.00901,0.00901
|
||||||
|
12,2273.34,0.24726,0.28842,0.79445,0.96248,0.92901,0.97544,0.7642,0.93193,0.48807,0.98769,0.008911,0.008911,0.008911
|
||||||
|
13,2461.71,0.24266,0.27619,0.79413,0.9672,0.93016,0.97834,0.69208,1.24698,0.59927,1.19042,0.008812,0.008812,0.008812
|
||||||
|
14,2649.81,0.23391,0.26941,0.79165,0.96579,0.93247,0.98028,0.72182,1.0815,0.50846,1.08505,0.008713,0.008713,0.008713
|
||||||
|
15,2837.95,0.22893,0.2651,0.79082,0.9639,0.93414,0.9807,0.8522,0.64306,0.38931,0.88909,0.008614,0.008614,0.008614
|
||||||
|
16,3021.27,0.22269,0.25369,0.78809,0.97667,0.92754,0.98283,0.79198,0.89233,0.43512,0.98623,0.008515,0.008515,0.008515
|
||||||
|
17,3209.3,0.21937,0.24886,0.78797,0.96559,0.93193,0.98178,0.67198,1.35518,0.59949,1.30325,0.008416,0.008416,0.008416
|
||||||
|
18,3400.76,0.21415,0.24489,0.78789,0.95973,0.94489,0.98156,0.63227,1.45967,0.63486,1.42858,0.008317,0.008317,0.008317
|
||||||
|
19,3590.96,0.21227,0.23986,0.78736,0.96136,0.94369,0.98263,0.83035,0.76379,0.39831,0.92541,0.008218,0.008218,0.008218
|
||||||
|
20,3779.15,0.20834,0.23506,0.78475,0.96214,0.93563,0.97908,0.5024,1.92081,0.81282,1.99717,0.008119,0.008119,0.008119
|
||||||
|
21,3976.3,0.20592,0.23055,0.78534,0.9636,0.94141,0.98186,0.71087,1.20783,0.53596,1.17998,0.00802,0.00802,0.00802
|
||||||
|
22,4165.69,0.20195,0.22554,0.78431,0.96621,0.94394,0.98458,0.86353,0.6245,0.35194,0.86978,0.007921,0.007921,0.007921
|
||||||
|
23,4357.69,0.19847,0.22066,0.78362,0.9745,0.93877,0.98501,0.84365,0.71155,0.37717,0.91395,0.007822,0.007822,0.007822
|
||||||
|
24,4548.46,0.19715,0.21991,0.78423,0.96456,0.94907,0.98541,0.77136,0.9901,0.45056,1.02401,0.007723,0.007723,0.007723
|
||||||
|
25,4738.2,0.19207,0.21284,0.7821,0.97136,0.94417,0.98568,0.8139,0.83053,0.41526,0.94475,0.007624,0.007624,0.007624
|
||||||
|
26,4926.95,0.19124,0.21138,0.7823,0.9712,0.94333,0.98466,0.78106,0.94702,0.44977,1.0068,0.007525,0.007525,0.007525
|
||||||
|
27,5115.34,0.18944,0.21166,0.78245,0.97207,0.9347,0.98325,0.57941,1.64865,0.72474,1.70082,0.007426,0.007426,0.007426
|
||||||
|
28,5303.92,0.18817,0.20777,0.7814,0.96837,0.9519,0.98672,0.77596,0.96734,0.43218,1.01592,0.007327,0.007327,0.007327
|
||||||
|
29,5493.53,0.18565,0.20231,0.78154,0.9719,0.94481,0.98552,0.67875,1.27094,0.57309,1.27411,0.007228,0.007228,0.007228
|
||||||
|
30,5682.31,0.18424,0.19916,0.77989,0.96714,0.95269,0.98588,0.73712,1.08764,0.52054,1.09884,0.007129,0.007129,0.007129
|
||||||
|
31,5870.11,0.1812,0.19544,0.78013,0.9698,0.95028,0.98687,0.72258,1.15995,0.50918,1.14823,0.00703,0.00703,0.00703
|
||||||
|
32,6060.08,0.1801,0.19571,0.7799,0.9699,0.95342,0.98761,0.83372,0.70788,0.36986,0.89674,0.006931,0.006931,0.006931
|
||||||
|
33,6244.87,0.17816,0.19373,0.7789,0.96278,0.95546,0.98684,0.85265,0.66817,0.36063,0.88589,0.006832,0.006832,0.006832
|
||||||
|
34,6427.13,0.176,0.1909,0.77864,0.96962,0.95053,0.98707,0.85351,0.70717,0.36925,0.90913,0.006733,0.006733,0.006733
|
||||||
|
35,6619.57,0.17377,0.18513,0.77794,0.97418,0.94725,0.98662,0.83346,0.77385,0.3987,0.92964,0.006634,0.006634,0.006634
|
||||||
|
36,6807.72,0.17359,0.18454,0.77885,0.97363,0.95,0.98703,0.85072,0.72917,0.36745,0.91556,0.006535,0.006535,0.006535
|
||||||
|
37,7001.26,0.17126,0.18179,0.77796,0.96337,0.95646,0.98744,0.79259,0.86734,0.4072,0.96666,0.006436,0.006436,0.006436
|
||||||
|
38,7186.91,0.16989,0.17967,0.77791,0.97277,0.94891,0.98737,0.8268,0.78213,0.38811,0.9293,0.006337,0.006337,0.006337
|
||||||
|
39,7372.72,0.16823,0.17959,0.77698,0.96961,0.95206,0.98714,0.86035,0.69764,0.35696,0.90589,0.006238,0.006238,0.006238
|
||||||
|
40,7558.51,0.16639,0.17648,0.77676,0.96471,0.9592,0.98756,0.85427,0.70458,0.35248,0.90042,0.006139,0.006139,0.006139
|
||||||
|
41,7747.05,0.16641,0.17472,0.77686,0.96565,0.95698,0.98751,0.76422,1.01303,0.45638,1.04876,0.00604,0.00604,0.00604
|
||||||
|
42,7930.66,0.16528,0.17295,0.77783,0.98029,0.94412,0.98709,0.68134,1.27076,0.55229,1.27714,0.005941,0.005941,0.005941
|
||||||
|
43,8113.31,0.16304,0.17093,0.77627,0.96909,0.95623,0.98692,0.77479,0.97729,0.45248,1.03101,0.005842,0.005842,0.005842
|
||||||
|
44,8298.3,0.16163,0.16817,0.77509,0.96809,0.9575,0.98709,0.80448,0.85637,0.40945,0.96298,0.005743,0.005743,0.005743
|
||||||
|
45,8485.25,0.16053,0.16768,0.77535,0.97311,0.95211,0.98726,0.8047,0.85835,0.40888,0.96605,0.005644,0.005644,0.005644
|
||||||
|
46,8669.88,0.15959,0.16634,0.77576,0.97431,0.95186,0.98739,0.797,0.87218,0.41446,0.97186,0.005545,0.005545,0.005545
|
||||||
|
47,8853.93,0.15778,0.16234,0.77599,0.97532,0.95052,0.98702,0.78582,0.92511,0.43102,1.0039,0.005446,0.005446,0.005446
|
||||||
|
48,9037.52,0.15602,0.16175,0.77439,0.97529,0.94998,0.9874,0.84071,0.77064,0.38361,0.93995,0.005347,0.005347,0.005347
|
||||||
|
49,9223.9,0.15478,0.1604,0.77364,0.97345,0.95143,0.98729,0.84662,0.73185,0.37143,0.92248,0.005248,0.005248,0.005248
|
||||||
|
50,9411.45,0.15431,0.1584,0.77449,0.98033,0.94592,0.98733,0.86995,0.63137,0.34173,0.8816,0.005149,0.005149,0.005149
|
||||||
|
51,9595.98,0.15305,0.15648,0.77414,0.97318,0.95431,0.98753,0.86938,0.63298,0.34305,0.88158,0.00505,0.00505,0.00505
|
||||||
|
52,9779.61,0.15291,0.15561,0.77441,0.97824,0.94936,0.98777,0.87333,0.60174,0.33831,0.87135,0.004951,0.004951,0.004951
|
||||||
|
53,9963.17,0.15193,0.15454,0.77361,0.97077,0.95579,0.98737,0.87038,0.62864,0.34404,0.87819,0.004852,0.004852,0.004852
|
||||||
|
54,10151.9,0.14978,0.15218,0.77431,0.97892,0.94812,0.9874,0.86664,0.6559,0.35033,0.88752,0.004753,0.004753,0.004753
|
||||||
|
55,10335.2,0.14867,0.14954,0.77318,0.98227,0.94556,0.98734,0.86535,0.66191,0.35233,0.89052,0.004654,0.004654,0.004654
|
||||||
|
56,10517.7,0.14781,0.1504,0.77387,0.97187,0.95472,0.98731,0.85393,0.70291,0.36551,0.90638,0.004555,0.004555,0.004555
|
||||||
|
57,10704.4,0.14704,0.14654,0.77286,0.96973,0.95539,0.98731,0.84386,0.75517,0.38619,0.92774,0.004456,0.004456,0.004456
|
||||||
|
58,10888.6,0.14478,0.14588,0.77324,0.9792,0.94676,0.9872,0.84023,0.76095,0.38846,0.93011,0.004357,0.004357,0.004357
|
||||||
|
59,11071.2,0.14408,0.14418,0.7724,0.9709,0.95553,0.98729,0.8499,0.71784,0.37089,0.91332,0.004258,0.004258,0.004258
|
||||||
|
60,11256.4,0.1427,0.14165,0.77106,0.96919,0.95682,0.98729,0.85156,0.70256,0.36509,0.90774,0.004159,0.004159,0.004159
|
||||||
|
61,11444.8,0.14194,0.14087,0.77269,0.96601,0.96121,0.98731,0.85107,0.70839,0.36753,0.90948,0.00406,0.00406,0.00406
|
||||||
|
62,11630.9,0.14062,0.13882,0.77215,0.96628,0.96081,0.98736,0.84858,0.73074,0.3762,0.92033,0.003961,0.003961,0.003961
|
||||||
|
63,11816.5,0.13938,0.13865,0.77152,0.96711,0.95961,0.98744,0.85214,0.70862,0.36754,0.91079,0.003862,0.003862,0.003862
|
||||||
|
64,12005.3,0.13858,0.13687,0.77045,0.96702,0.9595,0.98748,0.85574,0.69672,0.36084,0.90588,0.003763,0.003763,0.003763
|
||||||
|
65,12191.8,0.13775,0.13411,0.77132,0.96785,0.95943,0.9874,0.85729,0.68875,0.35766,0.90221,0.003664,0.003664,0.003664
|
||||||
|
66,12379.6,0.13556,0.13271,0.77167,0.96725,0.96005,0.98735,0.85898,0.68174,0.3561,0.89887,0.003565,0.003565,0.003565
|
||||||
|
67,12565.4,0.13463,0.13108,0.77009,0.97381,0.95338,0.98732,0.86031,0.67263,0.35399,0.89484,0.003466,0.003466,0.003466
|
||||||
|
68,12752.5,0.13515,0.1311,0.77095,0.96906,0.95916,0.98725,0.86029,0.66717,0.35274,0.89292,0.003367,0.003367,0.003367
|
||||||
|
69,12940.8,0.13415,0.12957,0.76963,0.97126,0.95685,0.9873,0.86049,0.6644,0.35306,0.8918,0.003268,0.003268,0.003268
|
||||||
|
70,13133.2,0.13179,0.12737,0.76937,0.97287,0.95478,0.98727,0.86047,0.6632,0.35246,0.89193,0.003169,0.003169,0.003169
|
||||||
|
71,13319.4,0.13185,0.1274,0.77079,0.97267,0.95587,0.98722,0.86193,0.65949,0.35213,0.89086,0.00307,0.00307,0.00307
|
||||||
|
72,13504.8,0.12947,0.12446,0.76998,0.97199,0.95686,0.98725,0.86401,0.64895,0.34877,0.88741,0.002971,0.002971,0.002971
|
||||||
|
73,13695.1,0.12876,0.12321,0.76883,0.9723,0.9569,0.98725,0.86643,0.64091,0.3473,0.88447,0.002872,0.002872,0.002872
|
||||||
|
74,13882,0.12828,0.12194,0.76915,0.97256,0.95686,0.98732,0.86847,0.6322,0.34702,0.88109,0.002773,0.002773,0.002773
|
||||||
|
75,14075,0.12664,0.11944,0.76878,0.97277,0.95678,0.98726,0.86861,0.63123,0.3482,0.88086,0.002674,0.002674,0.002674
|
||||||
|
76,14259.9,0.12587,0.11965,0.7692,0.9727,0.95673,0.98717,0.86916,0.62721,0.34811,0.87932,0.002575,0.002575,0.002575
|
||||||
|
77,14451.2,0.12433,0.1174,0.76838,0.97267,0.95663,0.9872,0.87057,0.62032,0.34709,0.8769,0.002476,0.002476,0.002476
|
||||||
|
78,14636.3,0.12352,0.11507,0.76971,0.97087,0.95829,0.98721,0.87189,0.61271,0.34667,0.87445,0.002377,0.002377,0.002377
|
||||||
|
79,14821.1,0.1231,0.11454,0.76897,0.97195,0.95752,0.98722,0.87292,0.60714,0.34596,0.87271,0.002278,0.002278,0.002278
|
||||||
|
80,15007.3,0.12117,0.11285,0.76864,0.97146,0.95789,0.98726,0.8735,0.6031,0.34515,0.87163,0.002179,0.002179,0.002179
|
||||||
|
81,15199.5,0.12029,0.11158,0.76708,0.97018,0.95938,0.9872,0.87378,0.60116,0.34538,0.87113,0.00208,0.00208,0.00208
|
||||||
|
82,15390.8,0.11877,0.10949,0.76719,0.97021,0.95964,0.98721,0.87422,0.59897,0.3457,0.87057,0.001981,0.001981,0.001981
|
||||||
|
83,15577.6,0.11812,0.10818,0.76749,0.97013,0.95951,0.98722,0.87429,0.59878,0.34524,0.87054,0.001882,0.001882,0.001882
|
||||||
|
84,15761.5,0.11687,0.10634,0.76703,0.97155,0.9583,0.98713,0.87407,0.59964,0.34532,0.8709,0.001783,0.001783,0.001783
|
||||||
|
85,15946.2,0.11551,0.10455,0.7672,0.9717,0.95797,0.9871,0.87367,0.60049,0.34569,0.87136,0.001684,0.001684,0.001684
|
||||||
|
86,16130.6,0.11474,0.10479,0.76737,0.97183,0.95808,0.98712,0.87406,0.5981,0.34504,0.87076,0.001585,0.001585,0.001585
|
||||||
|
87,16324.1,0.11337,0.10221,0.76695,0.97137,0.95881,0.98708,0.87382,0.59851,0.34519,0.87106,0.001486,0.001486,0.001486
|
||||||
|
88,16517.1,0.11185,0.10043,0.76513,0.97121,0.95899,0.98707,0.87379,0.59906,0.34583,0.87135,0.001387,0.001387,0.001387
|
||||||
|
89,16708.5,0.11103,0.09846,0.76565,0.97113,0.95904,0.98709,0.87369,0.59838,0.34599,0.87138,0.001288,0.001288,0.001288
|
||||||
|
90,16896.9,0.11054,0.0982,0.76703,0.97095,0.95892,0.98712,0.87377,0.59757,0.34552,0.87126,0.001189,0.001189,0.001189
|
||||||
|
91,17091.5,0.10967,0.09616,0.76665,0.97037,0.9595,0.98704,0.87361,0.59635,0.34561,0.87111,0.00109,0.00109,0.00109
|
||||||
|
92,17282.9,0.10834,0.09481,0.76509,0.9726,0.95743,0.98704,0.87372,0.5956,0.34572,0.87108,0.000991,0.000991,0.000991
|
||||||
|
93,17471.1,0.10692,0.09247,0.76461,0.97255,0.95738,0.9871,0.87368,0.59467,0.34689,0.8707,0.000892,0.000892,0.000892
|
||||||
|
94,17654.7,0.10578,0.09076,0.76573,0.97167,0.95786,0.9872,0.87367,0.59367,0.34732,0.87049,0.000793,0.000793,0.000793
|
||||||
|
95,17858.1,0.10457,0.08903,0.7648,0.97097,0.95816,0.98718,0.87394,0.59295,0.34757,0.87044,0.000694,0.000694,0.000694
|
||||||
|
96,18048,0.10283,0.08802,0.76437,0.97358,0.95577,0.98712,0.8737,0.59392,0.34877,0.87087,0.000595,0.000595,0.000595
|
||||||
|
97,18233,0.10269,0.08685,0.76468,0.97469,0.95492,0.98712,0.8741,0.59227,0.34903,0.87042,0.000496,0.000496,0.000496
|
||||||
|
98,18418.2,0.10143,0.0852,0.7644,0.97473,0.95512,0.98709,0.87397,0.59171,0.35007,0.8704,0.000397,0.000397,0.000397
|
||||||
|
99,18605.1,0.10052,0.08363,0.76442,0.97443,0.95526,0.98712,0.87396,0.5922,0.35121,0.87087,0.000298,0.000298,0.000298
|
||||||
|
100,18790,0.09925,0.08228,0.76465,0.97498,0.95493,0.98711,0.8737,0.59312,0.35293,0.87138,0.000199,0.000199,0.000199
|
||||||
|
1
tests/shared/augmentation/__init__.py
Normal file
1
tests/shared/augmentation/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Tests for augmentation module
|
||||||
347
tests/shared/augmentation/test_base.py
Normal file
347
tests/shared/augmentation/test_base.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""
|
||||||
|
Tests for augmentation base module.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationResult:
|
||||||
|
"""Tests for AugmentationResult dataclass."""
|
||||||
|
|
||||||
|
def test_minimal_result(self) -> None:
|
||||||
|
"""Test creating result with only required field."""
|
||||||
|
from shared.augmentation.base import AugmentationResult
|
||||||
|
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
result = AugmentationResult(image=image)
|
||||||
|
|
||||||
|
assert result.image is image
|
||||||
|
assert result.bboxes is None
|
||||||
|
assert result.transform_matrix is None
|
||||||
|
assert result.applied is True
|
||||||
|
assert result.metadata is None
|
||||||
|
|
||||||
|
def test_full_result(self) -> None:
|
||||||
|
"""Test creating result with all fields."""
|
||||||
|
from shared.augmentation.base import AugmentationResult
|
||||||
|
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||||
|
transform = np.eye(3)
|
||||||
|
metadata = {"applied_transform": "wrinkle"}
|
||||||
|
|
||||||
|
result = AugmentationResult(
|
||||||
|
image=image,
|
||||||
|
bboxes=bboxes,
|
||||||
|
transform_matrix=transform,
|
||||||
|
applied=True,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.image is image
|
||||||
|
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||||
|
np.testing.assert_array_equal(result.transform_matrix, transform)
|
||||||
|
assert result.applied is True
|
||||||
|
assert result.metadata == {"applied_transform": "wrinkle"}
|
||||||
|
|
||||||
|
def test_not_applied(self) -> None:
|
||||||
|
"""Test result when augmentation was not applied."""
|
||||||
|
from shared.augmentation.base import AugmentationResult
|
||||||
|
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
result = AugmentationResult(image=image, applied=False)
|
||||||
|
|
||||||
|
assert result.applied is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseAugmentation:
|
||||||
|
"""Tests for BaseAugmentation abstract class."""
|
||||||
|
|
||||||
|
def test_cannot_instantiate_directly(self) -> None:
|
||||||
|
"""Test that BaseAugmentation cannot be instantiated."""
|
||||||
|
from shared.augmentation.base import BaseAugmentation
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
BaseAugmentation({}) # type: ignore
|
||||||
|
|
||||||
|
def test_subclass_must_implement_apply(self) -> None:
|
||||||
|
"""Test that subclass must implement apply method."""
|
||||||
|
from shared.augmentation.base import BaseAugmentation
|
||||||
|
|
||||||
|
class IncompleteAugmentation(BaseAugmentation):
|
||||||
|
name = "incomplete"
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Missing apply method
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteAugmentation({}) # type: ignore
|
||||||
|
|
||||||
|
def test_subclass_must_implement_validate_params(self) -> None:
|
||||||
|
"""Test that subclass must implement _validate_params."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class IncompleteAugmentation(BaseAugmentation):
|
||||||
|
name = "incomplete"
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
return AugmentationResult(image=image)
|
||||||
|
|
||||||
|
# Missing _validate_params method
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteAugmentation({}) # type: ignore
|
||||||
|
|
||||||
|
def test_valid_subclass(self) -> None:
|
||||||
|
"""Test creating a valid subclass."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class DummyAugmentation(BaseAugmentation):
|
||||||
|
name = "dummy"
|
||||||
|
affects_geometry = False
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
return AugmentationResult(image=image, bboxes=bboxes)
|
||||||
|
|
||||||
|
aug = DummyAugmentation({"param1": "value1"})
|
||||||
|
|
||||||
|
assert aug.name == "dummy"
|
||||||
|
assert aug.affects_geometry is False
|
||||||
|
assert aug.params == {"param1": "value1"}
|
||||||
|
|
||||||
|
def test_apply_returns_augmentation_result(self) -> None:
|
||||||
|
"""Test that apply returns AugmentationResult."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class DummyAugmentation(BaseAugmentation):
|
||||||
|
name = "dummy"
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
# Simple pass-through
|
||||||
|
return AugmentationResult(image=image, bboxes=bboxes)
|
||||||
|
|
||||||
|
aug = DummyAugmentation({})
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]])
|
||||||
|
|
||||||
|
result = aug.apply(image, bboxes)
|
||||||
|
|
||||||
|
assert isinstance(result, AugmentationResult)
|
||||||
|
assert result.image is image
|
||||||
|
np.testing.assert_array_equal(result.bboxes, bboxes)
|
||||||
|
|
||||||
|
def test_affects_geometry_default(self) -> None:
|
||||||
|
"""Test that affects_geometry defaults to False."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class DummyAugmentation(BaseAugmentation):
|
||||||
|
name = "dummy"
|
||||||
|
# Not setting affects_geometry
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
return AugmentationResult(image=image)
|
||||||
|
|
||||||
|
aug = DummyAugmentation({})
|
||||||
|
|
||||||
|
assert aug.affects_geometry is False
|
||||||
|
|
||||||
|
def test_validate_params_called_on_init(self) -> None:
|
||||||
|
"""Test that _validate_params is called during initialization."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
validation_called = {"called": False}
|
||||||
|
|
||||||
|
class ValidatingAugmentation(BaseAugmentation):
|
||||||
|
name = "validating"
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
validation_called["called"] = True
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
return AugmentationResult(image=image)
|
||||||
|
|
||||||
|
ValidatingAugmentation({})
|
||||||
|
|
||||||
|
assert validation_called["called"] is True
|
||||||
|
|
||||||
|
def test_validate_params_raises_on_invalid(self) -> None:
|
||||||
|
"""Test that _validate_params can raise ValueError."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class StrictAugmentation(BaseAugmentation):
|
||||||
|
name = "strict"
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
if "required_param" not in self.params:
|
||||||
|
raise ValueError("required_param is required")
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
return AugmentationResult(image=image)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="required_param"):
|
||||||
|
StrictAugmentation({})
|
||||||
|
|
||||||
|
# Should work with required param
|
||||||
|
aug = StrictAugmentation({"required_param": "value"})
|
||||||
|
assert aug.params["required_param"] == "value"
|
||||||
|
|
||||||
|
def test_rng_usage(self) -> None:
|
||||||
|
"""Test that random generator can be passed and used."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class RandomAugmentation(BaseAugmentation):
|
||||||
|
name = "random"
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
if rng is None:
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
# Use rng to generate a random value
|
||||||
|
random_value = rng.random()
|
||||||
|
return AugmentationResult(
|
||||||
|
image=image,
|
||||||
|
metadata={"random_value": random_value},
|
||||||
|
)
|
||||||
|
|
||||||
|
aug = RandomAugmentation({})
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# With same seed, should get same result
|
||||||
|
rng1 = np.random.default_rng(42)
|
||||||
|
rng2 = np.random.default_rng(42)
|
||||||
|
|
||||||
|
result1 = aug.apply(image, rng=rng1)
|
||||||
|
result2 = aug.apply(image, rng=rng2)
|
||||||
|
|
||||||
|
assert result1.metadata is not None
|
||||||
|
assert result2.metadata is not None
|
||||||
|
assert result1.metadata["random_value"] == result2.metadata["random_value"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationResultImmutability:
|
||||||
|
"""Tests for ensuring result doesn't mutate input."""
|
||||||
|
|
||||||
|
def test_image_not_modified(self) -> None:
|
||||||
|
"""Test that original image is not modified."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class ModifyingAugmentation(BaseAugmentation):
|
||||||
|
name = "modifying"
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
# Should copy before modifying
|
||||||
|
modified = image.copy()
|
||||||
|
modified[:] = 255
|
||||||
|
return AugmentationResult(image=modified)
|
||||||
|
|
||||||
|
aug = ModifyingAugmentation({})
|
||||||
|
original = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
original_copy = original.copy()
|
||||||
|
|
||||||
|
result = aug.apply(original)
|
||||||
|
|
||||||
|
# Original should be unchanged
|
||||||
|
np.testing.assert_array_equal(original, original_copy)
|
||||||
|
# Result should be modified
|
||||||
|
assert np.all(result.image == 255)
|
||||||
|
|
||||||
|
def test_bboxes_not_modified(self) -> None:
|
||||||
|
"""Test that original bboxes are not modified."""
|
||||||
|
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||||
|
|
||||||
|
class BboxModifyingAugmentation(BaseAugmentation):
|
||||||
|
name = "bbox_modifying"
|
||||||
|
affects_geometry = True
|
||||||
|
|
||||||
|
def _validate_params(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: np.ndarray | None = None,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
) -> AugmentationResult:
|
||||||
|
if bboxes is not None:
|
||||||
|
# Should copy before modifying
|
||||||
|
modified_bboxes = bboxes.copy()
|
||||||
|
modified_bboxes[:, 1:] *= 0.5 # Scale down
|
||||||
|
return AugmentationResult(image=image, bboxes=modified_bboxes)
|
||||||
|
return AugmentationResult(image=image)
|
||||||
|
|
||||||
|
aug = BboxModifyingAugmentation({})
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
original_bboxes = np.array([[0, 0.5, 0.5, 0.2, 0.2]], dtype=np.float32)
|
||||||
|
original_bboxes_copy = original_bboxes.copy()
|
||||||
|
|
||||||
|
result = aug.apply(image, original_bboxes)
|
||||||
|
|
||||||
|
# Original should be unchanged
|
||||||
|
np.testing.assert_array_equal(original_bboxes, original_bboxes_copy)
|
||||||
|
# Result should be modified
|
||||||
|
assert result.bboxes is not None
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result.bboxes, np.array([[0, 0.25, 0.25, 0.1, 0.1]])
|
||||||
|
)
|
||||||
283
tests/shared/augmentation/test_config.py
Normal file
283
tests/shared/augmentation/test_config.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
Tests for augmentation configuration module.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationParams:
|
||||||
|
"""Tests for AugmentationParams dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test default parameter values."""
|
||||||
|
from shared.augmentation.config import AugmentationParams
|
||||||
|
|
||||||
|
params = AugmentationParams()
|
||||||
|
|
||||||
|
assert params.enabled is False
|
||||||
|
assert params.probability == 0.5
|
||||||
|
assert params.params == {}
|
||||||
|
|
||||||
|
def test_custom_values(self) -> None:
|
||||||
|
"""Test creating params with custom values."""
|
||||||
|
from shared.augmentation.config import AugmentationParams
|
||||||
|
|
||||||
|
params = AugmentationParams(
|
||||||
|
enabled=True,
|
||||||
|
probability=0.8,
|
||||||
|
params={"intensity": 0.5, "num_wrinkles": (2, 5)},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert params.enabled is True
|
||||||
|
assert params.probability == 0.8
|
||||||
|
assert params.params["intensity"] == 0.5
|
||||||
|
assert params.params["num_wrinkles"] == (2, 5)
|
||||||
|
|
||||||
|
def test_immutability_params_dict(self) -> None:
|
||||||
|
"""Test that params dict is independent between instances."""
|
||||||
|
from shared.augmentation.config import AugmentationParams
|
||||||
|
|
||||||
|
params1 = AugmentationParams()
|
||||||
|
params2 = AugmentationParams()
|
||||||
|
|
||||||
|
# Modifying one should not affect the other
|
||||||
|
params1.params["test"] = "value"
|
||||||
|
|
||||||
|
assert "test" not in params2.params
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test conversion to dictionary."""
|
||||||
|
from shared.augmentation.config import AugmentationParams
|
||||||
|
|
||||||
|
params = AugmentationParams(
|
||||||
|
enabled=True,
|
||||||
|
probability=0.7,
|
||||||
|
params={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = params.to_dict()
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.7,
|
||||||
|
"params": {"key": "value"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_from_dict(self) -> None:
|
||||||
|
"""Test creation from dictionary."""
|
||||||
|
from shared.augmentation.config import AugmentationParams
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.6,
|
||||||
|
"params": {"intensity": 0.3},
|
||||||
|
}
|
||||||
|
|
||||||
|
params = AugmentationParams.from_dict(data)
|
||||||
|
|
||||||
|
assert params.enabled is True
|
||||||
|
assert params.probability == 0.6
|
||||||
|
assert params.params == {"intensity": 0.3}
|
||||||
|
|
||||||
|
def test_from_dict_with_defaults(self) -> None:
|
||||||
|
"""Test creation from partial dictionary uses defaults."""
|
||||||
|
from shared.augmentation.config import AugmentationParams
|
||||||
|
|
||||||
|
data: dict[str, Any] = {"enabled": True}
|
||||||
|
|
||||||
|
params = AugmentationParams.from_dict(data)
|
||||||
|
|
||||||
|
assert params.enabled is True
|
||||||
|
assert params.probability == 0.5 # default
|
||||||
|
assert params.params == {} # default
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationConfig:
|
||||||
|
"""Tests for AugmentationConfig dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test that all augmentation types have defaults."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
config = AugmentationConfig()
|
||||||
|
|
||||||
|
# All augmentation types should exist
|
||||||
|
augmentation_types = [
|
||||||
|
"perspective_warp",
|
||||||
|
"wrinkle",
|
||||||
|
"edge_damage",
|
||||||
|
"stain",
|
||||||
|
"lighting_variation",
|
||||||
|
"shadow",
|
||||||
|
"gaussian_blur",
|
||||||
|
"motion_blur",
|
||||||
|
"gaussian_noise",
|
||||||
|
"salt_pepper",
|
||||||
|
"paper_texture",
|
||||||
|
"scanner_artifacts",
|
||||||
|
]
|
||||||
|
|
||||||
|
for aug_type in augmentation_types:
|
||||||
|
assert hasattr(config, aug_type), f"Missing augmentation type: {aug_type}"
|
||||||
|
params = getattr(config, aug_type)
|
||||||
|
assert hasattr(params, "enabled")
|
||||||
|
assert hasattr(params, "probability")
|
||||||
|
assert hasattr(params, "params")
|
||||||
|
|
||||||
|
def test_global_settings_defaults(self) -> None:
|
||||||
|
"""Test global settings default values."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
config = AugmentationConfig()
|
||||||
|
|
||||||
|
assert config.preserve_bboxes is True
|
||||||
|
assert config.seed is None
|
||||||
|
|
||||||
|
def test_custom_seed(self) -> None:
|
||||||
|
"""Test setting custom seed for reproducibility."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
config = AugmentationConfig(seed=42)
|
||||||
|
|
||||||
|
assert config.seed == 42
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test conversion to dictionary."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
config = AugmentationConfig(seed=123, preserve_bboxes=False)
|
||||||
|
|
||||||
|
result = config.to_dict()
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["seed"] == 123
|
||||||
|
assert result["preserve_bboxes"] is False
|
||||||
|
assert "perspective_warp" in result
|
||||||
|
assert "wrinkle" in result
|
||||||
|
|
||||||
|
def test_from_dict(self) -> None:
|
||||||
|
"""Test creation from dictionary."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"seed": 456,
|
||||||
|
"preserve_bboxes": False,
|
||||||
|
"wrinkle": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 0.8,
|
||||||
|
"params": {"intensity": 0.5},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AugmentationConfig.from_dict(data)
|
||||||
|
|
||||||
|
assert config.seed == 456
|
||||||
|
assert config.preserve_bboxes is False
|
||||||
|
assert config.wrinkle.enabled is True
|
||||||
|
assert config.wrinkle.probability == 0.8
|
||||||
|
assert config.wrinkle.params["intensity"] == 0.5
|
||||||
|
|
||||||
|
def test_from_dict_with_partial_data(self) -> None:
|
||||||
|
"""Test creation from partial dictionary uses defaults."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"wrinkle": {"enabled": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AugmentationConfig.from_dict(data)
|
||||||
|
|
||||||
|
# Explicitly set value
|
||||||
|
assert config.wrinkle.enabled is True
|
||||||
|
# Default values
|
||||||
|
assert config.preserve_bboxes is True
|
||||||
|
assert config.seed is None
|
||||||
|
assert config.gaussian_blur.enabled is False
|
||||||
|
|
||||||
|
def test_get_enabled_augmentations(self) -> None:
|
||||||
|
"""Test getting list of enabled augmentations."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
|
||||||
|
config = AugmentationConfig(
|
||||||
|
wrinkle=AugmentationParams(enabled=True),
|
||||||
|
stain=AugmentationParams(enabled=True),
|
||||||
|
gaussian_blur=AugmentationParams(enabled=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
enabled = config.get_enabled_augmentations()
|
||||||
|
|
||||||
|
assert "wrinkle" in enabled
|
||||||
|
assert "stain" in enabled
|
||||||
|
assert "gaussian_blur" not in enabled
|
||||||
|
|
||||||
|
def test_document_safe_defaults(self) -> None:
|
||||||
|
"""Test that default params are document-safe (conservative)."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
config = AugmentationConfig()
|
||||||
|
|
||||||
|
# Perspective warp should be very conservative
|
||||||
|
assert config.perspective_warp.params.get("max_warp", 0.02) <= 0.05
|
||||||
|
|
||||||
|
# Noise should be subtle
|
||||||
|
noise_std = config.gaussian_noise.params.get("std", (5, 15))
|
||||||
|
if isinstance(noise_std, tuple):
|
||||||
|
assert noise_std[1] <= 20 # Max std should be low
|
||||||
|
|
||||||
|
def test_immutability_between_instances(self) -> None:
|
||||||
|
"""Test that config instances are independent."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
|
||||||
|
config1 = AugmentationConfig()
|
||||||
|
config2 = AugmentationConfig()
|
||||||
|
|
||||||
|
# Modifying one should not affect the other
|
||||||
|
config1.wrinkle.params["test"] = "value"
|
||||||
|
|
||||||
|
assert "test" not in config2.wrinkle.params
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationConfigValidation:
|
||||||
|
"""Tests for configuration validation."""
|
||||||
|
|
||||||
|
def test_probability_range_validation(self) -> None:
|
||||||
|
"""Test that probability values are validated."""
|
||||||
|
from shared.augmentation.config import AugmentationParams
|
||||||
|
|
||||||
|
# Valid range
|
||||||
|
params = AugmentationParams(probability=0.5)
|
||||||
|
assert params.probability == 0.5
|
||||||
|
|
||||||
|
# Edge cases
|
||||||
|
params_zero = AugmentationParams(probability=0.0)
|
||||||
|
assert params_zero.probability == 0.0
|
||||||
|
|
||||||
|
params_one = AugmentationParams(probability=1.0)
|
||||||
|
assert params_one.probability == 1.0
|
||||||
|
|
||||||
|
def test_config_validate_method(self) -> None:
|
||||||
|
"""Test the validate method catches invalid configurations."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
|
||||||
|
# Invalid probability
|
||||||
|
config = AugmentationConfig(
|
||||||
|
wrinkle=AugmentationParams(probability=1.5), # Invalid
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="probability"):
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
def test_config_validate_negative_probability(self) -> None:
|
||||||
|
"""Test validation catches negative probability."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
|
||||||
|
config = AugmentationConfig(
|
||||||
|
wrinkle=AugmentationParams(probability=-0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="probability"):
|
||||||
|
config.validate()
|
||||||
338
tests/shared/augmentation/test_pipeline.py
Normal file
338
tests/shared/augmentation/test_pipeline.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
"""
|
||||||
|
Tests for augmentation pipeline module.
|
||||||
|
|
||||||
|
TDD Phase 2: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationPipeline:
|
||||||
|
"""Tests for AugmentationPipeline class."""
|
||||||
|
|
||||||
|
def test_create_with_config(self) -> None:
|
||||||
|
"""Test creating pipeline with config."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig()
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
assert pipeline.config is config
|
||||||
|
|
||||||
|
def test_create_with_seed(self) -> None:
|
||||||
|
"""Test creating pipeline with seed for reproducibility."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig(seed=42)
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
assert pipeline.config.seed == 42
|
||||||
|
|
||||||
|
def test_apply_returns_augmentation_result(self) -> None:
|
||||||
|
"""Test that apply returns AugmentationResult."""
|
||||||
|
from shared.augmentation.base import AugmentationResult
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig()
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
result = pipeline.apply(image)
|
||||||
|
|
||||||
|
assert isinstance(result, AugmentationResult)
|
||||||
|
assert result.image is not None
|
||||||
|
assert result.image.shape == image.shape
|
||||||
|
|
||||||
|
def test_apply_with_bboxes(self) -> None:
|
||||||
|
"""Test apply with bounding boxes."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig()
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
bboxes = np.array([[0, 0.5, 0.5, 0.1, 0.1]], dtype=np.float32)
|
||||||
|
|
||||||
|
result = pipeline.apply(image, bboxes)
|
||||||
|
|
||||||
|
# Bboxes should be preserved when preserve_bboxes=True
|
||||||
|
assert result.bboxes is not None
|
||||||
|
|
||||||
|
def test_apply_no_augmentations_enabled(self) -> None:
|
||||||
|
"""Test apply when no augmentations are enabled."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
# Disable all augmentations
|
||||||
|
config = AugmentationConfig(
|
||||||
|
lighting_variation=AugmentationParams(enabled=False),
|
||||||
|
)
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||||
|
result = pipeline.apply(image)
|
||||||
|
|
||||||
|
# Image should be unchanged (or a copy)
|
||||||
|
np.testing.assert_array_equal(result.image, image)
|
||||||
|
|
||||||
|
def test_apply_does_not_mutate_input(self) -> None:
|
||||||
|
"""Test that apply does not mutate input image."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig(
|
||||||
|
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||||
|
)
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||||
|
original_copy = image.copy()
|
||||||
|
|
||||||
|
pipeline.apply(image)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(image, original_copy)
|
||||||
|
|
||||||
|
def test_reproducibility_with_seed(self) -> None:
|
||||||
|
"""Test that same seed produces same results."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config1 = AugmentationConfig(
|
||||||
|
seed=42,
|
||||||
|
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||||
|
)
|
||||||
|
config2 = AugmentationConfig(
|
||||||
|
seed=42,
|
||||||
|
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline1 = AugmentationPipeline(config1)
|
||||||
|
pipeline2 = AugmentationPipeline(config2)
|
||||||
|
|
||||||
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||||
|
|
||||||
|
result1 = pipeline1.apply(image.copy())
|
||||||
|
result2 = pipeline2.apply(image.copy())
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(result1.image, result2.image)
|
||||||
|
|
||||||
|
def test_metadata_contains_applied_augmentations(self) -> None:
|
||||||
|
"""Test that metadata lists applied augmentations."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig(
|
||||||
|
seed=42,
|
||||||
|
gaussian_noise=AugmentationParams(enabled=True, probability=1.0),
|
||||||
|
lighting_variation=AugmentationParams(enabled=True, probability=1.0),
|
||||||
|
)
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||||
|
result = pipeline.apply(image)
|
||||||
|
|
||||||
|
assert result.metadata is not None
|
||||||
|
assert "applied_augmentations" in result.metadata
|
||||||
|
# Both should be applied with probability=1.0
|
||||||
|
assert "gaussian_noise" in result.metadata["applied_augmentations"]
|
||||||
|
assert "lighting_variation" in result.metadata["applied_augmentations"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationPipelineStageOrder:
|
||||||
|
"""Tests for pipeline stage ordering."""
|
||||||
|
|
||||||
|
def test_stage_order_defined(self) -> None:
|
||||||
|
"""Test that stage order is defined."""
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
assert hasattr(AugmentationPipeline, "STAGE_ORDER")
|
||||||
|
expected_stages = [
|
||||||
|
"geometric",
|
||||||
|
"degradation",
|
||||||
|
"lighting",
|
||||||
|
"texture",
|
||||||
|
"blur",
|
||||||
|
"noise",
|
||||||
|
]
|
||||||
|
assert AugmentationPipeline.STAGE_ORDER == expected_stages
|
||||||
|
|
||||||
|
def test_stage_mapping_defined(self) -> None:
|
||||||
|
"""Test that all augmentation types are mapped to stages."""
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
assert hasattr(AugmentationPipeline, "STAGE_MAPPING")
|
||||||
|
|
||||||
|
expected_mappings = {
|
||||||
|
"perspective_warp": "geometric",
|
||||||
|
"wrinkle": "degradation",
|
||||||
|
"edge_damage": "degradation",
|
||||||
|
"stain": "degradation",
|
||||||
|
"lighting_variation": "lighting",
|
||||||
|
"shadow": "lighting",
|
||||||
|
"paper_texture": "texture",
|
||||||
|
"scanner_artifacts": "texture",
|
||||||
|
"gaussian_blur": "blur",
|
||||||
|
"motion_blur": "blur",
|
||||||
|
"gaussian_noise": "noise",
|
||||||
|
"salt_pepper": "noise",
|
||||||
|
}
|
||||||
|
|
||||||
|
for aug_name, stage in expected_mappings.items():
|
||||||
|
assert aug_name in AugmentationPipeline.STAGE_MAPPING
|
||||||
|
assert AugmentationPipeline.STAGE_MAPPING[aug_name] == stage
|
||||||
|
|
||||||
|
def test_geometric_before_degradation(self) -> None:
|
||||||
|
"""Test that geometric transforms are applied before degradation."""
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
stages = AugmentationPipeline.STAGE_ORDER
|
||||||
|
geometric_idx = stages.index("geometric")
|
||||||
|
degradation_idx = stages.index("degradation")
|
||||||
|
|
||||||
|
assert geometric_idx < degradation_idx
|
||||||
|
|
||||||
|
def test_noise_applied_last(self) -> None:
|
||||||
|
"""Test that noise is applied last."""
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
stages = AugmentationPipeline.STAGE_ORDER
|
||||||
|
assert stages[-1] == "noise"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationRegistry:
|
||||||
|
"""Tests for augmentation registry."""
|
||||||
|
|
||||||
|
def test_registry_exists(self) -> None:
|
||||||
|
"""Test that augmentation registry exists."""
|
||||||
|
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||||
|
|
||||||
|
assert isinstance(AUGMENTATION_REGISTRY, dict)
|
||||||
|
|
||||||
|
def test_registry_contains_all_types(self) -> None:
|
||||||
|
"""Test that registry contains all augmentation types."""
|
||||||
|
from shared.augmentation.pipeline import AUGMENTATION_REGISTRY
|
||||||
|
|
||||||
|
expected_types = [
|
||||||
|
"perspective_warp",
|
||||||
|
"wrinkle",
|
||||||
|
"edge_damage",
|
||||||
|
"stain",
|
||||||
|
"lighting_variation",
|
||||||
|
"shadow",
|
||||||
|
"gaussian_blur",
|
||||||
|
"motion_blur",
|
||||||
|
"gaussian_noise",
|
||||||
|
"salt_pepper",
|
||||||
|
"paper_texture",
|
||||||
|
"scanner_artifacts",
|
||||||
|
]
|
||||||
|
|
||||||
|
for aug_type in expected_types:
|
||||||
|
assert aug_type in AUGMENTATION_REGISTRY, f"Missing: {aug_type}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelinePreview:
|
||||||
|
"""Tests for pipeline preview functionality."""
|
||||||
|
|
||||||
|
def test_preview_single_augmentation(self) -> None:
|
||||||
|
"""Test previewing a single augmentation."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig(
|
||||||
|
gaussian_noise=AugmentationParams(
|
||||||
|
enabled=True, probability=1.0, params={"std": (10, 10)}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||||
|
preview = pipeline.preview(image, "gaussian_noise")
|
||||||
|
|
||||||
|
assert preview.shape == image.shape
|
||||||
|
assert preview.dtype == np.uint8
|
||||||
|
# Preview should modify the image
|
||||||
|
assert not np.array_equal(preview, image)
|
||||||
|
|
||||||
|
def test_preview_unknown_augmentation_raises(self) -> None:
|
||||||
|
"""Test that previewing unknown augmentation raises error."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig()
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown augmentation"):
|
||||||
|
pipeline.preview(image, "non_existent_augmentation")
|
||||||
|
|
||||||
|
def test_preview_is_deterministic(self) -> None:
|
||||||
|
"""Test that preview produces deterministic results."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||||
|
from shared.augmentation.pipeline import AugmentationPipeline
|
||||||
|
|
||||||
|
config = AugmentationConfig(
|
||||||
|
gaussian_noise=AugmentationParams(enabled=True),
|
||||||
|
)
|
||||||
|
pipeline = AugmentationPipeline(config)
|
||||||
|
|
||||||
|
image = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||||
|
|
||||||
|
preview1 = pipeline.preview(image, "gaussian_noise")
|
||||||
|
preview2 = pipeline.preview(image, "gaussian_noise")
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(preview1, preview2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineGetAvailableAugmentations:
|
||||||
|
"""Tests for getting available augmentations."""
|
||||||
|
|
||||||
|
def test_get_available_augmentations(self) -> None:
|
||||||
|
"""Test getting list of available augmentations."""
|
||||||
|
from shared.augmentation.pipeline import get_available_augmentations
|
||||||
|
|
||||||
|
augmentations = get_available_augmentations()
|
||||||
|
|
||||||
|
assert isinstance(augmentations, list)
|
||||||
|
assert len(augmentations) == 12
|
||||||
|
|
||||||
|
# Each item should have name, description, affects_geometry
|
||||||
|
for aug in augmentations:
|
||||||
|
assert "name" in aug
|
||||||
|
assert "description" in aug
|
||||||
|
assert "affects_geometry" in aug
|
||||||
|
assert "stage" in aug
|
||||||
|
|
||||||
|
def test_get_available_augmentations_includes_all_types(self) -> None:
|
||||||
|
"""Test that all augmentation types are included."""
|
||||||
|
from shared.augmentation.pipeline import get_available_augmentations
|
||||||
|
|
||||||
|
augmentations = get_available_augmentations()
|
||||||
|
names = [aug["name"] for aug in augmentations]
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
"perspective_warp",
|
||||||
|
"wrinkle",
|
||||||
|
"edge_damage",
|
||||||
|
"stain",
|
||||||
|
"lighting_variation",
|
||||||
|
"shadow",
|
||||||
|
"gaussian_blur",
|
||||||
|
"motion_blur",
|
||||||
|
"gaussian_noise",
|
||||||
|
"salt_pepper",
|
||||||
|
"paper_texture",
|
||||||
|
"scanner_artifacts",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in expected:
|
||||||
|
assert name in names
|
||||||
102
tests/shared/augmentation/test_presets.py
Normal file
102
tests/shared/augmentation/test_presets.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
Tests for augmentation presets module.
|
||||||
|
|
||||||
|
TDD Phase 4: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestPresets:
|
||||||
|
"""Tests for augmentation presets."""
|
||||||
|
|
||||||
|
def test_presets_dict_exists(self) -> None:
|
||||||
|
"""Test that PRESETS dictionary exists."""
|
||||||
|
from shared.augmentation.presets import PRESETS
|
||||||
|
|
||||||
|
assert isinstance(PRESETS, dict)
|
||||||
|
assert len(PRESETS) > 0
|
||||||
|
|
||||||
|
def test_expected_presets_exist(self) -> None:
|
||||||
|
"""Test that expected presets are defined."""
|
||||||
|
from shared.augmentation.presets import PRESETS
|
||||||
|
|
||||||
|
expected_presets = ["conservative", "moderate", "aggressive", "scanned_document"]
|
||||||
|
|
||||||
|
for preset_name in expected_presets:
|
||||||
|
assert preset_name in PRESETS, f"Missing preset: {preset_name}"
|
||||||
|
|
||||||
|
def test_preset_structure(self) -> None:
|
||||||
|
"""Test that each preset has required structure."""
|
||||||
|
from shared.augmentation.presets import PRESETS
|
||||||
|
|
||||||
|
for name, preset in PRESETS.items():
|
||||||
|
assert "description" in preset, f"Preset {name} missing description"
|
||||||
|
assert "config" in preset, f"Preset {name} missing config"
|
||||||
|
assert isinstance(preset["description"], str)
|
||||||
|
assert isinstance(preset["config"], dict)
|
||||||
|
|
||||||
|
def test_get_preset_config(self) -> None:
|
||||||
|
"""Test getting config from preset."""
|
||||||
|
from shared.augmentation.presets import get_preset_config
|
||||||
|
|
||||||
|
config = get_preset_config("conservative")
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
# Should have at least some augmentations defined
|
||||||
|
assert len(config) > 0
|
||||||
|
|
||||||
|
def test_get_preset_config_unknown_raises(self) -> None:
|
||||||
|
"""Test that getting unknown preset raises error."""
|
||||||
|
from shared.augmentation.presets import get_preset_config
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown preset"):
|
||||||
|
get_preset_config("nonexistent_preset")
|
||||||
|
|
||||||
|
def test_create_config_from_preset(self) -> None:
|
||||||
|
"""Test creating AugmentationConfig from preset."""
|
||||||
|
from shared.augmentation.config import AugmentationConfig
|
||||||
|
from shared.augmentation.presets import create_config_from_preset
|
||||||
|
|
||||||
|
config = create_config_from_preset("moderate")
|
||||||
|
|
||||||
|
assert isinstance(config, AugmentationConfig)
|
||||||
|
|
||||||
|
def test_conservative_preset_is_safe(self) -> None:
|
||||||
|
"""Test that conservative preset only enables safe augmentations."""
|
||||||
|
from shared.augmentation.presets import create_config_from_preset
|
||||||
|
|
||||||
|
config = create_config_from_preset("conservative")
|
||||||
|
|
||||||
|
# Should NOT enable geometric transforms
|
||||||
|
assert config.perspective_warp.enabled is False
|
||||||
|
|
||||||
|
# Should NOT enable heavy degradation
|
||||||
|
assert config.wrinkle.enabled is False
|
||||||
|
assert config.edge_damage.enabled is False
|
||||||
|
assert config.stain.enabled is False
|
||||||
|
|
||||||
|
def test_aggressive_preset_enables_more(self) -> None:
|
||||||
|
"""Test that aggressive preset enables more augmentations."""
|
||||||
|
from shared.augmentation.presets import create_config_from_preset
|
||||||
|
|
||||||
|
config = create_config_from_preset("aggressive")
|
||||||
|
|
||||||
|
enabled = config.get_enabled_augmentations()
|
||||||
|
|
||||||
|
# Should enable multiple augmentation types
|
||||||
|
assert len(enabled) >= 6
|
||||||
|
|
||||||
|
def test_list_presets(self) -> None:
|
||||||
|
"""Test listing available presets."""
|
||||||
|
from shared.augmentation.presets import list_presets
|
||||||
|
|
||||||
|
presets = list_presets()
|
||||||
|
|
||||||
|
assert isinstance(presets, list)
|
||||||
|
assert len(presets) >= 4
|
||||||
|
|
||||||
|
# Each item should have name and description
|
||||||
|
for preset in presets:
|
||||||
|
assert "name" in preset
|
||||||
|
assert "description" in preset
|
||||||
1
tests/shared/augmentation/transforms/__init__.py
Normal file
1
tests/shared/augmentation/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Tests for augmentation transforms
|
||||||
293
tests/shared/test_dataset_augmenter.py
Normal file
293
tests/shared/test_dataset_augmenter.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
Tests for DatasetAugmenter.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetAugmenter:
|
||||||
|
"""Tests for DatasetAugmenter class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_dataset(self, tmp_path: Path) -> Path:
|
||||||
|
"""Create a sample YOLO dataset structure."""
|
||||||
|
dataset_dir = tmp_path / "dataset"
|
||||||
|
|
||||||
|
# Create directory structure
|
||||||
|
for split in ["train", "val", "test"]:
|
||||||
|
(dataset_dir / "images" / split).mkdir(parents=True)
|
||||||
|
(dataset_dir / "labels" / split).mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create sample images and labels
|
||||||
|
for i in range(3):
|
||||||
|
# Create 100x100 white image
|
||||||
|
img = Image.new("RGB", (100, 100), color="white")
|
||||||
|
img_path = dataset_dir / "images" / "train" / f"doc_{i}.png"
|
||||||
|
img.save(img_path)
|
||||||
|
|
||||||
|
# Create label with 2 bboxes
|
||||||
|
# Format: class_id x_center y_center width height
|
||||||
|
label_content = "0 0.5 0.3 0.2 0.1\n1 0.7 0.6 0.15 0.2\n"
|
||||||
|
label_path = dataset_dir / "labels" / "train" / f"doc_{i}.txt"
|
||||||
|
label_path.write_text(label_content)
|
||||||
|
|
||||||
|
# Create data.yaml
|
||||||
|
data_yaml = dataset_dir / "data.yaml"
|
||||||
|
data_yaml.write_text(
|
||||||
|
"path: .\n"
|
||||||
|
"train: images/train\n"
|
||||||
|
"val: images/val\n"
|
||||||
|
"test: images/test\n"
|
||||||
|
"nc: 10\n"
|
||||||
|
"names: [class0, class1, class2, class3, class4, class5, class6, class7, class8, class9]\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def augmentation_config(self) -> dict:
|
||||||
|
"""Create a sample augmentation config."""
|
||||||
|
return {
|
||||||
|
"gaussian_noise": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 1.0,
|
||||||
|
"params": {"std": 10},
|
||||||
|
},
|
||||||
|
"gaussian_blur": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 1.0,
|
||||||
|
"params": {"kernel_size": 3},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_augmenter_creates_additional_images(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that augmenter creates new augmented images."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
|
||||||
|
# Count original images
|
||||||
|
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||||
|
assert original_count == 3
|
||||||
|
|
||||||
|
# Apply augmentation with multiplier=2
|
||||||
|
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||||
|
|
||||||
|
# Should now have original + 2x augmented = 3 + 6 = 9 images
|
||||||
|
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||||
|
assert new_count == 9
|
||||||
|
assert result["augmented_images"] == 6
|
||||||
|
|
||||||
|
def test_augmenter_creates_matching_labels(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that augmenter creates label files for each augmented image."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||||
|
|
||||||
|
# Check that each image has a matching label file
|
||||||
|
images = list((sample_dataset / "images" / "train").glob("*.png"))
|
||||||
|
labels = list((sample_dataset / "labels" / "train").glob("*.txt"))
|
||||||
|
|
||||||
|
assert len(images) == len(labels)
|
||||||
|
|
||||||
|
# Check that augmented images have corresponding labels
|
||||||
|
for img_path in images:
|
||||||
|
label_path = sample_dataset / "labels" / "train" / f"{img_path.stem}.txt"
|
||||||
|
assert label_path.exists(), f"Missing label for {img_path.name}"
|
||||||
|
|
||||||
|
def test_augmented_labels_have_valid_format(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that augmented label files have valid YOLO format."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||||
|
|
||||||
|
# Check all label files
|
||||||
|
for label_path in (sample_dataset / "labels" / "train").glob("*.txt"):
|
||||||
|
content = label_path.read_text().strip()
|
||||||
|
if not content:
|
||||||
|
continue # Empty labels are valid (background images)
|
||||||
|
|
||||||
|
for line in content.split("\n"):
|
||||||
|
parts = line.split()
|
||||||
|
assert len(parts) == 5, f"Invalid label format in {label_path.name}"
|
||||||
|
|
||||||
|
class_id = int(parts[0])
|
||||||
|
x_center = float(parts[1])
|
||||||
|
y_center = float(parts[2])
|
||||||
|
width = float(parts[3])
|
||||||
|
height = float(parts[4])
|
||||||
|
|
||||||
|
# Check values are in valid range
|
||||||
|
assert 0 <= class_id < 100, f"Invalid class_id: {class_id}"
|
||||||
|
assert 0 <= x_center <= 1, f"Invalid x_center: {x_center}"
|
||||||
|
assert 0 <= y_center <= 1, f"Invalid y_center: {y_center}"
|
||||||
|
assert 0 <= width <= 1, f"Invalid width: {width}"
|
||||||
|
assert 0 <= height <= 1, f"Invalid height: {height}"
|
||||||
|
|
||||||
|
def test_augmented_images_are_different(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that augmented images are actually different from originals."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
# Load original image
|
||||||
|
original_path = sample_dataset / "images" / "train" / "doc_0.png"
|
||||||
|
original_img = np.array(Image.open(original_path))
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||||
|
|
||||||
|
# Find augmented version
|
||||||
|
aug_path = sample_dataset / "images" / "train" / "doc_0_aug0.png"
|
||||||
|
assert aug_path.exists()
|
||||||
|
|
||||||
|
aug_img = np.array(Image.open(aug_path))
|
||||||
|
|
||||||
|
# Images should be different (due to noise/blur)
|
||||||
|
assert not np.array_equal(original_img, aug_img)
|
||||||
|
|
||||||
|
def test_augmented_images_same_size(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that augmented images have same size as originals."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
# Get original size
|
||||||
|
original_path = sample_dataset / "images" / "train" / "doc_0.png"
|
||||||
|
original_img = Image.open(original_path)
|
||||||
|
original_size = original_img.size
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||||
|
|
||||||
|
# Check all augmented images have same size
|
||||||
|
for img_path in (sample_dataset / "images" / "train").glob("*_aug*.png"):
|
||||||
|
img = Image.open(img_path)
|
||||||
|
assert img.size == original_size, f"{img_path.name} has wrong size"
|
||||||
|
|
||||||
|
def test_perspective_warp_updates_bboxes(self, sample_dataset: Path):
|
||||||
|
"""Test that perspective_warp augmentation updates bbox coordinates."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"perspective_warp": {
|
||||||
|
"enabled": True,
|
||||||
|
"probability": 1.0,
|
||||||
|
"params": {"max_warp": 0.05}, # Use larger warp for visible difference
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Read original label
|
||||||
|
original_label = (sample_dataset / "labels" / "train" / "doc_0.txt").read_text()
|
||||||
|
original_bboxes = [line.split() for line in original_label.strip().split("\n")]
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(config)
|
||||||
|
augmenter.augment_dataset(sample_dataset, multiplier=1)
|
||||||
|
|
||||||
|
# Read augmented label
|
||||||
|
aug_label = (sample_dataset / "labels" / "train" / "doc_0_aug0.txt").read_text()
|
||||||
|
aug_bboxes = [line.split() for line in aug_label.strip().split("\n")]
|
||||||
|
|
||||||
|
# Same number of bboxes
|
||||||
|
assert len(original_bboxes) == len(aug_bboxes)
|
||||||
|
|
||||||
|
# At least one bbox should have different coordinates
|
||||||
|
# (perspective warp changes geometry)
|
||||||
|
differences_found = False
|
||||||
|
for orig, aug in zip(original_bboxes, aug_bboxes):
|
||||||
|
# Class ID should be same
|
||||||
|
assert orig[0] == aug[0]
|
||||||
|
# Coordinates might differ
|
||||||
|
if orig[1:] != aug[1:]:
|
||||||
|
differences_found = True
|
||||||
|
|
||||||
|
assert differences_found, "Perspective warp should change bbox coordinates"
|
||||||
|
|
||||||
|
def test_augmenter_only_processes_train_split(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that augmenter only processes train split by default."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
# Add a val image
|
||||||
|
val_img = Image.new("RGB", (100, 100), color="white")
|
||||||
|
val_img.save(sample_dataset / "images" / "val" / "val_doc.png")
|
||||||
|
(sample_dataset / "labels" / "val" / "val_doc.txt").write_text("0 0.5 0.5 0.1 0.1\n")
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||||
|
|
||||||
|
# Val should still have only 1 image
|
||||||
|
val_count = len(list((sample_dataset / "images" / "val").glob("*.png")))
|
||||||
|
assert val_count == 1
|
||||||
|
|
||||||
|
def test_augmenter_with_multiplier_zero_does_nothing(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that multiplier=0 creates no augmented images."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
original_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
result = augmenter.augment_dataset(sample_dataset, multiplier=0)
|
||||||
|
|
||||||
|
new_count = len(list((sample_dataset / "images" / "train").glob("*.png")))
|
||||||
|
assert new_count == original_count
|
||||||
|
assert result["augmented_images"] == 0
|
||||||
|
|
||||||
|
def test_augmenter_with_seed_is_reproducible(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that same seed produces same augmentation results."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
# Create two separate datasets
|
||||||
|
import shutil
|
||||||
|
dataset1 = sample_dataset
|
||||||
|
dataset2 = sample_dataset.parent / "dataset2"
|
||||||
|
shutil.copytree(dataset1, dataset2)
|
||||||
|
|
||||||
|
# Augment both with same seed
|
||||||
|
augmenter1 = DatasetAugmenter(augmentation_config, seed=42)
|
||||||
|
augmenter1.augment_dataset(dataset1, multiplier=1)
|
||||||
|
|
||||||
|
augmenter2 = DatasetAugmenter(augmentation_config, seed=42)
|
||||||
|
augmenter2.augment_dataset(dataset2, multiplier=1)
|
||||||
|
|
||||||
|
# Compare augmented images
|
||||||
|
aug1 = np.array(Image.open(dataset1 / "images" / "train" / "doc_0_aug0.png"))
|
||||||
|
aug2 = np.array(Image.open(dataset2 / "images" / "train" / "doc_0_aug0.png"))
|
||||||
|
|
||||||
|
assert np.array_equal(aug1, aug2), "Same seed should produce same augmentation"
|
||||||
|
|
||||||
|
def test_augmenter_returns_summary(
|
||||||
|
self, sample_dataset: Path, augmentation_config: dict
|
||||||
|
):
|
||||||
|
"""Test that augmenter returns a summary of what was done."""
|
||||||
|
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||||
|
|
||||||
|
augmenter = DatasetAugmenter(augmentation_config)
|
||||||
|
result = augmenter.augment_dataset(sample_dataset, multiplier=2)
|
||||||
|
|
||||||
|
assert "original_images" in result
|
||||||
|
assert "augmented_images" in result
|
||||||
|
assert "total_images" in result
|
||||||
|
assert result["original_images"] == 3
|
||||||
|
assert result["augmented_images"] == 6
|
||||||
|
assert result["total_images"] == 9
|
||||||
261
tests/web/test_augmentation_routes.py
Normal file
261
tests/web/test_augmentation_routes.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
"""
|
||||||
|
Tests for augmentation API routes.
|
||||||
|
|
||||||
|
TDD Phase 5: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationTypesEndpoint:
|
||||||
|
"""Tests for GET /admin/augmentation/types endpoint."""
|
||||||
|
|
||||||
|
def test_list_augmentation_types(
|
||||||
|
self, admin_client: TestClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
"""Test listing available augmentation types."""
|
||||||
|
response = admin_client.get(
|
||||||
|
"/api/v1/admin/augmentation/types",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "augmentation_types" in data
|
||||||
|
assert len(data["augmentation_types"]) == 12
|
||||||
|
|
||||||
|
# Check structure
|
||||||
|
aug_type = data["augmentation_types"][0]
|
||||||
|
assert "name" in aug_type
|
||||||
|
assert "description" in aug_type
|
||||||
|
assert "affects_geometry" in aug_type
|
||||||
|
assert "stage" in aug_type
|
||||||
|
|
||||||
|
def test_list_augmentation_types_unauthorized(
|
||||||
|
self, admin_client: TestClient
|
||||||
|
) -> None:
|
||||||
|
"""Test that unauthorized request is rejected."""
|
||||||
|
response = admin_client.get("/api/v1/admin/augmentation/types")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationPresetsEndpoint:
|
||||||
|
"""Tests for GET /admin/augmentation/presets endpoint."""
|
||||||
|
|
||||||
|
def test_list_presets(self, admin_client: TestClient, admin_token: str) -> None:
|
||||||
|
"""Test listing available presets."""
|
||||||
|
response = admin_client.get(
|
||||||
|
"/api/v1/admin/augmentation/presets",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "presets" in data
|
||||||
|
assert len(data["presets"]) >= 4
|
||||||
|
|
||||||
|
# Check expected presets exist
|
||||||
|
preset_names = [p["name"] for p in data["presets"]]
|
||||||
|
assert "conservative" in preset_names
|
||||||
|
assert "moderate" in preset_names
|
||||||
|
assert "aggressive" in preset_names
|
||||||
|
assert "scanned_document" in preset_names
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationPreviewEndpoint:
|
||||||
|
"""Tests for POST /admin/augmentation/preview/{document_id} endpoint."""
|
||||||
|
|
||||||
|
def test_preview_augmentation(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
admin_token: str,
|
||||||
|
sample_document_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test previewing augmentation on a document."""
|
||||||
|
response = admin_client.post(
|
||||||
|
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"augmentation_type": "gaussian_noise",
|
||||||
|
"params": {"std": 15},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "preview_url" in data
|
||||||
|
assert "original_url" in data
|
||||||
|
assert "applied_params" in data
|
||||||
|
|
||||||
|
def test_preview_invalid_augmentation_type(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
admin_token: str,
|
||||||
|
sample_document_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test that invalid augmentation type returns error."""
|
||||||
|
response = admin_client.post(
|
||||||
|
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"augmentation_type": "nonexistent",
|
||||||
|
"params": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
def test_preview_nonexistent_document(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
admin_token: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test that nonexistent document returns 404."""
|
||||||
|
response = admin_client.post(
|
||||||
|
"/api/v1/admin/augmentation/preview/00000000-0000-0000-0000-000000000000",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"augmentation_type": "gaussian_noise",
|
||||||
|
"params": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationPreviewConfigEndpoint:
|
||||||
|
"""Tests for POST /admin/augmentation/preview-config/{document_id} endpoint."""
|
||||||
|
|
||||||
|
def test_preview_config(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
admin_token: str,
|
||||||
|
sample_document_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test previewing full config on a document."""
|
||||||
|
response = admin_client.post(
|
||||||
|
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
||||||
|
"lighting_variation": {"enabled": True, "probability": 1.0},
|
||||||
|
"preserve_bboxes": True,
|
||||||
|
"seed": 42,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "preview_url" in data
|
||||||
|
assert "original_url" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentationBatchEndpoint:
|
||||||
|
"""Tests for POST /admin/augmentation/batch endpoint."""
|
||||||
|
|
||||||
|
def test_create_augmented_dataset(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
admin_token: str,
|
||||||
|
sample_dataset_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating augmented dataset."""
|
||||||
|
response = admin_client.post(
|
||||||
|
"/api/v1/admin/augmentation/batch",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"dataset_id": sample_dataset_id,
|
||||||
|
"config": {
|
||||||
|
"gaussian_noise": {"enabled": True, "probability": 0.5},
|
||||||
|
"preserve_bboxes": True,
|
||||||
|
},
|
||||||
|
"output_name": "test_augmented_dataset",
|
||||||
|
"multiplier": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "task_id" in data
|
||||||
|
assert "status" in data
|
||||||
|
assert "estimated_images" in data
|
||||||
|
|
||||||
|
def test_create_augmented_dataset_invalid_multiplier(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
admin_token: str,
|
||||||
|
sample_dataset_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test that invalid multiplier is rejected."""
|
||||||
|
response = admin_client.post(
|
||||||
|
"/api/v1/admin/augmentation/batch",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"dataset_id": sample_dataset_id,
|
||||||
|
"config": {},
|
||||||
|
"output_name": "test",
|
||||||
|
"multiplier": 100, # Too high
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
|
||||||
|
class TestAugmentedDatasetsListEndpoint:
|
||||||
|
"""Tests for GET /admin/augmentation/datasets endpoint."""
|
||||||
|
|
||||||
|
def test_list_augmented_datasets(
|
||||||
|
self, admin_client: TestClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
"""Test listing augmented datasets."""
|
||||||
|
response = admin_client.get(
|
||||||
|
"/api/v1/admin/augmentation/datasets",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "total" in data
|
||||||
|
assert "limit" in data
|
||||||
|
assert "offset" in data
|
||||||
|
assert "datasets" in data
|
||||||
|
assert isinstance(data["datasets"], list)
|
||||||
|
|
||||||
|
def test_list_augmented_datasets_pagination(
|
||||||
|
self, admin_client: TestClient, admin_token: str
|
||||||
|
) -> None:
|
||||||
|
"""Test pagination parameters."""
|
||||||
|
response = admin_client.get(
|
||||||
|
"/api/v1/admin/augmentation/datasets",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
params={"limit": 5, "offset": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["limit"] == 5
|
||||||
|
assert data["offset"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for tests
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_document_id() -> str:
|
||||||
|
"""Provide a sample document ID for testing."""
|
||||||
|
# This would need to be created in test setup
|
||||||
|
return "test-document-id"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_dataset_id() -> str:
|
||||||
|
"""Provide a sample dataset ID for testing."""
|
||||||
|
# This would need to be created in test setup
|
||||||
|
return "test-dataset-id"
|
||||||
@@ -329,3 +329,414 @@ class TestDatasetBuilder:
|
|||||||
results.append([(d["document_id"], d["split"]) for d in docs])
|
results.append([(d["document_id"], d["split"]) for d in docs])
|
||||||
|
|
||||||
assert results[0] == results[1]
|
assert results[0] == results[1]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssignSplitsByGroup:
|
||||||
|
"""Tests for _assign_splits_by_group method with group_key logic."""
|
||||||
|
|
||||||
|
def _make_mock_doc(self, doc_id, group_key=None):
|
||||||
|
"""Create a mock AdminDocument with document_id and group_key."""
|
||||||
|
doc = MagicMock(spec=AdminDocument)
|
||||||
|
doc.document_id = doc_id
|
||||||
|
doc.group_key = group_key
|
||||||
|
doc.page_count = 1
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
|
||||||
|
"""Documents with unique group_key are distributed across splits."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
# 3 documents, each with unique group_key
|
||||||
|
docs = [
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-B"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-C"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||||
|
|
||||||
|
# With 3 groups: 70% train = 2, 20% val = 1 (at least 1)
|
||||||
|
train_count = sum(1 for s in result.values() if s == "train")
|
||||||
|
val_count = sum(1 for s in result.values() if s == "val")
|
||||||
|
assert train_count >= 1
|
||||||
|
assert val_count >= 1 # Ensure val is not empty
|
||||||
|
|
||||||
|
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
|
||||||
|
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
self._make_mock_doc(uuid4(), group_key=None),
|
||||||
|
self._make_mock_doc(uuid4(), group_key=""),
|
||||||
|
self._make_mock_doc(uuid4(), group_key=None),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||||
|
|
||||||
|
# Each null/empty group_key doc is independent, distributed across splits
|
||||||
|
# With 3 docs: ensure at least 1 in train and 1 in val
|
||||||
|
train_count = sum(1 for s in result.values() if s == "train")
|
||||||
|
val_count = sum(1 for s in result.values() if s == "val")
|
||||||
|
assert train_count >= 1
|
||||||
|
assert val_count >= 1
|
||||||
|
|
||||||
|
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
|
||||||
|
"""Documents with same group_key should be assigned to the same split."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
# 6 documents in 2 groups
|
||||||
|
docs = [
|
||||||
|
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="supplier-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="supplier-B"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
|
||||||
|
|
||||||
|
# All docs in supplier-A should have same split
|
||||||
|
splits_a = [result[str(d.document_id)] for d in docs[:3]]
|
||||||
|
assert len(set(splits_a)) == 1, "All docs in supplier-A should be in same split"
|
||||||
|
|
||||||
|
# All docs in supplier-B should have same split
|
||||||
|
splits_b = [result[str(d.document_id)] for d in docs[3:]]
|
||||||
|
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
|
||||||
|
|
||||||
|
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
|
||||||
|
"""Multi-doc groups should be split according to train/val/test ratios."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
# 10 groups with 2 docs each
|
||||||
|
docs = []
|
||||||
|
for i in range(10):
|
||||||
|
group_key = f"group-{i}"
|
||||||
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||||
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||||
|
|
||||||
|
# Count groups per split
|
||||||
|
group_splits = {}
|
||||||
|
for doc in docs:
|
||||||
|
split = result[str(doc.document_id)]
|
||||||
|
if doc.group_key not in group_splits:
|
||||||
|
group_splits[doc.group_key] = split
|
||||||
|
else:
|
||||||
|
# Verify same group has same split
|
||||||
|
assert group_splits[doc.group_key] == split
|
||||||
|
|
||||||
|
split_counts = {"train": 0, "val": 0, "test": 0}
|
||||||
|
for split in group_splits.values():
|
||||||
|
split_counts[split] += 1
|
||||||
|
|
||||||
|
# With 10 groups, 70/20/10 -> ~7 train, ~2 val, ~1 test
|
||||||
|
assert split_counts["train"] >= 6
|
||||||
|
assert split_counts["train"] <= 8
|
||||||
|
assert split_counts["val"] >= 1
|
||||||
|
assert split_counts["val"] <= 3
|
||||||
|
|
||||||
|
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||||
|
"""Mix of single-doc and multi-doc groups should be handled correctly."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
# Single-doc groups
|
||||||
|
self._make_mock_doc(uuid4(), group_key="single-1"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="single-2"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key=None),
|
||||||
|
# Multi-doc groups
|
||||||
|
self._make_mock_doc(uuid4(), group_key="multi-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="multi-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="multi-B"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="multi-B"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.5, seed=42)
|
||||||
|
|
||||||
|
# All groups are shuffled and distributed
|
||||||
|
# Ensure at least 1 in train and 1 in val
|
||||||
|
train_count = sum(1 for s in result.values() if s == "train")
|
||||||
|
val_count = sum(1 for s in result.values() if s == "val")
|
||||||
|
assert train_count >= 1
|
||||||
|
assert val_count >= 1
|
||||||
|
|
||||||
|
# Multi-doc groups stay together
|
||||||
|
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
|
||||||
|
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
|
||||||
|
|
||||||
|
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
|
||||||
|
"""Same seed should produce same split assignments."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-B"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-B"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-C"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-C"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
|
||||||
|
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=123)
|
||||||
|
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
|
||||||
|
"""Different seeds should potentially produce different split assignments."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
# Many groups to increase chance of different results
|
||||||
|
docs = []
|
||||||
|
for i in range(20):
|
||||||
|
group_key = f"group-{i}"
|
||||||
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||||
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||||
|
|
||||||
|
result1 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=1)
|
||||||
|
result2 = builder._assign_splits_by_group(docs, train_ratio=0.5, val_ratio=0.3, seed=999)
|
||||||
|
|
||||||
|
# Results should be different (very likely with 20 groups)
|
||||||
|
assert result1 != result2
|
||||||
|
|
||||||
|
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
|
||||||
|
"""Every document should be assigned a split."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key=None),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="single"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||||
|
|
||||||
|
assert len(result) == len(docs)
|
||||||
|
for doc in docs:
|
||||||
|
assert str(doc.document_id) in result
|
||||||
|
assert result[str(doc.document_id)] in ["train", "val", "test"]
|
||||||
|
|
||||||
|
def test_empty_documents_list(self, tmp_path, mock_admin_db):
|
||||||
|
"""Empty document list should return empty result."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||||
|
"""When all groups have multiple docs, splits should follow ratios."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
# 5 groups with 3 docs each
|
||||||
|
docs = []
|
||||||
|
for i in range(5):
|
||||||
|
group_key = f"group-{i}"
|
||||||
|
for _ in range(3):
|
||||||
|
docs.append(self._make_mock_doc(uuid4(), group_key=group_key))
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
|
||||||
|
|
||||||
|
# Group splits
|
||||||
|
group_splits = {}
|
||||||
|
for doc in docs:
|
||||||
|
if doc.group_key not in group_splits:
|
||||||
|
group_splits[doc.group_key] = result[str(doc.document_id)]
|
||||||
|
|
||||||
|
split_counts = {"train": 0, "val": 0, "test": 0}
|
||||||
|
for split in group_splits.values():
|
||||||
|
split_counts[split] += 1
|
||||||
|
|
||||||
|
# With 5 groups, 60/20/20 -> 3 train, 1 val, 1 test
|
||||||
|
assert split_counts["train"] >= 2
|
||||||
|
assert split_counts["train"] <= 4
|
||||||
|
|
||||||
|
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
|
||||||
|
"""When all groups have single doc, they are distributed across splits."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
self._make_mock_doc(uuid4(), group_key="unique-1"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="unique-2"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key="unique-3"),
|
||||||
|
self._make_mock_doc(uuid4(), group_key=None),
|
||||||
|
self._make_mock_doc(uuid4(), group_key=""),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = builder._assign_splits_by_group(docs, train_ratio=0.6, val_ratio=0.2, seed=42)
|
||||||
|
|
||||||
|
# With 5 groups: 60% train = 3, 20% val = 1 (at least 1)
|
||||||
|
train_count = sum(1 for s in result.values() if s == "train")
|
||||||
|
val_count = sum(1 for s in result.values() if s == "val")
|
||||||
|
assert train_count >= 2
|
||||||
|
assert val_count >= 1 # Ensure val is not empty
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildDatasetWithGroupKey:
|
||||||
|
"""Integration tests for build_dataset with group_key logic."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def grouped_documents(self, tmp_path):
|
||||||
|
"""Create documents with various group_key configurations."""
|
||||||
|
doc_ids = []
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
# Create 3 groups: 2 multi-doc groups + 2 single-doc groups
|
||||||
|
group_configs = [
|
||||||
|
("supplier-A", 3), # Multi-doc group: 3 docs
|
||||||
|
("supplier-B", 2), # Multi-doc group: 2 docs
|
||||||
|
("unique-1", 1), # Single-doc group
|
||||||
|
(None, 1), # Null group_key
|
||||||
|
]
|
||||||
|
|
||||||
|
for group_key, count in group_configs:
|
||||||
|
for _ in range(count):
|
||||||
|
doc_id = uuid4()
|
||||||
|
doc_ids.append(doc_id)
|
||||||
|
|
||||||
|
# Create image files
|
||||||
|
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
||||||
|
doc_dir.mkdir(parents=True)
|
||||||
|
for page in range(1, 3):
|
||||||
|
(doc_dir / f"page_{page}.png").write_bytes(b"fake-png")
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
doc = MagicMock(spec=AdminDocument)
|
||||||
|
doc.document_id = doc_id
|
||||||
|
doc.filename = f"{doc_id}.pdf"
|
||||||
|
doc.page_count = 2
|
||||||
|
doc.group_key = group_key
|
||||||
|
doc.file_path = str(doc_dir)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return tmp_path, docs
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def grouped_annotations(self, grouped_documents):
|
||||||
|
"""Create annotations for grouped documents."""
|
||||||
|
tmp_path, docs = grouped_documents
|
||||||
|
annotations = {}
|
||||||
|
for doc in docs:
|
||||||
|
doc_anns = []
|
||||||
|
for page in range(1, 3):
|
||||||
|
ann = MagicMock(spec=AdminAnnotation)
|
||||||
|
ann.document_id = doc.document_id
|
||||||
|
ann.page_number = page
|
||||||
|
ann.class_id = 0
|
||||||
|
ann.class_name = "invoice_number"
|
||||||
|
ann.x_center = 0.5
|
||||||
|
ann.y_center = 0.3
|
||||||
|
ann.width = 0.2
|
||||||
|
ann.height = 0.05
|
||||||
|
doc_anns.append(ann)
|
||||||
|
annotations[str(doc.document_id)] = doc_anns
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
def test_build_respects_group_key_splits(
|
||||||
|
self, grouped_documents, grouped_annotations, mock_admin_db
|
||||||
|
):
|
||||||
|
"""build_dataset should use group_key for split assignment."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
tmp_path, docs = grouped_documents
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||||
|
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||||
|
grouped_annotations.get(str(doc_id), [])
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = mock_admin_db.create_dataset.return_value
|
||||||
|
builder.build_dataset(
|
||||||
|
dataset_id=str(dataset.dataset_id),
|
||||||
|
document_ids=[str(d.document_id) for d in docs],
|
||||||
|
train_ratio=0.5,
|
||||||
|
val_ratio=0.5,
|
||||||
|
seed=42,
|
||||||
|
admin_images_dir=tmp_path / "admin_images",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the document splits from add_dataset_documents call
|
||||||
|
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||||
|
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||||
|
|
||||||
|
# Build mapping of doc_id -> split
|
||||||
|
doc_split_map = {d["document_id"]: d["split"] for d in docs_added}
|
||||||
|
|
||||||
|
# Verify all docs are assigned a valid split
|
||||||
|
for doc_id in doc_split_map:
|
||||||
|
assert doc_split_map[doc_id] in ("train", "val", "test")
|
||||||
|
|
||||||
|
# Verify multi-doc groups stay together
|
||||||
|
supplier_a_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-A"]
|
||||||
|
supplier_a_splits = [doc_split_map[doc_id] for doc_id in supplier_a_ids]
|
||||||
|
assert len(set(supplier_a_splits)) == 1, "supplier-A docs should be in same split"
|
||||||
|
|
||||||
|
supplier_b_ids = [str(d.document_id) for d in docs if d.group_key == "supplier-B"]
|
||||||
|
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
|
||||||
|
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
|
||||||
|
|
||||||
|
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
|
||||||
|
"""All docs with same group_key should go to same split."""
|
||||||
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
# Create 5 docs all with same group_key
|
||||||
|
docs = []
|
||||||
|
for i in range(5):
|
||||||
|
doc_id = uuid4()
|
||||||
|
doc_dir = tmp_path / "admin_images" / str(doc_id)
|
||||||
|
doc_dir.mkdir(parents=True)
|
||||||
|
(doc_dir / "page_1.png").write_bytes(b"fake-png")
|
||||||
|
|
||||||
|
doc = MagicMock(spec=AdminDocument)
|
||||||
|
doc.document_id = doc_id
|
||||||
|
doc.filename = f"{doc_id}.pdf"
|
||||||
|
doc.page_count = 1
|
||||||
|
doc.group_key = "same-group"
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||||
|
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||||
|
mock_admin_db.get_annotations_for_document.return_value = []
|
||||||
|
|
||||||
|
dataset = mock_admin_db.create_dataset.return_value
|
||||||
|
builder.build_dataset(
|
||||||
|
dataset_id=str(dataset.dataset_id),
|
||||||
|
document_ids=[str(d.document_id) for d in docs],
|
||||||
|
train_ratio=0.6,
|
||||||
|
val_ratio=0.2,
|
||||||
|
seed=42,
|
||||||
|
admin_images_dir=tmp_path / "admin_images",
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||||
|
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||||
|
|
||||||
|
splits = [d["split"] for d in docs_added]
|
||||||
|
# All should be in the same split (one group)
|
||||||
|
assert len(set(splits)) == 1, "All docs with same group_key should be in same split"
|
||||||
|
|||||||
@@ -25,6 +25,9 @@ TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012"
|
|||||||
TEST_TOKEN = "test-admin-token-12345"
|
TEST_TOKEN = "test-admin-token-12345"
|
||||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||||
|
|
||||||
|
# Generate 10 unique UUIDs for minimum document count tests
|
||||||
|
TEST_DOC_UUIDS = [f"990e8400-e29b-41d4-a716-4466554400{i:02d}" for i in range(10, 20)]
|
||||||
|
|
||||||
|
|
||||||
def _make_dataset(**overrides) -> MagicMock:
|
def _make_dataset(**overrides) -> MagicMock:
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
@@ -83,14 +86,14 @@ class TestCreateDatasetRoute:
|
|||||||
|
|
||||||
mock_builder = MagicMock()
|
mock_builder = MagicMock()
|
||||||
mock_builder.build_dataset.return_value = {
|
mock_builder.build_dataset.return_value = {
|
||||||
"total_documents": 2,
|
"total_documents": 10,
|
||||||
"total_images": 4,
|
"total_images": 20,
|
||||||
"total_annotations": 10,
|
"total_annotations": 50,
|
||||||
}
|
}
|
||||||
|
|
||||||
request = DatasetCreateRequest(
|
request = DatasetCreateRequest(
|
||||||
name="test-dataset",
|
name="test-dataset",
|
||||||
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
document_ids=TEST_DOC_UUIDS, # Use 10 documents to meet minimum
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
@@ -104,6 +107,73 @@ class TestCreateDatasetRoute:
|
|||||||
assert result.dataset_id == TEST_DATASET_UUID
|
assert result.dataset_id == TEST_DATASET_UUID
|
||||||
assert result.name == "test-dataset"
|
assert result.name == "test-dataset"
|
||||||
|
|
||||||
|
def test_create_dataset_fails_with_less_than_10_documents(self):
|
||||||
|
"""Test that creating dataset fails if fewer than 10 documents provided."""
|
||||||
|
fn = _find_endpoint("create_dataset")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
# Only 2 documents - should fail
|
||||||
|
request = DatasetCreateRequest(
|
||||||
|
name="test-dataset",
|
||||||
|
document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2],
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||||
|
assert "got 2" in exc_info.value.detail
|
||||||
|
# Ensure DB was never called since validation failed first
|
||||||
|
mock_db.create_dataset.assert_not_called()
|
||||||
|
|
||||||
|
def test_create_dataset_fails_with_9_documents(self):
|
||||||
|
"""Test boundary condition: 9 documents should fail."""
|
||||||
|
fn = _find_endpoint("create_dataset")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
# 9 documents - just under the limit
|
||||||
|
request = DatasetCreateRequest(
|
||||||
|
name="test-dataset",
|
||||||
|
document_ids=TEST_DOC_UUIDS[:9],
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||||
|
|
||||||
|
def test_create_dataset_succeeds_with_exactly_10_documents(self):
|
||||||
|
"""Test boundary condition: exactly 10 documents should succeed."""
|
||||||
|
fn = _find_endpoint("create_dataset")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.create_dataset.return_value = _make_dataset(status="building")
|
||||||
|
|
||||||
|
mock_builder = MagicMock()
|
||||||
|
|
||||||
|
# Exactly 10 documents - should pass
|
||||||
|
request = DatasetCreateRequest(
|
||||||
|
name="test-dataset",
|
||||||
|
document_ids=TEST_DOC_UUIDS[:10],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||||
|
return_value=mock_builder,
|
||||||
|
):
|
||||||
|
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
mock_db.create_dataset.assert_called_once()
|
||||||
|
assert result.dataset_id == TEST_DATASET_UUID
|
||||||
|
|
||||||
|
|
||||||
class TestListDatasetsRoute:
|
class TestListDatasetsRoute:
|
||||||
"""Tests for GET /admin/training/datasets."""
|
"""Tests for GET /admin/training/datasets."""
|
||||||
@@ -198,3 +268,53 @@ class TestTrainFromDatasetRoute:
|
|||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
assert exc_info.value.status_code == 400
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
def test_incremental_training_with_base_model(self):
|
||||||
|
"""Test training with base_model_version_id for incremental training."""
|
||||||
|
fn = _find_endpoint("train_from_dataset")
|
||||||
|
|
||||||
|
mock_model_version = MagicMock()
|
||||||
|
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
|
||||||
|
mock_model_version.version = "1.0.0"
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||||
|
mock_db.get_model_version.return_value = mock_model_version
|
||||||
|
mock_db.create_training_task.return_value = TEST_TASK_UUID
|
||||||
|
|
||||||
|
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||||
|
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||||
|
request = DatasetTrainRequest(name="incremental-train", config=config)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
# Verify model version was looked up
|
||||||
|
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
|
||||||
|
|
||||||
|
# Verify task was created with finetune type
|
||||||
|
call_kwargs = mock_db.create_training_task.call_args[1]
|
||||||
|
assert call_kwargs["task_type"] == "finetune"
|
||||||
|
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
|
||||||
|
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
|
||||||
|
|
||||||
|
assert result.task_id == TEST_TASK_UUID
|
||||||
|
assert "Incremental training" in result.message
|
||||||
|
|
||||||
|
def test_incremental_training_with_invalid_base_model_fails(self):
|
||||||
|
"""Test that training fails if base_model_version_id doesn't exist."""
|
||||||
|
fn = _find_endpoint("train_from_dataset")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||||
|
mock_db.get_model_version.return_value = None
|
||||||
|
|
||||||
|
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||||
|
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||||
|
request = DatasetTrainRequest(name="incremental-train", config=config)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
assert "Base model version not found" in exc_info.value.detail
|
||||||
|
|||||||
399
tests/web/test_model_versions.py
Normal file
399
tests/web/test_model_versions.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
"""
|
||||||
|
Tests for Model Version API routes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from inference.data.admin_models import ModelVersion
|
||||||
|
from inference.web.api.v1.admin.training import create_training_router
|
||||||
|
from inference.web.schemas.admin import (
|
||||||
|
ModelVersionCreateRequest,
|
||||||
|
ModelVersionUpdateRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TEST_VERSION_UUID = "880e8400-e29b-41d4-a716-446655440020"
|
||||||
|
TEST_VERSION_UUID_2 = "880e8400-e29b-41d4-a716-446655440021"
|
||||||
|
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||||
|
TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010"
|
||||||
|
TEST_TOKEN = "test-admin-token-12345"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model_version(**overrides) -> MagicMock:
|
||||||
|
"""Create a mock ModelVersion."""
|
||||||
|
defaults = dict(
|
||||||
|
version_id=UUID(TEST_VERSION_UUID),
|
||||||
|
version="1.0.0",
|
||||||
|
name="test-model-v1",
|
||||||
|
description="Test model version",
|
||||||
|
model_path="/models/test-model-v1.pt",
|
||||||
|
status="inactive",
|
||||||
|
is_active=False,
|
||||||
|
task_id=UUID(TEST_TASK_UUID),
|
||||||
|
dataset_id=UUID(TEST_DATASET_UUID),
|
||||||
|
metrics_mAP=0.935,
|
||||||
|
metrics_precision=0.92,
|
||||||
|
metrics_recall=0.88,
|
||||||
|
document_count=100,
|
||||||
|
training_config={"epochs": 100, "batch_size": 16},
|
||||||
|
file_size=52428800,
|
||||||
|
trained_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
||||||
|
activated_at=None,
|
||||||
|
created_at=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||||
|
updated_at=datetime(2025, 1, 15, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
model = MagicMock(spec=ModelVersion)
|
||||||
|
for k, v in defaults.items():
|
||||||
|
setattr(model, k, v)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _find_endpoint(name: str):
|
||||||
|
"""Find endpoint function by name."""
|
||||||
|
router = create_training_router()
|
||||||
|
for route in router.routes:
|
||||||
|
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
|
||||||
|
return route.endpoint
|
||||||
|
raise AssertionError(f"Endpoint {name} not found")
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelVersionRouterRegistration:
|
||||||
|
"""Tests that model version endpoints are registered."""
|
||||||
|
|
||||||
|
def test_router_has_model_endpoints(self):
|
||||||
|
router = create_training_router()
|
||||||
|
paths = [route.path for route in router.routes]
|
||||||
|
assert any("models" in p for p in paths)
|
||||||
|
|
||||||
|
def test_has_create_model_version_endpoint(self):
|
||||||
|
endpoint = _find_endpoint("create_model_version")
|
||||||
|
assert endpoint is not None
|
||||||
|
|
||||||
|
def test_has_list_model_versions_endpoint(self):
|
||||||
|
endpoint = _find_endpoint("list_model_versions")
|
||||||
|
assert endpoint is not None
|
||||||
|
|
||||||
|
def test_has_get_active_model_endpoint(self):
|
||||||
|
endpoint = _find_endpoint("get_active_model")
|
||||||
|
assert endpoint is not None
|
||||||
|
|
||||||
|
def test_has_activate_model_version_endpoint(self):
|
||||||
|
endpoint = _find_endpoint("activate_model_version")
|
||||||
|
assert endpoint is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateModelVersionRoute:
|
||||||
|
"""Tests for POST /admin/training/models."""
|
||||||
|
|
||||||
|
def test_create_model_version(self):
|
||||||
|
fn = _find_endpoint("create_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.create_model_version.return_value = _make_model_version()
|
||||||
|
|
||||||
|
request = ModelVersionCreateRequest(
|
||||||
|
version="1.0.0",
|
||||||
|
name="test-model-v1",
|
||||||
|
model_path="/models/test-model-v1.pt",
|
||||||
|
description="Test model",
|
||||||
|
metrics_mAP=0.935,
|
||||||
|
document_count=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
mock_db.create_model_version.assert_called_once()
|
||||||
|
assert result.version_id == TEST_VERSION_UUID
|
||||||
|
assert result.status == "inactive"
|
||||||
|
assert result.message == "Model version created successfully"
|
||||||
|
|
||||||
|
def test_create_model_version_with_task_and_dataset(self):
|
||||||
|
fn = _find_endpoint("create_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.create_model_version.return_value = _make_model_version()
|
||||||
|
|
||||||
|
request = ModelVersionCreateRequest(
|
||||||
|
version="1.0.0",
|
||||||
|
name="test-model-v1",
|
||||||
|
model_path="/models/test-model-v1.pt",
|
||||||
|
task_id=TEST_TASK_UUID,
|
||||||
|
dataset_id=TEST_DATASET_UUID,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
call_kwargs = mock_db.create_model_version.call_args[1]
|
||||||
|
assert call_kwargs["task_id"] == TEST_TASK_UUID
|
||||||
|
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
|
||||||
|
|
||||||
|
|
||||||
|
class TestListModelVersionsRoute:
|
||||||
|
"""Tests for GET /admin/training/models."""
|
||||||
|
|
||||||
|
def test_list_model_versions(self):
|
||||||
|
fn = _find_endpoint("list_model_versions")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_model_versions.return_value = (
|
||||||
|
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||||
|
|
||||||
|
assert result.total == 2
|
||||||
|
assert len(result.models) == 2
|
||||||
|
assert result.models[0].version == "1.0.0"
|
||||||
|
|
||||||
|
def test_list_model_versions_with_status_filter(self):
|
||||||
|
fn = _find_endpoint("list_model_versions")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
|
||||||
|
|
||||||
|
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.models[0].status == "active"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetActiveModelRoute:
|
||||||
|
"""Tests for GET /admin/training/models/active."""
|
||||||
|
|
||||||
|
def test_get_active_model_when_exists(self):
|
||||||
|
fn = _find_endpoint("get_active_model")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
assert result.has_active_model is True
|
||||||
|
assert result.model is not None
|
||||||
|
assert result.model.is_active is True
|
||||||
|
|
||||||
|
def test_get_active_model_when_none(self):
|
||||||
|
fn = _find_endpoint("get_active_model")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_active_model_version.return_value = None
|
||||||
|
|
||||||
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
assert result.has_active_model is False
|
||||||
|
assert result.model is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetModelVersionRoute:
|
||||||
|
"""Tests for GET /admin/training/models/{version_id}."""
|
||||||
|
|
||||||
|
def test_get_model_version(self):
|
||||||
|
fn = _find_endpoint("get_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_model_version.return_value = _make_model_version()
|
||||||
|
|
||||||
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
assert result.version_id == TEST_VERSION_UUID
|
||||||
|
assert result.version == "1.0.0"
|
||||||
|
assert result.name == "test-model-v1"
|
||||||
|
assert result.metrics_mAP == 0.935
|
||||||
|
|
||||||
|
def test_get_model_version_not_found(self):
|
||||||
|
fn = _find_endpoint("get_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_model_version.return_value = None
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateModelVersionRoute:
|
||||||
|
"""Tests for PATCH /admin/training/models/{version_id}."""
|
||||||
|
|
||||||
|
def test_update_model_version(self):
|
||||||
|
fn = _find_endpoint("update_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
|
||||||
|
|
||||||
|
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
|
||||||
|
|
||||||
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
mock_db.update_model_version.assert_called_once_with(
|
||||||
|
version_id=TEST_VERSION_UUID,
|
||||||
|
name="updated-name",
|
||||||
|
description="Updated description",
|
||||||
|
status=None,
|
||||||
|
)
|
||||||
|
assert result.message == "Model version updated successfully"
|
||||||
|
|
||||||
|
def test_update_model_version_not_found(self):
|
||||||
|
fn = _find_endpoint("update_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.update_model_version.return_value = None
|
||||||
|
|
||||||
|
request = ModelVersionUpdateRequest(name="updated-name")
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivateModelVersionRoute:
|
||||||
|
"""Tests for POST /admin/training/models/{version_id}/activate."""
|
||||||
|
|
||||||
|
def test_activate_model_version(self):
|
||||||
|
fn = _find_endpoint("activate_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||||
|
assert result.status == "active"
|
||||||
|
assert result.message == "Model version activated for inference"
|
||||||
|
|
||||||
|
def test_activate_model_version_not_found(self):
|
||||||
|
fn = _find_endpoint("activate_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.activate_model_version.return_value = None
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeactivateModelVersionRoute:
|
||||||
|
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
|
||||||
|
|
||||||
|
def test_deactivate_model_version(self):
|
||||||
|
fn = _find_endpoint("deactivate_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
|
||||||
|
|
||||||
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
assert result.status == "inactive"
|
||||||
|
assert result.message == "Model version deactivated"
|
||||||
|
|
||||||
|
def test_deactivate_model_version_not_found(self):
|
||||||
|
fn = _find_endpoint("deactivate_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.deactivate_model_version.return_value = None
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestArchiveModelVersionRoute:
|
||||||
|
"""Tests for POST /admin/training/models/{version_id}/archive."""
|
||||||
|
|
||||||
|
def test_archive_model_version(self):
|
||||||
|
fn = _find_endpoint("archive_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
|
||||||
|
|
||||||
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
assert result.status == "archived"
|
||||||
|
assert result.message == "Model version archived"
|
||||||
|
|
||||||
|
def test_archive_active_model_fails(self):
|
||||||
|
fn = _find_endpoint("archive_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.archive_model_version.return_value = None
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteModelVersionRoute:
|
||||||
|
"""Tests for DELETE /admin/training/models/{version_id}."""
|
||||||
|
|
||||||
|
def test_delete_model_version(self):
|
||||||
|
fn = _find_endpoint("delete_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.delete_model_version.return_value = True
|
||||||
|
|
||||||
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
|
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||||
|
assert result["message"] == "Model version deleted"
|
||||||
|
|
||||||
|
def test_delete_active_model_fails(self):
|
||||||
|
fn = _find_endpoint("delete_model_version")
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.delete_model_version.return_value = False
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelVersionSchemas:
|
||||||
|
"""Tests for model version Pydantic schemas."""
|
||||||
|
|
||||||
|
def test_create_request_validation(self):
|
||||||
|
request = ModelVersionCreateRequest(
|
||||||
|
version="1.0.0",
|
||||||
|
name="test-model",
|
||||||
|
model_path="/models/test.pt",
|
||||||
|
)
|
||||||
|
assert request.version == "1.0.0"
|
||||||
|
assert request.name == "test-model"
|
||||||
|
assert request.document_count == 0
|
||||||
|
|
||||||
|
def test_create_request_with_metrics(self):
|
||||||
|
request = ModelVersionCreateRequest(
|
||||||
|
version="2.0.0",
|
||||||
|
name="test-model-v2",
|
||||||
|
model_path="/models/v2.pt",
|
||||||
|
metrics_mAP=0.95,
|
||||||
|
metrics_precision=0.92,
|
||||||
|
metrics_recall=0.88,
|
||||||
|
document_count=500,
|
||||||
|
)
|
||||||
|
assert request.metrics_mAP == 0.95
|
||||||
|
assert request.document_count == 500
|
||||||
|
|
||||||
|
def test_update_request_partial(self):
|
||||||
|
request = ModelVersionUpdateRequest(name="new-name")
|
||||||
|
assert request.name == "new-name"
|
||||||
|
assert request.description is None
|
||||||
|
assert request.status is None
|
||||||
Reference in New Issue
Block a user