From 33ada0350da7200161adfed92a65e8dfbefc038f Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Fri, 30 Jan 2026 00:44:21 +0100 Subject: [PATCH] WIP --- frontend/src/App.tsx | 10 +- .../src/api/endpoints/augmentation.test.ts | 118 ++++ frontend/src/api/endpoints/augmentation.ts | 144 +++++ frontend/src/api/endpoints/datasets.ts | 52 ++ frontend/src/api/endpoints/documents.ts | 20 +- frontend/src/api/endpoints/index.ts | 3 + frontend/src/api/endpoints/models.ts | 55 ++ frontend/src/api/types.ts | 168 ++++++ .../components/AugmentationConfig.test.tsx | 251 ++++++++ .../src/components/AugmentationConfig.tsx | 136 +++++ frontend/src/components/Dashboard.tsx | 10 +- frontend/src/components/DatasetDetail.tsx | 122 ++++ frontend/src/components/DocumentDetail.tsx | 67 ++- frontend/src/components/Models.tsx | 202 +++++-- frontend/src/components/Training.tsx | 561 +++++++++++++++--- frontend/src/components/UploadModal.tsx | 34 +- frontend/src/hooks/index.ts | 3 + frontend/src/hooks/useAugmentation.test.tsx | 226 +++++++ frontend/src/hooks/useAugmentation.ts | 121 ++++ frontend/src/hooks/useDatasets.ts | 84 +++ frontend/src/hooks/useDocuments.ts | 14 +- frontend/src/hooks/useModels.ts | 98 +++ migrations/005_add_group_key.sql | 8 + migrations/006_model_versions.sql | 49 ++ .../007_training_tasks_extra_columns.sql | 46 ++ migrations/008_fix_model_versions_fk.sql | 14 + packages/inference/inference/data/admin_db.py | 226 +++++++ .../inference/inference/data/admin_models.py | 52 ++ packages/inference/inference/data/database.py | 108 ++++ .../inference/web/api/v1/admin/__init__.py | 2 + .../web/api/v1/admin/augmentation/__init__.py | 15 + .../web/api/v1/admin/augmentation/routes.py | 162 +++++ .../inference/web/api/v1/admin/documents.py | 61 ++ .../web/api/v1/admin/training/__init__.py | 2 + .../web/api/v1/admin/training/datasets.py | 55 +- .../web/api/v1/admin/training/documents.py | 8 +- .../web/api/v1/admin/training/models.py | 333 +++++++++++ packages/inference/inference/web/app.py | 25 +- .../inference/inference/web/core/scheduler.py | 234 ++++++-- .../inference/web/schemas/admin/__init__.py | 1 + .../web/schemas/admin/augmentation.py | 187 ++++++ .../inference/web/schemas/admin/datasets.py | 2 + .../inference/web/schemas/admin/documents.py | 3 + .../inference/web/schemas/admin/models.py | 95 +++ .../inference/web/schemas/admin/training.py | 19 +- .../web/services/augmentation_service.py | 317 ++++++++++ .../inference/web/services/dataset_builder.py | 99 +++- .../inference/web/services/inference.py | 69 ++- .../shared/shared/augmentation/__init__.py | 24 + packages/shared/shared/augmentation/base.py | 108 ++++ packages/shared/shared/augmentation/config.py | 274 +++++++++ .../shared/augmentation/dataset_augmenter.py | 206 +++++++ .../shared/shared/augmentation/pipeline.py | 184 ++++++ .../shared/shared/augmentation/presets.py | 212 +++++++ .../augmentation/transforms/__init__.py | 13 + .../shared/augmentation/transforms/blur.py | 144 +++++ .../augmentation/transforms/degradation.py | 259 ++++++++ .../augmentation/transforms/geometric.py | 145 +++++ .../augmentation/transforms/lighting.py | 167 ++++++ .../shared/augmentation/transforms/noise.py | 142 +++++ .../shared/augmentation/transforms/texture.py | 159 +++++ packages/shared/shared/training/__init__.py | 5 + .../shared/shared/training/yolo_trainer.py | 239 ++++++++ packages/training/training/cli/train.py | 84 ++- runs_backup/train/invoice_fields/args.yaml | 106 ++++ runs_backup/train/invoice_fields/results.csv | 101 ++++ .../train/invoice_yolo11n_full/args.yaml | 106 ++++ .../train/invoice_yolo11n_full/results.csv | 101 ++++ tests/shared/augmentation/__init__.py | 1 + tests/shared/augmentation/test_base.py | 347 +++++++++++ tests/shared/augmentation/test_config.py | 283 +++++++++ tests/shared/augmentation/test_pipeline.py | 338 +++++++++++ tests/shared/augmentation/test_presets.py | 102 ++++ .../augmentation/transforms/__init__.py | 1 + tests/shared/test_dataset_augmenter.py | 293 +++++++++ tests/web/test_augmentation_routes.py | 261 ++++++++ tests/web/test_dataset_builder.py | 411 +++++++++++++ tests/web/test_dataset_routes.py | 128 +++- tests/web/test_model_versions.py | 399 +++++++++++++ 79 files changed, 9737 insertions(+), 297 deletions(-) create mode 100644 frontend/src/api/endpoints/augmentation.test.ts create mode 100644 frontend/src/api/endpoints/augmentation.ts create mode 100644 frontend/src/api/endpoints/datasets.ts create mode 100644 frontend/src/api/endpoints/models.ts create mode 100644 frontend/src/components/AugmentationConfig.test.tsx create mode 100644 frontend/src/components/AugmentationConfig.tsx create mode 100644 frontend/src/components/DatasetDetail.tsx create mode 100644 frontend/src/hooks/useAugmentation.test.tsx create mode 100644 frontend/src/hooks/useAugmentation.ts create mode 100644 frontend/src/hooks/useDatasets.ts create mode 100644 frontend/src/hooks/useModels.ts create mode 100644 migrations/005_add_group_key.sql create mode 100644 migrations/006_model_versions.sql create mode 100644 migrations/007_training_tasks_extra_columns.sql create mode 100644 migrations/008_fix_model_versions_fk.sql create mode 100644 packages/inference/inference/web/api/v1/admin/augmentation/__init__.py create mode 100644 packages/inference/inference/web/api/v1/admin/augmentation/routes.py create mode 100644 packages/inference/inference/web/api/v1/admin/training/models.py create mode 100644 packages/inference/inference/web/schemas/admin/augmentation.py create mode 100644 packages/inference/inference/web/schemas/admin/models.py create mode 100644 packages/inference/inference/web/services/augmentation_service.py create mode 100644 packages/shared/shared/augmentation/__init__.py create mode 100644 packages/shared/shared/augmentation/base.py create mode 100644 packages/shared/shared/augmentation/config.py create mode 100644 packages/shared/shared/augmentation/dataset_augmenter.py create mode 100644 packages/shared/shared/augmentation/pipeline.py create mode 100644 packages/shared/shared/augmentation/presets.py create mode 100644 packages/shared/shared/augmentation/transforms/__init__.py create mode 100644 packages/shared/shared/augmentation/transforms/blur.py create mode 100644 packages/shared/shared/augmentation/transforms/degradation.py create mode 100644 packages/shared/shared/augmentation/transforms/geometric.py create mode 100644 packages/shared/shared/augmentation/transforms/lighting.py create mode 100644 packages/shared/shared/augmentation/transforms/noise.py create mode 100644 packages/shared/shared/augmentation/transforms/texture.py create mode 100644 packages/shared/shared/training/__init__.py create mode 100644 packages/shared/shared/training/yolo_trainer.py create mode 100644 runs_backup/train/invoice_fields/args.yaml create mode 100644 runs_backup/train/invoice_fields/results.csv create mode 100644 runs_backup/train/invoice_yolo11n_full/args.yaml create mode 100644 runs_backup/train/invoice_yolo11n_full/results.csv create mode 100644 tests/shared/augmentation/__init__.py create mode 100644 tests/shared/augmentation/test_base.py create mode 100644 tests/shared/augmentation/test_config.py create mode 100644 tests/shared/augmentation/test_pipeline.py create mode 100644 tests/shared/augmentation/test_presets.py create mode 100644 tests/shared/augmentation/transforms/__init__.py create mode 100644 tests/shared/test_dataset_augmenter.py create mode 100644 tests/web/test_augmentation_routes.py create mode 100644 tests/web/test_model_versions.py diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 93a0a08..ba97090 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -4,6 +4,7 @@ import { DashboardOverview } from './components/DashboardOverview' import { Dashboard } from './components/Dashboard' import { DocumentDetail } from './components/DocumentDetail' import { Training } from './components/Training' +import { DatasetDetail } from './components/DatasetDetail' import { Models } from './components/Models' import { Login } from './components/Login' import { InferenceDemo } from './components/InferenceDemo' @@ -55,7 +56,14 @@ const App: React.FC = () => { case 'demo': return case 'training': - return + return + case 'dataset-detail': + return ( + setCurrentView('training')} + /> + ) case 'models': return default: diff --git a/frontend/src/api/endpoints/augmentation.test.ts b/frontend/src/api/endpoints/augmentation.test.ts new file mode 100644 index 0000000..9a0b7c2 --- /dev/null +++ b/frontend/src/api/endpoints/augmentation.test.ts @@ -0,0 +1,118 @@ +/** + * Tests for augmentation API endpoints. + * + * TDD Phase 1: RED - Write tests first, then implement to pass. + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { augmentationApi } from './augmentation' +import apiClient from '../client' + +// Mock the API client +vi.mock('../client', () => ({ + default: { + get: vi.fn(), + post: vi.fn(), + }, +})) + +describe('augmentationApi', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('getTypes', () => { + it('should fetch augmentation types', async () => { + const mockResponse = { + data: { + augmentation_types: [ + { + name: 'gaussian_noise', + description: 'Adds Gaussian noise', + affects_geometry: false, + stage: 'noise', + default_params: { mean: 0, std: 15 }, + }, + ], + }, + } + vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse) + + const result = await augmentationApi.getTypes() + + expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types') + expect(result.augmentation_types).toHaveLength(1) + expect(result.augmentation_types[0].name).toBe('gaussian_noise') + }) + }) + + describe('getPresets', () => { + it('should fetch augmentation presets', async () => { + const mockResponse = { + data: { + presets: [ + { name: 'conservative', description: 'Safe augmentations' }, + { name: 'moderate', description: 'Balanced augmentations' }, + ], + }, + } + vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse) + + const result = await augmentationApi.getPresets() + + expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets') + expect(result.presets).toHaveLength(2) + }) + }) + + describe('preview', () => { + it('should preview single augmentation', async () => { + const mockResponse = { + data: { + preview_url: '', + original_url: '', + applied_params: { std: 15 }, + }, + } + vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse) + + const result = await augmentationApi.preview('doc-123', { + augmentation_type: 'gaussian_noise', + params: { std: 15 }, + }) + + expect(apiClient.post).toHaveBeenCalledWith( + '/api/v1/admin/augmentation/preview/doc-123', + { + augmentation_type: 'gaussian_noise', + params: { std: 15 }, + }, + { params: { page: 1 } } + ) + expect(result.preview_url).toBe('') + }) + + it('should support custom page number', async () => { + const mockResponse = { + data: { + preview_url: '', + original_url: '', + applied_params: {}, + }, + } + vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse) + + await augmentationApi.preview( + 'doc-123', + { augmentation_type: 'gaussian_noise', params: {} }, + 2 + ) + + expect(apiClient.post).toHaveBeenCalledWith( + '/api/v1/admin/augmentation/preview/doc-123', + expect.anything(), + { params: { page: 2 } } + ) + }) + }) +}) diff --git a/frontend/src/api/endpoints/augmentation.ts b/frontend/src/api/endpoints/augmentation.ts new file mode 100644 index 0000000..52cac76 --- /dev/null +++ b/frontend/src/api/endpoints/augmentation.ts @@ -0,0 +1,144 @@ +/** + * Augmentation API endpoints. + * + * Provides functions for fetching augmentation types, presets, and previewing augmentations. + */ + +import apiClient from '../client' + +// Types +export interface AugmentationTypeInfo { + name: string + description: string + affects_geometry: boolean + stage: string + default_params: Record +} + +export interface AugmentationTypesResponse { + augmentation_types: AugmentationTypeInfo[] +} + +export interface PresetInfo { + name: string + description: string + config?: Record +} + +export interface PresetsResponse { + presets: PresetInfo[] +} + +export interface PreviewRequest { + augmentation_type: string + params: Record +} + +export interface PreviewResponse { + preview_url: string + original_url: string + applied_params: Record +} + +export interface AugmentationParams { + enabled: boolean + probability: number + params: Record +} + +export interface AugmentationConfig { + perspective_warp?: AugmentationParams + wrinkle?: AugmentationParams + edge_damage?: AugmentationParams + stain?: AugmentationParams + lighting_variation?: AugmentationParams + shadow?: AugmentationParams + gaussian_blur?: AugmentationParams + motion_blur?: AugmentationParams + gaussian_noise?: AugmentationParams + salt_pepper?: AugmentationParams + paper_texture?: AugmentationParams + scanner_artifacts?: AugmentationParams + preserve_bboxes?: boolean + seed?: number | null +} + +export interface BatchRequest { + dataset_id: string + config: AugmentationConfig + output_name: string + multiplier: number +} + +export interface BatchResponse { + task_id: string + status: string + message: string + estimated_images: number +} + +// API functions +export const augmentationApi = { + /** + * Fetch available augmentation types. + */ + async getTypes(): Promise { + const response = await apiClient.get( + '/api/v1/admin/augmentation/types' + ) + return response.data + }, + + /** + * Fetch augmentation presets. + */ + async getPresets(): Promise { + const response = await apiClient.get( + '/api/v1/admin/augmentation/presets' + ) + return response.data + }, + + /** + * Preview a single augmentation on a document page. + */ + async preview( + documentId: string, + request: PreviewRequest, + page: number = 1 + ): Promise { + const response = await apiClient.post( + `/api/v1/admin/augmentation/preview/${documentId}`, + request, + { params: { page } } + ) + return response.data + }, + + /** + * Preview full augmentation config on a document page. + */ + async previewConfig( + documentId: string, + config: AugmentationConfig, + page: number = 1 + ): Promise { + const response = await apiClient.post( + `/api/v1/admin/augmentation/preview-config/${documentId}`, + config, + { params: { page } } + ) + return response.data + }, + + /** + * Create an augmented dataset. + */ + async createBatch(request: BatchRequest): Promise { + const response = await apiClient.post( + '/api/v1/admin/augmentation/batch', + request + ) + return response.data + }, +} diff --git a/frontend/src/api/endpoints/datasets.ts b/frontend/src/api/endpoints/datasets.ts new file mode 100644 index 0000000..42d1d67 --- /dev/null +++ b/frontend/src/api/endpoints/datasets.ts @@ -0,0 +1,52 @@ +import apiClient from '../client' +import type { + DatasetCreateRequest, + DatasetDetailResponse, + DatasetListResponse, + DatasetResponse, + DatasetTrainRequest, + TrainingTaskResponse, +} from '../types' + +export const datasetsApi = { + list: async (params?: { + status?: string + limit?: number + offset?: number + }): Promise => { + const { data } = await apiClient.get('/api/v1/admin/training/datasets', { + params, + }) + return data + }, + + create: async (req: DatasetCreateRequest): Promise => { + const { data } = await apiClient.post('/api/v1/admin/training/datasets', req) + return data + }, + + getDetail: async (datasetId: string): Promise => { + const { data } = await apiClient.get( + `/api/v1/admin/training/datasets/${datasetId}` + ) + return data + }, + + remove: async (datasetId: string): Promise<{ message: string }> => { + const { data } = await apiClient.delete( + `/api/v1/admin/training/datasets/${datasetId}` + ) + return data + }, + + trainFromDataset: async ( + datasetId: string, + req: DatasetTrainRequest + ): Promise => { + const { data } = await apiClient.post( + `/api/v1/admin/training/datasets/${datasetId}/train`, + req + ) + return data + }, +} diff --git a/frontend/src/api/endpoints/documents.ts b/frontend/src/api/endpoints/documents.ts index 56e5627..75367ed 100644 --- a/frontend/src/api/endpoints/documents.ts +++ b/frontend/src/api/endpoints/documents.ts @@ -21,14 +21,20 @@ export const documentsApi = { return data }, - upload: async (file: File): Promise => { + upload: async (file: File, groupKey?: string): Promise => { const formData = new FormData() formData.append('file', file) + const params: Record = {} + if (groupKey) { + params.group_key = groupKey + } + const { data } = await apiClient.post('/api/v1/admin/documents', formData, { headers: { 'Content-Type': 'multipart/form-data', }, + params, }) return data }, @@ -77,4 +83,16 @@ export const documentsApi = { ) return data }, + + updateGroupKey: async ( + documentId: string, + groupKey: string | null + ): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => { + const { data } = await apiClient.patch( + `/api/v1/admin/documents/${documentId}/group-key`, + null, + { params: { group_key: groupKey } } + ) + return data + }, } diff --git a/frontend/src/api/endpoints/index.ts b/frontend/src/api/endpoints/index.ts index f24f2f5..554ac30 100644 --- a/frontend/src/api/endpoints/index.ts +++ b/frontend/src/api/endpoints/index.ts @@ -2,3 +2,6 @@ export { documentsApi } from './documents' export { annotationsApi } from './annotations' export { trainingApi } from './training' export { inferenceApi } from './inference' +export { datasetsApi } from './datasets' +export { augmentationApi } from './augmentation' +export { modelsApi } from './models' diff --git a/frontend/src/api/endpoints/models.ts b/frontend/src/api/endpoints/models.ts new file mode 100644 index 0000000..46b71c0 --- /dev/null +++ b/frontend/src/api/endpoints/models.ts @@ -0,0 +1,55 @@ +import apiClient from '../client' +import type { + ModelVersionListResponse, + ModelVersionDetailResponse, + ModelVersionResponse, + ActiveModelResponse, +} from '../types' + +export const modelsApi = { + list: async (params?: { + status?: string + limit?: number + offset?: number + }): Promise => { + const { data } = await apiClient.get('/api/v1/admin/training/models', { + params, + }) + return data + }, + + getDetail: async (versionId: string): Promise => { + const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`) + return data + }, + + getActive: async (): Promise => { + const { data } = await apiClient.get('/api/v1/admin/training/models/active') + return data + }, + + activate: async (versionId: string): Promise => { + const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`) + return data + }, + + deactivate: async (versionId: string): Promise => { + const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`) + return data + }, + + archive: async (versionId: string): Promise => { + const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`) + return data + }, + + delete: async (versionId: string): Promise<{ message: string }> => { + const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`) + return data + }, + + reload: async (): Promise<{ message: string; reloaded: boolean }> => { + const { data } = await apiClient.post('/api/v1/admin/training/models/reload') + return data + }, +} diff --git a/frontend/src/api/types.ts b/frontend/src/api/types.ts index c668a59..73908ca 100644 --- a/frontend/src/api/types.ts +++ b/frontend/src/api/types.ts @@ -8,6 +8,7 @@ export interface DocumentItem { auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null auto_label_error: string | null upload_source: string + group_key: string | null created_at: string updated_at: string annotation_count?: number @@ -59,6 +60,7 @@ export interface DocumentDetailResponse { auto_label_error: string | null upload_source: string batch_id: string | null + group_key: string | null csv_field_values: Record | null can_annotate: boolean annotation_lock_until: string | null @@ -113,7 +115,11 @@ export interface ErrorResponse { export interface UploadDocumentResponse { document_id: string filename: string + file_size: number + page_count: number status: string + group_key: string | null + auto_label_started: boolean message: string } @@ -171,3 +177,165 @@ export interface InferenceResult { export interface InferenceResponse { result: InferenceResult } + +// Dataset types + +export interface DatasetCreateRequest { + name: string + description?: string + document_ids: string[] + train_ratio?: number + val_ratio?: number + seed?: number +} + +export interface DatasetResponse { + dataset_id: string + name: string + status: string + message: string +} + +export interface DatasetDocumentItem { + document_id: string + split: string + page_count: number + annotation_count: number +} + +export interface DatasetListItem { + dataset_id: string + name: string + description: string | null + status: string + training_status: string | null + active_training_task_id: string | null + total_documents: number + total_images: number + total_annotations: number + created_at: string +} + +export interface DatasetListResponse { + total: number + limit: number + offset: number + datasets: DatasetListItem[] +} + +export interface DatasetDetailResponse { + dataset_id: string + name: string + description: string | null + status: string + train_ratio: number + val_ratio: number + seed: number + total_documents: number + total_images: number + total_annotations: number + dataset_path: string | null + error_message: string | null + documents: DatasetDocumentItem[] + created_at: string + updated_at: string +} + +export interface AugmentationParams { + enabled: boolean + probability: number + params: Record +} + +export interface AugmentationTrainingConfig { + gaussian_noise?: AugmentationParams + perspective_warp?: AugmentationParams + wrinkle?: AugmentationParams + edge_damage?: AugmentationParams + stain?: AugmentationParams + lighting_variation?: AugmentationParams + shadow?: AugmentationParams + gaussian_blur?: AugmentationParams + motion_blur?: AugmentationParams + salt_pepper?: AugmentationParams + paper_texture?: AugmentationParams + scanner_artifacts?: AugmentationParams + preserve_bboxes?: boolean + seed?: number | null +} + +export interface DatasetTrainRequest { + name: string + config: { + model_name?: string + base_model_version_id?: string | null + epochs?: number + batch_size?: number + image_size?: number + learning_rate?: number + device?: string + augmentation?: AugmentationTrainingConfig + augmentation_multiplier?: number + } +} + +export interface TrainingTaskResponse { + task_id: string + status: string + message: string +} + +// Model Version types + +export interface ModelVersionItem { + version_id: string + version: string + name: string + status: string + is_active: boolean + metrics_mAP: number | null + document_count: number + trained_at: string | null + activated_at: string | null + created_at: string +} + +export interface ModelVersionDetailResponse { + version_id: string + version: string + name: string + description: string | null + model_path: string + status: string + is_active: boolean + task_id: string | null + dataset_id: string | null + metrics_mAP: number | null + metrics_precision: number | null + metrics_recall: number | null + document_count: number + training_config: Record | null + file_size: number | null + trained_at: string | null + activated_at: string | null + created_at: string + updated_at: string +} + +export interface ModelVersionListResponse { + total: number + limit: number + offset: number + models: ModelVersionItem[] +} + +export interface ModelVersionResponse { + version_id: string + status: string + message: string +} + +export interface ActiveModelResponse { + has_active_model: boolean + model: ModelVersionItem | null +} diff --git a/frontend/src/components/AugmentationConfig.test.tsx b/frontend/src/components/AugmentationConfig.test.tsx new file mode 100644 index 0000000..f79a1be --- /dev/null +++ b/frontend/src/components/AugmentationConfig.test.tsx @@ -0,0 +1,251 @@ +/** + * Tests for AugmentationConfig component. + * + * TDD Phase 1: RED - Write tests first, then implement to pass. + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { AugmentationConfig } from './AugmentationConfig' +import { augmentationApi } from '../api/endpoints/augmentation' +import type { ReactNode } from 'react' + +// Mock the API +vi.mock('../api/endpoints/augmentation', () => ({ + augmentationApi: { + getTypes: vi.fn(), + getPresets: vi.fn(), + preview: vi.fn(), + previewConfig: vi.fn(), + createBatch: vi.fn(), + }, +})) + +// Default mock data +const mockTypes = { + augmentation_types: [ + { + name: 'gaussian_noise', + description: 'Adds Gaussian noise to simulate sensor noise', + affects_geometry: false, + stage: 'noise', + default_params: { mean: 0, std: 15 }, + }, + { + name: 'perspective_warp', + description: 'Applies perspective transformation', + affects_geometry: true, + stage: 'geometric', + default_params: { max_warp: 0.02 }, + }, + { + name: 'gaussian_blur', + description: 'Applies Gaussian blur', + affects_geometry: false, + stage: 'blur', + default_params: { kernel_size: 5 }, + }, + ], +} + +const mockPresets = { + presets: [ + { name: 'conservative', description: 'Safe augmentations for high-quality documents' }, + { name: 'moderate', description: 'Balanced augmentation settings' }, + { name: 'aggressive', description: 'Strong augmentations for data diversity' }, + ], +} + +// Test wrapper with QueryClient +const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => ( + {children} + ) +} + +describe('AugmentationConfig', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes) + vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets) + }) + + describe('rendering', () => { + it('should render enable checkbox', async () => { + render( + , + { wrapper: createWrapper() } + ) + + expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument() + }) + + it('should be collapsed when disabled', () => { + render( + , + { wrapper: createWrapper() } + ) + + // Config options should not be visible + expect(screen.queryByText(/preset/i)).not.toBeInTheDocument() + }) + + it('should expand when enabled', async () => { + render( + , + { wrapper: createWrapper() } + ) + + await waitFor(() => { + expect(screen.getByText(/preset/i)).toBeInTheDocument() + }) + }) + }) + + describe('preset selection', () => { + it('should display available presets', async () => { + render( + , + { wrapper: createWrapper() } + ) + + await waitFor(() => { + expect(screen.getByText('conservative')).toBeInTheDocument() + expect(screen.getByText('moderate')).toBeInTheDocument() + expect(screen.getByText('aggressive')).toBeInTheDocument() + }) + }) + + it('should call onConfigChange when preset is selected', async () => { + const user = userEvent.setup() + const onConfigChange = vi.fn() + + render( + , + { wrapper: createWrapper() } + ) + + await waitFor(() => { + expect(screen.getByText('moderate')).toBeInTheDocument() + }) + + await user.click(screen.getByText('moderate')) + + expect(onConfigChange).toHaveBeenCalled() + }) + }) + + describe('enable toggle', () => { + it('should call onEnabledChange when checkbox is toggled', async () => { + const user = userEvent.setup() + const onEnabledChange = vi.fn() + + render( + , + { wrapper: createWrapper() } + ) + + await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i })) + + expect(onEnabledChange).toHaveBeenCalledWith(true) + }) + }) + + describe('augmentation types', () => { + it('should display augmentation types when in custom mode', async () => { + render( + , + { wrapper: createWrapper() } + ) + + await waitFor(() => { + expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument() + expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument() + }) + }) + + it('should indicate which augmentations affect geometry', async () => { + render( + , + { wrapper: createWrapper() } + ) + + await waitFor(() => { + // perspective_warp affects geometry + const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div') + expect(perspectiveItem).toHaveTextContent(/affects bbox/i) + }) + }) + }) + + describe('loading state', () => { + it('should show loading indicator while fetching types', () => { + vi.mocked(augmentationApi.getTypes).mockImplementation( + () => new Promise(() => {}) + ) + + render( + , + { wrapper: createWrapper() } + ) + + expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument() + }) + }) +}) diff --git a/frontend/src/components/AugmentationConfig.tsx b/frontend/src/components/AugmentationConfig.tsx new file mode 100644 index 0000000..07ffdd6 --- /dev/null +++ b/frontend/src/components/AugmentationConfig.tsx @@ -0,0 +1,136 @@ +/** + * AugmentationConfig component for configuring image augmentation during training. + * + * Provides preset selection and optional custom augmentation type configuration. + */ + +import React from 'react' +import { Loader2, AlertTriangle } from 'lucide-react' +import { useAugmentation } from '../hooks/useAugmentation' +import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation' + +interface AugmentationConfigProps { + enabled: boolean + onEnabledChange: (enabled: boolean) => void + config: Partial + onConfigChange: (config: Partial) => void + showCustomOptions?: boolean +} + +export const AugmentationConfig: React.FC = ({ + enabled, + onEnabledChange, + config, + onConfigChange, + showCustomOptions = false, +}) => { + const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation() + + const isLoading = isLoadingTypes || isLoadingPresets + + const handlePresetSelect = (presetName: string) => { + const preset = presets.find((p) => p.name === presetName) + if (preset && preset.config) { + onConfigChange(preset.config as Partial) + } else { + // Apply a basic config based on preset name + const presetConfigs: Record> = { + conservative: { + gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } }, + gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } }, + }, + moderate: { + gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } }, + gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } }, + lighting_variation: { enabled: true, probability: 0.3, params: {} }, + perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } }, + }, + aggressive: { + gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } }, + gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } }, + motion_blur: { enabled: true, probability: 0.3, params: {} }, + lighting_variation: { enabled: true, probability: 0.5, params: {} }, + shadow: { enabled: true, probability: 0.3, params: {} }, + perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } }, + wrinkle: { enabled: true, probability: 0.2, params: {} }, + stain: { enabled: true, probability: 0.2, params: {} }, + }, + } + onConfigChange(presetConfigs[presetName] || {}) + } + } + + return ( +
+ {/* Enable checkbox */} + + + {/* Expanded content when enabled */} + {enabled && ( +
+ {isLoading ? ( +
+ + Loading augmentation options... +
+ ) : ( + <> + {/* Preset selection */} +
+ +
+ {presets.map((preset) => ( + + ))} +
+
+ + {/* Custom options (if enabled) */} + {showCustomOptions && ( +
+

Augmentation Types

+
+ {augmentationTypes.map((type) => ( +
+
+ {type.name} + {type.affects_geometry && ( + + + affects bbox + + )} +
+ {type.stage} +
+ ))} +
+
+ )} + + )} +
+ )} +
+ ) +} diff --git a/frontend/src/components/Dashboard.tsx b/frontend/src/components/Dashboard.tsx index 02517d7..5572734 100644 --- a/frontend/src/components/Dashboard.tsx +++ b/frontend/src/components/Dashboard.tsx @@ -144,6 +144,9 @@ export const Dashboard: React.FC = ({ onNavigate }) => { Annotations + + Group + Auto-label @@ -153,13 +156,13 @@ export const Dashboard: React.FC = ({ onNavigate }) => { {isLoading ? ( - + Loading documents... ) : documents.length === 0 ? ( - + No documents found. Upload your first document to get started. @@ -213,6 +216,9 @@ export const Dashboard: React.FC = ({ onNavigate }) => { {doc.annotation_count || 0} annotations + + {doc.group_key || '-'} + {doc.auto_label_status === 'running' && progress && (
diff --git a/frontend/src/components/DatasetDetail.tsx b/frontend/src/components/DatasetDetail.tsx new file mode 100644 index 0000000..b70b704 --- /dev/null +++ b/frontend/src/components/DatasetDetail.tsx @@ -0,0 +1,122 @@ +import React from 'react' +import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react' +import { Button } from './Button' +import { useDatasetDetail } from '../hooks/useDatasets' + +interface DatasetDetailProps { + datasetId: string + onBack: () => void +} + +const SPLIT_STYLES: Record = { + train: 'bg-warm-state-info/10 text-warm-state-info', + val: 'bg-warm-state-warning/10 text-warm-state-warning', + test: 'bg-warm-state-success/10 text-warm-state-success', +} + +export const DatasetDetail: React.FC = ({ datasetId, onBack }) => { + const { dataset, isLoading, error } = useDatasetDetail(datasetId) + + if (isLoading) { + return ( +
+ Loading dataset... +
+ ) + } + + if (error || !dataset) { + return ( +
+ +

Failed to load dataset.

+
+ ) + } + + const statusIcon = dataset.status === 'ready' + ? + : dataset.status === 'failed' + ? + : + + return ( +
+ {/* Header */} + + +
+
+

+ {dataset.name} {statusIcon} +

+ {dataset.description && ( +

{dataset.description}

+ )} +
+ {dataset.status === 'ready' && ( + + )} +
+ + {dataset.error_message && ( +
+ {dataset.error_message} +
+ )} + + {/* Stats */} +
+ {[ + ['Documents', dataset.total_documents], + ['Images', dataset.total_images], + ['Annotations', dataset.total_annotations], + ['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`], + ].map(([label, value]) => ( +
+

{label}

+

{value}

+
+ ))} +
+ + {/* Document list */} +

Documents

+
+ + + + + + + + + + + {dataset.documents.map(doc => ( + + + + + + + ))} + +
Document IDSplitPagesAnnotations
{doc.document_id.slice(0, 8)}... + + {doc.split} + + {doc.page_count}{doc.annotation_count}
+
+ +

+ Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()} + {dataset.dataset_path && <> | Path: {dataset.dataset_path}} +

+
+ ) +} diff --git a/frontend/src/components/DocumentDetail.tsx b/frontend/src/components/DocumentDetail.tsx index cf4da73..258a4ba 100644 --- a/frontend/src/components/DocumentDetail.tsx +++ b/frontend/src/components/DocumentDetail.tsx @@ -1,8 +1,9 @@ import React, { useState, useRef, useEffect } from 'react' -import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react' +import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react' import { Button } from './Button' import { useDocumentDetail } from '../hooks/useDocumentDetail' import { useAnnotations } from '../hooks/useAnnotations' +import { useDocuments } from '../hooks/useDocuments' import { documentsApi } from '../api/endpoints/documents' import type { AnnotationItem } from '../api/types' @@ -26,7 +27,7 @@ const FIELD_CLASSES: Record = { } export const DocumentDetail: React.FC = ({ docId, onBack }) => { - const { document, annotations, isLoading } = useDocumentDetail(docId) + const { document, annotations, isLoading, refetch } = useDocumentDetail(docId) const { createAnnotation, updateAnnotation, @@ -34,10 +35,13 @@ export const DocumentDetail: React.FC = ({ docId, onBack }) isCreating, isDeleting, } = useAnnotations(docId) + const { updateGroupKey, isUpdatingGroupKey } = useDocuments({}) const [selectedId, setSelectedId] = useState(null) const [zoom, setZoom] = useState(100) const [isDrawing, setIsDrawing] = useState(false) + const [isEditingGroupKey, setIsEditingGroupKey] = useState(false) + const [editGroupKeyValue, setEditGroupKeyValue] = useState('') const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null) const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null) const [selectedClassId, setSelectedClassId] = useState(0) @@ -426,6 +430,65 @@ export const DocumentDetail: React.FC = ({ docId, onBack }) {new Date(document.created_at).toLocaleDateString()}
+
+ Group + {isEditingGroupKey ? ( +
+ setEditGroupKeyValue(e.target.value)} + className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info" + placeholder="group key" + autoFocus + /> + + +
+ ) : ( +
+ + {document.group_key || '-'} + + +
+ )} +
diff --git a/frontend/src/components/Models.tsx b/frontend/src/components/Models.tsx index c35052f..bfe2222 100644 --- a/frontend/src/components/Models.tsx +++ b/frontend/src/components/Models.tsx @@ -1,73 +1,108 @@ -import React from 'react'; +import React, { useState } from 'react'; import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts'; +import { Loader2, Power, CheckCircle } from 'lucide-react'; import { Button } from './Button'; +import { useModels, useModelDetail } from '../hooks'; +import type { ModelVersionItem } from '../api/types'; -const CHART_DATA = [ - { name: 'Model A', value: 75 }, - { name: 'Model B', value: 82 }, - { name: 'Model C', value: 95 }, - { name: 'Model D', value: 68 }, -]; - -const METRICS_DATA = [ - { name: 'Precision', value: 88 }, - { name: 'Recall', value: 76 }, - { name: 'F1 Score', value: 91 }, - { name: 'Accuracy', value: 82 }, -]; - -const JOBS = [ - { id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 }, - { id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 }, - { id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 }, -]; +const formatDate = (dateString: string | null): string => { + if (!dateString) return 'N/A'; + return new Date(dateString).toLocaleString(); +}; export const Models: React.FC = () => { + const [selectedModel, setSelectedModel] = useState(null); + const { models, isLoading, activateModel, isActivating } = useModels(); + const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null); + + // Build chart data from selected model's metrics + const metricsData = modelDetail ? [ + { name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 }, + { name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 }, + { name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 }, + ] : [ + { name: 'Precision', value: 0 }, + { name: 'Recall', value: 0 }, + { name: 'mAP', value: 0 }, + ]; + + // Build comparison chart from all models (with placeholder if empty) + const chartData = models.length > 0 + ? models.slice(0, 4).map(m => ({ + name: m.version, + value: (m.metrics_mAP ?? 0) * 100, + })) + : [ + { name: 'Model A', value: 0 }, + { name: 'Model B', value: 0 }, + { name: 'Model C', value: 0 }, + { name: 'Model D', value: 0 }, + ]; + return (
{/* Left: Job History */}

Models & History

-

Recent Training Jobs

+

Model Versions

-
- {JOBS.map(job => ( -
-
-
-

{job.name}

-

Started {job.date}

-
- - {job.status} - -
- - {job.status === 'Running' ? ( -
-
-
-
+ {isLoading ? ( +
+ +
+ ) : models.length === 0 ? ( +
+ No model versions found. Complete a training task to create a model version. +
+ ) : ( +
+ {models.map(model => ( +
setSelectedModel(model)} + className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${ + selectedModel?.version_id === model.version_id + ? 'border-warm-text-secondary' + : 'border-warm-border hover:border-warm-divider' + }`} + > +
+
+

+ {model.name} + {model.is_active && } +

+

Trained {formatDate(model.trained_at)}

+
+ + {model.is_active ? 'Active' : model.status} +
- ) : ( +
- Success - {job.success} + Documents + {model.document_count}
- Performance - {job.metrics}% + mAP + + {model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'} +
- Completed - 100% + Version + {model.version}
- )} -
- ))} -
+
+ ))} +
+ )}
{/* Right: Model Detail */} @@ -75,27 +110,34 @@ export const Models: React.FC = () => {

Model Detail

- Completed + + {selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'} +

Model name

-

Invoices Q4 v2.1

+

+ {selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'} +

{/* Chart 1 */}
-

Bar Rate Metrics

+

Model Comparison (mAP)

- + - + - [`${value.toFixed(1)}%`, 'mAP']} /> @@ -105,14 +147,17 @@ export const Models: React.FC = () => { {/* Chart 2 */}
-

Entity Extraction Accuracy

+

Performance Metrics

- + - + [`${value.toFixed(1)}%`, 'Score']} + /> @@ -121,14 +166,43 @@ export const Models: React.FC = () => {
- + {selectedModel && !selectedModel.is_active ? ( + + ) : ( + + )}
- - + +
); -}; \ No newline at end of file +}; diff --git a/frontend/src/components/Training.tsx b/frontend/src/components/Training.tsx index 39a9976..13c63eb 100644 --- a/frontend/src/components/Training.tsx +++ b/frontend/src/components/Training.tsx @@ -1,113 +1,482 @@ -import React, { useState } from 'react'; -import { Check, AlertCircle } from 'lucide-react'; -import { Button } from './Button'; -import { DocumentStatus } from '../types'; +import React, { useState, useMemo } from 'react' +import { useQuery } from '@tanstack/react-query' +import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react' +import { Button } from './Button' +import { AugmentationConfig } from './AugmentationConfig' +import { useDatasets } from '../hooks/useDatasets' +import { useTrainingDocuments } from '../hooks/useTraining' +import { trainingApi } from '../api/endpoints' +import type { DatasetListItem } from '../api/types' +import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation' -export const Training: React.FC = () => { - const [split, setSplit] = useState(80); +type Tab = 'datasets' | 'create' - const docs = [ - { id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED }, - { id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED }, - { id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED }, - { id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL }, - { id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL }, - { id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL }, - { id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED }, - ]; +interface TrainingProps { + onNavigate?: (view: string, id?: string) => void +} + +const STATUS_STYLES: Record = { + ready: 'bg-warm-state-success/10 text-warm-state-success', + building: 'bg-warm-state-info/10 text-warm-state-info', + training: 'bg-warm-state-info/10 text-warm-state-info', + failed: 'bg-warm-state-error/10 text-warm-state-error', + pending: 'bg-warm-state-warning/10 text-warm-state-warning', + scheduled: 'bg-warm-state-warning/10 text-warm-state-warning', + running: 'bg-warm-state-info/10 text-warm-state-info', +} + +const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => { + // If there's an active training task, show training status + const displayStatus = trainingStatus === 'running' + ? 'training' + : trainingStatus === 'pending' || trainingStatus === 'scheduled' + ? 'pending' + : status return ( -
- {/* Document Selection List */} -
-

Document Selection

- -
-
- - - - - - - - - - - {docs.map(doc => ( - - - - - - - ))} - -
Document nameDateStatus
{doc.name}{doc.date} - {doc.status === DocumentStatus.VERIFIED ? ( -
-
- -
- Verified -
- ) : ( -
-
- ! -
- Partial -
- )} -
-
-
-
+ + {(displayStatus === 'building' || displayStatus === 'training') && } + {displayStatus === 'ready' && } + {displayStatus === 'failed' && } + {displayStatus} + + ) +} - {/* Configuration Panel */} -
-
-

Training Configuration

- -
-
- - void + onSubmit: (config: { + name: string + config: { + model_name?: string + base_model_version_id?: string | null + epochs: number + batch_size: number + augmentation?: AugmentationConfigType + augmentation_multiplier?: number + } + }) => void + isPending: boolean +} + +const TrainDialog: React.FC = ({ dataset, onClose, onSubmit, isPending }) => { + const [name, setName] = useState(`train-${dataset.name}`) + const [epochs, setEpochs] = useState(100) + const [batchSize, setBatchSize] = useState(16) + const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained') + const [baseModelVersionId, setBaseModelVersionId] = useState(null) + const [augmentationEnabled, setAugmentationEnabled] = useState(false) + const [augmentationConfig, setAugmentationConfig] = useState>({}) + const [augmentationMultiplier, setAugmentationMultiplier] = useState(2) + + // Fetch available trained models + const { data: modelsData } = useQuery({ + queryKey: ['training', 'models', 'completed'], + queryFn: () => trainingApi.getModels({ status: 'completed' }), + }) + const completedModels = modelsData?.models ?? [] + + const handleSubmit = () => { + onSubmit({ + name, + config: { + model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined, + base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null, + epochs, + batch_size: batchSize, + augmentation: augmentationEnabled + ? (augmentationConfig as AugmentationConfigType) + : undefined, + augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined, + }, + }) + } + + return ( +
+
e.stopPropagation()}> +

Start Training

+

+ Dataset: {dataset.name} + {' '}({dataset.total_images} images, {dataset.total_annotations} annotations) +

+ +
+
+ + setName(e.target.value)} + className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" /> +
+ + {/* Base Model Selection */} +
+ + +

+ {baseModelType === 'pretrained' + ? 'Start from pretrained YOLO model' + : 'Continue training from an existing model (incremental training)'} +

+
+ +
+
+ + setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))} className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
- -
- - -
- -
-
- - {split}% / {100-split}% -
- setSplit(parseInt(e.target.value))} - className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info" +
+ + setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))} + className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
+
+ + {/* Augmentation Configuration */} + + + {/* Augmentation Multiplier - only shown when augmentation is enabled */} + {augmentationEnabled && ( +
+ + setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))} + className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" + /> +

+ Number of augmented copies per original image (1-10) +

+
+ )} +
+ +
+ + +
+
+
+ ) +} + +// --- Dataset List --- + +const DatasetList: React.FC<{ + onNavigate?: (view: string, id?: string) => void + onSwitchTab: (tab: Tab) => void +}> = ({ onNavigate, onSwitchTab }) => { + const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets() + const [trainTarget, setTrainTarget] = useState(null) + + const handleTrain = (config: { + name: string + config: { + model_name?: string + base_model_version_id?: string | null + epochs: number + batch_size: number + augmentation?: AugmentationConfigType + augmentation_multiplier?: number + } + }) => { + if (!trainTarget) return + // Pass config to the training API + const trainRequest = { + name: config.name, + config: config.config, + } + trainFromDataset( + { datasetId: trainTarget.dataset_id, req: trainRequest }, + { onSuccess: () => setTrainTarget(null) }, + ) + } + + if (isLoading) { + return
Loading datasets...
+ } + + if (datasets.length === 0) { + return ( +
+ +

No datasets yet

+

Create a dataset to start training

+ +
+ ) + } + + return ( + <> +
+ + + + + + + + + + + + + + {datasets.map(ds => ( + + + + + + + + + + ))} + +
NameStatusDocsImagesAnnotationsCreatedActions
{ds.name}{ds.total_documents}{ds.total_images}{ds.total_annotations}{new Date(ds.created_at).toLocaleDateString()} +
+ + {ds.status === 'ready' && ( + + )} + +
+
+
+ + {trainTarget && ( + setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} /> + )} + + ) +} + +// --- Create Dataset --- + +const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => { + const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true }) + const { createDatasetAsync, isCreating } = useDatasets() + + const [selectedIds, setSelectedIds] = useState>(new Set()) + const [name, setName] = useState('') + const [description, setDescription] = useState('') + const [trainRatio, setTrainRatio] = useState(0.7) + const [valRatio, setValRatio] = useState(0.2) + + const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio]) + + const toggleDoc = (id: string) => { + setSelectedIds(prev => { + const next = new Set(prev) + if (next.has(id)) { next.delete(id) } else { next.add(id) } + return next + }) + } + + const toggleAll = () => { + if (selectedIds.size === documents.length) { + setSelectedIds(new Set()) + } else { + setSelectedIds(new Set(documents.map((d) => d.document_id))) + } + } + + const handleCreate = async () => { + await createDatasetAsync({ + name, + description: description || undefined, + document_ids: [...selectedIds], + train_ratio: trainRatio, + val_ratio: valRatio, + }) + onSwitchTab('datasets') + } + + return ( +
+ {/* Document selection */} +
+

Select Documents

+ {isLoadingDocs ? ( +
Loading...
+ ) : ( +
+
+ + + + + + + + + + + {documents.map((doc) => ( + toggleDoc(doc.document_id)}> + + + + + + ))} + +
+ 0} + onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" /> + Document IDPagesAnnotations
+ + {doc.document_id.slice(0, 8)}...{doc.page_count}{doc.annotation_count ?? 0}
+
+
+ )} +

{selectedIds.size} of {documents.length} documents selected

+
+ + {/* Config panel */} +
+
+

Dataset Configuration

+
+
+ + setName(e.target.value)} placeholder="e.g. invoice-dataset-v1" + className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" /> +
+
+ +