Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

BIN
.coverage

Binary file not shown.

View File

@@ -0,0 +1,99 @@
# Dashboard 原型图提示词
> 视觉风格现代极简Minimalism- 保持现有 Warm 主题设计风格
> 配色方案Warm 浅色系(米白背景 #FAFAF8、白色卡片、深灰文字 #121212
> 目标平台网页Web Desktop
---
## 当前颜色方案参考
| 用途 | 颜色值 | 说明 |
|------|--------|------|
| 页面背景 | #FAFAF8 | 温暖的米白色 |
| 卡片背景 | #FFFFFF | 纯白 |
| 边框 | #E6E4E1 | 浅灰褐色 |
| 主文字 | #121212 | 近黑色 |
| 次要文字 | #6B6B6B | 中灰色 |
| 成功状态 | #3E4A3A + green-500 | 深橄榄绿 + 亮绿指示点 |
| 警告状态 | #4A4A3A + yellow-50 | 深黄褐 + 浅黄背景 |
| 信息状态 | #3A3A3A + blue-50 | 深灰 + 浅蓝背景 |
---
## 页面 1Dashboard 主界面(正常状态)
**页面说明**:用户登录后的首页,显示文档统计、数据质量、活跃模型状态和最近活动
**提示词**
```
A modern web application dashboard UI for a document annotation system, main overview page, warm minimalist design theme, page background color #FAFAF8 warm off-white, single column layout with header navigation at top, content area below with multiple sections, top section shows: 4 equal-width stat cards in a row on white #FFFFFF background with subtle border #E6E4E1, first card Total Documents (38) with gray file icon on #FAFAF8 background, second card Complete (25) with dark olive green checkmark icon on light green #dcfce7 background, third card Incomplete (8) with orange alert icon on light orange #fef3c7 background, fourth card Pending (5) with blue clock icon on light blue #dbeafe background, each card has icon top-left in rounded square and large bold number in #121212 with label below in #6B6B6B, cards have subtle shadow on hover, middle section has two-column layout (50%/50%): left panel white card titled DATA QUALITY in uppercase #6B6B6B with circular progress ring 120px showing 78% in center with green #22C55E filled portion and gray #E5E7EB remaining, percentage text 36px bold #121212 centered in ring, text Annotation Complete next to ring, stats list below showing Complete 25 and Incomplete 8 and Pending 5 with small colored dots, text button View Incomplete Docs in primary color at bottom, right panel white card titled ACTIVE MODEL showing v1.2.0 - Invoice Model as title in bold #121212, thin horizontal divider #E6E4E1 below, three-column metrics row displaying mAP 95.1% and Precision 94% and Recall 92% in 24px bold with 12px labels below in #6B6B6B, info rows showing Activated 2024-01-20 and Documents 500 in 14px, training progress section at bottom showing Run-2024-02 with horizontal progress bar, below panels is full-width white card RECENT ACTIVITY section with list of 6 activity items each 40px height showing icon on left and description text in #121212 and relative timestamp in #6B6B6B right aligned, activity icons: rocket in purple for model activation, checkmark in green for training complete, edit pencil in orange for annotation modified, file in blue for document uploaded, x in red for training failed, subtle hover background #F1F0ED on activity rows, bottom section is SYSTEM STATUS white card showing Backend API Online with bright green #22C55E dot and Database Connected with green dot and GPU Available with green dot, all text in #2A2A2A, Inter font family, rounded corners 8px on all cards, subtle card shadow, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style, dribbble quality
```
---
## 页面 2Dashboard 空状态(无活跃模型)
**页面说明**:系统刚部署或无训练模型时的引导界面
**提示词**
```
A modern web application dashboard UI for a document annotation system, empty state variation, warm minimalist design theme, page background #FAFAF8 warm off-white, single column layout with header navigation, top section shows: 4 stat cards on white background with #E6E4E1 border, all showing 0 values, Total Documents 0 with gray icon, Complete 0 with muted green, Incomplete 0 with muted orange, Pending 0 with muted blue, middle section two-column layout: left DATA QUALITY panel white card shows circular progress ring at 0% completely gray #E5E7EB with dashed outline style, large text 0% in #6B6B6B centered, text No data yet below in muted color, empty stats all showing 0, right ACTIVE MODEL panel white card shows empty state with large subtle model icon in center opacity 20%, text No Active Model as heading in #121212, subtext Train and activate a model to see stats here in #6B6B6B, primary button Go to Training at bottom, below panels RECENT ACTIVITY white card shows empty state with Activity icon centered at 20% opacity, text No recent activity in #121212, subtext Start by uploading documents or creating training jobs in #6B6B6B, bottom SYSTEM STATUS card showing all services online with green #22C55E dots, warm color palette throughout, Inter font, rounded corners 8px, subtle shadows, friendly and inviting empty state design, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style
```
---
## 页面 3Dashboard 训练中状态
**页面说明**有模型正在训练时Active Model 面板显示训练进度
**提示词**
```
A modern web application dashboard UI for a document annotation system, training in progress state, warm minimalist theme with #FAFAF8 background, header with navigation, top section: 4 white stat cards with #E6E4E1 borders showing Total Documents 38, Complete 25 with green icon on #dcfce7, Incomplete 8 with orange icon on #fef3c7, Pending 5 with blue icon on #dbeafe, middle section two-column layout: left DATA QUALITY white card with 78% progress ring in green #22C55E, stats list showing counts, right ACTIVE MODEL white card showing current model v1.1.0 in bold #121212 with metrics mAP 93.5% Precision 92% Recall 88% in grid, below a highlighted training section with subtle blue tint background #EFF6FF, pulsing blue dot indicator, text Training in Progress in #121212, task name Run-2024-02, horizontal progress bar 45% complete with blue #3B82F6 fill and gray #E5E7EB track, text Started 2 hours ago in #6B6B6B below, RECENT ACTIVITY white card below with latest item showing blue spinner icon and Training started Run-2024-02, other activities listed with appropriate icons, SYSTEM STATUS card at bottom showing GPU Available highlighted with green dot indicating active usage, warm color scheme throughout, Inter font, 8px rounded corners, subtle card shadows, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style
```
---
## 页面 4Dashboard 移动端响应式
**页面说明**:移动端(<768px下的单列堆叠布局
**提示词**
```
A modern mobile web application dashboard UI for a document annotation system, responsive mobile layout on smartphone screen, warm minimalist theme with #FAFAF8 background, single column stacked layout, top shows condensed header with hamburger menu icon and logo, below 2x2 grid of compact white stat cards with #E6E4E1 borders showing Total 38 Complete 25 Incomplete 8 Pending 5 with small colored icons on tinted backgrounds, DATA QUALITY section below as full-width white card with smaller progress ring 80px showing 78% in green #22C55E, horizontal stats row compact, ACTIVE MODEL section below as full-width white card with model name v1.2.0 in bold, compact metrics row showing mAP Precision Recall values, RECENT ACTIVITY section full-width white card with scrollable list of 4 visible items with icons and timestamps in #6B6B6B, compact SYSTEM STATUS bar at bottom with three green #22C55E status dots, warm color palette #FAFAF8 background white cards #121212 text, Inter font, touch-friendly tap targets 44px minimum, comfortable 16px padding, 8px rounded corners, iOS/Android native feel, UI/UX design, high fidelity mockup, mobile screen 375x812 iPhone size, professional, Figma style
```
---
## 使用说明
1. 将提示词复制到 AI 绘图工具 MidjourneyDALL-EStable Diffusion
2. 建议先生成页面 1主界面验证风格是否匹配现有设计
3. 提示词已包含你现有的颜色方案
- 页面背景#FAFAF8温暖米白
- 卡片背景#FFFFFF白色
- 边框#E6E4E1浅灰褐
- 主文字#121212近黑
- 次要文字#6B6B6B中灰
- 成功色#22C55E亮绿/ #3E4A3A深橄榄绿文字
- 图标背景#dcfce7浅绿/ #fef3c7浅黄/ #dbeafe浅蓝
4. 如果生成结果颜色有偏差可以在后期用 Figma 调整
---
## Tailwind 类参考(开发用)
```
背景bg-warm-bg (#FAFAF8)
卡片bg-warm-card (#FFFFFF)
边框border-warm-border (#E6E4E1)
主文字text-warm-text-primary (#121212)
次要文字text-warm-text-secondary (#2A2A2A)
灰色文字text-warm-text-muted (#6B6B6B)
悬停背景bg-warm-hover (#F1F0ED)
成功状态text-warm-state-success (#3E4A3A)
绿色图标背景bg-green-50 (#dcfce7)
黄色图标背景bg-yellow-50 (#fef3c7)
蓝色图标背景bg-blue-50 (#dbeafe)
绿色指示点bg-green-500 (#22C55E)
```

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
import apiClient from '../client'
import type {
DashboardStatsResponse,
DashboardActiveModelResponse,
RecentActivityResponse,
} from '../types'
export const dashboardApi = {
getStats: async (): Promise<DashboardStatsResponse> => {
const response = await apiClient.get('/api/v1/admin/dashboard/stats')
return response.data
},
getActiveModel: async (): Promise<DashboardActiveModelResponse> => {
const response = await apiClient.get('/api/v1/admin/dashboard/active-model')
return response.data
},
getRecentActivity: async (limit: number = 10): Promise<RecentActivityResponse> => {
const response = await apiClient.get('/api/v1/admin/dashboard/activity', {
params: { limit },
})
return response.data
},
}

View File

@@ -5,3 +5,4 @@ export { inferenceApi } from './inference'
export { datasetsApi } from './datasets'
export { augmentationApi } from './augmentation'
export { modelsApi } from './models'
export { dashboardApi } from './dashboard'

View File

@@ -362,3 +362,48 @@ export interface ActiveModelResponse {
has_active_model: boolean
model: ModelVersionItem | null
}
// Dashboard types
export interface DashboardStatsResponse {
total_documents: number
annotation_complete: number
annotation_incomplete: number
pending: number
completeness_rate: number
}
export interface DashboardActiveModelInfo {
version_id: string
version: string
name: string
metrics_mAP: number | null
metrics_precision: number | null
metrics_recall: number | null
document_count: number
activated_at: string | null
}
export interface DashboardRunningTrainingInfo {
task_id: string
name: string
status: string
started_at: string | null
progress: number
}
export interface DashboardActiveModelResponse {
model: DashboardActiveModelInfo | null
running_training: DashboardRunningTrainingInfo | null
}
export interface ActivityItem {
type: 'document_uploaded' | 'annotation_modified' | 'training_completed' | 'training_failed' | 'model_activated'
description: string
timestamp: string
metadata: Record<string, unknown>
}
export interface RecentActivityResponse {
activities: ActivityItem[]
}

View File

@@ -1,47 +1,58 @@
import React from 'react'
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
import { Button } from './Button'
import { useDocuments } from '../hooks/useDocuments'
import { useTraining } from '../hooks/useTraining'
import { FileText, CheckCircle, AlertCircle, Clock, RefreshCw } from 'lucide-react'
import {
StatsCard,
DataQualityPanel,
ActiveModelPanel,
RecentActivityPanel,
SystemStatusBar,
} from './dashboard/index'
import { useDashboard } from '../hooks/useDashboard'
interface DashboardOverviewProps {
onNavigate: (view: string) => void
}
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
const { models, isLoadingModels } = useTraining()
const {
stats,
model,
runningTraining,
activities,
isLoading,
error,
} = useDashboard()
const stats = [
{
label: 'Total Documents',
value: docsLoading ? '...' : totalDocs.toString(),
icon: FileText,
color: 'text-warm-text-primary',
bgColor: 'bg-warm-bg',
},
{
label: 'Labeled',
value: '0',
icon: CheckCircle,
color: 'text-warm-state-success',
bgColor: 'bg-green-50',
},
{
label: 'Pending',
value: '0',
icon: Clock,
color: 'text-warm-state-warning',
bgColor: 'bg-yellow-50',
},
{
label: 'Training Models',
value: isLoadingModels ? '...' : models.length.toString(),
icon: TrendingUp,
color: 'text-warm-state-info',
bgColor: 'bg-blue-50',
},
]
const handleStatsClick = (filter?: string) => {
if (filter) {
onNavigate(`documents?status=${filter}`)
} else {
onNavigate('documents')
}
}
if (error) {
return (
<div className="p-8 max-w-7xl mx-auto">
<div className="bg-red-50 border border-red-200 rounded-lg p-6 text-center">
<AlertCircle className="w-12 h-12 text-red-500 mx-auto mb-4" />
<h2 className="text-lg font-semibold text-red-800 mb-2">
Failed to load dashboard
</h2>
<p className="text-sm text-red-600 mb-4">
{error instanceof Error ? error.message : 'An unexpected error occurred'}
</p>
<button
onClick={() => window.location.reload()}
className="inline-flex items-center gap-2 px-4 py-2 bg-red-100 hover:bg-red-200 text-red-800 rounded-md text-sm font-medium transition-colors"
>
<RefreshCw className="w-4 h-4" />
Retry
</button>
</div>
</div>
)
}
return (
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
@@ -55,94 +66,74 @@ export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate
</p>
</div>
{/* Stats Grid */}
{/* Stats Cards Row */}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
{stats.map((stat) => (
<div
key={stat.label}
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
>
<div className="flex items-center justify-between mb-4">
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
<stat.icon className={stat.color} size={24} />
</div>
</div>
<p className="text-2xl font-bold text-warm-text-primary mb-1">
{stat.value}
</p>
<p className="text-sm text-warm-text-muted">{stat.label}</p>
</div>
))}
<StatsCard
label="Total Documents"
value={stats?.total_documents ?? 0}
icon={FileText}
iconColor="text-warm-text-primary"
iconBgColor="bg-warm-bg"
isLoading={isLoading}
onClick={() => handleStatsClick()}
/>
<StatsCard
label="Complete"
value={stats?.annotation_complete ?? 0}
icon={CheckCircle}
iconColor="text-warm-state-success"
iconBgColor="bg-green-50"
isLoading={isLoading}
onClick={() => handleStatsClick('labeled')}
/>
<StatsCard
label="Incomplete"
value={stats?.annotation_incomplete ?? 0}
icon={AlertCircle}
iconColor="text-orange-600"
iconBgColor="bg-orange-50"
isLoading={isLoading}
onClick={() => handleStatsClick('labeled')}
/>
<StatsCard
label="Pending"
value={stats?.pending ?? 0}
icon={Clock}
iconColor="text-blue-600"
iconBgColor="bg-blue-50"
isLoading={isLoading}
onClick={() => handleStatsClick('pending')}
/>
</div>
{/* Quick Actions */}
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
Quick Actions
</h2>
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<Button onClick={() => onNavigate('documents')} className="justify-start">
<FileText size={18} className="mr-2" />
Manage Documents
</Button>
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
<Activity size={18} className="mr-2" />
Start Training
</Button>
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
<TrendingUp size={18} className="mr-2" />
View Models
</Button>
</div>
{/* Two-column layout: Data Quality + Active Model */}
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6 mb-8">
<DataQualityPanel
completenessRate={stats?.completeness_rate ?? 0}
completeCount={stats?.annotation_complete ?? 0}
incompleteCount={stats?.annotation_incomplete ?? 0}
pendingCount={stats?.pending ?? 0}
isLoading={isLoading}
onViewIncomplete={() => handleStatsClick('labeled')}
/>
<ActiveModelPanel
model={model}
runningTraining={runningTraining}
isLoading={isLoading}
onGoToTraining={() => onNavigate('training')}
/>
</div>
{/* Recent Activity */}
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<div className="p-6 border-b border-warm-border">
<h2 className="text-lg font-semibold text-warm-text-primary">
Recent Activity
</h2>
</div>
<div className="p-6">
<div className="text-center py-8 text-warm-text-muted">
<Activity size={48} className="mx-auto mb-3 opacity-20" />
<p className="text-sm">No recent activity</p>
<p className="text-xs mt-1">
Start by uploading documents or creating training jobs
</p>
</div>
</div>
<div className="mb-8">
<RecentActivityPanel
activities={activities}
isLoading={isLoading}
/>
</div>
{/* System Status */}
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
System Status
</h2>
<div className="space-y-3">
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">Backend API</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Online
</span>
</div>
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">Database</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Connected
</span>
</div>
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">GPU</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Available
</span>
</div>
</div>
</div>
<SystemStatusBar />
</div>
)
}

View File

@@ -0,0 +1,143 @@
import React from 'react'
import { TrendingUp } from 'lucide-react'
import { Button } from '../Button'
import type { DashboardActiveModelInfo, DashboardRunningTrainingInfo } from '../../api/types'
interface ActiveModelPanelProps {
model: DashboardActiveModelInfo | null
runningTraining: DashboardRunningTrainingInfo | null
isLoading?: boolean
onGoToTraining?: () => void
}
const formatDate = (dateStr: string | null): string => {
if (!dateStr) return 'N/A'
const date = new Date(dateStr)
return date.toLocaleDateString('en-US', {
year: 'numeric',
month: 'short',
day: 'numeric',
})
}
const formatMetric = (value: number | null): string => {
if (value === null) return 'N/A'
return `${(value * 100).toFixed(1)}%`
}
const getMetricColor = (value: number | null): string => {
if (value === null) return 'text-warm-text-muted'
if (value >= 0.9) return 'text-green-600'
if (value >= 0.8) return 'text-yellow-600'
return 'text-red-600'
}
export const ActiveModelPanel: React.FC<ActiveModelPanelProps> = ({
model,
runningTraining,
isLoading = false,
onGoToTraining,
}) => {
if (isLoading) {
return (
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
Active Model
</h2>
<div className="flex items-center justify-center py-8">
<div className="animate-pulse text-warm-text-muted">Loading...</div>
</div>
</div>
)
}
if (!model) {
return (
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
Active Model
</h2>
<div className="flex flex-col items-center justify-center py-8 text-center">
<TrendingUp className="w-12 h-12 text-warm-text-disabled mb-3 opacity-20" />
<p className="text-warm-text-primary font-medium mb-1">No Active Model</p>
<p className="text-sm text-warm-text-muted mb-4">
Train and activate a model to see stats here
</p>
{onGoToTraining && (
<Button onClick={onGoToTraining} variant="primary" size="sm">
Go to Training
</Button>
)}
</div>
</div>
)
}
return (
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
Active Model
</h2>
<div className="mb-4">
<span className="text-lg font-bold text-warm-text-primary">{model.version}</span>
<span className="text-warm-text-secondary ml-2">- {model.name}</span>
</div>
<div className="border-t border-warm-border pt-4 mb-4">
<div className="grid grid-cols-3 gap-4">
<div className="text-center">
<p className={`text-2xl font-bold ${getMetricColor(model.metrics_mAP)}`}>
{formatMetric(model.metrics_mAP)}
</p>
<p className="text-xs text-warm-text-muted uppercase">mAP</p>
</div>
<div className="text-center">
<p className={`text-2xl font-bold ${getMetricColor(model.metrics_precision)}`}>
{formatMetric(model.metrics_precision)}
</p>
<p className="text-xs text-warm-text-muted uppercase">Precision</p>
</div>
<div className="text-center">
<p className={`text-2xl font-bold ${getMetricColor(model.metrics_recall)}`}>
{formatMetric(model.metrics_recall)}
</p>
<p className="text-xs text-warm-text-muted uppercase">Recall</p>
</div>
</div>
</div>
<div className="space-y-1 text-sm text-warm-text-secondary">
<p>
<span className="text-warm-text-muted">Activated:</span>{' '}
{formatDate(model.activated_at)}
</p>
<p>
<span className="text-warm-text-muted">Documents:</span>{' '}
{model.document_count.toLocaleString()}
</p>
</div>
{runningTraining && (
<div className="mt-4 p-3 bg-blue-50 rounded-lg border border-blue-100">
<div className="flex items-center gap-2 mb-2">
<span className="w-2 h-2 bg-blue-500 rounded-full animate-pulse" />
<span className="text-sm font-medium text-warm-text-primary">
Training in Progress
</span>
</div>
<p className="text-sm text-warm-text-secondary mb-2">{runningTraining.name}</p>
<div className="w-full bg-gray-200 rounded-full h-2">
<div
className="bg-blue-500 h-2 rounded-full transition-all duration-500"
style={{ width: `${runningTraining.progress}%` }}
/>
</div>
<p className="text-xs text-warm-text-muted mt-1">
{runningTraining.progress}% complete
</p>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,105 @@
import React from 'react'
import { Button } from '../Button'
interface DataQualityPanelProps {
completenessRate: number
completeCount: number
incompleteCount: number
pendingCount: number
isLoading?: boolean
onViewIncomplete?: () => void
}
export const DataQualityPanel: React.FC<DataQualityPanelProps> = ({
completenessRate,
completeCount,
incompleteCount,
pendingCount,
isLoading = false,
onViewIncomplete,
}) => {
const radius = 54
const circumference = 2 * Math.PI * radius
const strokeDashoffset = circumference - (completenessRate / 100) * circumference
return (
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
Data Quality
</h2>
<div className="flex items-center gap-6">
<div className="relative">
<svg width="120" height="120" className="transform -rotate-90">
<circle
cx="60"
cy="60"
r={radius}
stroke="#E5E7EB"
strokeWidth="12"
fill="none"
/>
<circle
cx="60"
cy="60"
r={radius}
stroke="#22C55E"
strokeWidth="12"
fill="none"
strokeLinecap="round"
strokeDasharray={circumference}
strokeDashoffset={isLoading ? circumference : strokeDashoffset}
className="transition-all duration-500"
/>
</svg>
<div className="absolute inset-0 flex items-center justify-center">
<span className="text-3xl font-bold text-warm-text-primary">
{isLoading ? '...' : `${Math.round(completenessRate)}%`}
</span>
</div>
</div>
<div className="flex-1">
<p className="text-sm text-warm-text-secondary mb-4">
Annotation Complete
</p>
<div className="space-y-2">
<div className="flex items-center justify-between text-sm">
<span className="flex items-center gap-2">
<span className="w-2 h-2 bg-green-500 rounded-full" />
Complete
</span>
<span className="font-medium">{isLoading ? '...' : completeCount}</span>
</div>
<div className="flex items-center justify-between text-sm">
<span className="flex items-center gap-2">
<span className="w-2 h-2 bg-orange-500 rounded-full" />
Incomplete
</span>
<span className="font-medium">{isLoading ? '...' : incompleteCount}</span>
</div>
<div className="flex items-center justify-between text-sm">
<span className="flex items-center gap-2">
<span className="w-2 h-2 bg-blue-500 rounded-full" />
Pending
</span>
<span className="font-medium">{isLoading ? '...' : pendingCount}</span>
</div>
</div>
</div>
</div>
{onViewIncomplete && incompleteCount > 0 && (
<div className="mt-4 pt-4 border-t border-warm-border">
<button
onClick={onViewIncomplete}
className="text-sm text-blue-600 hover:text-blue-800 font-medium"
>
View Incomplete Docs
</button>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,134 @@
import React from 'react'
import {
FileText,
Edit,
CheckCircle,
XCircle,
Rocket,
Activity,
} from 'lucide-react'
import type { ActivityItem } from '../../api/types'
interface RecentActivityPanelProps {
activities: ActivityItem[]
isLoading?: boolean
onSeeAll?: () => void
}
const getActivityIcon = (type: ActivityItem['type']) => {
switch (type) {
case 'document_uploaded':
return { Icon: FileText, color: 'text-blue-500', bg: 'bg-blue-50' }
case 'annotation_modified':
return { Icon: Edit, color: 'text-orange-500', bg: 'bg-orange-50' }
case 'training_completed':
return { Icon: CheckCircle, color: 'text-green-500', bg: 'bg-green-50' }
case 'training_failed':
return { Icon: XCircle, color: 'text-red-500', bg: 'bg-red-50' }
case 'model_activated':
return { Icon: Rocket, color: 'text-purple-500', bg: 'bg-purple-50' }
default:
return { Icon: Activity, color: 'text-gray-500', bg: 'bg-gray-50' }
}
}
const formatTimestamp = (timestamp: string): string => {
const date = new Date(timestamp)
const now = new Date()
const diffMs = now.getTime() - date.getTime()
const diffMinutes = Math.floor(diffMs / 60000)
const diffHours = Math.floor(diffMs / 3600000)
const diffDays = Math.floor(diffMs / 86400000)
if (diffMinutes < 1) return 'just now'
if (diffMinutes < 60) return `${diffMinutes} minutes ago`
if (diffHours < 24) return `${diffHours} hours ago`
if (diffDays === 1) return 'yesterday'
if (diffDays < 7) return `${diffDays} days ago`
return date.toLocaleDateString('en-US', { month: 'short', day: 'numeric' })
}
export const RecentActivityPanel: React.FC<RecentActivityPanelProps> = ({
activities,
isLoading = false,
onSeeAll,
}) => {
if (isLoading) {
return (
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<div className="p-6 border-b border-warm-border flex items-center justify-between">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide">
Recent Activity
</h2>
</div>
<div className="p-6">
<div className="flex items-center justify-center py-8">
<div className="animate-pulse text-warm-text-muted">Loading...</div>
</div>
</div>
</div>
)
}
if (activities.length === 0) {
return (
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<div className="p-6 border-b border-warm-border">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide">
Recent Activity
</h2>
</div>
<div className="p-6">
<div className="flex flex-col items-center justify-center py-8 text-center">
<Activity className="w-12 h-12 text-warm-text-disabled mb-3 opacity-20" />
<p className="text-warm-text-primary font-medium mb-1">No recent activity</p>
<p className="text-sm text-warm-text-muted">
Start by uploading documents or creating training jobs
</p>
</div>
</div>
</div>
)
}
return (
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<div className="p-6 border-b border-warm-border flex items-center justify-between">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide">
Recent Activity
</h2>
{onSeeAll && (
<button
onClick={onSeeAll}
className="text-sm text-blue-600 hover:text-blue-800 font-medium"
>
See All
</button>
)}
</div>
<div className="divide-y divide-warm-border">
{activities.map((activity, index) => {
const { Icon, color, bg } = getActivityIcon(activity.type)
return (
<div
key={`${activity.type}-${activity.timestamp}-${index}`}
className="px-6 py-3 flex items-center gap-4 hover:bg-warm-hover transition-colors"
>
<div className={`p-2 rounded-lg ${bg}`}>
<Icon className={color} size={16} />
</div>
<p className="flex-1 text-sm text-warm-text-primary truncate">
{activity.description}
</p>
<span className="text-xs text-warm-text-muted whitespace-nowrap">
{formatTimestamp(activity.timestamp)}
</span>
</div>
)
})}
</div>
</div>
)
}

View File

@@ -0,0 +1,44 @@
import React from 'react'
import { LucideIcon } from 'lucide-react'
interface StatsCardProps {
label: string
value: string | number
icon: LucideIcon
iconColor: string
iconBgColor: string
onClick?: () => void
isLoading?: boolean
}
export const StatsCard: React.FC<StatsCardProps> = ({
label,
value,
icon: Icon,
iconColor,
iconBgColor,
onClick,
isLoading = false,
}) => {
return (
<div
className={`bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow ${
onClick ? 'cursor-pointer' : ''
}`}
onClick={onClick}
role={onClick ? 'button' : undefined}
tabIndex={onClick ? 0 : undefined}
onKeyDown={onClick ? (e) => e.key === 'Enter' && onClick() : undefined}
>
<div className="flex items-center justify-between mb-4">
<div className={`p-3 rounded-lg ${iconBgColor}`}>
<Icon className={iconColor} size={24} />
</div>
</div>
<p className="text-2xl font-bold text-warm-text-primary mb-1">
{isLoading ? '...' : value}
</p>
<p className="text-sm text-warm-text-muted">{label}</p>
</div>
)
}

View File

@@ -0,0 +1,62 @@
import React from 'react'
interface StatusItem {
label: string
status: 'online' | 'degraded' | 'offline'
statusText: string
}
interface SystemStatusBarProps {
items?: StatusItem[]
}
const getStatusColor = (status: StatusItem['status']) => {
switch (status) {
case 'online':
return 'bg-green-500'
case 'degraded':
return 'bg-yellow-500'
case 'offline':
return 'bg-red-500'
}
}
const getStatusTextColor = (status: StatusItem['status']) => {
switch (status) {
case 'online':
return 'text-warm-state-success'
case 'degraded':
return 'text-yellow-600'
case 'offline':
return 'text-red-600'
}
}
const defaultItems: StatusItem[] = [
{ label: 'Backend API', status: 'online', statusText: 'Online' },
{ label: 'Database', status: 'online', statusText: 'Connected' },
{ label: 'GPU', status: 'online', statusText: 'Available' },
]
export const SystemStatusBar: React.FC<SystemStatusBarProps> = ({
items = defaultItems,
}) => {
return (
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-4">
System Status
</h2>
<div className="space-y-3">
{items.map((item) => (
<div key={item.label} className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">{item.label}</span>
<span className={`flex items-center text-sm ${getStatusTextColor(item.status)}`}>
<span className={`w-2 h-2 ${getStatusColor(item.status)} rounded-full mr-2`} />
{item.statusText}
</span>
</div>
))}
</div>
</div>
)
}

View File

@@ -0,0 +1,5 @@
export { StatsCard } from './StatsCard'
export { DataQualityPanel } from './DataQualityPanel'
export { ActiveModelPanel } from './ActiveModelPanel'
export { RecentActivityPanel } from './RecentActivityPanel'
export { SystemStatusBar } from './SystemStatusBar'

View File

@@ -5,3 +5,4 @@ export { useTraining, useTrainingDocuments } from './useTraining'
export { useDatasets, useDatasetDetail } from './useDatasets'
export { useAugmentation } from './useAugmentation'
export { useModels, useModelDetail, useActiveModel } from './useModels'
export { useDashboard, useDashboardStats, useActiveModel as useDashboardActiveModel, useRecentActivity } from './useDashboard'

View File

@@ -0,0 +1,76 @@
import { useQuery } from '@tanstack/react-query'
import { dashboardApi } from '../api/endpoints'
import type {
DashboardStatsResponse,
DashboardActiveModelResponse,
RecentActivityResponse,
} from '../api/types'
export const useDashboardStats = () => {
const { data, isLoading, error, refetch } = useQuery<DashboardStatsResponse>({
queryKey: ['dashboard', 'stats'],
queryFn: () => dashboardApi.getStats(),
staleTime: 30000,
refetchInterval: 60000,
})
return {
stats: data,
isLoading,
error,
refetch,
}
}
export const useActiveModel = () => {
const { data, isLoading, error, refetch } = useQuery<DashboardActiveModelResponse>({
queryKey: ['dashboard', 'active-model'],
queryFn: () => dashboardApi.getActiveModel(),
staleTime: 30000,
refetchInterval: 60000,
})
return {
model: data?.model ?? null,
runningTraining: data?.running_training ?? null,
isLoading,
error,
refetch,
}
}
export const useRecentActivity = (limit: number = 10) => {
const { data, isLoading, error, refetch } = useQuery<RecentActivityResponse>({
queryKey: ['dashboard', 'activity', limit],
queryFn: () => dashboardApi.getRecentActivity(limit),
staleTime: 30000,
refetchInterval: 60000,
})
return {
activities: data?.activities ?? [],
isLoading,
error,
refetch,
}
}
export const useDashboard = () => {
const stats = useDashboardStats()
const activeModel = useActiveModel()
const activity = useRecentActivity()
return {
stats: stats.stats,
model: activeModel.model,
runningTraining: activeModel.runningTraining,
activities: activity.activities,
isLoading: stats.isLoading || activeModel.isLoading || activity.isLoading,
error: stats.error || activeModel.error || activity.error,
refetch: () => {
stats.refetch()
activeModel.refetch()
activity.refetch()
},
}
}

View File

@@ -175,6 +175,80 @@ def run_migrations() -> None:
);
""",
),
# Migration 007: Add extra columns to training_tasks
(
"training_tasks_name",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
""",
),
(
"training_tasks_description",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
""",
),
(
"training_tasks_admin_token",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
""",
),
(
"training_tasks_task_type",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
""",
),
(
"training_tasks_recurring",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
""",
),
(
"training_tasks_metrics",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
""",
),
(
"training_tasks_updated_at",
"""
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
""",
),
# Migration 008: Fix model_versions foreign key constraints
(
"model_versions_fk_fix",
"""
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_dataset_id_fkey
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_task_id_fkey
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
""",
),
# Migration 006b: Ensure only one active model at a time
(
"model_versions_single_active",
"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
ON model_versions(is_active) WHERE is_active = TRUE;
""",
),
]
with engine.connect() as conn:

View File

@@ -193,6 +193,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]):
annotation = session.get(AdminAnnotation, UUID(annotation_id))
if annotation:
session.delete(annotation)
session.commit()
return True
return False
@@ -216,6 +217,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]):
count = len(annotations)
for ann in annotations:
session.delete(ann)
session.commit()
return count
def verify(

View File

@@ -203,6 +203,14 @@ class DatasetRepository(BaseRepository[TrainingDataset]):
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return False
# Delete associated document links first
doc_links = session.exec(
select(DatasetDocument).where(
DatasetDocument.dataset_id == UUID(str(dataset_id))
)
).all()
for link in doc_links:
session.delete(link)
session.delete(dataset)
session.commit()
return True

View File

@@ -264,6 +264,7 @@ class DocumentRepository(BaseRepository[AdminDocument]):
for ann in annotations:
session.delete(ann)
session.delete(document)
session.commit()
return True
return False
@@ -389,7 +390,11 @@ class DocumentRepository(BaseRepository[AdminDocument]):
return None
now = datetime.now(timezone.utc)
if doc.annotation_lock_until and doc.annotation_lock_until > now:
lock_until = doc.annotation_lock_until
# Handle PostgreSQL returning offset-naive datetimes
if lock_until and lock_until.tzinfo is None:
lock_until = lock_until.replace(tzinfo=timezone.utc)
if lock_until and lock_until > now:
return None
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
@@ -433,10 +438,14 @@ class DocumentRepository(BaseRepository[AdminDocument]):
return None
now = datetime.now(timezone.utc)
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
lock_until = doc.annotation_lock_until
# Handle PostgreSQL returning offset-naive datetimes
if lock_until and lock_until.tzinfo is None:
lock_until = lock_until.replace(tzinfo=timezone.utc)
if not lock_until or lock_until <= now:
return None
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
doc.annotation_lock_until = lock_until + timedelta(seconds=additional_seconds)
session.add(doc)
session.commit()
session.refresh(doc)

View File

@@ -118,6 +118,22 @@ class TrainingTaskRepository(BaseRepository[TrainingTask]):
session.expunge(r)
return list(results)
def get_running(self) -> TrainingTask | None:
"""Get currently running training task.
Returns:
Running task or None if no task is running
"""
with get_session_context() as session:
result = session.exec(
select(TrainingTask)
.where(TrainingTask.status == "running")
.order_by(TrainingTask.started_at.desc())
).first()
if result:
session.expunge(result)
return result
def update_status(
self,
task_id: str,

View File

@@ -55,5 +55,6 @@ def create_normalizer_registry(
"Amount": amount_normalizer,
"InvoiceDate": date_normalizer,
"InvoiceDueDate": date_normalizer,
"supplier_org_number": SupplierOrgNumberNormalizer(),
# Note: field_name is "supplier_organisation_number" (from CLASS_TO_FIELD mapping)
"supplier_organisation_number": SupplierOrgNumberNormalizer(),
}

View File

@@ -481,11 +481,22 @@ def create_annotation_router() -> APIRouter:
detail="At least one field value is required",
)
# Get the actual file path from storage
# document.file_path is a relative storage path like "raw_pdfs/uuid.pdf"
storage = get_storage_helper()
filename = document.file_path.split("/")[-1] if "/" in document.file_path else document.file_path
file_path = storage.get_raw_pdf_local_path(filename)
if file_path is None:
raise HTTPException(
status_code=500,
detail=f"Cannot find PDF file: {document.file_path}",
)
# Run auto-labeling
service = get_auto_label_service()
result = service.auto_label_document(
document_id=document_id,
file_path=document.file_path,
file_path=str(file_path),
field_values=request.field_values,
doc_repo=doc_repo,
ann_repo=ann_repo,

View File

@@ -6,7 +6,7 @@ FastAPI endpoints for admin token management.
import logging
import secrets
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter
@@ -41,10 +41,10 @@ def create_auth_router() -> APIRouter:
# Generate secure token
token = secrets.token_urlsafe(32)
# Calculate expiration
# Calculate expiration (use timezone-aware datetime)
expires_at = None
if request.expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
expires_at = datetime.now(timezone.utc) + timedelta(days=request.expires_in_days)
# Create token in database
tokens.create(

View File

@@ -0,0 +1,135 @@
"""
Dashboard API Routes
FastAPI endpoints for dashboard statistics and activity.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from inference.web.core.auth import (
AdminTokenDep,
get_model_version_repository,
get_training_task_repository,
ModelVersionRepoDep,
TrainingTaskRepoDep,
)
from inference.web.schemas.admin import (
DashboardStatsResponse,
ActiveModelResponse,
ActiveModelInfo,
RunningTrainingInfo,
RecentActivityResponse,
ActivityItem,
)
from inference.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
)
logger = logging.getLogger(__name__)
def create_dashboard_router() -> APIRouter:
"""Create dashboard API router."""
router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"])
@router.get(
"/stats",
response_model=DashboardStatsResponse,
summary="Get dashboard statistics",
description="Returns document counts and annotation completeness metrics.",
)
async def get_dashboard_stats(
admin_token: AdminTokenDep,
) -> DashboardStatsResponse:
"""Get dashboard statistics."""
service = DashboardStatsService()
stats = service.get_stats()
return DashboardStatsResponse(
total_documents=stats["total_documents"],
annotation_complete=stats["annotation_complete"],
annotation_incomplete=stats["annotation_incomplete"],
pending=stats["pending"],
completeness_rate=stats["completeness_rate"],
)
@router.get(
"/active-model",
response_model=ActiveModelResponse,
summary="Get active model info",
description="Returns current active model and running training status.",
)
async def get_active_model(
admin_token: AdminTokenDep,
model_repo: ModelVersionRepoDep,
task_repo: TrainingTaskRepoDep,
) -> ActiveModelResponse:
"""Get active model and training status."""
# Get active model
active_model = model_repo.get_active()
model_info = None
if active_model:
model_info = ActiveModelInfo(
version_id=str(active_model.version_id),
version=active_model.version,
name=active_model.name,
metrics_mAP=active_model.metrics_mAP,
metrics_precision=active_model.metrics_precision,
metrics_recall=active_model.metrics_recall,
document_count=active_model.document_count,
activated_at=active_model.activated_at,
)
# Get running training task
running_task = task_repo.get_running()
training_info = None
if running_task:
training_info = RunningTrainingInfo(
task_id=str(running_task.task_id),
name=running_task.name,
status=running_task.status,
started_at=running_task.started_at,
progress=running_task.progress or 0,
)
return ActiveModelResponse(
model=model_info,
running_training=training_info,
)
@router.get(
"/activity",
response_model=RecentActivityResponse,
summary="Get recent activity",
description="Returns recent system activities sorted by timestamp.",
)
async def get_recent_activity(
admin_token: AdminTokenDep,
limit: Annotated[
int,
Query(ge=1, le=50, description="Maximum number of activities"),
] = 10,
) -> RecentActivityResponse:
"""Get recent system activity."""
service = DashboardActivityService()
activities = service.get_recent_activities(limit=limit)
return RecentActivityResponse(
activities=[
ActivityItem(
type=act["type"],
description=act["description"],
timestamp=act["timestamp"],
metadata=act["metadata"],
)
for act in activities
]
)
return router

View File

@@ -44,6 +44,7 @@ from inference.web.api.v1.admin import (
create_locks_router,
create_training_router,
)
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from inference.web.core.scheduler import start_scheduler, stop_scheduler
from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
@@ -115,13 +116,21 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
"""Application lifespan manager."""
logger.info("Starting Invoice Inference API...")
# Initialize database tables
# Initialize async request database tables
try:
async_db.create_tables()
logger.info("Async database tables ready")
except Exception as e:
logger.error(f"Failed to initialize async database: {e}")
# Initialize admin database tables (admin_tokens, admin_documents, training_tasks, etc.)
try:
from inference.data.database import create_db_and_tables
create_db_and_tables()
logger.info("Admin database tables ready")
except Exception as e:
logger.error(f"Failed to initialize admin database: {e}")
# Initialize inference service on startup
try:
inference_service.initialize()
@@ -279,6 +288,10 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
augmentation_router = create_augmentation_router()
app.include_router(augmentation_router, prefix="/api/v1/admin")
# Include dashboard routes
dashboard_router = create_dashboard_router()
app.include_router(dashboard_router, prefix="/api/v1")
# Include batch upload routes
app.include_router(batch_upload_router)

View File

@@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
from .dashboard import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse

View File

@@ -0,0 +1,92 @@
"""
Dashboard API Schemas
Pydantic models for dashboard statistics and activity endpoints.
"""
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, Field
# Activity type literals for type safety
ActivityType = Literal[
"document_uploaded",
"annotation_modified",
"training_completed",
"training_failed",
"model_activated",
]
class DashboardStatsResponse(BaseModel):
"""Response for dashboard statistics."""
total_documents: int = Field(..., description="Total number of documents")
annotation_complete: int = Field(
..., description="Documents with complete annotations"
)
annotation_incomplete: int = Field(
..., description="Documents with incomplete annotations"
)
pending: int = Field(..., description="Documents pending processing")
completeness_rate: float = Field(
..., description="Annotation completeness percentage"
)
class ActiveModelInfo(BaseModel):
"""Active model information."""
version_id: str = Field(..., description="Model version UUID")
version: str = Field(..., description="Model version string")
name: str = Field(..., description="Model name")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
metrics_precision: float | None = Field(None, description="Precision score")
metrics_recall: float | None = Field(None, description="Recall score")
document_count: int = Field(0, description="Number of training documents")
activated_at: datetime | None = Field(None, description="Activation timestamp")
class RunningTrainingInfo(BaseModel):
"""Running training task information."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Training task name")
status: str = Field(..., description="Training status")
started_at: datetime | None = Field(None, description="Start timestamp")
progress: int = Field(0, description="Training progress percentage")
class ActiveModelResponse(BaseModel):
"""Response for active model endpoint."""
model: ActiveModelInfo | None = Field(
None, description="Active model info, null if none"
)
running_training: RunningTrainingInfo | None = Field(
None, description="Running training task, null if none"
)
class ActivityItem(BaseModel):
"""Single activity item."""
type: ActivityType = Field(
...,
description="Activity type: document_uploaded, annotation_modified, training_completed, training_failed, model_activated",
)
description: str = Field(..., description="Human-readable description")
timestamp: datetime = Field(..., description="Activity timestamp")
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class RecentActivityResponse(BaseModel):
"""Response for recent activity endpoint."""
activities: list[ActivityItem] = Field(
default_factory=list, description="List of recent activities"
)

View File

@@ -291,7 +291,7 @@ class AutoLabelService:
"bbox_y": bbox_y,
"bbox_width": bbox_width,
"bbox_height": bbox_height,
"text_value": best_match.matched_value,
"text_value": best_match.matched_text,
"confidence": best_match.score,
"source": "auto",
})

View File

@@ -0,0 +1,276 @@
"""
Dashboard Service
Business logic for dashboard statistics and activity aggregation.
"""
import logging
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import func, exists, and_, or_
from sqlmodel import select
from inference.data.database import get_session_context
from inference.data.admin_models import (
AdminDocument,
AdminAnnotation,
AnnotationHistory,
TrainingTask,
ModelVersion,
)
logger = logging.getLogger(__name__)
# Field class IDs for completeness calculation
# Identifiers: invoice_number (0) or ocr_number (3)
IDENTIFIER_CLASS_IDS = {0, 3}
# Payment accounts: bankgiro (4) or plusgiro (5)
PAYMENT_CLASS_IDS = {4, 5}
def is_annotation_complete(annotations: list[dict[str, Any]]) -> bool:
"""Check if a document's annotations are complete.
A document is complete if it has:
- At least one identifier field (invoice_number OR ocr_number)
- At least one payment field (bankgiro OR plusgiro)
Args:
annotations: List of annotation dicts with class_id
Returns:
True if document has required fields
"""
class_ids = {ann.get("class_id") for ann in annotations}
has_identifier = bool(class_ids & IDENTIFIER_CLASS_IDS)
has_payment = bool(class_ids & PAYMENT_CLASS_IDS)
return has_identifier and has_payment
class DashboardStatsService:
"""Service for computing dashboard statistics."""
def get_stats(self) -> dict[str, Any]:
"""Get dashboard statistics.
Returns:
Dict with total_documents, annotation_complete, annotation_incomplete,
pending, and completeness_rate
"""
with get_session_context() as session:
# Total documents
total = session.exec(
select(func.count()).select_from(AdminDocument)
).one()
# Pending documents (status in ['pending', 'auto_labeling'])
pending = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(AdminDocument.status.in_(["pending", "auto_labeling"]))
).one()
# Complete annotations: labeled + has identifier + has payment
complete = self._count_complete(session)
# Incomplete: labeled but not complete
labeled_count = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(AdminDocument.status == "labeled")
).one()
incomplete = labeled_count - complete
# Calculate completeness rate
total_assessed = complete + incomplete
completeness_rate = (
round(complete / total_assessed * 100, 2)
if total_assessed > 0
else 0.0
)
return {
"total_documents": total,
"annotation_complete": complete,
"annotation_incomplete": incomplete,
"pending": pending,
"completeness_rate": completeness_rate,
}
def _count_complete(self, session) -> int:
"""Count documents with complete annotations.
A document is complete if it:
1. Has status = 'labeled'
2. Has at least one identifier annotation (class_id 0 or 3)
3. Has at least one payment annotation (class_id 4 or 5)
"""
# Subquery for documents with identifier
has_identifier = exists(
select(1)
.select_from(AdminAnnotation)
.where(
and_(
AdminAnnotation.document_id == AdminDocument.document_id,
AdminAnnotation.class_id.in_(IDENTIFIER_CLASS_IDS),
)
)
)
# Subquery for documents with payment
has_payment = exists(
select(1)
.select_from(AdminAnnotation)
.where(
and_(
AdminAnnotation.document_id == AdminDocument.document_id,
AdminAnnotation.class_id.in_(PAYMENT_CLASS_IDS),
)
)
)
count = session.exec(
select(func.count())
.select_from(AdminDocument)
.where(
and_(
AdminDocument.status == "labeled",
has_identifier,
has_payment,
)
)
).one()
return count
class DashboardActivityService:
"""Service for aggregating recent activities."""
def get_recent_activities(self, limit: int = 10) -> list[dict[str, Any]]:
"""Get recent system activities.
Aggregates from:
- Document uploads
- Annotation modifications
- Training completions/failures
- Model activations
Args:
limit: Maximum number of activities to return
Returns:
List of activity dicts sorted by timestamp DESC
"""
activities = []
with get_session_context() as session:
# Document uploads (recent 10)
uploads = session.exec(
select(AdminDocument)
.order_by(AdminDocument.created_at.desc())
.limit(limit)
).all()
for doc in uploads:
activities.append({
"type": "document_uploaded",
"description": f"Uploaded {doc.filename}",
"timestamp": doc.created_at,
"metadata": {
"document_id": str(doc.document_id),
"filename": doc.filename,
},
})
# Annotation modifications (from history)
modifications = session.exec(
select(AnnotationHistory)
.where(AnnotationHistory.action == "override")
.order_by(AnnotationHistory.created_at.desc())
.limit(limit)
).all()
for mod in modifications:
# Get document filename
doc = session.get(AdminDocument, mod.document_id)
filename = doc.filename if doc else "Unknown"
field_name = ""
if mod.new_value and isinstance(mod.new_value, dict):
field_name = mod.new_value.get("class_name", "")
activities.append({
"type": "annotation_modified",
"description": f"Modified {filename} {field_name}".strip(),
"timestamp": mod.created_at,
"metadata": {
"annotation_id": str(mod.annotation_id),
"document_id": str(mod.document_id),
"field_name": field_name,
},
})
# Training completions and failures
training_tasks = session.exec(
select(TrainingTask)
.where(TrainingTask.status.in_(["completed", "failed"]))
.order_by(TrainingTask.updated_at.desc())
.limit(limit)
).all()
for task in training_tasks:
if task.updated_at is None:
continue
if task.status == "completed":
# Use metrics_mAP field directly
mAP = task.metrics_mAP or 0.0
activities.append({
"type": "training_completed",
"description": f"Training complete: {task.name}, mAP {mAP:.1%}",
"timestamp": task.updated_at,
"metadata": {
"task_id": str(task.task_id),
"task_name": task.name,
"mAP": mAP,
},
})
else:
activities.append({
"type": "training_failed",
"description": f"Training failed: {task.name}",
"timestamp": task.updated_at,
"metadata": {
"task_id": str(task.task_id),
"task_name": task.name,
"error": task.error_message or "",
},
})
# Model activations
model_versions = session.exec(
select(ModelVersion)
.where(ModelVersion.activated_at.is_not(None))
.order_by(ModelVersion.activated_at.desc())
.limit(limit)
).all()
for model in model_versions:
if model.activated_at is None:
continue
activities.append({
"type": "model_activated",
"description": f"Activated model {model.version}",
"timestamp": model.activated_at,
"metadata": {
"version_id": str(model.version_id),
"version": model.version,
},
})
# Sort all activities by timestamp DESC and return top N
activities.sort(key=lambda x: x["timestamp"], reverse=True)
return activities[:limit]

