This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -4,6 +4,7 @@ import { DashboardOverview } from './components/DashboardOverview'
import { Dashboard } from './components/Dashboard'
import { DocumentDetail } from './components/DocumentDetail'
import { Training } from './components/Training'
import { DatasetDetail } from './components/DatasetDetail'
import { Models } from './components/Models'
import { Login } from './components/Login'
import { InferenceDemo } from './components/InferenceDemo'
@@ -55,7 +56,14 @@ const App: React.FC = () => {
case 'demo':
return <InferenceDemo />
case 'training':
return <Training />
return <Training onNavigate={handleNavigate} />
case 'dataset-detail':
return (
<DatasetDetail
datasetId={selectedDocId || ''}
onBack={() => setCurrentView('training')}
/>
)
case 'models':
return <Models />
default:

View File

@@ -0,0 +1,118 @@
/**
* Tests for augmentation API endpoints.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { augmentationApi } from './augmentation'
import apiClient from '../client'
// Mock the API client
vi.mock('../client', () => ({
default: {
get: vi.fn(),
post: vi.fn(),
},
}))
describe('augmentationApi', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('getTypes', () => {
it('should fetch augmentation types', async () => {
const mockResponse = {
data: {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
],
},
}
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.getTypes()
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types')
expect(result.augmentation_types).toHaveLength(1)
expect(result.augmentation_types[0].name).toBe('gaussian_noise')
})
})
describe('getPresets', () => {
it('should fetch augmentation presets', async () => {
const mockResponse = {
data: {
presets: [
{ name: 'conservative', description: 'Safe augmentations' },
{ name: 'moderate', description: 'Balanced augmentations' },
],
},
}
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.getPresets()
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets')
expect(result.presets).toHaveLength(2)
})
})
describe('preview', () => {
it('should preview single augmentation', async () => {
const mockResponse = {
data: {
preview_url: 'data:image/png;base64,xxx',
original_url: 'data:image/png;base64,yyy',
applied_params: { std: 15 },
},
}
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.preview('doc-123', {
augmentation_type: 'gaussian_noise',
params: { std: 15 },
})
expect(apiClient.post).toHaveBeenCalledWith(
'/api/v1/admin/augmentation/preview/doc-123',
{
augmentation_type: 'gaussian_noise',
params: { std: 15 },
},
{ params: { page: 1 } }
)
expect(result.preview_url).toBe('data:image/png;base64,xxx')
})
it('should support custom page number', async () => {
const mockResponse = {
data: {
preview_url: 'data:image/png;base64,xxx',
original_url: 'data:image/png;base64,yyy',
applied_params: {},
},
}
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
await augmentationApi.preview(
'doc-123',
{ augmentation_type: 'gaussian_noise', params: {} },
2
)
expect(apiClient.post).toHaveBeenCalledWith(
'/api/v1/admin/augmentation/preview/doc-123',
expect.anything(),
{ params: { page: 2 } }
)
})
})
})

View File

@@ -0,0 +1,144 @@
/**
* Augmentation API endpoints.
*
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
*/
import apiClient from '../client'
// Types
export interface AugmentationTypeInfo {
name: string
description: string
affects_geometry: boolean
stage: string
default_params: Record<string, unknown>
}
export interface AugmentationTypesResponse {
augmentation_types: AugmentationTypeInfo[]
}
export interface PresetInfo {
name: string
description: string
config?: Record<string, unknown>
}
export interface PresetsResponse {
presets: PresetInfo[]
}
export interface PreviewRequest {
augmentation_type: string
params: Record<string, unknown>
}
export interface PreviewResponse {
preview_url: string
original_url: string
applied_params: Record<string, unknown>
}
export interface AugmentationParams {
enabled: boolean
probability: number
params: Record<string, unknown>
}
export interface AugmentationConfig {
perspective_warp?: AugmentationParams
wrinkle?: AugmentationParams
edge_damage?: AugmentationParams
stain?: AugmentationParams
lighting_variation?: AugmentationParams
shadow?: AugmentationParams
gaussian_blur?: AugmentationParams
motion_blur?: AugmentationParams
gaussian_noise?: AugmentationParams
salt_pepper?: AugmentationParams
paper_texture?: AugmentationParams
scanner_artifacts?: AugmentationParams
preserve_bboxes?: boolean
seed?: number | null
}
export interface BatchRequest {
dataset_id: string
config: AugmentationConfig
output_name: string
multiplier: number
}
export interface BatchResponse {
task_id: string
status: string
message: string
estimated_images: number
}
// API functions
export const augmentationApi = {
/**
* Fetch available augmentation types.
*/
async getTypes(): Promise<AugmentationTypesResponse> {
const response = await apiClient.get<AugmentationTypesResponse>(
'/api/v1/admin/augmentation/types'
)
return response.data
},
/**
* Fetch augmentation presets.
*/
async getPresets(): Promise<PresetsResponse> {
const response = await apiClient.get<PresetsResponse>(
'/api/v1/admin/augmentation/presets'
)
return response.data
},
/**
* Preview a single augmentation on a document page.
*/
async preview(
documentId: string,
request: PreviewRequest,
page: number = 1
): Promise<PreviewResponse> {
const response = await apiClient.post<PreviewResponse>(
`/api/v1/admin/augmentation/preview/${documentId}`,
request,
{ params: { page } }
)
return response.data
},
/**
* Preview full augmentation config on a document page.
*/
async previewConfig(
documentId: string,
config: AugmentationConfig,
page: number = 1
): Promise<PreviewResponse> {
const response = await apiClient.post<PreviewResponse>(
`/api/v1/admin/augmentation/preview-config/${documentId}`,
config,
{ params: { page } }
)
return response.data
},
/**
* Create an augmented dataset.
*/
async createBatch(request: BatchRequest): Promise<BatchResponse> {
const response = await apiClient.post<BatchResponse>(
'/api/v1/admin/augmentation/batch',
request
)
return response.data
},
}

View File

@@ -0,0 +1,52 @@
import apiClient from '../client'
import type {
DatasetCreateRequest,
DatasetDetailResponse,
DatasetListResponse,
DatasetResponse,
DatasetTrainRequest,
TrainingTaskResponse,
} from '../types'
export const datasetsApi = {
list: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<DatasetListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/datasets', {
params,
})
return data
},
create: async (req: DatasetCreateRequest): Promise<DatasetResponse> => {
const { data } = await apiClient.post('/api/v1/admin/training/datasets', req)
return data
},
getDetail: async (datasetId: string): Promise<DatasetDetailResponse> => {
const { data } = await apiClient.get(
`/api/v1/admin/training/datasets/${datasetId}`
)
return data
},
remove: async (datasetId: string): Promise<{ message: string }> => {
const { data } = await apiClient.delete(
`/api/v1/admin/training/datasets/${datasetId}`
)
return data
},
trainFromDataset: async (
datasetId: string,
req: DatasetTrainRequest
): Promise<TrainingTaskResponse> => {
const { data } = await apiClient.post(
`/api/v1/admin/training/datasets/${datasetId}/train`,
req
)
return data
},
}

View File

