WIP
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
import React, { useState, useMemo } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle, Shield, CheckCircle, XCircle } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { AugmentationConfig } from './AugmentationConfig'
|
||||
import { useDatasets } from '../hooks/useDatasets'
|
||||
import { useTrainingDocuments } from '../hooks/useTraining'
|
||||
import { trainingApi } from '../api/endpoints'
|
||||
import type { DatasetListItem } from '../api/types'
|
||||
import { trainingApi, poolApi } from '../api/endpoints'
|
||||
import type { DatasetListItem, PoolEntryItem } from '../api/types'
|
||||
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||
|
||||
type Tab = 'datasets' | 'create'
|
||||
type Tab = 'datasets' | 'create' | 'pool'
|
||||
|
||||
interface TrainingProps {
|
||||
onNavigate?: (view: string, id?: string) => void
|
||||
@@ -72,19 +72,23 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
|
||||
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
||||
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
||||
|
||||
const isFineTune = baseModelType === 'existing'
|
||||
|
||||
// Fetch available trained models (active or inactive, not archived)
|
||||
const { data: modelsData } = useQuery({
|
||||
queryKey: ['training', 'models', 'available'],
|
||||
queryFn: () => trainingApi.getModels(),
|
||||
})
|
||||
// Filter out archived models - only show active/inactive models for base model selection
|
||||
const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived')
|
||||
// Only show base models (not fine-tuned) for selection - prevents chaining fine-tunes
|
||||
const availableModels = (modelsData?.models ?? []).filter(
|
||||
m => m.status !== 'archived' && (m.model_type ?? 'base') === 'base'
|
||||
)
|
||||
|
||||
const handleSubmit = () => {
|
||||
onSubmit({
|
||||
name,
|
||||
config: {
|
||||
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
|
||||
model_name: baseModelType === 'pretrained' ? 'yolo26s.pt' : undefined,
|
||||
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
|
||||
epochs,
|
||||
batch_size: batchSize,
|
||||
@@ -121,14 +125,16 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
|
||||
if (e.target.value === 'pretrained') {
|
||||
setBaseModelType('pretrained')
|
||||
setBaseModelVersionId(null)
|
||||
setEpochs(100)
|
||||
} else {
|
||||
setBaseModelType('existing')
|
||||
setBaseModelVersionId(e.target.value)
|
||||
setEpochs(10) // Fine-tune: fewer epochs per best practices
|
||||
}
|
||||
}}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
>
|
||||
<option value="pretrained">yolo11n.pt (Pretrained)</option>
|
||||
<option value="pretrained">yolo26s.pt (Pretrained)</option>
|
||||
{availableModels.map(m => (
|
||||
<option key={m.version_id} value={m.version_id}>
|
||||
{m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
||||
@@ -138,10 +144,23 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
{baseModelType === 'pretrained'
|
||||
? 'Start from pretrained YOLO model'
|
||||
: 'Continue training from an existing model (incremental training)'}
|
||||
: 'Fine-tune from base model (freeze=10, cos_lr, data mixing)'}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Fine-tune info panel */}
|
||||
{isFineTune && (
|
||||
<div className="bg-warm-state-info/5 border border-warm-state-info/20 rounded-lg p-3 text-xs text-warm-text-secondary">
|
||||
<p className="font-medium text-warm-state-info mb-1">Fine-Tune Mode</p>
|
||||
<ul className="space-y-0.5 text-warm-text-muted">
|
||||
<li>Epochs: 10 (auto-set), Backbone frozen (10 layers)</li>
|
||||
<li>Cosine LR scheduler, Pool data mixed with old data</li>
|
||||
<li>Requires 50+ verified pool entries</li>
|
||||
<li>Deployment gating runs automatically after training</li>
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-4">
|
||||
<div className="flex-1">
|
||||
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
|
||||
@@ -455,6 +474,148 @@ const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitch
|
||||
)
|
||||
}
|
||||
|
||||
// --- Fine-Tune Pool ---
|
||||
|
||||
const FineTunePool: React.FC = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data: statsData, isLoading: isLoadingStats } = useQuery({
|
||||
queryKey: ['pool', 'stats'],
|
||||
queryFn: () => poolApi.getStats(),
|
||||
})
|
||||
|
||||
const { data: entriesData, isLoading: isLoadingEntries } = useQuery({
|
||||
queryKey: ['pool', 'entries'],
|
||||
queryFn: () => poolApi.listEntries({ limit: 50 }),
|
||||
})
|
||||
|
||||
const verifyMutation = useMutation({
|
||||
mutationFn: (entryId: string) => poolApi.verifyEntry(entryId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['pool'] })
|
||||
},
|
||||
})
|
||||
|
||||
const removeMutation = useMutation({
|
||||
mutationFn: (entryId: string) => poolApi.removeEntry(entryId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['pool'] })
|
||||
},
|
||||
})
|
||||
|
||||
const stats = statsData
|
||||
const entries = entriesData?.entries ?? []
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Pool Stats */}
|
||||
<div className="grid grid-cols-4 gap-4">
|
||||
{isLoadingStats ? (
|
||||
<div className="col-span-4 flex items-center justify-center py-8 text-warm-text-muted">
|
||||
<Loader2 size={20} className="animate-spin mr-2" />Loading stats...
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||
<p className="text-xs text-warm-text-muted uppercase mb-1">Total Entries</p>
|
||||
<p className="text-2xl font-bold font-mono text-warm-text-primary">{stats?.total_entries ?? 0}</p>
|
||||
</div>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||
<p className="text-xs text-warm-text-muted uppercase mb-1">Verified</p>
|
||||
<p className="text-2xl font-bold font-mono text-warm-state-success">{stats?.verified_entries ?? 0}</p>
|
||||
</div>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||
<p className="text-xs text-warm-text-muted uppercase mb-1">Unverified</p>
|
||||
<p className="text-2xl font-bold font-mono text-warm-state-warning">{stats?.unverified_entries ?? 0}</p>
|
||||
</div>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||
<p className="text-xs text-warm-text-muted uppercase mb-1">Ready for Fine-Tune</p>
|
||||
<div className="flex items-center gap-2">
|
||||
{stats?.is_ready ? (
|
||||
<CheckCircle size={20} className="text-warm-state-success" />
|
||||
) : (
|
||||
<AlertCircle size={20} className="text-warm-state-warning" />
|
||||
)}
|
||||
<p className="text-lg font-medium text-warm-text-primary">
|
||||
{stats?.is_ready ? 'Yes' : `Need ${(stats?.min_required ?? 50) - (stats?.verified_entries ?? 0)} more`}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Pool Entries Table */}
|
||||
{isLoadingEntries ? (
|
||||
<div className="flex items-center justify-center py-12 text-warm-text-muted">
|
||||
<Loader2 size={20} className="animate-spin mr-2" />Loading pool entries...
|
||||
</div>
|
||||
) : entries.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center py-16 text-warm-text-muted">
|
||||
<Shield size={48} className="mb-4 opacity-40" />
|
||||
<p className="text-lg mb-2">Fine-tune pool is empty</p>
|
||||
<p className="text-sm">Add documents with extraction failures to the pool for future fine-tuning.</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||
<table className="w-full text-left">
|
||||
<thead className="bg-white border-b border-warm-border">
|
||||
<tr>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Reason</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Added</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{entries.map((entry: PoolEntryItem) => (
|
||||
<tr key={entry.entry_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{entry.document_id.slice(0, 8)}...</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted">{entry.reason ?? '-'}</td>
|
||||
<td className="py-3 px-4">
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${
|
||||
entry.is_verified
|
||||
? 'bg-warm-state-success/10 text-warm-state-success'
|
||||
: 'bg-warm-state-warning/10 text-warm-state-warning'
|
||||
}`}>
|
||||
{entry.is_verified ? <Check size={12} className="mr-1" /> : <AlertCircle size={12} className="mr-1" />}
|
||||
{entry.is_verified ? 'Verified' : 'Unverified'}
|
||||
</span>
|
||||
</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(entry.created_at).toLocaleDateString()}</td>
|
||||
<td className="py-3 px-4">
|
||||
<div className="flex gap-1">
|
||||
{!entry.is_verified && (
|
||||
<button
|
||||
title="Verify"
|
||||
onClick={() => verifyMutation.mutate(entry.entry_id)}
|
||||
disabled={verifyMutation.isPending}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors"
|
||||
>
|
||||
<CheckCircle size={14} />
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
title="Remove"
|
||||
onClick={() => removeMutation.mutate(entry.entry_id)}
|
||||
disabled={removeMutation.isPending}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors"
|
||||
>
|
||||
<Trash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Main Training Component ---
|
||||
|
||||
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
||||
@@ -468,7 +629,7 @@ export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
||||
|
||||
{/* Tabs */}
|
||||
<div className="flex gap-1 mb-6 border-b border-warm-border">
|
||||
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
|
||||
{([['datasets', 'Datasets'], ['create', 'Create Dataset'], ['pool', 'Fine-Tune Pool']] as const).map(([key, label]) => (
|
||||
<button key={key} onClick={() => setActiveTab(key)}
|
||||
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
|
||||
activeTab === key
|
||||
@@ -482,6 +643,7 @@ export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
||||
|
||||
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
|
||||
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
|
||||
{activeTab === 'pool' && <FineTunePool />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user