View File

@@ -42,6 +42,7 @@ dev = [
"black>=23.0.0",
"ruff>=0.1.0",
"mypy>=1.0.0",
"testcontainers[postgres]>=4.0.0",
]
gpu = [
"paddlepaddle-gpu>=2.5.0",

View File

@@ -1,73 +0,0 @@
"""Run database migration for training_status fields."""
import psycopg2
import os
# Read password from .env file
password = ""
try:
with open(".env") as f:
for line in f:
if line.startswith("DB_PASSWORD="):
password = line.strip().split("=", 1)[1].strip('"').strip("'")
break
except Exception as e:
print(f"Error reading .env: {e}")
print(f"Password found: {bool(password)}")
conn = psycopg2.connect(
host="192.168.68.31",
port=5432,
database="docmaster",
user="docmaster",
password=password
)
conn.autocommit = True
cur = conn.cursor()
# Add training_status column
try:
cur.execute("ALTER TABLE training_datasets ADD COLUMN training_status VARCHAR(20) DEFAULT NULL")
print("Added training_status column")
except Exception as e:
print(f"training_status: {e}")
# Add active_training_task_id column
try:
cur.execute("ALTER TABLE training_datasets ADD COLUMN active_training_task_id UUID DEFAULT NULL")
print("Added active_training_task_id column")
except Exception as e:
print(f"active_training_task_id: {e}")
# Create indexes
try:
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status)")
print("Created training_status index")
except Exception as e:
print(f"index training_status: {e}")
try:
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id)")
print("Created active_training_task_id index")
except Exception as e:
print(f"index active_training_task_id: {e}")
# Update existing datasets that have been used in completed training tasks to trained status
try:
cur.execute("""
UPDATE training_datasets d
SET status = 'trained'
WHERE d.status = 'ready'
AND EXISTS (
SELECT 1 FROM training_tasks t
WHERE t.dataset_id = d.dataset_id
AND t.status = 'completed'
)
""")
print(f"Updated {cur.rowcount} datasets to trained status")
except Exception as e:
print(f"update status: {e}")
cur.close()
conn.close()
print("Migration complete!")