@@ -21,14 +21,20 @@ export const documentsApi = {
return data
},
upload: async (file: File): Promise<UploadDocumentResponse> => {
upload: async (file: File, groupKey?: string): Promise<UploadDocumentResponse> => {
const formData = new FormData()
formData.append('file', file)
const params: Record<string, string> = {}
if (groupKey) {
params.group_key = groupKey
}
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
params,
})
return data
},
@@ -77,4 +83,16 @@ export const documentsApi = {
)
return data
},
updateGroupKey: async (
documentId: string,
groupKey: string | null
): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/group-key`,
null,
{ params: { group_key: groupKey } }
)
return data
},
}

View File

@@ -2,3 +2,6 @@ export { documentsApi } from './documents'
export { annotationsApi } from './annotations'
export { trainingApi } from './training'
export { inferenceApi } from './inference'
export { datasetsApi } from './datasets'
export { augmentationApi } from './augmentation'
export { modelsApi } from './models'

View File

@@ -0,0 +1,55 @@
import apiClient from '../client'
import type {
ModelVersionListResponse,
ModelVersionDetailResponse,
ModelVersionResponse,
ActiveModelResponse,
} from '../types'
export const modelsApi = {
list: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<ModelVersionListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/models', {
params,
})
return data
},
getDetail: async (versionId: string): Promise<ModelVersionDetailResponse> => {
const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`)
return data
},
getActive: async (): Promise<ActiveModelResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/models/active')
return data
},
activate: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`)
return data
},
deactivate: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`)
return data
},
archive: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`)
return data
},
delete: async (versionId: string): Promise<{ message: string }> => {
const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`)
return data
},
reload: async (): Promise<{ message: string; reloaded: boolean }> => {
const { data } = await apiClient.post('/api/v1/admin/training/models/reload')
return data
},
}

View File

@@ -8,6 +8,7 @@ export interface DocumentItem {
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
auto_label_error: string | null
upload_source: string
group_key: string | null
created_at: string
updated_at: string
annotation_count?: number
@@ -59,6 +60,7 @@ export interface DocumentDetailResponse {
auto_label_error: string | null
upload_source: string
batch_id: string | null
group_key: string | null
csv_field_values: Record<string, string> | null
can_annotate: boolean
annotation_lock_until: string | null
@@ -113,7 +115,11 @@ export interface ErrorResponse {
export interface UploadDocumentResponse {
document_id: string
filename: string
file_size: number
page_count: number
status: string
group_key: string | null
auto_label_started: boolean
message: string
}
@@ -171,3 +177,165 @@ export interface InferenceResult {
export interface InferenceResponse {
result: InferenceResult
}
// Dataset types
export interface DatasetCreateRequest {
name: string
description?: string
document_ids: string[]
train_ratio?: number
val_ratio?: number
seed?: number
}
export interface DatasetResponse {
dataset_id: string
name: string
status: string
message: string
}
export interface DatasetDocumentItem {
document_id: string
split: string
page_count: number
annotation_count: number
}
export interface DatasetListItem {
dataset_id: string
name: string
description: string | null
status: string
training_status: string | null
active_training_task_id: string | null
total_documents: number
total_images: number
total_annotations: number
created_at: string
}
export interface DatasetListResponse {
total: number
limit: number
offset: number
datasets: DatasetListItem[]
}
export interface DatasetDetailResponse {
dataset_id: string
name: string
description: string | null
status: string
train_ratio: number
val_ratio: number
seed: number
total_documents: number
total_images: number
total_annotations: number
dataset_path: string | null
error_message: string | null
documents: DatasetDocumentItem[]
created_at: string
updated_at: string
}
export interface AugmentationParams {
enabled: boolean
probability: number
params: Record<string, unknown>
}
export interface AugmentationTrainingConfig {
gaussian_noise?: AugmentationParams
perspective_warp?: AugmentationParams
wrinkle?: AugmentationParams
edge_damage?: AugmentationParams
stain?: AugmentationParams
lighting_variation?: AugmentationParams
shadow?: AugmentationParams
gaussian_blur?: AugmentationParams
motion_blur?: AugmentationParams
salt_pepper?: AugmentationParams
paper_texture?: AugmentationParams
scanner_artifacts?: AugmentationParams
preserve_bboxes?: boolean
seed?: number | null
}
export interface DatasetTrainRequest {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs?: number
batch_size?: number
image_size?: number
learning_rate?: number
device?: string
augmentation?: AugmentationTrainingConfig
augmentation_multiplier?: number
}
}
export interface TrainingTaskResponse {
task_id: string
status: string
message: string
}
// Model Version types
export interface ModelVersionItem {
version_id: string
version: string
name: string
status: string
is_active: boolean
metrics_mAP: number | null
document_count: number
trained_at: string | null
activated_at: string | null
created_at: string
}
export interface ModelVersionDetailResponse {
version_id: string
version: string
name: string
description: string | null
model_path: string
status: string
is_active: boolean
task_id: string | null
dataset_id: string | null
metrics_mAP: number | null
metrics_precision: number | null
metrics_recall: number | null
document_count: number
training_config: Record<string, unknown> | null
file_size: number | null
trained_at: string | null
activated_at: string | null
created_at: string
updated_at: string
}
export interface ModelVersionListResponse {
total: number
limit: number
offset: number
models: ModelVersionItem[]
}
export interface ModelVersionResponse {
version_id: string
status: string
message: string
}
export interface ActiveModelResponse {
has_active_model: boolean
model: ModelVersionItem | null
}

View File

@@ -0,0 +1,251 @@
/**
* Tests for AugmentationConfig component.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { AugmentationConfig } from './AugmentationConfig'
import { augmentationApi } from '../api/endpoints/augmentation'
import type { ReactNode } from 'react'
// Mock the API
vi.mock('../api/endpoints/augmentation', () => ({
augmentationApi: {
getTypes: vi.fn(),
getPresets: vi.fn(),
preview: vi.fn(),
previewConfig: vi.fn(),
createBatch: vi.fn(),
},
}))
// Default mock data
const mockTypes = {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise to simulate sensor noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
{
name: 'perspective_warp',
description: 'Applies perspective transformation',
affects_geometry: true,
stage: 'geometric',
default_params: { max_warp: 0.02 },
},
{
name: 'gaussian_blur',
description: 'Applies Gaussian blur',
affects_geometry: false,
stage: 'blur',
default_params: { kernel_size: 5 },
},
],
}
const mockPresets = {
presets: [
{ name: 'conservative', description: 'Safe augmentations for high-quality documents' },
{ name: 'moderate', description: 'Balanced augmentation settings' },
{ name: 'aggressive', description: 'Strong augmentations for data diversity' },
],
}
// Test wrapper with QueryClient
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return ({ children }: { children: ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
}
describe('AugmentationConfig', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes)
vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets)
})
describe('rendering', () => {
it('should render enable checkbox', async () => {
render(
<AugmentationConfig
enabled={false}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument()
})
it('should be collapsed when disabled', () => {
render(
<AugmentationConfig
enabled={false}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
// Config options should not be visible
expect(screen.queryByText(/preset/i)).not.toBeInTheDocument()
})
it('should expand when enabled', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText(/preset/i)).toBeInTheDocument()
})
})
})
describe('preset selection', () => {
it('should display available presets', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText('conservative')).toBeInTheDocument()
expect(screen.getByText('moderate')).toBeInTheDocument()
expect(screen.getByText('aggressive')).toBeInTheDocument()
})
})
it('should call onConfigChange when preset is selected', async () => {
const user = userEvent.setup()
const onConfigChange = vi.fn()
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={onConfigChange}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText('moderate')).toBeInTheDocument()
})
await user.click(screen.getByText('moderate'))
expect(onConfigChange).toHaveBeenCalled()
})
})
describe('enable toggle', () => {
it('should call onEnabledChange when checkbox is toggled', async () => {
const user = userEvent.setup()
const onEnabledChange = vi.fn()
render(
<AugmentationConfig
enabled={false}
onEnabledChange={onEnabledChange}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i }))
expect(onEnabledChange).toHaveBeenCalledWith(true)
})
})
describe('augmentation types', () => {
it('should display augmentation types when in custom mode', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
showCustomOptions={true}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument()
expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument()
})
})
it('should indicate which augmentations affect geometry', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
showCustomOptions={true}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
// perspective_warp affects geometry
const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div')
expect(perspectiveItem).toHaveTextContent(/affects bbox/i)
})
})
})
describe('loading state', () => {
it('should show loading indicator while fetching types', () => {
vi.mocked(augmentationApi.getTypes).mockImplementation(
() => new Promise(() => {})
)
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,136 @@
/**
* AugmentationConfig component for configuring image augmentation during training.
*
* Provides preset selection and optional custom augmentation type configuration.
*/
import React from 'react'
import { Loader2, AlertTriangle } from 'lucide-react'
import { useAugmentation } from '../hooks/useAugmentation'
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
interface AugmentationConfigProps {
enabled: boolean
onEnabledChange: (enabled: boolean) => void
config: Partial<AugmentationConfigType>
onConfigChange: (config: Partial<AugmentationConfigType>) => void
showCustomOptions?: boolean
}
export const AugmentationConfig: React.FC<AugmentationConfigProps> = ({
enabled,
onEnabledChange,
config,
onConfigChange,
showCustomOptions = false,
}) => {
const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation()
const isLoading = isLoadingTypes || isLoadingPresets
const handlePresetSelect = (presetName: string) => {
const preset = presets.find((p) => p.name === presetName)
if (preset && preset.config) {
onConfigChange(preset.config as Partial<AugmentationConfigType>)
} else {
// Apply a basic config based on preset name
const presetConfigs: Record<string, Partial<AugmentationConfigType>> = {
conservative: {
gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } },
gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } },
},
moderate: {
gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } },
gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } },
lighting_variation: { enabled: true, probability: 0.3, params: {} },
perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } },
},
aggressive: {
gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } },
gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } },
motion_blur: { enabled: true, probability: 0.3, params: {} },
lighting_variation: { enabled: true, probability: 0.5, params: {} },
shadow: { enabled: true, probability: 0.3, params: {} },
perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } },
wrinkle: { enabled: true, probability: 0.2, params: {} },
stain: { enabled: true, probability: 0.2, params: {} },
},
}
onConfigChange(presetConfigs[presetName] || {})
}
}
return (
<div className="border border-warm-divider rounded-lg p-4 bg-warm-bg-secondary">
{/* Enable checkbox */}
<label className="flex items-center gap-2 cursor-pointer">
<input
type="checkbox"
checked={enabled}
onChange={(e) => onEnabledChange(e.target.checked)}
className="w-4 h-4 rounded border-warm-divider text-warm-state-info focus:ring-warm-state-info"
aria-label="Enable augmentation"
/>
<span className="text-sm font-medium text-warm-text-secondary">Enable Augmentation</span>
<span className="text-xs text-warm-text-muted">(Simulate real-world document conditions)</span>
</label>
{/* Expanded content when enabled */}
{enabled && (
<div className="mt-4 space-y-4">
{isLoading ? (
<div className="flex items-center justify-center py-4" data-testid="augmentation-loading">
<Loader2 className="w-5 h-5 animate-spin text-warm-state-info" />
<span className="ml-2 text-sm text-warm-text-muted">Loading augmentation options...</span>
</div>
) : (
<>
{/* Preset selection */}
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Preset</label>
<div className="flex flex-wrap gap-2">
{presets.map((preset) => (
<button
key={preset.name}
onClick={() => handlePresetSelect(preset.name)}
className="px-3 py-1.5 text-sm rounded-md border border-warm-divider hover:bg-warm-bg-tertiary transition-colors"
title={preset.description}
>
{preset.name}
</button>
))}
</div>
</div>
{/* Custom options (if enabled) */}
{showCustomOptions && (
<div className="border-t border-warm-divider pt-4">
<h4 className="text-sm font-medium text-warm-text-secondary mb-3">Augmentation Types</h4>
<div className="grid gap-2">
{augmentationTypes.map((type) => (
<div
key={type.name}
className="flex items-center justify-between p-2 bg-warm-bg-primary rounded border border-warm-divider"
>
<div className="flex items-center gap-2">
<span className="text-sm text-warm-text-primary">{type.name}</span>
{type.affects_geometry && (
<span className="flex items-center gap-1 text-xs text-warm-state-warning">
<AlertTriangle size={12} />
affects bbox
</span>
)}
</div>
<span className="text-xs text-warm-text-muted">{type.stage}</span>
</div>
))}
</div>
</div>
)}
</>
)}
</div>
)}
</div>
)
}

