650 lines
30 KiB
TypeScript
650 lines
30 KiB
TypeScript
import React, { useState, useMemo } from 'react'
|
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
|
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle, Shield, CheckCircle, XCircle } from 'lucide-react'
|
|
import { Button } from './Button'
|
|
import { AugmentationConfig } from './AugmentationConfig'
|
|
import { useDatasets } from '../hooks/useDatasets'
|
|
import { useTrainingDocuments } from '../hooks/useTraining'
|
|
import { trainingApi, poolApi } from '../api/endpoints'
|
|
import type { DatasetListItem, PoolEntryItem } from '../api/types'
|
|
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
|
|
|
type Tab = 'datasets' | 'create' | 'pool'
|
|
|
|
interface TrainingProps {
|
|
onNavigate?: (view: string, id?: string) => void
|
|
}
|
|
|
|
const STATUS_STYLES: Record<string, string> = {
|
|
ready: 'bg-warm-state-success/10 text-warm-state-success',
|
|
building: 'bg-warm-state-info/10 text-warm-state-info',
|
|
training: 'bg-warm-state-info/10 text-warm-state-info',
|
|
failed: 'bg-warm-state-error/10 text-warm-state-error',
|
|
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
|
|
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
|
|
running: 'bg-warm-state-info/10 text-warm-state-info',
|
|
}
|
|
|
|
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
|
|
// If there's an active training task, show training status
|
|
const displayStatus = trainingStatus === 'running'
|
|
? 'training'
|
|
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
|
|
? 'pending'
|
|
: status
|
|
|
|
return (
|
|
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
|
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
|
|
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
|
|
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
|
|
{displayStatus}
|
|
</span>
|
|
)
|
|
}
|
|
|
|
// --- Train Dialog ---
|
|
|
|
interface TrainDialogProps {
|
|
dataset: DatasetListItem
|
|
onClose: () => void
|
|
onSubmit: (config: {
|
|
name: string
|
|
config: {
|
|
model_name?: string
|
|
base_model_version_id?: string | null
|
|
epochs: number
|
|
batch_size: number
|
|
augmentation?: AugmentationConfigType
|
|
augmentation_multiplier?: number
|
|
}
|
|
}) => void
|
|
isPending: boolean
|
|
}
|
|
|
|
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
|
|
const [name, setName] = useState(`train-${dataset.name}`)
|
|
const [epochs, setEpochs] = useState(100)
|
|
const [batchSize, setBatchSize] = useState(16)
|
|
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
|
|
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
|
|
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
|
|
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
|
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
|
|
|
const isFineTune = baseModelType === 'existing'
|
|
|
|
// Fetch available trained models (active or inactive, not archived)
|
|
const { data: modelsData } = useQuery({
|
|
queryKey: ['training', 'models', 'available'],
|
|
queryFn: () => trainingApi.getModels(),
|
|
})
|
|
// Only show base models (not fine-tuned) for selection - prevents chaining fine-tunes
|
|
const availableModels = (modelsData?.models ?? []).filter(
|
|
m => m.status !== 'archived' && (m.model_type ?? 'base') === 'base'
|
|
)
|
|
|
|
const handleSubmit = () => {
|
|
onSubmit({
|
|
name,
|
|
config: {
|
|
model_name: baseModelType === 'pretrained' ? 'yolo26s.pt' : undefined,
|
|
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
|
|
epochs,
|
|
batch_size: batchSize,
|
|
augmentation: augmentationEnabled
|
|
? (augmentationConfig as AugmentationConfigType)
|
|
: undefined,
|
|
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
|
|
},
|
|
})
|
|
}
|
|
|
|
return (
|
|
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
|
|
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
|
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
|
|
<p className="text-sm text-warm-text-muted mb-4">
|
|
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
|
|
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
|
|
</p>
|
|
|
|
<div className="space-y-4">
|
|
<div>
|
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
|
|
<input type="text" value={name} onChange={e => setName(e.target.value)}
|
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
|
</div>
|
|
|
|
{/* Base Model Selection */}
|
|
<div>
|
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
|
|
<select
|
|
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
|
|
onChange={e => {
|
|
if (e.target.value === 'pretrained') {
|
|
setBaseModelType('pretrained')
|
|
setBaseModelVersionId(null)
|
|
setEpochs(100)
|
|
} else {
|
|
setBaseModelType('existing')
|
|
setBaseModelVersionId(e.target.value)
|
|
setEpochs(10) // Fine-tune: fewer epochs per best practices
|
|
}
|
|
}}
|
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
|
>
|
|
<option value="pretrained">yolo26s.pt (Pretrained)</option>
|
|
{availableModels.map(m => (
|
|
<option key={m.version_id} value={m.version_id}>
|
|
{m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
|
</option>
|
|
))}
|
|
</select>
|
|
<p className="text-xs text-warm-text-muted mt-1">
|
|
{baseModelType === 'pretrained'
|
|
? 'Start from pretrained YOLO model'
|
|
: 'Fine-tune from base model (freeze=10, cos_lr, data mixing)'}
|
|
</p>
|
|
</div>
|
|
|
|
{/* Fine-tune info panel */}
|
|
{isFineTune && (
|
|
<div className="bg-warm-state-info/5 border border-warm-state-info/20 rounded-lg p-3 text-xs text-warm-text-secondary">
|
|
<p className="font-medium text-warm-state-info mb-1">Fine-Tune Mode</p>
|
|
<ul className="space-y-0.5 text-warm-text-muted">
|
|
<li>Epochs: 10 (auto-set), Backbone frozen (10 layers)</li>
|
|
<li>Cosine LR scheduler, Pool data mixed with old data</li>
|
|
<li>Requires 50+ verified pool entries</li>
|
|
<li>Deployment gating runs automatically after training</li>
|
|
</ul>
|
|
</div>
|
|
)}
|
|
|
|
<div className="flex gap-4">
|
|
<div className="flex-1">
|
|
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
|
|
<input
|
|
id="train-epochs"
|
|
type="number"
|
|
min={1}
|
|
max={1000}
|
|
value={epochs}
|
|
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
|
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
|
/>
|
|
</div>
|
|
<div className="flex-1">
|
|
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
|
|
<input
|
|
id="train-batch-size"
|
|
type="number"
|
|
min={1}
|
|
max={128}
|
|
value={batchSize}
|
|
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
|
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
|
/>
|
|
</div>
|
|
</div>
|
|
|
|
{/* Augmentation Configuration */}
|
|
<AugmentationConfig
|
|
enabled={augmentationEnabled}
|
|
onEnabledChange={setAugmentationEnabled}
|
|
config={augmentationConfig}
|
|
onConfigChange={setAugmentationConfig}
|
|
/>
|
|
|
|
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
|
|
{augmentationEnabled && (
|
|
<div>
|
|
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
|
|
Augmentation Multiplier
|
|
</label>
|
|
<input
|
|
id="aug-multiplier"
|
|
type="number"
|
|
min={1}
|
|
max={10}
|
|
value={augmentationMultiplier}
|
|
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
|
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
|
/>
|
|
<p className="text-xs text-warm-text-muted mt-1">
|
|
Number of augmented copies per original image (1-10)
|
|
</p>
|
|
</div>
|
|
)}
|
|
</div>
|
|
|
|
<div className="flex justify-end gap-3 mt-6">
|
|
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
|
|
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
|
|
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
|
|
</Button>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
)
|
|
}
|
|
|
|
// --- Dataset List ---
|
|
|
|
const DatasetList: React.FC<{
|
|
onNavigate?: (view: string, id?: string) => void
|
|
onSwitchTab: (tab: Tab) => void
|
|
}> = ({ onNavigate, onSwitchTab }) => {
|
|
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
|
|
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
|
|
|
|
const handleTrain = (config: {
|
|
name: string
|
|
config: {
|
|
model_name?: string
|
|
base_model_version_id?: string | null
|
|
epochs: number
|
|
batch_size: number
|
|
augmentation?: AugmentationConfigType
|
|
augmentation_multiplier?: number
|
|
}
|
|
}) => {
|
|
if (!trainTarget) return
|
|
// Pass config to the training API
|
|
const trainRequest = {
|
|
name: config.name,
|
|
config: config.config,
|
|
}
|
|
trainFromDataset(
|
|
{ datasetId: trainTarget.dataset_id, req: trainRequest },
|
|
{ onSuccess: () => setTrainTarget(null) },
|
|
)
|
|
}
|
|
|
|
if (isLoading) {
|
|
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
|
|
}
|
|
|
|
if (datasets.length === 0) {
|
|
return (
|
|
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
|
|
<Database size={48} className="mb-4 opacity-40" />
|
|
<p className="text-lg mb-2">No datasets yet</p>
|
|
<p className="text-sm mb-4">Create a dataset to start training</p>
|
|
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
|
|
</div>
|
|
)
|
|
}
|
|
|
|
return (
|
|
<>
|
|
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
|
<table className="w-full text-left">
|
|
<thead className="bg-white border-b border-warm-border">
|
|
<tr>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
|
|
</tr>
|
|
</thead>
|
|
<tbody>
|
|
{datasets.map(ds => (
|
|
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
|
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
|
|
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
|
|
<td className="py-3 px-4">
|
|
<div className="flex gap-1">
|
|
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
|
|
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
|
|
<Eye size={14} />
|
|
</button>
|
|
{ds.status === 'ready' && (
|
|
<button title="Train" onClick={() => setTrainTarget(ds)}
|
|
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
|
|
<Play size={14} />
|
|
</button>
|
|
)}
|
|
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
|
|
disabled={isDeleting || ds.status === 'pending' || ds.status === 'building'}
|
|
className={`p-1.5 rounded transition-colors ${
|
|
ds.status === 'pending' || ds.status === 'building'
|
|
? 'text-warm-text-muted/40 cursor-not-allowed'
|
|
: 'hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error'
|
|
}`}>
|
|
<Trash2 size={14} />
|
|
</button>
|
|
</div>
|
|
</td>
|
|
</tr>
|
|
))}
|
|
</tbody>
|
|
</table>
|
|
</div>
|
|
|
|
{trainTarget && (
|
|
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
|
|
)}
|
|
</>
|
|
)
|
|
}
|
|
|
|
// --- Create Dataset ---
|
|
|
|
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
|
|
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
|
|
const { createDatasetAsync, isCreating } = useDatasets()
|
|
|
|
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
|
|
const [name, setName] = useState('')
|
|
const [description, setDescription] = useState('')
|
|
const [trainRatio, setTrainRatio] = useState(0.7)
|
|
const [valRatio, setValRatio] = useState(0.2)
|
|
|
|
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
|
|
|
|
const toggleDoc = (id: string) => {
|
|
setSelectedIds(prev => {
|
|
const next = new Set(prev)
|
|
if (next.has(id)) { next.delete(id) } else { next.add(id) }
|
|
return next
|
|
})
|
|
}
|
|
|
|
const toggleAll = () => {
|
|
if (selectedIds.size === documents.length) {
|
|
setSelectedIds(new Set())
|
|
} else {
|
|
setSelectedIds(new Set(documents.map((d) => d.document_id)))
|
|
}
|
|
}
|
|
|
|
const handleCreate = async () => {
|
|
await createDatasetAsync({
|
|
name,
|
|
description: description || undefined,
|
|
document_ids: [...selectedIds],
|
|
train_ratio: trainRatio,
|
|
val_ratio: valRatio,
|
|
})
|
|
onSwitchTab('datasets')
|
|
}
|
|
|
|
return (
|
|
<div className="flex gap-8">
|
|
{/* Document selection */}
|
|
<div className="flex-1 flex flex-col">
|
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
|
|
{isLoadingDocs ? (
|
|
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
|
|
) : (
|
|
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
|
|
<div className="overflow-auto max-h-[calc(100vh-240px)]">
|
|
<table className="w-full text-left">
|
|
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
|
<tr>
|
|
<th className="py-3 pl-6 pr-4 w-12">
|
|
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
|
|
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
|
|
</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
|
</tr>
|
|
</thead>
|
|
<tbody>
|
|
{documents.map((doc) => (
|
|
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
|
|
onClick={() => toggleDoc(doc.document_id)}>
|
|
<td className="py-3 pl-6 pr-4">
|
|
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
|
|
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
|
|
</td>
|
|
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
|
|
</tr>
|
|
))}
|
|
</tbody>
|
|
</table>
|
|
</div>
|
|
</div>
|
|
)}
|
|
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
|
|
</div>
|
|
|
|
{/* Config panel */}
|
|
<div className="w-80">
|
|
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
|
|
<div className="space-y-4">
|
|
<div>
|
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
|
|
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
|
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
|
</div>
|
|
<div>
|
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
|
|
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
|
|
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
|
|
</div>
|
|
<div>
|
|
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
|
|
<div className="flex gap-2 text-sm">
|
|
<div className="flex-1">
|
|
<span className="text-xs text-warm-text-muted">Train</span>
|
|
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
|
|
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
|
</div>
|
|
<div className="flex-1">
|
|
<span className="text-xs text-warm-text-muted">Val</span>
|
|
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
|
|
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
|
</div>
|
|
<div className="flex-1">
|
|
<span className="text-xs text-warm-text-muted">Test</span>
|
|
<input type="number" value={testRatio} readOnly
|
|
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
<div className="pt-4 border-t border-warm-border">
|
|
{selectedIds.size > 0 && selectedIds.size < 10 && (
|
|
<p className="text-xs text-warm-state-warning mb-2">
|
|
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
|
|
</p>
|
|
)}
|
|
<Button className="w-full h-11" onClick={handleCreate}
|
|
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
|
|
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
|
|
</Button>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
)
|
|
}
|
|
|
|
// --- Fine-Tune Pool ---
|
|
|
|
const FineTunePool: React.FC = () => {
|
|
const queryClient = useQueryClient()
|
|
|
|
const { data: statsData, isLoading: isLoadingStats } = useQuery({
|
|
queryKey: ['pool', 'stats'],
|
|
queryFn: () => poolApi.getStats(),
|
|
})
|
|
|
|
const { data: entriesData, isLoading: isLoadingEntries } = useQuery({
|
|
queryKey: ['pool', 'entries'],
|
|
queryFn: () => poolApi.listEntries({ limit: 50 }),
|
|
})
|
|
|
|
const verifyMutation = useMutation({
|
|
mutationFn: (entryId: string) => poolApi.verifyEntry(entryId),
|
|
onSuccess: () => {
|
|
queryClient.invalidateQueries({ queryKey: ['pool'] })
|
|
},
|
|
})
|
|
|
|
const removeMutation = useMutation({
|
|
mutationFn: (entryId: string) => poolApi.removeEntry(entryId),
|
|
onSuccess: () => {
|
|
queryClient.invalidateQueries({ queryKey: ['pool'] })
|
|
},
|
|
})
|
|
|
|
const stats = statsData
|
|
const entries = entriesData?.entries ?? []
|
|
|
|
return (
|
|
<div className="space-y-6">
|
|
{/* Pool Stats */}
|
|
<div className="grid grid-cols-4 gap-4">
|
|
{isLoadingStats ? (
|
|
<div className="col-span-4 flex items-center justify-center py-8 text-warm-text-muted">
|
|
<Loader2 size={20} className="animate-spin mr-2" />Loading stats...
|
|
</div>
|
|
) : (
|
|
<>
|
|
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
|
<p className="text-xs text-warm-text-muted uppercase mb-1">Total Entries</p>
|
|
<p className="text-2xl font-bold font-mono text-warm-text-primary">{stats?.total_entries ?? 0}</p>
|
|
</div>
|
|
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
|
<p className="text-xs text-warm-text-muted uppercase mb-1">Verified</p>
|
|
<p className="text-2xl font-bold font-mono text-warm-state-success">{stats?.verified_entries ?? 0}</p>
|
|
</div>
|
|
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
|
<p className="text-xs text-warm-text-muted uppercase mb-1">Unverified</p>
|
|
<p className="text-2xl font-bold font-mono text-warm-state-warning">{stats?.unverified_entries ?? 0}</p>
|
|
</div>
|
|
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
|
<p className="text-xs text-warm-text-muted uppercase mb-1">Ready for Fine-Tune</p>
|
|
<div className="flex items-center gap-2">
|
|
{stats?.is_ready ? (
|
|
<CheckCircle size={20} className="text-warm-state-success" />
|
|
) : (
|
|
<AlertCircle size={20} className="text-warm-state-warning" />
|
|
)}
|
|
<p className="text-lg font-medium text-warm-text-primary">
|
|
{stats?.is_ready ? 'Yes' : `Need ${(stats?.min_required ?? 50) - (stats?.verified_entries ?? 0)} more`}
|
|
</p>
|
|
</div>
|
|
</div>
|
|
</>
|
|
)}
|
|
</div>
|
|
|
|
{/* Pool Entries Table */}
|
|
{isLoadingEntries ? (
|
|
<div className="flex items-center justify-center py-12 text-warm-text-muted">
|
|
<Loader2 size={20} className="animate-spin mr-2" />Loading pool entries...
|
|
</div>
|
|
) : entries.length === 0 ? (
|
|
<div className="flex flex-col items-center justify-center py-16 text-warm-text-muted">
|
|
<Shield size={48} className="mb-4 opacity-40" />
|
|
<p className="text-lg mb-2">Fine-tune pool is empty</p>
|
|
<p className="text-sm">Add documents with extraction failures to the pool for future fine-tuning.</p>
|
|
</div>
|
|
) : (
|
|
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
|
<table className="w-full text-left">
|
|
<thead className="bg-white border-b border-warm-border">
|
|
<tr>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Reason</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Added</th>
|
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
|
|
</tr>
|
|
</thead>
|
|
<tbody>
|
|
{entries.map((entry: PoolEntryItem) => (
|
|
<tr key={entry.entry_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
|
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{entry.document_id.slice(0, 8)}...</td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted">{entry.reason ?? '-'}</td>
|
|
<td className="py-3 px-4">
|
|
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${
|
|
entry.is_verified
|
|
? 'bg-warm-state-success/10 text-warm-state-success'
|
|
: 'bg-warm-state-warning/10 text-warm-state-warning'
|
|
}`}>
|
|
{entry.is_verified ? <Check size={12} className="mr-1" /> : <AlertCircle size={12} className="mr-1" />}
|
|
{entry.is_verified ? 'Verified' : 'Unverified'}
|
|
</span>
|
|
</td>
|
|
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(entry.created_at).toLocaleDateString()}</td>
|
|
<td className="py-3 px-4">
|
|
<div className="flex gap-1">
|
|
{!entry.is_verified && (
|
|
<button
|
|
title="Verify"
|
|
onClick={() => verifyMutation.mutate(entry.entry_id)}
|
|
disabled={verifyMutation.isPending}
|
|
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors"
|
|
>
|
|
<CheckCircle size={14} />
|
|
</button>
|
|
)}
|
|
<button
|
|
title="Remove"
|
|
onClick={() => removeMutation.mutate(entry.entry_id)}
|
|
disabled={removeMutation.isPending}
|
|
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors"
|
|
>
|
|
<Trash2 size={14} />
|
|
</button>
|
|
</div>
|
|
</td>
|
|
</tr>
|
|
))}
|
|
</tbody>
|
|
</table>
|
|
</div>
|
|
)}
|
|
</div>
|
|
)
|
|
}
|
|
|
|
// --- Main Training Component ---
|
|
|
|
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
|
const [activeTab, setActiveTab] = useState<Tab>('datasets')
|
|
|
|
return (
|
|
<div className="p-8 max-w-7xl mx-auto">
|
|
<div className="flex items-center justify-between mb-6">
|
|
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
|
|
</div>
|
|
|
|
{/* Tabs */}
|
|
<div className="flex gap-1 mb-6 border-b border-warm-border">
|
|
{([['datasets', 'Datasets'], ['create', 'Create Dataset'], ['pool', 'Fine-Tune Pool']] as const).map(([key, label]) => (
|
|
<button key={key} onClick={() => setActiveTab(key)}
|
|
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
|
|
activeTab === key
|
|
? 'border-warm-state-info text-warm-state-info'
|
|
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
|
|
}`}>
|
|
{label}
|
|
</button>
|
|
))}
|
|
</div>
|
|
|
|
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
|
|
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
|
|
{activeTab === 'pool' && <FineTunePool />}
|
|
</div>
|
|
)
|
|
}
|