View File

@@ -0,0 +1 @@
"""Integration tests for invoice-master-poc-v2."""

View File

@@ -0,0 +1 @@
"""API integration tests."""

View File

@@ -0,0 +1,389 @@
"""
API Integration Tests
Tests FastAPI endpoints with mocked services.
These tests verify the API layer works correctly with the service layer.
"""
import io
import tempfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
@dataclass
class MockServiceResult:
"""Mock result from inference service."""
document_id: str = "test-doc-123"
success: bool = True
document_type: str = "invoice"
fields: dict[str, str] = field(default_factory=lambda: {
"InvoiceNumber": "INV-2024-001",
"Amount": "1500.00",
"InvoiceDate": "2024-01-15",
"OCR": "12345678901234",
"Bankgiro": "1234-5678",
})
confidence: dict[str, float] = field(default_factory=lambda: {
"InvoiceNumber": 0.95,
"Amount": 0.92,
"InvoiceDate": 0.88,
"OCR": 0.95,
"Bankgiro": 0.90,
})
detections: list[dict[str, Any]] = field(default_factory=list)
processing_time_ms: float = 150.5
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
@pytest.fixture
def temp_storage_dir():
"""Create temporary storage directories."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
uploads_dir = base / "uploads" / "inference"
results_dir = base / "results"
uploads_dir.mkdir(parents=True, exist_ok=True)
results_dir.mkdir(parents=True, exist_ok=True)
yield {
"base": base,
"uploads": uploads_dir,
"results": results_dir,
}
@pytest.fixture
def mock_inference_service():
"""Create a mock inference service."""
service = MagicMock()
service.is_initialized = True
service.gpu_available = False
# Create a realistic mock result
mock_result = MockServiceResult()
service.process_pdf.return_value = mock_result
service.process_image.return_value = mock_result
service.initialize.return_value = None
return service
@pytest.fixture
def mock_storage_config(temp_storage_dir):
"""Create mock storage configuration."""
from inference.web.config import StorageConfig
return StorageConfig(
upload_dir=temp_storage_dir["uploads"],
result_dir=temp_storage_dir["results"],
max_file_size_mb=50,
)
@pytest.fixture
def mock_storage_helper(temp_storage_dir):
"""Create a mock storage helper."""
helper = MagicMock()
helper.get_uploads_base_path.return_value = temp_storage_dir["uploads"]
helper.get_result_local_path.return_value = None
helper.result_exists.return_value = False
return helper
@pytest.fixture
def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
"""Create a test FastAPI application with mocked storage."""
from inference.web.api.v1.public.inference import create_inference_router
app = FastAPI()
# Patch get_storage_helper to return our mock
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
inference_router = create_inference_router(mock_inference_service, mock_storage_config)
app.include_router(inference_router)
return app
@pytest.fixture
def client(test_app, mock_storage_helper):
"""Create a test client with storage helper patched."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
yield TestClient(test_app)
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_check(self, client, mock_inference_service):
"""Test health check returns status."""
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
class TestInferenceEndpoint:
"""Tests for inference endpoint."""
def test_infer_pdf(self, client, mock_inference_service, mock_storage_helper, temp_storage_dir):
"""Test PDF inference endpoint."""
# Create a minimal PDF content
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert data["result"]["success"] is True
assert "InvoiceNumber" in data["result"]["fields"]
def test_infer_image(self, client, mock_inference_service, mock_storage_helper):
"""Test image inference endpoint."""
# Create minimal PNG header
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", io.BytesIO(png_header), "image/png")},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
def test_infer_invalid_file_type(self, client, mock_storage_helper):
"""Test rejection of invalid file types."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")},
)
assert response.status_code == 400
def test_infer_no_file(self, client, mock_storage_helper):
"""Test rejection when no file provided."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post("/api/v1/infer")
assert response.status_code == 422 # Validation error
def test_infer_result_structure(self, client, mock_inference_service, mock_storage_helper):
"""Test that result has expected structure."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
result = data["result"]
# Check required fields
assert "document_id" in result
assert "success" in result
assert "fields" in result
assert "confidence" in result
assert "processing_time_ms" in result
class TestInferenceResultFormat:
"""Tests for inference result formatting."""
def test_result_fields_mapped_correctly(self, client, mock_inference_service, mock_storage_helper):
"""Test that fields are mapped to API response format."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
fields = data["result"]["fields"]
assert fields["InvoiceNumber"] == "INV-2024-001"
assert fields["Amount"] == "1500.00"
assert fields["InvoiceDate"] == "2024-01-15"
def test_confidence_values_included(self, client, mock_inference_service, mock_storage_helper):
"""Test that confidence values are included."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
confidence = data["result"]["confidence"]
assert "InvoiceNumber" in confidence
assert confidence["InvoiceNumber"] == 0.95
class TestErrorHandling:
"""Tests for error handling in API."""
def test_service_error_handling(self, client, mock_inference_service, mock_storage_helper):
"""Test handling of service errors."""
mock_inference_service.process_pdf.side_effect = Exception("Processing failed")
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
# Should return error response
assert response.status_code >= 400
def test_empty_file_handling(self, client, mock_storage_helper):
"""Test handling of empty files."""
# Empty file still has valid content type
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(b""), "application/pdf")},
)
# Empty file may be processed or rejected depending on implementation
# Just verify we get a response
assert response.status_code in [200, 400, 422, 500]
class TestResponseFormat:
"""Tests for API response format consistency."""
def test_success_response_format(self, client, mock_inference_service, mock_storage_helper):
"""Test successful response format."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert isinstance(data, dict)
assert "result" in data
def test_json_serialization(self, client, mock_inference_service, mock_storage_helper):
"""Test that all result fields are JSON serializable."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
# If this doesn't raise, JSON is valid
data = response.json()
assert data is not None
class TestDocumentIdGeneration:
"""Tests for document ID handling."""
def test_document_id_generated(self, client, mock_inference_service, mock_storage_helper):
"""Test that document ID is generated."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
assert "document_id" in data["result"]
assert data["result"]["document_id"] is not None
def test_document_id_from_filename(self, client, mock_inference_service, mock_storage_helper):
"""Test document ID derived from filename."""
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
"/api/v1/infer",
files={"file": ("my_invoice_123.pdf", io.BytesIO(pdf_content), "application/pdf")},
)
data = response.json()
# Document ID should be set (either from filename or generated)
assert data["result"]["document_id"] is not None

View File

