WIP
This commit is contained in:
@@ -4,6 +4,7 @@ import { DashboardOverview } from './components/DashboardOverview'
|
||||
import { Dashboard } from './components/Dashboard'
|
||||
import { DocumentDetail } from './components/DocumentDetail'
|
||||
import { Training } from './components/Training'
|
||||
import { DatasetDetail } from './components/DatasetDetail'
|
||||
import { Models } from './components/Models'
|
||||
import { Login } from './components/Login'
|
||||
import { InferenceDemo } from './components/InferenceDemo'
|
||||
@@ -55,7 +56,14 @@ const App: React.FC = () => {
|
||||
case 'demo':
|
||||
return <InferenceDemo />
|
||||
case 'training':
|
||||
return <Training />
|
||||
return <Training onNavigate={handleNavigate} />
|
||||
case 'dataset-detail':
|
||||
return (
|
||||
<DatasetDetail
|
||||
datasetId={selectedDocId || ''}
|
||||
onBack={() => setCurrentView('training')}
|
||||
/>
|
||||
)
|
||||
case 'models':
|
||||
return <Models />
|
||||
default:
|
||||
|
||||
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
/**
|
||||
* Tests for augmentation API endpoints.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { augmentationApi } from './augmentation'
|
||||
import apiClient from '../client'
|
||||
|
||||
// Mock the API client
|
||||
vi.mock('../client', () => ({
|
||||
default: {
|
||||
get: vi.fn(),
|
||||
post: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('augmentationApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getTypes', () => {
|
||||
it('should fetch augmentation types', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.getTypes()
|
||||
|
||||
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types')
|
||||
expect(result.augmentation_types).toHaveLength(1)
|
||||
expect(result.augmentation_types[0].name).toBe('gaussian_noise')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getPresets', () => {
|
||||
it('should fetch augmentation presets', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations' },
|
||||
{ name: 'moderate', description: 'Balanced augmentations' },
|
||||
],
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.getPresets()
|
||||
|
||||
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets')
|
||||
expect(result.presets).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('preview', () => {
|
||||
it('should preview single augmentation', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
preview_url: 'data:image/png;base64,xxx',
|
||||
original_url: 'data:image/png;base64,yyy',
|
||||
applied_params: { std: 15 },
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.preview('doc-123', {
|
||||
augmentation_type: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
})
|
||||
|
||||
expect(apiClient.post).toHaveBeenCalledWith(
|
||||
'/api/v1/admin/augmentation/preview/doc-123',
|
||||
{
|
||||
augmentation_type: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
},
|
||||
{ params: { page: 1 } }
|
||||
)
|
||||
expect(result.preview_url).toBe('data:image/png;base64,xxx')
|
||||
})
|
||||
|
||||
it('should support custom page number', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
preview_url: 'data:image/png;base64,xxx',
|
||||
original_url: 'data:image/png;base64,yyy',
|
||||
applied_params: {},
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
await augmentationApi.preview(
|
||||
'doc-123',
|
||||
{ augmentation_type: 'gaussian_noise', params: {} },
|
||||
2
|
||||
)
|
||||
|
||||
expect(apiClient.post).toHaveBeenCalledWith(
|
||||
'/api/v1/admin/augmentation/preview/doc-123',
|
||||
expect.anything(),
|
||||
{ params: { page: 2 } }
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
144
frontend/src/api/endpoints/augmentation.ts
Normal file
144
frontend/src/api/endpoints/augmentation.ts
Normal file
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Augmentation API endpoints.
|
||||
*
|
||||
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
|
||||
*/
|
||||
|
||||
import apiClient from '../client'
|
||||
|
||||
// Types
|
||||
export interface AugmentationTypeInfo {
|
||||
name: string
|
||||
description: string
|
||||
affects_geometry: boolean
|
||||
stage: string
|
||||
default_params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationTypesResponse {
|
||||
augmentation_types: AugmentationTypeInfo[]
|
||||
}
|
||||
|
||||
export interface PresetInfo {
|
||||
name: string
|
||||
description: string
|
||||
config?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface PresetsResponse {
|
||||
presets: PresetInfo[]
|
||||
}
|
||||
|
||||
export interface PreviewRequest {
|
||||
augmentation_type: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface PreviewResponse {
|
||||
preview_url: string
|
||||
original_url: string
|
||||
applied_params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationParams {
|
||||
enabled: boolean
|
||||
probability: number
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationConfig {
|
||||
perspective_warp?: AugmentationParams
|
||||
wrinkle?: AugmentationParams
|
||||
edge_damage?: AugmentationParams
|
||||
stain?: AugmentationParams
|
||||
lighting_variation?: AugmentationParams
|
||||
shadow?: AugmentationParams
|
||||
gaussian_blur?: AugmentationParams
|
||||
motion_blur?: AugmentationParams
|
||||
gaussian_noise?: AugmentationParams
|
||||
salt_pepper?: AugmentationParams
|
||||
paper_texture?: AugmentationParams
|
||||
scanner_artifacts?: AugmentationParams
|
||||
preserve_bboxes?: boolean
|
||||
seed?: number | null
|
||||
}
|
||||
|
||||
export interface BatchRequest {
|
||||
dataset_id: string
|
||||
config: AugmentationConfig
|
||||
output_name: string
|
||||
multiplier: number
|
||||
}
|
||||
|
||||
export interface BatchResponse {
|
||||
task_id: string
|
||||
status: string
|
||||
message: string
|
||||
estimated_images: number
|
||||
}
|
||||
|
||||
// API functions
|
||||
export const augmentationApi = {
|
||||
/**
|
||||
* Fetch available augmentation types.
|
||||
*/
|
||||
async getTypes(): Promise<AugmentationTypesResponse> {
|
||||
const response = await apiClient.get<AugmentationTypesResponse>(
|
||||
'/api/v1/admin/augmentation/types'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Fetch augmentation presets.
|
||||
*/
|
||||
async getPresets(): Promise<PresetsResponse> {
|
||||
const response = await apiClient.get<PresetsResponse>(
|
||||
'/api/v1/admin/augmentation/presets'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Preview a single augmentation on a document page.
|
||||
*/
|
||||
async preview(
|
||||
documentId: string,
|
||||
request: PreviewRequest,
|
||||
page: number = 1
|
||||
): Promise<PreviewResponse> {
|
||||
const response = await apiClient.post<PreviewResponse>(
|
||||
`/api/v1/admin/augmentation/preview/${documentId}`,
|
||||
request,
|
||||
{ params: { page } }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Preview full augmentation config on a document page.
|
||||
*/
|
||||
async previewConfig(
|
||||
documentId: string,
|
||||
config: AugmentationConfig,
|
||||
page: number = 1
|
||||
): Promise<PreviewResponse> {
|
||||
const response = await apiClient.post<PreviewResponse>(
|
||||
`/api/v1/admin/augmentation/preview-config/${documentId}`,
|
||||
config,
|
||||
{ params: { page } }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Create an augmented dataset.
|
||||
*/
|
||||
async createBatch(request: BatchRequest): Promise<BatchResponse> {
|
||||
const response = await apiClient.post<BatchResponse>(
|
||||
'/api/v1/admin/augmentation/batch',
|
||||
request
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
52
frontend/src/api/endpoints/datasets.ts
Normal file
52
frontend/src/api/endpoints/datasets.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetListResponse,
|
||||
DatasetResponse,
|
||||
DatasetTrainRequest,
|
||||
TrainingTaskResponse,
|
||||
} from '../types'
|
||||
|
||||
export const datasetsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DatasetListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/datasets', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
create: async (req: DatasetCreateRequest): Promise<DatasetResponse> => {
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/datasets', req)
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (datasetId: string): Promise<DatasetDetailResponse> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/training/datasets/${datasetId}`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
remove: async (datasetId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.delete(
|
||||
`/api/v1/admin/training/datasets/${datasetId}`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
trainFromDataset: async (
|
||||
datasetId: string,
|
||||
req: DatasetTrainRequest
|
||||
): Promise<TrainingTaskResponse> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/training/datasets/${datasetId}/train`,
|
||||
req
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
@@ -21,14 +21,20 @@ export const documentsApi = {
|
||||
return data
|
||||
},
|
||||
|
||||
upload: async (file: File): Promise<UploadDocumentResponse> => {
|
||||
upload: async (file: File, groupKey?: string): Promise<UploadDocumentResponse> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const params: Record<string, string> = {}
|
||||
if (groupKey) {
|
||||
params.group_key = groupKey
|
||||
}
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
@@ -77,4 +83,16 @@ export const documentsApi = {
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
updateGroupKey: async (
|
||||
documentId: string,
|
||||
groupKey: string | null
|
||||
): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/group-key`,
|
||||
null,
|
||||
{ params: { group_key: groupKey } }
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2,3 +2,6 @@ export { documentsApi } from './documents'
|
||||
export { annotationsApi } from './annotations'
|
||||
export { trainingApi } from './training'
|
||||
export { inferenceApi } from './inference'
|
||||
export { datasetsApi } from './datasets'
|
||||
export { augmentationApi } from './augmentation'
|
||||
export { modelsApi } from './models'
|
||||
|
||||
55
frontend/src/api/endpoints/models.ts
Normal file
55
frontend/src/api/endpoints/models.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ModelVersionResponse,
|
||||
ActiveModelResponse,
|
||||
} from '../types'
|
||||
|
||||
export const modelsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<ModelVersionListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/models', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (versionId: string): Promise<ModelVersionDetailResponse> => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
getActive: async (): Promise<ActiveModelResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/models/active')
|
||||
return data
|
||||
},
|
||||
|
||||
activate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`)
|
||||
return data
|
||||
},
|
||||
|
||||
deactivate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`)
|
||||
return data
|
||||
},
|
||||
|
||||
archive: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`)
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (versionId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
reload: async (): Promise<{ message: string; reloaded: boolean }> => {
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/models/reload')
|
||||
return data
|
||||
},
|
||||
}
|
||||
@@ -8,6 +8,7 @@ export interface DocumentItem {
|
||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
group_key: string | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
annotation_count?: number
|
||||
@@ -59,6 +60,7 @@ export interface DocumentDetailResponse {
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
batch_id: string | null
|
||||
group_key: string | null
|
||||
csv_field_values: Record<string, string> | null
|
||||
can_annotate: boolean
|
||||
annotation_lock_until: string | null
|
||||
@@ -113,7 +115,11 @@ export interface ErrorResponse {
|
||||
export interface UploadDocumentResponse {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
page_count: number
|
||||
status: string
|
||||
group_key: string | null
|
||||
auto_label_started: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
@@ -171,3 +177,165 @@ export interface InferenceResult {
|
||||
export interface InferenceResponse {
|
||||
result: InferenceResult
|
||||
}
|
||||
|
||||
// Dataset types
|
||||
|
||||
export interface DatasetCreateRequest {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
train_ratio?: number
|
||||
val_ratio?: number
|
||||
seed?: number
|
||||
}
|
||||
|
||||
export interface DatasetResponse {
|
||||
dataset_id: string
|
||||
name: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface DatasetDocumentItem {
|
||||
document_id: string
|
||||
split: string
|
||||
page_count: number
|
||||
annotation_count: number
|
||||
}
|
||||
|
||||
export interface DatasetListItem {
|
||||
dataset_id: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: string
|
||||
training_status: string | null
|
||||
active_training_task_id: string | null
|
||||
total_documents: number
|
||||
total_images: number
|
||||
total_annotations: number
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface DatasetListResponse {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
datasets: DatasetListItem[]
|
||||
}
|
||||
|
||||
export interface DatasetDetailResponse {
|
||||
dataset_id: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: string
|
||||
train_ratio: number
|
||||
val_ratio: number
|
||||
seed: number
|
||||
total_documents: number
|
||||
total_images: number
|
||||
total_annotations: number
|
||||
dataset_path: string | null
|
||||
error_message: string | null
|
||||
documents: DatasetDocumentItem[]
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface AugmentationParams {
|
||||
enabled: boolean
|
||||
probability: number
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationTrainingConfig {
|
||||
gaussian_noise?: AugmentationParams
|
||||
perspective_warp?: AugmentationParams
|
||||
wrinkle?: AugmentationParams
|
||||
edge_damage?: AugmentationParams
|
||||
stain?: AugmentationParams
|
||||
lighting_variation?: AugmentationParams
|
||||
shadow?: AugmentationParams
|
||||
gaussian_blur?: AugmentationParams
|
||||
motion_blur?: AugmentationParams
|
||||
salt_pepper?: AugmentationParams
|
||||
paper_texture?: AugmentationParams
|
||||
scanner_artifacts?: AugmentationParams
|
||||
preserve_bboxes?: boolean
|
||||
seed?: number | null
|
||||
}
|
||||
|
||||
export interface DatasetTrainRequest {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
image_size?: number
|
||||
learning_rate?: number
|
||||
device?: string
|
||||
augmentation?: AugmentationTrainingConfig
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface TrainingTaskResponse {
|
||||
task_id: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
// Model Version types
|
||||
|
||||
export interface ModelVersionItem {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
metrics_mAP: number | null
|
||||
document_count: number
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionDetailResponse {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
description: string | null
|
||||
model_path: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
task_id: string | null
|
||||
dataset_id: string | null
|
||||
metrics_mAP: number | null
|
||||
metrics_precision: number | null
|
||||
metrics_recall: number | null
|
||||
document_count: number
|
||||
training_config: Record<string, unknown> | null
|
||||
file_size: number | null
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionListResponse {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
models: ModelVersionItem[]
|
||||
}
|
||||
|
||||
export interface ModelVersionResponse {
|
||||
version_id: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface ActiveModelResponse {
|
||||
has_active_model: boolean
|
||||
model: ModelVersionItem | null
|
||||
}
|
||||
|
||||
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
@@ -0,0 +1,251 @@
|
||||
/**
|
||||
* Tests for AugmentationConfig component.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { AugmentationConfig } from './AugmentationConfig'
|
||||
import { augmentationApi } from '../api/endpoints/augmentation'
|
||||
import type { ReactNode } from 'react'
|
||||
|
||||
// Mock the API
|
||||
vi.mock('../api/endpoints/augmentation', () => ({
|
||||
augmentationApi: {
|
||||
getTypes: vi.fn(),
|
||||
getPresets: vi.fn(),
|
||||
preview: vi.fn(),
|
||||
previewConfig: vi.fn(),
|
||||
createBatch: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Default mock data
|
||||
const mockTypes = {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise to simulate sensor noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
{
|
||||
name: 'perspective_warp',
|
||||
description: 'Applies perspective transformation',
|
||||
affects_geometry: true,
|
||||
stage: 'geometric',
|
||||
default_params: { max_warp: 0.02 },
|
||||
},
|
||||
{
|
||||
name: 'gaussian_blur',
|
||||
description: 'Applies Gaussian blur',
|
||||
affects_geometry: false,
|
||||
stage: 'blur',
|
||||
default_params: { kernel_size: 5 },
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const mockPresets = {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations for high-quality documents' },
|
||||
{ name: 'moderate', description: 'Balanced augmentation settings' },
|
||||
{ name: 'aggressive', description: 'Strong augmentations for data diversity' },
|
||||
],
|
||||
}
|
||||
|
||||
// Test wrapper with QueryClient
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('AugmentationConfig', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes)
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets)
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render enable checkbox', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should be collapsed when disabled', () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
// Config options should not be visible
|
||||
expect(screen.queryByText(/preset/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should expand when enabled', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/preset/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('preset selection', () => {
|
||||
it('should display available presets', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('conservative')).toBeInTheDocument()
|
||||
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||
expect(screen.getByText('aggressive')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onConfigChange when preset is selected', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onConfigChange = vi.fn()
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={onConfigChange}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText('moderate'))
|
||||
|
||||
expect(onConfigChange).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('enable toggle', () => {
|
||||
it('should call onEnabledChange when checkbox is toggled', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onEnabledChange = vi.fn()
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={onEnabledChange}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i }))
|
||||
|
||||
expect(onEnabledChange).toHaveBeenCalledWith(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('augmentation types', () => {
|
||||
it('should display augmentation types when in custom mode', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
showCustomOptions={true}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should indicate which augmentations affect geometry', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
showCustomOptions={true}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
// perspective_warp affects geometry
|
||||
const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div')
|
||||
expect(perspectiveItem).toHaveTextContent(/affects bbox/i)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('loading state', () => {
|
||||
it('should show loading indicator while fetching types', () => {
|
||||
vi.mocked(augmentationApi.getTypes).mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
)
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
136
frontend/src/components/AugmentationConfig.tsx
Normal file
136
frontend/src/components/AugmentationConfig.tsx
Normal file
@@ -0,0 +1,136 @@
|
||||
/**
|
||||
* AugmentationConfig component for configuring image augmentation during training.
|
||||
*
|
||||
* Provides preset selection and optional custom augmentation type configuration.
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { Loader2, AlertTriangle } from 'lucide-react'
|
||||
import { useAugmentation } from '../hooks/useAugmentation'
|
||||
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||
|
||||
interface AugmentationConfigProps {
|
||||
enabled: boolean
|
||||
onEnabledChange: (enabled: boolean) => void
|
||||
config: Partial<AugmentationConfigType>
|
||||
onConfigChange: (config: Partial<AugmentationConfigType>) => void
|
||||
showCustomOptions?: boolean
|
||||
}
|
||||
|
||||
export const AugmentationConfig: React.FC<AugmentationConfigProps> = ({
|
||||
enabled,
|
||||
onEnabledChange,
|
||||
config,
|
||||
onConfigChange,
|
||||
showCustomOptions = false,
|
||||
}) => {
|
||||
const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation()
|
||||
|
||||
const isLoading = isLoadingTypes || isLoadingPresets
|
||||
|
||||
const handlePresetSelect = (presetName: string) => {
|
||||
const preset = presets.find((p) => p.name === presetName)
|
||||
if (preset && preset.config) {
|
||||
onConfigChange(preset.config as Partial<AugmentationConfigType>)
|
||||
} else {
|
||||
// Apply a basic config based on preset name
|
||||
const presetConfigs: Record<string, Partial<AugmentationConfigType>> = {
|
||||
conservative: {
|
||||
gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } },
|
||||
},
|
||||
moderate: {
|
||||
gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } },
|
||||
lighting_variation: { enabled: true, probability: 0.3, params: {} },
|
||||
perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } },
|
||||
},
|
||||
aggressive: {
|
||||
gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } },
|
||||
motion_blur: { enabled: true, probability: 0.3, params: {} },
|
||||
lighting_variation: { enabled: true, probability: 0.5, params: {} },
|
||||
shadow: { enabled: true, probability: 0.3, params: {} },
|
||||
perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } },
|
||||
wrinkle: { enabled: true, probability: 0.2, params: {} },
|
||||
stain: { enabled: true, probability: 0.2, params: {} },
|
||||
},
|
||||
}
|
||||
onConfigChange(presetConfigs[presetName] || {})
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="border border-warm-divider rounded-lg p-4 bg-warm-bg-secondary">
|
||||
{/* Enable checkbox */}
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={enabled}
|
||||
onChange={(e) => onEnabledChange(e.target.checked)}
|
||||
className="w-4 h-4 rounded border-warm-divider text-warm-state-info focus:ring-warm-state-info"
|
||||
aria-label="Enable augmentation"
|
||||
/>
|
||||
<span className="text-sm font-medium text-warm-text-secondary">Enable Augmentation</span>
|
||||
<span className="text-xs text-warm-text-muted">(Simulate real-world document conditions)</span>
|
||||
</label>
|
||||
|
||||
{/* Expanded content when enabled */}
|
||||
{enabled && (
|
||||
<div className="mt-4 space-y-4">
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-4" data-testid="augmentation-loading">
|
||||
<Loader2 className="w-5 h-5 animate-spin text-warm-state-info" />
|
||||
<span className="ml-2 text-sm text-warm-text-muted">Loading augmentation options...</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{/* Preset selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Preset</label>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{presets.map((preset) => (
|
||||
<button
|
||||
key={preset.name}
|
||||
onClick={() => handlePresetSelect(preset.name)}
|
||||
className="px-3 py-1.5 text-sm rounded-md border border-warm-divider hover:bg-warm-bg-tertiary transition-colors"
|
||||
title={preset.description}
|
||||
>
|
||||
{preset.name}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Custom options (if enabled) */}
|
||||
{showCustomOptions && (
|
||||
<div className="border-t border-warm-divider pt-4">
|
||||
<h4 className="text-sm font-medium text-warm-text-secondary mb-3">Augmentation Types</h4>
|
||||
<div className="grid gap-2">
|
||||
{augmentationTypes.map((type) => (
|
||||
<div
|
||||
key={type.name}
|
||||
className="flex items-center justify-between p-2 bg-warm-bg-primary rounded border border-warm-divider"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-warm-text-primary">{type.name}</span>
|
||||
{type.affects_geometry && (
|
||||
<span className="flex items-center gap-1 text-xs text-warm-state-warning">
|
||||
<AlertTriangle size={12} />
|
||||
affects bbox
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs text-warm-text-muted">{type.stage}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -144,6 +144,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Annotations
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Group
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
|
||||
Auto-label
|
||||
</th>
|
||||
@@ -153,13 +156,13 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
||||
<tbody>
|
||||
{isLoading ? (
|
||||
<tr>
|
||||
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
||||
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
|
||||
Loading documents...
|
||||
</td>
|
||||
</tr>
|
||||
) : documents.length === 0 ? (
|
||||
<tr>
|
||||
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
||||
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
|
||||
No documents found. Upload your first document to get started.
|
||||
</td>
|
||||
</tr>
|
||||
@@ -213,6 +216,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
||||
{doc.annotation_count || 0} annotations
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-muted">
|
||||
{doc.group_key || '-'}
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
{doc.auto_label_status === 'running' && progress && (
|
||||
<div className="w-full">
|
||||
|
||||
122
frontend/src/components/DatasetDetail.tsx
Normal file
122
frontend/src/components/DatasetDetail.tsx
Normal file
@@ -0,0 +1,122 @@
|
||||
import React from 'react'
|
||||
import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDatasetDetail } from '../hooks/useDatasets'
|
||||
|
||||
interface DatasetDetailProps {
|
||||
datasetId: string
|
||||
onBack: () => void
|
||||
}
|
||||
|
||||
const SPLIT_STYLES: Record<string, string> = {
|
||||
train: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
val: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
test: 'bg-warm-state-success/10 text-warm-state-success',
|
||||
}
|
||||
|
||||
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
|
||||
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-20 text-warm-text-muted">
|
||||
<Loader2 size={24} className="animate-spin mr-2" />Loading dataset...
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (error || !dataset) {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||
<ArrowLeft size={16} />Back
|
||||
</button>
|
||||
<p className="text-warm-state-error">Failed to load dataset.</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const statusIcon = dataset.status === 'ready'
|
||||
? <Check size={14} className="text-warm-state-success" />
|
||||
: dataset.status === 'failed'
|
||||
? <AlertCircle size={14} className="text-warm-state-error" />
|
||||
: <Loader2 size={14} className="animate-spin text-warm-state-info" />
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
{/* Header */}
|
||||
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||
<ArrowLeft size={16} />Back to Datasets
|
||||
</button>
|
||||
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
|
||||
{dataset.name} {statusIcon}
|
||||
</h2>
|
||||
{dataset.description && (
|
||||
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
|
||||
)}
|
||||
</div>
|
||||
{dataset.status === 'ready' && (
|
||||
<Button><Play size={14} className="mr-1" />Start Training</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{dataset.error_message && (
|
||||
<div className="bg-warm-state-error/10 border border-warm-state-error/20 rounded-lg p-4 mb-6 text-sm text-warm-state-error">
|
||||
{dataset.error_message}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Stats */}
|
||||
<div className="grid grid-cols-4 gap-4 mb-8">
|
||||
{[
|
||||
['Documents', dataset.total_documents],
|
||||
['Images', dataset.total_images],
|
||||
['Annotations', dataset.total_annotations],
|
||||
['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
|
||||
].map(([label, value]) => (
|
||||
<div key={String(label)} className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||
<p className="text-xs text-warm-text-muted uppercase font-semibold mb-1">{label}</p>
|
||||
<p className="text-2xl font-bold text-warm-text-primary font-mono">{value}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Document list */}
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Documents</h3>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||
<table className="w-full text-left">
|
||||
<thead className="bg-white border-b border-warm-border">
|
||||
<tr>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Split</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{dataset.documents.map(doc => (
|
||||
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||
<td className="py-3 px-4">
|
||||
<span className={`inline-flex px-2.5 py-1 rounded-full text-xs font-medium ${SPLIT_STYLES[doc.split] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||
{doc.split}
|
||||
</span>
|
||||
</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<p className="text-xs text-warm-text-muted mt-4">
|
||||
Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
|
||||
{dataset.dataset_path && <> | Path: <code className="text-xs">{dataset.dataset_path}</code></>}
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
import React, { useState, useRef, useEffect } from 'react'
|
||||
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
|
||||
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocumentDetail } from '../hooks/useDocumentDetail'
|
||||
import { useAnnotations } from '../hooks/useAnnotations'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import { documentsApi } from '../api/endpoints/documents'
|
||||
import type { AnnotationItem } from '../api/types'
|
||||
|
||||
@@ -26,7 +27,7 @@ const FIELD_CLASSES: Record<number, string> = {
|
||||
}
|
||||
|
||||
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
|
||||
const { document, annotations, isLoading } = useDocumentDetail(docId)
|
||||
const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
|
||||
const {
|
||||
createAnnotation,
|
||||
updateAnnotation,
|
||||
@@ -34,10 +35,13 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
|
||||
isCreating,
|
||||
isDeleting,
|
||||
} = useAnnotations(docId)
|
||||
const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
|
||||
|
||||
const [selectedId, setSelectedId] = useState<string | null>(null)
|
||||
const [zoom, setZoom] = useState(100)
|
||||
const [isDrawing, setIsDrawing] = useState(false)
|
||||
const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
|
||||
const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
|
||||
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
|
||||
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
|
||||
const [selectedClassId, setSelectedClassId] = useState<number>(0)
|
||||
@@ -426,6 +430,65 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
|
||||
{new Date(document.created_at).toLocaleDateString()}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-warm-text-muted">Group</span>
|
||||
{isEditingGroupKey ? (
|
||||
<div className="flex items-center gap-1">
|
||||
<input
|
||||
type="text"
|
||||
value={editGroupKeyValue}
|
||||
onChange={(e) => setEditGroupKeyValue(e.target.value)}
|
||||
className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
placeholder="group key"
|
||||
autoFocus
|
||||
/>
|
||||
<button
|
||||
onClick={() => {
|
||||
updateGroupKey(
|
||||
{ documentId: docId, groupKey: editGroupKeyValue.trim() || null },
|
||||
{
|
||||
onSuccess: () => {
|
||||
setIsEditingGroupKey(false)
|
||||
refetch()
|
||||
},
|
||||
onError: () => {
|
||||
alert('Failed to update group key. Please try again.')
|
||||
},
|
||||
}
|
||||
)
|
||||
}}
|
||||
disabled={isUpdatingGroupKey}
|
||||
className="p-0.5 text-warm-state-success hover:bg-warm-hover rounded"
|
||||
>
|
||||
<Check size={14} />
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setIsEditingGroupKey(false)
|
||||
setEditGroupKeyValue(document.group_key || '')
|
||||
}}
|
||||
className="p-0.5 text-warm-state-error hover:bg-warm-hover rounded"
|
||||
>
|
||||
<X size={14} />
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center gap-1">
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{document.group_key || '-'}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => {
|
||||
setEditGroupKeyValue(document.group_key || '')
|
||||
setIsEditingGroupKey(true)
|
||||
}}
|
||||
className="p-0.5 text-warm-text-muted hover:text-warm-text-secondary hover:bg-warm-hover rounded"
|
||||
>
|
||||
<Edit2 size={12} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,73 +1,108 @@
|
||||
import React from 'react';
|
||||
import React, { useState } from 'react';
|
||||
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
|
||||
import { Loader2, Power, CheckCircle } from 'lucide-react';
|
||||
import { Button } from './Button';
|
||||
import { useModels, useModelDetail } from '../hooks';
|
||||
import type { ModelVersionItem } from '../api/types';
|
||||
|
||||
const CHART_DATA = [
|
||||
{ name: 'Model A', value: 75 },
|
||||
{ name: 'Model B', value: 82 },
|
||||
{ name: 'Model C', value: 95 },
|
||||
{ name: 'Model D', value: 68 },
|
||||
];
|
||||
|
||||
const METRICS_DATA = [
|
||||
{ name: 'Precision', value: 88 },
|
||||
{ name: 'Recall', value: 76 },
|
||||
{ name: 'F1 Score', value: 91 },
|
||||
{ name: 'Accuracy', value: 82 },
|
||||
];
|
||||
|
||||
const JOBS = [
|
||||
{ id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
|
||||
{ id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
|
||||
{ id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
|
||||
];
|
||||
const formatDate = (dateString: string | null): string => {
|
||||
if (!dateString) return 'N/A';
|
||||
return new Date(dateString).toLocaleString();
|
||||
};
|
||||
|
||||
export const Models: React.FC = () => {
|
||||
const [selectedModel, setSelectedModel] = useState<ModelVersionItem | null>(null);
|
||||
const { models, isLoading, activateModel, isActivating } = useModels();
|
||||
const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
|
||||
|
||||
// Build chart data from selected model's metrics
|
||||
const metricsData = modelDetail ? [
|
||||
{ name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
|
||||
{ name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
|
||||
{ name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
|
||||
] : [
|
||||
{ name: 'Precision', value: 0 },
|
||||
{ name: 'Recall', value: 0 },
|
||||
{ name: 'mAP', value: 0 },
|
||||
];
|
||||
|
||||
// Build comparison chart from all models (with placeholder if empty)
|
||||
const chartData = models.length > 0
|
||||
? models.slice(0, 4).map(m => ({
|
||||
name: m.version,
|
||||
value: (m.metrics_mAP ?? 0) * 100,
|
||||
}))
|
||||
: [
|
||||
{ name: 'Model A', value: 0 },
|
||||
{ name: 'Model B', value: 0 },
|
||||
{ name: 'Model C', value: 0 },
|
||||
{ name: 'Model D', value: 0 },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto flex gap-8">
|
||||
{/* Left: Job History */}
|
||||
<div className="flex-1">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Recent Training Jobs</h3>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Model Versions</h3>
|
||||
|
||||
<div className="space-y-4">
|
||||
{JOBS.map(job => (
|
||||
<div key={job.id} className="bg-warm-card border border-warm-border rounded-lg p-5 shadow-sm hover:border-warm-divider transition-colors">
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<div>
|
||||
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">{job.name}</h4>
|
||||
<p className="text-sm text-warm-text-muted">Started {job.date}</p>
|
||||
</div>
|
||||
<span className={`px-3 py-1 rounded-full text-xs font-medium ${job.status === 'Running' ? 'bg-warm-selected text-warm-text-secondary' : 'bg-warm-selected text-warm-state-success'}`}>
|
||||
{job.status}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{job.status === 'Running' ? (
|
||||
<div className="mt-4">
|
||||
<div className="h-2 w-full bg-warm-selected rounded-full overflow-hidden">
|
||||
<div className="h-full bg-warm-text-secondary w-[65%] rounded-full"></div>
|
||||
</div>
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Loader2 className="animate-spin text-warm-text-muted" size={32} />
|
||||
</div>
|
||||
) : models.length === 0 ? (
|
||||
<div className="text-center py-12 text-warm-text-muted">
|
||||
No model versions found. Complete a training task to create a model version.
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
{models.map(model => (
|
||||
<div
|
||||
key={model.version_id}
|
||||
onClick={() => setSelectedModel(model)}
|
||||
className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
|
||||
selectedModel?.version_id === model.version_id
|
||||
? 'border-warm-text-secondary'
|
||||
: 'border-warm-border hover:border-warm-divider'
|
||||
}`}
|
||||
>
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<div>
|
||||
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">
|
||||
{model.name}
|
||||
{model.is_active && <CheckCircle size={16} className="inline ml-2 text-warm-state-info" />}
|
||||
</h4>
|
||||
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
|
||||
</div>
|
||||
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
|
||||
model.is_active
|
||||
? 'bg-warm-state-info/10 text-warm-state-info'
|
||||
: 'bg-warm-selected text-warm-state-success'
|
||||
}`}>
|
||||
{model.is_active ? 'Active' : model.status}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
|
||||
<div className="mt-4 flex gap-8">
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Success</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{job.success}</span>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Documents</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{model.document_count}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Performance</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{job.metrics}%</span>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">mAP</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">
|
||||
{model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Completed</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">100%</span>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Version</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{model.version}</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right: Model Detail */}
|
||||
@@ -75,27 +110,34 @@ export const Models: React.FC = () => {
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
|
||||
<span className="text-sm font-medium text-warm-state-success">Completed</span>
|
||||
<span className={`text-sm font-medium ${
|
||||
selectedModel?.is_active ? 'text-warm-state-info' : 'text-warm-state-success'
|
||||
}`}>
|
||||
{selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="mb-8">
|
||||
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
|
||||
<p className="font-medium text-warm-text-primary">Invoices Q4 v2.1</p>
|
||||
<p className="font-medium text-warm-text-primary">
|
||||
{selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-8">
|
||||
{/* Chart 1 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Bar Rate Metrics</h4>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Model Comparison (mAP)</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={CHART_DATA}>
|
||||
<BarChart data={chartData}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" hide />
|
||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
|
||||
formatter={(value: number) => [`${value.toFixed(1)}%`, 'mAP']}
|
||||
/>
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
@@ -105,14 +147,17 @@ export const Models: React.FC = () => {
|
||||
|
||||
{/* Chart 2 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Entity Extraction Accuracy</h4>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Performance Metrics</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={METRICS_DATA}>
|
||||
<BarChart data={metricsData}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip cursor={{fill: '#F1F0ED'}} />
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
formatter={(value: number) => [`${value.toFixed(1)}%`, 'Score']}
|
||||
/>
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
@@ -121,14 +166,43 @@ export const Models: React.FC = () => {
|
||||
</div>
|
||||
|
||||
<div className="mt-8 space-y-3">
|
||||
<Button className="w-full">Download Model</Button>
|
||||
{selectedModel && !selectedModel.is_active ? (
|
||||
<Button
|
||||
className="w-full"
|
||||
onClick={() => activateModel(selectedModel.version_id)}
|
||||
disabled={isActivating}
|
||||
>
|
||||
{isActivating ? (
|
||||
<>
|
||||
<Loader2 size={16} className="mr-2 animate-spin" />
|
||||
Activating...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Power size={16} className="mr-2" />
|
||||
Activate for Inference
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
) : (
|
||||
<Button className="w-full" disabled={!selectedModel}>
|
||||
{selectedModel?.is_active ? (
|
||||
<>
|
||||
<CheckCircle size={16} className="mr-2" />
|
||||
Currently Active
|
||||
</>
|
||||
) : (
|
||||
'Select a Model'
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
<div className="flex gap-3">
|
||||
<Button variant="secondary" className="flex-1">View Logs</Button>
|
||||
<Button variant="secondary" className="flex-1">Use as Base</Button>
|
||||
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>View Logs</Button>
|
||||
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>Use as Base</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,113 +1,482 @@
|
||||
import React, { useState } from 'react';
|
||||
import { Check, AlertCircle } from 'lucide-react';
|
||||
import { Button } from './Button';
|
||||
import { DocumentStatus } from '../types';
|
||||
import React, { useState, useMemo } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { AugmentationConfig } from './AugmentationConfig'
|
||||
import { useDatasets } from '../hooks/useDatasets'
|
||||
import { useTrainingDocuments } from '../hooks/useTraining'
|
||||
import { trainingApi } from '../api/endpoints'
|
||||
import type { DatasetListItem } from '../api/types'
|
||||
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||
|
||||
export const Training: React.FC = () => {
|
||||
const [split, setSplit] = useState(80);
|
||||
type Tab = 'datasets' | 'create'
|
||||
|
||||
const docs = [
|
||||
{ id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
|
||||
{ id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||
{ id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||
{ id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||
{ id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||
{ id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||
{ id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||
];
|
||||
interface TrainingProps {
|
||||
onNavigate?: (view: string, id?: string) => void
|
||||
}
|
||||
|
||||
const STATUS_STYLES: Record<string, string> = {
|
||||
ready: 'bg-warm-state-success/10 text-warm-state-success',
|
||||
building: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
training: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
failed: 'bg-warm-state-error/10 text-warm-state-error',
|
||||
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
running: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
}
|
||||
|
||||
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
|
||||
// If there's an active training task, show training status
|
||||
const displayStatus = trainingStatus === 'running'
|
||||
? 'training'
|
||||
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
|
||||
? 'pending'
|
||||
: status
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto h-[calc(100vh-56px)] flex gap-8">
|
||||
{/* Document Selection List */}
|
||||
<div className="flex-1 flex flex-col">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Document Selection</h2>
|
||||
|
||||
<div className="flex-1 bg-warm-card border border-warm-border rounded-lg overflow-hidden flex flex-col shadow-sm">
|
||||
<div className="overflow-auto flex-1">
|
||||
<table className="w-full text-left">
|
||||
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
||||
<tr>
|
||||
<th className="py-3 pl-6 pr-4 w-12"><input type="checkbox" className="rounded border-warm-divider"/></th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document name</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Date</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{docs.map(doc => (
|
||||
<tr key={doc.id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 pl-6 pr-4"><input type="checkbox" defaultChecked className="rounded border-warm-divider accent-warm-state-info"/></td>
|
||||
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{doc.name}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.date}</td>
|
||||
<td className="py-3 px-4">
|
||||
{doc.status === DocumentStatus.VERIFIED ? (
|
||||
<div className="flex items-center text-warm-state-success text-sm font-medium">
|
||||
<div className="w-5 h-5 rounded-full bg-warm-state-success flex items-center justify-center text-white mr-2">
|
||||
<Check size={12} strokeWidth={3}/>
|
||||
</div>
|
||||
Verified
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center text-warm-text-muted text-sm">
|
||||
<div className="w-5 h-5 rounded-full bg-[#BDBBB5] flex items-center justify-center text-white mr-2">
|
||||
<span className="font-bold text-[10px]">!</span>
|
||||
</div>
|
||||
Partial
|
||||
</div>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
|
||||
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
|
||||
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
|
||||
{displayStatus}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
{/* Configuration Panel */}
|
||||
<div className="w-96">
|
||||
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-6">Training Configuration</h3>
|
||||
|
||||
<div className="space-y-6">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Model Name</label>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="e.g. Invoices Q4"
|
||||
// --- Train Dialog ---
|
||||
|
||||
interface TrainDialogProps {
|
||||
dataset: DatasetListItem
|
||||
onClose: () => void
|
||||
onSubmit: (config: {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs: number
|
||||
batch_size: number
|
||||
augmentation?: AugmentationConfigType
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}) => void
|
||||
isPending: boolean
|
||||
}
|
||||
|
||||
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
|
||||
const [name, setName] = useState(`train-${dataset.name}`)
|
||||
const [epochs, setEpochs] = useState(100)
|
||||
const [batchSize, setBatchSize] = useState(16)
|
||||
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
|
||||
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
|
||||
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
|
||||
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
||||
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
||||
|
||||
// Fetch available trained models
|
||||
const { data: modelsData } = useQuery({
|
||||
queryKey: ['training', 'models', 'completed'],
|
||||
queryFn: () => trainingApi.getModels({ status: 'completed' }),
|
||||
})
|
||||
const completedModels = modelsData?.models ?? []
|
||||
|
||||
const handleSubmit = () => {
|
||||
onSubmit({
|
||||
name,
|
||||
config: {
|
||||
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
|
||||
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
|
||||
epochs,
|
||||
batch_size: batchSize,
|
||||
augmentation: augmentationEnabled
|
||||
? (augmentationConfig as AugmentationConfigType)
|
||||
: undefined,
|
||||
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
|
||||
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
|
||||
<p className="text-sm text-warm-text-muted mb-4">
|
||||
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
|
||||
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
|
||||
</p>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
|
||||
<input type="text" value={name} onChange={e => setName(e.target.value)}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
|
||||
{/* Base Model Selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
|
||||
<select
|
||||
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
|
||||
onChange={e => {
|
||||
if (e.target.value === 'pretrained') {
|
||||
setBaseModelType('pretrained')
|
||||
setBaseModelVersionId(null)
|
||||
} else {
|
||||
setBaseModelType('existing')
|
||||
setBaseModelVersionId(e.target.value)
|
||||
}
|
||||
}}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
>
|
||||
<option value="pretrained">yolo11n.pt (Pretrained)</option>
|
||||
{completedModels.map(m => (
|
||||
<option key={m.task_id} value={m.task_id}>
|
||||
{m.name} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
{baseModelType === 'pretrained'
|
||||
? 'Start from pretrained YOLO model'
|
||||
: 'Continue training from an existing model (incremental training)'}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-4">
|
||||
<div className="flex-1">
|
||||
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
|
||||
<input
|
||||
id="train-epochs"
|
||||
type="number"
|
||||
min={1}
|
||||
max={1000}
|
||||
value={epochs}
|
||||
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Base Model</label>
|
||||
<select className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none">
|
||||
<option>LayoutLMv3 (Standard)</option>
|
||||
<option>Donut (Beta)</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="flex justify-between mb-2">
|
||||
<label className="block text-sm font-medium text-warm-text-secondary">Train/Test Split</label>
|
||||
<span className="text-xs font-mono text-warm-text-muted">{split}% / {100-split}%</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min="50"
|
||||
max="95"
|
||||
value={split}
|
||||
onChange={(e) => setSplit(parseInt(e.target.value))}
|
||||
className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
|
||||
<div className="flex-1">
|
||||
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
|
||||
<input
|
||||
id="train-batch-size"
|
||||
type="number"
|
||||
min={1}
|
||||
max={128}
|
||||
value={batchSize}
|
||||
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Augmentation Configuration */}
|
||||
<AugmentationConfig
|
||||
enabled={augmentationEnabled}
|
||||
onEnabledChange={setAugmentationEnabled}
|
||||
config={augmentationConfig}
|
||||
onConfigChange={setAugmentationConfig}
|
||||
/>
|
||||
|
||||
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
|
||||
{augmentationEnabled && (
|
||||
<div>
|
||||
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
|
||||
Augmentation Multiplier
|
||||
</label>
|
||||
<input
|
||||
id="aug-multiplier"
|
||||
type="number"
|
||||
min={1}
|
||||
max={10}
|
||||
value={augmentationMultiplier}
|
||||
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Number of augmented copies per original image (1-10)
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 mt-6">
|
||||
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
|
||||
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
|
||||
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Dataset List ---
|
||||
|
||||
const DatasetList: React.FC<{
|
||||
onNavigate?: (view: string, id?: string) => void
|
||||
onSwitchTab: (tab: Tab) => void
|
||||
}> = ({ onNavigate, onSwitchTab }) => {
|
||||
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
|
||||
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
|
||||
|
||||
const handleTrain = (config: {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs: number
|
||||
batch_size: number
|
||||
augmentation?: AugmentationConfigType
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}) => {
|
||||
if (!trainTarget) return
|
||||
// Pass config to the training API
|
||||
const trainRequest = {
|
||||
name: config.name,
|
||||
config: config.config,
|
||||
}
|
||||
trainFromDataset(
|
||||
{ datasetId: trainTarget.dataset_id, req: trainRequest },
|
||||
{ onSuccess: () => setTrainTarget(null) },
|
||||
)
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
|
||||
}
|
||||
|
||||
if (datasets.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
|
||||
<Database size={48} className="mb-4 opacity-40" />
|
||||
<p className="text-lg mb-2">No datasets yet</p>
|
||||
<p className="text-sm mb-4">Create a dataset to start training</p>
|
||||
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||
<table className="w-full text-left">
|
||||
<thead className="bg-white border-b border-warm-border">
|
||||
<tr>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{datasets.map(ds => (
|
||||
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
|
||||
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
|
||||
<td className="py-3 px-4">
|
||||
<div className="flex gap-1">
|
||||
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
|
||||
<Eye size={14} />
|
||||
</button>
|
||||
{ds.status === 'ready' && (
|
||||
<button title="Train" onClick={() => setTrainTarget(ds)}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
|
||||
<Play size={14} />
|
||||
</button>
|
||||
)}
|
||||
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
|
||||
disabled={isDeleting}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors">
|
||||
<Trash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
{trainTarget && (
|
||||
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Create Dataset ---
|
||||
|
||||
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
|
||||
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
|
||||
const { createDatasetAsync, isCreating } = useDatasets()
|
||||
|
||||
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [trainRatio, setTrainRatio] = useState(0.7)
|
||||
const [valRatio, setValRatio] = useState(0.2)
|
||||
|
||||
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
|
||||
|
||||
const toggleDoc = (id: string) => {
|
||||
setSelectedIds(prev => {
|
||||
const next = new Set(prev)
|
||||
if (next.has(id)) { next.delete(id) } else { next.add(id) }
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
const toggleAll = () => {
|
||||
if (selectedIds.size === documents.length) {
|
||||
setSelectedIds(new Set())
|
||||
} else {
|
||||
setSelectedIds(new Set(documents.map((d) => d.document_id)))
|
||||
}
|
||||
}
|
||||
|
||||
const handleCreate = async () => {
|
||||
await createDatasetAsync({
|
||||
name,
|
||||
description: description || undefined,
|
||||
document_ids: [...selectedIds],
|
||||
train_ratio: trainRatio,
|
||||
val_ratio: valRatio,
|
||||
})
|
||||
onSwitchTab('datasets')
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-8">
|
||||
{/* Document selection */}
|
||||
<div className="flex-1 flex flex-col">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
|
||||
{isLoadingDocs ? (
|
||||
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
|
||||
) : (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
|
||||
<div className="overflow-auto max-h-[calc(100vh-240px)]">
|
||||
<table className="w-full text-left">
|
||||
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
||||
<tr>
|
||||
<th className="py-3 pl-6 pr-4 w-12">
|
||||
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
|
||||
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{documents.map((doc) => (
|
||||
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
|
||||
onClick={() => toggleDoc(doc.document_id)}>
|
||||
<td className="py-3 pl-6 pr-4">
|
||||
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
|
||||
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
|
||||
</td>
|
||||
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
|
||||
</div>
|
||||
|
||||
{/* Config panel */}
|
||||
<div className="w-80">
|
||||
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
|
||||
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
|
||||
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
|
||||
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
|
||||
<div className="flex gap-2 text-sm">
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Train</span>
|
||||
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Val</span>
|
||||
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Test</span>
|
||||
<input type="number" value={testRatio} readOnly
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="pt-4 border-t border-warm-border">
|
||||
<Button className="w-full h-12">Start Training</Button>
|
||||
{selectedIds.size > 0 && selectedIds.size < 10 && (
|
||||
<p className="text-xs text-warm-state-warning mb-2">
|
||||
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
|
||||
</p>
|
||||
)}
|
||||
<Button className="w-full h-11" onClick={handleCreate}
|
||||
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
|
||||
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
)
|
||||
}
|
||||
|
||||
// --- Main Training Component ---
|
||||
|
||||
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
||||
const [activeTab, setActiveTab] = useState<Tab>('datasets')
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
|
||||
</div>
|
||||
|
||||
{/* Tabs */}
|
||||
<div className="flex gap-1 mb-6 border-b border-warm-border">
|
||||
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
|
||||
<button key={key} onClick={() => setActiveTab(key)}
|
||||
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
|
||||
activeTab === key
|
||||
? 'border-warm-state-info text-warm-state-info'
|
||||
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
|
||||
}`}>
|
||||
{label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
|
||||
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ interface UploadModalProps {
|
||||
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
||||
const [groupKey, setGroupKey] = useState('')
|
||||
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
||||
const [errorMessage, setErrorMessage] = useState('')
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
@@ -61,10 +62,13 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
||||
// Upload files one by one
|
||||
for (const file of selectedFiles) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
uploadDocument(file, {
|
||||
onSuccess: () => resolve(),
|
||||
onError: (error: Error) => reject(error),
|
||||
})
|
||||
uploadDocument(
|
||||
{ file, groupKey: groupKey || undefined },
|
||||
{
|
||||
onSuccess: () => resolve(),
|
||||
onError: (error: Error) => reject(error),
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -72,6 +76,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
||||
setTimeout(() => {
|
||||
onClose()
|
||||
setSelectedFiles([])
|
||||
setGroupKey('')
|
||||
setUploadStatus('idle')
|
||||
}, 1500)
|
||||
} catch (error) {
|
||||
@@ -85,6 +90,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
||||
return // Prevent closing during upload
|
||||
}
|
||||
setSelectedFiles([])
|
||||
setGroupKey('')
|
||||
setUploadStatus('idle')
|
||||
setErrorMessage('')
|
||||
onClose()
|
||||
@@ -173,6 +179,26 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Group Key Input */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6">
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Group Key (optional)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={groupKey}
|
||||
onChange={(e) => setGroupKey(e.target.value)}
|
||||
placeholder="e.g., 2024-Q1, supplier-abc, project-name"
|
||||
className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
/>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Use group keys to organize documents into logical groups
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Status Messages */}
|
||||
{uploadStatus === 'success' && (
|
||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
|
||||
|
||||
@@ -2,3 +2,6 @@ export { useDocuments } from './useDocuments'
|
||||
export { useDocumentDetail } from './useDocumentDetail'
|
||||
export { useAnnotations } from './useAnnotations'
|
||||
export { useTraining, useTrainingDocuments } from './useTraining'
|
||||
export { useDatasets, useDatasetDetail } from './useDatasets'
|
||||
export { useAugmentation } from './useAugmentation'
|
||||
export { useModels, useModelDetail, useActiveModel } from './useModels'
|
||||
|
||||
226
frontend/src/hooks/useAugmentation.test.tsx
Normal file
226
frontend/src/hooks/useAugmentation.test.tsx
Normal file
@@ -0,0 +1,226 @@
|
||||
/**
|
||||
* Tests for useAugmentation hook.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { renderHook, waitFor } from '@testing-library/react'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { augmentationApi } from '../api/endpoints/augmentation'
|
||||
import { useAugmentation } from './useAugmentation'
|
||||
import type { ReactNode } from 'react'
|
||||
|
||||
// Mock the API
|
||||
vi.mock('../api/endpoints/augmentation', () => ({
|
||||
augmentationApi: {
|
||||
getTypes: vi.fn(),
|
||||
getPresets: vi.fn(),
|
||||
preview: vi.fn(),
|
||||
previewConfig: vi.fn(),
|
||||
createBatch: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Test wrapper with QueryClient
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('useAugmentation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getTypes', () => {
|
||||
it('should fetch augmentation types', async () => {
|
||||
const mockTypes = {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
{
|
||||
name: 'perspective_warp',
|
||||
description: 'Applies perspective warp',
|
||||
affects_geometry: true,
|
||||
stage: 'geometric',
|
||||
default_params: { max_warp: 0.02 },
|
||||
},
|
||||
],
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce(mockTypes)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.augmentationTypes).toHaveLength(2)
|
||||
expect(result.current.augmentationTypes[0].name).toBe('gaussian_noise')
|
||||
})
|
||||
|
||||
it('should handle error when fetching types', async () => {
|
||||
vi.mocked(augmentationApi.getTypes).mockRejectedValueOnce(new Error('Network error'))
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.typesError).toBeTruthy()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getPresets', () => {
|
||||
it('should fetch augmentation presets', async () => {
|
||||
const mockPresets = {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations' },
|
||||
{ name: 'moderate', description: 'Balanced augmentations' },
|
||||
{ name: 'aggressive', description: 'Strong augmentations' },
|
||||
],
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce(mockPresets)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingPresets).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.presets).toHaveLength(3)
|
||||
expect(result.current.presets[0].name).toBe('conservative')
|
||||
})
|
||||
})
|
||||
|
||||
describe('preview', () => {
|
||||
it('should preview single augmentation', async () => {
|
||||
const mockPreview = {
|
||||
preview_url: 'data:image/png;base64,xxx',
|
||||
original_url: 'data:image/png;base64,yyy',
|
||||
applied_params: { std: 15 },
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||
vi.mocked(augmentationApi.preview).mockResolvedValueOnce(mockPreview)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
// Call preview mutation
|
||||
result.current.preview({
|
||||
documentId: 'doc-123',
|
||||
augmentationType: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
page: 1,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(augmentationApi.preview).toHaveBeenCalledWith(
|
||||
'doc-123',
|
||||
{ augmentation_type: 'gaussian_noise', params: { std: 15 } },
|
||||
1
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should track preview loading state', async () => {
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||
vi.mocked(augmentationApi.preview).mockImplementation(
|
||||
() => new Promise((resolve) => setTimeout(resolve, 100))
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.isPreviewing).toBe(false)
|
||||
|
||||
result.current.preview({
|
||||
documentId: 'doc-123',
|
||||
augmentationType: 'gaussian_noise',
|
||||
params: {},
|
||||
page: 1,
|
||||
})
|
||||
|
||||
// State update happens asynchronously
|
||||
await waitFor(() => {
|
||||
expect(result.current.isPreviewing).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('createBatch', () => {
|
||||
it('should create augmented dataset', async () => {
|
||||
const mockResponse = {
|
||||
task_id: 'task-123',
|
||||
status: 'pending',
|
||||
message: 'Augmentation task queued',
|
||||
estimated_images: 100,
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||
vi.mocked(augmentationApi.createBatch).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
result.current.createBatch({
|
||||
dataset_id: 'dataset-123',
|
||||
config: {
|
||||
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
|
||||
},
|
||||
output_name: 'augmented-dataset',
|
||||
multiplier: 2,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(augmentationApi.createBatch).toHaveBeenCalledWith({
|
||||
dataset_id: 'dataset-123',
|
||||
config: {
|
||||
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
|
||||
},
|
||||
output_name: 'augmented-dataset',
|
||||
multiplier: 2,
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
121
frontend/src/hooks/useAugmentation.ts
Normal file
121
frontend/src/hooks/useAugmentation.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Hook for managing augmentation operations.
|
||||
*
|
||||
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
|
||||
*/
|
||||
|
||||
import { useQuery, useMutation } from '@tanstack/react-query'
|
||||
import {
|
||||
augmentationApi,
|
||||
type AugmentationTypesResponse,
|
||||
type PresetsResponse,
|
||||
type PreviewResponse,
|
||||
type BatchRequest,
|
||||
type BatchResponse,
|
||||
type AugmentationConfig,
|
||||
} from '../api/endpoints/augmentation'
|
||||
|
||||
interface PreviewParams {
|
||||
documentId: string
|
||||
augmentationType: string
|
||||
params: Record<string, unknown>
|
||||
page?: number
|
||||
}
|
||||
|
||||
interface PreviewConfigParams {
|
||||
documentId: string
|
||||
config: AugmentationConfig
|
||||
page?: number
|
||||
}
|
||||
|
||||
export const useAugmentation = () => {
|
||||
// Fetch augmentation types
|
||||
const {
|
||||
data: typesData,
|
||||
isLoading: isLoadingTypes,
|
||||
error: typesError,
|
||||
} = useQuery<AugmentationTypesResponse>({
|
||||
queryKey: ['augmentation', 'types'],
|
||||
queryFn: () => augmentationApi.getTypes(),
|
||||
staleTime: 5 * 60 * 1000, // Cache for 5 minutes
|
||||
})
|
||||
|
||||
// Fetch presets
|
||||
const {
|
||||
data: presetsData,
|
||||
isLoading: isLoadingPresets,
|
||||
error: presetsError,
|
||||
} = useQuery<PresetsResponse>({
|
||||
queryKey: ['augmentation', 'presets'],
|
||||
queryFn: () => augmentationApi.getPresets(),
|
||||
staleTime: 5 * 60 * 1000,
|
||||
})
|
||||
|
||||
// Preview single augmentation mutation
|
||||
const previewMutation = useMutation<PreviewResponse, Error, PreviewParams>({
|
||||
mutationFn: ({ documentId, augmentationType, params, page = 1 }) =>
|
||||
augmentationApi.preview(
|
||||
documentId,
|
||||
{ augmentation_type: augmentationType, params },
|
||||
page
|
||||
),
|
||||
onError: (error) => {
|
||||
console.error('Preview augmentation failed:', error)
|
||||
},
|
||||
})
|
||||
|
||||
// Preview full config mutation
|
||||
const previewConfigMutation = useMutation<PreviewResponse, Error, PreviewConfigParams>({
|
||||
mutationFn: ({ documentId, config, page = 1 }) =>
|
||||
augmentationApi.previewConfig(documentId, config, page),
|
||||
onError: (error) => {
|
||||
console.error('Preview config failed:', error)
|
||||
},
|
||||
})
|
||||
|
||||
// Create augmented dataset mutation
|
||||
const createBatchMutation = useMutation<BatchResponse, Error, BatchRequest>({
|
||||
mutationFn: (request) => augmentationApi.createBatch(request),
|
||||
onError: (error) => {
|
||||
console.error('Create augmented dataset failed:', error)
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
// Types data
|
||||
augmentationTypes: typesData?.augmentation_types || [],
|
||||
isLoadingTypes,
|
||||
typesError,
|
||||
|
||||
// Presets data
|
||||
presets: presetsData?.presets || [],
|
||||
isLoadingPresets,
|
||||
presetsError,
|
||||
|
||||
// Preview single augmentation
|
||||
preview: previewMutation.mutate,
|
||||
previewAsync: previewMutation.mutateAsync,
|
||||
isPreviewing: previewMutation.isPending,
|
||||
previewData: previewMutation.data,
|
||||
previewError: previewMutation.error,
|
||||
|
||||
// Preview full config
|
||||
previewConfig: previewConfigMutation.mutate,
|
||||
previewConfigAsync: previewConfigMutation.mutateAsync,
|
||||
isPreviewingConfig: previewConfigMutation.isPending,
|
||||
previewConfigData: previewConfigMutation.data,
|
||||
previewConfigError: previewConfigMutation.error,
|
||||
|
||||
// Create batch
|
||||
createBatch: createBatchMutation.mutate,
|
||||
createBatchAsync: createBatchMutation.mutateAsync,
|
||||
isCreatingBatch: createBatchMutation.isPending,
|
||||
batchData: createBatchMutation.data,
|
||||
batchError: createBatchMutation.error,
|
||||
|
||||
// Reset functions for clearing stale mutation state
|
||||
resetPreview: previewMutation.reset,
|
||||
resetPreviewConfig: previewConfigMutation.reset,
|
||||
resetBatch: createBatchMutation.reset,
|
||||
}
|
||||
}
|
||||
84
frontend/src/hooks/useDatasets.ts
Normal file
84
frontend/src/hooks/useDatasets.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { datasetsApi } from '../api/endpoints'
|
||||
import type {
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetListResponse,
|
||||
DatasetTrainRequest,
|
||||
} from '../api/types'
|
||||
|
||||
export const useDatasets = (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isLoading, error, refetch } = useQuery<DatasetListResponse>({
|
||||
queryKey: ['datasets', params],
|
||||
queryFn: () => datasetsApi.list(params),
|
||||
staleTime: 30000,
|
||||
// Poll every 5 seconds when there's an active training task
|
||||
refetchInterval: (query) => {
|
||||
const datasets = query.state.data?.datasets ?? []
|
||||
const hasActiveTraining = datasets.some(
|
||||
d => d.training_status === 'running' || d.training_status === 'pending' || d.training_status === 'scheduled'
|
||||
)
|
||||
return hasActiveTraining ? 5000 : false
|
||||
},
|
||||
})
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (req: DatasetCreateRequest) => datasetsApi.create(req),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (datasetId: string) => datasetsApi.remove(datasetId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||
},
|
||||
})
|
||||
|
||||
const trainMutation = useMutation({
|
||||
mutationFn: ({ datasetId, req }: { datasetId: string; req: DatasetTrainRequest }) =>
|
||||
datasetsApi.trainFromDataset(datasetId, req),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
datasets: data?.datasets ?? [],
|
||||
total: data?.total ?? 0,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
createDataset: createMutation.mutate,
|
||||
createDatasetAsync: createMutation.mutateAsync,
|
||||
isCreating: createMutation.isPending,
|
||||
deleteDataset: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
trainFromDataset: trainMutation.mutate,
|
||||
trainFromDatasetAsync: trainMutation.mutateAsync,
|
||||
isTraining: trainMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
export const useDatasetDetail = (datasetId: string | null) => {
|
||||
const { data, isLoading, error } = useQuery<DatasetDetailResponse>({
|
||||
queryKey: ['datasets', datasetId],
|
||||
queryFn: () => datasetsApi.getDetail(datasetId!),
|
||||
enabled: !!datasetId,
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
dataset: data ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,16 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
|
||||
})
|
||||
|
||||
const uploadMutation = useMutation({
|
||||
mutationFn: (file: File) => documentsApi.upload(file),
|
||||
mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) =>
|
||||
documentsApi.upload(file, groupKey),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateGroupKeyMutation = useMutation({
|
||||
mutationFn: ({ documentId, groupKey }: { documentId: string; groupKey: string | null }) =>
|
||||
documentsApi.updateGroupKey(documentId, groupKey),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
@@ -74,5 +83,8 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
|
||||
isUpdatingStatus: updateStatusMutation.isPending,
|
||||
triggerAutoLabel: triggerAutoLabelMutation.mutate,
|
||||
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
|
||||
updateGroupKey: updateGroupKeyMutation.mutate,
|
||||
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
|
||||
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
98
frontend/src/hooks/useModels.ts
Normal file
98
frontend/src/hooks/useModels.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { modelsApi } from '../api/endpoints'
|
||||
import type {
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ActiveModelResponse,
|
||||
} from '../api/types'
|
||||
|
||||
export const useModels = (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isLoading, error, refetch } = useQuery<ModelVersionListResponse>({
|
||||
queryKey: ['models', params],
|
||||
queryFn: () => modelsApi.list(params),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
const activateMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.activate(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deactivateMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.deactivate(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
|
||||
},
|
||||
})
|
||||
|
||||
const archiveMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.archive(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.delete(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
models: data?.models ?? [],
|
||||
total: data?.total ?? 0,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
activateModel: activateMutation.mutate,
|
||||
activateModelAsync: activateMutation.mutateAsync,
|
||||
isActivating: activateMutation.isPending,
|
||||
deactivateModel: deactivateMutation.mutate,
|
||||
isDeactivating: deactivateMutation.isPending,
|
||||
archiveModel: archiveMutation.mutate,
|
||||
isArchiving: archiveMutation.isPending,
|
||||
deleteModel: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
export const useModelDetail = (versionId: string | null) => {
|
||||
const { data, isLoading, error } = useQuery<ModelVersionDetailResponse>({
|
||||
queryKey: ['models', versionId],
|
||||
queryFn: () => modelsApi.getDetail(versionId!),
|
||||
enabled: !!versionId,
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
model: data ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
export const useActiveModel = () => {
|
||||
const { data, isLoading, error } = useQuery<ActiveModelResponse>({
|
||||
queryKey: ['models', 'active'],
|
||||
queryFn: () => modelsApi.getActive(),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
hasActiveModel: data?.has_active_model ?? false,
|
||||
activeModel: data?.model ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user