View File

@@ -144,6 +144,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Annotations
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Group
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
Auto-label
</th>
@@ -153,13 +156,13 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<tbody>
{isLoading ? (
<tr>
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
Loading documents...
</td>
</tr>
) : documents.length === 0 ? (
<tr>
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
No documents found. Upload your first document to get started.
</td>
</tr>
@@ -213,6 +216,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<td className="py-4 px-4 text-sm text-warm-text-secondary">
{doc.annotation_count || 0} annotations
</td>
<td className="py-4 px-4 text-sm text-warm-text-muted">
{doc.group_key || '-'}
</td>
<td className="py-4 px-4">
{doc.auto_label_status === 'running' && progress && (
<div className="w-full">

View File

@@ -0,0 +1,122 @@
import React from 'react'
import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react'
import { Button } from './Button'
import { useDatasetDetail } from '../hooks/useDatasets'
interface DatasetDetailProps {
datasetId: string
onBack: () => void
}
const SPLIT_STYLES: Record<string, string> = {
train: 'bg-warm-state-info/10 text-warm-state-info',
val: 'bg-warm-state-warning/10 text-warm-state-warning',
test: 'bg-warm-state-success/10 text-warm-state-success',
}
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
if (isLoading) {
return (
<div className="flex items-center justify-center py-20 text-warm-text-muted">
<Loader2 size={24} className="animate-spin mr-2" />Loading dataset...
</div>
)
}
if (error || !dataset) {
return (
<div className="p-8 max-w-7xl mx-auto">
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
<ArrowLeft size={16} />Back
</button>
<p className="text-warm-state-error">Failed to load dataset.</p>
</div>
)
}
const statusIcon = dataset.status === 'ready'
? <Check size={14} className="text-warm-state-success" />
: dataset.status === 'failed'
? <AlertCircle size={14} className="text-warm-state-error" />
: <Loader2 size={14} className="animate-spin text-warm-state-info" />
return (
<div className="p-8 max-w-7xl mx-auto">
{/* Header */}
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
<ArrowLeft size={16} />Back to Datasets
</button>
<div className="flex items-center justify-between mb-6">
<div>
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
{dataset.name} {statusIcon}
</h2>
{dataset.description && (
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
)}
</div>
{dataset.status === 'ready' && (
<Button><Play size={14} className="mr-1" />Start Training</Button>
)}
</div>
{dataset.error_message && (
<div className="bg-warm-state-error/10 border border-warm-state-error/20 rounded-lg p-4 mb-6 text-sm text-warm-state-error">
{dataset.error_message}
</div>
)}
{/* Stats */}
<div className="grid grid-cols-4 gap-4 mb-8">
{[
['Documents', dataset.total_documents],
['Images', dataset.total_images],
['Annotations', dataset.total_annotations],
['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
].map(([label, value]) => (
<div key={String(label)} className="bg-warm-card border border-warm-border rounded-lg p-4">
<p className="text-xs text-warm-text-muted uppercase font-semibold mb-1">{label}</p>
<p className="text-2xl font-bold text-warm-text-primary font-mono">{value}</p>
</div>
))}
</div>
{/* Document list */}
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Documents</h3>
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
<table className="w-full text-left">
<thead className="bg-white border-b border-warm-border">
<tr>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Split</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
</tr>
</thead>
<tbody>
{dataset.documents.map(doc => (
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
<td className="py-3 px-4">
<span className={`inline-flex px-2.5 py-1 rounded-full text-xs font-medium ${SPLIT_STYLES[doc.split] ?? 'bg-warm-border text-warm-text-muted'}`}>
{doc.split}
</span>
</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count}</td>
</tr>
))}
</tbody>
</table>
</div>
<p className="text-xs text-warm-text-muted mt-4">
Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
{dataset.dataset_path && <> | Path: <code className="text-xs">{dataset.dataset_path}</code></>}
</p>
</div>
)
}

View File

@@ -1,8 +1,9 @@
import React, { useState, useRef, useEffect } from 'react'
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
import { Button } from './Button'
import { useDocumentDetail } from '../hooks/useDocumentDetail'
import { useAnnotations } from '../hooks/useAnnotations'
import { useDocuments } from '../hooks/useDocuments'
import { documentsApi } from '../api/endpoints/documents'
import type { AnnotationItem } from '../api/types'
@@ -26,7 +27,7 @@ const FIELD_CLASSES: Record<number, string> = {
}
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
const { document, annotations, isLoading } = useDocumentDetail(docId)
const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
const {
createAnnotation,
updateAnnotation,
@@ -34,10 +35,13 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
isCreating,
isDeleting,
} = useAnnotations(docId)
const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
const [selectedId, setSelectedId] = useState<string | null>(null)
const [zoom, setZoom] = useState(100)
const [isDrawing, setIsDrawing] = useState(false)
const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
const [selectedClassId, setSelectedClassId] = useState<number>(0)
@@ -426,6 +430,65 @@ export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack })
{new Date(document.created_at).toLocaleDateString()}
</span>
</div>
<div className="flex justify-between items-center text-xs">
<span className="text-warm-text-muted">Group</span>
{isEditingGroupKey ? (
<div className="flex items-center gap-1">
<input
type="text"
value={editGroupKeyValue}
onChange={(e) => setEditGroupKeyValue(e.target.value)}
className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
placeholder="group key"
autoFocus
/>
<button
onClick={() => {
updateGroupKey(
{ documentId: docId, groupKey: editGroupKeyValue.trim() || null },
{
onSuccess: () => {
setIsEditingGroupKey(false)
refetch()
},
onError: () => {
alert('Failed to update group key. Please try again.')
},
}
)
}}
disabled={isUpdatingGroupKey}
className="p-0.5 text-warm-state-success hover:bg-warm-hover rounded"
>
<Check size={14} />
</button>
<button
onClick={() => {
setIsEditingGroupKey(false)
setEditGroupKeyValue(document.group_key || '')
}}
className="p-0.5 text-warm-state-error hover:bg-warm-hover rounded"
>
<X size={14} />
</button>
</div>
) : (
<div className="flex items-center gap-1">
<span className="text-warm-text-secondary font-medium">
{document.group_key || '-'}
</span>
<button
onClick={() => {
setEditGroupKeyValue(document.group_key || '')
setIsEditingGroupKey(true)
}}
className="p-0.5 text-warm-text-muted hover:text-warm-text-secondary hover:bg-warm-hover rounded"
>
<Edit2 size={12} />
</button>
</div>
)}
</div>
</div>
</div>
</div>

View File

@@ -1,73 +1,108 @@
import React from 'react';
import React, { useState } from 'react';
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
import { Loader2, Power, CheckCircle } from 'lucide-react';
import { Button } from './Button';
import { useModels, useModelDetail } from '../hooks';
import type { ModelVersionItem } from '../api/types';
const CHART_DATA = [
{ name: 'Model A', value: 75 },
{ name: 'Model B', value: 82 },
{ name: 'Model C', value: 95 },
{ name: 'Model D', value: 68 },
];
const METRICS_DATA = [
{ name: 'Precision', value: 88 },
{ name: 'Recall', value: 76 },
{ name: 'F1 Score', value: 91 },
{ name: 'Accuracy', value: 82 },
];
const JOBS = [
{ id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
{ id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
{ id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
];
const formatDate = (dateString: string | null): string => {
if (!dateString) return 'N/A';
return new Date(dateString).toLocaleString();
};
export const Models: React.FC = () => {
const [selectedModel, setSelectedModel] = useState<ModelVersionItem | null>(null);
const { models, isLoading, activateModel, isActivating } = useModels();
const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
// Build chart data from selected model's metrics
const metricsData = modelDetail ? [
{ name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
{ name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
{ name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
] : [
{ name: 'Precision', value: 0 },
{ name: 'Recall', value: 0 },
{ name: 'mAP', value: 0 },
];
// Build comparison chart from all models (with placeholder if empty)
const chartData = models.length > 0
? models.slice(0, 4).map(m => ({
name: m.version,
value: (m.metrics_mAP ?? 0) * 100,
}))
: [
{ name: 'Model A', value: 0 },
{ name: 'Model B', value: 0 },
{ name: 'Model C', value: 0 },
{ name: 'Model D', value: 0 },
];
return (
<div className="p-8 max-w-7xl mx-auto flex gap-8">
{/* Left: Job History */}
<div className="flex-1">
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Recent Training Jobs</h3>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Model Versions</h3>
<div className="space-y-4">
{JOBS.map(job => (
<div key={job.id} className="bg-warm-card border border-warm-border rounded-lg p-5 shadow-sm hover:border-warm-divider transition-colors">
<div className="flex justify-between items-start mb-2">
<div>
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">{job.name}</h4>
<p className="text-sm text-warm-text-muted">Started {job.date}</p>
</div>
<span className={`px-3 py-1 rounded-full text-xs font-medium ${job.status === 'Running' ? 'bg-warm-selected text-warm-text-secondary' : 'bg-warm-selected text-warm-state-success'}`}>
{job.status}
</span>
</div>
{job.status === 'Running' ? (
<div className="mt-4">
<div className="h-2 w-full bg-warm-selected rounded-full overflow-hidden">
<div className="h-full bg-warm-text-secondary w-[65%] rounded-full"></div>
</div>
{isLoading ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="animate-spin text-warm-text-muted" size={32} />
</div>
) : models.length === 0 ? (
<div className="text-center py-12 text-warm-text-muted">
No model versions found. Complete a training task to create a model version.
</div>
) : (
<div className="space-y-4">
{models.map(model => (
<div
key={model.version_id}
onClick={() => setSelectedModel(model)}
className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
selectedModel?.version_id === model.version_id
? 'border-warm-text-secondary'
: 'border-warm-border hover:border-warm-divider'
}`}
>
<div className="flex justify-between items-start mb-2">
<div>
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">
{model.name}
{model.is_active && <CheckCircle size={16} className="inline ml-2 text-warm-state-info" />}
</h4>
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
</div>
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
model.is_active
? 'bg-warm-state-info/10 text-warm-state-info'
: 'bg-warm-selected text-warm-state-success'
}`}>
{model.is_active ? 'Active' : model.status}
</span>
</div>
) : (
<div className="mt-4 flex gap-8">
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Success</span>
<span className="text-lg font-mono text-warm-text-secondary">{job.success}</span>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Documents</span>
<span className="text-lg font-mono text-warm-text-secondary">{model.document_count}</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Performance</span>
<span className="text-lg font-mono text-warm-text-secondary">{job.metrics}%</span>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">mAP</span>
<span className="text-lg font-mono text-warm-text-secondary">
{model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Completed</span>
<span className="text-lg font-mono text-warm-text-secondary">100%</span>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Version</span>
<span className="text-lg font-mono text-warm-text-secondary">{model.version}</span>
</div>
</div>
)}
</div>
))}
</div>
</div>
))}
</div>
)}
</div>
{/* Right: Model Detail */}
@@ -75,27 +110,34 @@ export const Models: React.FC = () => {
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
<div className="flex justify-between items-center mb-6">
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
<span className="text-sm font-medium text-warm-state-success">Completed</span>
<span className={`text-sm font-medium ${
selectedModel?.is_active ? 'text-warm-state-info' : 'text-warm-state-success'
}`}>
{selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
</span>
</div>
<div className="mb-8">
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
<p className="font-medium text-warm-text-primary">Invoices Q4 v2.1</p>
<p className="font-medium text-warm-text-primary">
{selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
</p>
</div>
<div className="space-y-8">
{/* Chart 1 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Bar Rate Metrics</h4>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Model Comparison (mAP)</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={CHART_DATA}>
<BarChart data={chartData}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" hide />
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
<YAxis hide domain={[0, 100]} />
<Tooltip
cursor={{fill: '#F1F0ED'}}
<Tooltip
cursor={{fill: '#F1F0ED'}}
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
formatter={(value: number) => [`${value.toFixed(1)}%`, 'mAP']}
/>
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
@@ -105,14 +147,17 @@ export const Models: React.FC = () => {
{/* Chart 2 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Entity Extraction Accuracy</h4>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Performance Metrics</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={METRICS_DATA}>
<BarChart data={metricsData}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
<YAxis hide domain={[0, 100]} />
<Tooltip cursor={{fill: '#F1F0ED'}} />
<Tooltip
cursor={{fill: '#F1F0ED'}}
formatter={(value: number) => [`${value.toFixed(1)}%`, 'Score']}
/>
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
</ResponsiveContainer>
@@ -121,14 +166,43 @@ export const Models: React.FC = () => {
</div>
<div className="mt-8 space-y-3">
<Button className="w-full">Download Model</Button>
{selectedModel && !selectedModel.is_active ? (
<Button
className="w-full"
onClick={() => activateModel(selectedModel.version_id)}
disabled={isActivating}
>
{isActivating ? (
<>
<Loader2 size={16} className="mr-2 animate-spin" />
Activating...
</>
) : (
<>
<Power size={16} className="mr-2" />
Activate for Inference
</>
)}
</Button>
) : (
<Button className="w-full" disabled={!selectedModel}>
{selectedModel?.is_active ? (
<>
<CheckCircle size={16} className="mr-2" />
Currently Active
</>
) : (
'Select a Model'
)}
</Button>
)}
<div className="flex gap-3">
<Button variant="secondary" className="flex-1">View Logs</Button>
<Button variant="secondary" className="flex-1">Use as Base</Button>
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>View Logs</Button>
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>Use as Base</Button>
</div>
</div>
</div>
</div>
</div>
);
};
};

View File

@@ -1,113 +1,482 @@
import React, { useState } from 'react';
import { Check, AlertCircle } from 'lucide-react';
import { Button } from './Button';
import { DocumentStatus } from '../types';
import React, { useState, useMemo } from 'react'
import { useQuery } from '@tanstack/react-query'
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
import { Button } from './Button'
import { AugmentationConfig } from './AugmentationConfig'
import { useDatasets } from '../hooks/useDatasets'
import { useTrainingDocuments } from '../hooks/useTraining'
import { trainingApi } from '../api/endpoints'
import type { DatasetListItem } from '../api/types'
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
export const Training: React.FC = () => {
const [split, setSplit] = useState(80);
type Tab = 'datasets' | 'create'
const docs = [
{ id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
{ id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
{ id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
{ id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
];
interface TrainingProps {
onNavigate?: (view: string, id?: string) => void
}
const STATUS_STYLES: Record<string, string> = {
ready: 'bg-warm-state-success/10 text-warm-state-success',
building: 'bg-warm-state-info/10 text-warm-state-info',
training: 'bg-warm-state-info/10 text-warm-state-info',
failed: 'bg-warm-state-error/10 text-warm-state-error',
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
running: 'bg-warm-state-info/10 text-warm-state-info',
}
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
// If there's an active training task, show training status
const displayStatus = trainingStatus === 'running'
? 'training'
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
? 'pending'
: status
return (
<div className="p-8 max-w-7xl mx-auto h-[calc(100vh-56px)] flex gap-8">
{/* Document Selection List */}
<div className="flex-1 flex flex-col">
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Document Selection</h2>
<div className="flex-1 bg-warm-card border border-warm-border rounded-lg overflow-hidden flex flex-col shadow-sm">
<div className="overflow-auto flex-1">
<table className="w-full text-left">
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
<tr>
<th className="py-3 pl-6 pr-4 w-12"><input type="checkbox" className="rounded border-warm-divider"/></th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document name</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Date</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
</tr>
</thead>
<tbody>
{docs.map(doc => (
<tr key={doc.id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 pl-6 pr-4"><input type="checkbox" defaultChecked className="rounded border-warm-divider accent-warm-state-info"/></td>
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{doc.name}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.date}</td>
<td className="py-3 px-4">
{doc.status === DocumentStatus.VERIFIED ? (
<div className="flex items-center text-warm-state-success text-sm font-medium">
<div className="w-5 h-5 rounded-full bg-warm-state-success flex items-center justify-center text-white mr-2">
<Check size={12} strokeWidth={3}/>
</div>
Verified
</div>
) : (
<div className="flex items-center text-warm-text-muted text-sm">
<div className="w-5 h-5 rounded-full bg-[#BDBBB5] flex items-center justify-center text-white mr-2">
<span className="font-bold text-[10px]">!</span>
</div>
Partial
</div>
)}
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
</div>
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
{displayStatus}
</span>
)
}
{/* Configuration Panel */}
<div className="w-96">
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
<h3 className="text-lg font-semibold text-warm-text-primary mb-6">Training Configuration</h3>
<div className="space-y-6">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Model Name</label>
<input
type="text"
placeholder="e.g. Invoices Q4"
// --- Train Dialog ---
interface TrainDialogProps {
dataset: DatasetListItem
onClose: () => void
onSubmit: (config: {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs: number
batch_size: number
augmentation?: AugmentationConfigType
augmentation_multiplier?: number
}
}) => void
isPending: boolean
}
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
const [name, setName] = useState(`train-${dataset.name}`)
const [epochs, setEpochs] = useState(100)
const [batchSize, setBatchSize] = useState(16)
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
// Fetch available trained models
const { data: modelsData } = useQuery({
queryKey: ['training', 'models', 'completed'],
queryFn: () => trainingApi.getModels({ status: 'completed' }),
})
const completedModels = modelsData?.models ?? []
const handleSubmit = () => {
onSubmit({
name,
config: {
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
epochs,
batch_size: batchSize,
augmentation: augmentationEnabled
? (augmentationConfig as AugmentationConfigType)
: undefined,
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
},
})
}
return (
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
<p className="text-sm text-warm-text-muted mb-4">
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
</p>
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
<input type="text" value={name} onChange={e => setName(e.target.value)}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
{/* Base Model Selection */}
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
<select
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
onChange={e => {
if (e.target.value === 'pretrained') {
setBaseModelType('pretrained')
setBaseModelVersionId(null)
} else {
setBaseModelType('existing')
setBaseModelVersionId(e.target.value)
}
}}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
>
<option value="pretrained">yolo11n.pt (Pretrained)</option>
{completedModels.map(m => (
<option key={m.task_id} value={m.task_id}>
{m.name} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
</option>
))}
</select>
<p className="text-xs text-warm-text-muted mt-1">
{baseModelType === 'pretrained'
? 'Start from pretrained YOLO model'
: 'Continue training from an existing model (incremental training)'}
</p>
</div>
<div className="flex gap-4">
<div className="flex-1">
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
<input
id="train-epochs"
type="number"
min={1}
max={1000}
value={epochs}
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Base Model</label>
<select className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none">
<option>LayoutLMv3 (Standard)</option>
<option>Donut (Beta)</option>
</select>
</div>
<div>
<div className="flex justify-between mb-2">
<label className="block text-sm font-medium text-warm-text-secondary">Train/Test Split</label>
<span className="text-xs font-mono text-warm-text-muted">{split}% / {100-split}%</span>
</div>
<input
type="range"
min="50"
max="95"
value={split}
onChange={(e) => setSplit(parseInt(e.target.value))}
className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
<div className="flex-1">
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
<input
id="train-batch-size"
type="number"
min={1}
max={128}
value={batchSize}
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
</div>
{/* Augmentation Configuration */}
<AugmentationConfig
enabled={augmentationEnabled}
onEnabledChange={setAugmentationEnabled}
config={augmentationConfig}
onConfigChange={setAugmentationConfig}
/>
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
{augmentationEnabled && (
<div>
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
Augmentation Multiplier
</label>
<input
id="aug-multiplier"
type="number"
min={1}
max={10}
value={augmentationMultiplier}
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
<p className="text-xs text-warm-text-muted mt-1">
Number of augmented copies per original image (1-10)
</p>
</div>
)}
</div>
<div className="flex justify-end gap-3 mt-6">
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
</Button>
</div>
</div>
</div>
)
}
// --- Dataset List ---
const DatasetList: React.FC<{
onNavigate?: (view: string, id?: string) => void
onSwitchTab: (tab: Tab) => void
}> = ({ onNavigate, onSwitchTab }) => {
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
const handleTrain = (config: {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs: number
batch_size: number
augmentation?: AugmentationConfigType
augmentation_multiplier?: number
}
}) => {
if (!trainTarget) return
// Pass config to the training API
const trainRequest = {
name: config.name,
config: config.config,
}
trainFromDataset(
{ datasetId: trainTarget.dataset_id, req: trainRequest },
{ onSuccess: () => setTrainTarget(null) },
)
}
if (isLoading) {
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
}
if (datasets.length === 0) {
return (
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
<Database size={48} className="mb-4 opacity-40" />
<p className="text-lg mb-2">No datasets yet</p>
<p className="text-sm mb-4">Create a dataset to start training</p>
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
</div>
)
}
return (
<>
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
<table className="w-full text-left">
<thead className="bg-white border-b border-warm-border">
<tr>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
</tr>
</thead>
<tbody>
{datasets.map(ds => (
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
<td className="py-3 px-4">
<div className="flex gap-1">
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
<Eye size={14} />
</button>
{ds.status === 'ready' && (
<button title="Train" onClick={() => setTrainTarget(ds)}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
<Play size={14} />
</button>
)}
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
disabled={isDeleting}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors">
<Trash2 size={14} />
</button>
</div>
</td>
</tr>
))}
</tbody>
</table>
</div>
{trainTarget && (
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
)}
</>
)
}
// --- Create Dataset ---
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
const { createDatasetAsync, isCreating } = useDatasets()
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
const [name, setName] = useState('')
const [description, setDescription] = useState('')
const [trainRatio, setTrainRatio] = useState(0.7)
const [valRatio, setValRatio] = useState(0.2)
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
const toggleDoc = (id: string) => {
setSelectedIds(prev => {
const next = new Set(prev)
if (next.has(id)) { next.delete(id) } else { next.add(id) }
return next
})
}
const toggleAll = () => {
if (selectedIds.size === documents.length) {
setSelectedIds(new Set())
} else {
setSelectedIds(new Set(documents.map((d) => d.document_id)))
}
}
const handleCreate = async () => {
await createDatasetAsync({
name,
description: description || undefined,
document_ids: [...selectedIds],
train_ratio: trainRatio,
val_ratio: valRatio,
})
onSwitchTab('datasets')
}
return (
<div className="flex gap-8">
{/* Document selection */}
<div className="flex-1 flex flex-col">
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
{isLoadingDocs ? (
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
) : (
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
<div className="overflow-auto max-h-[calc(100vh-240px)]">
<table className="w-full text-left">
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
<tr>
<th className="py-3 pl-6 pr-4 w-12">
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
</tr>
</thead>
<tbody>
{documents.map((doc) => (
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
onClick={() => toggleDoc(doc.document_id)}>
<td className="py-3 pl-6 pr-4">
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
</td>
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
)}
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
</div>
{/* Config panel */}
<div className="w-80">
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
<div className="flex gap-2 text-sm">
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Train</span>
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Val</span>
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Test</span>
<input type="number" value={testRatio} readOnly
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
</div>
</div>
</div>
<div className="pt-4 border-t border-warm-border">
<Button className="w-full h-12">Start Training</Button>
{selectedIds.size > 0 && selectedIds.size < 10 && (
<p className="text-xs text-warm-state-warning mb-2">
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
</p>
)}
<Button className="w-full h-11" onClick={handleCreate}
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
</Button>
</div>
</div>
</div>
</div>
</div>
);
};
)
}
// --- Main Training Component ---
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
const [activeTab, setActiveTab] = useState<Tab>('datasets')
return (
<div className="p-8 max-w-7xl mx-auto">
<div className="flex items-center justify-between mb-6">
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
</div>
{/* Tabs */}
<div className="flex gap-1 mb-6 border-b border-warm-border">
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
<button key={key} onClick={() => setActiveTab(key)}
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
activeTab === key
? 'border-warm-state-info text-warm-state-info'
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
}`}>
{label}
</button>
))}
</div>
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
</div>
)
}

View File

@@ -11,6 +11,7 @@ interface UploadModalProps {
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
const [isDragging, setIsDragging] = useState(false)
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
const [groupKey, setGroupKey] = useState('')
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
const [errorMessage, setErrorMessage] = useState('')
const fileInputRef = useRef<HTMLInputElement>(null)
@@ -61,10 +62,13 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
// Upload files one by one
for (const file of selectedFiles) {
await new Promise<void>((resolve, reject) => {
uploadDocument(file, {
onSuccess: () => resolve(),
onError: (error: Error) => reject(error),
})
uploadDocument(
{ file, groupKey: groupKey || undefined },
{
onSuccess: () => resolve(),
onError: (error: Error) => reject(error),
}
)
})
}
@@ -72,6 +76,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
setTimeout(() => {
onClose()
setSelectedFiles([])
setGroupKey('')
setUploadStatus('idle')
}, 1500)
} catch (error) {
@@ -85,6 +90,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
return // Prevent closing during upload
}
setSelectedFiles([])
setGroupKey('')
setUploadStatus('idle')
setErrorMessage('')
onClose()
@@ -173,6 +179,26 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
</div>
)}
{/* Group Key Input */}
{selectedFiles.length > 0 && (
<div className="mb-6">
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
Group Key (optional)
</label>
<input
type="text"
value={groupKey}
onChange={(e) => setGroupKey(e.target.value)}
placeholder="e.g., 2024-Q1, supplier-abc, project-name"
className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
disabled={uploadStatus === 'uploading'}
/>
<p className="text-xs text-warm-text-muted mt-1">
Use group keys to organize documents into logical groups
</p>
</div>
)}
{/* Status Messages */}
{uploadStatus === 'success' && (
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">

View File

@@ -2,3 +2,6 @@ export { useDocuments } from './useDocuments'
export { useDocumentDetail } from './useDocumentDetail'
export { useAnnotations } from './useAnnotations'
export { useTraining, useTrainingDocuments } from './useTraining'
export { useDatasets, useDatasetDetail } from './useDatasets'
export { useAugmentation } from './useAugmentation'
export { useModels, useModelDetail, useActiveModel } from './useModels'

View File

@@ -0,0 +1,226 @@
/**
* Tests for useAugmentation hook.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { renderHook, waitFor } from '@testing-library/react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { augmentationApi } from '../api/endpoints/augmentation'
import { useAugmentation } from './useAugmentation'
import type { ReactNode } from 'react'
// Mock the API
vi.mock('../api/endpoints/augmentation', () => ({
augmentationApi: {
getTypes: vi.fn(),
getPresets: vi.fn(),
preview: vi.fn(),
previewConfig: vi.fn(),
createBatch: vi.fn(),
},
}))
// Test wrapper with QueryClient
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return ({ children }: { children: ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
}
describe('useAugmentation', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('getTypes', () => {
it('should fetch augmentation types', async () => {
const mockTypes = {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
{
name: 'perspective_warp',
description: 'Applies perspective warp',
affects_geometry: true,
stage: 'geometric',
default_params: { max_warp: 0.02 },
},
],
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce(mockTypes)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.augmentationTypes).toHaveLength(2)
expect(result.current.augmentationTypes[0].name).toBe('gaussian_noise')
})
it('should handle error when fetching types', async () => {
vi.mocked(augmentationApi.getTypes).mockRejectedValueOnce(new Error('Network error'))
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.typesError).toBeTruthy()
})
})
describe('getPresets', () => {
it('should fetch augmentation presets', async () => {
const mockPresets = {
presets: [
{ name: 'conservative', description: 'Safe augmentations' },
{ name: 'moderate', description: 'Balanced augmentations' },
{ name: 'aggressive', description: 'Strong augmentations' },
],
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce(mockPresets)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingPresets).toBe(false)
})
expect(result.current.presets).toHaveLength(3)
expect(result.current.presets[0].name).toBe('conservative')
})
})
describe('preview', () => {
it('should preview single augmentation', async () => {
const mockPreview = {
preview_url: 'data:image/png;base64,xxx',
original_url: 'data:image/png;base64,yyy',
applied_params: { std: 15 },
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.preview).mockResolvedValueOnce(mockPreview)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
// Call preview mutation
result.current.preview({
documentId: 'doc-123',
augmentationType: 'gaussian_noise',
params: { std: 15 },
page: 1,
})
await waitFor(() => {
expect(augmentationApi.preview).toHaveBeenCalledWith(
'doc-123',
{ augmentation_type: 'gaussian_noise', params: { std: 15 } },
1
)
})
})
it('should track preview loading state', async () => {
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.preview).mockImplementation(
() => new Promise((resolve) => setTimeout(resolve, 100))
)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.isPreviewing).toBe(false)
result.current.preview({
documentId: 'doc-123',
augmentationType: 'gaussian_noise',
params: {},
page: 1,
})
// State update happens asynchronously
await waitFor(() => {
expect(result.current.isPreviewing).toBe(true)
})
})
})
describe('createBatch', () => {
it('should create augmented dataset', async () => {
const mockResponse = {
task_id: 'task-123',
status: 'pending',
message: 'Augmentation task queued',
estimated_images: 100,
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.createBatch).mockResolvedValueOnce(mockResponse)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
result.current.createBatch({
dataset_id: 'dataset-123',
config: {
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
},
output_name: 'augmented-dataset',
multiplier: 2,
})
await waitFor(() => {
expect(augmentationApi.createBatch).toHaveBeenCalledWith({
dataset_id: 'dataset-123',
config: {
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
},
output_name: 'augmented-dataset',
multiplier: 2,
})
})
})
})
})

View File

@@ -0,0 +1,121 @@
/**
* Hook for managing augmentation operations.
*
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
*/
import { useQuery, useMutation } from '@tanstack/react-query'
import {
augmentationApi,
type AugmentationTypesResponse,
type PresetsResponse,
type PreviewResponse,
type BatchRequest,
type BatchResponse,
type AugmentationConfig,
} from '../api/endpoints/augmentation'
interface PreviewParams {
documentId: string
augmentationType: string
params: Record<string, unknown>
page?: number
}
interface PreviewConfigParams {
documentId: string
config: AugmentationConfig
page?: number
}
export const useAugmentation = () => {
// Fetch augmentation types
const {
data: typesData,
isLoading: isLoadingTypes,
error: typesError,
} = useQuery<AugmentationTypesResponse>({
queryKey: ['augmentation', 'types'],
queryFn: () => augmentationApi.getTypes(),
staleTime: 5 * 60 * 1000, // Cache for 5 minutes
})
// Fetch presets
const {
data: presetsData,
isLoading: isLoadingPresets,
error: presetsError,
} = useQuery<PresetsResponse>({
queryKey: ['augmentation', 'presets'],
queryFn: () => augmentationApi.getPresets(),
staleTime: 5 * 60 * 1000,
})
// Preview single augmentation mutation
const previewMutation = useMutation<PreviewResponse, Error, PreviewParams>({
mutationFn: ({ documentId, augmentationType, params, page = 1 }) =>
augmentationApi.preview(
documentId,
{ augmentation_type: augmentationType, params },
page
),
onError: (error) => {
console.error('Preview augmentation failed:', error)
},
})
// Preview full config mutation
const previewConfigMutation = useMutation<PreviewResponse, Error, PreviewConfigParams>({
mutationFn: ({ documentId, config, page = 1 }) =>
augmentationApi.previewConfig(documentId, config, page),
onError: (error) => {
console.error('Preview config failed:', error)
},
})
// Create augmented dataset mutation
const createBatchMutation = useMutation<BatchResponse, Error, BatchRequest>({
mutationFn: (request) => augmentationApi.createBatch(request),
onError: (error) => {
console.error('Create augmented dataset failed:', error)
},
})
return {
// Types data
augmentationTypes: typesData?.augmentation_types || [],
isLoadingTypes,
typesError,
// Presets data
presets: presetsData?.presets || [],
isLoadingPresets,
presetsError,
// Preview single augmentation
preview: previewMutation.mutate,
previewAsync: previewMutation.mutateAsync,
isPreviewing: previewMutation.isPending,
previewData: previewMutation.data,
previewError: previewMutation.error,
// Preview full config
previewConfig: previewConfigMutation.mutate,
previewConfigAsync: previewConfigMutation.mutateAsync,
isPreviewingConfig: previewConfigMutation.isPending,
previewConfigData: previewConfigMutation.data,
previewConfigError: previewConfigMutation.error,
// Create batch
createBatch: createBatchMutation.mutate,
createBatchAsync: createBatchMutation.mutateAsync,
isCreatingBatch: createBatchMutation.isPending,
batchData: createBatchMutation.data,
batchError: createBatchMutation.error,
// Reset functions for clearing stale mutation state
resetPreview: previewMutation.reset,
resetPreviewConfig: previewConfigMutation.reset,
resetBatch: createBatchMutation.reset,
}
}

View File

@@ -0,0 +1,84 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { datasetsApi } from '../api/endpoints'
import type {
DatasetCreateRequest,
DatasetDetailResponse,
DatasetListResponse,
DatasetTrainRequest,
} from '../api/types'
export const useDatasets = (params?: {
status?: string
limit?: number
offset?: number
}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<DatasetListResponse>({
queryKey: ['datasets', params],
queryFn: () => datasetsApi.list(params),
staleTime: 30000,
// Poll every 5 seconds when there's an active training task
refetchInterval: (query) => {
const datasets = query.state.data?.datasets ?? []
const hasActiveTraining = datasets.some(
d => d.training_status === 'running' || d.training_status === 'pending' || d.training_status === 'scheduled'
)
return hasActiveTraining ? 5000 : false
},
})
const createMutation = useMutation({
mutationFn: (req: DatasetCreateRequest) => datasetsApi.create(req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
},
})
const deleteMutation = useMutation({
mutationFn: (datasetId: string) => datasetsApi.remove(datasetId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
},
})
const trainMutation = useMutation({
mutationFn: ({ datasetId, req }: { datasetId: string; req: DatasetTrainRequest }) =>
datasetsApi.trainFromDataset(datasetId, req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
},
})
return {
datasets: data?.datasets ?? [],
total: data?.total ?? 0,
isLoading,
error,
refetch,
createDataset: createMutation.mutate,
createDatasetAsync: createMutation.mutateAsync,
isCreating: createMutation.isPending,
deleteDataset: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
trainFromDataset: trainMutation.mutate,
trainFromDatasetAsync: trainMutation.mutateAsync,
isTraining: trainMutation.isPending,
}
}
export const useDatasetDetail = (datasetId: string | null) => {
const { data, isLoading, error } = useQuery<DatasetDetailResponse>({
queryKey: ['datasets', datasetId],
queryFn: () => datasetsApi.getDetail(datasetId!),
enabled: !!datasetId,
staleTime: 30000,
})
return {
dataset: data ?? null,
isLoading,
error,
}
}

View File

@@ -18,7 +18,16 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
})
const uploadMutation = useMutation({
mutationFn: (file: File) => documentsApi.upload(file),
mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) =>
documentsApi.upload(file, groupKey),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const updateGroupKeyMutation = useMutation({
mutationFn: ({ documentId, groupKey }: { documentId: string; groupKey: string | null }) =>
documentsApi.updateGroupKey(documentId, groupKey),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
@@ -74,5 +83,8 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
isUpdatingStatus: updateStatusMutation.isPending,
triggerAutoLabel: triggerAutoLabelMutation.mutate,
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
updateGroupKey: updateGroupKeyMutation.mutate,
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
}
}

View File

@@ -0,0 +1,98 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { modelsApi } from '../api/endpoints'
import type {
ModelVersionListResponse,
ModelVersionDetailResponse,
ActiveModelResponse,
} from '../api/types'
export const useModels = (params?: {
status?: string
limit?: number
offset?: number
}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<ModelVersionListResponse>({
queryKey: ['models', params],
queryFn: () => modelsApi.list(params),
staleTime: 30000,
})
const activateMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.activate(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
},
})
const deactivateMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.deactivate(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
},
})
const archiveMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.archive(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
},
})
const deleteMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.delete(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
},
})
return {
models: data?.models ?? [],
total: data?.total ?? 0,
isLoading,
error,
refetch,
activateModel: activateMutation.mutate,
activateModelAsync: activateMutation.mutateAsync,
isActivating: activateMutation.isPending,
deactivateModel: deactivateMutation.mutate,
isDeactivating: deactivateMutation.isPending,
archiveModel: archiveMutation.mutate,
isArchiving: archiveMutation.isPending,
deleteModel: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
}
}
export const useModelDetail = (versionId: string | null) => {
const { data, isLoading, error } = useQuery<ModelVersionDetailResponse>({
queryKey: ['models', versionId],
queryFn: () => modelsApi.getDetail(versionId!),
enabled: !!versionId,
staleTime: 30000,
})
return {
model: data ?? null,
isLoading,
error,
}
}
export const useActiveModel = () => {
const { data, isLoading, error } = useQuery<ActiveModelResponse>({
queryKey: ['models', 'active'],
queryFn: () => modelsApi.getActive(),
staleTime: 30000,
})
return {
hasActiveModel: data?.has_active_model ?? false,
activeModel: data?.model ?? null,
isLoading,
error,
}
}