@@ -0,0 +1,400 @@
"""
Dashboard API Integration Tests
Tests Dashboard API endpoints with real database operations via TestClient.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
AnnotationHistory,
ModelVersion,
TrainingDataset,
TrainingTask,
)
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from inference.web.core.auth import get_admin_token_dep
def create_test_app(override_token_dep):
"""Create a FastAPI test application with dashboard router."""
app = FastAPI()
router = create_dashboard_router()
app.include_router(router)
# Override auth dependency
app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep
return app
class TestDashboardStatsEndpoint:
"""Tests for GET /admin/dashboard/stats endpoint."""
def test_stats_empty_database(self, patched_session, admin_token):
"""Test stats endpoint with empty database."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["total_documents"] == 0
assert data["annotation_complete"] == 0
assert data["annotation_incomplete"] == 0
assert data["pending"] == 0
assert data["completeness_rate"] == 0.0
def test_stats_with_pending_documents(self, patched_session, admin_token):
"""Test stats with pending documents."""
session = patched_session
# Create pending documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"pending_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/pending_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["total_documents"] == 3
assert data["pending"] == 3
def test_stats_with_complete_annotations(self, patched_session, admin_token):
"""Test stats with complete annotations."""
session = patched_session
# Create labeled document with complete annotations
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="complete.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/complete.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add identifier and payment annotations
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0, # invoice_number
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=4, # bankgiro
class_name="bankgiro",
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/stats")
assert response.status_code == 200
data = response.json()
assert data["annotation_complete"] == 1
assert data["completeness_rate"] == 100.0
class TestActiveModelEndpoint:
"""Tests for GET /admin/dashboard/active-model endpoint."""
def test_active_model_none(self, patched_session, admin_token):
"""Test active-model endpoint with no active model."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["model"] is None
assert data["running_training"] is None
def test_active_model_with_model(self, patched_session, admin_token, sample_dataset):
"""Test active-model endpoint with active model."""
session = patched_session
# Create training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Test Task",
status="completed",
task_type="train",
dataset_id=sample_dataset.dataset_id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
# Create active model
model = ModelVersion(
version_id=uuid4(),
version="1.0.0",
name="Test Model",
model_path="/models/test.pt",
status="active",
is_active=True,
task_id=task.task_id,
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.90,
metrics_precision=0.88,
metrics_recall=0.85,
document_count=100,
file_size=50000000,
activated_at=datetime.now(timezone.utc),
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(model)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["model"] is not None
assert data["model"]["version"] == "1.0.0"
assert data["model"]["name"] == "Test Model"
assert data["model"]["metrics_mAP"] == 0.90
def test_active_model_with_running_training(self, patched_session, admin_token, sample_dataset):
"""Test active-model endpoint with running training."""
session = patched_session
# Create running training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Running Task",
status="running",
task_type="train",
dataset_id=sample_dataset.dataset_id,
started_at=datetime.now(timezone.utc),
progress=50,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/active-model")
assert response.status_code == 200
data = response.json()
assert data["running_training"] is not None
assert data["running_training"]["name"] == "Running Task"
assert data["running_training"]["status"] == "running"
assert data["running_training"]["progress"] == 50
class TestRecentActivityEndpoint:
"""Tests for GET /admin/dashboard/activity endpoint."""
def test_activity_empty(self, patched_session, admin_token):
"""Test activity endpoint with no activities."""
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
assert data["activities"] == []
def test_activity_with_uploads(self, patched_session, admin_token):
"""Test activity includes document uploads."""
session = patched_session
# Create documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"activity_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/activity_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
upload_activities = [a for a in data["activities"] if a["type"] == "document_uploaded"]
assert len(upload_activities) == 3
def test_activity_limit_parameter(self, patched_session, admin_token):
"""Test activity limit parameter."""
session = patched_session
# Create many documents
for i in range(15):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"limit_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/limit_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity?limit=5")
assert response.status_code == 200
data = response.json()
assert len(data["activities"]) <= 5
def test_activity_invalid_limit(self, patched_session, admin_token):
"""Test activity with invalid limit parameter."""
app = create_test_app(admin_token.token)
client = TestClient(app)
# Limit too high
response = client.get("/admin/dashboard/activity?limit=100")
assert response.status_code == 422
# Limit too low
response = client.get("/admin/dashboard/activity?limit=0")
assert response.status_code == 422
def test_activity_with_training_completion(self, patched_session, admin_token, sample_dataset):
"""Test activity includes training completions."""
session = patched_session
# Create completed training task
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Completed Task",
status="completed",
task_type="train",
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.95,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(task)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
training_activities = [a for a in data["activities"] if a["type"] == "training_completed"]
assert len(training_activities) >= 1
def test_activity_sorted_by_timestamp(self, patched_session, admin_token):
"""Test activities are sorted by timestamp descending."""
session = patched_session
# Create documents
for i in range(5):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"sorted_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/sorted_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
app = create_test_app(admin_token.token)
client = TestClient(app)
response = client.get("/admin/dashboard/activity")
assert response.status_code == 200
data = response.json()
timestamps = [a["timestamp"] for a in data["activities"]]
assert timestamps == sorted(timestamps, reverse=True)

View File

@@ -0,0 +1,465 @@
"""
Integration Test Fixtures
Provides shared fixtures for integration tests using PostgreSQL.
IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation.
This ensures tests never touch the real production/development database.
Supported modes:
1. Docker testcontainers (default): Automatically starts a PostgreSQL container
2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!)
To use an external test database, set:
TEST_DB_URL=postgresql://user:password@host:port/test_dbname
"""
import os
import tempfile
from contextlib import contextmanager, ExitStack
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlmodel import Session, SQLModel, create_engine
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
AnnotationHistory,
BatchUpload,
BatchUploadFile,
DatasetDocument,
ModelVersion,
TrainingDataset,
TrainingDocumentLink,
TrainingLog,
TrainingTask,
)
# =============================================================================
# Database Fixtures
# =============================================================================
def _is_docker_available() -> bool:
"""Check if Docker is available."""
try:
import docker
client = docker.from_env()
client.ping()
return True
except Exception:
return False
def _get_test_db_url() -> str | None:
"""Get test database URL from environment."""
return os.environ.get("TEST_DB_URL")
@pytest.fixture(scope="session")
def test_engine():
"""Create a SQLAlchemy engine for testing.
Uses one of:
1. TEST_DB_URL environment variable (dedicated test database)
2. Docker testcontainers (if Docker is available)
IMPORTANT: Will NOT fall back to production database. If Docker is not
available and TEST_DB_URL is not set, tests will fail with a clear error.
The engine is shared across all tests in a session for efficiency.
"""
# Try to get URL from environment first
connection_url = _get_test_db_url()
if connection_url:
# Use external test database from environment
# Warn if it looks like a production database
if "docmaster" in connection_url and "_test" not in connection_url:
import warnings
warnings.warn(
"TEST_DB_URL appears to point to a production database. "
"Please use a dedicated test database (e.g., docmaster_test).",
UserWarning,
)
elif _is_docker_available():
# Use testcontainers - this is the recommended approach
from testcontainers.postgres import PostgresContainer
postgres = PostgresContainer("postgres:15-alpine")
postgres.start()
connection_url = postgres.get_connection_url()
if "psycopg2" in connection_url:
connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://")
# Store container for cleanup
test_engine._postgres_container = postgres
else:
# No Docker and no TEST_DB_URL - fail with clear instructions
pytest.fail(
"Integration tests require Docker or a TEST_DB_URL environment variable.\n\n"
"Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n"
" - Windows: https://docs.docker.com/desktop/install/windows-install/\n"
" - The testcontainers library will automatically create a PostgreSQL container.\n\n"
"Option 2: Set TEST_DB_URL to a dedicated test database:\n"
" - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n"
" - NEVER use your production database for tests!\n\n"
"Integration tests will NOT fall back to the production database."
)
engine = create_engine(
connection_url,
echo=False,
pool_pre_ping=True,
)
# Create all tables
SQLModel.metadata.create_all(engine)
yield engine
# Cleanup
SQLModel.metadata.drop_all(engine)
engine.dispose()
# Stop container if we started one
if hasattr(test_engine, "_postgres_container"):
test_engine._postgres_container.stop()
@pytest.fixture(scope="function")
def db_session(test_engine) -> Generator[Session, None, None]:
"""Provide a database session for each test function.
Each test gets a fresh session that rolls back after the test,
ensuring test isolation.
"""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
# Rollback and cleanup
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(scope="function")
def patched_session(db_session):
"""Patch get_session_context to use the test session.
This allows repository classes to use the test database session
instead of creating their own connections.
We need to patch in multiple locations because each repository module
imports get_session_context directly.
"""
@contextmanager
def mock_session_context() -> Generator[Session, None, None]:
yield db_session
# All modules that import get_session_context
patch_targets = [
"inference.data.database.get_session_context",
"inference.data.repositories.document_repository.get_session_context",
"inference.data.repositories.annotation_repository.get_session_context",
"inference.data.repositories.dataset_repository.get_session_context",
"inference.data.repositories.training_task_repository.get_session_context",
"inference.data.repositories.model_version_repository.get_session_context",
"inference.data.repositories.batch_upload_repository.get_session_context",
"inference.data.repositories.token_repository.get_session_context",
"inference.web.services.dashboard_service.get_session_context",
]
with ExitStack() as stack:
for target in patch_targets:
try:
stack.enter_context(patch(target, mock_session_context))
except (ModuleNotFoundError, AttributeError):
# Skip if module doesn't exist or doesn't have the attribute
pass
yield db_session
# =============================================================================
# Test Data Fixtures
# =============================================================================
@pytest.fixture
def admin_token(db_session) -> AdminToken:
"""Create a test admin token."""
token = AdminToken(
token="test-admin-token-12345",
name="Test Admin",
is_active=True,
created_at=datetime.now(timezone.utc),
)
db_session.add(token)
db_session.commit()
db_session.refresh(token)
return token
@pytest.fixture
def sample_document(db_session, admin_token) -> AdminDocument:
"""Create a sample document for testing."""
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="test_invoice.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(doc)
db_session.commit()
db_session.refresh(doc)
return doc
@pytest.fixture
def sample_annotation(db_session, sample_document) -> AdminAnnotation:
"""Create a sample annotation for testing."""
annotation = AdminAnnotation(
annotation_id=uuid4(),
document_id=sample_document.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=240,
bbox_width=160,
bbox_height=40,
text_value="INV-2024-001",
confidence=0.95,
source="auto",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(annotation)
db_session.commit()
db_session.refresh(annotation)
return annotation
@pytest.fixture
def sample_dataset(db_session) -> TrainingDataset:
"""Create a sample training dataset for testing."""
dataset = TrainingDataset(
dataset_id=uuid4(),
name="Test Dataset",
description="Dataset for integration testing",
status="building",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=0,
total_images=0,
total_annotations=0,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(dataset)
db_session.commit()
db_session.refresh(dataset)
return dataset
@pytest.fixture
def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask:
"""Create a sample training task for testing."""
task = TrainingTask(
task_id=uuid4(),
admin_token=admin_token.token,
name="Test Training Task",
description="Training task for integration testing",
status="pending",
task_type="train",
dataset_id=sample_dataset.dataset_id,
config={"epochs": 10, "batch_size": 16},
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
return task
@pytest.fixture
def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion:
"""Create a sample model version for testing."""
version = ModelVersion(
version_id=uuid4(),
version="1.0.0",
name="Test Model v1",
description="Model version for integration testing",
model_path="/models/test_model.pt",
status="inactive",
is_active=False,
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
metrics_mAP=0.85,
metrics_precision=0.88,
metrics_recall=0.82,
document_count=100,
file_size=50000000,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(version)
db_session.commit()
db_session.refresh(version)
return version
@pytest.fixture
def sample_batch_upload(db_session, admin_token) -> BatchUpload:
"""Create a sample batch upload for testing."""
batch = BatchUpload(
batch_id=uuid4(),
admin_token=admin_token.token,
filename="test_batch.zip",
file_size=10240,
upload_source="api",
status="processing",
total_files=5,
processed_files=0,
successful_files=0,
failed_files=0,
created_at=datetime.now(timezone.utc),
)
db_session.add(batch)
db_session.commit()
db_session.refresh(batch)
return batch
# =============================================================================
# Multiple Documents Fixture
# =============================================================================
@pytest.fixture
def multiple_documents(db_session, admin_token) -> list[AdminDocument]:
"""Create multiple documents for pagination/filtering tests."""
documents = []
statuses = ["pending", "pending", "labeled", "labeled", "exported"]
categories = ["invoice", "invoice", "invoice", "letter", "invoice"]
for i, (status, category) in enumerate(zip(statuses, categories)):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"test_doc_{i}.pdf",
file_size=1024 + i * 100,
content_type="application/pdf",
file_path=f"/uploads/test_doc_{i}.pdf",
page_count=1,
status=status,
upload_source="ui",
category=category,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(doc)
documents.append(doc)
db_session.commit()
for doc in documents:
db_session.refresh(doc)
return documents
# =============================================================================
# Temporary File Fixtures
# =============================================================================
@pytest.fixture
def temp_upload_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for file uploads."""
with tempfile.TemporaryDirectory() as tmpdir:
upload_dir = Path(tmpdir) / "uploads"
upload_dir.mkdir(parents=True, exist_ok=True)
yield upload_dir
@pytest.fixture
def temp_model_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for model files."""
with tempfile.TemporaryDirectory() as tmpdir:
model_dir = Path(tmpdir) / "models"
model_dir.mkdir(parents=True, exist_ok=True)
yield model_dir
@pytest.fixture
def temp_dataset_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for dataset files."""
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir) / "datasets"
dataset_dir.mkdir(parents=True, exist_ok=True)
yield dataset_dir
# =============================================================================
# Sample PDF Fixture
# =============================================================================
@pytest.fixture
def sample_pdf_bytes() -> bytes:
"""Return minimal valid PDF bytes for testing."""
# Minimal valid PDF structure
return b"""%PDF-1.4
1 0 obj
<< /Type /Catalog /Pages 2 0 R >>
endobj
2 0 obj
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
endobj
3 0 obj
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >>
endobj
xref
0 4
0000000000 65535 f
0000000009 00000 n
0000000058 00000 n
0000000115 00000 n
trailer
<< /Size 4 /Root 1 0 R >>
startxref
196
%%EOF"""
@pytest.fixture
def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path:
"""Create a sample PDF file for testing."""
pdf_path = temp_upload_dir / "test_invoice.pdf"
pdf_path.write_bytes(sample_pdf_bytes)
return pdf_path

View File

@@ -0,0 +1 @@
"""Pipeline integration tests."""

View File

@@ -0,0 +1,456 @@
"""
Inference Pipeline Integration Tests
Tests the complete pipeline from input to output.
Note: These tests use mocks for YOLO and OCR to avoid requiring actual models,
but test the integration of pipeline components.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
import numpy as np
from inference.pipeline.pipeline import (
InferencePipeline,
InferenceResult,
CrossValidationResult,
)
from inference.pipeline.yolo_detector import Detection
from inference.pipeline.field_extractor import ExtractedField
@pytest.fixture
def mock_detection():
"""Create a mock detection."""
return Detection(
class_id=0,
class_name="invoice_number",
confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
)
@pytest.fixture
def mock_extracted_field():
"""Create a mock extracted field."""
return ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-2024-001",
normalized_value="INV-2024-001",
confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=True,
)
class TestInferenceResultConstruction:
"""Tests for InferenceResult construction and methods."""
def test_default_result(self):
"""Test default InferenceResult values."""
result = InferenceResult()
assert result.document_id is None
assert result.success is False
assert result.fields == {}
assert result.confidence == {}
assert result.raw_detections == []
assert result.extracted_fields == []
assert result.errors == []
assert result.fallback_used is False
assert result.cross_validation is None
def test_result_to_json(self):
"""Test JSON serialization of result."""
result = InferenceResult(
document_id="test-doc",
success=True,
fields={
"InvoiceNumber": "INV-001",
"Amount": "1500.00",
},
confidence={
"InvoiceNumber": 0.95,
"Amount": 0.92,
},
bboxes={
"InvoiceNumber": (100, 50, 200, 30),
},
)
json_data = result.to_json()
assert json_data["DocumentId"] == "test-doc"
assert json_data["success"] is True
assert json_data["InvoiceNumber"] == "INV-001"
assert json_data["Amount"] == "1500.00"
assert json_data["confidence"]["InvoiceNumber"] == 0.95
assert "bboxes" in json_data
def test_result_get_field(self):
"""Test getting field value and confidence."""
result = InferenceResult(
fields={"InvoiceNumber": "INV-001"},
confidence={"InvoiceNumber": 0.95},
)
value, conf = result.get_field("InvoiceNumber")
assert value == "INV-001"
assert conf == 0.95
value, conf = result.get_field("Amount")
assert value is None
assert conf == 0.0
class TestCrossValidation:
"""Tests for cross-validation logic."""
def test_cross_validation_default(self):
"""Test default CrossValidationResult values."""
cv = CrossValidationResult()
assert cv.is_valid is False
assert cv.ocr_match is None
assert cv.amount_match is None
assert cv.bankgiro_match is None
assert cv.plusgiro_match is None
assert cv.payment_line_ocr is None
assert cv.payment_line_amount is None
assert cv.details == []
def test_cross_validation_with_matches(self):
"""Test CrossValidationResult with matches."""
cv = CrossValidationResult(
is_valid=True,
ocr_match=True,
amount_match=True,
bankgiro_match=True,
payment_line_ocr="12345678901234",
payment_line_amount="1500.00",
payment_line_account="1234-5678",
payment_line_account_type="bankgiro",
details=["OCR match", "Amount match", "Bankgiro match"],
)
assert cv.is_valid is True
assert cv.ocr_match is True
assert cv.amount_match is True
assert len(cv.details) == 3
class TestPipelineMergeFields:
"""Tests for field merging logic."""
def test_merge_selects_highest_confidence(self):
"""Test that merge selects highest confidence for duplicate fields."""
# Create mock pipeline with minimal mocking
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-001",
normalized_value="INV-001",
confidence=0.85,
detection_confidence=0.90,
ocr_confidence=0.85,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=True,
),
ExtractedField(
field_name="InvoiceNumber",
raw_text="INV-001",
normalized_value="INV-001",
confidence=0.95, # Higher confidence
detection_confidence=0.95,
ocr_confidence=0.95,
bbox=(105, 52, 198, 28),
page_no=0,
is_valid=True,
),
]
pipeline._merge_fields(result)
assert result.fields["InvoiceNumber"] == "INV-001"
assert result.confidence["InvoiceNumber"] == 0.95
def test_merge_skips_invalid_fields(self):
"""Test that merge skips invalid extracted fields."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False)
result = InferenceResult()
result.extracted_fields = [
ExtractedField(
field_name="InvoiceNumber",
raw_text="",
normalized_value=None,
confidence=0.95,
detection_confidence=0.95,
ocr_confidence=0.95,
bbox=(100, 50, 200, 30),
page_no=0,
is_valid=False, # Invalid
),
ExtractedField(
field_name="Amount",
raw_text="1500.00",
normalized_value="1500.00",
confidence=0.92,
detection_confidence=0.92,
ocr_confidence=0.92,
bbox=(200, 100, 100, 25),
page_no=0,
is_valid=True,
),
]
pipeline._merge_fields(result)
assert "InvoiceNumber" not in result.fields
assert result.fields["Amount"] == "1500.00"
class TestPaymentLineValidation:
"""Tests for payment line cross-validation."""
def test_payment_line_overrides_ocr(self):
"""Test that payment line OCR overrides detected OCR."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Mock payment line parser
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = "12345678901234"
mock_parsed.amount = "1500.00"
mock_parsed.account_number = "12345678"
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
"OCR": "99999999999999", # Different OCR
},
confidence={"OCR": 0.85},
)
pipeline._cross_validate_payment_line(result)
# Payment line OCR should override
assert result.fields["OCR"] == "12345678901234"
assert result.confidence["OCR"] == 0.95
def test_payment_line_overrides_amount(self):
"""Test that payment line amount overrides detected amount."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = None
mock_parsed.amount = "2500.50"
mock_parsed.account_number = None
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# ... # 2500 50 5 > ...",
"Amount": "2500.00", # Slightly different
},
confidence={"Amount": 0.80},
)
pipeline._cross_validate_payment_line(result)
assert result.fields["Amount"] == "2500.50"
assert result.confidence["Amount"] == 0.95
def test_cross_validation_records_matches(self):
"""Test that cross-validation records match status."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
mock_parsed = MagicMock()
mock_parsed.is_valid = True
mock_parsed.ocr_number = "12345678901234"
mock_parsed.amount = "1500.00"
mock_parsed.account_number = "12345678"
pipeline.payment_line_parser = MagicMock()
pipeline.payment_line_parser.parse.return_value = mock_parsed
result = InferenceResult(
fields={
"payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#",
"OCR": "12345678901234", # Matching
"Amount": "1500.00", # Matching
"Bankgiro": "1234-5678", # Matching
},
confidence={},
)
pipeline._cross_validate_payment_line(result)
assert result.cross_validation is not None
assert result.cross_validation.ocr_match is True
assert result.cross_validation.amount_match is True
assert result.cross_validation.is_valid is True
class TestFallbackLogic:
"""Tests for fallback detection logic."""
def test_needs_fallback_when_key_fields_missing(self):
"""Test fallback is triggered when key fields missing."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Only one key field present
result = InferenceResult(fields={"Amount": "1500.00"})
assert pipeline._needs_fallback(result) is True
def test_no_fallback_when_fields_present(self):
"""Test no fallback when key fields present."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# All key fields present
result = InferenceResult(
fields={
"Amount": "1500.00",
"InvoiceNumber": "INV-001",
"OCR": "12345678901234",
}
)
assert pipeline._needs_fallback(result) is False
class TestPatternExtraction:
"""Tests for fallback pattern extraction."""
def test_extract_amount_pattern(self):
"""Test amount extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Att betala: 1 500,00 SEK"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "Amount" in result.fields
assert result.confidence["Amount"] == 0.5
def test_extract_bankgiro_pattern(self):
"""Test bankgiro extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Bankgiro: 1234-5678"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "Bankgiro" in result.fields
assert result.fields["Bankgiro"] == "1234-5678"
def test_extract_ocr_pattern(self):
"""Test OCR extraction with regex."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "OCR: 12345678901234567890"
result = InferenceResult()
pipeline._extract_with_patterns(text, result)
assert "OCR" in result.fields
assert result.fields["OCR"] == "12345678901234567890"
def test_does_not_override_existing_fields(self):
"""Test pattern extraction doesn't override existing fields."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
text = "Fakturanr: 999"
result = InferenceResult(fields={"InvoiceNumber": "INV-001"})
pipeline._extract_with_patterns(text, result)
# Should keep existing value
assert result.fields["InvoiceNumber"] == "INV-001"
class TestAmountNormalization:
"""Tests for amount normalization."""
def test_normalize_swedish_format(self):
"""Test normalizing Swedish amount format."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
# Swedish format: space as thousands separator, comma as decimal
assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00
# Standard format: dot as decimal
assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00
# Swedish format with comma as decimal
assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00
def test_normalize_invalid_amount(self):
"""Test normalizing invalid amount returns None."""
with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None):
pipeline = InferencePipeline.__new__(InferencePipeline)
assert pipeline._normalize_amount_for_compare("invalid") is None
assert pipeline._normalize_amount_for_compare("") is None
class TestResultSerialization:
"""Tests for result serialization with cross-validation."""
def test_to_json_with_cross_validation(self):
"""Test JSON serialization includes cross-validation."""
cv = CrossValidationResult(
is_valid=True,
ocr_match=True,
amount_match=True,
payment_line_ocr="12345678901234",
payment_line_amount="1500.00",
details=["OCR match", "Amount match"],
)
result = InferenceResult(
document_id="test-doc",
success=True,
fields={"InvoiceNumber": "INV-001"},
cross_validation=cv,
)
json_data = result.to_json()
assert "cross_validation" in json_data
assert json_data["cross_validation"]["is_valid"] is True
assert json_data["cross_validation"]["ocr_match"] is True
assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234"

View File

@@ -0,0 +1 @@
"""Repository integration tests."""

View File

