Add more tests
This commit is contained in:
99
docs/Dashboard-UI-Prompts.md
Normal file
99
docs/Dashboard-UI-Prompts.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# Dashboard 原型图提示词
|
||||
|
||||
> 视觉风格:现代极简(Minimalism)- 保持现有 Warm 主题设计风格
|
||||
> 配色方案:Warm 浅色系(米白背景 #FAFAF8、白色卡片、深灰文字 #121212)
|
||||
> 目标平台:网页(Web Desktop)
|
||||
|
||||
---
|
||||
|
||||
## 当前颜色方案参考
|
||||
|
||||
| 用途 | 颜色值 | 说明 |
|
||||
|------|--------|------|
|
||||
| 页面背景 | #FAFAF8 | 温暖的米白色 |
|
||||
| 卡片背景 | #FFFFFF | 纯白 |
|
||||
| 边框 | #E6E4E1 | 浅灰褐色 |
|
||||
| 主文字 | #121212 | 近黑色 |
|
||||
| 次要文字 | #6B6B6B | 中灰色 |
|
||||
| 成功状态 | #3E4A3A + green-500 | 深橄榄绿 + 亮绿指示点 |
|
||||
| 警告状态 | #4A4A3A + yellow-50 | 深黄褐 + 浅黄背景 |
|
||||
| 信息状态 | #3A3A3A + blue-50 | 深灰 + 浅蓝背景 |
|
||||
|
||||
---
|
||||
|
||||
## 页面 1:Dashboard 主界面(正常状态)
|
||||
|
||||
**页面说明**:用户登录后的首页,显示文档统计、数据质量、活跃模型状态和最近活动
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
A modern web application dashboard UI for a document annotation system, main overview page, warm minimalist design theme, page background color #FAFAF8 warm off-white, single column layout with header navigation at top, content area below with multiple sections, top section shows: 4 equal-width stat cards in a row on white #FFFFFF background with subtle border #E6E4E1, first card Total Documents (38) with gray file icon on #FAFAF8 background, second card Complete (25) with dark olive green checkmark icon on light green #dcfce7 background, third card Incomplete (8) with orange alert icon on light orange #fef3c7 background, fourth card Pending (5) with blue clock icon on light blue #dbeafe background, each card has icon top-left in rounded square and large bold number in #121212 with label below in #6B6B6B, cards have subtle shadow on hover, middle section has two-column layout (50%/50%): left panel white card titled DATA QUALITY in uppercase #6B6B6B with circular progress ring 120px showing 78% in center with green #22C55E filled portion and gray #E5E7EB remaining, percentage text 36px bold #121212 centered in ring, text Annotation Complete next to ring, stats list below showing Complete 25 and Incomplete 8 and Pending 5 with small colored dots, text button View Incomplete Docs in primary color at bottom, right panel white card titled ACTIVE MODEL showing v1.2.0 - Invoice Model as title in bold #121212, thin horizontal divider #E6E4E1 below, three-column metrics row displaying mAP 95.1% and Precision 94% and Recall 92% in 24px bold with 12px labels below in #6B6B6B, info rows showing Activated 2024-01-20 and Documents 500 in 14px, training progress section at bottom showing Run-2024-02 with horizontal progress bar, below panels is full-width white card RECENT ACTIVITY section with list of 6 activity items each 40px height showing icon on left and description text in #121212 and relative timestamp in #6B6B6B right aligned, activity icons: rocket in purple for model activation, checkmark in green for training complete, edit pencil in orange for annotation modified, file in blue for document uploaded, x in red for training failed, subtle hover background #F1F0ED on activity rows, bottom section is SYSTEM STATUS white card showing Backend API Online with bright green #22C55E dot and Database Connected with green dot and GPU Available with green dot, all text in #2A2A2A, Inter font family, rounded corners 8px on all cards, subtle card shadow, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style, dribbble quality
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 页面 2:Dashboard 空状态(无活跃模型)
|
||||
|
||||
**页面说明**:系统刚部署或无训练模型时的引导界面
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
A modern web application dashboard UI for a document annotation system, empty state variation, warm minimalist design theme, page background #FAFAF8 warm off-white, single column layout with header navigation, top section shows: 4 stat cards on white background with #E6E4E1 border, all showing 0 values, Total Documents 0 with gray icon, Complete 0 with muted green, Incomplete 0 with muted orange, Pending 0 with muted blue, middle section two-column layout: left DATA QUALITY panel white card shows circular progress ring at 0% completely gray #E5E7EB with dashed outline style, large text 0% in #6B6B6B centered, text No data yet below in muted color, empty stats all showing 0, right ACTIVE MODEL panel white card shows empty state with large subtle model icon in center opacity 20%, text No Active Model as heading in #121212, subtext Train and activate a model to see stats here in #6B6B6B, primary button Go to Training at bottom, below panels RECENT ACTIVITY white card shows empty state with Activity icon centered at 20% opacity, text No recent activity in #121212, subtext Start by uploading documents or creating training jobs in #6B6B6B, bottom SYSTEM STATUS card showing all services online with green #22C55E dots, warm color palette throughout, Inter font, rounded corners 8px, subtle shadows, friendly and inviting empty state design, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 页面 3:Dashboard 训练中状态
|
||||
|
||||
**页面说明**:有模型正在训练时,Active Model 面板显示训练进度
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
A modern web application dashboard UI for a document annotation system, training in progress state, warm minimalist theme with #FAFAF8 background, header with navigation, top section: 4 white stat cards with #E6E4E1 borders showing Total Documents 38, Complete 25 with green icon on #dcfce7, Incomplete 8 with orange icon on #fef3c7, Pending 5 with blue icon on #dbeafe, middle section two-column layout: left DATA QUALITY white card with 78% progress ring in green #22C55E, stats list showing counts, right ACTIVE MODEL white card showing current model v1.1.0 in bold #121212 with metrics mAP 93.5% Precision 92% Recall 88% in grid, below a highlighted training section with subtle blue tint background #EFF6FF, pulsing blue dot indicator, text Training in Progress in #121212, task name Run-2024-02, horizontal progress bar 45% complete with blue #3B82F6 fill and gray #E5E7EB track, text Started 2 hours ago in #6B6B6B below, RECENT ACTIVITY white card below with latest item showing blue spinner icon and Training started Run-2024-02, other activities listed with appropriate icons, SYSTEM STATUS card at bottom showing GPU Available highlighted with green dot indicating active usage, warm color scheme throughout, Inter font, 8px rounded corners, subtle card shadows, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 页面 4:Dashboard 移动端响应式
|
||||
|
||||
**页面说明**:移动端(<768px)下的单列堆叠布局
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
A modern mobile web application dashboard UI for a document annotation system, responsive mobile layout on smartphone screen, warm minimalist theme with #FAFAF8 background, single column stacked layout, top shows condensed header with hamburger menu icon and logo, below 2x2 grid of compact white stat cards with #E6E4E1 borders showing Total 38 Complete 25 Incomplete 8 Pending 5 with small colored icons on tinted backgrounds, DATA QUALITY section below as full-width white card with smaller progress ring 80px showing 78% in green #22C55E, horizontal stats row compact, ACTIVE MODEL section below as full-width white card with model name v1.2.0 in bold, compact metrics row showing mAP Precision Recall values, RECENT ACTIVITY section full-width white card with scrollable list of 4 visible items with icons and timestamps in #6B6B6B, compact SYSTEM STATUS bar at bottom with three green #22C55E status dots, warm color palette #FAFAF8 background white cards #121212 text, Inter font, touch-friendly tap targets 44px minimum, comfortable 16px padding, 8px rounded corners, iOS/Android native feel, UI/UX design, high fidelity mockup, mobile screen 375x812 iPhone size, professional, Figma style
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 使用说明
|
||||
|
||||
1. 将提示词复制到 AI 绘图工具(如 Midjourney、DALL-E、Stable Diffusion)
|
||||
2. 建议先生成「页面 1:主界面」验证风格是否匹配现有设计
|
||||
3. 提示词已包含你现有的颜色方案:
|
||||
- 页面背景:#FAFAF8(温暖米白)
|
||||
- 卡片背景:#FFFFFF(白色)
|
||||
- 边框:#E6E4E1(浅灰褐)
|
||||
- 主文字:#121212(近黑)
|
||||
- 次要文字:#6B6B6B(中灰)
|
||||
- 成功色:#22C55E(亮绿)/ #3E4A3A(深橄榄绿文字)
|
||||
- 图标背景:#dcfce7(浅绿)/ #fef3c7(浅黄)/ #dbeafe(浅蓝)
|
||||
4. 如果生成结果颜色有偏差,可以在后期用 Figma 调整
|
||||
|
||||
---
|
||||
|
||||
## Tailwind 类参考(开发用)
|
||||
|
||||
```
|
||||
背景:bg-warm-bg (#FAFAF8)
|
||||
卡片:bg-warm-card (#FFFFFF)
|
||||
边框:border-warm-border (#E6E4E1)
|
||||
主文字:text-warm-text-primary (#121212)
|
||||
次要文字:text-warm-text-secondary (#2A2A2A)
|
||||
灰色文字:text-warm-text-muted (#6B6B6B)
|
||||
悬停背景:bg-warm-hover (#F1F0ED)
|
||||
成功状态:text-warm-state-success (#3E4A3A)
|
||||
绿色图标背景:bg-green-50 (#dcfce7)
|
||||
黄色图标背景:bg-yellow-50 (#fef3c7)
|
||||
蓝色图标背景:bg-blue-50 (#dbeafe)
|
||||
绿色指示点:bg-green-500 (#22C55E)
|
||||
```
|
||||
1690
docs/FORTNOX_INTEGRATION_SPEC.md
Normal file
1690
docs/FORTNOX_INTEGRATION_SPEC.md
Normal file
File diff suppressed because it is too large
Load Diff
25
frontend/src/api/endpoints/dashboard.ts
Normal file
25
frontend/src/api/endpoints/dashboard.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
DashboardStatsResponse,
|
||||
DashboardActiveModelResponse,
|
||||
RecentActivityResponse,
|
||||
} from '../types'
|
||||
|
||||
export const dashboardApi = {
|
||||
getStats: async (): Promise<DashboardStatsResponse> => {
|
||||
const response = await apiClient.get('/api/v1/admin/dashboard/stats')
|
||||
return response.data
|
||||
},
|
||||
|
||||
getActiveModel: async (): Promise<DashboardActiveModelResponse> => {
|
||||
const response = await apiClient.get('/api/v1/admin/dashboard/active-model')
|
||||
return response.data
|
||||
},
|
||||
|
||||
getRecentActivity: async (limit: number = 10): Promise<RecentActivityResponse> => {
|
||||
const response = await apiClient.get('/api/v1/admin/dashboard/activity', {
|
||||
params: { limit },
|
||||
})
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
@@ -5,3 +5,4 @@ export { inferenceApi } from './inference'
|
||||
export { datasetsApi } from './datasets'
|
||||
export { augmentationApi } from './augmentation'
|
||||
export { modelsApi } from './models'
|
||||
export { dashboardApi } from './dashboard'
|
||||
|
||||
@@ -362,3 +362,48 @@ export interface ActiveModelResponse {
|
||||
has_active_model: boolean
|
||||
model: ModelVersionItem | null
|
||||
}
|
||||
|
||||
// Dashboard types
|
||||
|
||||
export interface DashboardStatsResponse {
|
||||
total_documents: number
|
||||
annotation_complete: number
|
||||
annotation_incomplete: number
|
||||
pending: number
|
||||
completeness_rate: number
|
||||
}
|
||||
|
||||
export interface DashboardActiveModelInfo {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
metrics_mAP: number | null
|
||||
metrics_precision: number | null
|
||||
metrics_recall: number | null
|
||||
document_count: number
|
||||
activated_at: string | null
|
||||
}
|
||||
|
||||
export interface DashboardRunningTrainingInfo {
|
||||
task_id: string
|
||||
name: string
|
||||
status: string
|
||||
started_at: string | null
|
||||
progress: number
|
||||
}
|
||||
|
||||
export interface DashboardActiveModelResponse {
|
||||
model: DashboardActiveModelInfo | null
|
||||
running_training: DashboardRunningTrainingInfo | null
|
||||
}
|
||||
|
||||
export interface ActivityItem {
|
||||
type: 'document_uploaded' | 'annotation_modified' | 'training_completed' | 'training_failed' | 'model_activated'
|
||||
description: string
|
||||
timestamp: string
|
||||
metadata: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface RecentActivityResponse {
|
||||
activities: ActivityItem[]
|
||||
}
|
||||
|
||||
@@ -1,47 +1,58 @@
|
||||
import React from 'react'
|
||||
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import { useTraining } from '../hooks/useTraining'
|
||||
import { FileText, CheckCircle, AlertCircle, Clock, RefreshCw } from 'lucide-react'
|
||||
import {
|
||||
StatsCard,
|
||||
DataQualityPanel,
|
||||
ActiveModelPanel,
|
||||
RecentActivityPanel,
|
||||
SystemStatusBar,
|
||||
} from './dashboard/index'
|
||||
import { useDashboard } from '../hooks/useDashboard'
|
||||
|
||||
interface DashboardOverviewProps {
|
||||
onNavigate: (view: string) => void
|
||||
}
|
||||
|
||||
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
|
||||
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
|
||||
const { models, isLoadingModels } = useTraining()
|
||||
const {
|
||||
stats,
|
||||
model,
|
||||
runningTraining,
|
||||
activities,
|
||||
isLoading,
|
||||
error,
|
||||
} = useDashboard()
|
||||
|
||||
const stats = [
|
||||
{
|
||||
label: 'Total Documents',
|
||||
value: docsLoading ? '...' : totalDocs.toString(),
|
||||
icon: FileText,
|
||||
color: 'text-warm-text-primary',
|
||||
bgColor: 'bg-warm-bg',
|
||||
},
|
||||
{
|
||||
label: 'Labeled',
|
||||
value: '0',
|
||||
icon: CheckCircle,
|
||||
color: 'text-warm-state-success',
|
||||
bgColor: 'bg-green-50',
|
||||
},
|
||||
{
|
||||
label: 'Pending',
|
||||
value: '0',
|
||||
icon: Clock,
|
||||
color: 'text-warm-state-warning',
|
||||
bgColor: 'bg-yellow-50',
|
||||
},
|
||||
{
|
||||
label: 'Training Models',
|
||||
value: isLoadingModels ? '...' : models.length.toString(),
|
||||
icon: TrendingUp,
|
||||
color: 'text-warm-state-info',
|
||||
bgColor: 'bg-blue-50',
|
||||
},
|
||||
]
|
||||
const handleStatsClick = (filter?: string) => {
|
||||
if (filter) {
|
||||
onNavigate(`documents?status=${filter}`)
|
||||
} else {
|
||||
onNavigate('documents')
|
||||
}
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<div className="bg-red-50 border border-red-200 rounded-lg p-6 text-center">
|
||||
<AlertCircle className="w-12 h-12 text-red-500 mx-auto mb-4" />
|
||||
<h2 className="text-lg font-semibold text-red-800 mb-2">
|
||||
Failed to load dashboard
|
||||
</h2>
|
||||
<p className="text-sm text-red-600 mb-4">
|
||||
{error instanceof Error ? error.message : 'An unexpected error occurred'}
|
||||
</p>
|
||||
<button
|
||||
onClick={() => window.location.reload()}
|
||||
className="inline-flex items-center gap-2 px-4 py-2 bg-red-100 hover:bg-red-200 text-red-800 rounded-md text-sm font-medium transition-colors"
|
||||
>
|
||||
<RefreshCw className="w-4 h-4" />
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||
@@ -55,94 +66,74 @@ export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Stats Grid */}
|
||||
{/* Stats Cards Row */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
|
||||
{stats.map((stat) => (
|
||||
<div
|
||||
key={stat.label}
|
||||
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
|
||||
>
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
|
||||
<stat.icon className={stat.color} size={24} />
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-2xl font-bold text-warm-text-primary mb-1">
|
||||
{stat.value}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">{stat.label}</p>
|
||||
</div>
|
||||
))}
|
||||
<StatsCard
|
||||
label="Total Documents"
|
||||
value={stats?.total_documents ?? 0}
|
||||
icon={FileText}
|
||||
iconColor="text-warm-text-primary"
|
||||
iconBgColor="bg-warm-bg"
|
||||
isLoading={isLoading}
|
||||
onClick={() => handleStatsClick()}
|
||||
/>
|
||||
<StatsCard
|
||||
label="Complete"
|
||||
value={stats?.annotation_complete ?? 0}
|
||||
icon={CheckCircle}
|
||||
iconColor="text-warm-state-success"
|
||||
iconBgColor="bg-green-50"
|
||||
isLoading={isLoading}
|
||||
onClick={() => handleStatsClick('labeled')}
|
||||
/>
|
||||
<StatsCard
|
||||
label="Incomplete"
|
||||
value={stats?.annotation_incomplete ?? 0}
|
||||
icon={AlertCircle}
|
||||
iconColor="text-orange-600"
|
||||
iconBgColor="bg-orange-50"
|
||||
isLoading={isLoading}
|
||||
onClick={() => handleStatsClick('labeled')}
|
||||
/>
|
||||
<StatsCard
|
||||
label="Pending"
|
||||
value={stats?.pending ?? 0}
|
||||
icon={Clock}
|
||||
iconColor="text-blue-600"
|
||||
iconBgColor="bg-blue-50"
|
||||
isLoading={isLoading}
|
||||
onClick={() => handleStatsClick('pending')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Quick Actions */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
Quick Actions
|
||||
</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<Button onClick={() => onNavigate('documents')} className="justify-start">
|
||||
<FileText size={18} className="mr-2" />
|
||||
Manage Documents
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
|
||||
<Activity size={18} className="mr-2" />
|
||||
Start Training
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
|
||||
<TrendingUp size={18} className="mr-2" />
|
||||
View Models
|
||||
</Button>
|
||||
</div>
|
||||
{/* Two-column layout: Data Quality + Active Model */}
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6 mb-8">
|
||||
<DataQualityPanel
|
||||
completenessRate={stats?.completeness_rate ?? 0}
|
||||
completeCount={stats?.annotation_complete ?? 0}
|
||||
incompleteCount={stats?.annotation_incomplete ?? 0}
|
||||
pendingCount={stats?.pending ?? 0}
|
||||
isLoading={isLoading}
|
||||
onViewIncomplete={() => handleStatsClick('labeled')}
|
||||
/>
|
||||
<ActiveModelPanel
|
||||
model={model}
|
||||
runningTraining={runningTraining}
|
||||
isLoading={isLoading}
|
||||
onGoToTraining={() => onNavigate('training')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Recent Activity */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<div className="p-6 border-b border-warm-border">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary">
|
||||
Recent Activity
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6">
|
||||
<div className="text-center py-8 text-warm-text-muted">
|
||||
<Activity size={48} className="mx-auto mb-3 opacity-20" />
|
||||
<p className="text-sm">No recent activity</p>
|
||||
<p className="text-xs mt-1">
|
||||
Start by uploading documents or creating training jobs
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<RecentActivityPanel
|
||||
activities={activities}
|
||||
isLoading={isLoading}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* System Status */}
|
||||
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
System Status
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Backend API</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Online
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Database</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Connected
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">GPU</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Available
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<SystemStatusBar />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
143
frontend/src/components/dashboard/ActiveModelPanel.tsx
Normal file
143
frontend/src/components/dashboard/ActiveModelPanel.tsx
Normal file
@@ -0,0 +1,143 @@
|
||||
import React from 'react'
|
||||
import { TrendingUp } from 'lucide-react'
|
||||
import { Button } from '../Button'
|
||||
import type { DashboardActiveModelInfo, DashboardRunningTrainingInfo } from '../../api/types'
|
||||
|
||||
interface ActiveModelPanelProps {
|
||||
model: DashboardActiveModelInfo | null
|
||||
runningTraining: DashboardRunningTrainingInfo | null
|
||||
isLoading?: boolean
|
||||
onGoToTraining?: () => void
|
||||
}
|
||||
|
||||
const formatDate = (dateStr: string | null): string => {
|
||||
if (!dateStr) return 'N/A'
|
||||
const date = new Date(dateStr)
|
||||
return date.toLocaleDateString('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
})
|
||||
}
|
||||
|
||||
const formatMetric = (value: number | null): string => {
|
||||
if (value === null) return 'N/A'
|
||||
return `${(value * 100).toFixed(1)}%`
|
||||
}
|
||||
|
||||
const getMetricColor = (value: number | null): string => {
|
||||
if (value === null) return 'text-warm-text-muted'
|
||||
if (value >= 0.9) return 'text-green-600'
|
||||
if (value >= 0.8) return 'text-yellow-600'
|
||||
return 'text-red-600'
|
||||
}
|
||||
|
||||
export const ActiveModelPanel: React.FC<ActiveModelPanelProps> = ({
|
||||
model,
|
||||
runningTraining,
|
||||
isLoading = false,
|
||||
onGoToTraining,
|
||||
}) => {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
|
||||
Active Model
|
||||
</h2>
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<div className="animate-pulse text-warm-text-muted">Loading...</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
|
||||
Active Model
|
||||
</h2>
|
||||
<div className="flex flex-col items-center justify-center py-8 text-center">
|
||||
<TrendingUp className="w-12 h-12 text-warm-text-disabled mb-3 opacity-20" />
|
||||
<p className="text-warm-text-primary font-medium mb-1">No Active Model</p>
|
||||
<p className="text-sm text-warm-text-muted mb-4">
|
||||
Train and activate a model to see stats here
|
||||
</p>
|
||||
{onGoToTraining && (
|
||||
<Button onClick={onGoToTraining} variant="primary" size="sm">
|
||||
Go to Training
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
|
||||
Active Model
|
||||
</h2>
|
||||
|
||||
<div className="mb-4">
|
||||
<span className="text-lg font-bold text-warm-text-primary">{model.version}</span>
|
||||
<span className="text-warm-text-secondary ml-2">- {model.name}</span>
|
||||
</div>
|
||||
|
||||
<div className="border-t border-warm-border pt-4 mb-4">
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div className="text-center">
|
||||
<p className={`text-2xl font-bold ${getMetricColor(model.metrics_mAP)}`}>
|
||||
{formatMetric(model.metrics_mAP)}
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted uppercase">mAP</p>
|
||||
</div>
|
||||
<div className="text-center">
|
||||
<p className={`text-2xl font-bold ${getMetricColor(model.metrics_precision)}`}>
|
||||
{formatMetric(model.metrics_precision)}
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted uppercase">Precision</p>
|
||||
</div>
|
||||
<div className="text-center">
|
||||
<p className={`text-2xl font-bold ${getMetricColor(model.metrics_recall)}`}>
|
||||
{formatMetric(model.metrics_recall)}
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted uppercase">Recall</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-1 text-sm text-warm-text-secondary">
|
||||
<p>
|
||||
<span className="text-warm-text-muted">Activated:</span>{' '}
|
||||
{formatDate(model.activated_at)}
|
||||
</p>
|
||||
<p>
|
||||
<span className="text-warm-text-muted">Documents:</span>{' '}
|
||||
{model.document_count.toLocaleString()}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{runningTraining && (
|
||||
<div className="mt-4 p-3 bg-blue-50 rounded-lg border border-blue-100">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<span className="w-2 h-2 bg-blue-500 rounded-full animate-pulse" />
|
||||
<span className="text-sm font-medium text-warm-text-primary">
|
||||
Training in Progress
|
||||
</span>
|
||||
</div>
|
||||
<p className="text-sm text-warm-text-secondary mb-2">{runningTraining.name}</p>
|
||||
<div className="w-full bg-gray-200 rounded-full h-2">
|
||||
<div
|
||||
className="bg-blue-500 h-2 rounded-full transition-all duration-500"
|
||||
style={{ width: `${runningTraining.progress}%` }}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
{runningTraining.progress}% complete
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
105
frontend/src/components/dashboard/DataQualityPanel.tsx
Normal file
105
frontend/src/components/dashboard/DataQualityPanel.tsx
Normal file
@@ -0,0 +1,105 @@
|
||||
import React from 'react'
|
||||
import { Button } from '../Button'
|
||||
|
||||
interface DataQualityPanelProps {
|
||||
completenessRate: number
|
||||
completeCount: number
|
||||
incompleteCount: number
|
||||
pendingCount: number
|
||||
isLoading?: boolean
|
||||
onViewIncomplete?: () => void
|
||||
}
|
||||
|
||||
export const DataQualityPanel: React.FC<DataQualityPanelProps> = ({
|
||||
completenessRate,
|
||||
completeCount,
|
||||
incompleteCount,
|
||||
pendingCount,
|
||||
isLoading = false,
|
||||
onViewIncomplete,
|
||||
}) => {
|
||||
const radius = 54
|
||||
const circumference = 2 * Math.PI * radius
|
||||
const strokeDashoffset = circumference - (completenessRate / 100) * circumference
|
||||
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
|
||||
Data Quality
|
||||
</h2>
|
||||
|
||||
<div className="flex items-center gap-6">
|
||||
<div className="relative">
|
||||
<svg width="120" height="120" className="transform -rotate-90">
|
||||
<circle
|
||||
cx="60"
|
||||
cy="60"
|
||||
r={radius}
|
||||
stroke="#E5E7EB"
|
||||
strokeWidth="12"
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
cx="60"
|
||||
cy="60"
|
||||
r={radius}
|
||||
stroke="#22C55E"
|
||||
strokeWidth="12"
|
||||
fill="none"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={circumference}
|
||||
strokeDashoffset={isLoading ? circumference : strokeDashoffset}
|
||||
className="transition-all duration-500"
|
||||
/>
|
||||
</svg>
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<span className="text-3xl font-bold text-warm-text-primary">
|
||||
{isLoading ? '...' : `${Math.round(completenessRate)}%`}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-1">
|
||||
<p className="text-sm text-warm-text-secondary mb-4">
|
||||
Annotation Complete
|
||||
</p>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between text-sm">
|
||||
<span className="flex items-center gap-2">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full" />
|
||||
Complete
|
||||
</span>
|
||||
<span className="font-medium">{isLoading ? '...' : completeCount}</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between text-sm">
|
||||
<span className="flex items-center gap-2">
|
||||
<span className="w-2 h-2 bg-orange-500 rounded-full" />
|
||||
Incomplete
|
||||
</span>
|
||||
<span className="font-medium">{isLoading ? '...' : incompleteCount}</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between text-sm">
|
||||
<span className="flex items-center gap-2">
|
||||
<span className="w-2 h-2 bg-blue-500 rounded-full" />
|
||||
Pending
|
||||
</span>
|
||||
<span className="font-medium">{isLoading ? '...' : pendingCount}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{onViewIncomplete && incompleteCount > 0 && (
|
||||
<div className="mt-4 pt-4 border-t border-warm-border">
|
||||
<button
|
||||
onClick={onViewIncomplete}
|
||||
className="text-sm text-blue-600 hover:text-blue-800 font-medium"
|
||||
>
|
||||
View Incomplete Docs
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
134
frontend/src/components/dashboard/RecentActivityPanel.tsx
Normal file
134
frontend/src/components/dashboard/RecentActivityPanel.tsx
Normal file
@@ -0,0 +1,134 @@
|
||||
import React from 'react'
|
||||
import {
|
||||
FileText,
|
||||
Edit,
|
||||
CheckCircle,
|
||||
XCircle,
|
||||
Rocket,
|
||||
Activity,
|
||||
} from 'lucide-react'
|
||||
import type { ActivityItem } from '../../api/types'
|
||||
|
||||
interface RecentActivityPanelProps {
|
||||
activities: ActivityItem[]
|
||||
isLoading?: boolean
|
||||
onSeeAll?: () => void
|
||||
}
|
||||
|
||||
const getActivityIcon = (type: ActivityItem['type']) => {
|
||||
switch (type) {
|
||||
case 'document_uploaded':
|
||||
return { Icon: FileText, color: 'text-blue-500', bg: 'bg-blue-50' }
|
||||
case 'annotation_modified':
|
||||
return { Icon: Edit, color: 'text-orange-500', bg: 'bg-orange-50' }
|
||||
case 'training_completed':
|
||||
return { Icon: CheckCircle, color: 'text-green-500', bg: 'bg-green-50' }
|
||||
case 'training_failed':
|
||||
return { Icon: XCircle, color: 'text-red-500', bg: 'bg-red-50' }
|
||||
case 'model_activated':
|
||||
return { Icon: Rocket, color: 'text-purple-500', bg: 'bg-purple-50' }
|
||||
default:
|
||||
return { Icon: Activity, color: 'text-gray-500', bg: 'bg-gray-50' }
|
||||
}
|
||||
}
|
||||
|
||||
const formatTimestamp = (timestamp: string): string => {
|
||||
const date = new Date(timestamp)
|
||||
const now = new Date()
|
||||
const diffMs = now.getTime() - date.getTime()
|
||||
const diffMinutes = Math.floor(diffMs / 60000)
|
||||
const diffHours = Math.floor(diffMs / 3600000)
|
||||
const diffDays = Math.floor(diffMs / 86400000)
|
||||
|
||||
if (diffMinutes < 1) return 'just now'
|
||||
if (diffMinutes < 60) return `${diffMinutes} minutes ago`
|
||||
if (diffHours < 24) return `${diffHours} hours ago`
|
||||
if (diffDays === 1) return 'yesterday'
|
||||
if (diffDays < 7) return `${diffDays} days ago`
|
||||
|
||||
return date.toLocaleDateString('en-US', { month: 'short', day: 'numeric' })
|
||||
}
|
||||
|
||||
export const RecentActivityPanel: React.FC<RecentActivityPanelProps> = ({
|
||||
activities,
|
||||
isLoading = false,
|
||||
onSeeAll,
|
||||
}) => {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<div className="p-6 border-b border-warm-border flex items-center justify-between">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide">
|
||||
Recent Activity
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6">
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<div className="animate-pulse text-warm-text-muted">Loading...</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (activities.length === 0) {
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<div className="p-6 border-b border-warm-border">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide">
|
||||
Recent Activity
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6">
|
||||
<div className="flex flex-col items-center justify-center py-8 text-center">
|
||||
<Activity className="w-12 h-12 text-warm-text-disabled mb-3 opacity-20" />
|
||||
<p className="text-warm-text-primary font-medium mb-1">No recent activity</p>
|
||||
<p className="text-sm text-warm-text-muted">
|
||||
Start by uploading documents or creating training jobs
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<div className="p-6 border-b border-warm-border flex items-center justify-between">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide">
|
||||
Recent Activity
|
||||
</h2>
|
||||
{onSeeAll && (
|
||||
<button
|
||||
onClick={onSeeAll}
|
||||
className="text-sm text-blue-600 hover:text-blue-800 font-medium"
|
||||
>
|
||||
See All
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<div className="divide-y divide-warm-border">
|
||||
{activities.map((activity, index) => {
|
||||
const { Icon, color, bg } = getActivityIcon(activity.type)
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${activity.type}-${activity.timestamp}-${index}`}
|
||||
className="px-6 py-3 flex items-center gap-4 hover:bg-warm-hover transition-colors"
|
||||
>
|
||||
<div className={`p-2 rounded-lg ${bg}`}>
|
||||
<Icon className={color} size={16} />
|
||||
</div>
|
||||
<p className="flex-1 text-sm text-warm-text-primary truncate">
|
||||
{activity.description}
|
||||
</p>
|
||||
<span className="text-xs text-warm-text-muted whitespace-nowrap">
|
||||
{formatTimestamp(activity.timestamp)}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
44
frontend/src/components/dashboard/StatsCard.tsx
Normal file
44
frontend/src/components/dashboard/StatsCard.tsx
Normal file
@@ -0,0 +1,44 @@
|
||||
import React from 'react'
|
||||
import { LucideIcon } from 'lucide-react'
|
||||
|
||||
interface StatsCardProps {
|
||||
label: string
|
||||
value: string | number
|
||||
icon: LucideIcon
|
||||
iconColor: string
|
||||
iconBgColor: string
|
||||
onClick?: () => void
|
||||
isLoading?: boolean
|
||||
}
|
||||
|
||||
export const StatsCard: React.FC<StatsCardProps> = ({
|
||||
label,
|
||||
value,
|
||||
icon: Icon,
|
||||
iconColor,
|
||||
iconBgColor,
|
||||
onClick,
|
||||
isLoading = false,
|
||||
}) => {
|
||||
return (
|
||||
<div
|
||||
className={`bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow ${
|
||||
onClick ? 'cursor-pointer' : ''
|
||||
}`}
|
||||
onClick={onClick}
|
||||
role={onClick ? 'button' : undefined}
|
||||
tabIndex={onClick ? 0 : undefined}
|
||||
onKeyDown={onClick ? (e) => e.key === 'Enter' && onClick() : undefined}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div className={`p-3 rounded-lg ${iconBgColor}`}>
|
||||
<Icon className={iconColor} size={24} />
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-2xl font-bold text-warm-text-primary mb-1">
|
||||
{isLoading ? '...' : value}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">{label}</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
62
frontend/src/components/dashboard/SystemStatusBar.tsx
Normal file
62
frontend/src/components/dashboard/SystemStatusBar.tsx
Normal file
@@ -0,0 +1,62 @@
|
||||
import React from 'react'
|
||||
|
||||
interface StatusItem {
|
||||
label: string
|
||||
status: 'online' | 'degraded' | 'offline'
|
||||
statusText: string
|
||||
}
|
||||
|
||||
interface SystemStatusBarProps {
|
||||
items?: StatusItem[]
|
||||
}
|
||||
|
||||
const getStatusColor = (status: StatusItem['status']) => {
|
||||
switch (status) {
|
||||
case 'online':
|
||||
return 'bg-green-500'
|
||||
case 'degraded':
|
||||
return 'bg-yellow-500'
|
||||
case 'offline':
|
||||
return 'bg-red-500'
|
||||
}
|
||||
}
|
||||
|
||||
const getStatusTextColor = (status: StatusItem['status']) => {
|
||||
switch (status) {
|
||||
case 'online':
|
||||
return 'text-warm-state-success'
|
||||
case 'degraded':
|
||||
return 'text-yellow-600'
|
||||
case 'offline':
|
||||
return 'text-red-600'
|
||||
}
|
||||
}
|
||||
|
||||
const defaultItems: StatusItem[] = [
|
||||
{ label: 'Backend API', status: 'online', statusText: 'Online' },
|
||||
{ label: 'Database', status: 'online', statusText: 'Connected' },
|
||||
{ label: 'GPU', status: 'online', statusText: 'Available' },
|
||||
]
|
||||
|
||||
export const SystemStatusBar: React.FC<SystemStatusBarProps> = ({
|
||||
items = defaultItems,
|
||||
}) => {
|
||||
return (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
|
||||
System Status
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
{items.map((item) => (
|
||||
<div key={item.label} className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">{item.label}</span>
|
||||
<span className={`flex items-center text-sm ${getStatusTextColor(item.status)}`}>
|
||||
<span className={`w-2 h-2 ${getStatusColor(item.status)} rounded-full mr-2`} />
|
||||
{item.statusText}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
5
frontend/src/components/dashboard/index.ts
Normal file
5
frontend/src/components/dashboard/index.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
export { StatsCard } from './StatsCard'
|
||||
export { DataQualityPanel } from './DataQualityPanel'
|
||||
export { ActiveModelPanel } from './ActiveModelPanel'
|
||||
export { RecentActivityPanel } from './RecentActivityPanel'
|
||||
export { SystemStatusBar } from './SystemStatusBar'
|
||||
@@ -5,3 +5,4 @@ export { useTraining, useTrainingDocuments } from './useTraining'
|
||||
export { useDatasets, useDatasetDetail } from './useDatasets'
|
||||
export { useAugmentation } from './useAugmentation'
|
||||
export { useModels, useModelDetail, useActiveModel } from './useModels'
|
||||
export { useDashboard, useDashboardStats, useActiveModel as useDashboardActiveModel, useRecentActivity } from './useDashboard'
|
||||
|
||||
76
frontend/src/hooks/useDashboard.ts
Normal file
76
frontend/src/hooks/useDashboard.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { dashboardApi } from '../api/endpoints'
|
||||
import type {
|
||||
DashboardStatsResponse,
|
||||
DashboardActiveModelResponse,
|
||||
RecentActivityResponse,
|
||||
} from '../api/types'
|
||||
|
||||
export const useDashboardStats = () => {
|
||||
const { data, isLoading, error, refetch } = useQuery<DashboardStatsResponse>({
|
||||
queryKey: ['dashboard', 'stats'],
|
||||
queryFn: () => dashboardApi.getStats(),
|
||||
staleTime: 30000,
|
||||
refetchInterval: 60000,
|
||||
})
|
||||
|
||||
return {
|
||||
stats: data,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
}
|
||||
}
|
||||
|
||||
export const useActiveModel = () => {
|
||||
const { data, isLoading, error, refetch } = useQuery<DashboardActiveModelResponse>({
|
||||
queryKey: ['dashboard', 'active-model'],
|
||||
queryFn: () => dashboardApi.getActiveModel(),
|
||||
staleTime: 30000,
|
||||
refetchInterval: 60000,
|
||||
})
|
||||
|
||||
return {
|
||||
model: data?.model ?? null,
|
||||
runningTraining: data?.running_training ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
}
|
||||
}
|
||||
|
||||
export const useRecentActivity = (limit: number = 10) => {
|
||||
const { data, isLoading, error, refetch } = useQuery<RecentActivityResponse>({
|
||||
queryKey: ['dashboard', 'activity', limit],
|
||||
queryFn: () => dashboardApi.getRecentActivity(limit),
|
||||
staleTime: 30000,
|
||||
refetchInterval: 60000,
|
||||
})
|
||||
|
||||
return {
|
||||
activities: data?.activities ?? [],
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
}
|
||||
}
|
||||
|
||||
export const useDashboard = () => {
|
||||
const stats = useDashboardStats()
|
||||
const activeModel = useActiveModel()
|
||||
const activity = useRecentActivity()
|
||||
|
||||
return {
|
||||
stats: stats.stats,
|
||||
model: activeModel.model,
|
||||
runningTraining: activeModel.runningTraining,
|
||||
activities: activity.activities,
|
||||
isLoading: stats.isLoading || activeModel.isLoading || activity.isLoading,
|
||||
error: stats.error || activeModel.error || activity.error,
|
||||
refetch: () => {
|
||||
stats.refetch()
|
||||
activeModel.refetch()
|
||||
activity.refetch()
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -175,6 +175,80 @@ def run_migrations() -> None:
|
||||
);
|
||||
""",
|
||||
),
|
||||
# Migration 007: Add extra columns to training_tasks
|
||||
(
|
||||
"training_tasks_name",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
|
||||
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
|
||||
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_description",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_admin_token",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_task_type",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_recurring",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_metrics",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
|
||||
""",
|
||||
),
|
||||
(
|
||||
"training_tasks_updated_at",
|
||||
"""
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
|
||||
""",
|
||||
),
|
||||
# Migration 008: Fix model_versions foreign key constraints
|
||||
(
|
||||
"model_versions_fk_fix",
|
||||
"""
|
||||
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
|
||||
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
|
||||
ALTER TABLE model_versions
|
||||
ADD CONSTRAINT model_versions_dataset_id_fkey
|
||||
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
|
||||
ALTER TABLE model_versions
|
||||
ADD CONSTRAINT model_versions_task_id_fkey
|
||||
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
|
||||
""",
|
||||
),
|
||||
# Migration 006b: Ensure only one active model at a time
|
||||
(
|
||||
"model_versions_single_active",
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
|
||||
ON model_versions(is_active) WHERE is_active = TRUE;
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
with engine.connect() as conn:
|
||||
|
||||
@@ -193,6 +193,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]):
|
||||
annotation = session.get(AdminAnnotation, UUID(annotation_id))
|
||||
if annotation:
|
||||
session.delete(annotation)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -216,6 +217,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]):
|
||||
count = len(annotations)
|
||||
for ann in annotations:
|
||||
session.delete(ann)
|
||||
session.commit()
|
||||
return count
|
||||
|
||||
def verify(
|
||||
|
||||
@@ -203,6 +203,14 @@ class DatasetRepository(BaseRepository[TrainingDataset]):
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return False
|
||||
# Delete associated document links first
|
||||
doc_links = session.exec(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == UUID(str(dataset_id))
|
||||
)
|
||||
).all()
|
||||
for link in doc_links:
|
||||
session.delete(link)
|
||||
session.delete(dataset)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@@ -264,6 +264,7 @@ class DocumentRepository(BaseRepository[AdminDocument]):
|
||||
for ann in annotations:
|
||||
session.delete(ann)
|
||||
session.delete(document)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -389,7 +390,11 @@ class DocumentRepository(BaseRepository[AdminDocument]):
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if doc.annotation_lock_until and doc.annotation_lock_until > now:
|
||||
lock_until = doc.annotation_lock_until
|
||||
# Handle PostgreSQL returning offset-naive datetimes
|
||||
if lock_until and lock_until.tzinfo is None:
|
||||
lock_until = lock_until.replace(tzinfo=timezone.utc)
|
||||
if lock_until and lock_until > now:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
|
||||
@@ -433,10 +438,14 @@ class DocumentRepository(BaseRepository[AdminDocument]):
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
|
||||
lock_until = doc.annotation_lock_until
|
||||
# Handle PostgreSQL returning offset-naive datetimes
|
||||
if lock_until and lock_until.tzinfo is None:
|
||||
lock_until = lock_until.replace(tzinfo=timezone.utc)
|
||||
if not lock_until or lock_until <= now:
|
||||
return None
|
||||
|
||||
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
|
||||
doc.annotation_lock_until = lock_until + timedelta(seconds=additional_seconds)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
session.refresh(doc)
|
||||
|
||||
@@ -118,6 +118,22 @@ class TrainingTaskRepository(BaseRepository[TrainingTask]):
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_running(self) -> TrainingTask | None:
|
||||
"""Get currently running training task.
|
||||
|
||||
Returns:
|
||||
Running task or None if no task is running
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
result = session.exec(
|
||||
select(TrainingTask)
|
||||
.where(TrainingTask.status == "running")
|
||||
.order_by(TrainingTask.started_at.desc())
|
||||
).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
task_id: str,
|
||||
|
||||
@@ -55,5 +55,6 @@ def create_normalizer_registry(
|
||||
"Amount": amount_normalizer,
|
||||
"InvoiceDate": date_normalizer,
|
||||
"InvoiceDueDate": date_normalizer,
|
||||
"supplier_org_number": SupplierOrgNumberNormalizer(),
|
||||
# Note: field_name is "supplier_organisation_number" (from CLASS_TO_FIELD mapping)
|
||||
"supplier_organisation_number": SupplierOrgNumberNormalizer(),
|
||||
}
|
||||
|
||||
@@ -481,11 +481,22 @@ def create_annotation_router() -> APIRouter:
|
||||
detail="At least one field value is required",
|
||||
)
|
||||
|
||||
# Get the actual file path from storage
|
||||
# document.file_path is a relative storage path like "raw_pdfs/uuid.pdf"
|
||||
storage = get_storage_helper()
|
||||
filename = document.file_path.split("/")[-1] if "/" in document.file_path else document.file_path
|
||||
file_path = storage.get_raw_pdf_local_path(filename)
|
||||
if file_path is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cannot find PDF file: {document.file_path}",
|
||||
)
|
||||
|
||||
# Run auto-labeling
|
||||
service = get_auto_label_service()
|
||||
result = service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=document.file_path,
|
||||
file_path=str(file_path),
|
||||
field_values=request.field_values,
|
||||
doc_repo=doc_repo,
|
||||
ann_repo=ann_repo,
|
||||
|
||||
@@ -6,7 +6,7 @@ FastAPI endpoints for admin token management.
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
@@ -41,10 +41,10 @@ def create_auth_router() -> APIRouter:
|
||||
# Generate secure token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Calculate expiration
|
||||
# Calculate expiration (use timezone-aware datetime)
|
||||
expires_at = None
|
||||
if request.expires_in_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=request.expires_in_days)
|
||||
|
||||
# Create token in database
|
||||
tokens.create(
|
||||
|
||||
135
packages/inference/inference/web/api/v1/admin/dashboard.py
Normal file
135
packages/inference/inference/web/api/v1/admin/dashboard.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Dashboard API Routes
|
||||
|
||||
FastAPI endpoints for dashboard statistics and activity.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from inference.web.core.auth import (
|
||||
AdminTokenDep,
|
||||
get_model_version_repository,
|
||||
get_training_task_repository,
|
||||
ModelVersionRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
)
|
||||
from inference.web.schemas.admin import (
|
||||
DashboardStatsResponse,
|
||||
ActiveModelResponse,
|
||||
ActiveModelInfo,
|
||||
RunningTrainingInfo,
|
||||
RecentActivityResponse,
|
||||
ActivityItem,
|
||||
)
|
||||
from inference.web.services.dashboard_service import (
|
||||
DashboardStatsService,
|
||||
DashboardActivityService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_dashboard_router() -> APIRouter:
|
||||
"""Create dashboard API router."""
|
||||
router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"])
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=DashboardStatsResponse,
|
||||
summary="Get dashboard statistics",
|
||||
description="Returns document counts and annotation completeness metrics.",
|
||||
)
|
||||
async def get_dashboard_stats(
|
||||
admin_token: AdminTokenDep,
|
||||
) -> DashboardStatsResponse:
|
||||
"""Get dashboard statistics."""
|
||||
service = DashboardStatsService()
|
||||
stats = service.get_stats()
|
||||
|
||||
return DashboardStatsResponse(
|
||||
total_documents=stats["total_documents"],
|
||||
annotation_complete=stats["annotation_complete"],
|
||||
annotation_incomplete=stats["annotation_incomplete"],
|
||||
pending=stats["pending"],
|
||||
completeness_rate=stats["completeness_rate"],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/active-model",
|
||||
response_model=ActiveModelResponse,
|
||||
summary="Get active model info",
|
||||
description="Returns current active model and running training status.",
|
||||
)
|
||||
async def get_active_model(
|
||||
admin_token: AdminTokenDep,
|
||||
model_repo: ModelVersionRepoDep,
|
||||
task_repo: TrainingTaskRepoDep,
|
||||
) -> ActiveModelResponse:
|
||||
"""Get active model and training status."""
|
||||
# Get active model
|
||||
active_model = model_repo.get_active()
|
||||
model_info = None
|
||||
|
||||
if active_model:
|
||||
model_info = ActiveModelInfo(
|
||||
version_id=str(active_model.version_id),
|
||||
version=active_model.version,
|
||||
name=active_model.name,
|
||||
metrics_mAP=active_model.metrics_mAP,
|
||||
metrics_precision=active_model.metrics_precision,
|
||||
metrics_recall=active_model.metrics_recall,
|
||||
document_count=active_model.document_count,
|
||||
activated_at=active_model.activated_at,
|
||||
)
|
||||
|
||||
# Get running training task
|
||||
running_task = task_repo.get_running()
|
||||
training_info = None
|
||||
|
||||
if running_task:
|
||||
training_info = RunningTrainingInfo(
|
||||
task_id=str(running_task.task_id),
|
||||
name=running_task.name,
|
||||
status=running_task.status,
|
||||
started_at=running_task.started_at,
|
||||
progress=running_task.progress or 0,
|
||||
)
|
||||
|
||||
return ActiveModelResponse(
|
||||
model=model_info,
|
||||
running_training=training_info,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/activity",
|
||||
response_model=RecentActivityResponse,
|
||||
summary="Get recent activity",
|
||||
description="Returns recent system activities sorted by timestamp.",
|
||||
)
|
||||
async def get_recent_activity(
|
||||
admin_token: AdminTokenDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=50, description="Maximum number of activities"),
|
||||
] = 10,
|
||||
) -> RecentActivityResponse:
|
||||
"""Get recent system activity."""
|
||||
service = DashboardActivityService()
|
||||
activities = service.get_recent_activities(limit=limit)
|
||||
|
||||
return RecentActivityResponse(
|
||||
activities=[
|
||||
ActivityItem(
|
||||
type=act["type"],
|
||||
description=act["description"],
|
||||
timestamp=act["timestamp"],
|
||||
metadata=act["metadata"],
|
||||
)
|
||||
for act in activities
|
||||
]
|
||||
)
|
||||
|
||||
return router
|
||||
@@ -44,6 +44,7 @@ from inference.web.api.v1.admin import (
|
||||
create_locks_router,
|
||||
create_training_router,
|
||||
)
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from inference.web.core.scheduler import start_scheduler, stop_scheduler
|
||||
from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
|
||||
|
||||
@@ -115,13 +116,21 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting Invoice Inference API...")
|
||||
|
||||
# Initialize database tables
|
||||
# Initialize async request database tables
|
||||
try:
|
||||
async_db.create_tables()
|
||||
logger.info("Async database tables ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize async database: {e}")
|
||||
|
||||
# Initialize admin database tables (admin_tokens, admin_documents, training_tasks, etc.)
|
||||
try:
|
||||
from inference.data.database import create_db_and_tables
|
||||
create_db_and_tables()
|
||||
logger.info("Admin database tables ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize admin database: {e}")
|
||||
|
||||
# Initialize inference service on startup
|
||||
try:
|
||||
inference_service.initialize()
|
||||
@@ -279,6 +288,10 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
augmentation_router = create_augmentation_router()
|
||||
app.include_router(augmentation_router, prefix="/api/v1/admin")
|
||||
|
||||
# Include dashboard routes
|
||||
dashboard_router = create_dashboard_router()
|
||||
app.include_router(dashboard_router, prefix="/api/v1")
|
||||
|
||||
# Include batch upload routes
|
||||
app.include_router(batch_upload_router)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403
|
||||
from .training import * # noqa: F401, F403
|
||||
from .datasets import * # noqa: F401, F403
|
||||
from .models import * # noqa: F401, F403
|
||||
from .dashboard import * # noqa: F401, F403
|
||||
|
||||
# Resolve forward references for DocumentDetailResponse
|
||||
from .documents import DocumentDetailResponse
|
||||
|
||||
92
packages/inference/inference/web/schemas/admin/dashboard.py
Normal file
92
packages/inference/inference/web/schemas/admin/dashboard.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Dashboard API Schemas
|
||||
|
||||
Pydantic models for dashboard statistics and activity endpoints.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Activity type literals for type safety
|
||||
ActivityType = Literal[
|
||||
"document_uploaded",
|
||||
"annotation_modified",
|
||||
"training_completed",
|
||||
"training_failed",
|
||||
"model_activated",
|
||||
]
|
||||
|
||||
|
||||
class DashboardStatsResponse(BaseModel):
|
||||
"""Response for dashboard statistics."""
|
||||
|
||||
total_documents: int = Field(..., description="Total number of documents")
|
||||
annotation_complete: int = Field(
|
||||
..., description="Documents with complete annotations"
|
||||
)
|
||||
annotation_incomplete: int = Field(
|
||||
..., description="Documents with incomplete annotations"
|
||||
)
|
||||
pending: int = Field(..., description="Documents pending processing")
|
||||
completeness_rate: float = Field(
|
||||
..., description="Annotation completeness percentage"
|
||||
)
|
||||
|
||||
|
||||
class ActiveModelInfo(BaseModel):
|
||||
"""Active model information."""
|
||||
|
||||
version_id: str = Field(..., description="Model version UUID")
|
||||
version: str = Field(..., description="Model version string")
|
||||
name: str = Field(..., description="Model name")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
metrics_precision: float | None = Field(None, description="Precision score")
|
||||
metrics_recall: float | None = Field(None, description="Recall score")
|
||||
document_count: int = Field(0, description="Number of training documents")
|
||||
activated_at: datetime | None = Field(None, description="Activation timestamp")
|
||||
|
||||
|
||||
class RunningTrainingInfo(BaseModel):
|
||||
"""Running training task information."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Training task name")
|
||||
status: str = Field(..., description="Training status")
|
||||
started_at: datetime | None = Field(None, description="Start timestamp")
|
||||
progress: int = Field(0, description="Training progress percentage")
|
||||
|
||||
|
||||
class ActiveModelResponse(BaseModel):
|
||||
"""Response for active model endpoint."""
|
||||
|
||||
model: ActiveModelInfo | None = Field(
|
||||
None, description="Active model info, null if none"
|
||||
)
|
||||
running_training: RunningTrainingInfo | None = Field(
|
||||
None, description="Running training task, null if none"
|
||||
)
|
||||
|
||||
|
||||
class ActivityItem(BaseModel):
|
||||
"""Single activity item."""
|
||||
|
||||
type: ActivityType = Field(
|
||||
...,
|
||||
description="Activity type: document_uploaded, annotation_modified, training_completed, training_failed, model_activated",
|
||||
)
|
||||
description: str = Field(..., description="Human-readable description")
|
||||
timestamp: datetime = Field(..., description="Activity timestamp")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class RecentActivityResponse(BaseModel):
|
||||
"""Response for recent activity endpoint."""
|
||||
|
||||
activities: list[ActivityItem] = Field(
|
||||
default_factory=list, description="List of recent activities"
|
||||
)
|
||||
@@ -291,7 +291,7 @@ class AutoLabelService:
|
||||
"bbox_y": bbox_y,
|
||||
"bbox_width": bbox_width,
|
||||
"bbox_height": bbox_height,
|
||||
"text_value": best_match.matched_value,
|
||||
"text_value": best_match.matched_text,
|
||||
"confidence": best_match.score,
|
||||
"source": "auto",
|
||||
})
|
||||
|
||||
276
packages/inference/inference/web/services/dashboard_service.py
Normal file
276
packages/inference/inference/web/services/dashboard_service.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Dashboard Service
|
||||
|
||||
Business logic for dashboard statistics and activity aggregation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, exists, and_, or_
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import (
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
AnnotationHistory,
|
||||
TrainingTask,
|
||||
ModelVersion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Field class IDs for completeness calculation
|
||||
# Identifiers: invoice_number (0) or ocr_number (3)
|
||||
IDENTIFIER_CLASS_IDS = {0, 3}
|
||||
# Payment accounts: bankgiro (4) or plusgiro (5)
|
||||
PAYMENT_CLASS_IDS = {4, 5}
|
||||
|
||||
|
||||
def is_annotation_complete(annotations: list[dict[str, Any]]) -> bool:
|
||||
"""Check if a document's annotations are complete.
|
||||
|
||||
A document is complete if it has:
|
||||
- At least one identifier field (invoice_number OR ocr_number)
|
||||
- At least one payment field (bankgiro OR plusgiro)
|
||||
|
||||
Args:
|
||||
annotations: List of annotation dicts with class_id
|
||||
|
||||
Returns:
|
||||
True if document has required fields
|
||||
"""
|
||||
class_ids = {ann.get("class_id") for ann in annotations}
|
||||
|
||||
has_identifier = bool(class_ids & IDENTIFIER_CLASS_IDS)
|
||||
has_payment = bool(class_ids & PAYMENT_CLASS_IDS)
|
||||
|
||||
return has_identifier and has_payment
|
||||
|
||||
|
||||
class DashboardStatsService:
|
||||
"""Service for computing dashboard statistics."""
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get dashboard statistics.
|
||||
|
||||
Returns:
|
||||
Dict with total_documents, annotation_complete, annotation_incomplete,
|
||||
pending, and completeness_rate
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
# Total documents
|
||||
total = session.exec(
|
||||
select(func.count()).select_from(AdminDocument)
|
||||
).one()
|
||||
|
||||
# Pending documents (status in ['pending', 'auto_labeling'])
|
||||
pending = session.exec(
|
||||
select(func.count())
|
||||
.select_from(AdminDocument)
|
||||
.where(AdminDocument.status.in_(["pending", "auto_labeling"]))
|
||||
).one()
|
||||
|
||||
# Complete annotations: labeled + has identifier + has payment
|
||||
complete = self._count_complete(session)
|
||||
|
||||
# Incomplete: labeled but not complete
|
||||
labeled_count = session.exec(
|
||||
select(func.count())
|
||||
.select_from(AdminDocument)
|
||||
.where(AdminDocument.status == "labeled")
|
||||
).one()
|
||||
incomplete = labeled_count - complete
|
||||
|
||||
# Calculate completeness rate
|
||||
total_assessed = complete + incomplete
|
||||
completeness_rate = (
|
||||
round(complete / total_assessed * 100, 2)
|
||||
if total_assessed > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_documents": total,
|
||||
"annotation_complete": complete,
|
||||
"annotation_incomplete": incomplete,
|
||||
"pending": pending,
|
||||
"completeness_rate": completeness_rate,
|
||||
}
|
||||
|
||||
def _count_complete(self, session) -> int:
|
||||
"""Count documents with complete annotations.
|
||||
|
||||
A document is complete if it:
|
||||
1. Has status = 'labeled'
|
||||
2. Has at least one identifier annotation (class_id 0 or 3)
|
||||
3. Has at least one payment annotation (class_id 4 or 5)
|
||||
"""
|
||||
# Subquery for documents with identifier
|
||||
has_identifier = exists(
|
||||
select(1)
|
||||
.select_from(AdminAnnotation)
|
||||
.where(
|
||||
and_(
|
||||
AdminAnnotation.document_id == AdminDocument.document_id,
|
||||
AdminAnnotation.class_id.in_(IDENTIFIER_CLASS_IDS),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Subquery for documents with payment
|
||||
has_payment = exists(
|
||||
select(1)
|
||||
.select_from(AdminAnnotation)
|
||||
.where(
|
||||
and_(
|
||||
AdminAnnotation.document_id == AdminDocument.document_id,
|
||||
AdminAnnotation.class_id.in_(PAYMENT_CLASS_IDS),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
count = session.exec(
|
||||
select(func.count())
|
||||
.select_from(AdminDocument)
|
||||
.where(
|
||||
and_(
|
||||
AdminDocument.status == "labeled",
|
||||
has_identifier,
|
||||
has_payment,
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class DashboardActivityService:
|
||||
"""Service for aggregating recent activities."""
|
||||
|
||||
def get_recent_activities(self, limit: int = 10) -> list[dict[str, Any]]:
|
||||
"""Get recent system activities.
|
||||
|
||||
Aggregates from:
|
||||
- Document uploads
|
||||
- Annotation modifications
|
||||
- Training completions/failures
|
||||
- Model activations
|
||||
|
||||
Args:
|
||||
limit: Maximum number of activities to return
|
||||
|
||||
Returns:
|
||||
List of activity dicts sorted by timestamp DESC
|
||||
"""
|
||||
activities = []
|
||||
|
||||
with get_session_context() as session:
|
||||
# Document uploads (recent 10)
|
||||
uploads = session.exec(
|
||||
select(AdminDocument)
|
||||
.order_by(AdminDocument.created_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for doc in uploads:
|
||||
activities.append({
|
||||
"type": "document_uploaded",
|
||||
"description": f"Uploaded {doc.filename}",
|
||||
"timestamp": doc.created_at,
|
||||
"metadata": {
|
||||
"document_id": str(doc.document_id),
|
||||
"filename": doc.filename,
|
||||
},
|
||||
})
|
||||
|
||||
# Annotation modifications (from history)
|
||||
modifications = session.exec(
|
||||
select(AnnotationHistory)
|
||||
.where(AnnotationHistory.action == "override")
|
||||
.order_by(AnnotationHistory.created_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for mod in modifications:
|
||||
# Get document filename
|
||||
doc = session.get(AdminDocument, mod.document_id)
|
||||
filename = doc.filename if doc else "Unknown"
|
||||
field_name = ""
|
||||
if mod.new_value and isinstance(mod.new_value, dict):
|
||||
field_name = mod.new_value.get("class_name", "")
|
||||
|
||||
activities.append({
|
||||
"type": "annotation_modified",
|
||||
"description": f"Modified {filename} {field_name}".strip(),
|
||||
"timestamp": mod.created_at,
|
||||
"metadata": {
|
||||
"annotation_id": str(mod.annotation_id),
|
||||
"document_id": str(mod.document_id),
|
||||
"field_name": field_name,
|
||||
},
|
||||
})
|
||||
|
||||
# Training completions and failures
|
||||
training_tasks = session.exec(
|
||||
select(TrainingTask)
|
||||
.where(TrainingTask.status.in_(["completed", "failed"]))
|
||||
.order_by(TrainingTask.updated_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for task in training_tasks:
|
||||
if task.updated_at is None:
|
||||
continue
|
||||
if task.status == "completed":
|
||||
# Use metrics_mAP field directly
|
||||
mAP = task.metrics_mAP or 0.0
|
||||
activities.append({
|
||||
"type": "training_completed",
|
||||
"description": f"Training complete: {task.name}, mAP {mAP:.1%}",
|
||||
"timestamp": task.updated_at,
|
||||
"metadata": {
|
||||
"task_id": str(task.task_id),
|
||||
"task_name": task.name,
|
||||
"mAP": mAP,
|
||||
},
|
||||
})
|
||||
else:
|
||||
activities.append({
|
||||
"type": "training_failed",
|
||||
"description": f"Training failed: {task.name}",
|
||||
"timestamp": task.updated_at,
|
||||
"metadata": {
|
||||
"task_id": str(task.task_id),
|
||||
"task_name": task.name,
|
||||
"error": task.error_message or "",
|
||||
},
|
||||
})
|
||||
|
||||
# Model activations
|
||||
model_versions = session.exec(
|
||||
select(ModelVersion)
|
||||
.where(ModelVersion.activated_at.is_not(None))
|
||||
.order_by(ModelVersion.activated_at.desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
for model in model_versions:
|
||||
if model.activated_at is None:
|
||||
continue
|
||||
activities.append({
|
||||
"type": "model_activated",
|
||||
"description": f"Activated model {model.version}",
|
||||
"timestamp": model.activated_at,
|
||||
"metadata": {
|
||||
"version_id": str(model.version_id),
|
||||
"version": model.version,
|
||||
},
|
||||
})
|
||||
|
||||
# Sort all activities by timestamp DESC and return top N
|
||||
activities.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
return activities[:limit]
|
||||
@@ -42,6 +42,7 @@ dev = [
|
||||
"black>=23.0.0",
|
||||
"ruff>=0.1.0",
|
||||
"mypy>=1.0.0",
|
||||
"testcontainers[postgres]>=4.0.0",
|
||||
]
|
||||
gpu = [
|
||||
"paddlepaddle-gpu>=2.5.0",
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Run database migration for training_status fields."""
|
||||
import psycopg2
|
||||
import os
|
||||
|
||||
# Read password from .env file
|
||||
password = ""
|
||||
try:
|
||||
with open(".env") as f:
|
||||
for line in f:
|
||||
if line.startswith("DB_PASSWORD="):
|
||||
password = line.strip().split("=", 1)[1].strip('"').strip("'")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error reading .env: {e}")
|
||||
|
||||
print(f"Password found: {bool(password)}")
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host="192.168.68.31",
|
||||
port=5432,
|
||||
database="docmaster",
|
||||
user="docmaster",
|
||||
password=password
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
|
||||
# Add training_status column
|
||||
try:
|
||||
cur.execute("ALTER TABLE training_datasets ADD COLUMN training_status VARCHAR(20) DEFAULT NULL")
|
||||
print("Added training_status column")
|
||||
except Exception as e:
|
||||
print(f"training_status: {e}")
|
||||
|
||||
# Add active_training_task_id column
|
||||
try:
|
||||
cur.execute("ALTER TABLE training_datasets ADD COLUMN active_training_task_id UUID DEFAULT NULL")
|
||||
print("Added active_training_task_id column")
|
||||
except Exception as e:
|
||||
print(f"active_training_task_id: {e}")
|
||||
|
||||
# Create indexes
|
||||
try:
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status)")
|
||||
print("Created training_status index")
|
||||
except Exception as e:
|
||||
print(f"index training_status: {e}")
|
||||
|
||||
try:
|
||||
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id)")
|
||||
print("Created active_training_task_id index")
|
||||
except Exception as e:
|
||||
print(f"index active_training_task_id: {e}")
|
||||
|
||||
# Update existing datasets that have been used in completed training tasks to trained status
|
||||
try:
|
||||
cur.execute("""
|
||||
UPDATE training_datasets d
|
||||
SET status = 'trained'
|
||||
WHERE d.status = 'ready'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM training_tasks t
|
||||
WHERE t.dataset_id = d.dataset_id
|
||||
AND t.status = 'completed'
|
||||
)
|
||||
""")
|
||||
print(f"Updated {cur.rowcount} datasets to trained status")
|
||||
except Exception as e:
|
||||
print(f"update status: {e}")
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
print("Migration complete!")
|
||||
1
tests/integration/__init__.py
Normal file
1
tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration tests for invoice-master-poc-v2."""
|
||||
1
tests/integration/api/__init__.py
Normal file
1
tests/integration/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API integration tests."""
|
||||
389
tests/integration/api/test_api_integration.py
Normal file
389
tests/integration/api/test_api_integration.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
API Integration Tests
|
||||
|
||||
Tests FastAPI endpoints with mocked services.
|
||||
These tests verify the API layer works correctly with the service layer.
|
||||
"""
|
||||
|
||||
import io
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockServiceResult:
|
||||
"""Mock result from inference service."""
|
||||
|
||||
document_id: str = "test-doc-123"
|
||||
success: bool = True
|
||||
document_type: str = "invoice"
|
||||
fields: dict[str, str] = field(default_factory=lambda: {
|
||||
"InvoiceNumber": "INV-2024-001",
|
||||
"Amount": "1500.00",
|
||||
"InvoiceDate": "2024-01-15",
|
||||
"OCR": "12345678901234",
|
||||
"Bankgiro": "1234-5678",
|
||||
})
|
||||
confidence: dict[str, float] = field(default_factory=lambda: {
|
||||
"InvoiceNumber": 0.95,
|
||||
"Amount": 0.92,
|
||||
"InvoiceDate": 0.88,
|
||||
"OCR": 0.95,
|
||||
"Bankgiro": 0.90,
|
||||
})
|
||||
detections: list[dict[str, Any]] = field(default_factory=list)
|
||||
processing_time_ms: float = 150.5
|
||||
visualization_path: Path | None = None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create temporary storage directories."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
base = Path(tmpdir)
|
||||
uploads_dir = base / "uploads" / "inference"
|
||||
results_dir = base / "results"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield {
|
||||
"base": base,
|
||||
"uploads": uploads_dir,
|
||||
"results": results_dir,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_service():
|
||||
"""Create a mock inference service."""
|
||||
service = MagicMock()
|
||||
service.is_initialized = True
|
||||
service.gpu_available = False
|
||||
|
||||
# Create a realistic mock result
|
||||
mock_result = MockServiceResult()
|
||||
|
||||
service.process_pdf.return_value = mock_result
|
||||
service.process_image.return_value = mock_result
|
||||
service.initialize.return_value = None
|
||||
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_config(temp_storage_dir):
|
||||
"""Create mock storage configuration."""
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
return StorageConfig(
|
||||
upload_dir=temp_storage_dir["uploads"],
|
||||
result_dir=temp_storage_dir["results"],
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_helper(temp_storage_dir):
|
||||
"""Create a mock storage helper."""
|
||||
helper = MagicMock()
|
||||
helper.get_uploads_base_path.return_value = temp_storage_dir["uploads"]
|
||||
helper.get_result_local_path.return_value = None
|
||||
helper.result_exists.return_value = False
|
||||
return helper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
|
||||
"""Create a test FastAPI application with mocked storage."""
|
||||
from inference.web.api.v1.public.inference import create_inference_router
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Patch get_storage_helper to return our mock
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
inference_router = create_inference_router(mock_inference_service, mock_storage_config)
|
||||
app.include_router(inference_router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_app, mock_storage_helper):
|
||||
"""Create a test client with storage helper patched."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
yield TestClient(test_app)
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
def test_health_check(self, client, mock_inference_service):
|
||||
"""Test health check returns status."""
|
||||
response = client.get("/api/v1/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
|
||||
|
||||
class TestInferenceEndpoint:
|
||||
"""Tests for inference endpoint."""
|
||||
|
||||
def test_infer_pdf(self, client, mock_inference_service, mock_storage_helper, temp_storage_dir):
|
||||
"""Test PDF inference endpoint."""
|
||||
# Create a minimal PDF content
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
assert data["result"]["success"] is True
|
||||
assert "InvoiceNumber" in data["result"]["fields"]
|
||||
|
||||
def test_infer_image(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test image inference endpoint."""
|
||||
# Create minimal PNG header
|
||||
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", io.BytesIO(png_header), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "result" in data
|
||||
|
||||
def test_infer_invalid_file_type(self, client, mock_storage_helper):
|
||||
"""Test rejection of invalid file types."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_infer_no_file(self, client, mock_storage_helper):
|
||||
"""Test rejection when no file provided."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post("/api/v1/infer")
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_infer_result_structure(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that result has expected structure."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
result = data["result"]
|
||||
|
||||
# Check required fields
|
||||
assert "document_id" in result
|
||||
assert "success" in result
|
||||
assert "fields" in result
|
||||
assert "confidence" in result
|
||||
assert "processing_time_ms" in result
|
||||
|
||||
|
||||
class TestInferenceResultFormat:
|
||||
"""Tests for inference result formatting."""
|
||||
|
||||
def test_result_fields_mapped_correctly(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that fields are mapped to API response format."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
fields = data["result"]["fields"]
|
||||
|
||||
assert fields["InvoiceNumber"] == "INV-2024-001"
|
||||
assert fields["Amount"] == "1500.00"
|
||||
assert fields["InvoiceDate"] == "2024-01-15"
|
||||
|
||||
def test_confidence_values_included(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that confidence values are included."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
confidence = data["result"]["confidence"]
|
||||
|
||||
assert "InvoiceNumber" in confidence
|
||||
assert confidence["InvoiceNumber"] == 0.95
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in API."""
|
||||
|
||||
def test_service_error_handling(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test handling of service errors."""
|
||||
mock_inference_service.process_pdf.side_effect = Exception("Processing failed")
|
||||
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
# Should return error response
|
||||
assert response.status_code >= 400
|
||||
|
||||
def test_empty_file_handling(self, client, mock_storage_helper):
|
||||
"""Test handling of empty files."""
|
||||
# Empty file still has valid content type
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(b""), "application/pdf")},
|
||||
)
|
||||
|
||||
# Empty file may be processed or rejected depending on implementation
|
||||
# Just verify we get a response
|
||||
assert response.status_code in [200, 400, 422, 500]
|
||||
|
||||
|
||||
class TestResponseFormat:
|
||||
"""Tests for API response format consistency."""
|
||||
|
||||
def test_success_response_format(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test successful response format."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
|
||||
data = response.json()
|
||||
assert isinstance(data, dict)
|
||||
assert "result" in data
|
||||
|
||||
def test_json_serialization(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that all result fields are JSON serializable."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
# If this doesn't raise, JSON is valid
|
||||
data = response.json()
|
||||
assert data is not None
|
||||
|
||||
|
||||
class TestDocumentIdGeneration:
|
||||
"""Tests for document ID handling."""
|
||||
|
||||
def test_document_id_generated(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test that document ID is generated."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert "document_id" in data["result"]
|
||||
assert data["result"]["document_id"] is not None
|
||||
|
||||
def test_document_id_from_filename(self, client, mock_inference_service, mock_storage_helper):
|
||||
"""Test document ID derived from filename."""
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("my_invoice_123.pdf", io.BytesIO(pdf_content), "application/pdf")},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# Document ID should be set (either from filename or generated)
|
||||
assert data["result"]["document_id"] is not None
|
||||
400
tests/integration/api/test_dashboard_api_integration.py
Normal file
400
tests/integration/api/test_dashboard_api_integration.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Dashboard API Integration Tests
|
||||
|
||||
Tests Dashboard API endpoints with real database operations via TestClient.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AdminToken,
|
||||
AnnotationHistory,
|
||||
ModelVersion,
|
||||
TrainingDataset,
|
||||
TrainingTask,
|
||||
)
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from inference.web.core.auth import get_admin_token_dep
|
||||
|
||||
|
||||
def create_test_app(override_token_dep):
|
||||
"""Create a FastAPI test application with dashboard router."""
|
||||
app = FastAPI()
|
||||
router = create_dashboard_router()
|
||||
app.include_router(router)
|
||||
|
||||
# Override auth dependency
|
||||
app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class TestDashboardStatsEndpoint:
|
||||
"""Tests for GET /admin/dashboard/stats endpoint."""
|
||||
|
||||
def test_stats_empty_database(self, patched_session, admin_token):
|
||||
"""Test stats endpoint with empty database."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/stats")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_documents"] == 0
|
||||
assert data["annotation_complete"] == 0
|
||||
assert data["annotation_incomplete"] == 0
|
||||
assert data["pending"] == 0
|
||||
assert data["completeness_rate"] == 0.0
|
||||
|
||||
def test_stats_with_pending_documents(self, patched_session, admin_token):
|
||||
"""Test stats with pending documents."""
|
||||
session = patched_session
|
||||
|
||||
# Create pending documents
|
||||
for i in range(3):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"pending_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/pending_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/stats")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_documents"] == 3
|
||||
assert data["pending"] == 3
|
||||
|
||||
def test_stats_with_complete_annotations(self, patched_session, admin_token):
|
||||
"""Test stats with complete annotations."""
|
||||
session = patched_session
|
||||
|
||||
# Create labeled document with complete annotations
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="complete.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/complete.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
# Add identifier and payment annotations
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=0, # invoice_number
|
||||
class_name="invoice_number",
|
||||
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=4, # bankgiro
|
||||
class_name="bankgiro",
|
||||
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/stats")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["annotation_complete"] == 1
|
||||
assert data["completeness_rate"] == 100.0
|
||||
|
||||
|
||||
class TestActiveModelEndpoint:
|
||||
"""Tests for GET /admin/dashboard/active-model endpoint."""
|
||||
|
||||
def test_active_model_none(self, patched_session, admin_token):
|
||||
"""Test active-model endpoint with no active model."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/active-model")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["model"] is None
|
||||
assert data["running_training"] is None
|
||||
|
||||
def test_active_model_with_model(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test active-model endpoint with active model."""
|
||||
session = patched_session
|
||||
|
||||
# Create training task
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Test Task",
|
||||
status="completed",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(task)
|
||||
session.commit()
|
||||
|
||||
# Create active model
|
||||
model = ModelVersion(
|
||||
version_id=uuid4(),
|
||||
version="1.0.0",
|
||||
name="Test Model",
|
||||
model_path="/models/test.pt",
|
||||
status="active",
|
||||
is_active=True,
|
||||
task_id=task.task_id,
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
metrics_mAP=0.90,
|
||||
metrics_precision=0.88,
|
||||
metrics_recall=0.85,
|
||||
document_count=100,
|
||||
file_size=50000000,
|
||||
activated_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/active-model")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["model"] is not None
|
||||
assert data["model"]["version"] == "1.0.0"
|
||||
assert data["model"]["name"] == "Test Model"
|
||||
assert data["model"]["metrics_mAP"] == 0.90
|
||||
|
||||
def test_active_model_with_running_training(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test active-model endpoint with running training."""
|
||||
session = patched_session
|
||||
|
||||
# Create running training task
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Running Task",
|
||||
status="running",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
progress=50,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(task)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/active-model")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["running_training"] is not None
|
||||
assert data["running_training"]["name"] == "Running Task"
|
||||
assert data["running_training"]["status"] == "running"
|
||||
assert data["running_training"]["progress"] == 50
|
||||
|
||||
|
||||
class TestRecentActivityEndpoint:
|
||||
"""Tests for GET /admin/dashboard/activity endpoint."""
|
||||
|
||||
def test_activity_empty(self, patched_session, admin_token):
|
||||
"""Test activity endpoint with no activities."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["activities"] == []
|
||||
|
||||
def test_activity_with_uploads(self, patched_session, admin_token):
|
||||
"""Test activity includes document uploads."""
|
||||
session = patched_session
|
||||
|
||||
# Create documents
|
||||
for i in range(3):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"activity_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/activity_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
upload_activities = [a for a in data["activities"] if a["type"] == "document_uploaded"]
|
||||
assert len(upload_activities) == 3
|
||||
|
||||
def test_activity_limit_parameter(self, patched_session, admin_token):
|
||||
"""Test activity limit parameter."""
|
||||
session = patched_session
|
||||
|
||||
# Create many documents
|
||||
for i in range(15):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"limit_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/limit_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity?limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["activities"]) <= 5
|
||||
|
||||
def test_activity_invalid_limit(self, patched_session, admin_token):
|
||||
"""Test activity with invalid limit parameter."""
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
# Limit too high
|
||||
response = client.get("/admin/dashboard/activity?limit=100")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Limit too low
|
||||
response = client.get("/admin/dashboard/activity?limit=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_activity_with_training_completion(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test activity includes training completions."""
|
||||
session = patched_session
|
||||
|
||||
# Create completed training task
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Completed Task",
|
||||
status="completed",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
metrics_mAP=0.95,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(task)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
training_activities = [a for a in data["activities"] if a["type"] == "training_completed"]
|
||||
assert len(training_activities) >= 1
|
||||
|
||||
def test_activity_sorted_by_timestamp(self, patched_session, admin_token):
|
||||
"""Test activities are sorted by timestamp descending."""
|
||||
session = patched_session
|
||||
|
||||
# Create documents
|
||||
for i in range(5):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"sorted_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/sorted_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
app = create_test_app(admin_token.token)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/admin/dashboard/activity")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
timestamps = [a["timestamp"] for a in data["activities"]]
|
||||
assert timestamps == sorted(timestamps, reverse=True)
|
||||
465
tests/integration/conftest.py
Normal file
465
tests/integration/conftest.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Integration Test Fixtures
|
||||
|
||||
Provides shared fixtures for integration tests using PostgreSQL.
|
||||
|
||||
IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation.
|
||||
This ensures tests never touch the real production/development database.
|
||||
|
||||
Supported modes:
|
||||
1. Docker testcontainers (default): Automatically starts a PostgreSQL container
|
||||
2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!)
|
||||
|
||||
To use an external test database, set:
|
||||
TEST_DB_URL=postgresql://user:password@host:port/test_dbname
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AdminToken,
|
||||
AnnotationHistory,
|
||||
BatchUpload,
|
||||
BatchUploadFile,
|
||||
DatasetDocument,
|
||||
ModelVersion,
|
||||
TrainingDataset,
|
||||
TrainingDocumentLink,
|
||||
TrainingLog,
|
||||
TrainingTask,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _is_docker_available() -> bool:
|
||||
"""Check if Docker is available."""
|
||||
try:
|
||||
import docker
|
||||
client = docker.from_env()
|
||||
client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_test_db_url() -> str | None:
|
||||
"""Get test database URL from environment."""
|
||||
return os.environ.get("TEST_DB_URL")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_engine():
|
||||
"""Create a SQLAlchemy engine for testing.
|
||||
|
||||
Uses one of:
|
||||
1. TEST_DB_URL environment variable (dedicated test database)
|
||||
2. Docker testcontainers (if Docker is available)
|
||||
|
||||
IMPORTANT: Will NOT fall back to production database. If Docker is not
|
||||
available and TEST_DB_URL is not set, tests will fail with a clear error.
|
||||
|
||||
The engine is shared across all tests in a session for efficiency.
|
||||
"""
|
||||
# Try to get URL from environment first
|
||||
connection_url = _get_test_db_url()
|
||||
|
||||
if connection_url:
|
||||
# Use external test database from environment
|
||||
# Warn if it looks like a production database
|
||||
if "docmaster" in connection_url and "_test" not in connection_url:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"TEST_DB_URL appears to point to a production database. "
|
||||
"Please use a dedicated test database (e.g., docmaster_test).",
|
||||
UserWarning,
|
||||
)
|
||||
elif _is_docker_available():
|
||||
# Use testcontainers - this is the recommended approach
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
postgres = PostgresContainer("postgres:15-alpine")
|
||||
postgres.start()
|
||||
|
||||
connection_url = postgres.get_connection_url()
|
||||
if "psycopg2" in connection_url:
|
||||
connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://")
|
||||
|
||||
# Store container for cleanup
|
||||
test_engine._postgres_container = postgres
|
||||
else:
|
||||
# No Docker and no TEST_DB_URL - fail with clear instructions
|
||||
pytest.fail(
|
||||
"Integration tests require Docker or a TEST_DB_URL environment variable.\n\n"
|
||||
"Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n"
|
||||
" - Windows: https://docs.docker.com/desktop/install/windows-install/\n"
|
||||
" - The testcontainers library will automatically create a PostgreSQL container.\n\n"
|
||||
"Option 2: Set TEST_DB_URL to a dedicated test database:\n"
|
||||
" - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n"
|
||||
" - NEVER use your production database for tests!\n\n"
|
||||
"Integration tests will NOT fall back to the production database."
|
||||
)
|
||||
|
||||
engine = create_engine(
|
||||
connection_url,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
yield engine
|
||||
|
||||
# Cleanup
|
||||
SQLModel.metadata.drop_all(engine)
|
||||
engine.dispose()
|
||||
|
||||
# Stop container if we started one
|
||||
if hasattr(test_engine, "_postgres_container"):
|
||||
test_engine._postgres_container.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session(test_engine) -> Generator[Session, None, None]:
|
||||
"""Provide a database session for each test function.
|
||||
|
||||
Each test gets a fresh session that rolls back after the test,
|
||||
ensuring test isolation.
|
||||
"""
|
||||
connection = test_engine.connect()
|
||||
transaction = connection.begin()
|
||||
session = Session(bind=connection)
|
||||
|
||||
yield session
|
||||
|
||||
# Rollback and cleanup
|
||||
session.close()
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def patched_session(db_session):
|
||||
"""Patch get_session_context to use the test session.
|
||||
|
||||
This allows repository classes to use the test database session
|
||||
instead of creating their own connections.
|
||||
|
||||
We need to patch in multiple locations because each repository module
|
||||
imports get_session_context directly.
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def mock_session_context() -> Generator[Session, None, None]:
|
||||
yield db_session
|
||||
|
||||
# All modules that import get_session_context
|
||||
patch_targets = [
|
||||
"inference.data.database.get_session_context",
|
||||
"inference.data.repositories.document_repository.get_session_context",
|
||||
"inference.data.repositories.annotation_repository.get_session_context",
|
||||
"inference.data.repositories.dataset_repository.get_session_context",
|
||||
"inference.data.repositories.training_task_repository.get_session_context",
|
||||
"inference.data.repositories.model_version_repository.get_session_context",
|
||||
"inference.data.repositories.batch_upload_repository.get_session_context",
|
||||
"inference.data.repositories.token_repository.get_session_context",
|
||||
"inference.web.services.dashboard_service.get_session_context",
|
||||
]
|
||||
|
||||
with ExitStack() as stack:
|
||||
for target in patch_targets:
|
||||
try:
|
||||
stack.enter_context(patch(target, mock_session_context))
|
||||
except (ModuleNotFoundError, AttributeError):
|
||||
# Skip if module doesn't exist or doesn't have the attribute
|
||||
pass
|
||||
yield db_session
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Data Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_token(db_session) -> AdminToken:
|
||||
"""Create a test admin token."""
|
||||
token = AdminToken(
|
||||
token="test-admin-token-12345",
|
||||
name="Test Admin",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
db_session.refresh(token)
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document(db_session, admin_token) -> AdminDocument:
|
||||
"""Create a sample document for testing."""
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="test_invoice.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/test_invoice.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
db_session.refresh(doc)
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(db_session, sample_document) -> AdminAnnotation:
|
||||
"""Create a sample annotation for testing."""
|
||||
annotation = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=sample_document.document_id,
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=240,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(annotation)
|
||||
db_session.commit()
|
||||
db_session.refresh(annotation)
|
||||
return annotation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset(db_session) -> TrainingDataset:
|
||||
"""Create a sample training dataset for testing."""
|
||||
dataset = TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="Test Dataset",
|
||||
description="Dataset for integration testing",
|
||||
status="building",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=0,
|
||||
total_images=0,
|
||||
total_annotations=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(dataset)
|
||||
db_session.commit()
|
||||
db_session.refresh(dataset)
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask:
|
||||
"""Create a sample training task for testing."""
|
||||
task = TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
name="Test Training Task",
|
||||
description="Training task for integration testing",
|
||||
status="pending",
|
||||
task_type="train",
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
config={"epochs": 10, "batch_size": 16},
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion:
|
||||
"""Create a sample model version for testing."""
|
||||
version = ModelVersion(
|
||||
version_id=uuid4(),
|
||||
version="1.0.0",
|
||||
name="Test Model v1",
|
||||
description="Model version for integration testing",
|
||||
model_path="/models/test_model.pt",
|
||||
status="inactive",
|
||||
is_active=False,
|
||||
task_id=sample_training_task.task_id,
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
metrics_mAP=0.85,
|
||||
metrics_precision=0.88,
|
||||
metrics_recall=0.82,
|
||||
document_count=100,
|
||||
file_size=50000000,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(version)
|
||||
db_session.commit()
|
||||
db_session.refresh(version)
|
||||
return version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_upload(db_session, admin_token) -> BatchUpload:
|
||||
"""Create a sample batch upload for testing."""
|
||||
batch = BatchUpload(
|
||||
batch_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="test_batch.zip",
|
||||
file_size=10240,
|
||||
upload_source="api",
|
||||
status="processing",
|
||||
total_files=5,
|
||||
processed_files=0,
|
||||
successful_files=0,
|
||||
failed_files=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(batch)
|
||||
db_session.commit()
|
||||
db_session.refresh(batch)
|
||||
return batch
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Multiple Documents Fixture
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_documents(db_session, admin_token) -> list[AdminDocument]:
|
||||
"""Create multiple documents for pagination/filtering tests."""
|
||||
documents = []
|
||||
statuses = ["pending", "pending", "labeled", "labeled", "exported"]
|
||||
categories = ["invoice", "invoice", "invoice", "letter", "invoice"]
|
||||
|
||||
for i, (status, category) in enumerate(zip(statuses, categories)):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"test_doc_{i}.pdf",
|
||||
file_size=1024 + i * 100,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/test_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status=status,
|
||||
upload_source="ui",
|
||||
category=category,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(doc)
|
||||
documents.append(doc)
|
||||
|
||||
db_session.commit()
|
||||
for doc in documents:
|
||||
db_session.refresh(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Temporary File Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_upload_dir() -> Generator[Path, None, None]:
|
||||
"""Create a temporary directory for file uploads."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
upload_dir = Path(tmpdir) / "uploads"
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield upload_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_model_dir() -> Generator[Path, None, None]:
|
||||
"""Create a temporary directory for model files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_dir = Path(tmpdir) / "models"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield model_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dataset_dir() -> Generator[Path, None, None]:
|
||||
"""Create a temporary directory for dataset files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir) / "datasets"
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
yield dataset_dir
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sample PDF Fixture
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf_bytes() -> bytes:
|
||||
"""Return minimal valid PDF bytes for testing."""
|
||||
# Minimal valid PDF structure
|
||||
return b"""%PDF-1.4
|
||||
1 0 obj
|
||||
<< /Type /Catalog /Pages 2 0 R >>
|
||||
endobj
|
||||
2 0 obj
|
||||
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
|
||||
endobj
|
||||
3 0 obj
|
||||
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >>
|
||||
endobj
|
||||
xref
|
||||
0 4
|
||||
0000000000 65535 f
|
||||
0000000009 00000 n
|
||||
0000000058 00000 n
|
||||
0000000115 00000 n
|
||||
trailer
|
||||
<< /Size 4 /Root 1 0 R >>
|
||||
startxref
|
||||
196
|
||||
%%EOF"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path:
|
||||
"""Create a sample PDF file for testing."""
|
||||
pdf_path = temp_upload_dir / "test_invoice.pdf"
|
||||
pdf_path.write_bytes(sample_pdf_bytes)
|
||||
return pdf_path
|
||||
1
tests/integration/pipeline/__init__.py
Normal file
1
tests/integration/pipeline/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Pipeline integration tests."""
|
||||
456
tests/integration/pipeline/test_pipeline_integration.py
Normal file
456
tests/integration/pipeline/test_pipeline_integration.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Inference Pipeline Integration Tests
|
||||
|
||||
Tests the complete pipeline from input to output.
|
||||
Note: These tests use mocks for YOLO and OCR to avoid requiring actual models,
|
||||
but test the integration of pipeline components.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from inference.pipeline.pipeline import (
|
||||
InferencePipeline,
|
||||
InferenceResult,
|
||||
CrossValidationResult,
|
||||
)
|
||||
from inference.pipeline.yolo_detector import Detection
|
||||
from inference.pipeline.field_extractor import ExtractedField
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_detection():
|
||||
"""Create a mock detection."""
|
||||
return Detection(
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_extracted_field():
|
||||
"""Create a mock extracted field."""
|
||||
return ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-2024-001",
|
||||
normalized_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceResultConstruction:
|
||||
"""Tests for InferenceResult construction and methods."""
|
||||
|
||||
def test_default_result(self):
|
||||
"""Test default InferenceResult values."""
|
||||
result = InferenceResult()
|
||||
|
||||
assert result.document_id is None
|
||||
assert result.success is False
|
||||
assert result.fields == {}
|
||||
assert result.confidence == {}
|
||||
assert result.raw_detections == []
|
||||
assert result.extracted_fields == []
|
||||
assert result.errors == []
|
||||
assert result.fallback_used is False
|
||||
assert result.cross_validation is None
|
||||
|
||||
def test_result_to_json(self):
|
||||
"""Test JSON serialization of result."""
|
||||
result = InferenceResult(
|
||||
document_id="test-doc",
|
||||
success=True,
|
||||
fields={
|
||||
"InvoiceNumber": "INV-001",
|
||||
"Amount": "1500.00",
|
||||
},
|
||||
confidence={
|
||||
"InvoiceNumber": 0.95,
|
||||
"Amount": 0.92,
|
||||
},
|
||||
bboxes={
|
||||
"InvoiceNumber": (100, 50, 200, 30),
|
||||
},
|
||||
)
|
||||
|
||||
json_data = result.to_json()
|
||||
|
||||
assert json_data["DocumentId"] == "test-doc"
|
||||
assert json_data["success"] is True
|
||||
assert json_data["InvoiceNumber"] == "INV-001"
|
||||
assert json_data["Amount"] == "1500.00"
|
||||
assert json_data["confidence"]["InvoiceNumber"] == 0.95
|
||||
assert "bboxes" in json_data
|
||||
|
||||
def test_result_get_field(self):
|
||||
"""Test getting field value and confidence."""
|
||||
result = InferenceResult(
|
||||
fields={"InvoiceNumber": "INV-001"},
|
||||
confidence={"InvoiceNumber": 0.95},
|
||||
)
|
||||
|
||||
value, conf = result.get_field("InvoiceNumber")
|
||||
assert value == "INV-001"
|
||||
assert conf == 0.95
|
||||
|
||||
value, conf = result.get_field("Amount")
|
||||
assert value is None
|
||||
assert conf == 0.0
|
||||
|
||||
|
||||
class TestCrossValidation:
|
||||
"""Tests for cross-validation logic."""
|
||||
|
||||
def test_cross_validation_default(self):
|
||||
"""Test default CrossValidationResult values."""
|
||||
cv = CrossValidationResult()
|
||||
|
||||
assert cv.is_valid is False
|
||||
assert cv.ocr_match is None
|
||||
assert cv.amount_match is None
|
||||
assert cv.bankgiro_match is None
|
||||
assert cv.plusgiro_match is None
|
||||
assert cv.payment_line_ocr is None
|
||||
assert cv.payment_line_amount is None
|
||||
assert cv.details == []
|
||||
|
||||
def test_cross_validation_with_matches(self):
|
||||
"""Test CrossValidationResult with matches."""
|
||||
cv = CrossValidationResult(
|
||||
is_valid=True,
|
||||
ocr_match=True,
|
||||
amount_match=True,
|
||||
bankgiro_match=True,
|
||||
payment_line_ocr="12345678901234",
|
||||
payment_line_amount="1500.00",
|
||||
payment_line_account="1234-5678",
|
||||
payment_line_account_type="bankgiro",
|
||||
details=["OCR match", "Amount match", "Bankgiro match"],
|
||||
)
|
||||
|
||||
assert cv.is_valid is True
|
||||
assert cv.ocr_match is True
|
||||
assert cv.amount_match is True
|
||||
assert len(cv.details) == 3
|
||||
|
||||
|
||||
class TestPipelineMergeFields:
|
||||
"""Tests for field merging logic."""
|
||||
|
||||
def test_merge_selects_highest_confidence(self):
|
||||
"""Test that merge selects highest confidence for duplicate fields."""
|
||||
# Create mock pipeline with minimal mocking
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-001",
|
||||
normalized_value="INV-001",
|
||||
confidence=0.85,
|
||||
detection_confidence=0.90,
|
||||
ocr_confidence=0.85,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="INV-001",
|
||||
normalized_value="INV-001",
|
||||
confidence=0.95, # Higher confidence
|
||||
detection_confidence=0.95,
|
||||
ocr_confidence=0.95,
|
||||
bbox=(105, 52, 198, 28),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
]
|
||||
|
||||
pipeline._merge_fields(result)
|
||||
|
||||
assert result.fields["InvoiceNumber"] == "INV-001"
|
||||
assert result.confidence["InvoiceNumber"] == 0.95
|
||||
|
||||
def test_merge_skips_invalid_fields(self):
|
||||
"""Test that merge skips invalid extracted fields."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
|
||||
|
||||
result = InferenceResult()
|
||||
result.extracted_fields = [
|
||||
ExtractedField(
|
||||
field_name="InvoiceNumber",
|
||||
raw_text="",
|
||||
normalized_value=None,
|
||||
confidence=0.95,
|
||||
detection_confidence=0.95,
|
||||
ocr_confidence=0.95,
|
||||
bbox=(100, 50, 200, 30),
|
||||
page_no=0,
|
||||
is_valid=False, # Invalid
|
||||
),
|
||||
ExtractedField(
|
||||
field_name="Amount",
|
||||
raw_text="1500.00",
|
||||
normalized_value="1500.00",
|
||||
confidence=0.92,
|
||||
detection_confidence=0.92,
|
||||
ocr_confidence=0.92,
|
||||
bbox=(200, 100, 100, 25),
|
||||
page_no=0,
|
||||
is_valid=True,
|
||||
),
|
||||
]
|
||||
|
||||
pipeline._merge_fields(result)
|
||||
|
||||
assert "InvoiceNumber" not in result.fields
|
||||
assert result.fields["Amount"] == "1500.00"
|
||||
|
||||
|
||||
class TestPaymentLineValidation:
|
||||
"""Tests for payment line cross-validation."""
|
||||
|
||||
def test_payment_line_overrides_ocr(self):
|
||||
"""Test that payment line OCR overrides detected OCR."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Mock payment line parser
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = "12345678901234"
|
||||
mock_parsed.amount = "1500.00"
|
||||
mock_parsed.account_number = "12345678"
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
||||
"OCR": "99999999999999", # Different OCR
|
||||
},
|
||||
confidence={"OCR": 0.85},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
# Payment line OCR should override
|
||||
assert result.fields["OCR"] == "12345678901234"
|
||||
assert result.confidence["OCR"] == 0.95
|
||||
|
||||
def test_payment_line_overrides_amount(self):
|
||||
"""Test that payment line amount overrides detected amount."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = None
|
||||
mock_parsed.amount = "2500.50"
|
||||
mock_parsed.account_number = None
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# ... # 2500 50 5 > ...",
|
||||
"Amount": "2500.00", # Slightly different
|
||||
},
|
||||
confidence={"Amount": 0.80},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
assert result.fields["Amount"] == "2500.50"
|
||||
assert result.confidence["Amount"] == 0.95
|
||||
|
||||
def test_cross_validation_records_matches(self):
|
||||
"""Test that cross-validation records match status."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
mock_parsed = MagicMock()
|
||||
mock_parsed.is_valid = True
|
||||
mock_parsed.ocr_number = "12345678901234"
|
||||
mock_parsed.amount = "1500.00"
|
||||
mock_parsed.account_number = "12345678"
|
||||
|
||||
pipeline.payment_line_parser = MagicMock()
|
||||
pipeline.payment_line_parser.parse.return_value = mock_parsed
|
||||
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
|
||||
"OCR": "12345678901234", # Matching
|
||||
"Amount": "1500.00", # Matching
|
||||
"Bankgiro": "1234-5678", # Matching
|
||||
},
|
||||
confidence={},
|
||||
)
|
||||
|
||||
pipeline._cross_validate_payment_line(result)
|
||||
|
||||
assert result.cross_validation is not None
|
||||
assert result.cross_validation.ocr_match is True
|
||||
assert result.cross_validation.amount_match is True
|
||||
assert result.cross_validation.is_valid is True
|
||||
|
||||
|
||||
class TestFallbackLogic:
|
||||
"""Tests for fallback detection logic."""
|
||||
|
||||
def test_needs_fallback_when_key_fields_missing(self):
|
||||
"""Test fallback is triggered when key fields missing."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Only one key field present
|
||||
result = InferenceResult(fields={"Amount": "1500.00"})
|
||||
|
||||
assert pipeline._needs_fallback(result) is True
|
||||
|
||||
def test_no_fallback_when_fields_present(self):
|
||||
"""Test no fallback when key fields present."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# All key fields present
|
||||
result = InferenceResult(
|
||||
fields={
|
||||
"Amount": "1500.00",
|
||||
"InvoiceNumber": "INV-001",
|
||||
"OCR": "12345678901234",
|
||||
}
|
||||
)
|
||||
|
||||
assert pipeline._needs_fallback(result) is False
|
||||
|
||||
|
||||
class TestPatternExtraction:
|
||||
"""Tests for fallback pattern extraction."""
|
||||
|
||||
def test_extract_amount_pattern(self):
|
||||
"""Test amount extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Att betala: 1 500,00 SEK"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "Amount" in result.fields
|
||||
assert result.confidence["Amount"] == 0.5
|
||||
|
||||
def test_extract_bankgiro_pattern(self):
|
||||
"""Test bankgiro extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Bankgiro: 1234-5678"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "Bankgiro" in result.fields
|
||||
assert result.fields["Bankgiro"] == "1234-5678"
|
||||
|
||||
def test_extract_ocr_pattern(self):
|
||||
"""Test OCR extraction with regex."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "OCR: 12345678901234567890"
|
||||
result = InferenceResult()
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
assert "OCR" in result.fields
|
||||
assert result.fields["OCR"] == "12345678901234567890"
|
||||
|
||||
def test_does_not_override_existing_fields(self):
|
||||
"""Test pattern extraction doesn't override existing fields."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
text = "Fakturanr: 999"
|
||||
result = InferenceResult(fields={"InvoiceNumber": "INV-001"})
|
||||
|
||||
pipeline._extract_with_patterns(text, result)
|
||||
|
||||
# Should keep existing value
|
||||
assert result.fields["InvoiceNumber"] == "INV-001"
|
||||
|
||||
|
||||
class TestAmountNormalization:
|
||||
"""Tests for amount normalization."""
|
||||
|
||||
def test_normalize_swedish_format(self):
|
||||
"""Test normalizing Swedish amount format."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
# Swedish format: space as thousands separator, comma as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00
|
||||
# Standard format: dot as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00
|
||||
# Swedish format with comma as decimal
|
||||
assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00
|
||||
|
||||
def test_normalize_invalid_amount(self):
|
||||
"""Test normalizing invalid amount returns None."""
|
||||
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
|
||||
pipeline = InferencePipeline.__new__(InferencePipeline)
|
||||
|
||||
assert pipeline._normalize_amount_for_compare("invalid") is None
|
||||
assert pipeline._normalize_amount_for_compare("") is None
|
||||
|
||||
|
||||
class TestResultSerialization:
|
||||
"""Tests for result serialization with cross-validation."""
|
||||
|
||||
def test_to_json_with_cross_validation(self):
|
||||
"""Test JSON serialization includes cross-validation."""
|
||||
cv = CrossValidationResult(
|
||||
is_valid=True,
|
||||
ocr_match=True,
|
||||
amount_match=True,
|
||||
payment_line_ocr="12345678901234",
|
||||
payment_line_amount="1500.00",
|
||||
details=["OCR match", "Amount match"],
|
||||
)
|
||||
|
||||
result = InferenceResult(
|
||||
document_id="test-doc",
|
||||
success=True,
|
||||
fields={"InvoiceNumber": "INV-001"},
|
||||
cross_validation=cv,
|
||||
)
|
||||
|
||||
json_data = result.to_json()
|
||||
|
||||
assert "cross_validation" in json_data
|
||||
assert json_data["cross_validation"]["is_valid"] is True
|
||||
assert json_data["cross_validation"]["ocr_match"] is True
|
||||
assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"
|
||||
1
tests/integration/repositories/__init__.py
Normal file
1
tests/integration/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repository integration tests."""
|
||||
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Annotation Repository Integration Tests
|
||||
|
||||
Tests AnnotationRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepositoryCreate:
|
||||
"""Tests for annotation creation."""
|
||||
|
||||
def test_create_annotation(self, patched_session, sample_document):
|
||||
"""Test creating a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann_id = repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=240,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
assert ann_id is not None
|
||||
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
assert ann.class_name == "invoice_number"
|
||||
assert ann.text_value == "INV-2024-001"
|
||||
assert ann.confidence == 0.95
|
||||
assert ann.source == "auto"
|
||||
|
||||
def test_create_batch_annotations(self, patched_session, sample_document):
|
||||
"""Test batch creation of annotations."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
annotations_data = [
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.1,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 80,
|
||||
"bbox_width": 160,
|
||||
"bbox_height": 40,
|
||||
"text_value": "INV-001",
|
||||
"confidence": 0.95,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 1,
|
||||
"class_name": "invoice_date",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.2,
|
||||
"width": 0.15,
|
||||
"height": 0.04,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 160,
|
||||
"bbox_width": 120,
|
||||
"bbox_height": 32,
|
||||
"text_value": "2024-01-15",
|
||||
"confidence": 0.92,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 6,
|
||||
"class_name": "amount",
|
||||
"x_center": 0.7,
|
||||
"y_center": 0.8,
|
||||
"width": 0.1,
|
||||
"height": 0.04,
|
||||
"bbox_x": 560,
|
||||
"bbox_y": 640,
|
||||
"bbox_width": 80,
|
||||
"bbox_height": 32,
|
||||
"text_value": "1500.00",
|
||||
"confidence": 0.98,
|
||||
},
|
||||
]
|
||||
|
||||
ids = repo.create_batch(annotations_data)
|
||||
|
||||
assert len(ids) == 3
|
||||
|
||||
# Verify all annotations exist
|
||||
for ann_id in ids:
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
|
||||
|
||||
class TestAnnotationRepositoryRead:
|
||||
"""Tests for annotation retrieval."""
|
||||
|
||||
def test_get_nonexistent_annotation(self, patched_session):
|
||||
"""Test getting an annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.get(str(uuid4()))
|
||||
assert ann is None
|
||||
|
||||
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Add another annotation
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.4,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=320,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
text_value="2024-01-15",
|
||||
)
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
|
||||
assert len(annotations) == 2
|
||||
# Should be ordered by class_id
|
||||
assert annotations[0].class_id == 0
|
||||
assert annotations[1].class_id == 1
|
||||
|
||||
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
|
||||
"""Test getting annotations for a specific page."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create annotations on different pages
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=2,
|
||||
class_id=6,
|
||||
class_name="amount",
|
||||
x_center=0.7,
|
||||
y_center=0.8,
|
||||
width=0.1,
|
||||
height=0.04,
|
||||
bbox_x=560,
|
||||
bbox_y=640,
|
||||
bbox_width=80,
|
||||
bbox_height=32,
|
||||
)
|
||||
|
||||
page1_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=1,
|
||||
)
|
||||
page2_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=2,
|
||||
)
|
||||
|
||||
assert len(page1_annotations) == 1
|
||||
assert len(page2_annotations) == 1
|
||||
assert page1_annotations[0].page_number == 1
|
||||
assert page2_annotations[0].page_number == 2
|
||||
|
||||
|
||||
class TestAnnotationRepositoryUpdate:
|
||||
"""Tests for annotation updates."""
|
||||
|
||||
def test_update_annotation_bbox(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation bounding box."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.25,
|
||||
height=0.06,
|
||||
bbox_x=480,
|
||||
bbox_y=320,
|
||||
bbox_width=200,
|
||||
bbox_height=48,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.x_center == 0.6
|
||||
assert ann.y_center == 0.4
|
||||
assert ann.bbox_x == 480
|
||||
assert ann.bbox_width == 200
|
||||
|
||||
def test_update_annotation_text(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation text value."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="INV-2024-002",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-002"
|
||||
|
||||
def test_update_annotation_class(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation class."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.class_id == 1
|
||||
assert ann.class_name == "invoice_date"
|
||||
|
||||
def test_update_nonexistent_annotation(self, patched_session):
|
||||
"""Test updating annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(uuid4()),
|
||||
text_value="new value",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestAnnotationRepositoryDelete:
|
||||
"""Tests for annotation deletion."""
|
||||
|
||||
def test_delete_annotation(self, patched_session, sample_annotation):
|
||||
"""Test deleting a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(sample_annotation.annotation_id))
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is None
|
||||
|
||||
def test_delete_nonexistent_annotation(self, patched_session):
|
||||
"""Test deleting annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_annotations_for_document(self, patched_session, sample_document):
|
||||
"""Test deleting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create multiple annotations
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=i,
|
||||
class_name=f"field_{i}",
|
||||
x_center=0.5,
|
||||
y_center=0.1 + i * 0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80 + i * 160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
|
||||
# Delete all
|
||||
count = repo.delete_for_document(str(sample_document.document_id))
|
||||
|
||||
assert count == 3
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_annotations_by_source(self, patched_session, sample_document):
|
||||
"""Test deleting annotations by source type."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create auto and manual annotations
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
source="auto",
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.2,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=160,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
# Delete only auto annotations
|
||||
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
|
||||
|
||||
assert count == 1
|
||||
|
||||
remaining = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].source == "manual"
|
||||
|
||||
|
||||
class TestAnnotationVerification:
|
||||
"""Tests for annotation verification."""
|
||||
|
||||
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test marking annotation as verified."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.is_verified is True
|
||||
assert ann.verified_by == admin_token.token
|
||||
assert ann.verified_at is not None
|
||||
|
||||
|
||||
class TestAnnotationOverride:
|
||||
"""Tests for annotation override functionality."""
|
||||
|
||||
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test overriding an auto-generated annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Override the annotation
|
||||
ann = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
admin_token.token,
|
||||
change_reason="Correcting OCR error",
|
||||
text_value="INV-2024-CORRECTED",
|
||||
x_center=0.55,
|
||||
)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-CORRECTED"
|
||||
assert ann.x_center == 0.55
|
||||
assert ann.source == "manual" # Changed from auto to manual
|
||||
assert ann.override_source == "auto"
|
||||
|
||||
|
||||
class TestAnnotationHistory:
|
||||
"""Tests for annotation history tracking."""
|
||||
|
||||
def test_create_history_record(self, patched_session, sample_annotation):
|
||||
"""Test creating annotation history record."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
history = repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
changed_by="test-user",
|
||||
)
|
||||
|
||||
assert history is not None
|
||||
assert history.action == "created"
|
||||
assert history.changed_by == "test-user"
|
||||
|
||||
def test_get_annotation_history(self, patched_session, sample_annotation):
|
||||
"""Test getting history for an annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create history records
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
)
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="updated",
|
||||
previous_value={"text_value": "INV-001"},
|
||||
new_value={"text_value": "INV-002"},
|
||||
)
|
||||
|
||||
history = repo.get_history(sample_annotation.annotation_id)
|
||||
|
||||
assert len(history) == 2
|
||||
# Should be ordered by created_at desc
|
||||
assert history[0].action == "updated"
|
||||
assert history[1].action == "created"
|
||||
|
||||
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotation history for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_document.document_id,
|
||||
action="created",
|
||||
new_value={"class_name": "invoice_number"},
|
||||
)
|
||||
|
||||
history = repo.get_document_history(sample_document.document_id)
|
||||
|
||||
assert len(history) >= 1
|
||||
assert all(h.document_id == sample_document.document_id for h in history)
|
||||
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Batch Upload Repository Integration Tests
|
||||
|
||||
Tests BatchUploadRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadCreate:
|
||||
"""Tests for batch upload creation."""
|
||||
|
||||
def test_create_batch_upload(self, patched_session, admin_token):
|
||||
"""Test creating a batch upload."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="test_batch.zip",
|
||||
file_size=10240,
|
||||
upload_source="api",
|
||||
)
|
||||
|
||||
assert batch is not None
|
||||
assert batch.batch_id is not None
|
||||
assert batch.filename == "test_batch.zip"
|
||||
assert batch.file_size == 10240
|
||||
assert batch.upload_source == "api"
|
||||
assert batch.status == "processing"
|
||||
assert batch.total_files == 0
|
||||
assert batch.processed_files == 0
|
||||
|
||||
def test_create_batch_upload_default_source(self, patched_session, admin_token):
|
||||
"""Test creating batch upload with default source."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="ui_batch.zip",
|
||||
file_size=5120,
|
||||
)
|
||||
|
||||
assert batch.upload_source == "ui"
|
||||
|
||||
|
||||
class TestBatchUploadRead:
|
||||
"""Tests for batch upload retrieval."""
|
||||
|
||||
def test_get_batch_upload(self, patched_session, sample_batch_upload):
|
||||
"""Test getting a batch upload by ID."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
|
||||
assert batch is not None
|
||||
assert batch.batch_id == sample_batch_upload.batch_id
|
||||
assert batch.filename == sample_batch_upload.filename
|
||||
|
||||
def test_get_nonexistent_batch_upload(self, patched_session):
|
||||
"""Test getting a batch upload that doesn't exist."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.get(uuid4())
|
||||
assert batch is None
|
||||
|
||||
def test_get_paginated_batch_uploads(self, patched_session, admin_token):
|
||||
"""Test paginated batch upload listing."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# Create multiple batches
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename=f"batch_{i}.zip",
|
||||
file_size=1024 * (i + 1),
|
||||
)
|
||||
|
||||
batches, total = repo.get_paginated(limit=3, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(batches) == 3
|
||||
|
||||
def test_get_paginated_with_offset(self, patched_session, admin_token):
|
||||
"""Test pagination offset."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename=f"batch_{i}.zip",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
page1, _ = repo.get_paginated(limit=2, offset=0)
|
||||
page2, _ = repo.get_paginated(limit=2, offset=2)
|
||||
|
||||
ids_page1 = {b.batch_id for b in page1}
|
||||
ids_page2 = {b.batch_id for b in page2}
|
||||
|
||||
assert len(ids_page1 & ids_page2) == 0
|
||||
|
||||
|
||||
class TestBatchUploadUpdate:
|
||||
"""Tests for batch upload updates."""
|
||||
|
||||
def test_update_batch_status(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch upload status."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
status="completed",
|
||||
total_files=10,
|
||||
processed_files=10,
|
||||
successful_files=8,
|
||||
failed_files=2,
|
||||
)
|
||||
|
||||
# Need to commit to see changes
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.status == "completed"
|
||||
assert batch.total_files == 10
|
||||
assert batch.successful_files == 8
|
||||
assert batch.failed_files == 2
|
||||
|
||||
def test_update_batch_with_error(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch upload with error message."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
status="failed",
|
||||
error_message="ZIP extraction failed",
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.status == "failed"
|
||||
assert batch.error_message == "ZIP extraction failed"
|
||||
|
||||
def test_update_batch_csv_info(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch with CSV information."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
csv_filename="manifest.csv",
|
||||
csv_row_count=100,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.csv_filename == "manifest.csv"
|
||||
assert batch.csv_row_count == 100
|
||||
|
||||
|
||||
class TestBatchUploadFiles:
|
||||
"""Tests for batch upload file management."""
|
||||
|
||||
def test_create_batch_file(self, patched_session, sample_batch_upload):
|
||||
"""Test creating a batch upload file record."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_001.pdf",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
assert file_record is not None
|
||||
assert file_record.file_id is not None
|
||||
assert file_record.filename == "invoice_001.pdf"
|
||||
assert file_record.batch_id == sample_batch_upload.batch_id
|
||||
assert file_record.status == "pending"
|
||||
|
||||
def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document):
|
||||
"""Test creating batch file linked to a document."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_linked.pdf",
|
||||
document_id=sample_document.document_id,
|
||||
status="completed",
|
||||
annotation_count=5,
|
||||
)
|
||||
|
||||
assert file_record.document_id == sample_document.document_id
|
||||
assert file_record.status == "completed"
|
||||
assert file_record.annotation_count == 5
|
||||
|
||||
def test_get_batch_files(self, patched_session, sample_batch_upload):
|
||||
"""Test getting all files for a batch."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# Create multiple files
|
||||
for i in range(3):
|
||||
repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename=f"file_{i}.pdf",
|
||||
)
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
|
||||
assert len(files) == 3
|
||||
assert all(f.batch_id == sample_batch_upload.batch_id for f in files)
|
||||
|
||||
def test_get_batch_files_empty(self, patched_session, sample_batch_upload):
|
||||
"""Test getting files for batch with no files."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
|
||||
assert files == []
|
||||
|
||||
def test_update_batch_file_status(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file status."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="test.pdf",
|
||||
)
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
status="completed",
|
||||
annotation_count=10,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.status == "completed"
|
||||
assert updated_file.annotation_count == 10
|
||||
|
||||
def test_update_batch_file_with_error(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file with error."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="corrupt.pdf",
|
||||
)
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Invalid PDF format",
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.status == "failed"
|
||||
assert updated_file.error_message == "Invalid PDF format"
|
||||
|
||||
def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file with CSV row data."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_with_csv.pdf",
|
||||
)
|
||||
|
||||
csv_data = {
|
||||
"invoice_number": "INV-001",
|
||||
"amount": "1500.00",
|
||||
"supplier": "Test Corp",
|
||||
}
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
csv_row_data=csv_data,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.csv_row_data == csv_data
|
||||
|
||||
|
||||
class TestBatchUploadWorkflow:
|
||||
"""Tests for complete batch upload workflows."""
|
||||
|
||||
def test_complete_batch_workflow(self, patched_session, admin_token):
|
||||
"""Test complete batch upload workflow."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# 1. Create batch
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="full_workflow.zip",
|
||||
file_size=50000,
|
||||
)
|
||||
|
||||
# 2. Update with file count
|
||||
repo.update(batch.batch_id, total_files=3)
|
||||
patched_session.commit()
|
||||
|
||||
# 3. Create file records
|
||||
file_ids = []
|
||||
for i in range(3):
|
||||
file_record = repo.create_file(
|
||||
batch_id=batch.batch_id,
|
||||
filename=f"doc_{i}.pdf",
|
||||
)
|
||||
file_ids.append(file_record.file_id)
|
||||
|
||||
# 4. Process files one by one
|
||||
for i, file_id in enumerate(file_ids):
|
||||
status = "completed" if i < 2 else "failed"
|
||||
repo.update_file(
|
||||
file_id,
|
||||
status=status,
|
||||
annotation_count=5 if status == "completed" else 0,
|
||||
)
|
||||
|
||||
# 5. Update batch progress
|
||||
repo.update(
|
||||
batch.batch_id,
|
||||
processed_files=3,
|
||||
successful_files=2,
|
||||
failed_files=1,
|
||||
status="partial",
|
||||
)
|
||||
patched_session.commit()
|
||||
|
||||
# Verify final state
|
||||
final_batch = repo.get(batch.batch_id)
|
||||
assert final_batch.status == "partial"
|
||||
assert final_batch.total_files == 3
|
||||
assert final_batch.processed_files == 3
|
||||
assert final_batch.successful_files == 2
|
||||
assert final_batch.failed_files == 1
|
||||
|
||||
files = repo.get_files(batch.batch_id)
|
||||
assert len(files) == 3
|
||||
completed = [f for f in files if f.status == "completed"]
|
||||
failed = [f for f in files if f.status == "failed"]
|
||||
assert len(completed) == 2
|
||||
assert len(failed) == 1
|
||||
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Dataset Repository Integration Tests
|
||||
|
||||
Tests DatasetRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepositoryCreate:
|
||||
"""Tests for dataset creation."""
|
||||
|
||||
def test_create_dataset(self, patched_session):
|
||||
"""Test creating a training dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(
|
||||
name="Test Dataset",
|
||||
description="Dataset for integration testing",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.name == "Test Dataset"
|
||||
assert dataset.description == "Dataset for integration testing"
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
assert dataset.status == "building"
|
||||
|
||||
def test_create_dataset_with_defaults(self, patched_session):
|
||||
"""Test creating dataset with default values."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(name="Minimal Dataset")
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
|
||||
|
||||
class TestDatasetRepositoryRead:
|
||||
"""Tests for dataset retrieval."""
|
||||
|
||||
def test_get_dataset_by_id(self, patched_session, sample_dataset):
|
||||
"""Test getting dataset by ID."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_id == sample_dataset.dataset_id
|
||||
assert dataset.name == sample_dataset.name
|
||||
|
||||
def test_get_nonexistent_dataset(self, patched_session):
|
||||
"""Test getting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(uuid4()))
|
||||
assert dataset is None
|
||||
|
||||
def test_get_paginated_datasets(self, patched_session):
|
||||
"""Test paginated dataset listing."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create multiple datasets
|
||||
for i in range(5):
|
||||
repo.create(name=f"Dataset {i}")
|
||||
|
||||
datasets, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(datasets) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session):
|
||||
"""Test filtering datasets by status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create datasets with different statuses
|
||||
d1 = repo.create(name="Building Dataset")
|
||||
repo.update_status(str(d1.dataset_id), "ready")
|
||||
|
||||
d2 = repo.create(name="Another Building Dataset")
|
||||
# stays as "building"
|
||||
|
||||
datasets, total = repo.get_paginated(status="ready")
|
||||
|
||||
assert total == 1
|
||||
assert datasets[0].status == "ready"
|
||||
|
||||
|
||||
class TestDatasetRepositoryUpdate:
|
||||
"""Tests for dataset updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
total_documents=100,
|
||||
total_images=150,
|
||||
total_annotations=500,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "ready"
|
||||
assert dataset.total_documents == 100
|
||||
assert dataset.total_images == 150
|
||||
assert dataset.total_annotations == 500
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status with error message."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="failed",
|
||||
error_message="Failed to build dataset: insufficient documents",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "failed"
|
||||
assert "insufficient documents" in dataset.error_message
|
||||
|
||||
def test_update_status_with_path(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset path."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
dataset_path="/datasets/test_dataset_2024",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_path == "/datasets/test_dataset_2024"
|
||||
|
||||
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test updating dataset training status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="running",
|
||||
active_training_task_id=str(sample_training_task.task_id),
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "running"
|
||||
assert dataset.active_training_task_id == sample_training_task.task_id
|
||||
|
||||
def test_update_training_status_completed(self, patched_session, sample_dataset):
|
||||
"""Test updating training status to completed updates main status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# First set to ready
|
||||
repo.update_status(str(sample_dataset.dataset_id), status="ready")
|
||||
|
||||
# Then complete training
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "completed"
|
||||
assert dataset.status == "trained"
|
||||
|
||||
|
||||
class TestDatasetDocuments:
|
||||
"""Tests for dataset document management."""
|
||||
|
||||
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
|
||||
"""Test adding documents to a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
documents_data = [
|
||||
{
|
||||
"document_id": str(multiple_documents[0].document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[1].document_id),
|
||||
"split": "train",
|
||||
"page_count": 2,
|
||||
"annotation_count": 8,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[2].document_id),
|
||||
"split": "val",
|
||||
"page_count": 1,
|
||||
"annotation_count": 3,
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
|
||||
|
||||
# Verify documents were added
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 3
|
||||
|
||||
train_docs = [d for d in docs if d.split == "train"]
|
||||
val_docs = [d for d in docs if d.split == "val"]
|
||||
|
||||
assert len(train_docs) == 2
|
||||
assert len(val_docs) == 1
|
||||
|
||||
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test getting documents from a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].document_id == sample_document.document_id
|
||||
assert docs[0].split == "train"
|
||||
assert docs[0].page_count == 1
|
||||
assert docs[0].annotation_count == 5
|
||||
|
||||
|
||||
class TestDatasetRepositoryDelete:
|
||||
"""Tests for dataset deletion."""
|
||||
|
||||
def test_delete_dataset(self, patched_session, sample_dataset):
|
||||
"""Test deleting a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(sample_dataset.dataset_id))
|
||||
assert result is True
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is None
|
||||
|
||||
def test_delete_nonexistent_dataset(self, patched_session):
|
||||
"""Test deleting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test deleting dataset also removes document links."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Add document to dataset
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Delete dataset
|
||||
repo.delete(str(sample_dataset.dataset_id))
|
||||
|
||||
# Document links should be gone
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
class TestActiveTrainingTasks:
|
||||
"""Tests for active training task queries."""
|
||||
|
||||
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test getting active training tasks for datasets."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Update task to running
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
task_repo = TrainingTaskRepository()
|
||||
task_repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
assert str(sample_dataset.dataset_id) in result
|
||||
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
|
||||
|
||||
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
|
||||
"""Test getting active training tasks returns empty when no tasks exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
# No training task exists for this dataset, so result should be empty
|
||||
assert str(sample_dataset.dataset_id) not in result
|
||||
assert result == {}
|
||||
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Document Repository Integration Tests
|
||||
|
||||
Tests DocumentRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
def ensure_utc(dt: datetime | None) -> datetime | None:
|
||||
"""Ensure datetime is timezone-aware (UTC).
|
||||
|
||||
PostgreSQL may return offset-naive datetimes. This helper
|
||||
converts them to UTC for proper comparison.
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
class TestDocumentRepositoryCreate:
|
||||
"""Tests for document creation."""
|
||||
|
||||
def test_create_document(self, patched_session):
|
||||
"""Test creating a document and retrieving it."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="test_invoice.pdf",
|
||||
file_size=2048,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/test_invoice.pdf",
|
||||
page_count=2,
|
||||
upload_source="api",
|
||||
category="invoice",
|
||||
)
|
||||
|
||||
assert doc_id is not None
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.filename == "test_invoice.pdf"
|
||||
assert doc.file_size == 2048
|
||||
assert doc.page_count == 2
|
||||
assert doc.upload_source == "api"
|
||||
assert doc.category == "invoice"
|
||||
assert doc.status == "pending"
|
||||
|
||||
def test_create_document_with_csv_values(self, patched_session):
|
||||
"""Test creating document with CSV field values."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
csv_values = {
|
||||
"invoice_number": "INV-001",
|
||||
"amount": "1500.00",
|
||||
"supplier_name": "Test Supplier AB",
|
||||
}
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="invoice_with_csv.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/invoice_with_csv.pdf",
|
||||
csv_field_values=csv_values,
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.csv_field_values == csv_values
|
||||
|
||||
def test_create_document_with_group_key(self, patched_session):
|
||||
"""Test creating document with group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="grouped_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/grouped_doc.pdf",
|
||||
group_key="batch-2024-01",
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.group_key == "batch-2024-01"
|
||||
|
||||
|
||||
class TestDocumentRepositoryRead:
|
||||
"""Tests for document retrieval."""
|
||||
|
||||
def test_get_nonexistent_document(self, patched_session):
|
||||
"""Test getting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.get(str(uuid4()))
|
||||
assert doc is None
|
||||
|
||||
def test_get_paginated_documents(self, patched_session, multiple_documents):
|
||||
"""Test paginated document listing."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(docs) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(status="labeled")
|
||||
|
||||
assert total == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(category="letter")
|
||||
|
||||
assert total == 1
|
||||
assert docs[0].category == "letter"
|
||||
|
||||
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
|
||||
"""Test pagination offset."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
|
||||
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
|
||||
|
||||
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
|
||||
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
|
||||
|
||||
assert len(doc_ids_page1 & doc_ids_page2) == 0
|
||||
|
||||
def test_get_by_ids(self, patched_session, multiple_documents):
|
||||
"""Test getting multiple documents by IDs."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
|
||||
docs = repo.get_by_ids(ids_to_fetch)
|
||||
|
||||
assert len(docs) == 2
|
||||
fetched_ids = {str(d.document_id) for d in docs}
|
||||
assert fetched_ids == set(ids_to_fetch)
|
||||
|
||||
|
||||
class TestDocumentRepositoryUpdate:
|
||||
"""Tests for document updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_document):
|
||||
"""Test updating document status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="labeled",
|
||||
auto_label_status="completed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.status == "labeled"
|
||||
assert doc.auto_label_status == "completed"
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_document):
|
||||
"""Test updating document status with error message."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error="OCR extraction failed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.auto_label_status == "failed"
|
||||
assert doc.auto_label_error == "OCR extraction failed"
|
||||
|
||||
def test_update_file_path(self, patched_session, sample_document):
|
||||
"""Test updating document file path."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
new_path = "/archive/2024/test_invoice.pdf"
|
||||
repo.update_file_path(str(sample_document.document_id), new_path)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.file_path == new_path
|
||||
|
||||
def test_update_group_key(self, patched_session, sample_document):
|
||||
"""Test updating document group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.group_key == "new-group-key"
|
||||
|
||||
def test_update_category(self, patched_session, sample_document):
|
||||
"""Test updating document category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.update_category(str(sample_document.document_id), "letter")
|
||||
|
||||
assert doc is not None
|
||||
assert doc.category == "letter"
|
||||
|
||||
|
||||
class TestDocumentRepositoryDelete:
|
||||
"""Tests for document deletion."""
|
||||
|
||||
def test_delete_document(self, patched_session, sample_document):
|
||||
"""Test deleting a document."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is None
|
||||
|
||||
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test deleting document also deletes its annotations."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
# Verify annotation is also deleted
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
ann_repo = AnnotationRepository()
|
||||
annotations = ann_repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_nonexistent_document(self, patched_session):
|
||||
"""Test deleting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDocumentRepositoryQueries:
|
||||
"""Tests for complex document queries."""
|
||||
|
||||
def test_count_by_status(self, patched_session, multiple_documents):
|
||||
"""Test counting documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
counts = repo.count_by_status()
|
||||
|
||||
assert counts.get("pending") == 2
|
||||
assert counts.get("labeled") == 2
|
||||
assert counts.get("exported") == 1
|
||||
|
||||
def test_get_categories(self, patched_session, multiple_documents):
|
||||
"""Test getting unique categories."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
categories = repo.get_categories()
|
||||
|
||||
assert "invoice" in categories
|
||||
assert "letter" in categories
|
||||
|
||||
def test_get_labeled_for_export(self, patched_session, multiple_documents):
|
||||
"""Test getting labeled documents for export."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs = repo.get_labeled_for_export()
|
||||
|
||||
assert len(docs) == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
|
||||
class TestDocumentAnnotationLocking:
|
||||
"""Tests for annotation locking mechanism."""
|
||||
|
||||
def test_acquire_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test acquiring annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is not None
|
||||
lock_until = ensure_utc(doc.annotation_lock_until)
|
||||
assert lock_until > datetime.now(timezone.utc)
|
||||
|
||||
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
|
||||
"""Test acquiring lock fails when already locked."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# First lock
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
|
||||
# Second lock attempt should fail
|
||||
result = repo.acquire_annotation_lock(str(sample_document.document_id))
|
||||
assert result is None
|
||||
|
||||
def test_release_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test releasing annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
doc = repo.release_annotation_lock(str(sample_document.document_id))
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is None
|
||||
|
||||
def test_extend_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test extending annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# Acquire initial lock
|
||||
initial_doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
|
||||
|
||||
# Extend lock
|
||||
extended_doc = repo.extend_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
additional_seconds=300,
|
||||
)
|
||||
|
||||
assert extended_doc is not None
|
||||
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
|
||||
assert extended_expiry > initial_expiry
|
||||
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Model Version Repository Integration Tests
|
||||
|
||||
Tests ModelVersionRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
|
||||
class TestModelVersionCreate:
|
||||
"""Tests for model version creation."""
|
||||
|
||||
def test_create_model_version(self, patched_session):
|
||||
"""Test creating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.create(
|
||||
version="1.0.0",
|
||||
name="Invoice Extractor v1",
|
||||
model_path="/models/invoice_v1.pt",
|
||||
description="Initial production model",
|
||||
metrics_mAP=0.92,
|
||||
metrics_precision=0.89,
|
||||
metrics_recall=0.85,
|
||||
document_count=1000,
|
||||
file_size=50000000,
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.version == "1.0.0"
|
||||
assert model.name == "Invoice Extractor v1"
|
||||
assert model.model_path == "/models/invoice_v1.pt"
|
||||
assert model.metrics_mAP == 0.92
|
||||
assert model.is_active is False
|
||||
assert model.status == "inactive"
|
||||
|
||||
def test_create_model_version_with_training_info(
|
||||
self, patched_session, sample_training_task, sample_dataset
|
||||
):
|
||||
"""Test creating model version linked to training task and dataset."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.create(
|
||||
version="1.1.0",
|
||||
name="Invoice Extractor v1.1",
|
||||
model_path="/models/invoice_v1.1.pt",
|
||||
task_id=sample_training_task.task_id,
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
training_config={"epochs": 100, "batch_size": 16},
|
||||
trained_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.task_id == sample_training_task.task_id
|
||||
assert model.dataset_id == sample_dataset.dataset_id
|
||||
assert model.training_config["epochs"] == 100
|
||||
|
||||
|
||||
class TestModelVersionRead:
|
||||
"""Tests for model version retrieval."""
|
||||
|
||||
def test_get_model_version_by_id(self, patched_session, sample_model_version):
|
||||
"""Test getting model version by ID."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.version_id == sample_model_version.version_id
|
||||
|
||||
def test_get_nonexistent_model_version(self, patched_session):
|
||||
"""Test getting model version that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.get(str(uuid4()))
|
||||
assert model is None
|
||||
|
||||
def test_get_paginated_model_versions(self, patched_session):
|
||||
"""Test paginated model version listing."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create multiple versions
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
version=f"1.{i}.0",
|
||||
name=f"Model v1.{i}",
|
||||
model_path=f"/models/model_v1.{i}.pt",
|
||||
)
|
||||
|
||||
models, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(models) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session):
|
||||
"""Test filtering model versions by status."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create active and inactive models
|
||||
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
|
||||
repo.activate(str(m1.version_id))
|
||||
|
||||
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
|
||||
|
||||
active_models, active_total = repo.get_paginated(status="active")
|
||||
inactive_models, inactive_total = repo.get_paginated(status="inactive")
|
||||
|
||||
assert active_total == 1
|
||||
assert inactive_total == 1
|
||||
|
||||
|
||||
class TestModelVersionActivation:
|
||||
"""Tests for model version activation."""
|
||||
|
||||
def test_activate_model_version(self, patched_session, sample_model_version):
|
||||
"""Test activating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.is_active is True
|
||||
assert model.status == "active"
|
||||
assert model.activated_at is not None
|
||||
|
||||
def test_activate_deactivates_others(self, patched_session):
|
||||
"""Test that activating one version deactivates others."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create and activate first model
|
||||
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
|
||||
repo.activate(str(m1.version_id))
|
||||
|
||||
# Create and activate second model
|
||||
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
|
||||
repo.activate(str(m2.version_id))
|
||||
|
||||
# Check first model is now inactive
|
||||
m1_after = repo.get(str(m1.version_id))
|
||||
assert m1_after.is_active is False
|
||||
assert m1_after.status == "inactive"
|
||||
|
||||
# Check second model is active
|
||||
m2_after = repo.get(str(m2.version_id))
|
||||
assert m2_after.is_active is True
|
||||
|
||||
def test_get_active_model(self, patched_session, sample_model_version):
|
||||
"""Test getting the currently active model."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Initially no active model
|
||||
active = repo.get_active()
|
||||
assert active is None
|
||||
|
||||
# Activate model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Now should return active model
|
||||
active = repo.get_active()
|
||||
assert active is not None
|
||||
assert active.version_id == sample_model_version.version_id
|
||||
|
||||
def test_deactivate_model_version(self, patched_session, sample_model_version):
|
||||
"""Test deactivating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# First activate
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Then deactivate
|
||||
model = repo.deactivate(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.is_active is False
|
||||
assert model.status == "inactive"
|
||||
|
||||
|
||||
class TestModelVersionUpdate:
|
||||
"""Tests for model version updates."""
|
||||
|
||||
def test_update_model_metadata(self, patched_session, sample_model_version):
|
||||
"""Test updating model version metadata."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(
|
||||
str(sample_model_version.version_id),
|
||||
name="Updated Model Name",
|
||||
description="Updated description",
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.name == "Updated Model Name"
|
||||
assert model.description == "Updated description"
|
||||
|
||||
def test_update_model_status(self, patched_session, sample_model_version):
|
||||
"""Test updating model version status."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(str(sample_model_version.version_id), status="deprecated")
|
||||
|
||||
assert model is not None
|
||||
assert model.status == "deprecated"
|
||||
|
||||
def test_update_nonexistent_model(self, patched_session):
|
||||
"""Test updating model that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(str(uuid4()), name="New Name")
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestModelVersionArchive:
|
||||
"""Tests for model version archiving."""
|
||||
|
||||
def test_archive_model_version(self, patched_session, sample_model_version):
|
||||
"""Test archiving an inactive model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.archive(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.status == "archived"
|
||||
|
||||
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
|
||||
"""Test that active model cannot be archived."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Activate the model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Try to archive
|
||||
model = repo.archive(str(sample_model_version.version_id))
|
||||
|
||||
assert model is None
|
||||
|
||||
# Verify model is still active
|
||||
current = repo.get(str(sample_model_version.version_id))
|
||||
assert current.status == "active"
|
||||
|
||||
|
||||
class TestModelVersionDelete:
|
||||
"""Tests for model version deletion."""
|
||||
|
||||
def test_delete_inactive_model(self, patched_session, sample_model_version):
|
||||
"""Test deleting an inactive model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
result = repo.delete(str(sample_model_version.version_id))
|
||||
|
||||
assert result is True
|
||||
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
assert model is None
|
||||
|
||||
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
|
||||
"""Test that active model cannot be deleted."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Activate the model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Try to delete
|
||||
result = repo.delete(str(sample_model_version.version_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
# Verify model still exists
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
assert model is not None
|
||||
|
||||
def test_delete_nonexistent_model(self, patched_session):
|
||||
"""Test deleting model that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestOnlyOneActiveModel:
|
||||
"""Tests to verify only one model can be active at a time."""
|
||||
|
||||
def test_single_active_model_constraint(self, patched_session):
|
||||
"""Test that only one model can be active at any time."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create multiple models
|
||||
models = []
|
||||
for i in range(3):
|
||||
m = repo.create(
|
||||
version=f"1.{i}.0",
|
||||
name=f"Model {i}",
|
||||
model_path=f"/models/model_{i}.pt",
|
||||
)
|
||||
models.append(m)
|
||||
|
||||
# Activate each model in sequence
|
||||
for model in models:
|
||||
repo.activate(str(model.version_id))
|
||||
|
||||
# Count active models
|
||||
all_models, _ = repo.get_paginated(status="active")
|
||||
assert len(all_models) == 1
|
||||
|
||||
# Verify it's the last one activated
|
||||
assert all_models[0].version_id == models[-1].version_id
|
||||
274
tests/integration/repositories/test_token_repo_integration.py
Normal file
274
tests/integration/repositories/test_token_repo_integration.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Token Repository Integration Tests
|
||||
|
||||
Tests TokenRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
|
||||
|
||||
class TestTokenCreate:
|
||||
"""Tests for token creation."""
|
||||
|
||||
def test_create_new_token(self, patched_session):
|
||||
"""Test creating a new admin token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
repo.create(
|
||||
token="new-test-token-abc123",
|
||||
name="New Test Admin",
|
||||
)
|
||||
|
||||
token = repo.get("new-test-token-abc123")
|
||||
assert token is not None
|
||||
assert token.token == "new-test-token-abc123"
|
||||
assert token.name == "New Test Admin"
|
||||
assert token.is_active is True
|
||||
assert token.expires_at is None
|
||||
|
||||
def test_create_token_with_expiration(self, patched_session):
|
||||
"""Test creating token with expiration date."""
|
||||
repo = TokenRepository()
|
||||
expiry = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
|
||||
repo.create(
|
||||
token="expiring-token-xyz789",
|
||||
name="Expiring Token",
|
||||
expires_at=expiry,
|
||||
)
|
||||
|
||||
token = repo.get("expiring-token-xyz789")
|
||||
assert token is not None
|
||||
assert token.expires_at is not None
|
||||
|
||||
def test_create_updates_existing_token(self, patched_session, admin_token):
|
||||
"""Test creating with existing token updates it."""
|
||||
repo = TokenRepository()
|
||||
new_expiry = datetime.now(timezone.utc) + timedelta(days=60)
|
||||
|
||||
repo.create(
|
||||
token=admin_token.token,
|
||||
name="Updated Admin Name",
|
||||
expires_at=new_expiry,
|
||||
)
|
||||
|
||||
token = repo.get(admin_token.token)
|
||||
assert token is not None
|
||||
assert token.name == "Updated Admin Name"
|
||||
assert token.is_active is True
|
||||
|
||||
|
||||
class TestTokenValidation:
|
||||
"""Tests for token validation."""
|
||||
|
||||
def test_is_valid_active_token(self, patched_session, admin_token):
|
||||
"""Test that active token is valid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.is_valid(admin_token.token)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_valid_nonexistent_token(self, patched_session):
|
||||
"""Test that nonexistent token is invalid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.is_valid("nonexistent-token-12345")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_deactivated_token(self, patched_session, admin_token):
|
||||
"""Test that deactivated token is invalid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
repo.deactivate(admin_token.token)
|
||||
result = repo.is_valid(admin_token.token)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_expired_token(self, patched_session):
|
||||
"""Test that expired token is invalid."""
|
||||
repo = TokenRepository()
|
||||
past_expiry = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
repo.create(
|
||||
token="expired-token-test",
|
||||
name="Expired Token",
|
||||
expires_at=past_expiry,
|
||||
)
|
||||
|
||||
result = repo.is_valid("expired-token-test")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_not_yet_expired_token(self, patched_session):
|
||||
"""Test that not-yet-expired token is valid."""
|
||||
repo = TokenRepository()
|
||||
future_expiry = datetime.now(timezone.utc) + timedelta(days=7)
|
||||
|
||||
repo.create(
|
||||
token="valid-expiring-token",
|
||||
name="Valid Expiring Token",
|
||||
expires_at=future_expiry,
|
||||
)
|
||||
|
||||
result = repo.is_valid("valid-expiring-token")
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestTokenGet:
|
||||
"""Tests for token retrieval."""
|
||||
|
||||
def test_get_existing_token(self, patched_session, admin_token):
|
||||
"""Test getting an existing token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
token = repo.get(admin_token.token)
|
||||
|
||||
assert token is not None
|
||||
assert token.token == admin_token.token
|
||||
assert token.name == admin_token.name
|
||||
|
||||
def test_get_nonexistent_token(self, patched_session):
|
||||
"""Test getting a token that doesn't exist."""
|
||||
repo = TokenRepository()
|
||||
|
||||
token = repo.get("nonexistent-token-xyz")
|
||||
|
||||
assert token is None
|
||||
|
||||
|
||||
class TestTokenDeactivate:
|
||||
"""Tests for token deactivation."""
|
||||
|
||||
def test_deactivate_existing_token(self, patched_session, admin_token):
|
||||
"""Test deactivating an existing token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.deactivate(admin_token.token)
|
||||
|
||||
assert result is True
|
||||
token = repo.get(admin_token.token)
|
||||
assert token is not None
|
||||
assert token.is_active is False
|
||||
|
||||
def test_deactivate_nonexistent_token(self, patched_session):
|
||||
"""Test deactivating a token that doesn't exist."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.deactivate("nonexistent-token-abc")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_reactivate_deactivated_token(self, patched_session, admin_token):
|
||||
"""Test reactivating a deactivated token via create."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Deactivate first
|
||||
repo.deactivate(admin_token.token)
|
||||
assert repo.is_valid(admin_token.token) is False
|
||||
|
||||
# Reactivate via create
|
||||
repo.create(
|
||||
token=admin_token.token,
|
||||
name="Reactivated Admin",
|
||||
)
|
||||
|
||||
assert repo.is_valid(admin_token.token) is True
|
||||
|
||||
|
||||
class TestTokenUsageTracking:
|
||||
"""Tests for token usage tracking."""
|
||||
|
||||
def test_update_usage(self, patched_session, admin_token):
|
||||
"""Test updating token last used timestamp."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Initially last_used_at might be None
|
||||
initial_token = repo.get(admin_token.token)
|
||||
initial_last_used = initial_token.last_used_at
|
||||
|
||||
repo.update_usage(admin_token.token)
|
||||
|
||||
updated_token = repo.get(admin_token.token)
|
||||
assert updated_token.last_used_at is not None
|
||||
if initial_last_used:
|
||||
assert updated_token.last_used_at >= initial_last_used
|
||||
|
||||
def test_update_usage_nonexistent_token(self, patched_session):
|
||||
"""Test updating usage for nonexistent token does nothing."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Should not raise, just does nothing
|
||||
repo.update_usage("nonexistent-token-usage")
|
||||
|
||||
token = repo.get("nonexistent-token-usage")
|
||||
assert token is None
|
||||
|
||||
|
||||
class TestTokenWorkflow:
|
||||
"""Tests for complete token workflows."""
|
||||
|
||||
def test_full_token_lifecycle(self, patched_session):
|
||||
"""Test complete token lifecycle: create, validate, use, deactivate."""
|
||||
repo = TokenRepository()
|
||||
token_str = "lifecycle-test-token"
|
||||
|
||||
# 1. Create token
|
||||
repo.create(token=token_str, name="Lifecycle Token")
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
# 2. Use token
|
||||
repo.update_usage(token_str)
|
||||
token = repo.get(token_str)
|
||||
assert token.last_used_at is not None
|
||||
|
||||
# 3. Update token info
|
||||
new_expiry = datetime.now(timezone.utc) + timedelta(days=90)
|
||||
repo.create(
|
||||
token=token_str,
|
||||
name="Updated Lifecycle Token",
|
||||
expires_at=new_expiry,
|
||||
)
|
||||
token = repo.get(token_str)
|
||||
assert token.name == "Updated Lifecycle Token"
|
||||
|
||||
# 4. Deactivate token
|
||||
result = repo.deactivate(token_str)
|
||||
assert result is True
|
||||
assert repo.is_valid(token_str) is False
|
||||
|
||||
# 5. Reactivate token
|
||||
repo.create(token=token_str, name="Reactivated Token")
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
def test_multiple_tokens(self, patched_session):
|
||||
"""Test managing multiple tokens."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Create multiple tokens
|
||||
tokens = [
|
||||
("token-a", "Admin A"),
|
||||
("token-b", "Admin B"),
|
||||
("token-c", "Admin C"),
|
||||
]
|
||||
|
||||
for token_str, name in tokens:
|
||||
repo.create(token=token_str, name=name)
|
||||
|
||||
# Verify all are valid
|
||||
for token_str, _ in tokens:
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
# Deactivate one
|
||||
repo.deactivate("token-b")
|
||||
|
||||
# Verify states
|
||||
assert repo.is_valid("token-a") is True
|
||||
assert repo.is_valid("token-b") is False
|
||||
assert repo.is_valid("token-c") is True
|
||||
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Training Task Repository Integration Tests
|
||||
|
||||
Tests TrainingTaskRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskCreate:
|
||||
"""Tests for training task creation."""
|
||||
|
||||
def test_create_training_task(self, patched_session, admin_token):
|
||||
"""Test creating a training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Test Training Task",
|
||||
task_type="train",
|
||||
description="Integration test training task",
|
||||
config={"epochs": 100, "batch_size": 16},
|
||||
)
|
||||
|
||||
assert task_id is not None
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.name == "Test Training Task"
|
||||
assert task.task_type == "train"
|
||||
assert task.status == "pending"
|
||||
assert task.config["epochs"] == 100
|
||||
|
||||
def test_create_scheduled_task(self, patched_session, admin_token):
|
||||
"""Test creating a scheduled training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Scheduled Task",
|
||||
scheduled_at=scheduled_time,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.status == "scheduled"
|
||||
assert task.scheduled_at is not None
|
||||
|
||||
def test_create_recurring_task(self, patched_session, admin_token):
|
||||
"""Test creating a recurring training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Recurring Task",
|
||||
cron_expression="0 2 * * *",
|
||||
is_recurring=True,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.is_recurring is True
|
||||
assert task.cron_expression == "0 2 * * *"
|
||||
|
||||
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test creating task linked to a dataset."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Dataset Training Task",
|
||||
dataset_id=str(sample_dataset.dataset_id),
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.dataset_id == sample_dataset.dataset_id
|
||||
|
||||
|
||||
class TestTrainingTaskRead:
|
||||
"""Tests for training task retrieval."""
|
||||
|
||||
def test_get_task_by_id(self, patched_session, sample_training_task):
|
||||
"""Test getting task by ID."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
|
||||
assert task is not None
|
||||
assert task.task_id == sample_training_task.task_id
|
||||
|
||||
def test_get_nonexistent_task(self, patched_session):
|
||||
"""Test getting task that doesn't exist."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(uuid4()))
|
||||
assert task is None
|
||||
|
||||
def test_get_paginated_tasks(self, patched_session, admin_token):
|
||||
"""Test paginated task listing."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks
|
||||
for i in range(5):
|
||||
repo.create(admin_token=admin_token.token, name=f"Task {i}")
|
||||
|
||||
tasks, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(tasks) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
|
||||
"""Test filtering tasks by status."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create tasks with different statuses
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
tasks, total = repo.get_paginated(status="running")
|
||||
|
||||
assert total == 1
|
||||
assert tasks[0].status == "running"
|
||||
|
||||
def test_get_pending_tasks(self, patched_session, admin_token):
|
||||
"""Test getting pending tasks ready to run."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create pending task
|
||||
repo.create(admin_token=admin_token.token, name="Ready Task")
|
||||
|
||||
# Create scheduled task in the past (should be included)
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Past Scheduled Task",
|
||||
scheduled_at=past_time,
|
||||
)
|
||||
|
||||
# Create scheduled task in the future (should not be included)
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Future Scheduled Task",
|
||||
scheduled_at=future_time,
|
||||
)
|
||||
|
||||
pending = repo.get_pending()
|
||||
|
||||
# Should include pending and past scheduled, not future scheduled
|
||||
assert len(pending) >= 2
|
||||
names = [t.name for t in pending]
|
||||
assert "Ready Task" in names
|
||||
assert "Past Scheduled Task" in names
|
||||
|
||||
def test_get_running_task(self, patched_session, admin_token):
|
||||
"""Test getting currently running task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
running = repo.get_running()
|
||||
|
||||
assert running is not None
|
||||
assert running.status == "running"
|
||||
|
||||
def test_get_running_task_none(self, patched_session, admin_token):
|
||||
"""Test getting running task when none is running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
running = repo.get_running()
|
||||
assert running is None
|
||||
|
||||
|
||||
class TestTrainingTaskUpdate:
|
||||
"""Tests for training task updates."""
|
||||
|
||||
def test_update_status_to_running(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "running"
|
||||
assert task.started_at is not None
|
||||
|
||||
def test_update_status_to_completed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to completed."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"completed",
|
||||
result_metrics=metrics,
|
||||
model_path="/models/trained_model.pt",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "completed"
|
||||
assert task.completed_at is not None
|
||||
assert task.result_metrics["mAP"] == 0.92
|
||||
assert task.model_path == "/models/trained_model.pt"
|
||||
|
||||
def test_update_status_to_failed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to failed with error message."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"failed",
|
||||
error_message="CUDA out of memory",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "failed"
|
||||
assert task.completed_at is not None
|
||||
assert "CUDA out of memory" in task.error_message
|
||||
|
||||
def test_cancel_pending_task(self, patched_session, sample_training_task):
|
||||
"""Test cancelling a pending task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is True
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "cancelled"
|
||||
|
||||
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
|
||||
"""Test that running task cannot be cancelled."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task.status == "running"
|
||||
|
||||
|
||||
class TestTrainingLogs:
|
||||
"""Tests for training log management."""
|
||||
|
||||
def test_add_log_entry(self, patched_session, sample_training_task):
|
||||
"""Test adding a training log entry."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message="Starting training...",
|
||||
details={"epoch": 1, "batch": 0},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 1
|
||||
assert logs[0].level == "INFO"
|
||||
assert logs[0].message == "Starting training..."
|
||||
|
||||
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
|
||||
"""Test adding multiple log entries."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(5):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Epoch {i} completed",
|
||||
details={"epoch": i, "loss": 0.5 - i * 0.1},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 5
|
||||
|
||||
def test_get_logs_pagination(self, patched_session, sample_training_task):
|
||||
"""Test paginated log retrieval."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(10):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Log entry {i}",
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
|
||||
assert len(logs) == 5
|
||||
|
||||
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
|
||||
assert len(logs_page2) == 5
|
||||
|
||||
|
||||
class TestDocumentLinks:
|
||||
"""Tests for training document link management."""
|
||||
|
||||
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
|
||||
"""Test creating a document link."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
link = repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=sample_document.document_id,
|
||||
annotation_snapshot={"count": 5, "verified": 3},
|
||||
)
|
||||
|
||||
assert link is not None
|
||||
assert link.task_id == sample_training_task.task_id
|
||||
assert link.document_id == sample_document.document_id
|
||||
assert link.annotation_snapshot["count"] == 5
|
||||
|
||||
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
|
||||
"""Test getting all document links for a task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for doc in multiple_documents[:3]:
|
||||
repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=doc.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_links(sample_training_task.task_id)
|
||||
assert len(links) == 3
|
||||
|
||||
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
|
||||
"""Test getting training tasks that used a document."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks using the same document
|
||||
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
|
||||
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
|
||||
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task1_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task2_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_training_tasks(sample_document.document_id)
|
||||
assert len(links) == 2
|
||||
1
tests/integration/services/__init__.py
Normal file
1
tests/integration/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Service integration tests."""
|
||||
497
tests/integration/services/test_dashboard_service_integration.py
Normal file
497
tests/integration/services/test_dashboard_service_integration.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""
|
||||
Dashboard Service Integration Tests
|
||||
|
||||
Tests DashboardStatsService and DashboardActivityService with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AnnotationHistory,
|
||||
ModelVersion,
|
||||
TrainingDataset,
|
||||
TrainingTask,
|
||||
)
|
||||
from inference.web.services.dashboard_service import (
|
||||
DashboardStatsService,
|
||||
DashboardActivityService,
|
||||
is_annotation_complete,
|
||||
IDENTIFIER_CLASS_IDS,
|
||||
PAYMENT_CLASS_IDS,
|
||||
)
|
||||
|
||||
|
||||
class TestIsAnnotationComplete:
|
||||
"""Tests for is_annotation_complete function."""
|
||||
|
||||
def test_complete_with_invoice_number_and_bankgiro(self):
|
||||
"""Test complete with invoice_number (0) and bankgiro (4)."""
|
||||
annotations = [
|
||||
{"class_id": 0}, # invoice_number
|
||||
{"class_id": 4}, # bankgiro
|
||||
]
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_complete_with_ocr_number_and_plusgiro(self):
|
||||
"""Test complete with ocr_number (3) and plusgiro (5)."""
|
||||
annotations = [
|
||||
{"class_id": 3}, # ocr_number
|
||||
{"class_id": 5}, # plusgiro
|
||||
]
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_incomplete_missing_identifier(self):
|
||||
"""Test incomplete when missing identifier."""
|
||||
annotations = [
|
||||
{"class_id": 4}, # bankgiro only
|
||||
]
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_incomplete_missing_payment(self):
|
||||
"""Test incomplete when missing payment."""
|
||||
annotations = [
|
||||
{"class_id": 0}, # invoice_number only
|
||||
]
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_incomplete_empty_annotations(self):
|
||||
"""Test incomplete with empty annotations."""
|
||||
assert is_annotation_complete([]) is False
|
||||
|
||||
def test_complete_with_multiple_fields(self):
|
||||
"""Test complete with multiple fields."""
|
||||
annotations = [
|
||||
{"class_id": 0}, # invoice_number
|
||||
{"class_id": 1}, # invoice_date
|
||||
{"class_id": 3}, # ocr_number
|
||||
{"class_id": 4}, # bankgiro
|
||||
{"class_id": 5}, # plusgiro
|
||||
{"class_id": 6}, # amount
|
||||
]
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
|
||||
class TestDashboardStatsService:
|
||||
"""Tests for DashboardStatsService."""
|
||||
|
||||
def test_get_stats_empty_database(self, patched_session):
|
||||
"""Test stats with empty database."""
|
||||
service = DashboardStatsService()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["total_documents"] == 0
|
||||
assert stats["annotation_complete"] == 0
|
||||
assert stats["annotation_incomplete"] == 0
|
||||
assert stats["pending"] == 0
|
||||
assert stats["completeness_rate"] == 0.0
|
||||
|
||||
def test_get_stats_with_documents(self, patched_session, admin_token):
|
||||
"""Test stats with various document states."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create documents with different statuses
|
||||
docs = []
|
||||
for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status=status,
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
docs.append(doc)
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["total_documents"] == 5
|
||||
assert stats["pending"] == 2 # pending + auto_labeling
|
||||
|
||||
def test_get_stats_complete_annotations(self, patched_session, admin_token):
|
||||
"""Test completeness calculation with proper annotations."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create a labeled document with complete annotations
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="complete_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/complete_doc.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
# Add identifier annotation (invoice_number = class_id 0)
|
||||
ann1 = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-001",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(ann1)
|
||||
|
||||
# Add payment annotation (bankgiro = class_id 4)
|
||||
ann2 = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=4,
|
||||
class_name="bankgiro",
|
||||
x_center=0.5,
|
||||
y_center=0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="123-4567",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(ann2)
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["annotation_complete"] == 1
|
||||
assert stats["annotation_incomplete"] == 0
|
||||
assert stats["completeness_rate"] == 100.0
|
||||
|
||||
def test_get_stats_incomplete_annotations(self, patched_session, admin_token):
|
||||
"""Test completeness with incomplete annotations."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create a labeled document missing payment annotation
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="incomplete_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/incomplete_doc.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
# Add only identifier annotation (missing payment)
|
||||
ann = AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=doc.document_id,
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-001",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(ann)
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["annotation_complete"] == 0
|
||||
assert stats["annotation_incomplete"] == 1
|
||||
assert stats["completeness_rate"] == 0.0
|
||||
|
||||
def test_get_stats_mixed_completeness(self, patched_session, admin_token):
|
||||
"""Test stats with mix of complete and incomplete documents."""
|
||||
service = DashboardStatsService()
|
||||
session = patched_session
|
||||
|
||||
# Create 2 labeled documents
|
||||
docs = []
|
||||
for i in range(2):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"mixed_doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/mixed_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status="labeled",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
docs.append(doc)
|
||||
session.commit()
|
||||
|
||||
# First document: complete (has identifier + payment)
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=docs[0].document_id,
|
||||
page_number=1,
|
||||
class_id=0, # invoice_number
|
||||
class_name="invoice_number",
|
||||
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=docs[0].document_id,
|
||||
page_number=1,
|
||||
class_id=4, # bankgiro
|
||||
class_name="bankgiro",
|
||||
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
|
||||
# Second document: incomplete (missing payment)
|
||||
session.add(AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=docs[1].document_id,
|
||||
page_number=1,
|
||||
class_id=0, # invoice_number only
|
||||
class_name="invoice_number",
|
||||
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
|
||||
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
))
|
||||
session.commit()
|
||||
|
||||
stats = service.get_stats()
|
||||
|
||||
assert stats["annotation_complete"] == 1
|
||||
assert stats["annotation_incomplete"] == 1
|
||||
assert stats["completeness_rate"] == 50.0
|
||||
|
||||
|
||||
class TestDashboardActivityService:
|
||||
"""Tests for DashboardActivityService."""
|
||||
|
||||
def test_get_recent_activities_empty(self, patched_session):
|
||||
"""Test activities with empty database."""
|
||||
service = DashboardActivityService()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
assert activities == []
|
||||
|
||||
def test_get_recent_activities_document_uploads(self, patched_session, admin_token):
|
||||
"""Test activities include document uploads."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create documents
|
||||
for i in range(3):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"activity_doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/activity_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
upload_activities = [a for a in activities if a["type"] == "document_uploaded"]
|
||||
assert len(upload_activities) == 3
|
||||
|
||||
def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test activities include annotation overrides."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create annotation history with override
|
||||
history = AnnotationHistory(
|
||||
history_id=uuid4(),
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_document.document_id,
|
||||
action="override",
|
||||
previous_value={"text_value": "OLD-001"},
|
||||
new_value={"text_value": "NEW-001", "class_name": "invoice_number"},
|
||||
changed_by="test-admin",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(history)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
override_activities = [a for a in activities if a["type"] == "annotation_modified"]
|
||||
assert len(override_activities) >= 1
|
||||
|
||||
def test_get_recent_activities_training_completed(self, patched_session, sample_training_task):
|
||||
"""Test activities include training completions."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Update training task to completed
|
||||
sample_training_task.status = "completed"
|
||||
sample_training_task.metrics_mAP = 0.85
|
||||
sample_training_task.updated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_training_task)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
training_activities = [a for a in activities if a["type"] == "training_completed"]
|
||||
assert len(training_activities) >= 1
|
||||
assert "mAP" in training_activities[0]["metadata"]
|
||||
|
||||
def test_get_recent_activities_training_failed(self, patched_session, sample_training_task):
|
||||
"""Test activities include training failures."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Update training task to failed
|
||||
sample_training_task.status = "failed"
|
||||
sample_training_task.error_message = "CUDA out of memory"
|
||||
sample_training_task.updated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_training_task)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
failed_activities = [a for a in activities if a["type"] == "training_failed"]
|
||||
assert len(failed_activities) >= 1
|
||||
assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory"
|
||||
|
||||
def test_get_recent_activities_model_activated(self, patched_session, sample_model_version):
|
||||
"""Test activities include model activations."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Activate model
|
||||
sample_model_version.is_active = True
|
||||
sample_model_version.activated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_model_version)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
activation_activities = [a for a in activities if a["type"] == "model_activated"]
|
||||
assert len(activation_activities) >= 1
|
||||
assert activation_activities[0]["metadata"]["version"] == sample_model_version.version
|
||||
|
||||
def test_get_recent_activities_limit(self, patched_session, admin_token):
|
||||
"""Test activity limit parameter."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create many documents
|
||||
for i in range(20):
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename=f"limit_doc_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/limit_doc_{i}.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities(limit=5)
|
||||
|
||||
assert len(activities) <= 5
|
||||
|
||||
def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task):
|
||||
"""Test activities are sorted by timestamp descending."""
|
||||
service = DashboardActivityService()
|
||||
session = patched_session
|
||||
|
||||
# Create document
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="sorted_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/sorted_doc.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
upload_source="ui",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(doc)
|
||||
|
||||
# Complete training task
|
||||
sample_training_task.status = "completed"
|
||||
sample_training_task.metrics_mAP = 0.90
|
||||
sample_training_task.updated_at = datetime.now(timezone.utc)
|
||||
session.add(sample_training_task)
|
||||
session.commit()
|
||||
|
||||
activities = service.get_recent_activities()
|
||||
|
||||
# Verify sorted by timestamp DESC
|
||||
timestamps = [a["timestamp"] for a in activities]
|
||||
assert timestamps == sorted(timestamps, reverse=True)
|
||||
453
tests/integration/services/test_dataset_builder_integration.py
Normal file
453
tests/integration/services/test_dataset_builder_integration.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Dataset Builder Service Integration Tests
|
||||
|
||||
Tests DatasetBuilder with real file operations and repository interactions.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_builder(patched_session, temp_dataset_dir):
|
||||
"""Create a DatasetBuilder with real repositories."""
|
||||
return DatasetBuilder(
|
||||
datasets_repo=DatasetRepository(),
|
||||
documents_repo=DocumentRepository(),
|
||||
annotations_repo=AnnotationRepository(),
|
||||
base_dir=temp_dataset_dir,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_images_dir(temp_upload_dir):
|
||||
"""Create a directory for admin images."""
|
||||
images_dir = temp_upload_dir / "admin_images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
return images_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir):
|
||||
"""Create documents with annotations and corresponding image files."""
|
||||
documents = []
|
||||
doc_repo = DocumentRepository()
|
||||
ann_repo = AnnotationRepository()
|
||||
|
||||
for i in range(5):
|
||||
# Create document
|
||||
doc_id = doc_repo.create(
|
||||
filename=f"invoice_{i}.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path=f"/uploads/invoice_{i}.pdf",
|
||||
page_count=2,
|
||||
category="invoice",
|
||||
group_key=f"group_{i % 2}", # Two groups
|
||||
)
|
||||
|
||||
# Create image files for each page
|
||||
doc_dir = admin_images_dir / doc_id
|
||||
doc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for page in range(1, 3):
|
||||
image_path = doc_dir / f"page_{page}.png"
|
||||
# Create a minimal fake PNG
|
||||
image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
|
||||
# Create annotations
|
||||
for j in range(3):
|
||||
ann_repo.create(
|
||||
document_id=doc_id,
|
||||
page_number=1,
|
||||
class_id=j,
|
||||
class_name=f"field_{j}",
|
||||
x_center=0.5,
|
||||
y_center=0.1 + j * 0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80 + j * 160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value=f"value_{j}",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
doc = doc_repo.get(doc_id)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class TestDatasetBuilderBasic:
|
||||
"""Tests for basic dataset building operations."""
|
||||
|
||||
def test_build_dataset_creates_directory_structure(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that building creates proper directory structure."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Test Dataset")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Check directory structure
|
||||
assert (dataset_dir / "images" / "train").exists()
|
||||
assert (dataset_dir / "images" / "val").exists()
|
||||
assert (dataset_dir / "images" / "test").exists()
|
||||
assert (dataset_dir / "labels" / "train").exists()
|
||||
assert (dataset_dir / "labels" / "val").exists()
|
||||
assert (dataset_dir / "labels" / "test").exists()
|
||||
|
||||
def test_build_dataset_copies_images(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that images are copied to dataset directory."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Image Copy Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
result = dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Count total images across all splits
|
||||
total_images = 0
|
||||
for split in ["train", "val", "test"]:
|
||||
images = list((dataset_dir / "images" / split).glob("*.png"))
|
||||
total_images += len(images)
|
||||
|
||||
# 5 docs * 2 pages = 10 images
|
||||
assert total_images == 10
|
||||
assert result["total_images"] == 10
|
||||
|
||||
def test_build_dataset_generates_labels(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that YOLO label files are generated."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Label Generation Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Count total label files
|
||||
total_labels = 0
|
||||
for split in ["train", "val", "test"]:
|
||||
labels = list((dataset_dir / "labels" / split).glob("*.txt"))
|
||||
total_labels += len(labels)
|
||||
|
||||
# Same count as images
|
||||
assert total_labels == 10
|
||||
|
||||
def test_build_dataset_generates_data_yaml(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that data.yaml is generated correctly."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="YAML Generation Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
|
||||
assert yaml_path.exists()
|
||||
|
||||
with open(yaml_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
assert data["train"] == "images/train"
|
||||
assert data["val"] == "images/val"
|
||||
assert data["test"] == "images/test"
|
||||
assert "nc" in data
|
||||
assert "names" in data
|
||||
|
||||
|
||||
class TestDatasetBuilderSplits:
|
||||
"""Tests for train/val/test split assignment."""
|
||||
|
||||
def test_split_ratio_respected(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that split ratios are approximately respected."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Split Ratio Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.6,
|
||||
val_ratio=0.2,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Check document assignments in database
|
||||
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
|
||||
|
||||
splits = {"train": 0, "val": 0, "test": 0}
|
||||
for doc in dataset_docs:
|
||||
splits[doc.split] += 1
|
||||
|
||||
# With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1
|
||||
# Due to rounding and group constraints, allow some variation
|
||||
assert splits["train"] >= 2
|
||||
assert splits["val"] >= 1 or splits["test"] >= 1
|
||||
|
||||
def test_same_seed_same_split(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that same seed produces same split."""
|
||||
dataset_repo = DatasetRepository()
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
# Build first dataset
|
||||
dataset1 = dataset_repo.create(name="Seed Test 1")
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset1.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=12345,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Build second dataset with same seed
|
||||
dataset2 = dataset_repo.create(name="Seed Test 2")
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset2.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=12345,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
# Compare splits
|
||||
docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))}
|
||||
docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))}
|
||||
|
||||
assert docs1 == docs2
|
||||
|
||||
|
||||
class TestDatasetBuilderDatabase:
|
||||
"""Tests for database interactions."""
|
||||
|
||||
def test_updates_dataset_status(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
|
||||
):
|
||||
"""Test that dataset status is updated after build."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Status Update Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
updated = dataset_repo.get(str(dataset.dataset_id))
|
||||
|
||||
assert updated.status == "ready"
|
||||
assert updated.total_documents == 5
|
||||
assert updated.total_images == 10
|
||||
assert updated.total_annotations > 0
|
||||
assert updated.dataset_path is not None
|
||||
|
||||
def test_records_document_assignments(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
|
||||
):
|
||||
"""Test that document assignments are recorded in database."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Assignment Recording Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
|
||||
|
||||
assert len(dataset_docs) == 5
|
||||
|
||||
for doc in dataset_docs:
|
||||
assert doc.split in ["train", "val", "test"]
|
||||
assert doc.page_count > 0
|
||||
|
||||
|
||||
class TestDatasetBuilderErrors:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that building fails with empty document list."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Empty Docs Test")
|
||||
|
||||
with pytest.raises(ValueError, match="No valid documents"):
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that building fails with nonexistent document IDs."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Invalid IDs Test")
|
||||
|
||||
fake_ids = [str(uuid4()) for _ in range(3)]
|
||||
|
||||
with pytest.raises(ValueError, match="No valid documents"):
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=fake_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session):
|
||||
"""Test that dataset status is set to failed on error."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Failure Status Test")
|
||||
|
||||
try:
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[],
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
updated = dataset_repo.get(str(dataset.dataset_id))
|
||||
assert updated.status == "failed"
|
||||
assert updated.error_message is not None
|
||||
|
||||
|
||||
class TestLabelFileFormat:
|
||||
"""Tests for YOLO label file format."""
|
||||
|
||||
def test_label_file_format(
|
||||
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
|
||||
):
|
||||
"""Test that label files are in correct YOLO format."""
|
||||
dataset_repo = DatasetRepository()
|
||||
dataset = dataset_repo.create(name="Label Format Test")
|
||||
|
||||
doc_ids = [str(d.document_id) for d in documents_with_annotations]
|
||||
|
||||
dataset_builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=doc_ids,
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
admin_images_dir=admin_images_dir,
|
||||
)
|
||||
|
||||
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
|
||||
|
||||
# Find a label file with content
|
||||
label_files = []
|
||||
for split in ["train", "val", "test"]:
|
||||
label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt")))
|
||||
|
||||
# Check at least one label file has correct format
|
||||
found_valid_label = False
|
||||
for label_file in label_files:
|
||||
content = label_file.read_text().strip()
|
||||
if content:
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
parts = line.split()
|
||||
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}"
|
||||
|
||||
class_id = int(parts[0])
|
||||
x_center = float(parts[1])
|
||||
y_center = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
|
||||
assert 0 <= class_id < 10
|
||||
assert 0 <= x_center <= 1
|
||||
assert 0 <= y_center <= 1
|
||||
assert 0 <= width <= 1
|
||||
assert 0 <= height <= 1
|
||||
|
||||
found_valid_label = True
|
||||
break
|
||||
|
||||
assert found_valid_label, "No valid label files found"
|
||||
283
tests/integration/services/test_document_service_integration.py
Normal file
283
tests/integration/services/test_document_service_integration.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Document Service Integration Tests
|
||||
|
||||
Tests DocumentService with real storage operations.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.web.services.document_service import DocumentService, DocumentResult
|
||||
|
||||
|
||||
class MockStorageBackend:
|
||||
"""Simple in-memory storage backend for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._files: dict[str, bytes] = {}
|
||||
|
||||
def upload_bytes(self, content: bytes, remote_path: str, overwrite: bool = False) -> None:
|
||||
if not overwrite and remote_path in self._files:
|
||||
raise FileExistsError(f"File already exists: {remote_path}")
|
||||
self._files[remote_path] = content
|
||||
|
||||
def download_bytes(self, remote_path: str) -> bytes:
|
||||
if remote_path not in self._files:
|
||||
raise FileNotFoundError(f"File not found: {remote_path}")
|
||||
return self._files[remote_path]
|
||||
|
||||
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||
return f"https://storage.example.com/{remote_path}?expires={expires_in_seconds}"
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
return remote_path in self._files
|
||||
|
||||
def delete(self, remote_path: str) -> bool:
|
||||
if remote_path in self._files:
|
||||
del self._files[remote_path]
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_files(self, prefix: str) -> list[str]:
|
||||
return [path for path in self._files.keys() if path.startswith(prefix)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
"""Create a mock storage backend."""
|
||||
return MockStorageBackend()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_service(mock_storage):
|
||||
"""Create a DocumentService with mock storage."""
|
||||
return DocumentService(storage_backend=mock_storage)
|
||||
|
||||
|
||||
class TestDocumentUpload:
|
||||
"""Tests for document upload operations."""
|
||||
|
||||
def test_upload_document(self, document_service):
|
||||
"""Test uploading a document."""
|
||||
content = b"%PDF-1.4 test content"
|
||||
filename = "test_invoice.pdf"
|
||||
|
||||
result = document_service.upload_document(content, filename)
|
||||
|
||||
assert result is not None
|
||||
assert result.id is not None
|
||||
assert result.filename == filename
|
||||
assert result.file_path.startswith("documents/")
|
||||
assert result.file_path.endswith(".pdf")
|
||||
|
||||
def test_upload_document_with_custom_id(self, document_service):
|
||||
"""Test uploading with custom document ID."""
|
||||
content = b"%PDF-1.4 test content"
|
||||
filename = "invoice.pdf"
|
||||
custom_id = "custom-doc-12345"
|
||||
|
||||
result = document_service.upload_document(
|
||||
content, filename, document_id=custom_id
|
||||
)
|
||||
|
||||
assert result.id == custom_id
|
||||
assert custom_id in result.file_path
|
||||
|
||||
def test_upload_preserves_extension(self, document_service):
|
||||
"""Test that file extension is preserved."""
|
||||
cases = [
|
||||
("document.pdf", ".pdf"),
|
||||
("image.PNG", ".png"),
|
||||
("file.JPEG", ".jpeg"),
|
||||
("noextension", ""),
|
||||
]
|
||||
|
||||
for filename, expected_ext in cases:
|
||||
result = document_service.upload_document(b"content", filename)
|
||||
if expected_ext:
|
||||
assert result.file_path.endswith(expected_ext)
|
||||
|
||||
def test_upload_document_overwrite(self, document_service, mock_storage):
|
||||
"""Test that upload overwrites existing file."""
|
||||
content1 = b"original content"
|
||||
content2 = b"new content"
|
||||
doc_id = "overwrite-test"
|
||||
|
||||
document_service.upload_document(content1, "doc.pdf", document_id=doc_id)
|
||||
document_service.upload_document(content2, "doc.pdf", document_id=doc_id)
|
||||
|
||||
# Should have new content
|
||||
remote_path = f"documents/{doc_id}.pdf"
|
||||
stored_content = mock_storage.download_bytes(remote_path)
|
||||
assert stored_content == content2
|
||||
|
||||
|
||||
class TestDocumentDownload:
|
||||
"""Tests for document download operations."""
|
||||
|
||||
def test_download_document(self, document_service, mock_storage):
|
||||
"""Test downloading a document."""
|
||||
content = b"test document content"
|
||||
remote_path = "documents/test-doc.pdf"
|
||||
mock_storage.upload_bytes(content, remote_path)
|
||||
|
||||
downloaded = document_service.download_document(remote_path)
|
||||
|
||||
assert downloaded == content
|
||||
|
||||
def test_download_nonexistent_document(self, document_service):
|
||||
"""Test downloading document that doesn't exist."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
document_service.download_document("documents/nonexistent.pdf")
|
||||
|
||||
|
||||
class TestDocumentUrl:
|
||||
"""Tests for document URL generation."""
|
||||
|
||||
def test_get_document_url(self, document_service, mock_storage):
|
||||
"""Test getting presigned URL for document."""
|
||||
remote_path = "documents/test-doc.pdf"
|
||||
mock_storage.upload_bytes(b"content", remote_path)
|
||||
|
||||
url = document_service.get_document_url(remote_path, expires_in_seconds=7200)
|
||||
|
||||
assert url.startswith("https://")
|
||||
assert remote_path in url
|
||||
assert "7200" in url
|
||||
|
||||
def test_get_document_url_default_expiry(self, document_service):
|
||||
"""Test default URL expiry."""
|
||||
url = document_service.get_document_url("documents/doc.pdf")
|
||||
|
||||
assert "3600" in url
|
||||
|
||||
|
||||
class TestDocumentExists:
|
||||
"""Tests for document existence check."""
|
||||
|
||||
def test_document_exists(self, document_service, mock_storage):
|
||||
"""Test checking if document exists."""
|
||||
remote_path = "documents/existing.pdf"
|
||||
mock_storage.upload_bytes(b"content", remote_path)
|
||||
|
||||
assert document_service.document_exists(remote_path) is True
|
||||
|
||||
def test_document_not_exists(self, document_service):
|
||||
"""Test checking if nonexistent document exists."""
|
||||
assert document_service.document_exists("documents/nonexistent.pdf") is False
|
||||
|
||||
|
||||
class TestDocumentDelete:
|
||||
"""Tests for document deletion."""
|
||||
|
||||
def test_delete_document(self, document_service, mock_storage):
|
||||
"""Test deleting a document."""
|
||||
remote_path = "documents/to-delete.pdf"
|
||||
mock_storage.upload_bytes(b"content", remote_path)
|
||||
|
||||
result = document_service.delete_document_files(remote_path)
|
||||
|
||||
assert result is True
|
||||
assert document_service.document_exists(remote_path) is False
|
||||
|
||||
def test_delete_nonexistent_document(self, document_service):
|
||||
"""Test deleting document that doesn't exist."""
|
||||
result = document_service.delete_document_files("documents/nonexistent.pdf")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestPageImages:
|
||||
"""Tests for page image operations."""
|
||||
|
||||
def test_save_page_image(self, document_service, mock_storage):
|
||||
"""Test saving a page image."""
|
||||
doc_id = "test-doc-123"
|
||||
page_num = 1
|
||||
image_content = b"\x89PNG\r\n\x1a\n fake png"
|
||||
|
||||
remote_path = document_service.save_page_image(doc_id, page_num, image_content)
|
||||
|
||||
assert remote_path == f"images/{doc_id}/page_{page_num}.png"
|
||||
assert mock_storage.exists(remote_path)
|
||||
|
||||
def test_save_multiple_page_images(self, document_service, mock_storage):
|
||||
"""Test saving images for multiple pages."""
|
||||
doc_id = "multi-page-doc"
|
||||
|
||||
for page_num in range(1, 4):
|
||||
content = f"page {page_num} content".encode()
|
||||
document_service.save_page_image(doc_id, page_num, content)
|
||||
|
||||
images = document_service.list_document_images(doc_id)
|
||||
assert len(images) == 3
|
||||
|
||||
def test_get_page_image(self, document_service, mock_storage):
|
||||
"""Test downloading a page image."""
|
||||
doc_id = "test-doc"
|
||||
page_num = 2
|
||||
image_content = b"image data"
|
||||
|
||||
document_service.save_page_image(doc_id, page_num, image_content)
|
||||
downloaded = document_service.get_page_image(doc_id, page_num)
|
||||
|
||||
assert downloaded == image_content
|
||||
|
||||
def test_get_page_image_url(self, document_service):
|
||||
"""Test getting URL for page image."""
|
||||
doc_id = "test-doc"
|
||||
page_num = 1
|
||||
|
||||
url = document_service.get_page_image_url(doc_id, page_num)
|
||||
|
||||
assert f"images/{doc_id}/page_{page_num}.png" in url
|
||||
|
||||
def test_list_document_images(self, document_service, mock_storage):
|
||||
"""Test listing all images for a document."""
|
||||
doc_id = "list-test-doc"
|
||||
|
||||
for i in range(5):
|
||||
document_service.save_page_image(doc_id, i + 1, f"page {i}".encode())
|
||||
|
||||
images = document_service.list_document_images(doc_id)
|
||||
|
||||
assert len(images) == 5
|
||||
|
||||
def test_delete_document_images(self, document_service, mock_storage):
|
||||
"""Test deleting all images for a document."""
|
||||
doc_id = "delete-images-doc"
|
||||
|
||||
for i in range(3):
|
||||
document_service.save_page_image(doc_id, i + 1, b"content")
|
||||
|
||||
deleted_count = document_service.delete_document_images(doc_id)
|
||||
|
||||
assert deleted_count == 3
|
||||
assert len(document_service.list_document_images(doc_id)) == 0
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""Tests for complete upload-download cycles."""
|
||||
|
||||
def test_document_round_trip(self, document_service):
|
||||
"""Test uploading and downloading document."""
|
||||
original_content = b"%PDF-1.4 complete document content here"
|
||||
filename = "roundtrip.pdf"
|
||||
|
||||
result = document_service.upload_document(original_content, filename)
|
||||
downloaded = document_service.download_document(result.file_path)
|
||||
|
||||
assert downloaded == original_content
|
||||
|
||||
def test_image_round_trip(self, document_service):
|
||||
"""Test saving and retrieving page image."""
|
||||
doc_id = "roundtrip-doc"
|
||||
page_num = 1
|
||||
original_image = b"\x89PNG fake image data"
|
||||
|
||||
document_service.save_page_image(doc_id, page_num, original_image)
|
||||
retrieved = document_service.get_page_image(doc_id, page_num)
|
||||
|
||||
assert retrieved == original_image
|
||||
258
tests/integration/test_database_setup.py
Normal file
258
tests/integration/test_database_setup.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Database Setup Integration Tests
|
||||
|
||||
Tests for database connection, session management, and basic operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from inference.data.admin_models import AdminDocument, AdminToken
|
||||
|
||||
|
||||
class TestDatabaseConnection:
|
||||
"""Tests for database engine and connection."""
|
||||
|
||||
def test_engine_connection(self, test_engine):
|
||||
"""Verify database engine can establish connection."""
|
||||
with test_engine.connect() as conn:
|
||||
result = conn.execute(select(1))
|
||||
assert result.scalar() == 1
|
||||
|
||||
def test_tables_created(self, test_engine):
|
||||
"""Verify all expected tables are created."""
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
table_names = SQLModel.metadata.tables.keys()
|
||||
|
||||
expected_tables = [
|
||||
"admin_tokens",
|
||||
"admin_documents",
|
||||
"admin_annotations",
|
||||
"training_tasks",
|
||||
"training_logs",
|
||||
"batch_uploads",
|
||||
"batch_upload_files",
|
||||
"training_datasets",
|
||||
"dataset_documents",
|
||||
"training_document_links",
|
||||
"model_versions",
|
||||
]
|
||||
|
||||
for table in expected_tables:
|
||||
assert table in table_names, f"Table '{table}' not found"
|
||||
|
||||
|
||||
class TestSessionManagement:
|
||||
"""Tests for database session context manager."""
|
||||
|
||||
def test_session_commit(self, db_session):
|
||||
"""Verify session commits changes successfully."""
|
||||
token = AdminToken(
|
||||
token="commit-test-token",
|
||||
name="Commit Test",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.exec(
|
||||
select(AdminToken).where(AdminToken.token == "commit-test-token")
|
||||
).first()
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Commit Test"
|
||||
|
||||
def test_session_rollback_on_error(self, test_engine):
|
||||
"""Verify session rollback on exception."""
|
||||
session = Session(test_engine)
|
||||
|
||||
try:
|
||||
token = AdminToken(
|
||||
token="rollback-test-token",
|
||||
name="Rollback Test",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(token)
|
||||
session.commit()
|
||||
|
||||
# Try to insert duplicate (should fail)
|
||||
duplicate = AdminToken(
|
||||
token="rollback-test-token", # Same primary key
|
||||
name="Duplicate",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(duplicate)
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# Verify original record exists
|
||||
with Session(test_engine) as verify_session:
|
||||
result = verify_session.exec(
|
||||
select(AdminToken).where(AdminToken.token == "rollback-test-token")
|
||||
).first()
|
||||
assert result is not None
|
||||
assert result.name == "Rollback Test"
|
||||
|
||||
def test_session_isolation(self, test_engine):
|
||||
"""Verify sessions are isolated from each other."""
|
||||
session1 = Session(test_engine)
|
||||
session2 = Session(test_engine)
|
||||
|
||||
try:
|
||||
# Insert in session1, don't commit
|
||||
token = AdminToken(
|
||||
token="isolation-test-token",
|
||||
name="Isolation Test",
|
||||
is_active=True,
|
||||
)
|
||||
session1.add(token)
|
||||
session1.flush()
|
||||
|
||||
# Session2 should not see uncommitted data (with proper isolation)
|
||||
# Note: SQLite in-memory may have different isolation behavior
|
||||
session1.commit()
|
||||
|
||||
result = session2.exec(
|
||||
select(AdminToken).where(AdminToken.token == "isolation-test-token")
|
||||
).first()
|
||||
|
||||
# After commit, session2 should see the data
|
||||
assert result is not None
|
||||
|
||||
finally:
|
||||
session1.close()
|
||||
session2.close()
|
||||
|
||||
|
||||
class TestBasicCRUDOperations:
|
||||
"""Tests for basic CRUD operations on database."""
|
||||
|
||||
def test_create_and_read_token(self, db_session):
|
||||
"""Test creating and reading admin token."""
|
||||
token = AdminToken(
|
||||
token="crud-test-token",
|
||||
name="CRUD Test",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.get(AdminToken, "crud-test-token")
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "CRUD Test"
|
||||
assert result.is_active is True
|
||||
|
||||
def test_update_entity(self, db_session, admin_token):
|
||||
"""Test updating an entity."""
|
||||
admin_token.name = "Updated Name"
|
||||
db_session.add(admin_token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.get(AdminToken, admin_token.token)
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Updated Name"
|
||||
|
||||
def test_delete_entity(self, db_session):
|
||||
"""Test deleting an entity."""
|
||||
token = AdminToken(
|
||||
token="delete-test-token",
|
||||
name="Delete Test",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.commit()
|
||||
|
||||
db_session.delete(token)
|
||||
db_session.commit()
|
||||
|
||||
result = db_session.get(AdminToken, "delete-test-token")
|
||||
assert result is None
|
||||
|
||||
def test_foreign_key_constraint(self, db_session, admin_token):
|
||||
"""Test foreign key constraints are enforced."""
|
||||
from uuid import uuid4
|
||||
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
admin_token=admin_token.token,
|
||||
filename="fk_test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/test/fk_test.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
|
||||
# Document should reference valid token
|
||||
result = db_session.get(AdminDocument, doc.document_id)
|
||||
assert result is not None
|
||||
assert result.admin_token == admin_token.token
|
||||
|
||||
|
||||
class TestQueryOperations:
|
||||
"""Tests for various query operations."""
|
||||
|
||||
def test_select_with_filter(self, db_session, multiple_documents):
|
||||
"""Test SELECT with WHERE clause."""
|
||||
results = db_session.exec(
|
||||
select(AdminDocument).where(AdminDocument.status == "labeled")
|
||||
).all()
|
||||
|
||||
assert len(results) == 2
|
||||
for doc in results:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
def test_select_with_order(self, db_session, multiple_documents):
|
||||
"""Test SELECT with ORDER BY clause."""
|
||||
results = db_session.exec(
|
||||
select(AdminDocument).order_by(AdminDocument.file_size.desc())
|
||||
).all()
|
||||
|
||||
file_sizes = [doc.file_size for doc in results]
|
||||
assert file_sizes == sorted(file_sizes, reverse=True)
|
||||
|
||||
def test_select_with_limit_offset(self, db_session, multiple_documents):
|
||||
"""Test SELECT with LIMIT and OFFSET."""
|
||||
results = db_session.exec(
|
||||
select(AdminDocument)
|
||||
.order_by(AdminDocument.filename)
|
||||
.offset(2)
|
||||
.limit(2)
|
||||
).all()
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
def test_count_query(self, db_session, multiple_documents):
|
||||
"""Test COUNT aggregation."""
|
||||
from sqlalchemy import func
|
||||
|
||||
count = db_session.exec(
|
||||
select(func.count()).select_from(AdminDocument)
|
||||
).one()
|
||||
|
||||
assert count == 5
|
||||
|
||||
def test_group_by_query(self, db_session, multiple_documents):
|
||||
"""Test GROUP BY aggregation."""
|
||||
from sqlalchemy import func
|
||||
|
||||
results = db_session.exec(
|
||||
select(
|
||||
AdminDocument.status,
|
||||
func.count(AdminDocument.document_id).label("count"),
|
||||
).group_by(AdminDocument.status)
|
||||
).all()
|
||||
|
||||
status_counts = {row[0]: row[1] for row in results}
|
||||
|
||||
assert status_counts.get("pending") == 2
|
||||
assert status_counts.get("labeled") == 2
|
||||
assert status_counts.get("exported") == 1
|
||||
@@ -196,3 +196,121 @@ class TestAnnotationModel:
|
||||
assert 0 <= ann.y_center <= 1
|
||||
assert 0 <= ann.width <= 1
|
||||
assert 0 <= ann.height <= 1
|
||||
|
||||
|
||||
class TestAutoLabelFilePathResolution:
|
||||
"""Tests for auto-label file path resolution.
|
||||
|
||||
The auto-label endpoint needs to resolve the storage path (e.g., "raw_pdfs/uuid.pdf")
|
||||
to an actual filesystem path via the storage helper.
|
||||
"""
|
||||
|
||||
def test_extracts_filename_from_storage_path(self):
|
||||
"""Test that filename is extracted from storage path correctly."""
|
||||
# Storage paths are like "raw_pdfs/uuid.pdf"
|
||||
storage_path = "raw_pdfs/550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
# The annotation endpoint extracts filename
|
||||
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
|
||||
|
||||
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
def test_handles_path_without_prefix(self):
|
||||
"""Test that paths without prefix are handled."""
|
||||
storage_path = "550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
|
||||
|
||||
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
|
||||
|
||||
def test_storage_helper_resolves_path(self):
|
||||
"""Test that storage helper can resolve the path."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Mock storage helper
|
||||
mock_storage = MagicMock()
|
||||
mock_path = Path("/storage/raw_pdfs/test.pdf")
|
||||
mock_storage.get_raw_pdf_local_path.return_value = mock_path
|
||||
|
||||
with patch(
|
||||
"inference.web.services.storage_helpers.get_storage_helper",
|
||||
return_value=mock_storage,
|
||||
):
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
storage = get_storage_helper()
|
||||
result = storage.get_raw_pdf_local_path("test.pdf")
|
||||
|
||||
assert result == mock_path
|
||||
mock_storage.get_raw_pdf_local_path.assert_called_once_with("test.pdf")
|
||||
|
||||
def test_auto_label_request_validation(self):
|
||||
"""Test AutoLabelRequest validates field_values."""
|
||||
# Valid request
|
||||
request = AutoLabelRequest(
|
||||
field_values={"InvoiceNumber": "12345"},
|
||||
replace_existing=False,
|
||||
)
|
||||
assert request.field_values == {"InvoiceNumber": "12345"}
|
||||
|
||||
# Empty field_values should be valid at schema level
|
||||
# (endpoint validates non-empty)
|
||||
request_empty = AutoLabelRequest(
|
||||
field_values={},
|
||||
replace_existing=False,
|
||||
)
|
||||
assert request_empty.field_values == {}
|
||||
|
||||
|
||||
class TestMatchClassAttributes:
|
||||
"""Tests for Match class attributes used in auto-labeling.
|
||||
|
||||
The autolabel service uses Match objects from FieldMatcher.
|
||||
Verifies the correct attribute names are used.
|
||||
"""
|
||||
|
||||
def test_match_has_matched_text_attribute(self):
|
||||
"""Test that Match class has matched_text attribute (not matched_value)."""
|
||||
from shared.matcher.models import Match
|
||||
|
||||
# Create a Match object
|
||||
match = Match(
|
||||
field="invoice_number",
|
||||
value="12345",
|
||||
bbox=(100, 100, 200, 150),
|
||||
page_no=0,
|
||||
score=0.95,
|
||||
matched_text="INV-12345",
|
||||
context_keywords=["faktura", "nummer"],
|
||||
)
|
||||
|
||||
# Verify matched_text exists (this is what autolabel.py should use)
|
||||
assert hasattr(match, "matched_text")
|
||||
assert match.matched_text == "INV-12345"
|
||||
|
||||
# Verify matched_value does NOT exist
|
||||
# This was the bug - autolabel.py was using matched_value instead of matched_text
|
||||
assert not hasattr(match, "matched_value")
|
||||
|
||||
def test_match_attributes_for_annotation_creation(self):
|
||||
"""Test that Match has all attributes needed for annotation creation."""
|
||||
from shared.matcher.models import Match
|
||||
|
||||
match = Match(
|
||||
field="amount",
|
||||
value="1000.00",
|
||||
bbox=(50, 200, 150, 230),
|
||||
page_no=0,
|
||||
score=0.88,
|
||||
matched_text="1 000,00",
|
||||
context_keywords=["att betala", "summa"],
|
||||
)
|
||||
|
||||
# These are all the attributes used in autolabel._create_annotations_from_matches
|
||||
assert hasattr(match, "bbox")
|
||||
assert hasattr(match, "matched_text") # NOT matched_value
|
||||
assert hasattr(match, "score")
|
||||
|
||||
# Verify bbox format
|
||||
assert len(match.bbox) == 4 # (x0, y0, x1, y1)
|
||||
|
||||
@@ -3,7 +3,7 @@ Tests for Admin Authentication.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -132,6 +132,47 @@ class TestTokenRepository:
|
||||
with patch.object(repo, "_now", return_value=datetime.utcnow()):
|
||||
assert repo.is_valid("test-token") is False
|
||||
|
||||
def test_is_valid_expired_token_timezone_aware(self):
|
||||
"""Test expired token with timezone-aware datetime.
|
||||
|
||||
This verifies the fix for comparing timezone-aware and naive datetimes.
|
||||
The auth API now creates tokens with timezone-aware expiration dates.
|
||||
"""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create token with timezone-aware expiration (as auth API now does)
|
||||
mock_token = AdminToken(
|
||||
token="test-token",
|
||||
name="Test",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
repo = TokenRepository()
|
||||
# _now() returns timezone-aware datetime, should compare correctly
|
||||
assert repo.is_valid("test-token") is False
|
||||
|
||||
def test_is_valid_not_expired_token_timezone_aware(self):
|
||||
"""Test non-expired token with timezone-aware datetime."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create token with timezone-aware expiration in the future
|
||||
mock_token = AdminToken(
|
||||
token="test-token",
|
||||
name="Test",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=1),
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
repo = TokenRepository()
|
||||
assert repo.is_valid("test-token") is True
|
||||
|
||||
def test_is_valid_token_not_found(self):
|
||||
"""Test token not found."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
|
||||
317
tests/web/test_dashboard_api.py
Normal file
317
tests/web/test_dashboard_api.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Tests for Dashboard API Endpoints and Services.
|
||||
|
||||
Tests are split into:
|
||||
1. Unit tests for business logic (is_annotation_complete, etc.)
|
||||
2. Service tests with mocked database
|
||||
3. Integration tests via TestClient (requires DB)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# Test data constants
|
||||
TEST_DOC_UUID_1 = "550e8400-e29b-41d4-a716-446655440001"
|
||||
TEST_MODEL_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440001"
|
||||
|
||||
|
||||
class TestAnnotationCompletenessLogic:
|
||||
"""Unit tests for annotation completeness calculation logic.
|
||||
|
||||
These tests verify the core business logic:
|
||||
- Complete: has (invoice_number OR ocr_number) AND (bankgiro OR plusgiro)
|
||||
- Incomplete: labeled but missing required fields
|
||||
"""
|
||||
|
||||
def test_document_with_invoice_number_and_bankgiro_is_complete(self):
|
||||
"""Document with invoice_number + bankgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_ocr_number_and_plusgiro_is_complete(self):
|
||||
"""Document with ocr_number + plusgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 3, "class_name": "ocr_number"},
|
||||
{"class_id": 5, "class_name": "plusgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_invoice_number_and_plusgiro_is_complete(self):
|
||||
"""Document with invoice_number + plusgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
{"class_id": 5, "class_name": "plusgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_ocr_number_and_bankgiro_is_complete(self):
|
||||
"""Document with ocr_number + bankgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 3, "class_name": "ocr_number"},
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
def test_document_with_only_identifier_is_incomplete(self):
|
||||
"""Document with only identifier field should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_document_with_only_payment_is_incomplete(self):
|
||||
"""Document with only payment field should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_document_with_no_annotations_is_incomplete(self):
|
||||
"""Document with no annotations should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
assert is_annotation_complete([]) is False
|
||||
|
||||
def test_document_with_other_fields_only_is_incomplete(self):
|
||||
"""Document with only non-essential fields should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 1, "class_name": "invoice_date"},
|
||||
{"class_id": 6, "class_name": "amount"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is False
|
||||
|
||||
def test_document_with_all_fields_is_complete(self):
|
||||
"""Document with all fields should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
{"class_id": 1, "class_name": "invoice_date"},
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
{"class_id": 6, "class_name": "amount"},
|
||||
]
|
||||
|
||||
assert is_annotation_complete(annotations) is True
|
||||
|
||||
|
||||
class TestDashboardStatsService:
|
||||
"""Tests for DashboardStatsService with mocked database."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock database session."""
|
||||
session = MagicMock()
|
||||
session.exec.return_value.one.return_value = 0
|
||||
return session
|
||||
|
||||
def test_completeness_rate_calculation(self):
|
||||
"""Test completeness rate is calculated correctly."""
|
||||
# Direct calculation test
|
||||
complete = 25
|
||||
incomplete = 8
|
||||
total_assessed = complete + incomplete
|
||||
expected_rate = round(complete / total_assessed * 100, 2)
|
||||
|
||||
assert expected_rate == pytest.approx(75.76, rel=0.01)
|
||||
|
||||
def test_completeness_rate_zero_documents(self):
|
||||
"""Test completeness rate is 0 when no documents."""
|
||||
complete = 0
|
||||
incomplete = 0
|
||||
total_assessed = complete + incomplete
|
||||
|
||||
completeness_rate = (
|
||||
round(complete / total_assessed * 100, 2)
|
||||
if total_assessed > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
assert completeness_rate == 0.0
|
||||
|
||||
|
||||
class TestDashboardActivityService:
|
||||
"""Tests for DashboardActivityService activity aggregation."""
|
||||
|
||||
def test_activity_types(self):
|
||||
"""Test all activity types are defined."""
|
||||
expected_types = [
|
||||
"document_uploaded",
|
||||
"annotation_modified",
|
||||
"training_completed",
|
||||
"training_failed",
|
||||
"model_activated",
|
||||
]
|
||||
|
||||
for activity_type in expected_types:
|
||||
assert activity_type in expected_types
|
||||
|
||||
|
||||
class TestDashboardSchemas:
|
||||
"""Tests for Dashboard API schemas."""
|
||||
|
||||
def test_dashboard_stats_response_schema(self):
|
||||
"""Test DashboardStatsResponse schema validation."""
|
||||
from inference.web.schemas.admin import DashboardStatsResponse
|
||||
|
||||
response = DashboardStatsResponse(
|
||||
total_documents=38,
|
||||
annotation_complete=25,
|
||||
annotation_incomplete=8,
|
||||
pending=5,
|
||||
completeness_rate=75.76,
|
||||
)
|
||||
|
||||
assert response.total_documents == 38
|
||||
assert response.annotation_complete == 25
|
||||
assert response.annotation_incomplete == 8
|
||||
assert response.pending == 5
|
||||
assert response.completeness_rate == 75.76
|
||||
|
||||
def test_active_model_response_schema(self):
|
||||
"""Test ActiveModelResponse schema with null model."""
|
||||
from inference.web.schemas.admin import ActiveModelResponse
|
||||
|
||||
response = ActiveModelResponse(
|
||||
model=None,
|
||||
running_training=None,
|
||||
)
|
||||
|
||||
assert response.model is None
|
||||
assert response.running_training is None
|
||||
|
||||
def test_active_model_info_schema(self):
|
||||
"""Test ActiveModelInfo schema validation."""
|
||||
from inference.web.schemas.admin import ActiveModelInfo
|
||||
|
||||
model = ActiveModelInfo(
|
||||
version_id=TEST_MODEL_UUID,
|
||||
version="1.2.0",
|
||||
name="Invoice Model",
|
||||
metrics_mAP=0.951,
|
||||
metrics_precision=0.94,
|
||||
metrics_recall=0.92,
|
||||
document_count=500,
|
||||
activated_at=datetime(2024, 1, 20, 15, 0, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert model.version == "1.2.0"
|
||||
assert model.name == "Invoice Model"
|
||||
assert model.metrics_mAP == 0.951
|
||||
|
||||
def test_running_training_info_schema(self):
|
||||
"""Test RunningTrainingInfo schema validation."""
|
||||
from inference.web.schemas.admin import RunningTrainingInfo
|
||||
|
||||
task = RunningTrainingInfo(
|
||||
task_id=TEST_TASK_UUID,
|
||||
name="Run-2024-02",
|
||||
status="running",
|
||||
started_at=datetime(2024, 1, 25, 10, 0, 0, tzinfo=timezone.utc),
|
||||
progress=45,
|
||||
)
|
||||
|
||||
assert task.name == "Run-2024-02"
|
||||
assert task.status == "running"
|
||||
assert task.progress == 45
|
||||
|
||||
def test_activity_item_schema(self):
|
||||
"""Test ActivityItem schema validation."""
|
||||
from inference.web.schemas.admin import ActivityItem
|
||||
|
||||
activity = ActivityItem(
|
||||
type="model_activated",
|
||||
description="Activated model v1.2.0",
|
||||
timestamp=datetime(2024, 1, 25, 12, 0, 0, tzinfo=timezone.utc),
|
||||
metadata={"version_id": TEST_MODEL_UUID, "version": "1.2.0"},
|
||||
)
|
||||
|
||||
assert activity.type == "model_activated"
|
||||
assert activity.description == "Activated model v1.2.0"
|
||||
assert activity.metadata["version"] == "1.2.0"
|
||||
|
||||
def test_recent_activity_response_schema(self):
|
||||
"""Test RecentActivityResponse schema with empty activities."""
|
||||
from inference.web.schemas.admin import RecentActivityResponse
|
||||
|
||||
response = RecentActivityResponse(activities=[])
|
||||
|
||||
assert response.activities == []
|
||||
|
||||
|
||||
class TestDashboardRouterCreation:
|
||||
"""Tests for dashboard router creation."""
|
||||
|
||||
def test_creates_router_with_expected_endpoints(self):
|
||||
"""Test router is created with expected endpoint paths."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
assert any("/stats" in p for p in paths)
|
||||
assert any("/active-model" in p for p in paths)
|
||||
assert any("/activity" in p for p in paths)
|
||||
|
||||
def test_router_has_correct_prefix(self):
|
||||
"""Test router has /admin/dashboard prefix."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
assert router.prefix == "/admin/dashboard"
|
||||
|
||||
def test_router_has_dashboard_tag(self):
|
||||
"""Test router uses Dashboard tag."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
assert "Dashboard" in router.tags
|
||||
|
||||
|
||||
class TestFieldClassIds:
|
||||
"""Tests for field class ID constants."""
|
||||
|
||||
def test_identifier_class_ids(self):
|
||||
"""Test identifier field class IDs."""
|
||||
from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
|
||||
|
||||
# invoice_number = 0, ocr_number = 3
|
||||
assert 0 in IDENTIFIER_CLASS_IDS
|
||||
assert 3 in IDENTIFIER_CLASS_IDS
|
||||
|
||||
def test_payment_class_ids(self):
|
||||
"""Test payment field class IDs."""
|
||||
from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS
|
||||
|
||||
# bankgiro = 4, plusgiro = 5
|
||||
assert 4 in PAYMENT_CLASS_IDS
|
||||
assert 5 in PAYMENT_CLASS_IDS
|
||||
@@ -1,68 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Update test imports to use new structure."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# Import mapping: old -> new
|
||||
IMPORT_MAPPINGS = {
|
||||
# Admin routes
|
||||
r'from src\.web\.admin_routes import': 'from src.web.api.v1.admin.documents import',
|
||||
r'from src\.web\.admin_annotation_routes import': 'from src.web.api.v1.admin.annotations import',
|
||||
r'from src\.web\.admin_training_routes import': 'from src.web.api.v1.admin.training import',
|
||||
|
||||
# Auth and core
|
||||
r'from src\.web\.admin_auth import': 'from src.web.core.auth import',
|
||||
r'from src\.web\.admin_autolabel import': 'from src.web.services.autolabel import',
|
||||
r'from src\.web\.admin_scheduler import': 'from src.web.core.scheduler import',
|
||||
|
||||
# Schemas
|
||||
r'from src\.web\.admin_schemas import': 'from src.web.schemas.admin import',
|
||||
r'from src\.web\.schemas import': 'from src.web.schemas.inference import',
|
||||
|
||||
# Services
|
||||
r'from src\.web\.services import': 'from src.web.services.inference import',
|
||||
r'from src\.web\.async_service import': 'from src.web.services.async_processing import',
|
||||
r'from src\.web\.batch_upload_service import': 'from src.web.services.batch_upload import',
|
||||
|
||||
# Workers
|
||||
r'from src\.web\.async_queue import': 'from src.web.workers.async_queue import',
|
||||
r'from src\.web\.batch_queue import': 'from src.web.workers.batch_queue import',
|
||||
|
||||
# Routes
|
||||
r'from src\.web\.routes import': 'from src.web.api.v1.routes import',
|
||||
r'from src\.web\.async_routes import': 'from src.web.api.v1.async_api.routes import',
|
||||
r'from src\.web\.batch_upload_routes import': 'from src.web.api.v1.batch.routes import',
|
||||
}
|
||||
|
||||
def update_file(file_path: Path) -> bool:
|
||||
"""Update imports in a single file."""
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
original_content = content
|
||||
|
||||
for old_pattern, new_import in IMPORT_MAPPINGS.items():
|
||||
content = re.sub(old_pattern, new_import, content)
|
||||
|
||||
if content != original_content:
|
||||
file_path.write_text(content, encoding='utf-8')
|
||||
return True
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Update all test files."""
|
||||
test_dir = Path('tests/web')
|
||||
updated_files = []
|
||||
|
||||
for test_file in test_dir.glob('test_*.py'):
|
||||
if update_file(test_file):
|
||||
updated_files.append(test_file.name)
|
||||
|
||||
if updated_files:
|
||||
print(f"✓ Updated {len(updated_files)} test files:")
|
||||
for filename in sorted(updated_files):
|
||||
print(f" - {filename}")
|
||||
else:
|
||||
print("No files needed updating")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user