@@ -0,0 +1,464 @@
"""
Annotation Repository Integration Tests
Tests AnnotationRepository with real database operations.
"""
from uuid import uuid4
import pytest
from inference.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepositoryCreate:
"""Tests for annotation creation."""
def test_create_annotation(self, patched_session, sample_document):
"""Test creating a single annotation."""
repo = AnnotationRepository()
ann_id = repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=240,
bbox_width=160,
bbox_height=40,
text_value="INV-2024-001",
confidence=0.95,
source="auto",
)
assert ann_id is not None
ann = repo.get(ann_id)
assert ann is not None
assert ann.class_name == "invoice_number"
assert ann.text_value == "INV-2024-001"
assert ann.confidence == 0.95
assert ann.source == "auto"
def test_create_batch_annotations(self, patched_session, sample_document):
"""Test batch creation of annotations."""
repo = AnnotationRepository()
annotations_data = [
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.1,
"width": 0.2,
"height": 0.05,
"bbox_x": 400,
"bbox_y": 80,
"bbox_width": 160,
"bbox_height": 40,
"text_value": "INV-001",
"confidence": 0.95,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 1,
"class_name": "invoice_date",
"x_center": 0.5,
"y_center": 0.2,
"width": 0.15,
"height": 0.04,
"bbox_x": 400,
"bbox_y": 160,
"bbox_width": 120,
"bbox_height": 32,
"text_value": "2024-01-15",
"confidence": 0.92,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 6,
"class_name": "amount",
"x_center": 0.7,
"y_center": 0.8,
"width": 0.1,
"height": 0.04,
"bbox_x": 560,
"bbox_y": 640,
"bbox_width": 80,
"bbox_height": 32,
"text_value": "1500.00",
"confidence": 0.98,
},
]
ids = repo.create_batch(annotations_data)
assert len(ids) == 3
# Verify all annotations exist
for ann_id in ids:
ann = repo.get(ann_id)
assert ann is not None
class TestAnnotationRepositoryRead:
"""Tests for annotation retrieval."""
def test_get_nonexistent_annotation(self, patched_session):
"""Test getting an annotation that doesn't exist."""
repo = AnnotationRepository()
ann = repo.get(str(uuid4()))
assert ann is None
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotations for a document."""
repo = AnnotationRepository()
# Add another annotation
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.4,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=320,
bbox_width=120,
bbox_height=32,
text_value="2024-01-15",
)
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 2
# Should be ordered by class_id
assert annotations[0].class_id == 0
assert annotations[1].class_id == 1
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
"""Test getting annotations for a specific page."""
repo = AnnotationRepository()
# Create annotations on different pages
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
)
repo.create(
document_id=str(sample_document.document_id),
page_number=2,
class_id=6,
class_name="amount",
x_center=0.7,
y_center=0.8,
width=0.1,
height=0.04,
bbox_x=560,
bbox_y=640,
bbox_width=80,
bbox_height=32,
)
page1_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=1,
)
page2_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=2,
)
assert len(page1_annotations) == 1
assert len(page2_annotations) == 1
assert page1_annotations[0].page_number == 1
assert page2_annotations[0].page_number == 2
class TestAnnotationRepositoryUpdate:
"""Tests for annotation updates."""
def test_update_annotation_bbox(self, patched_session, sample_annotation):
"""Test updating annotation bounding box."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
x_center=0.6,
y_center=0.4,
width=0.25,
height=0.06,
bbox_x=480,
bbox_y=320,
bbox_width=200,
bbox_height=48,
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.x_center == 0.6
assert ann.y_center == 0.4
assert ann.bbox_x == 480
assert ann.bbox_width == 200
def test_update_annotation_text(self, patched_session, sample_annotation):
"""Test updating annotation text value."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
text_value="INV-2024-002",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.text_value == "INV-2024-002"
def test_update_annotation_class(self, patched_session, sample_annotation):
"""Test updating annotation class."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
class_id=1,
class_name="invoice_date",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.class_id == 1
assert ann.class_name == "invoice_date"
def test_update_nonexistent_annotation(self, patched_session):
"""Test updating annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.update(
str(uuid4()),
text_value="new value",
)
assert result is False
class TestAnnotationRepositoryDelete:
"""Tests for annotation deletion."""
def test_delete_annotation(self, patched_session, sample_annotation):
"""Test deleting a single annotation."""
repo = AnnotationRepository()
result = repo.delete(str(sample_annotation.annotation_id))
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is None
def test_delete_nonexistent_annotation(self, patched_session):
"""Test deleting annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_annotations_for_document(self, patched_session, sample_document):
"""Test deleting all annotations for a document."""
repo = AnnotationRepository()
# Create multiple annotations
for i in range(3):
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=i,
class_name=f"field_{i}",
x_center=0.5,
y_center=0.1 + i * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + i * 160,
bbox_width=160,
bbox_height=40,
)
# Delete all
count = repo.delete_for_document(str(sample_document.document_id))
assert count == 3
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_annotations_by_source(self, patched_session, sample_document):
"""Test deleting annotations by source type."""
repo = AnnotationRepository()
# Create auto and manual annotations
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
source="auto",
)
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.2,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=160,
bbox_width=120,
bbox_height=32,
source="manual",
)
# Delete only auto annotations
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
assert count == 1
remaining = repo.get_for_document(str(sample_document.document_id))
assert len(remaining) == 1
assert remaining[0].source == "manual"
class TestAnnotationVerification:
"""Tests for annotation verification."""
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
"""Test marking annotation as verified."""
repo = AnnotationRepository()
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
assert ann is not None
assert ann.is_verified is True
assert ann.verified_by == admin_token.token
assert ann.verified_at is not None
class TestAnnotationOverride:
"""Tests for annotation override functionality."""
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
"""Test overriding an auto-generated annotation."""
repo = AnnotationRepository()
# Override the annotation
ann = repo.override(
str(sample_annotation.annotation_id),
admin_token.token,
change_reason="Correcting OCR error",
text_value="INV-2024-CORRECTED",
x_center=0.55,
)
assert ann is not None
assert ann.text_value == "INV-2024-CORRECTED"
assert ann.x_center == 0.55
assert ann.source == "manual" # Changed from auto to manual
assert ann.override_source == "auto"
class TestAnnotationHistory:
"""Tests for annotation history tracking."""
def test_create_history_record(self, patched_session, sample_annotation):
"""Test creating annotation history record."""
repo = AnnotationRepository()
history = repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
changed_by="test-user",
)
assert history is not None
assert history.action == "created"
assert history.changed_by == "test-user"
def test_get_annotation_history(self, patched_session, sample_annotation):
"""Test getting history for an annotation."""
repo = AnnotationRepository()
# Create history records
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
)
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="updated",
previous_value={"text_value": "INV-001"},
new_value={"text_value": "INV-002"},
)
history = repo.get_history(sample_annotation.annotation_id)
assert len(history) == 2
# Should be ordered by created_at desc
assert history[0].action == "updated"
assert history[1].action == "created"
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotation history for a document."""
repo = AnnotationRepository()
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="created",
new_value={"class_name": "invoice_number"},
)
history = repo.get_document_history(sample_document.document_id)
assert len(history) >= 1
assert all(h.document_id == sample_document.document_id for h in history)

View File

@@ -0,0 +1,355 @@
"""
Batch Upload Repository Integration Tests
Tests BatchUploadRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
class TestBatchUploadCreate:
"""Tests for batch upload creation."""
def test_create_batch_upload(self, patched_session, admin_token):
"""Test creating a batch upload."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="test_batch.zip",
file_size=10240,
upload_source="api",
)
assert batch is not None
assert batch.batch_id is not None
assert batch.filename == "test_batch.zip"
assert batch.file_size == 10240
assert batch.upload_source == "api"
assert batch.status == "processing"
assert batch.total_files == 0
assert batch.processed_files == 0
def test_create_batch_upload_default_source(self, patched_session, admin_token):
"""Test creating batch upload with default source."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="ui_batch.zip",
file_size=5120,
)
assert batch.upload_source == "ui"
class TestBatchUploadRead:
"""Tests for batch upload retrieval."""
def test_get_batch_upload(self, patched_session, sample_batch_upload):
"""Test getting a batch upload by ID."""
repo = BatchUploadRepository()
batch = repo.get(sample_batch_upload.batch_id)
assert batch is not None
assert batch.batch_id == sample_batch_upload.batch_id
assert batch.filename == sample_batch_upload.filename
def test_get_nonexistent_batch_upload(self, patched_session):
"""Test getting a batch upload that doesn't exist."""
repo = BatchUploadRepository()
batch = repo.get(uuid4())
assert batch is None
def test_get_paginated_batch_uploads(self, patched_session, admin_token):
"""Test paginated batch upload listing."""
repo = BatchUploadRepository()
# Create multiple batches
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024 * (i + 1),
)
batches, total = repo.get_paginated(limit=3, offset=0)
assert total == 5
assert len(batches) == 3
def test_get_paginated_with_offset(self, patched_session, admin_token):
"""Test pagination offset."""
repo = BatchUploadRepository()
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024,
)
page1, _ = repo.get_paginated(limit=2, offset=0)
page2, _ = repo.get_paginated(limit=2, offset=2)
ids_page1 = {b.batch_id for b in page1}
ids_page2 = {b.batch_id for b in page2}
assert len(ids_page1 & ids_page2) == 0
class TestBatchUploadUpdate:
"""Tests for batch upload updates."""
def test_update_batch_status(self, patched_session, sample_batch_upload):
"""Test updating batch upload status."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="completed",
total_files=10,
processed_files=10,
successful_files=8,
failed_files=2,
)
# Need to commit to see changes
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "completed"
assert batch.total_files == 10
assert batch.successful_files == 8
assert batch.failed_files == 2
def test_update_batch_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch upload with error message."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="failed",
error_message="ZIP extraction failed",
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "failed"
assert batch.error_message == "ZIP extraction failed"
def test_update_batch_csv_info(self, patched_session, sample_batch_upload):
"""Test updating batch with CSV information."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
csv_filename="manifest.csv",
csv_row_count=100,
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.csv_filename == "manifest.csv"
assert batch.csv_row_count == 100
class TestBatchUploadFiles:
"""Tests for batch upload file management."""
def test_create_batch_file(self, patched_session, sample_batch_upload):
"""Test creating a batch upload file record."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_001.pdf",
status="pending",
)
assert file_record is not None
assert file_record.file_id is not None
assert file_record.filename == "invoice_001.pdf"
assert file_record.batch_id == sample_batch_upload.batch_id
assert file_record.status == "pending"
def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document):
"""Test creating batch file linked to a document."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_linked.pdf",
document_id=sample_document.document_id,
status="completed",
annotation_count=5,
)
assert file_record.document_id == sample_document.document_id
assert file_record.status == "completed"
assert file_record.annotation_count == 5
def test_get_batch_files(self, patched_session, sample_batch_upload):
"""Test getting all files for a batch."""
repo = BatchUploadRepository()
# Create multiple files
for i in range(3):
repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename=f"file_{i}.pdf",
)
files = repo.get_files(sample_batch_upload.batch_id)
assert len(files) == 3
assert all(f.batch_id == sample_batch_upload.batch_id for f in files)
def test_get_batch_files_empty(self, patched_session, sample_batch_upload):
"""Test getting files for batch with no files."""
repo = BatchUploadRepository()
files = repo.get_files(sample_batch_upload.batch_id)
assert files == []
def test_update_batch_file_status(self, patched_session, sample_batch_upload):
"""Test updating batch file status."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="test.pdf",
)
repo.update_file(
file_record.file_id,
status="completed",
annotation_count=10,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "completed"
assert updated_file.annotation_count == 10
def test_update_batch_file_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch file with error."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="corrupt.pdf",
)
repo.update_file(
file_record.file_id,
status="failed",
error_message="Invalid PDF format",
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "failed"
assert updated_file.error_message == "Invalid PDF format"
def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload):
"""Test updating batch file with CSV row data."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_with_csv.pdf",
)
csv_data = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier": "Test Corp",
}
repo.update_file(
file_record.file_id,
csv_row_data=csv_data,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.csv_row_data == csv_data
class TestBatchUploadWorkflow:
"""Tests for complete batch upload workflows."""
def test_complete_batch_workflow(self, patched_session, admin_token):
"""Test complete batch upload workflow."""
repo = BatchUploadRepository()
# 1. Create batch
batch = repo.create(
admin_token=admin_token.token,
filename="full_workflow.zip",
file_size=50000,
)
# 2. Update with file count
repo.update(batch.batch_id, total_files=3)
patched_session.commit()
# 3. Create file records
file_ids = []
for i in range(3):
file_record = repo.create_file(
batch_id=batch.batch_id,
filename=f"doc_{i}.pdf",
)
file_ids.append(file_record.file_id)
# 4. Process files one by one
for i, file_id in enumerate(file_ids):
status = "completed" if i < 2 else "failed"
repo.update_file(
file_id,
status=status,
annotation_count=5 if status == "completed" else 0,
)
# 5. Update batch progress
repo.update(
batch.batch_id,
processed_files=3,
successful_files=2,
failed_files=1,
status="partial",
)
patched_session.commit()
# Verify final state
final_batch = repo.get(batch.batch_id)
assert final_batch.status == "partial"
assert final_batch.total_files == 3
assert final_batch.processed_files == 3
assert final_batch.successful_files == 2
assert final_batch.failed_files == 1
files = repo.get_files(batch.batch_id)
assert len(files) == 3
completed = [f for f in files if f.status == "completed"]
failed = [f for f in files if f.status == "failed"]
assert len(completed) == 2
assert len(failed) == 1

View File

@@ -0,0 +1,321 @@
"""
Dataset Repository Integration Tests
Tests DatasetRepository with real database operations.
"""
from uuid import uuid4
import pytest
from inference.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepositoryCreate:
"""Tests for dataset creation."""
def test_create_dataset(self, patched_session):
"""Test creating a training dataset."""
repo = DatasetRepository()
dataset = repo.create(
name="Test Dataset",
description="Dataset for integration testing",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
)
assert dataset is not None
assert dataset.name == "Test Dataset"
assert dataset.description == "Dataset for integration testing"
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
assert dataset.status == "building"
def test_create_dataset_with_defaults(self, patched_session):
"""Test creating dataset with default values."""
repo = DatasetRepository()
dataset = repo.create(name="Minimal Dataset")
assert dataset is not None
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
class TestDatasetRepositoryRead:
"""Tests for dataset retrieval."""
def test_get_dataset_by_id(self, patched_session, sample_dataset):
"""Test getting dataset by ID."""
repo = DatasetRepository()
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_id == sample_dataset.dataset_id
assert dataset.name == sample_dataset.name
def test_get_nonexistent_dataset(self, patched_session):
"""Test getting dataset that doesn't exist."""
repo = DatasetRepository()
dataset = repo.get(str(uuid4()))
assert dataset is None
def test_get_paginated_datasets(self, patched_session):
"""Test paginated dataset listing."""
repo = DatasetRepository()
# Create multiple datasets
for i in range(5):
repo.create(name=f"Dataset {i}")
datasets, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(datasets) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering datasets by status."""
repo = DatasetRepository()
# Create datasets with different statuses
d1 = repo.create(name="Building Dataset")
repo.update_status(str(d1.dataset_id), "ready")
d2 = repo.create(name="Another Building Dataset")
# stays as "building"
datasets, total = repo.get_paginated(status="ready")
assert total == 1
assert datasets[0].status == "ready"
class TestDatasetRepositoryUpdate:
"""Tests for dataset updates."""
def test_update_status(self, patched_session, sample_dataset):
"""Test updating dataset status."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
total_documents=100,
total_images=150,
total_annotations=500,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "ready"
assert dataset.total_documents == 100
assert dataset.total_images == 150
assert dataset.total_annotations == 500
def test_update_status_with_error(self, patched_session, sample_dataset):
"""Test updating dataset status with error message."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="failed",
error_message="Failed to build dataset: insufficient documents",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "failed"
assert "insufficient documents" in dataset.error_message
def test_update_status_with_path(self, patched_session, sample_dataset):
"""Test updating dataset path."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
dataset_path="/datasets/test_dataset_2024",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_path == "/datasets/test_dataset_2024"
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
"""Test updating dataset training status."""
repo = DatasetRepository()
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="running",
active_training_task_id=str(sample_training_task.task_id),
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "running"
assert dataset.active_training_task_id == sample_training_task.task_id
def test_update_training_status_completed(self, patched_session, sample_dataset):
"""Test updating training status to completed updates main status."""
repo = DatasetRepository()
# First set to ready
repo.update_status(str(sample_dataset.dataset_id), status="ready")
# Then complete training
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="completed",
update_main_status=True,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "completed"
assert dataset.status == "trained"
class TestDatasetDocuments:
"""Tests for dataset document management."""
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
"""Test adding documents to a dataset."""
repo = DatasetRepository()
documents_data = [
{
"document_id": str(multiple_documents[0].document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
},
{
"document_id": str(multiple_documents[1].document_id),
"split": "train",
"page_count": 2,
"annotation_count": 8,
},
{
"document_id": str(multiple_documents[2].document_id),
"split": "val",
"page_count": 1,
"annotation_count": 3,
},
]
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
# Verify documents were added
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 3
train_docs = [d for d in docs if d.split == "train"]
val_docs = [d for d in docs if d.split == "val"]
assert len(train_docs) == 2
assert len(val_docs) == 1
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
"""Test getting documents from a dataset."""
repo = DatasetRepository()
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 1
assert docs[0].document_id == sample_document.document_id
assert docs[0].split == "train"
assert docs[0].page_count == 1
assert docs[0].annotation_count == 5
class TestDatasetRepositoryDelete:
"""Tests for dataset deletion."""
def test_delete_dataset(self, patched_session, sample_dataset):
"""Test deleting a dataset."""
repo = DatasetRepository()
result = repo.delete(str(sample_dataset.dataset_id))
assert result is True
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is None
def test_delete_nonexistent_dataset(self, patched_session):
"""Test deleting dataset that doesn't exist."""
repo = DatasetRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
"""Test deleting dataset also removes document links."""
repo = DatasetRepository()
# Add document to dataset
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
# Delete dataset
repo.delete(str(sample_dataset.dataset_id))
# Document links should be gone
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 0
class TestActiveTrainingTasks:
"""Tests for active training task queries."""
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
"""Test getting active training tasks for datasets."""
repo = DatasetRepository()
# Update task to running
from inference.data.repositories.training_task_repository import TrainingTaskRepository
task_repo = TrainingTaskRepository()
task_repo.update_status(str(sample_training_task.task_id), "running")
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
assert str(sample_dataset.dataset_id) in result
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
"""Test getting active training tasks returns empty when no tasks exist."""
repo = DatasetRepository()
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
# No training task exists for this dataset, so result should be empty
assert str(sample_dataset.dataset_id) not in result
assert result == {}

View File

@@ -0,0 +1,350 @@
"""
Document Repository Integration Tests
Tests DocumentRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from sqlmodel import select
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.document_repository import DocumentRepository
def ensure_utc(dt: datetime | None) -> datetime | None:
"""Ensure datetime is timezone-aware (UTC).
PostgreSQL may return offset-naive datetimes. This helper
converts them to UTC for proper comparison.
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
class TestDocumentRepositoryCreate:
"""Tests for document creation."""
def test_create_document(self, patched_session):
"""Test creating a document and retrieving it."""
repo = DocumentRepository()
doc_id = repo.create(
filename="test_invoice.pdf",
file_size=2048,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=2,
upload_source="api",
category="invoice",
)
assert doc_id is not None
doc = repo.get(doc_id)
assert doc is not None
assert doc.filename == "test_invoice.pdf"
assert doc.file_size == 2048
assert doc.page_count == 2
assert doc.upload_source == "api"
assert doc.category == "invoice"
assert doc.status == "pending"
def test_create_document_with_csv_values(self, patched_session):
"""Test creating document with CSV field values."""
repo = DocumentRepository()
csv_values = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier_name": "Test Supplier AB",
}
doc_id = repo.create(
filename="invoice_with_csv.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/invoice_with_csv.pdf",
csv_field_values=csv_values,
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.csv_field_values == csv_values
def test_create_document_with_group_key(self, patched_session):
"""Test creating document with group key."""
repo = DocumentRepository()
doc_id = repo.create(
filename="grouped_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/grouped_doc.pdf",
group_key="batch-2024-01",
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.group_key == "batch-2024-01"
class TestDocumentRepositoryRead:
"""Tests for document retrieval."""
def test_get_nonexistent_document(self, patched_session):
"""Test getting a document that doesn't exist."""
repo = DocumentRepository()
doc = repo.get(str(uuid4()))
assert doc is None
def test_get_paginated_documents(self, patched_session, multiple_documents):
"""Test paginated document listing."""
repo = DocumentRepository()
docs, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(docs) == 2
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
"""Test filtering documents by status."""
repo = DocumentRepository()
docs, total = repo.get_paginated(status="labeled")
assert total == 2
for doc in docs:
assert doc.status == "labeled"
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
"""Test filtering documents by category."""
repo = DocumentRepository()
docs, total = repo.get_paginated(category="letter")
assert total == 1
assert docs[0].category == "letter"
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
"""Test pagination offset."""
repo = DocumentRepository()
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
assert len(doc_ids_page1 & doc_ids_page2) == 0
def test_get_by_ids(self, patched_session, multiple_documents):
"""Test getting multiple documents by IDs."""
repo = DocumentRepository()
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
docs = repo.get_by_ids(ids_to_fetch)
assert len(docs) == 2
fetched_ids = {str(d.document_id) for d in docs}
assert fetched_ids == set(ids_to_fetch)
class TestDocumentRepositoryUpdate:
"""Tests for document updates."""
def test_update_status(self, patched_session, sample_document):
"""Test updating document status."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="labeled",
auto_label_status="completed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.status == "labeled"
assert doc.auto_label_status == "completed"
def test_update_status_with_error(self, patched_session, sample_document):
"""Test updating document status with error message."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="pending",
auto_label_status="failed",
auto_label_error="OCR extraction failed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.auto_label_status == "failed"
assert doc.auto_label_error == "OCR extraction failed"
def test_update_file_path(self, patched_session, sample_document):
"""Test updating document file path."""
repo = DocumentRepository()
new_path = "/archive/2024/test_invoice.pdf"
repo.update_file_path(str(sample_document.document_id), new_path)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.file_path == new_path
def test_update_group_key(self, patched_session, sample_document):
"""Test updating document group key."""
repo = DocumentRepository()
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.group_key == "new-group-key"
def test_update_category(self, patched_session, sample_document):
"""Test updating document category."""
repo = DocumentRepository()
doc = repo.update_category(str(sample_document.document_id), "letter")
assert doc is not None
assert doc.category == "letter"
class TestDocumentRepositoryDelete:
"""Tests for document deletion."""
def test_delete_document(self, patched_session, sample_document):
"""Test deleting a document."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is None
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
"""Test deleting document also deletes its annotations."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
# Verify annotation is also deleted
from inference.data.repositories.annotation_repository import AnnotationRepository
ann_repo = AnnotationRepository()
annotations = ann_repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_nonexistent_document(self, patched_session):
"""Test deleting a document that doesn't exist."""
repo = DocumentRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestDocumentRepositoryQueries:
"""Tests for complex document queries."""
def test_count_by_status(self, patched_session, multiple_documents):
"""Test counting documents by status."""
repo = DocumentRepository()
counts = repo.count_by_status()
assert counts.get("pending") == 2
assert counts.get("labeled") == 2
assert counts.get("exported") == 1
def test_get_categories(self, patched_session, multiple_documents):
"""Test getting unique categories."""
repo = DocumentRepository()
categories = repo.get_categories()
assert "invoice" in categories
assert "letter" in categories
def test_get_labeled_for_export(self, patched_session, multiple_documents):
"""Test getting labeled documents for export."""
repo = DocumentRepository()
docs = repo.get_labeled_for_export()
assert len(docs) == 2
for doc in docs:
assert doc.status == "labeled"
class TestDocumentAnnotationLocking:
"""Tests for annotation locking mechanism."""
def test_acquire_annotation_lock(self, patched_session, sample_document):
"""Test acquiring annotation lock."""
repo = DocumentRepository()
doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
assert doc is not None
assert doc.annotation_lock_until is not None
lock_until = ensure_utc(doc.annotation_lock_until)
assert lock_until > datetime.now(timezone.utc)
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
"""Test acquiring lock fails when already locked."""
repo = DocumentRepository()
# First lock
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
# Second lock attempt should fail
result = repo.acquire_annotation_lock(str(sample_document.document_id))
assert result is None
def test_release_annotation_lock(self, patched_session, sample_document):
"""Test releasing annotation lock."""
repo = DocumentRepository()
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
doc = repo.release_annotation_lock(str(sample_document.document_id))
assert doc is not None
assert doc.annotation_lock_until is None
def test_extend_annotation_lock(self, patched_session, sample_document):
"""Test extending annotation lock."""
repo = DocumentRepository()
# Acquire initial lock
initial_doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
# Extend lock
extended_doc = repo.extend_annotation_lock(
str(sample_document.document_id),
additional_seconds=300,
)
assert extended_doc is not None
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
assert extended_expiry > initial_expiry

View File

@@ -0,0 +1,310 @@
"""
Model Version Repository Integration Tests
Tests ModelVersionRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionCreate:
"""Tests for model version creation."""
def test_create_model_version(self, patched_session):
"""Test creating a model version."""
repo = ModelVersionRepository()
model = repo.create(
version="1.0.0",
name="Invoice Extractor v1",
model_path="/models/invoice_v1.pt",
description="Initial production model",
metrics_mAP=0.92,
metrics_precision=0.89,
metrics_recall=0.85,
document_count=1000,
file_size=50000000,
)
assert model is not None
assert model.version == "1.0.0"
assert model.name == "Invoice Extractor v1"
assert model.model_path == "/models/invoice_v1.pt"
assert model.metrics_mAP == 0.92
assert model.is_active is False
assert model.status == "inactive"
def test_create_model_version_with_training_info(
self, patched_session, sample_training_task, sample_dataset
):
"""Test creating model version linked to training task and dataset."""
repo = ModelVersionRepository()
model = repo.create(
version="1.1.0",
name="Invoice Extractor v1.1",
model_path="/models/invoice_v1.1.pt",
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
training_config={"epochs": 100, "batch_size": 16},
trained_at=datetime.now(timezone.utc),
)
assert model is not None
assert model.task_id == sample_training_task.task_id
assert model.dataset_id == sample_dataset.dataset_id
assert model.training_config["epochs"] == 100
class TestModelVersionRead:
"""Tests for model version retrieval."""
def test_get_model_version_by_id(self, patched_session, sample_model_version):
"""Test getting model version by ID."""
repo = ModelVersionRepository()
model = repo.get(str(sample_model_version.version_id))
assert model is not None
assert model.version_id == sample_model_version.version_id
def test_get_nonexistent_model_version(self, patched_session):
"""Test getting model version that doesn't exist."""
repo = ModelVersionRepository()
model = repo.get(str(uuid4()))
assert model is None
def test_get_paginated_model_versions(self, patched_session):
"""Test paginated model version listing."""
repo = ModelVersionRepository()
# Create multiple versions
for i in range(5):
repo.create(
version=f"1.{i}.0",
name=f"Model v1.{i}",
model_path=f"/models/model_v1.{i}.pt",
)
models, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(models) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering model versions by status."""
repo = ModelVersionRepository()
# Create active and inactive models
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
repo.activate(str(m1.version_id))
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
active_models, active_total = repo.get_paginated(status="active")
inactive_models, inactive_total = repo.get_paginated(status="inactive")
assert active_total == 1
assert inactive_total == 1
class TestModelVersionActivation:
"""Tests for model version activation."""
def test_activate_model_version(self, patched_session, sample_model_version):
"""Test activating a model version."""
repo = ModelVersionRepository()
model = repo.activate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is True
assert model.status == "active"
assert model.activated_at is not None
def test_activate_deactivates_others(self, patched_session):
"""Test that activating one version deactivates others."""
repo = ModelVersionRepository()
# Create and activate first model
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
repo.activate(str(m1.version_id))
# Create and activate second model
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
repo.activate(str(m2.version_id))
# Check first model is now inactive
m1_after = repo.get(str(m1.version_id))
assert m1_after.is_active is False
assert m1_after.status == "inactive"
# Check second model is active
m2_after = repo.get(str(m2.version_id))
assert m2_after.is_active is True
def test_get_active_model(self, patched_session, sample_model_version):
"""Test getting the currently active model."""
repo = ModelVersionRepository()
# Initially no active model
active = repo.get_active()
assert active is None
# Activate model
repo.activate(str(sample_model_version.version_id))
# Now should return active model
active = repo.get_active()
assert active is not None
assert active.version_id == sample_model_version.version_id
def test_deactivate_model_version(self, patched_session, sample_model_version):
"""Test deactivating a model version."""
repo = ModelVersionRepository()
# First activate
repo.activate(str(sample_model_version.version_id))
# Then deactivate
model = repo.deactivate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is False
assert model.status == "inactive"
class TestModelVersionUpdate:
"""Tests for model version updates."""
def test_update_model_metadata(self, patched_session, sample_model_version):
"""Test updating model version metadata."""
repo = ModelVersionRepository()
model = repo.update(
str(sample_model_version.version_id),
name="Updated Model Name",
description="Updated description",
)
assert model is not None
assert model.name == "Updated Model Name"
assert model.description == "Updated description"
def test_update_model_status(self, patched_session, sample_model_version):
"""Test updating model version status."""
repo = ModelVersionRepository()
model = repo.update(str(sample_model_version.version_id), status="deprecated")
assert model is not None
assert model.status == "deprecated"
def test_update_nonexistent_model(self, patched_session):
"""Test updating model that doesn't exist."""
repo = ModelVersionRepository()
model = repo.update(str(uuid4()), name="New Name")
assert model is None
class TestModelVersionArchive:
"""Tests for model version archiving."""
def test_archive_model_version(self, patched_session, sample_model_version):
"""Test archiving an inactive model version."""
repo = ModelVersionRepository()
model = repo.archive(str(sample_model_version.version_id))
assert model is not None
assert model.status == "archived"
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be archived."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to archive
model = repo.archive(str(sample_model_version.version_id))
assert model is None
# Verify model is still active
current = repo.get(str(sample_model_version.version_id))
assert current.status == "active"
class TestModelVersionDelete:
"""Tests for model version deletion."""
def test_delete_inactive_model(self, patched_session, sample_model_version):
"""Test deleting an inactive model version."""
repo = ModelVersionRepository()
result = repo.delete(str(sample_model_version.version_id))
assert result is True
model = repo.get(str(sample_model_version.version_id))
assert model is None
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be deleted."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to delete
result = repo.delete(str(sample_model_version.version_id))
assert result is False
# Verify model still exists
model = repo.get(str(sample_model_version.version_id))
assert model is not None
def test_delete_nonexistent_model(self, patched_session):
"""Test deleting model that doesn't exist."""
repo = ModelVersionRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestOnlyOneActiveModel:
"""Tests to verify only one model can be active at a time."""
def test_single_active_model_constraint(self, patched_session):
"""Test that only one model can be active at any time."""
repo = ModelVersionRepository()
# Create multiple models
models = []
for i in range(3):
m = repo.create(
version=f"1.{i}.0",
name=f"Model {i}",
model_path=f"/models/model_{i}.pt",
)
models.append(m)
# Activate each model in sequence
for model in models:
repo.activate(str(model.version_id))
# Count active models
all_models, _ = repo.get_paginated(status="active")
assert len(all_models) == 1
# Verify it's the last one activated
assert all_models[0].version_id == models[-1].version_id

View File

@@ -0,0 +1,274 @@
"""
Token Repository Integration Tests
Tests TokenRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
import pytest
from inference.data.repositories.token_repository import TokenRepository
class TestTokenCreate:
"""Tests for token creation."""
def test_create_new_token(self, patched_session):
"""Test creating a new admin token."""
repo = TokenRepository()
repo.create(
token="new-test-token-abc123",
name="New Test Admin",
)
token = repo.get("new-test-token-abc123")
assert token is not None
assert token.token == "new-test-token-abc123"
assert token.name == "New Test Admin"
assert token.is_active is True
assert token.expires_at is None
def test_create_token_with_expiration(self, patched_session):
"""Test creating token with expiration date."""
repo = TokenRepository()
expiry = datetime.now(timezone.utc) + timedelta(days=30)
repo.create(
token="expiring-token-xyz789",
name="Expiring Token",
expires_at=expiry,
)
token = repo.get("expiring-token-xyz789")
assert token is not None
assert token.expires_at is not None
def test_create_updates_existing_token(self, patched_session, admin_token):
"""Test creating with existing token updates it."""
repo = TokenRepository()
new_expiry = datetime.now(timezone.utc) + timedelta(days=60)
repo.create(
token=admin_token.token,
name="Updated Admin Name",
expires_at=new_expiry,
)
token = repo.get(admin_token.token)
assert token is not None
assert token.name == "Updated Admin Name"
assert token.is_active is True
class TestTokenValidation:
"""Tests for token validation."""
def test_is_valid_active_token(self, patched_session, admin_token):
"""Test that active token is valid."""
repo = TokenRepository()
result = repo.is_valid(admin_token.token)
assert result is True
def test_is_valid_nonexistent_token(self, patched_session):
"""Test that nonexistent token is invalid."""
repo = TokenRepository()
result = repo.is_valid("nonexistent-token-12345")
assert result is False
def test_is_valid_deactivated_token(self, patched_session, admin_token):
"""Test that deactivated token is invalid."""
repo = TokenRepository()
repo.deactivate(admin_token.token)
result = repo.is_valid(admin_token.token)
assert result is False
def test_is_valid_expired_token(self, patched_session):
"""Test that expired token is invalid."""
repo = TokenRepository()
past_expiry = datetime.now(timezone.utc) - timedelta(days=1)
repo.create(
token="expired-token-test",
name="Expired Token",
expires_at=past_expiry,
)
result = repo.is_valid("expired-token-test")
assert result is False
def test_is_valid_not_yet_expired_token(self, patched_session):
"""Test that not-yet-expired token is valid."""
repo = TokenRepository()
future_expiry = datetime.now(timezone.utc) + timedelta(days=7)
repo.create(
token="valid-expiring-token",
name="Valid Expiring Token",
expires_at=future_expiry,
)
result = repo.is_valid("valid-expiring-token")
assert result is True
class TestTokenGet:
"""Tests for token retrieval."""
def test_get_existing_token(self, patched_session, admin_token):
"""Test getting an existing token."""
repo = TokenRepository()
token = repo.get(admin_token.token)
assert token is not None
assert token.token == admin_token.token
assert token.name == admin_token.name
def test_get_nonexistent_token(self, patched_session):
"""Test getting a token that doesn't exist."""
repo = TokenRepository()
token = repo.get("nonexistent-token-xyz")
assert token is None
class TestTokenDeactivate:
"""Tests for token deactivation."""
def test_deactivate_existing_token(self, patched_session, admin_token):
"""Test deactivating an existing token."""
repo = TokenRepository()
result = repo.deactivate(admin_token.token)
assert result is True
token = repo.get(admin_token.token)
assert token is not None
assert token.is_active is False
def test_deactivate_nonexistent_token(self, patched_session):
"""Test deactivating a token that doesn't exist."""
repo = TokenRepository()
result = repo.deactivate("nonexistent-token-abc")
assert result is False
def test_reactivate_deactivated_token(self, patched_session, admin_token):
"""Test reactivating a deactivated token via create."""
repo = TokenRepository()
# Deactivate first
repo.deactivate(admin_token.token)
assert repo.is_valid(admin_token.token) is False
# Reactivate via create
repo.create(
token=admin_token.token,
name="Reactivated Admin",
)
assert repo.is_valid(admin_token.token) is True
class TestTokenUsageTracking:
"""Tests for token usage tracking."""
def test_update_usage(self, patched_session, admin_token):
"""Test updating token last used timestamp."""
repo = TokenRepository()
# Initially last_used_at might be None
initial_token = repo.get(admin_token.token)
initial_last_used = initial_token.last_used_at
repo.update_usage(admin_token.token)
updated_token = repo.get(admin_token.token)
assert updated_token.last_used_at is not None
if initial_last_used:
assert updated_token.last_used_at >= initial_last_used
def test_update_usage_nonexistent_token(self, patched_session):
"""Test updating usage for nonexistent token does nothing."""
repo = TokenRepository()
# Should not raise, just does nothing
repo.update_usage("nonexistent-token-usage")
token = repo.get("nonexistent-token-usage")
assert token is None
class TestTokenWorkflow:
"""Tests for complete token workflows."""
def test_full_token_lifecycle(self, patched_session):
"""Test complete token lifecycle: create, validate, use, deactivate."""
repo = TokenRepository()
token_str = "lifecycle-test-token"
# 1. Create token
repo.create(token=token_str, name="Lifecycle Token")
assert repo.is_valid(token_str) is True
# 2. Use token
repo.update_usage(token_str)
token = repo.get(token_str)
assert token.last_used_at is not None
# 3. Update token info
new_expiry = datetime.now(timezone.utc) + timedelta(days=90)
repo.create(
token=token_str,
name="Updated Lifecycle Token",
expires_at=new_expiry,
)
token = repo.get(token_str)
assert token.name == "Updated Lifecycle Token"
# 4. Deactivate token
result = repo.deactivate(token_str)
assert result is True
assert repo.is_valid(token_str) is False
# 5. Reactivate token
repo.create(token=token_str, name="Reactivated Token")
assert repo.is_valid(token_str) is True
def test_multiple_tokens(self, patched_session):
"""Test managing multiple tokens."""
repo = TokenRepository()
# Create multiple tokens
tokens = [
("token-a", "Admin A"),
("token-b", "Admin B"),
("token-c", "Admin C"),
]
for token_str, name in tokens:
repo.create(token=token_str, name=name)
# Verify all are valid
for token_str, _ in tokens:
assert repo.is_valid(token_str) is True
# Deactivate one
repo.deactivate("token-b")
# Verify states
assert repo.is_valid("token-a") is True
assert repo.is_valid("token-b") is False
assert repo.is_valid("token-c") is True

View File

@@ -0,0 +1,364 @@
"""
Training Task Repository Integration Tests
Tests TrainingTaskRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from inference.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskCreate:
"""Tests for training task creation."""
def test_create_training_task(self, patched_session, admin_token):
"""Test creating a training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Test Training Task",
task_type="train",
description="Integration test training task",
config={"epochs": 100, "batch_size": 16},
)
assert task_id is not None
task = repo.get(task_id)
assert task is not None
assert task.name == "Test Training Task"
assert task.task_type == "train"
assert task.status == "pending"
assert task.config["epochs"] == 100
def test_create_scheduled_task(self, patched_session, admin_token):
"""Test creating a scheduled training task."""
repo = TrainingTaskRepository()
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
task_id = repo.create(
admin_token=admin_token.token,
name="Scheduled Task",
scheduled_at=scheduled_time,
)
task = repo.get(task_id)
assert task is not None
assert task.status == "scheduled"
assert task.scheduled_at is not None
def test_create_recurring_task(self, patched_session, admin_token):
"""Test creating a recurring training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Recurring Task",
cron_expression="0 2 * * *",
is_recurring=True,
)
task = repo.get(task_id)
assert task is not None
assert task.is_recurring is True
assert task.cron_expression == "0 2 * * *"
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
"""Test creating task linked to a dataset."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Dataset Training Task",
dataset_id=str(sample_dataset.dataset_id),
)
task = repo.get(task_id)
assert task is not None
assert task.dataset_id == sample_dataset.dataset_id
class TestTrainingTaskRead:
"""Tests for training task retrieval."""
def test_get_task_by_id(self, patched_session, sample_training_task):
"""Test getting task by ID."""
repo = TrainingTaskRepository()
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.task_id == sample_training_task.task_id
def test_get_nonexistent_task(self, patched_session):
"""Test getting task that doesn't exist."""
repo = TrainingTaskRepository()
task = repo.get(str(uuid4()))
assert task is None
def test_get_paginated_tasks(self, patched_session, admin_token):
"""Test paginated task listing."""
repo = TrainingTaskRepository()
# Create multiple tasks
for i in range(5):
repo.create(admin_token=admin_token.token, name=f"Task {i}")
tasks, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(tasks) == 2
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
"""Test filtering tasks by status."""
repo = TrainingTaskRepository()
# Create tasks with different statuses
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
repo.create(admin_token=admin_token.token, name="Pending Task")
tasks, total = repo.get_paginated(status="running")
assert total == 1
assert tasks[0].status == "running"
def test_get_pending_tasks(self, patched_session, admin_token):
"""Test getting pending tasks ready to run."""
repo = TrainingTaskRepository()
# Create pending task
repo.create(admin_token=admin_token.token, name="Ready Task")
# Create scheduled task in the past (should be included)
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Past Scheduled Task",
scheduled_at=past_time,
)
# Create scheduled task in the future (should not be included)
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Future Scheduled Task",
scheduled_at=future_time,
)
pending = repo.get_pending()
# Should include pending and past scheduled, not future scheduled
assert len(pending) >= 2
names = [t.name for t in pending]
assert "Ready Task" in names
assert "Past Scheduled Task" in names
def test_get_running_task(self, patched_session, admin_token):
"""Test getting currently running task."""
repo = TrainingTaskRepository()
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
running = repo.get_running()
assert running is not None
assert running.status == "running"
def test_get_running_task_none(self, patched_session, admin_token):
"""Test getting running task when none is running."""
repo = TrainingTaskRepository()
repo.create(admin_token=admin_token.token, name="Pending Task")
running = repo.get_running()
assert running is None
class TestTrainingTaskUpdate:
"""Tests for training task updates."""
def test_update_status_to_running(self, patched_session, sample_training_task):
"""Test updating task status to running."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "running"
assert task.started_at is not None
def test_update_status_to_completed(self, patched_session, sample_training_task):
"""Test updating task status to completed."""
repo = TrainingTaskRepository()
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
repo.update_status(
str(sample_training_task.task_id),
"completed",
result_metrics=metrics,
model_path="/models/trained_model.pt",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "completed"
assert task.completed_at is not None
assert task.result_metrics["mAP"] == 0.92
assert task.model_path == "/models/trained_model.pt"
def test_update_status_to_failed(self, patched_session, sample_training_task):
"""Test updating task status to failed with error message."""
repo = TrainingTaskRepository()
repo.update_status(
str(sample_training_task.task_id),
"failed",
error_message="CUDA out of memory",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "failed"
assert task.completed_at is not None
assert "CUDA out of memory" in task.error_message
def test_cancel_pending_task(self, patched_session, sample_training_task):
"""Test cancelling a pending task."""
repo = TrainingTaskRepository()
result = repo.cancel(str(sample_training_task.task_id))
assert result is True
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "cancelled"
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
"""Test that running task cannot be cancelled."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
result = repo.cancel(str(sample_training_task.task_id))
assert result is False
task = repo.get(str(sample_training_task.task_id))
assert task.status == "running"
class TestTrainingLogs:
"""Tests for training log management."""
def test_add_log_entry(self, patched_session, sample_training_task):
"""Test adding a training log entry."""
repo = TrainingTaskRepository()
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message="Starting training...",
details={"epoch": 1, "batch": 0},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 1
assert logs[0].level == "INFO"
assert logs[0].message == "Starting training..."
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
"""Test adding multiple log entries."""
repo = TrainingTaskRepository()
for i in range(5):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Epoch {i} completed",
details={"epoch": i, "loss": 0.5 - i * 0.1},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 5
def test_get_logs_pagination(self, patched_session, sample_training_task):
"""Test paginated log retrieval."""
repo = TrainingTaskRepository()
for i in range(10):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Log entry {i}",
)
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
assert len(logs) == 5
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
assert len(logs_page2) == 5
class TestDocumentLinks:
"""Tests for training document link management."""
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
"""Test creating a document link."""
repo = TrainingTaskRepository()
link = repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=sample_document.document_id,
annotation_snapshot={"count": 5, "verified": 3},
)
assert link is not None
assert link.task_id == sample_training_task.task_id
assert link.document_id == sample_document.document_id
assert link.annotation_snapshot["count"] == 5
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
"""Test getting all document links for a task."""
repo = TrainingTaskRepository()
for doc in multiple_documents[:3]:
repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=doc.document_id,
)
links = repo.get_document_links(sample_training_task.task_id)
assert len(links) == 3
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
"""Test getting training tasks that used a document."""
repo = TrainingTaskRepository()
# Create multiple tasks using the same document
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
repo.create_document_link(
task_id=repo.get(task1_id).task_id,
document_id=sample_document.document_id,
)
repo.create_document_link(
task_id=repo.get(task2_id).task_id,
document_id=sample_document.document_id,
)
links = repo.get_document_training_tasks(sample_document.document_id)
assert len(links) == 2

View File

@@ -0,0 +1 @@
"""Service integration tests."""

View File

@@ -0,0 +1,497 @@
"""
Dashboard Service Integration Tests
Tests DashboardStatsService and DashboardActivityService with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
AnnotationHistory,
ModelVersion,
TrainingDataset,
TrainingTask,
)
from inference.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
is_annotation_complete,
IDENTIFIER_CLASS_IDS,
PAYMENT_CLASS_IDS,
)
class TestIsAnnotationComplete:
"""Tests for is_annotation_complete function."""
def test_complete_with_invoice_number_and_bankgiro(self):
"""Test complete with invoice_number (0) and bankgiro (4)."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 4}, # bankgiro
]
assert is_annotation_complete(annotations) is True
def test_complete_with_ocr_number_and_plusgiro(self):
"""Test complete with ocr_number (3) and plusgiro (5)."""
annotations = [
{"class_id": 3}, # ocr_number
{"class_id": 5}, # plusgiro
]
assert is_annotation_complete(annotations) is True
def test_incomplete_missing_identifier(self):
"""Test incomplete when missing identifier."""
annotations = [
{"class_id": 4}, # bankgiro only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_missing_payment(self):
"""Test incomplete when missing payment."""
annotations = [
{"class_id": 0}, # invoice_number only
]
assert is_annotation_complete(annotations) is False
def test_incomplete_empty_annotations(self):
"""Test incomplete with empty annotations."""
assert is_annotation_complete([]) is False
def test_complete_with_multiple_fields(self):
"""Test complete with multiple fields."""
annotations = [
{"class_id": 0}, # invoice_number
{"class_id": 1}, # invoice_date
{"class_id": 3}, # ocr_number
{"class_id": 4}, # bankgiro
{"class_id": 5}, # plusgiro
{"class_id": 6}, # amount
]
assert is_annotation_complete(annotations) is True
class TestDashboardStatsService:
"""Tests for DashboardStatsService."""
def test_get_stats_empty_database(self, patched_session):
"""Test stats with empty database."""
service = DashboardStatsService()
stats = service.get_stats()
assert stats["total_documents"] == 0
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 0
assert stats["pending"] == 0
assert stats["completeness_rate"] == 0.0
def test_get_stats_with_documents(self, patched_session, admin_token):
"""Test stats with various document states."""
service = DashboardStatsService()
session = patched_session
# Create documents with different statuses
docs = []
for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/doc_{i}.pdf",
page_count=1,
status=status,
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
stats = service.get_stats()
assert stats["total_documents"] == 5
assert stats["pending"] == 2 # pending + auto_labeling
def test_get_stats_complete_annotations(self, patched_session, admin_token):
"""Test completeness calculation with proper annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document with complete annotations
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="complete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/complete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add identifier annotation (invoice_number = class_id 0)
ann1 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann1)
# Add payment annotation (bankgiro = class_id 4)
ann2 = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=4,
class_name="bankgiro",
x_center=0.5,
y_center=0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=160,
bbox_width=160,
bbox_height=40,
text_value="123-4567",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann2)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 0
assert stats["completeness_rate"] == 100.0
def test_get_stats_incomplete_annotations(self, patched_session, admin_token):
"""Test completeness with incomplete annotations."""
service = DashboardStatsService()
session = patched_session
# Create a labeled document missing payment annotation
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="incomplete_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/incomplete_doc.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
# Add only identifier annotation (missing payment)
ann = AdminAnnotation(
annotation_id=uuid4(),
document_id=doc.document_id,
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
text_value="INV-001",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(ann)
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 0
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 0.0
def test_get_stats_mixed_completeness(self, patched_session, admin_token):
"""Test stats with mix of complete and incomplete documents."""
service = DashboardStatsService()
session = patched_session
# Create 2 labeled documents
docs = []
for i in range(2):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"mixed_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/mixed_doc_{i}.pdf",
page_count=1,
status="labeled",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
docs.append(doc)
session.commit()
# First document: complete (has identifier + payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=0, # invoice_number
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[0].document_id,
page_number=1,
class_id=4, # bankgiro
class_name="bankgiro",
x_center=0.5, y_center=0.2, width=0.2, height=0.05,
bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
# Second document: incomplete (missing payment)
session.add(AdminAnnotation(
annotation_id=uuid4(),
document_id=docs[1].document_id,
page_number=1,
class_id=0, # invoice_number only
class_name="invoice_number",
x_center=0.5, y_center=0.1, width=0.2, height=0.05,
bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
))
session.commit()
stats = service.get_stats()
assert stats["annotation_complete"] == 1
assert stats["annotation_incomplete"] == 1
assert stats["completeness_rate"] == 50.0
class TestDashboardActivityService:
"""Tests for DashboardActivityService."""
def test_get_recent_activities_empty(self, patched_session):
"""Test activities with empty database."""
service = DashboardActivityService()
activities = service.get_recent_activities()
assert activities == []
def test_get_recent_activities_document_uploads(self, patched_session, admin_token):
"""Test activities include document uploads."""
service = DashboardActivityService()
session = patched_session
# Create documents
for i in range(3):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"activity_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/activity_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities()
upload_activities = [a for a in activities if a["type"] == "document_uploaded"]
assert len(upload_activities) == 3
def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation):
"""Test activities include annotation overrides."""
service = DashboardActivityService()
session = patched_session
# Create annotation history with override
history = AnnotationHistory(
history_id=uuid4(),
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="override",
previous_value={"text_value": "OLD-001"},
new_value={"text_value": "NEW-001", "class_name": "invoice_number"},
changed_by="test-admin",
created_at=datetime.now(timezone.utc),
)
session.add(history)
session.commit()
activities = service.get_recent_activities()
override_activities = [a for a in activities if a["type"] == "annotation_modified"]
assert len(override_activities) >= 1
def test_get_recent_activities_training_completed(self, patched_session, sample_training_task):
"""Test activities include training completions."""
service = DashboardActivityService()
session = patched_session
# Update training task to completed
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.85
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
training_activities = [a for a in activities if a["type"] == "training_completed"]
assert len(training_activities) >= 1
assert "mAP" in training_activities[0]["metadata"]
def test_get_recent_activities_training_failed(self, patched_session, sample_training_task):
"""Test activities include training failures."""
service = DashboardActivityService()
session = patched_session
# Update training task to failed
sample_training_task.status = "failed"
sample_training_task.error_message = "CUDA out of memory"
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
failed_activities = [a for a in activities if a["type"] == "training_failed"]
assert len(failed_activities) >= 1
assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory"
def test_get_recent_activities_model_activated(self, patched_session, sample_model_version):
"""Test activities include model activations."""
service = DashboardActivityService()
session = patched_session
# Activate model
sample_model_version.is_active = True
sample_model_version.activated_at = datetime.now(timezone.utc)
session.add(sample_model_version)
session.commit()
activities = service.get_recent_activities()
activation_activities = [a for a in activities if a["type"] == "model_activated"]
assert len(activation_activities) >= 1
assert activation_activities[0]["metadata"]["version"] == sample_model_version.version
def test_get_recent_activities_limit(self, patched_session, admin_token):
"""Test activity limit parameter."""
service = DashboardActivityService()
session = patched_session
# Create many documents
for i in range(20):
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename=f"limit_doc_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/limit_doc_{i}.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
session.commit()
activities = service.get_recent_activities(limit=5)
assert len(activities) <= 5
def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task):
"""Test activities are sorted by timestamp descending."""
service = DashboardActivityService()
session = patched_session
# Create document
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="sorted_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/sorted_doc.pdf",
page_count=1,
status="pending",
upload_source="ui",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(doc)
# Complete training task
sample_training_task.status = "completed"
sample_training_task.metrics_mAP = 0.90
sample_training_task.updated_at = datetime.now(timezone.utc)
session.add(sample_training_task)
session.commit()
activities = service.get_recent_activities()
# Verify sorted by timestamp DESC
timestamps = [a["timestamp"] for a in activities]
assert timestamps == sorted(timestamps, reverse=True)

View File

@@ -0,0 +1,453 @@
"""
Dataset Builder Service Integration Tests
Tests DatasetBuilder with real file operations and repository interactions.
"""
import shutil
from datetime import datetime, timezone
from pathlib import Path
from uuid import uuid4
import pytest
import yaml
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.annotation_repository import AnnotationRepository
from inference.data.repositories.dataset_repository import DatasetRepository
from inference.data.repositories.document_repository import DocumentRepository
from inference.web.services.dataset_builder import DatasetBuilder
@pytest.fixture
def dataset_builder(patched_session, temp_dataset_dir):
"""Create a DatasetBuilder with real repositories."""
return DatasetBuilder(
datasets_repo=DatasetRepository(),
documents_repo=DocumentRepository(),
annotations_repo=AnnotationRepository(),
base_dir=temp_dataset_dir,
)
@pytest.fixture
def admin_images_dir(temp_upload_dir):
"""Create a directory for admin images."""
images_dir = temp_upload_dir / "admin_images"
images_dir.mkdir(parents=True, exist_ok=True)
return images_dir
@pytest.fixture
def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir):
"""Create documents with annotations and corresponding image files."""
documents = []
doc_repo = DocumentRepository()
ann_repo = AnnotationRepository()
for i in range(5):
# Create document
doc_id = doc_repo.create(
filename=f"invoice_{i}.pdf",
file_size=1024,
content_type="application/pdf",
file_path=f"/uploads/invoice_{i}.pdf",
page_count=2,
category="invoice",
group_key=f"group_{i % 2}", # Two groups
)
# Create image files for each page
doc_dir = admin_images_dir / doc_id
doc_dir.mkdir(parents=True, exist_ok=True)
for page in range(1, 3):
image_path = doc_dir / f"page_{page}.png"
# Create a minimal fake PNG
image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
# Create annotations
for j in range(3):
ann_repo.create(
document_id=doc_id,
page_number=1,
class_id=j,
class_name=f"field_{j}",
x_center=0.5,
y_center=0.1 + j * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + j * 160,
bbox_width=160,
bbox_height=40,
text_value=f"value_{j}",
confidence=0.95,
source="auto",
)
doc = doc_repo.get(doc_id)
documents.append(doc)
return documents
class TestDatasetBuilderBasic:
"""Tests for basic dataset building operations."""
def test_build_dataset_creates_directory_structure(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that building creates proper directory structure."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Test Dataset")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Check directory structure
assert (dataset_dir / "images" / "train").exists()
assert (dataset_dir / "images" / "val").exists()
assert (dataset_dir / "images" / "test").exists()
assert (dataset_dir / "labels" / "train").exists()
assert (dataset_dir / "labels" / "val").exists()
assert (dataset_dir / "labels" / "test").exists()
def test_build_dataset_copies_images(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that images are copied to dataset directory."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Image Copy Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
result = dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total images across all splits
total_images = 0
for split in ["train", "val", "test"]:
images = list((dataset_dir / "images" / split).glob("*.png"))
total_images += len(images)
# 5 docs * 2 pages = 10 images
assert total_images == 10
assert result["total_images"] == 10
def test_build_dataset_generates_labels(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that YOLO label files are generated."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Count total label files
total_labels = 0
for split in ["train", "val", "test"]:
labels = list((dataset_dir / "labels" / split).glob("*.txt"))
total_labels += len(labels)
# Same count as images
assert total_labels == 10
def test_build_dataset_generates_data_yaml(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that data.yaml is generated correctly."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="YAML Generation Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
yaml_path = dataset_dir / "data.yaml"
assert yaml_path.exists()
with open(yaml_path) as f:
data = yaml.safe_load(f)
assert data["train"] == "images/train"
assert data["val"] == "images/val"
assert data["test"] == "images/test"
assert "nc" in data
assert "names" in data
class TestDatasetBuilderSplits:
"""Tests for train/val/test split assignment."""
def test_split_ratio_respected(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that split ratios are approximately respected."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Split Ratio Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.6,
val_ratio=0.2,
seed=42,
admin_images_dir=admin_images_dir,
)
# Check document assignments in database
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
splits = {"train": 0, "val": 0, "test": 0}
for doc in dataset_docs:
splits[doc.split] += 1
# With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1
# Due to rounding and group constraints, allow some variation
assert splits["train"] >= 2
assert splits["val"] >= 1 or splits["test"] >= 1
def test_same_seed_same_split(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that same seed produces same split."""
dataset_repo = DatasetRepository()
doc_ids = [str(d.document_id) for d in documents_with_annotations]
# Build first dataset
dataset1 = dataset_repo.create(name="Seed Test 1")
dataset_builder.build_dataset(
dataset_id=str(dataset1.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Build second dataset with same seed
dataset2 = dataset_repo.create(name="Seed Test 2")
dataset_builder.build_dataset(
dataset_id=str(dataset2.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=12345,
admin_images_dir=admin_images_dir,
)
# Compare splits
docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))}
docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))}
assert docs1 == docs2
class TestDatasetBuilderDatabase:
"""Tests for database interactions."""
def test_updates_dataset_status(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that dataset status is updated after build."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Status Update Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "ready"
assert updated.total_documents == 5
assert updated.total_images == 10
assert updated.total_annotations > 0
assert updated.dataset_path is not None
def test_records_document_assignments(
self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session
):
"""Test that document assignments are recorded in database."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Assignment Recording Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id))
assert len(dataset_docs) == 5
for doc in dataset_docs:
assert doc.split in ["train", "val", "test"]
assert doc.page_count > 0
class TestDatasetBuilderErrors:
"""Tests for error handling."""
def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with empty document list."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Empty Docs Test")
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session):
"""Test that building fails with nonexistent document IDs."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Invalid IDs Test")
fake_ids = [str(uuid4()) for _ in range(3)]
with pytest.raises(ValueError, match="No valid documents"):
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=fake_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session):
"""Test that dataset status is set to failed on error."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Failure Status Test")
try:
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[],
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
except ValueError:
pass
updated = dataset_repo.get(str(dataset.dataset_id))
assert updated.status == "failed"
assert updated.error_message is not None
class TestLabelFileFormat:
"""Tests for YOLO label file format."""
def test_label_file_format(
self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session
):
"""Test that label files are in correct YOLO format."""
dataset_repo = DatasetRepository()
dataset = dataset_repo.create(name="Label Format Test")
doc_ids = [str(d.document_id) for d in documents_with_annotations]
dataset_builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=doc_ids,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
admin_images_dir=admin_images_dir,
)
dataset_dir = temp_dataset_dir / str(dataset.dataset_id)
# Find a label file with content
label_files = []
for split in ["train", "val", "test"]:
label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt")))
# Check at least one label file has correct format
found_valid_label = False
for label_file in label_files:
content = label_file.read_text().strip()
if content:
lines = content.split("\n")
for line in lines:
parts = line.split()
assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}"
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
assert 0 <= class_id < 10
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 <= width <= 1
assert 0 <= height <= 1
found_valid_label = True
break
assert found_valid_label, "No valid label files found"

View File

@@ -0,0 +1,283 @@
"""
Document Service Integration Tests
Tests DocumentService with real storage operations.
"""
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from inference.web.services.document_service import DocumentService, DocumentResult
class MockStorageBackend:
"""Simple in-memory storage backend for testing."""
def __init__(self):
self._files: dict[str, bytes] = {}
def upload_bytes(self, content: bytes, remote_path: str, overwrite: bool = False) -> None:
if not overwrite and remote_path in self._files:
raise FileExistsError(f"File already exists: {remote_path}")
self._files[remote_path] = content
def download_bytes(self, remote_path: str) -> bytes:
if remote_path not in self._files:
raise FileNotFoundError(f"File not found: {remote_path}")
return self._files[remote_path]
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
return f"https://storage.example.com/{remote_path}?expires={expires_in_seconds}"
def exists(self, remote_path: str) -> bool:
return remote_path in self._files
def delete(self, remote_path: str) -> bool:
if remote_path in self._files:
del self._files[remote_path]
return True
return False
def list_files(self, prefix: str) -> list[str]:
return [path for path in self._files.keys() if path.startswith(prefix)]
@pytest.fixture
def mock_storage():
"""Create a mock storage backend."""
return MockStorageBackend()
@pytest.fixture
def document_service(mock_storage):
"""Create a DocumentService with mock storage."""
return DocumentService(storage_backend=mock_storage)
class TestDocumentUpload:
"""Tests for document upload operations."""
def test_upload_document(self, document_service):
"""Test uploading a document."""
content = b"%PDF-1.4 test content"
filename = "test_invoice.pdf"
result = document_service.upload_document(content, filename)
assert result is not None
assert result.id is not None
assert result.filename == filename
assert result.file_path.startswith("documents/")
assert result.file_path.endswith(".pdf")
def test_upload_document_with_custom_id(self, document_service):
"""Test uploading with custom document ID."""
content = b"%PDF-1.4 test content"
filename = "invoice.pdf"
custom_id = "custom-doc-12345"
result = document_service.upload_document(
content, filename, document_id=custom_id
)
assert result.id == custom_id
assert custom_id in result.file_path
def test_upload_preserves_extension(self, document_service):
"""Test that file extension is preserved."""
cases = [
("document.pdf", ".pdf"),
("image.PNG", ".png"),
("file.JPEG", ".jpeg"),
("noextension", ""),
]
for filename, expected_ext in cases:
result = document_service.upload_document(b"content", filename)
if expected_ext:
assert result.file_path.endswith(expected_ext)
def test_upload_document_overwrite(self, document_service, mock_storage):
"""Test that upload overwrites existing file."""
content1 = b"original content"
content2 = b"new content"
doc_id = "overwrite-test"
document_service.upload_document(content1, "doc.pdf", document_id=doc_id)
document_service.upload_document(content2, "doc.pdf", document_id=doc_id)
# Should have new content
remote_path = f"documents/{doc_id}.pdf"
stored_content = mock_storage.download_bytes(remote_path)
assert stored_content == content2
class TestDocumentDownload:
"""Tests for document download operations."""
def test_download_document(self, document_service, mock_storage):
"""Test downloading a document."""
content = b"test document content"
remote_path = "documents/test-doc.pdf"
mock_storage.upload_bytes(content, remote_path)
downloaded = document_service.download_document(remote_path)
assert downloaded == content
def test_download_nonexistent_document(self, document_service):
"""Test downloading document that doesn't exist."""
with pytest.raises(FileNotFoundError):
document_service.download_document("documents/nonexistent.pdf")
class TestDocumentUrl:
"""Tests for document URL generation."""
def test_get_document_url(self, document_service, mock_storage):
"""Test getting presigned URL for document."""
remote_path = "documents/test-doc.pdf"
mock_storage.upload_bytes(b"content", remote_path)
url = document_service.get_document_url(remote_path, expires_in_seconds=7200)
assert url.startswith("https://")
assert remote_path in url
assert "7200" in url
def test_get_document_url_default_expiry(self, document_service):
"""Test default URL expiry."""
url = document_service.get_document_url("documents/doc.pdf")
assert "3600" in url
class TestDocumentExists:
"""Tests for document existence check."""
def test_document_exists(self, document_service, mock_storage):
"""Test checking if document exists."""
remote_path = "documents/existing.pdf"
mock_storage.upload_bytes(b"content", remote_path)
assert document_service.document_exists(remote_path) is True
def test_document_not_exists(self, document_service):
"""Test checking if nonexistent document exists."""
assert document_service.document_exists("documents/nonexistent.pdf") is False
class TestDocumentDelete:
"""Tests for document deletion."""
def test_delete_document(self, document_service, mock_storage):
"""Test deleting a document."""
remote_path = "documents/to-delete.pdf"
mock_storage.upload_bytes(b"content", remote_path)
result = document_service.delete_document_files(remote_path)
assert result is True
assert document_service.document_exists(remote_path) is False
def test_delete_nonexistent_document(self, document_service):
"""Test deleting document that doesn't exist."""
result = document_service.delete_document_files("documents/nonexistent.pdf")
assert result is False
class TestPageImages:
"""Tests for page image operations."""
def test_save_page_image(self, document_service, mock_storage):
"""Test saving a page image."""
doc_id = "test-doc-123"
page_num = 1
image_content = b"\x89PNG\r\n\x1a\n fake png"
remote_path = document_service.save_page_image(doc_id, page_num, image_content)
assert remote_path == f"images/{doc_id}/page_{page_num}.png"
assert mock_storage.exists(remote_path)
def test_save_multiple_page_images(self, document_service, mock_storage):
"""Test saving images for multiple pages."""
doc_id = "multi-page-doc"
for page_num in range(1, 4):
content = f"page {page_num} content".encode()
document_service.save_page_image(doc_id, page_num, content)
images = document_service.list_document_images(doc_id)
assert len(images) == 3
def test_get_page_image(self, document_service, mock_storage):
"""Test downloading a page image."""
doc_id = "test-doc"
page_num = 2
image_content = b"image data"
document_service.save_page_image(doc_id, page_num, image_content)
downloaded = document_service.get_page_image(doc_id, page_num)
assert downloaded == image_content
def test_get_page_image_url(self, document_service):
"""Test getting URL for page image."""
doc_id = "test-doc"
page_num = 1
url = document_service.get_page_image_url(doc_id, page_num)
assert f"images/{doc_id}/page_{page_num}.png" in url
def test_list_document_images(self, document_service, mock_storage):
"""Test listing all images for a document."""
doc_id = "list-test-doc"
for i in range(5):
document_service.save_page_image(doc_id, i + 1, f"page {i}".encode())
images = document_service.list_document_images(doc_id)
assert len(images) == 5
def test_delete_document_images(self, document_service, mock_storage):
"""Test deleting all images for a document."""
doc_id = "delete-images-doc"
for i in range(3):
document_service.save_page_image(doc_id, i + 1, b"content")
deleted_count = document_service.delete_document_images(doc_id)
assert deleted_count == 3
assert len(document_service.list_document_images(doc_id)) == 0
class TestRoundTrip:
"""Tests for complete upload-download cycles."""
def test_document_round_trip(self, document_service):
"""Test uploading and downloading document."""
original_content = b"%PDF-1.4 complete document content here"
filename = "roundtrip.pdf"
result = document_service.upload_document(original_content, filename)
downloaded = document_service.download_document(result.file_path)
assert downloaded == original_content
def test_image_round_trip(self, document_service):
"""Test saving and retrieving page image."""
doc_id = "roundtrip-doc"
page_num = 1
original_image = b"\x89PNG fake image data"
document_service.save_page_image(doc_id, page_num, original_image)
retrieved = document_service.get_page_image(doc_id, page_num)
assert retrieved == original_image

View File

@@ -0,0 +1,258 @@
"""
Database Setup Integration Tests
Tests for database connection, session management, and basic operations.
"""
import pytest
from sqlmodel import Session, select
from inference.data.admin_models import AdminDocument, AdminToken
class TestDatabaseConnection:
"""Tests for database engine and connection."""
def test_engine_connection(self, test_engine):
"""Verify database engine can establish connection."""
with test_engine.connect() as conn:
result = conn.execute(select(1))
assert result.scalar() == 1
def test_tables_created(self, test_engine):
"""Verify all expected tables are created."""
from sqlmodel import SQLModel
table_names = SQLModel.metadata.tables.keys()
expected_tables = [
"admin_tokens",
"admin_documents",
"admin_annotations",
"training_tasks",
"training_logs",
"batch_uploads",
"batch_upload_files",
"training_datasets",
"dataset_documents",
"training_document_links",
"model_versions",
]
for table in expected_tables:
assert table in table_names, f"Table '{table}' not found"
class TestSessionManagement:
"""Tests for database session context manager."""
def test_session_commit(self, db_session):
"""Verify session commits changes successfully."""
token = AdminToken(
token="commit-test-token",
name="Commit Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
result = db_session.exec(
select(AdminToken).where(AdminToken.token == "commit-test-token")
).first()
assert result is not None
assert result.name == "Commit Test"
def test_session_rollback_on_error(self, test_engine):
"""Verify session rollback on exception."""
session = Session(test_engine)
try:
token = AdminToken(
token="rollback-test-token",
name="Rollback Test",
is_active=True,
)
session.add(token)
session.commit()
# Try to insert duplicate (should fail)
duplicate = AdminToken(
token="rollback-test-token", # Same primary key
name="Duplicate",
is_active=True,
)
session.add(duplicate)
session.commit()
except Exception:
session.rollback()
finally:
session.close()
# Verify original record exists
with Session(test_engine) as verify_session:
result = verify_session.exec(
select(AdminToken).where(AdminToken.token == "rollback-test-token")
).first()
assert result is not None
assert result.name == "Rollback Test"
def test_session_isolation(self, test_engine):
"""Verify sessions are isolated from each other."""
session1 = Session(test_engine)
session2 = Session(test_engine)
try:
# Insert in session1, don't commit
token = AdminToken(
token="isolation-test-token",
name="Isolation Test",
is_active=True,
)
session1.add(token)
session1.flush()
# Session2 should not see uncommitted data (with proper isolation)
# Note: SQLite in-memory may have different isolation behavior
session1.commit()
result = session2.exec(
select(AdminToken).where(AdminToken.token == "isolation-test-token")
).first()
# After commit, session2 should see the data
assert result is not None
finally:
session1.close()
session2.close()
class TestBasicCRUDOperations:
"""Tests for basic CRUD operations on database."""
def test_create_and_read_token(self, db_session):
"""Test creating and reading admin token."""
token = AdminToken(
token="crud-test-token",
name="CRUD Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
result = db_session.get(AdminToken, "crud-test-token")
assert result is not None
assert result.name == "CRUD Test"
assert result.is_active is True
def test_update_entity(self, db_session, admin_token):
"""Test updating an entity."""
admin_token.name = "Updated Name"
db_session.add(admin_token)
db_session.commit()
result = db_session.get(AdminToken, admin_token.token)
assert result is not None
assert result.name == "Updated Name"
def test_delete_entity(self, db_session):
"""Test deleting an entity."""
token = AdminToken(
token="delete-test-token",
name="Delete Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
db_session.delete(token)
db_session.commit()
result = db_session.get(AdminToken, "delete-test-token")
assert result is None
def test_foreign_key_constraint(self, db_session, admin_token):
"""Test foreign key constraints are enforced."""
from uuid import uuid4
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="fk_test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/test/fk_test.pdf",
page_count=1,
status="pending",
)
db_session.add(doc)
db_session.commit()
# Document should reference valid token
result = db_session.get(AdminDocument, doc.document_id)
assert result is not None
assert result.admin_token == admin_token.token
class TestQueryOperations:
"""Tests for various query operations."""
def test_select_with_filter(self, db_session, multiple_documents):
"""Test SELECT with WHERE clause."""
results = db_session.exec(
select(AdminDocument).where(AdminDocument.status == "labeled")
).all()
assert len(results) == 2
for doc in results:
assert doc.status == "labeled"
def test_select_with_order(self, db_session, multiple_documents):
"""Test SELECT with ORDER BY clause."""
results = db_session.exec(
select(AdminDocument).order_by(AdminDocument.file_size.desc())
).all()
file_sizes = [doc.file_size for doc in results]
assert file_sizes == sorted(file_sizes, reverse=True)
def test_select_with_limit_offset(self, db_session, multiple_documents):
"""Test SELECT with LIMIT and OFFSET."""
results = db_session.exec(
select(AdminDocument)
.order_by(AdminDocument.filename)
.offset(2)
.limit(2)
).all()
assert len(results) == 2
def test_count_query(self, db_session, multiple_documents):
"""Test COUNT aggregation."""
from sqlalchemy import func
count = db_session.exec(
select(func.count()).select_from(AdminDocument)
).one()
assert count == 5
def test_group_by_query(self, db_session, multiple_documents):
"""Test GROUP BY aggregation."""
from sqlalchemy import func
results = db_session.exec(
select(
AdminDocument.status,
func.count(AdminDocument.document_id).label("count"),
).group_by(AdminDocument.status)
).all()
status_counts = {row[0]: row[1] for row in results}
assert status_counts.get("pending") == 2
assert status_counts.get("labeled") == 2
assert status_counts.get("exported") == 1

View File

@@ -196,3 +196,121 @@ class TestAnnotationModel:
assert 0 <= ann.y_center <= 1
assert 0 <= ann.width <= 1
assert 0 <= ann.height <= 1
class TestAutoLabelFilePathResolution:
"""Tests for auto-label file path resolution.
The auto-label endpoint needs to resolve the storage path (e.g., "raw_pdfs/uuid.pdf")
to an actual filesystem path via the storage helper.
"""
def test_extracts_filename_from_storage_path(self):
"""Test that filename is extracted from storage path correctly."""
# Storage paths are like "raw_pdfs/uuid.pdf"
storage_path = "raw_pdfs/550e8400-e29b-41d4-a716-446655440000.pdf"
# The annotation endpoint extracts filename
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
def test_handles_path_without_prefix(self):
"""Test that paths without prefix are handled."""
storage_path = "550e8400-e29b-41d4-a716-446655440000.pdf"
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
def test_storage_helper_resolves_path(self):
"""Test that storage helper can resolve the path."""
from pathlib import Path
from unittest.mock import MagicMock, patch
# Mock storage helper
mock_storage = MagicMock()
mock_path = Path("/storage/raw_pdfs/test.pdf")
mock_storage.get_raw_pdf_local_path.return_value = mock_path
with patch(
"inference.web.services.storage_helpers.get_storage_helper",
return_value=mock_storage,
):
from inference.web.services.storage_helpers import get_storage_helper
storage = get_storage_helper()
result = storage.get_raw_pdf_local_path("test.pdf")
assert result == mock_path
mock_storage.get_raw_pdf_local_path.assert_called_once_with("test.pdf")
def test_auto_label_request_validation(self):
"""Test AutoLabelRequest validates field_values."""
# Valid request
request = AutoLabelRequest(
field_values={"InvoiceNumber": "12345"},
replace_existing=False,
)
assert request.field_values == {"InvoiceNumber": "12345"}
# Empty field_values should be valid at schema level
# (endpoint validates non-empty)
request_empty = AutoLabelRequest(
field_values={},
replace_existing=False,
)
assert request_empty.field_values == {}
class TestMatchClassAttributes:
"""Tests for Match class attributes used in auto-labeling.
The autolabel service uses Match objects from FieldMatcher.
Verifies the correct attribute names are used.
"""
def test_match_has_matched_text_attribute(self):
"""Test that Match class has matched_text attribute (not matched_value)."""
from shared.matcher.models import Match
# Create a Match object
match = Match(
field="invoice_number",
value="12345",
bbox=(100, 100, 200, 150),
page_no=0,
score=0.95,
matched_text="INV-12345",
context_keywords=["faktura", "nummer"],
)
# Verify matched_text exists (this is what autolabel.py should use)
assert hasattr(match, "matched_text")
assert match.matched_text == "INV-12345"
# Verify matched_value does NOT exist
# This was the bug - autolabel.py was using matched_value instead of matched_text
assert not hasattr(match, "matched_value")
def test_match_attributes_for_annotation_creation(self):
"""Test that Match has all attributes needed for annotation creation."""
from shared.matcher.models import Match
match = Match(
field="amount",
value="1000.00",
bbox=(50, 200, 150, 230),
page_no=0,
score=0.88,
matched_text="1 000,00",
context_keywords=["att betala", "summa"],
)
# These are all the attributes used in autolabel._create_annotations_from_matches
assert hasattr(match, "bbox")
assert hasattr(match, "matched_text") # NOT matched_value
assert hasattr(match, "score")
# Verify bbox format
assert len(match.bbox) == 4 # (x0, y0, x1, y1)

View File

@@ -3,7 +3,7 @@ Tests for Admin Authentication.
"""
import pytest
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from fastapi import HTTPException
@@ -132,6 +132,47 @@ class TestTokenRepository:
with patch.object(repo, "_now", return_value=datetime.utcnow()):
assert repo.is_valid("test-token") is False
def test_is_valid_expired_token_timezone_aware(self):
"""Test expired token with timezone-aware datetime.
This verifies the fix for comparing timezone-aware and naive datetimes.
The auth API now creates tokens with timezone-aware expiration dates.
"""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
# Create token with timezone-aware expiration (as auth API now does)
mock_token = AdminToken(
token="test-token",
name="Test",
is_active=True,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
)
mock_session.get.return_value = mock_token
repo = TokenRepository()
# _now() returns timezone-aware datetime, should compare correctly
assert repo.is_valid("test-token") is False
def test_is_valid_not_expired_token_timezone_aware(self):
"""Test non-expired token with timezone-aware datetime."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
# Create token with timezone-aware expiration in the future
mock_token = AdminToken(
token="test-token",
name="Test",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=1),
)
mock_session.get.return_value = mock_token
repo = TokenRepository()
assert repo.is_valid("test-token") is True
def test_is_valid_token_not_found(self):
"""Test token not found."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:

View File

@@ -0,0 +1,317 @@
"""
Tests for Dashboard API Endpoints and Services.
Tests are split into:
1. Unit tests for business logic (is_annotation_complete, etc.)
2. Service tests with mocked database
3. Integration tests via TestClient (requires DB)
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
# Test data constants
TEST_DOC_UUID_1 = "550e8400-e29b-41d4-a716-446655440001"
TEST_MODEL_UUID = "660e8400-e29b-41d4-a716-446655440001"
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440001"
class TestAnnotationCompletenessLogic:
"""Unit tests for annotation completeness calculation logic.
These tests verify the core business logic:
- Complete: has (invoice_number OR ocr_number) AND (bankgiro OR plusgiro)
- Incomplete: labeled but missing required fields
"""
def test_document_with_invoice_number_and_bankgiro_is_complete(self):
"""Document with invoice_number + bankgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
{"class_id": 4, "class_name": "bankgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_ocr_number_and_plusgiro_is_complete(self):
"""Document with ocr_number + plusgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 3, "class_name": "ocr_number"},
{"class_id": 5, "class_name": "plusgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_invoice_number_and_plusgiro_is_complete(self):
"""Document with invoice_number + plusgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
{"class_id": 5, "class_name": "plusgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_ocr_number_and_bankgiro_is_complete(self):
"""Document with ocr_number + bankgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 3, "class_name": "ocr_number"},
{"class_id": 4, "class_name": "bankgiro"},
]
assert is_annotation_complete(annotations) is True
def test_document_with_only_identifier_is_incomplete(self):
"""Document with only identifier field should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
]
assert is_annotation_complete(annotations) is False
def test_document_with_only_payment_is_incomplete(self):
"""Document with only payment field should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 4, "class_name": "bankgiro"},
]
assert is_annotation_complete(annotations) is False
def test_document_with_no_annotations_is_incomplete(self):
"""Document with no annotations should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
assert is_annotation_complete([]) is False
def test_document_with_other_fields_only_is_incomplete(self):
"""Document with only non-essential fields should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 1, "class_name": "invoice_date"},
{"class_id": 6, "class_name": "amount"},
]
assert is_annotation_complete(annotations) is False
def test_document_with_all_fields_is_complete(self):
"""Document with all fields should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
{"class_id": 1, "class_name": "invoice_date"},
{"class_id": 4, "class_name": "bankgiro"},
{"class_id": 6, "class_name": "amount"},
]
assert is_annotation_complete(annotations) is True
class TestDashboardStatsService:
"""Tests for DashboardStatsService with mocked database."""
@pytest.fixture
def mock_session(self):
"""Create a mock database session."""
session = MagicMock()
session.exec.return_value.one.return_value = 0
return session
def test_completeness_rate_calculation(self):
"""Test completeness rate is calculated correctly."""
# Direct calculation test
complete = 25
incomplete = 8
total_assessed = complete + incomplete
expected_rate = round(complete / total_assessed * 100, 2)
assert expected_rate == pytest.approx(75.76, rel=0.01)
def test_completeness_rate_zero_documents(self):
"""Test completeness rate is 0 when no documents."""
complete = 0
incomplete = 0
total_assessed = complete + incomplete
completeness_rate = (
round(complete / total_assessed * 100, 2)
if total_assessed > 0
else 0.0
)
assert completeness_rate == 0.0
class TestDashboardActivityService:
"""Tests for DashboardActivityService activity aggregation."""
def test_activity_types(self):
"""Test all activity types are defined."""
expected_types = [
"document_uploaded",
"annotation_modified",
"training_completed",
"training_failed",
"model_activated",
]
for activity_type in expected_types:
assert activity_type in expected_types
class TestDashboardSchemas:
"""Tests for Dashboard API schemas."""
def test_dashboard_stats_response_schema(self):
"""Test DashboardStatsResponse schema validation."""
from inference.web.schemas.admin import DashboardStatsResponse
response = DashboardStatsResponse(
total_documents=38,
annotation_complete=25,
annotation_incomplete=8,
pending=5,
completeness_rate=75.76,
)
assert response.total_documents == 38
assert response.annotation_complete == 25
assert response.annotation_incomplete == 8
assert response.pending == 5
assert response.completeness_rate == 75.76
def test_active_model_response_schema(self):
"""Test ActiveModelResponse schema with null model."""
from inference.web.schemas.admin import ActiveModelResponse
response = ActiveModelResponse(
model=None,
running_training=None,
)
assert response.model is None
assert response.running_training is None
def test_active_model_info_schema(self):
"""Test ActiveModelInfo schema validation."""
from inference.web.schemas.admin import ActiveModelInfo
model = ActiveModelInfo(
version_id=TEST_MODEL_UUID,
version="1.2.0",
name="Invoice Model",
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
document_count=500,
activated_at=datetime(2024, 1, 20, 15, 0, 0, tzinfo=timezone.utc),
)
assert model.version == "1.2.0"
assert model.name == "Invoice Model"
assert model.metrics_mAP == 0.951
def test_running_training_info_schema(self):
"""Test RunningTrainingInfo schema validation."""
from inference.web.schemas.admin import RunningTrainingInfo
task = RunningTrainingInfo(
task_id=TEST_TASK_UUID,
name="Run-2024-02",
status="running",
started_at=datetime(2024, 1, 25, 10, 0, 0, tzinfo=timezone.utc),
progress=45,
)
assert task.name == "Run-2024-02"
assert task.status == "running"
assert task.progress == 45
def test_activity_item_schema(self):
"""Test ActivityItem schema validation."""
from inference.web.schemas.admin import ActivityItem
activity = ActivityItem(
type="model_activated",
description="Activated model v1.2.0",
timestamp=datetime(2024, 1, 25, 12, 0, 0, tzinfo=timezone.utc),
metadata={"version_id": TEST_MODEL_UUID, "version": "1.2.0"},
)
assert activity.type == "model_activated"
assert activity.description == "Activated model v1.2.0"
assert activity.metadata["version"] == "1.2.0"
def test_recent_activity_response_schema(self):
"""Test RecentActivityResponse schema with empty activities."""
from inference.web.schemas.admin import RecentActivityResponse
response = RecentActivityResponse(activities=[])
assert response.activities == []
class TestDashboardRouterCreation:
"""Tests for dashboard router creation."""
def test_creates_router_with_expected_endpoints(self):
"""Test router is created with expected endpoint paths."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
paths = [route.path for route in router.routes]
assert any("/stats" in p for p in paths)
assert any("/active-model" in p for p in paths)
assert any("/activity" in p for p in paths)
def test_router_has_correct_prefix(self):
"""Test router has /admin/dashboard prefix."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
assert router.prefix == "/admin/dashboard"
def test_router_has_dashboard_tag(self):
"""Test router uses Dashboard tag."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
assert "Dashboard" in router.tags
class TestFieldClassIds:
"""Tests for field class ID constants."""
def test_identifier_class_ids(self):
"""Test identifier field class IDs."""
from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
# invoice_number = 0, ocr_number = 3
assert 0 in IDENTIFIER_CLASS_IDS
assert 3 in IDENTIFIER_CLASS_IDS
def test_payment_class_ids(self):
"""Test payment field class IDs."""
from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS
# bankgiro = 4, plusgiro = 5
assert 4 in PAYMENT_CLASS_IDS
assert 5 in PAYMENT_CLASS_IDS

View File

@@ -1,68 +0,0 @@
#!/usr/bin/env python3
"""Update test imports to use new structure."""
import re
from pathlib import Path
# Import mapping: old -> new
IMPORT_MAPPINGS = {
# Admin routes
r'from src\.web\.admin_routes import': 'from src.web.api.v1.admin.documents import',
r'from src\.web\.admin_annotation_routes import': 'from src.web.api.v1.admin.annotations import',
r'from src\.web\.admin_training_routes import': 'from src.web.api.v1.admin.training import',
# Auth and core
r'from src\.web\.admin_auth import': 'from src.web.core.auth import',
r'from src\.web\.admin_autolabel import': 'from src.web.services.autolabel import',
r'from src\.web\.admin_scheduler import': 'from src.web.core.scheduler import',
# Schemas
r'from src\.web\.admin_schemas import': 'from src.web.schemas.admin import',
r'from src\.web\.schemas import': 'from src.web.schemas.inference import',
# Services
r'from src\.web\.services import': 'from src.web.services.inference import',
r'from src\.web\.async_service import': 'from src.web.services.async_processing import',
r'from src\.web\.batch_upload_service import': 'from src.web.services.batch_upload import',
# Workers
r'from src\.web\.async_queue import': 'from src.web.workers.async_queue import',
r'from src\.web\.batch_queue import': 'from src.web.workers.batch_queue import',
# Routes
r'from src\.web\.routes import': 'from src.web.api.v1.routes import',
r'from src\.web\.async_routes import': 'from src.web.api.v1.async_api.routes import',
r'from src\.web\.batch_upload_routes import': 'from src.web.api.v1.batch.routes import',
}
def update_file(file_path: Path) -> bool:
"""Update imports in a single file."""
content = file_path.read_text(encoding='utf-8')
original_content = content
for old_pattern, new_import in IMPORT_MAPPINGS.items():
content = re.sub(old_pattern, new_import, content)
if content != original_content:
file_path.write_text(content, encoding='utf-8')
return True
return False
def main():
"""Update all test files."""
test_dir = Path('tests/web')
updated_files = []
for test_file in test_dir.glob('test_*.py'):
if update_file(test_file):
updated_files.append(test_file.name)
if updated_files:
print(f"✓ Updated {len(updated_files)} test files:")
for filename in sorted(updated_files):
print(f" - {filename}")
else:
print("No files needed updating")
if __name__ == '__main__':
main()