diff --git a/.coverage b/.coverage index 3ce6985..94d8a0a 100644 Binary files a/.coverage and b/.coverage differ diff --git a/docs/Dashboard-UI-Prompts.md b/docs/Dashboard-UI-Prompts.md new file mode 100644 index 0000000..bb32135 --- /dev/null +++ b/docs/Dashboard-UI-Prompts.md @@ -0,0 +1,99 @@ +# Dashboard 原型图提示词 + +> 视觉风格:现代极简(Minimalism)- 保持现有 Warm 主题设计风格 +> 配色方案:Warm 浅色系(米白背景 #FAFAF8、白色卡片、深灰文字 #121212) +> 目标平台:网页(Web Desktop) + +--- + +## 当前颜色方案参考 + +| 用途 | 颜色值 | 说明 | +|------|--------|------| +| 页面背景 | #FAFAF8 | 温暖的米白色 | +| 卡片背景 | #FFFFFF | 纯白 | +| 边框 | #E6E4E1 | 浅灰褐色 | +| 主文字 | #121212 | 近黑色 | +| 次要文字 | #6B6B6B | 中灰色 | +| 成功状态 | #3E4A3A + green-500 | 深橄榄绿 + 亮绿指示点 | +| 警告状态 | #4A4A3A + yellow-50 | 深黄褐 + 浅黄背景 | +| 信息状态 | #3A3A3A + blue-50 | 深灰 + 浅蓝背景 | + +--- + +## 页面 1:Dashboard 主界面(正常状态) + +**页面说明**:用户登录后的首页,显示文档统计、数据质量、活跃模型状态和最近活动 + +**提示词**: +``` +A modern web application dashboard UI for a document annotation system, main overview page, warm minimalist design theme, page background color #FAFAF8 warm off-white, single column layout with header navigation at top, content area below with multiple sections, top section shows: 4 equal-width stat cards in a row on white #FFFFFF background with subtle border #E6E4E1, first card Total Documents (38) with gray file icon on #FAFAF8 background, second card Complete (25) with dark olive green checkmark icon on light green #dcfce7 background, third card Incomplete (8) with orange alert icon on light orange #fef3c7 background, fourth card Pending (5) with blue clock icon on light blue #dbeafe background, each card has icon top-left in rounded square and large bold number in #121212 with label below in #6B6B6B, cards have subtle shadow on hover, middle section has two-column layout (50%/50%): left panel white card titled DATA QUALITY in uppercase #6B6B6B with circular progress ring 120px showing 78% in center with green #22C55E filled portion and gray #E5E7EB remaining, percentage text 36px bold #121212 centered in ring, text Annotation Complete next to ring, stats list below showing Complete 25 and Incomplete 8 and Pending 5 with small colored dots, text button View Incomplete Docs in primary color at bottom, right panel white card titled ACTIVE MODEL showing v1.2.0 - Invoice Model as title in bold #121212, thin horizontal divider #E6E4E1 below, three-column metrics row displaying mAP 95.1% and Precision 94% and Recall 92% in 24px bold with 12px labels below in #6B6B6B, info rows showing Activated 2024-01-20 and Documents 500 in 14px, training progress section at bottom showing Run-2024-02 with horizontal progress bar, below panels is full-width white card RECENT ACTIVITY section with list of 6 activity items each 40px height showing icon on left and description text in #121212 and relative timestamp in #6B6B6B right aligned, activity icons: rocket in purple for model activation, checkmark in green for training complete, edit pencil in orange for annotation modified, file in blue for document uploaded, x in red for training failed, subtle hover background #F1F0ED on activity rows, bottom section is SYSTEM STATUS white card showing Backend API Online with bright green #22C55E dot and Database Connected with green dot and GPU Available with green dot, all text in #2A2A2A, Inter font family, rounded corners 8px on all cards, subtle card shadow, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style, dribbble quality +``` + +--- + +## 页面 2:Dashboard 空状态(无活跃模型) + +**页面说明**:系统刚部署或无训练模型时的引导界面 + +**提示词**: +``` +A modern web application dashboard UI for a document annotation system, empty state variation, warm minimalist design theme, page background #FAFAF8 warm off-white, single column layout with header navigation, top section shows: 4 stat cards on white background with #E6E4E1 border, all showing 0 values, Total Documents 0 with gray icon, Complete 0 with muted green, Incomplete 0 with muted orange, Pending 0 with muted blue, middle section two-column layout: left DATA QUALITY panel white card shows circular progress ring at 0% completely gray #E5E7EB with dashed outline style, large text 0% in #6B6B6B centered, text No data yet below in muted color, empty stats all showing 0, right ACTIVE MODEL panel white card shows empty state with large subtle model icon in center opacity 20%, text No Active Model as heading in #121212, subtext Train and activate a model to see stats here in #6B6B6B, primary button Go to Training at bottom, below panels RECENT ACTIVITY white card shows empty state with Activity icon centered at 20% opacity, text No recent activity in #121212, subtext Start by uploading documents or creating training jobs in #6B6B6B, bottom SYSTEM STATUS card showing all services online with green #22C55E dots, warm color palette throughout, Inter font, rounded corners 8px, subtle shadows, friendly and inviting empty state design, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style +``` + +--- + +## 页面 3:Dashboard 训练中状态 + +**页面说明**:有模型正在训练时,Active Model 面板显示训练进度 + +**提示词**: +``` +A modern web application dashboard UI for a document annotation system, training in progress state, warm minimalist theme with #FAFAF8 background, header with navigation, top section: 4 white stat cards with #E6E4E1 borders showing Total Documents 38, Complete 25 with green icon on #dcfce7, Incomplete 8 with orange icon on #fef3c7, Pending 5 with blue icon on #dbeafe, middle section two-column layout: left DATA QUALITY white card with 78% progress ring in green #22C55E, stats list showing counts, right ACTIVE MODEL white card showing current model v1.1.0 in bold #121212 with metrics mAP 93.5% Precision 92% Recall 88% in grid, below a highlighted training section with subtle blue tint background #EFF6FF, pulsing blue dot indicator, text Training in Progress in #121212, task name Run-2024-02, horizontal progress bar 45% complete with blue #3B82F6 fill and gray #E5E7EB track, text Started 2 hours ago in #6B6B6B below, RECENT ACTIVITY white card below with latest item showing blue spinner icon and Training started Run-2024-02, other activities listed with appropriate icons, SYSTEM STATUS card at bottom showing GPU Available highlighted with green dot indicating active usage, warm color scheme throughout, Inter font, 8px rounded corners, subtle card shadows, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style +``` + +--- + +## 页面 4:Dashboard 移动端响应式 + +**页面说明**:移动端(<768px)下的单列堆叠布局 + +**提示词**: +``` +A modern mobile web application dashboard UI for a document annotation system, responsive mobile layout on smartphone screen, warm minimalist theme with #FAFAF8 background, single column stacked layout, top shows condensed header with hamburger menu icon and logo, below 2x2 grid of compact white stat cards with #E6E4E1 borders showing Total 38 Complete 25 Incomplete 8 Pending 5 with small colored icons on tinted backgrounds, DATA QUALITY section below as full-width white card with smaller progress ring 80px showing 78% in green #22C55E, horizontal stats row compact, ACTIVE MODEL section below as full-width white card with model name v1.2.0 in bold, compact metrics row showing mAP Precision Recall values, RECENT ACTIVITY section full-width white card with scrollable list of 4 visible items with icons and timestamps in #6B6B6B, compact SYSTEM STATUS bar at bottom with three green #22C55E status dots, warm color palette #FAFAF8 background white cards #121212 text, Inter font, touch-friendly tap targets 44px minimum, comfortable 16px padding, 8px rounded corners, iOS/Android native feel, UI/UX design, high fidelity mockup, mobile screen 375x812 iPhone size, professional, Figma style +``` + +--- + +## 使用说明 + +1. 将提示词复制到 AI 绘图工具(如 Midjourney、DALL-E、Stable Diffusion) +2. 建议先生成「页面 1:主界面」验证风格是否匹配现有设计 +3. 提示词已包含你现有的颜色方案: + - 页面背景:#FAFAF8(温暖米白) + - 卡片背景:#FFFFFF(白色) + - 边框:#E6E4E1(浅灰褐) + - 主文字:#121212(近黑) + - 次要文字:#6B6B6B(中灰) + - 成功色:#22C55E(亮绿)/ #3E4A3A(深橄榄绿文字) + - 图标背景:#dcfce7(浅绿)/ #fef3c7(浅黄)/ #dbeafe(浅蓝) +4. 如果生成结果颜色有偏差,可以在后期用 Figma 调整 + +--- + +## Tailwind 类参考(开发用) + +``` +背景:bg-warm-bg (#FAFAF8) +卡片:bg-warm-card (#FFFFFF) +边框:border-warm-border (#E6E4E1) +主文字:text-warm-text-primary (#121212) +次要文字:text-warm-text-secondary (#2A2A2A) +灰色文字:text-warm-text-muted (#6B6B6B) +悬停背景:bg-warm-hover (#F1F0ED) +成功状态:text-warm-state-success (#3E4A3A) +绿色图标背景:bg-green-50 (#dcfce7) +黄色图标背景:bg-yellow-50 (#fef3c7) +蓝色图标背景:bg-blue-50 (#dbeafe) +绿色指示点:bg-green-500 (#22C55E) +``` diff --git a/docs/FORTNOX_INTEGRATION_SPEC.md b/docs/FORTNOX_INTEGRATION_SPEC.md new file mode 100644 index 0000000..79dfc32 --- /dev/null +++ b/docs/FORTNOX_INTEGRATION_SPEC.md @@ -0,0 +1,1690 @@ +# Invoice Master - Fortnox Integration Technical Specification + +**版本**: v1.0 +**日期**: 2026-02-01 +**作者**: Claude Code +**状态**: 设计阶段 + +--- + +## 目录 + +1. [概述](#概述) +2. [集成模式说明](#集成模式说明) +3. [系统架构](#系统架构) +4. [Fortnox API分析](#fortnox-api分析) +5. [数据映射设计](#数据映射设计) +6. [核心功能模块](#核心功能模块) +7. [用户流程设计](#用户流程设计) +8. [UI设计规范](#ui设计规范) +9. [API设计](#api设计) +10. [数据库设计](#数据库设计) +11. [安全设计](#安全设计) +12. [错误处理](#错误处理) +13. [开发计划](#开发计划) +14. [测试策略](#测试策略) +15. [部署方案](#部署方案) +16. [附录](#附录) + +--- + +## 概述 + +### 1.1 项目背景 + +Invoice Master是一个基于YOLOv11 + PaddleOCR的发票字段自动提取系统,当前准确率达到94.8%。本方案设计将Invoice Master作为Fortnox会计软件的插件/扩展,实现无缝的发票数据导入功能。 + +### 1.2 目标 + +- 为Fortnox用户提供智能发票识别功能 +- 实现一键将发票数据导入Fortnox +- 自动匹配供应商和会计科目 +- 减少90%的手动录入工作 + +### 1.3 范围 + +**包含功能:** +- Fortnox OAuth2认证集成 +- 发票PDF上传和OCR识别 +- 供应商自动匹配/创建 +- 会计凭证(Voucher)自动生成 +- 发票图像存档 + +**不包含功能 (Phase 2):** +- 多文档类型支持 (收据、对账单) +- 自动付款流程 +- 审批工作流 + +### 1.4 术语定义 + +| 术语 | 英文 | 说明 | +|------|------|------| +| 供应商 | Supplier | Leverantör i Fortnox | +| 会计凭证 | Voucher | Verifikation i Fortnox | +| 发票 | Invoice | Faktura | +| 科目 | Account | Konto i kontoplanen | +| OCR参考号 | OCR Number | 瑞典特有的付款参考号 | + +--- + +## 集成模式说明 + +### 2.1 Fortnox Extension UI模式 + +Fortnox的集成主要有**两种模式**,Invoice Master采用**模式1: 外部独立应用**。 + +#### 模式1: 外部独立应用 (External App) - 推荐 + +**架构示意图:** + +``` +用户流程: +┌─────────────────┐ ┌─────────────────────┐ ┌─────────────────┐ +│ Fortnox │────▶│ Invoice Master │────▶│ Fortnox │ +│ (点击集成) │ │ (独立Web应用) │ │ (数据已导入) │ +└─────────────────┘ └─────────────────────┘ └─────────────────┘ + │ │ │ + │ │ │ + ▼ ▼ ▼ + 在Fortnox中 用户在你的网站上 用户回到Fortnox + 看到"Invoice Master" 完成发票上传和识别 查看已导入的凭证 + 点击打开新窗口 +``` + +**特点:** +- ✅ **有自己的完整UI**(独立网站) +- ✅ 通过OAuth2连接Fortnox +- ✅ 用户在Fortnox点击后跳转到你的网站 +- ✅ 数据通过API双向同步 +- ✅ 更灵活的功能和用户体验 + +**Fortnox中的展示:** +- 在Fortnox Integrations页面列出 +- 用户点击后打开新标签页到你的网站 +- 显示连接状态和基本设置 + +#### 模式2: 嵌入式集成 (Embedded) - 有限支持 + +**Fortnox目前支持:** +- 菜单链接 (Menu Links) - 在Fortnox菜单中添加链接 +- 快捷操作 (Quick Actions) - 有限的上下文操作 +- 文件导入 (File Import) - 通过Inbox API + +**Fortnox不提供:** +- ❌ iframe嵌入第三方UI +- ❌ 自定义页面/标签 +- ❌ 深度UI定制 + +### 2.2 推荐方案: 混合模式 + +**架构设计:** + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Invoice Master for Fortnox │ +│ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ 独立Web应用 (你的域名) │ │ +│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ │ +│ │ │ 发票上传 │ │ 识别结果 │ │ 历史记录 │ │ │ +│ │ │ 页面 │ │ 确认页面 │ │ 页面 │ │ │ +│ │ └──────────────┘ └──────────────┘ └──────────────┘ │ │ +│ │ │ +│ │ 功能: OCR识别、供应商匹配、预览确认、一键导入Fortnox │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ │ +│ │ HTTPS API │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Fortnox Integration Service │ │ +│ │ (Backend API) │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ │ +└──────────────────────────────┼───────────────────────────────────┘ + │ + │ OAuth2 + REST API + │ + ┌──────────▼──────────┐ + │ Fortnox │ + │ (数据存储/展示) │ + └─────────────────────┘ +``` + +### 2.3 与纯API方案对比 + +| 特性 | 独立UI方案 (推荐) | 纯API方案 | +|------|------------------|-----------| +| **用户体验** | ⭐⭐⭐⭐⭐ 完整的可视化界面 | ⭐⭐ 需要用户自己调用API | +| **开发复杂度** | ⭐⭐⭐ 需要前端+后端 | ⭐⭐ 只需要后端API | +| **功能灵活性** | ⭐⭐⭐⭐⭐ 可以做OCR预览、编辑 | ⭐⭐ 直接导入,无法预览 | +| **用户门槛** | ⭐⭐⭐⭐⭐ 低,非技术用户可用 | ⭐ 高,需要开发者 | +| **Fortnox审核** | ⭐⭐⭐⭐ 标准流程 | ⭐⭐⭐⭐ 更简单 | + +### 2.4 用户完整流程 + +``` +1. 发现阶段 + 用户在Fortnox Integrations页面找到"Invoice Master" + +2. 授权阶段 + 用户点击"连接" → OAuth2授权 → 跳转到Invoice Master + +3. 使用阶段 (在Invoice Master网站) + 上传PDF → OCR识别 → 确认/编辑 → 导入到Fortnox + +4. 查看阶段 (回到Fortnox) + 用户在Fortnox中查看已导入的凭证和发票 +``` + +--- + +## 系统架构 + +### 3.1 整体架构图 + +**注意: 这是技术架构图,对应第2章描述的"独立Web应用"模式** + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Fortnox Platform │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ Fortnox │ │ Fortnox │ │ Fortnox │ │ +│ │ UI │ │ API │ │ Database │ │ +│ └──────┬───────┘ └──────┬───────┘ └──────────────┘ │ +└─────────┼─────────────────┼─────────────────────────────────────┘ + │ │ + │ OAuth2 │ HTTPS + │ │ +┌─────────▼─────────────────▼─────────────────────────────────────┐ +│ Invoice Master Integration │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Fortnox Integration Service │ │ +│ │ ┌────────────┐ ┌────────────┐ ┌────────────┐ │ │ +│ │ │ Auth │ │ Invoice │ │ Supplier │ │ │ +│ │ │ Module │ │ Handler │ │ Matcher │ │ │ +│ │ └────────────┘ └────────────┘ └────────────┘ │ │ +│ │ ┌────────────┐ ┌────────────┐ ┌────────────┐ │ │ +│ │ │ Voucher │ │ File │ │ Webhook │ │ │ +│ │ │ Creator │ │ Storage │ │ Handler │ │ │ +│ │ └────────────┘ └────────────┘ └────────────┘ │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Invoice Master Core Services │ │ +│ │ ┌────────────┐ ┌────────────┐ ┌────────────┐ │ │ +│ │ │ OCR │ │ YOLO │ │ Field │ │ │ +│ │ │ Engine │ │ Detector │ │Normalizer │ │ │ +│ │ └────────────┘ └────────────┘ └────────────┘ │ │ +│ └──────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + │ + │ PostgreSQL / Azure Blob + │ +┌─────────▼─────────────────────────────────────────────────────┐ +│ Data Storage │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ Invoice │ │ Fortnox │ │ File │ │ +│ │ Data │ │ Tokens │ │ Storage │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### 3.2 组件说明 + +| 组件 | 技术栈 | 职责 | +|------|--------|------| +| **Integration Service** | FastAPI + Python | Fortnox API交互、业务逻辑 | +| **Auth Module** | OAuth2 + JWT | Fortnox认证、Token管理 | +| **Invoice Handler** | - | 发票处理流程协调 | +| **Supplier Matcher** | Fuzzy Matching | 供应商匹配算法 | +| **Voucher Creator** | - | 生成Fortnox会计凭证 | +| **File Storage** | Azure Blob / S3 | 发票PDF存储 | +| **Webhook Handler** | - | 接收Fortnox事件 | + +### 3.3 技术栈 + +| 层级 | 技术 | 说明 | +|------|------|------| +| **Backend** | FastAPI + Python 3.11 | API服务 | +| **Database** | PostgreSQL 15 | 关系数据 | +| **Cache** | Redis | Token缓存、限流 | +| **Storage** | Azure Blob Storage | 文件存储 | +| **Message Queue** | Redis Queue | 异步任务 | +| **Monitoring** | Prometheus + Grafana | 监控告警 | + +--- + +## Fortnox API分析 + +### 3.1 认证机制 + +Fortnox使用OAuth2授权码流程: + +``` +┌─────────┐ ┌─────────────┐ +│ User │──(1) Authorization Request───────▶│ Fortnox │ +│ │◀──(2) Authorization Code─────────│ OAuth2 │ +│ │ │ Server │ +│ │──(3) Token Request────────────────▶│ │ +│ │◀──(4) Access + Refresh Token──────│ │ +└─────────┘ └─────────────┘ +``` + +**关键端点:** +``` +Authorization URL: https://apps.fortnox.se/oauth-v1/auth +Token URL: https://apps.fortnox.se/oauth-v1/token +API Base URL: https://api.fortnox.se/3 +``` + +**Scopes Required:** +``` +supplier - 供应商管理 +invoice - 发票管理 (如需要) +voucher - 会计凭证 +account - 会计科目 +companyinformation - 公司信息 +``` + +### 3.2 核心API端点 + +#### 3.2.1 供应商管理 + +```http +# 获取供应商列表 +GET /3/suppliers +Response: { + "Suppliers": [ + { + "@url": "https://api.fortnox.se/3/suppliers/123", + "Name": "ABC Company", + "SupplierNumber": "123", + "OrganisationNumber": "556677-8899" + } + ] +} + +# 创建供应商 +POST /3/suppliers +Body: { + "Supplier": { + "Name": "New Supplier", + "OrganisationNumber": "112233-4455" + } +} +``` + +#### 3.2.2 会计凭证 (Voucher) + +```http +# 创建会计凭证 +POST /3/vouchers +Body: { + "Voucher": { + "VoucherSeries": "A", // 凭证系列 + "TransactionDate": "2024-01-15", // 交易日期 + "VoucherRows": [ + { + "Account": 2440, // 应付账款科目 + "Debit": 1250.00, + "Credit": 0, + "Description": "Invoice F2024-001" + }, + { + "Account": 5460, // 费用科目 + "Debit": 0, + "Credit": 1000.00, + "Description": "Office supplies" + }, + { + "Account": 2610, // 增值税科目 + "Debit": 0, + "Credit": 250.00, + "Description": "VAT 25%" + } + ] + } +} +``` + +#### 3.2.3 文件上传 + +```http +# 上传附件到Fortnox +POST /3/inbox +Content-Type: multipart/form-data +Body: { + "file": [PDF file], + "name": "Invoice_F2024_001.pdf" +} +``` + +### 3.3 API限制 + +| 限制类型 | 值 | 说明 | +|---------|-----|------| +| 速率限制 | 300请求/分钟 | 超出返回429 | +| 并发连接 | 10 | 同时连接数 | +| Token有效期 | 3600秒 | 需使用Refresh Token | +| 文件大小 | 10MB | 单个文件限制 | + +--- + +## 数据映射设计 + +### 4.1 发票字段映射 + +**Invoice Master提取字段 → Fortnox字段** + +| Invoice Master | Fortnox | 类型 | 必填 | 转换逻辑 | +|---------------|---------|------|------|----------| +| `invoice_number` | `ExternalInvoiceNumber` | string | 是 | 直接映射 | +| `invoice_date` | `TransactionDate` | date | 是 | ISO 8601格式 | +| `due_date` | `DueDate` | date | 否 | 计算或提取 | +| `supplier_name` | `SupplierName` | string | 是 | 匹配或创建 | +| `supplier_org_number` | `SupplierOrganisationNumber` | string | 否 | 用于匹配 | +| `amount_total` | `TotalAmount` | decimal | 是 | 直接映射 | +| `amount_vat` | `VatAmount` | decimal | 否 | 计算得出 | +| `ocr_number` | `OCRNumber` | string | 否 | 瑞典特有 | +| `bankgiro` | `BankgiroNumber` | string | 否 | 付款信息 | +| `plusgiro` | `PlusgiroNumber` | string | 否 | 付款信息 | +| `currency` | `Currency` | string | 是 | 默认SEK | + +### 4.2 会计科目映射 + +**默认科目映射表 (Kontoplan BAS2024)** + +| 费用类型 | 科目代码 | 科目名称 | 说明 | +|---------|---------|---------|------| +| 应付账款 | 2440 | Leverantörsskulder | 默认贷方 | +| 办公用品 | 5460 | Kontorsmaterial | 常见费用 | +| 咨询服务 | 6210 | Konsultarvoden | 外部服务 | +| 运输费 | 5710 | Frakter | 物流费用 | +| 增值税进项 | 2610 | Ingående moms | 25% VAT | +| 增值税进项12% | 2620 | Ingående moms 12% | 食品等 | +| 增值税进项6% | 2630 | Ingående moms 6% | 交通等 | + +**科目选择逻辑:** +```python +def select_account(invoice_data: dict) -> int: + """根据发票内容选择会计科目""" + + # 1. 检查是否有历史映射 + if invoice_data['supplier_org_number']: + historical = get_historical_account(invoice_data['supplier_org_number']) + if historical: + return historical + + # 2. 关键词匹配 + description = invoice_data.get('description', '').lower() + if any(word in description for word in ['kontor', 'papper', 'penna']): + return 5460 # 办公用品 + elif any(word in description for word in ['konsult', 'tjänst']): + return 6210 # 咨询服务 + elif any(word in description for word in ['frakt', 'transport']): + return 5710 # 运输费 + + # 3. 默认科目 + return 6100 # 其他外部费用 +``` + +### 4.3 供应商匹配算法 + +**匹配优先级:** + +```python +class SupplierMatcher: + def match_supplier(self, extracted_data: dict) -> MatchResult: + """ + 供应商匹配算法 + 返回: (supplier_number, confidence_score, action) + """ + + # 1. 组织号精确匹配 (权重: 100%) + if extracted_data.get('supplier_org_number'): + exact_match = self.find_by_org_number( + extracted_data['supplier_org_number'] + ) + if exact_match: + return MatchResult( + supplier_number=exact_match.number, + confidence=1.0, + action='USE_EXISTING' + ) + + # 2. 名称模糊匹配 (权重: 80%) + name_matches = self.fuzzy_match_name( + extracted_data['supplier_name'], + threshold=0.85 + ) + if name_matches and name_matches[0].score > 0.9: + return MatchResult( + supplier_number=name_matches[0].number, + confidence=name_matches[0].score, + action='USE_EXISTING' + ) + + # 3. 建议创建新供应商 (权重: <80%) + return MatchResult( + supplier_number=None, + confidence=0.0, + action='CREATE_NEW', + suggested_name=extracted_data['supplier_name'] + ) +``` + +--- + +## 核心功能模块 + +### 5.1 认证模块 (Auth Module) + +**职责:** +- Fortnox OAuth2流程管理 +- Token存储和刷新 +- 多租户隔离 + +**核心类:** + +```python +class FortnoxAuthManager: + """Fortnox认证管理器""" + + def __init__(self, client_id: str, client_secret: str): + self.client_id = client_id + self.client_secret = client_secret + self.token_store = TokenStore() + + def get_authorization_url(self, state: str) -> str: + """生成Fortnox授权URL""" + params = { + 'client_id': self.client_id, + 'redirect_uri': settings.FORTNOX_REDIRECT_URI, + 'scope': 'supplier voucher account companyinformation', + 'state': state, + 'response_type': 'code' + } + return f"{FORTNOX_AUTH_URL}?{urlencode(params)}" + + async def exchange_code_for_token(self, code: str) -> FortnoxToken: + """用授权码换取Token""" + response = await httpx.post( + FORTNOX_TOKEN_URL, + auth=(self.client_id, self.client_secret), + data={ + 'grant_type': 'authorization_code', + 'code': code, + 'redirect_uri': settings.FORTNOX_REDIRECT_URI + } + ) + token_data = response.json() + + return FortnoxToken( + access_token=token_data['access_token'], + refresh_token=token_data['refresh_token'], + expires_at=datetime.utcnow() + timedelta(seconds=token_data['expires_in']), + scope=token_data['scope'] + ) + + async def get_valid_access_token(self, tenant_id: str) -> str: + """获取有效的访问Token(自动刷新)""" + token = await self.token_store.get(tenant_id) + + if token.is_expired(): + token = await self.refresh_token(token.refresh_token) + await self.token_store.save(tenant_id, token) + + return token.access_token +``` + +### 5.2 发票处理模块 (Invoice Handler) + +**处理流程:** + +```python +class InvoiceProcessingService: + """发票处理服务""" + + async def process_invoice( + self, + tenant_id: str, + pdf_file: UploadFile, + settings: ProcessingSettings + ) -> ProcessingResult: + """ + 处理发票的主流程 + """ + + # 1. 保存PDF文件 + file_path = await self.file_storage.save(pdf_file) + + # 2. OCR提取 + extraction_result = await self.ocr_service.extract(file_path) + + # 3. 验证提取结果 + if not self.validate_extraction(extraction_result): + return ProcessingResult( + status='FAILED', + error='Extraction validation failed' + ) + + # 4. 供应商匹配 + supplier_match = await self.supplier_matcher.match( + tenant_id, + extraction_result + ) + + # 5. 创建或获取供应商 + if supplier_match.action == 'CREATE_NEW': + supplier_number = await self.create_supplier( + tenant_id, + extraction_result + ) + else: + supplier_number = supplier_match.supplier_number + + # 6. 生成会计凭证 + voucher = await self.voucher_creator.create( + tenant_id, + extraction_result, + supplier_number, + settings + ) + + # 7. 上传附件 + if settings.attach_pdf: + await self.attach_invoice_pdf(tenant_id, voucher.id, file_path) + + return ProcessingResult( + status='SUCCESS', + extraction=extraction_result, + supplier_number=supplier_number, + voucher_id=voucher.id, + confidence=supplier_match.confidence + ) +``` + +### 5.3 供应商匹配模块 (Supplier Matcher) + +```python +class FortnoxSupplierMatcher: + """Fortnox供应商匹配器""" + + def __init__(self, fortnox_client: FortnoxClient): + self.client = fortnox_client + self.cache = SupplierCache() + + async def match( + self, + tenant_id: str, + extraction: ExtractionResult + ) -> SupplierMatchResult: + """匹配供应商""" + + # 获取所有供应商(带缓存) + suppliers = await self.cache.get_suppliers(tenant_id) + + # 1. 组织号精确匹配 + if extraction.supplier_org_number: + match = self._match_by_org_number( + suppliers, + extraction.supplier_org_number + ) + if match: + return SupplierMatchResult( + supplier_number=match['SupplierNumber'], + confidence=1.0, + action='USE_EXISTING' + ) + + # 2. 名称模糊匹配 + name_match = self._fuzzy_match_name( + suppliers, + extraction.supplier_name + ) + + if name_match and name_match['score'] > 0.9: + return SupplierMatchResult( + supplier_number=name_match['supplier']['SupplierNumber'], + confidence=name_match['score'], + action='USE_EXISTING' + ) + elif name_match and name_match['score'] > 0.7: + return SupplierMatchResult( + supplier_number=name_match['supplier']['SupplierNumber'], + confidence=name_match['score'], + action='SUGGEST_MATCH', + suggested_name=extraction.supplier_name + ) + + # 3. 建议创建新供应商 + return SupplierMatchResult( + supplier_number=None, + confidence=0.0, + action='CREATE_NEW', + suggested_name=extraction.supplier_name, + suggested_org_number=extraction.supplier_org_number + ) + + async def create_supplier( + self, + tenant_id: str, + extraction: ExtractionResult + ) -> str: + """在Fortnox中创建新供应商""" + + supplier_data = { + 'Supplier': { + 'Name': extraction.supplier_name, + 'OrganisationNumber': extraction.supplier_org_number, + 'Address1': extraction.supplier_address, + 'Phone': extraction.supplier_phone, + 'Email': extraction.supplier_email, + 'BankgiroNumber': extraction.bankgiro, + 'PlusgiroNumber': extraction.plusgiro + } + } + + response = await self.client.post( + tenant_id, + '/3/suppliers', + json=supplier_data + ) + + # 刷新缓存 + await self.cache.invalidate(tenant_id) + + return response['Supplier']['SupplierNumber'] +``` + +### 5.4 凭证生成模块 (Voucher Creator) + +```python +class FortnoxVoucherCreator: + """Fortnox会计凭证生成器""" + + async def create_voucher( + self, + tenant_id: str, + extraction: ExtractionResult, + supplier_number: str, + settings: VoucherSettings + ) -> VoucherResult: + """创建会计凭证""" + + # 确定会计科目 + account = await self.select_account(extraction) + + # 计算VAT + vat_amount = self.calculate_vat( + extraction.amount_total, + extraction.vat_rate or 25 + ) + amount_excl_vat = extraction.amount_total - vat_amount + + # 构建凭证行 + voucher_rows = [ + # 借方: 费用科目 + { + 'Account': account, + 'Debit': amount_excl_vat, + 'Credit': 0, + 'Description': f"{extraction.supplier_name} - {extraction.invoice_number}", + 'Project': settings.project_code + }, + # 借方: 增值税 + { + 'Account': self.get_vat_account(extraction.vat_rate), + 'Debit': vat_amount, + 'Credit': 0, + 'Description': f"Moms {extraction.vat_rate}%" + }, + # 贷方: 应付账款 + { + 'Account': 2440, # Leverantörsskulder + 'Debit': 0, + 'Credit': extraction.amount_total, + 'Description': f"Faktura {extraction.invoice_number}", + 'SupplierNumber': supplier_number, + 'OCRNumber': extraction.ocr_number + } + ] + + voucher_data = { + 'Voucher': { + 'VoucherSeries': settings.voucher_series or 'A', + 'TransactionDate': extraction.invoice_date.isoformat(), + 'VoucherText': f"Inköp {extraction.supplier_name}", + 'VoucherRows': voucher_rows + } + } + + response = await self.client.post( + tenant_id, + '/3/vouchers', + json=voucher_data + ) + + return VoucherResult( + voucher_id=response['Voucher']['VoucherNumber'], + series=response['Voucher']['VoucherSeries'], + url=response['Voucher']['@url'] + ) +``` + +--- + +## 用户流程设计 + +### 6.1 集成入口点 (在Fortnox中) + +**Fortnox Integrations页面展示:** + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Invoice Master - Smart Invoice OCR │ +│ │ +│ 📄 自动识别发票信息 │ +│ 🤖 AI驱动的OCR技术 │ +│ ⚡ 一键导入到Fortnox │ +│ │ +│ [连接/打开] │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Fortnox内配置页面:** + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Invoice Master 设置 │ +│ │ +│ 状态: ✅ 已连接 │ +│ 公司: My Company AB │ +│ │ +│ 默认设置: │ +│ - 凭证系列: [A ▼] │ +│ - 自动导入: [✓] │ +│ - 附件上传: [✓] │ +│ │ +│ [保存设置] [断开连接] │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 6.2 首次设置流程 + +``` +用户点击"连接Fortnox" + │ + ▼ +┌───────────────────┐ +│ 跳转到Fortnox授权页 │ +│ (OAuth2流程) │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 用户登录Fortnox │ +│ 并授权访问 │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 返回到Invoice │ +│ Master回调页面 │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 获取公司信息 │ +│ 验证连接成功 │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 配置默认设置 │ +│ - 凭证系列 │ +│ - 默认科目 │ +│ - 文件存储选项 │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 完成!显示 │ +│ 上传发票界面 │ +└───────────────────┘ +``` + +### 6.2 发票处理流程 + +``` +用户上传PDF发票 + │ + ▼ +┌───────────────────┐ +│ 显示处理进度 │ +│ - OCR提取中... │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 显示提取结果 │ +│ 供用户确认/编辑 │ +│ │ +│ ┌───────────────┐ │ +│ │ 供应商: XXX │ │ +│ │ 金额: 1,250 │ │ +│ │ 日期: 2024... │ │ +│ │ [编辑] [确认] │ │ +│ └───────────────┘ │ +└─────────┬─────────┘ + │ + ┌─────┴─────┐ + │ │ + ▼ ▼ +┌────────┐ ┌────────┐ +│ 编辑 │ │ 确认 │ +│ 数据 │ │ 导入 │ +└───┬────┘ └───┬────┘ + │ │ + └─────┬─────┘ + │ + ▼ +┌───────────────────┐ +│ 供应商匹配 │ +│ - 查找现有 │ +│ - 或创建新 │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 生成会计凭证 │ +│ 上传到Fortnox │ +└─────────┬─────────┘ + │ + ▼ +┌───────────────────┐ +│ 显示成功消息 │ +│ 提供Fortnox链接 │ +│ 查看凭证 │ +└───────────────────┘ +``` + +### 6.4 独立Web应用UI设计 + +**重要说明: 以下UI是Invoice Master独立Web应用的界面,用户在Fortnox点击"打开"后跳转到此界面。** + +#### 主界面 + +#### 主界面 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Invoice Master for Fortnox [⚙️设置] [?] │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 📤 上传发票 │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ 拖放PDF文件到这里 │ │ +│ │ 或点击选择文件 │ │ +│ │ │ │ +│ └───────────────────────────────────────────────────────┘ │ +│ │ +│ 📋 最近处理的发票 │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ 文件名 │ 供应商 │ 金额 │ 状态 │ │ +│ ├───────────────────────────────────────────────────────┤ │ +│ │ INV001.pdf │ ABC Company │ 1,250 │ ✅ 已导入 │ │ +│ │ INV002.pdf │ XYZ AB │ 3,450 │ ✅ 已导入 │ │ +│ │ INV003.pdf │ (未匹配) │ 890 │ ⚠️ 待确认 │ │ +│ └───────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +#### 结果确认界面 + +**说明: 此界面在Invoice Master独立Web应用中显示,用于用户确认OCR识别结果。 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 确认发票信息 [✕] [✓] │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 📄 Invoice_F2024_001.pdf │ +│ │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ 供应商信息 │ │ +│ │ ┌─────────────────────────────────────────────────┐ │ │ +│ │ │ 名称: ABC Company │ │ │ +│ │ │ 组织号: 556677-8899 │ │ │ +│ │ │ 状态: ✅ 已匹配现有供应商 │ │ │ +│ │ └─────────────────────────────────────────────────┘ │ │ +│ └───────────────────────────────────────────────────────┘ │ +│ │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ 发票信息 │ │ +│ │ ┌─────────────────────────────────────────────────┐ │ │ +│ │ │ 发票号: F2024-001 │ │ │ +│ │ │ 日期: 2024-01-15 │ │ │ +│ │ │ 到期日: 2024-02-15 │ │ │ +│ │ │ 金额: 1,250.00 SEK │ │ │ +│ │ │ OCR: 7350012345678 │ │ │ +│ │ └─────────────────────────────────────────────────┘ │ │ +│ └───────────────────────────────────────────────────────┘ │ +│ │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ 会计科目 │ │ +│ │ 借方: 5460 - Kontorsmaterial 1,000.00 │ │ +│ │ 借方: 2610 - Ingående moms 250.00 │ │ +│ │ 贷方: 2440 - Leverantörsskulder 1,250.00 │ │ +│ └───────────────────────────────────────────────────────┘ │ +│ │ +│ [编辑信息] [重新识别] [取消] [确认并导入到Fortnox] │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## UI设计规范 + +### 7.1 设计原则 + +**独立Web应用设计原则:** + +1. **品牌一致性**: 保持Invoice Master品牌,同时尊重Fortnox用户习惯 +2. **简洁高效**: 发票处理是高频操作,界面必须简洁快速 +3. **清晰反馈**: OCR识别结果必须清晰展示,便于用户确认 +4. **无缝集成**: 虽然是独立应用,但要让用户感觉与Fortnox是一体的 + +### 7.2 响应式设计 + +**断点定义:** + +| 断点 | 宽度 | 布局 | +|------|------|------| +| Mobile | < 768px | 单列,堆叠布局 | +| Tablet | 768px - 1024px | 双列布局 | +| Desktop | > 1024px | 三列布局 | + +### 7.3 组件规范 + +#### 文件上传区域 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 📤 上传发票 │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ 📄 拖放PDF文件到这里 │ │ +│ │ │ │ +│ │ 或点击选择 │ │ +│ │ │ │ +│ │ 支持格式: PDF, JPG, PNG (最大10MB) │ │ +│ │ │ │ +│ └───────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +**交互状态:** +- 默认: 灰色边框,虚线 +- 悬停: 蓝色边框,背景变浅蓝 +- 拖入: 蓝色边框,背景变深蓝 +- 上传中: 显示进度条 + +#### 发票卡片 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 📄 Invoice_F2024_001.pdf [✓] [✏️] [🗑️]│ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 供应商: ABC Company (556677-8899) │ +│ 金额: 1,250.00 SEK │ +│ 日期: 2024-01-15 │ +│ │ +│ 状态: ✅ 已导入到Fortnox │ +│ 凭证: A-1234 [在Fortnox中查看] │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 7.4 颜色规范 + +**主色调:** + +| 用途 | 颜色 | Hex | +|------|------|-----| +| 主色 | 蓝色 | #2563EB | +| 成功 | 绿色 | #10B981 | +| 警告 | 黄色 | #F59E0B | +| 错误 | 红色 | #EF4444 | +| 背景 | 浅灰 | #F9FAFB | +| 文字 | 深灰 | #1F2937 | + +**Fortnox品牌协调:** +- 使用Fortnox的蓝色作为次要色 (#0057FF) +- 在"导入到Fortnox"按钮中使用Fortnox品牌色 + +### 7.5 字体规范 + +| 元素 | 字体 | 大小 | 字重 | +|------|------|------|------| +| 标题 | Inter | 24px | 600 | +| 副标题 | Inter | 18px | 500 | +| 正文 | Inter | 14px | 400 | +| 小字 | Inter | 12px | 400 | +| 数字 | Inter | 16px | 600 (等宽) | + +--- + +## API设计 + +### 8.1 REST API端点 + +#### 认证相关 + +```http +# 获取Fortnox授权URL +GET /api/v1/fortnox/auth/url +Response: { + "authorization_url": "https://apps.fortnox.se/oauth-v1/auth?...", + "state": "random_state_string" +} + +# OAuth回调处理 +GET /api/v1/fortnox/auth/callback?code=xxx&state=xxx +Response: { + "status": "success", + "company_name": "My Company AB", + "connected_at": "2024-01-15T10:30:00Z" +} + +# 断开连接 +DELETE /api/v1/fortnox/auth +Response: { + "status": "disconnected" +} +``` + +#### 发票处理 + +```http +# 上传并处理发票 +POST /api/v1/fortnox/invoices +Content-Type: multipart/form-data +Body: { + "file": [PDF file], + "auto_import": false, // 是否自动导入,false则返回预览 + "settings": { + "voucher_series": "A", + "attach_pdf": true + } +} + +Response (预览模式): { + "id": "uuid", + "status": "preview", + "extraction": { + "supplier_name": "ABC Company", + "supplier_org_number": "556677-8899", + "invoice_number": "F2024-001", + "invoice_date": "2024-01-15", + "amount_total": 1250.00, + "ocr_number": "7350012345678" + }, + "supplier_match": { + "action": "USE_EXISTING", + "supplier_number": "123", + "confidence": 1.0 + }, + "voucher_preview": { + "rows": [...] + } +} + +Response (自动导入模式): { + "id": "uuid", + "status": "imported", + "voucher": { + "voucher_number": "1234", + "series": "A", + "url": "https://api.fortnox.se/3/vouchers/A/1234" + }, + "fortnox_url": "https://apps.fortnox.se/..." +} +``` + +#### 供应商管理 + +```http +# 获取Fortnox供应商列表 +GET /api/v1/fortnox/suppliers +Response: { + "suppliers": [ + { + "supplier_number": "123", + "name": "ABC Company", + "organisation_number": "556677-8899" + } + ] +} + +# 创建供应商 +POST /api/v1/fortnox/suppliers +Body: { + "name": "New Supplier", + "organisation_number": "112233-4455", + "address": "..." +} +``` + +### 8.2 Webhook接收 + +```http +# Fortnox Webhook接收端点 +POST /webhooks/fortnox +Headers: { + "X-Fortnox-Event": "voucher.created" +} +Body: { + "event": "voucher.created", + "data": { + "voucher_number": "1234", + "series": "A" + } +} +``` + +--- + +## 数据库设计 + +### 9.1 实体关系图 + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ fortnox_tenants │ │ fortnox_invoices │ │ supplier_cache │ +├─────────────────┤ ├──────────────────┤ ├─────────────────┤ +│ id (PK) │◄──────┤ id (PK) │ │ id (PK) │ +│ organization_id │ │ tenant_id (FK) │ │ tenant_id (FK) │ +│ access_token │ │ file_path │ │ supplier_number │ +│ refresh_token │ │ extraction_data │ │ name │ +│ expires_at │ │ voucher_id │ │ org_number │ +│ company_name │ │ status │ │ cached_at │ +│ created_at │ │ created_at │ └─────────────────┘ +└─────────────────┘ └──────────────────┘ + │ + │ ┌──────────────────┐ + │ │ processing_queue │ + │ ├──────────────────┤ + └────────►│ id (PK) │ + │ invoice_id (FK) │ + │ status │ + │ retry_count │ + └──────────────────┘ +``` + +### 9.2 SQL Schema + +```sql +-- Fortnox租户表 +CREATE TABLE fortnox_tenants ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + organization_id UUID NOT NULL REFERENCES organizations(id), + + -- OAuth Tokens + access_token TEXT NOT NULL, + refresh_token TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + scope TEXT, + + -- 公司信息 + company_name VARCHAR(255), + company_org_number VARCHAR(20), + + -- 设置 + default_voucher_series VARCHAR(10) DEFAULT 'A', + default_account_code INTEGER DEFAULT 5460, + auto_attach_pdf BOOLEAN DEFAULT true, + + -- 状态 + is_active BOOLEAN DEFAULT true, + last_sync_at TIMESTAMP, + + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW(), + + UNIQUE(organization_id) +); + +-- Fortnox发票处理记录 +CREATE TABLE fortnox_invoices ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + tenant_id UUID NOT NULL REFERENCES fortnox_tenants(id), + + -- 文件信息 + original_filename VARCHAR(255), + storage_path TEXT, + file_size INTEGER, + + -- OCR提取结果 + extraction_data JSONB, + extraction_confidence DECIMAL(3,2), + + -- 供应商匹配 + supplier_number VARCHAR(50), + supplier_match_confidence DECIMAL(3,2), + supplier_match_action VARCHAR(20), -- USE_EXISTING, CREATE_NEW, SUGGEST_MATCH + + -- Fortnox凭证 + voucher_series VARCHAR(10), + voucher_number VARCHAR(50), + voucher_url TEXT, + + -- 处理状态 + status VARCHAR(20) DEFAULT 'pending', -- pending, processing, preview, imported, failed + error_message TEXT, + + -- 用户操作 + reviewed_by UUID, + reviewed_at TIMESTAMP, + + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() +); + +-- 供应商缓存 +CREATE TABLE supplier_cache ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + tenant_id UUID NOT NULL REFERENCES fortnox_tenants(id), + + supplier_number VARCHAR(50) NOT NULL, + name VARCHAR(255), + organisation_number VARCHAR(20), + address TEXT, + phone VARCHAR(50), + email VARCHAR(255), + + cached_at TIMESTAMP DEFAULT NOW(), + + UNIQUE(tenant_id, supplier_number) +); + +-- 处理队列 +CREATE TABLE processing_queue ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + invoice_id UUID NOT NULL REFERENCES fortnox_invoices(id), + + status VARCHAR(20) DEFAULT 'queued', -- queued, processing, completed, failed + priority INTEGER DEFAULT 5, + retry_count INTEGER DEFAULT 0, + max_retries INTEGER DEFAULT 3, + + scheduled_at TIMESTAMP DEFAULT NOW(), + started_at TIMESTAMP, + completed_at TIMESTAMP, + error_message TEXT +); + +-- 索引 +CREATE INDEX idx_fortnox_invoices_tenant ON fortnox_invoices(tenant_id); +CREATE INDEX idx_fortnox_invoices_status ON fortnox_invoices(status); +CREATE INDEX idx_supplier_cache_tenant ON supplier_cache(tenant_id); +CREATE INDEX idx_processing_queue_status ON processing_queue(status); +``` + +--- + +## 安全设计 + +### 10.1 认证安全 + +**Token存储:** +- Access Token和Refresh Token使用AES-256加密存储 +- 加密密钥存储在Azure Key Vault / AWS Secrets Manager +- Token定期轮换 + +**OAuth安全:** +- 使用state参数防止CSRF攻击 +- 强制HTTPS回调 +- 授权码一次性使用 + +### 10.2 数据安全 + +**传输安全:** +- 所有API通信强制TLS 1.3 +- 证书固定(Certificate Pinning)防止中间人攻击 + +**存储安全:** +- 发票PDF加密存储(AES-256) +- 数据库连接使用SSL +- 敏感字段加密(组织号、银行信息) + +### 10.3 访问控制 + +```python +# 权限检查装饰器 +async def require_fortnox_connection(func): + @wraps(func) + async def wrapper(request: Request, *args, **kwargs): + tenant_id = request.headers.get('X-Tenant-ID') + + # 检查是否已连接Fortnox + connection = await get_fortnox_connection(tenant_id) + if not connection or not connection.is_active: + raise HTTPException( + status_code=401, + detail="Fortnox connection required" + ) + + # 检查Token是否有效 + if connection.is_token_expired(): + await refresh_fortnox_token(connection) + + return await func(request, *args, **kwargs) + return wrapper +``` + +--- + +## 错误处理 + +### 11.1 错误分类 + +| 错误类型 | 示例 | 处理策略 | +|---------|------|----------| +| **认证错误** | Token过期、无效 | 自动刷新或提示重新授权 | +| **API限制** | 429 Too Many Requests | 指数退避重试 | +| **数据错误** | 无效的组织号格式 | 返回具体验证错误 | +| **网络错误** | 连接超时 | 重试3次后失败 | +| **业务错误** | 供应商不存在 | 提供创建选项 | + +### 11.2 错误响应格式 + +```json +{ + "error": { + "code": "FORTNOX_TOKEN_EXPIRED", + "message": "Fortnox access token has expired", + "details": { + "action": "RECONNECT_REQUIRED", + "reconnect_url": "/api/v1/fortnox/auth/url" + }, + "timestamp": "2024-01-15T10:30:00Z", + "request_id": "req_123456" + } +} +``` + +### 11.3 重试策略 + +```python +class FortnoxAPIRetry: + """Fortnox API重试策略""" + + def __init__(self): + self.max_retries = 3 + self.base_delay = 1 # 秒 + + async def execute(self, func, *args, **kwargs): + for attempt in range(self.max_retries): + try: + return await func(*args, **kwargs) + except FortnoxAPIError as e: + if e.status_code == 429: # Rate limit + delay = self.base_delay * (2 ** attempt) + await asyncio.sleep(delay) + elif e.status_code in [500, 502, 503, 504]: + if attempt < self.max_retries - 1: + delay = self.base_delay * (2 ** attempt) + await asyncio.sleep(delay) + else: + raise + else: + raise +``` + +--- + +## 开发计划 + +### 12.1 里程碑 + +| 阶段 | 时间 | 目标 | 交付物 | +|------|------|------|--------| +| **M1** | Week 1-2 | 基础架构 | 认证模块、数据库 | +| **M2** | Week 3-4 | 核心功能 | 发票处理、供应商匹配 | +| **M3** | Week 5-6 | Fortnox集成 | API集成、凭证创建 | +| **M4** | Week 7-8 | UI开发 | 前端界面、用户流程 | +| **M5** | Week 9-10 | 测试优化 | 测试、性能优化 | +| **M6** | Week 11-12 | 上线准备 | 文档、审核、部署 | + +### 12.2 任务分解 + +**Week 1-2: 基础架构** +- [ ] 创建Fortnox开发者账号 +- [ ] 设计数据库Schema +- [ ] 实现OAuth2认证流程 +- [ ] Token管理和刷新机制 +- [ ] 基础API客户端 + +**Week 3-4: 核心功能** +- [ ] 集成Invoice Master OCR +- [ ] 实现供应商匹配算法 +- [ ] 文件上传和存储 +- [ ] 异步处理队列 + +**Week 5-6: Fortnox集成** +- [ ] 供应商API集成 +- [ ] 凭证创建逻辑 +- [ ] 文件附件上传 +- [ ] 错误处理和重试 + +**Week 7-8: UI开发** +- [ ] 连接设置页面 +- [ ] 发票上传界面 +- [ ] 结果预览/编辑页面 +- [ ] 历史记录页面 + +**Week 9-10: 测试优化** +- [ ] 单元测试 (目标80%覆盖率) +- [ ] 集成测试 +- [ ] 性能测试 +- [ ] 安全审计 + +**Week 11-12: 上线准备** +- [ ] 用户文档 +- [ ] API文档 +- [ ] Fortnox审核申请 +- [ ] 生产环境部署 + +--- + +## 测试策略 + +### 13.1 测试类型 + +| 测试类型 | 工具 | 覆盖率目标 | 说明 | +|---------|------|-----------|------| +| **单元测试** | pytest | 80% | 核心逻辑 | +| **集成测试** | pytest + httpx | - | Fortnox API交互 | +| **E2E测试** | Playwright | 核心流程 | 用户场景 | +| **性能测试** | Locust | - | 并发处理 | +| **安全测试** | bandit, safety | - | 漏洞扫描 | + +### 13.2 测试用例示例 + +```python +# 供应商匹配测试 +class TestSupplierMatcher: + async def test_exact_org_number_match(self): + """测试组织号精确匹配""" + matcher = FortnoxSupplierMatcher(mock_client) + + result = await matcher.match( + tenant_id="test", + extraction=ExtractionResult( + supplier_org_number="556677-8899" + ) + ) + + assert result.action == 'USE_EXISTING' + assert result.confidence == 1.0 + + async def test_fuzzy_name_match(self): + """测试名称模糊匹配""" + result = await matcher.match( + tenant_id="test", + extraction=ExtractionResult( + supplier_name="ABC Company AB" + ) + ) + + assert result.confidence > 0.85 + +# Fortnox API集成测试 +class TestFortnoxIntegration: + async def test_create_voucher(self): + """测试创建会计凭证""" + creator = FortnoxVoucherCreator(client) + + result = await creator.create_voucher( + tenant_id="test", + extraction=mock_extraction, + supplier_number="123", + settings=mock_settings + ) + + assert result.voucher_id is not None +``` + +--- + +## 部署方案 + +### 14.1 架构部署图 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Azure │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ Container │ │ PostgreSQL │ │ Blob │ │ +│ │ Apps │ │ Flexible │ │ Storage │ │ +│ │ (FastAPI) │ │ Server │ │ │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ │ │ │ │ +│ └─────────────────┼─────────────────┘ │ +│ │ │ +│ ┌──────▼──────┐ │ +│ │ Redis │ │ +│ │ Cache │ │ +│ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 14.2 资源配置 + +**Container Apps:** +- CPU: 1 vCPU +- Memory: 2 GiB +- Min replicas: 1 +- Max replicas: 5 + +**PostgreSQL:** +- SKU: Standard_B1ms +- Storage: 32 GB +- Backup: 7 days + +**Blob Storage:** +- Tier: Hot +- Redundancy: LRS + +### 14.3 部署流程 + +```bash +# 1. 基础设施部署 +az group create --name invoice-master-rg --location swedencentral + +# 2. 数据库部署 +az postgres flexible-server create \ + --name invoice-master-db \ + --resource-group invoice-master-rg \ + --sku-name Standard_B1ms + +# 3. 应用部署 +az containerapp create \ + --name invoice-master-fortnox \ + --resource-group invoice-master-rg \ + --image invoicemaster.azurecr.io/fortnox-integration:latest \ + --cpu 1 --memory 2Gi \ + --min-replicas 1 --max-replicas 5 +``` + +--- + +## 附录 + +### A. Fortnox API参考 + +**常用端点速查:** + +| 功能 | 方法 | 端点 | +|------|------|------| +| 获取公司信息 | GET | /3/companyinformation | +| 获取供应商列表 | GET | /3/suppliers | +| 创建供应商 | POST | /3/suppliers | +| 获取会计科目 | GET | /3/accounts | +| 创建凭证 | POST | /3/vouchers | +| 上传文件 | POST | /3/inbox | + +### B. 科目表参考 (BAS2024) + +**常用费用科目:** + +| 代码 | 名称 | 说明 | +|------|------|------| +| 2440 | Leverantörsskulder | 应付账款 | +| 2610 | Ingående moms | 进项VAT 25% | +| 2620 | Ingående moms | 进项VAT 12% | +| 2630 | Ingående moms | 进项VAT 6% | +| 5460 | Kontorsmaterial | 办公用品 | +| 5710 | Frakter | 运输费 | +| 6100 | Övriga externa tjänster | 其他外部服务 | +| 6210 | Konsultarvoden | 咨询费 | + +### C. 错误代码表 + +| 代码 | 说明 | HTTP状态 | +|------|------|---------| +| FORTNOX_TOKEN_EXPIRED | Token过期 | 401 | +| FORTNOX_RATE_LIMITED | 请求过于频繁 | 429 | +| SUPPLIER_NOT_FOUND | 供应商不存在 | 404 | +| INVALID_ORG_NUMBER | 无效的组织号 | 400 | +| EXTRACTION_FAILED | OCR提取失败 | 422 | + +### D. 相关链接 + +- [Fortnox Developer Portal](https://developer.fortnox.se/) +- [Fortnox API Docs](https://api.fortnox.se/apidocs) +- [BAS Kontoplan](https://www.bas.se/) + +--- + +**文档版本历史:** + +| 版本 | 日期 | 作者 | 变更说明 | +|------|------|------|---------| +| 1.0 | 2026-02-01 | Claude Code | 初始版本,添加Fortnox集成模式说明和UI设计规范 | + +--- + +**审批:** + +- [ ] 技术负责人 +- [ ] 产品经理 +- [ ] 安全团队 diff --git a/frontend/src/api/endpoints/dashboard.ts b/frontend/src/api/endpoints/dashboard.ts new file mode 100644 index 0000000..fd73fb3 --- /dev/null +++ b/frontend/src/api/endpoints/dashboard.ts @@ -0,0 +1,25 @@ +import apiClient from '../client' +import type { + DashboardStatsResponse, + DashboardActiveModelResponse, + RecentActivityResponse, +} from '../types' + +export const dashboardApi = { + getStats: async (): Promise => { + const response = await apiClient.get('/api/v1/admin/dashboard/stats') + return response.data + }, + + getActiveModel: async (): Promise => { + const response = await apiClient.get('/api/v1/admin/dashboard/active-model') + return response.data + }, + + getRecentActivity: async (limit: number = 10): Promise => { + const response = await apiClient.get('/api/v1/admin/dashboard/activity', { + params: { limit }, + }) + return response.data + }, +} diff --git a/frontend/src/api/endpoints/index.ts b/frontend/src/api/endpoints/index.ts index 554ac30..1533939 100644 --- a/frontend/src/api/endpoints/index.ts +++ b/frontend/src/api/endpoints/index.ts @@ -5,3 +5,4 @@ export { inferenceApi } from './inference' export { datasetsApi } from './datasets' export { augmentationApi } from './augmentation' export { modelsApi } from './models' +export { dashboardApi } from './dashboard' diff --git a/frontend/src/api/types.ts b/frontend/src/api/types.ts index 7ceda95..46a35f5 100644 --- a/frontend/src/api/types.ts +++ b/frontend/src/api/types.ts @@ -362,3 +362,48 @@ export interface ActiveModelResponse { has_active_model: boolean model: ModelVersionItem | null } + +// Dashboard types + +export interface DashboardStatsResponse { + total_documents: number + annotation_complete: number + annotation_incomplete: number + pending: number + completeness_rate: number +} + +export interface DashboardActiveModelInfo { + version_id: string + version: string + name: string + metrics_mAP: number | null + metrics_precision: number | null + metrics_recall: number | null + document_count: number + activated_at: string | null +} + +export interface DashboardRunningTrainingInfo { + task_id: string + name: string + status: string + started_at: string | null + progress: number +} + +export interface DashboardActiveModelResponse { + model: DashboardActiveModelInfo | null + running_training: DashboardRunningTrainingInfo | null +} + +export interface ActivityItem { + type: 'document_uploaded' | 'annotation_modified' | 'training_completed' | 'training_failed' | 'model_activated' + description: string + timestamp: string + metadata: Record +} + +export interface RecentActivityResponse { + activities: ActivityItem[] +} diff --git a/frontend/src/components/DashboardOverview.tsx b/frontend/src/components/DashboardOverview.tsx index 6de1561..d0cf234 100644 --- a/frontend/src/components/DashboardOverview.tsx +++ b/frontend/src/components/DashboardOverview.tsx @@ -1,47 +1,58 @@ import React from 'react' -import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react' -import { Button } from './Button' -import { useDocuments } from '../hooks/useDocuments' -import { useTraining } from '../hooks/useTraining' +import { FileText, CheckCircle, AlertCircle, Clock, RefreshCw } from 'lucide-react' +import { + StatsCard, + DataQualityPanel, + ActiveModelPanel, + RecentActivityPanel, + SystemStatusBar, +} from './dashboard/index' +import { useDashboard } from '../hooks/useDashboard' interface DashboardOverviewProps { onNavigate: (view: string) => void } export const DashboardOverview: React.FC = ({ onNavigate }) => { - const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 }) - const { models, isLoadingModels } = useTraining() + const { + stats, + model, + runningTraining, + activities, + isLoading, + error, + } = useDashboard() - const stats = [ - { - label: 'Total Documents', - value: docsLoading ? '...' : totalDocs.toString(), - icon: FileText, - color: 'text-warm-text-primary', - bgColor: 'bg-warm-bg', - }, - { - label: 'Labeled', - value: '0', - icon: CheckCircle, - color: 'text-warm-state-success', - bgColor: 'bg-green-50', - }, - { - label: 'Pending', - value: '0', - icon: Clock, - color: 'text-warm-state-warning', - bgColor: 'bg-yellow-50', - }, - { - label: 'Training Models', - value: isLoadingModels ? '...' : models.length.toString(), - icon: TrendingUp, - color: 'text-warm-state-info', - bgColor: 'bg-blue-50', - }, - ] + const handleStatsClick = (filter?: string) => { + if (filter) { + onNavigate(`documents?status=${filter}`) + } else { + onNavigate('documents') + } + } + + if (error) { + return ( +
+
+ +

+ Failed to load dashboard +

+

+ {error instanceof Error ? error.message : 'An unexpected error occurred'} +

+ +
+
+ ) + } return (
@@ -55,94 +66,74 @@ export const DashboardOverview: React.FC = ({ onNavigate

- {/* Stats Grid */} + {/* Stats Cards Row */}
- {stats.map((stat) => ( -
-
-
- -
-
-

- {stat.value} -

-

{stat.label}

-
- ))} + handleStatsClick()} + /> + handleStatsClick('labeled')} + /> + handleStatsClick('labeled')} + /> + handleStatsClick('pending')} + />
- {/* Quick Actions */} -
-

- Quick Actions -

-
- - - -
+ {/* Two-column layout: Data Quality + Active Model */} +
+ handleStatsClick('labeled')} + /> + onNavigate('training')} + />
{/* Recent Activity */} -
-
-

- Recent Activity -

-
-
-
- -

No recent activity

-

- Start by uploading documents or creating training jobs -

-
-
+
+
{/* System Status */} -
-

- System Status -

-
-
- Backend API - - - Online - -
-
- Database - - - Connected - -
-
- GPU - - - Available - -
-
-
+
) } diff --git a/frontend/src/components/dashboard/ActiveModelPanel.tsx b/frontend/src/components/dashboard/ActiveModelPanel.tsx new file mode 100644 index 0000000..7baac3f --- /dev/null +++ b/frontend/src/components/dashboard/ActiveModelPanel.tsx @@ -0,0 +1,143 @@ +import React from 'react' +import { TrendingUp } from 'lucide-react' +import { Button } from '../Button' +import type { DashboardActiveModelInfo, DashboardRunningTrainingInfo } from '../../api/types' + +interface ActiveModelPanelProps { + model: DashboardActiveModelInfo | null + runningTraining: DashboardRunningTrainingInfo | null + isLoading?: boolean + onGoToTraining?: () => void +} + +const formatDate = (dateStr: string | null): string => { + if (!dateStr) return 'N/A' + const date = new Date(dateStr) + return date.toLocaleDateString('en-US', { + year: 'numeric', + month: 'short', + day: 'numeric', + }) +} + +const formatMetric = (value: number | null): string => { + if (value === null) return 'N/A' + return `${(value * 100).toFixed(1)}%` +} + +const getMetricColor = (value: number | null): string => { + if (value === null) return 'text-warm-text-muted' + if (value >= 0.9) return 'text-green-600' + if (value >= 0.8) return 'text-yellow-600' + return 'text-red-600' +} + +export const ActiveModelPanel: React.FC = ({ + model, + runningTraining, + isLoading = false, + onGoToTraining, +}) => { + if (isLoading) { + return ( +
+

+ Active Model +

+
+
Loading...
+
+
+ ) + } + + if (!model) { + return ( +
+

+ Active Model +

+
+ +

No Active Model

+

+ Train and activate a model to see stats here +

+ {onGoToTraining && ( + + )} +
+
+ ) + } + + return ( +
+

+ Active Model +

+ +
+ {model.version} + - {model.name} +
+ +
+
+
+

+ {formatMetric(model.metrics_mAP)} +

+

mAP

+
+
+

+ {formatMetric(model.metrics_precision)} +

+

Precision

+
+
+

+ {formatMetric(model.metrics_recall)} +

+

Recall

+
+
+
+ +
+

+ Activated:{' '} + {formatDate(model.activated_at)} +

+

+ Documents:{' '} + {model.document_count.toLocaleString()} +

+
+ + {runningTraining && ( +
+
+ + + Training in Progress + +
+

{runningTraining.name}

+
+
+
+

+ {runningTraining.progress}% complete +

+
+ )} +
+ ) +} diff --git a/frontend/src/components/dashboard/DataQualityPanel.tsx b/frontend/src/components/dashboard/DataQualityPanel.tsx new file mode 100644 index 0000000..a775e35 --- /dev/null +++ b/frontend/src/components/dashboard/DataQualityPanel.tsx @@ -0,0 +1,105 @@ +import React from 'react' +import { Button } from '../Button' + +interface DataQualityPanelProps { + completenessRate: number + completeCount: number + incompleteCount: number + pendingCount: number + isLoading?: boolean + onViewIncomplete?: () => void +} + +export const DataQualityPanel: React.FC = ({ + completenessRate, + completeCount, + incompleteCount, + pendingCount, + isLoading = false, + onViewIncomplete, +}) => { + const radius = 54 + const circumference = 2 * Math.PI * radius + const strokeDashoffset = circumference - (completenessRate / 100) * circumference + + return ( +
+

+ Data Quality +

+ +
+
+ + + + +
+ + {isLoading ? '...' : `${Math.round(completenessRate)}%`} + +
+
+ +
+

+ Annotation Complete +

+ +
+
+ + + Complete + + {isLoading ? '...' : completeCount} +
+
+ + + Incomplete + + {isLoading ? '...' : incompleteCount} +
+
+ + + Pending + + {isLoading ? '...' : pendingCount} +
+
+
+
+ + {onViewIncomplete && incompleteCount > 0 && ( +
+ +
+ )} +
+ ) +} diff --git a/frontend/src/components/dashboard/RecentActivityPanel.tsx b/frontend/src/components/dashboard/RecentActivityPanel.tsx new file mode 100644 index 0000000..0092b48 --- /dev/null +++ b/frontend/src/components/dashboard/RecentActivityPanel.tsx @@ -0,0 +1,134 @@ +import React from 'react' +import { + FileText, + Edit, + CheckCircle, + XCircle, + Rocket, + Activity, +} from 'lucide-react' +import type { ActivityItem } from '../../api/types' + +interface RecentActivityPanelProps { + activities: ActivityItem[] + isLoading?: boolean + onSeeAll?: () => void +} + +const getActivityIcon = (type: ActivityItem['type']) => { + switch (type) { + case 'document_uploaded': + return { Icon: FileText, color: 'text-blue-500', bg: 'bg-blue-50' } + case 'annotation_modified': + return { Icon: Edit, color: 'text-orange-500', bg: 'bg-orange-50' } + case 'training_completed': + return { Icon: CheckCircle, color: 'text-green-500', bg: 'bg-green-50' } + case 'training_failed': + return { Icon: XCircle, color: 'text-red-500', bg: 'bg-red-50' } + case 'model_activated': + return { Icon: Rocket, color: 'text-purple-500', bg: 'bg-purple-50' } + default: + return { Icon: Activity, color: 'text-gray-500', bg: 'bg-gray-50' } + } +} + +const formatTimestamp = (timestamp: string): string => { + const date = new Date(timestamp) + const now = new Date() + const diffMs = now.getTime() - date.getTime() + const diffMinutes = Math.floor(diffMs / 60000) + const diffHours = Math.floor(diffMs / 3600000) + const diffDays = Math.floor(diffMs / 86400000) + + if (diffMinutes < 1) return 'just now' + if (diffMinutes < 60) return `${diffMinutes} minutes ago` + if (diffHours < 24) return `${diffHours} hours ago` + if (diffDays === 1) return 'yesterday' + if (diffDays < 7) return `${diffDays} days ago` + + return date.toLocaleDateString('en-US', { month: 'short', day: 'numeric' }) +} + +export const RecentActivityPanel: React.FC = ({ + activities, + isLoading = false, + onSeeAll, +}) => { + if (isLoading) { + return ( +
+
+

+ Recent Activity +

+
+
+
+
Loading...
+
+
+
+ ) + } + + if (activities.length === 0) { + return ( +
+
+

+ Recent Activity +

+
+
+
+ +

No recent activity

+

+ Start by uploading documents or creating training jobs +

+
+
+
+ ) + } + + return ( +
+
+

+ Recent Activity +

+ {onSeeAll && ( + + )} +
+
+ {activities.map((activity, index) => { + const { Icon, color, bg } = getActivityIcon(activity.type) + + return ( +
+
+ +
+

+ {activity.description} +

+ + {formatTimestamp(activity.timestamp)} + +
+ ) + })} +
+
+ ) +} diff --git a/frontend/src/components/dashboard/StatsCard.tsx b/frontend/src/components/dashboard/StatsCard.tsx new file mode 100644 index 0000000..e917350 --- /dev/null +++ b/frontend/src/components/dashboard/StatsCard.tsx @@ -0,0 +1,44 @@ +import React from 'react' +import { LucideIcon } from 'lucide-react' + +interface StatsCardProps { + label: string + value: string | number + icon: LucideIcon + iconColor: string + iconBgColor: string + onClick?: () => void + isLoading?: boolean +} + +export const StatsCard: React.FC = ({ + label, + value, + icon: Icon, + iconColor, + iconBgColor, + onClick, + isLoading = false, +}) => { + return ( +
e.key === 'Enter' && onClick() : undefined} + > +
+
+ +
+
+

+ {isLoading ? '...' : value} +

+

{label}

+
+ ) +} diff --git a/frontend/src/components/dashboard/SystemStatusBar.tsx b/frontend/src/components/dashboard/SystemStatusBar.tsx new file mode 100644 index 0000000..f9d487f --- /dev/null +++ b/frontend/src/components/dashboard/SystemStatusBar.tsx @@ -0,0 +1,62 @@ +import React from 'react' + +interface StatusItem { + label: string + status: 'online' | 'degraded' | 'offline' + statusText: string +} + +interface SystemStatusBarProps { + items?: StatusItem[] +} + +const getStatusColor = (status: StatusItem['status']) => { + switch (status) { + case 'online': + return 'bg-green-500' + case 'degraded': + return 'bg-yellow-500' + case 'offline': + return 'bg-red-500' + } +} + +const getStatusTextColor = (status: StatusItem['status']) => { + switch (status) { + case 'online': + return 'text-warm-state-success' + case 'degraded': + return 'text-yellow-600' + case 'offline': + return 'text-red-600' + } +} + +const defaultItems: StatusItem[] = [ + { label: 'Backend API', status: 'online', statusText: 'Online' }, + { label: 'Database', status: 'online', statusText: 'Connected' }, + { label: 'GPU', status: 'online', statusText: 'Available' }, +] + +export const SystemStatusBar: React.FC = ({ + items = defaultItems, +}) => { + return ( +
+

+ System Status +

+
+ {items.map((item) => ( +
+ {item.label} + + + {item.statusText} + +
+ ))} +
+
+ ) +} diff --git a/frontend/src/components/dashboard/index.ts b/frontend/src/components/dashboard/index.ts new file mode 100644 index 0000000..c43abcf --- /dev/null +++ b/frontend/src/components/dashboard/index.ts @@ -0,0 +1,5 @@ +export { StatsCard } from './StatsCard' +export { DataQualityPanel } from './DataQualityPanel' +export { ActiveModelPanel } from './ActiveModelPanel' +export { RecentActivityPanel } from './RecentActivityPanel' +export { SystemStatusBar } from './SystemStatusBar' diff --git a/frontend/src/hooks/index.ts b/frontend/src/hooks/index.ts index fb72d26..47b6ad5 100644 --- a/frontend/src/hooks/index.ts +++ b/frontend/src/hooks/index.ts @@ -5,3 +5,4 @@ export { useTraining, useTrainingDocuments } from './useTraining' export { useDatasets, useDatasetDetail } from './useDatasets' export { useAugmentation } from './useAugmentation' export { useModels, useModelDetail, useActiveModel } from './useModels' +export { useDashboard, useDashboardStats, useActiveModel as useDashboardActiveModel, useRecentActivity } from './useDashboard' diff --git a/frontend/src/hooks/useDashboard.ts b/frontend/src/hooks/useDashboard.ts new file mode 100644 index 0000000..4b3d35f --- /dev/null +++ b/frontend/src/hooks/useDashboard.ts @@ -0,0 +1,76 @@ +import { useQuery } from '@tanstack/react-query' +import { dashboardApi } from '../api/endpoints' +import type { + DashboardStatsResponse, + DashboardActiveModelResponse, + RecentActivityResponse, +} from '../api/types' + +export const useDashboardStats = () => { + const { data, isLoading, error, refetch } = useQuery({ + queryKey: ['dashboard', 'stats'], + queryFn: () => dashboardApi.getStats(), + staleTime: 30000, + refetchInterval: 60000, + }) + + return { + stats: data, + isLoading, + error, + refetch, + } +} + +export const useActiveModel = () => { + const { data, isLoading, error, refetch } = useQuery({ + queryKey: ['dashboard', 'active-model'], + queryFn: () => dashboardApi.getActiveModel(), + staleTime: 30000, + refetchInterval: 60000, + }) + + return { + model: data?.model ?? null, + runningTraining: data?.running_training ?? null, + isLoading, + error, + refetch, + } +} + +export const useRecentActivity = (limit: number = 10) => { + const { data, isLoading, error, refetch } = useQuery({ + queryKey: ['dashboard', 'activity', limit], + queryFn: () => dashboardApi.getRecentActivity(limit), + staleTime: 30000, + refetchInterval: 60000, + }) + + return { + activities: data?.activities ?? [], + isLoading, + error, + refetch, + } +} + +export const useDashboard = () => { + const stats = useDashboardStats() + const activeModel = useActiveModel() + const activity = useRecentActivity() + + return { + stats: stats.stats, + model: activeModel.model, + runningTraining: activeModel.runningTraining, + activities: activity.activities, + isLoading: stats.isLoading || activeModel.isLoading || activity.isLoading, + error: stats.error || activeModel.error || activity.error, + refetch: () => { + stats.refetch() + activeModel.refetch() + activity.refetch() + }, + } +} diff --git a/packages/inference/inference/data/database.py b/packages/inference/inference/data/database.py index 656636e..a6909c6 100644 --- a/packages/inference/inference/data/database.py +++ b/packages/inference/inference/data/database.py @@ -175,6 +175,80 @@ def run_migrations() -> None: ); """, ), + # Migration 007: Add extra columns to training_tasks + ( + "training_tasks_name", + """ + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255); + UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL; + ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL; + CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name); + """, + ), + ( + "training_tasks_description", + """ + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT; + """, + ), + ( + "training_tasks_admin_token", + """ + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255); + """, + ), + ( + "training_tasks_task_type", + """ + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train'; + """, + ), + ( + "training_tasks_recurring", + """ + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50); + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE; + """, + ), + ( + "training_tasks_metrics", + """ + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB; + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0; + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION; + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION; + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION; + CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP); + """, + ), + ( + "training_tasks_updated_at", + """ + ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(); + """, + ), + # Migration 008: Fix model_versions foreign key constraints + ( + "model_versions_fk_fix", + """ + ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey; + ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey; + ALTER TABLE model_versions + ADD CONSTRAINT model_versions_dataset_id_fkey + FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL; + ALTER TABLE model_versions + ADD CONSTRAINT model_versions_task_id_fkey + FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL; + """, + ), + # Migration 006b: Ensure only one active model at a time + ( + "model_versions_single_active", + """ + CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active + ON model_versions(is_active) WHERE is_active = TRUE; + """, + ), ] with engine.connect() as conn: diff --git a/packages/inference/inference/data/repositories/annotation_repository.py b/packages/inference/inference/data/repositories/annotation_repository.py index 9de9b30..701b61c 100644 --- a/packages/inference/inference/data/repositories/annotation_repository.py +++ b/packages/inference/inference/data/repositories/annotation_repository.py @@ -193,6 +193,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]): annotation = session.get(AdminAnnotation, UUID(annotation_id)) if annotation: session.delete(annotation) + session.commit() return True return False @@ -216,6 +217,7 @@ class AnnotationRepository(BaseRepository[AdminAnnotation]): count = len(annotations) for ann in annotations: session.delete(ann) + session.commit() return count def verify( diff --git a/packages/inference/inference/data/repositories/dataset_repository.py b/packages/inference/inference/data/repositories/dataset_repository.py index c714ea0..3e8d20e 100644 --- a/packages/inference/inference/data/repositories/dataset_repository.py +++ b/packages/inference/inference/data/repositories/dataset_repository.py @@ -203,6 +203,14 @@ class DatasetRepository(BaseRepository[TrainingDataset]): dataset = session.get(TrainingDataset, UUID(str(dataset_id))) if not dataset: return False + # Delete associated document links first + doc_links = session.exec( + select(DatasetDocument).where( + DatasetDocument.dataset_id == UUID(str(dataset_id)) + ) + ).all() + for link in doc_links: + session.delete(link) session.delete(dataset) session.commit() return True diff --git a/packages/inference/inference/data/repositories/document_repository.py b/packages/inference/inference/data/repositories/document_repository.py index 69dca6b..ccc2164 100644 --- a/packages/inference/inference/data/repositories/document_repository.py +++ b/packages/inference/inference/data/repositories/document_repository.py @@ -264,6 +264,7 @@ class DocumentRepository(BaseRepository[AdminDocument]): for ann in annotations: session.delete(ann) session.delete(document) + session.commit() return True return False @@ -389,7 +390,11 @@ class DocumentRepository(BaseRepository[AdminDocument]): return None now = datetime.now(timezone.utc) - if doc.annotation_lock_until and doc.annotation_lock_until > now: + lock_until = doc.annotation_lock_until + # Handle PostgreSQL returning offset-naive datetimes + if lock_until and lock_until.tzinfo is None: + lock_until = lock_until.replace(tzinfo=timezone.utc) + if lock_until and lock_until > now: return None doc.annotation_lock_until = now + timedelta(seconds=duration_seconds) @@ -433,10 +438,14 @@ class DocumentRepository(BaseRepository[AdminDocument]): return None now = datetime.now(timezone.utc) - if not doc.annotation_lock_until or doc.annotation_lock_until <= now: + lock_until = doc.annotation_lock_until + # Handle PostgreSQL returning offset-naive datetimes + if lock_until and lock_until.tzinfo is None: + lock_until = lock_until.replace(tzinfo=timezone.utc) + if not lock_until or lock_until <= now: return None - doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds) + doc.annotation_lock_until = lock_until + timedelta(seconds=additional_seconds) session.add(doc) session.commit() session.refresh(doc) diff --git a/packages/inference/inference/data/repositories/training_task_repository.py b/packages/inference/inference/data/repositories/training_task_repository.py index 2b44ee9..1ec7edf 100644 --- a/packages/inference/inference/data/repositories/training_task_repository.py +++ b/packages/inference/inference/data/repositories/training_task_repository.py @@ -118,6 +118,22 @@ class TrainingTaskRepository(BaseRepository[TrainingTask]): session.expunge(r) return list(results) + def get_running(self) -> TrainingTask | None: + """Get currently running training task. + + Returns: + Running task or None if no task is running + """ + with get_session_context() as session: + result = session.exec( + select(TrainingTask) + .where(TrainingTask.status == "running") + .order_by(TrainingTask.started_at.desc()) + ).first() + if result: + session.expunge(result) + return result + def update_status( self, task_id: str, diff --git a/packages/inference/inference/pipeline/normalizers/__init__.py b/packages/inference/inference/pipeline/normalizers/__init__.py index eec623a..b4d0390 100644 --- a/packages/inference/inference/pipeline/normalizers/__init__.py +++ b/packages/inference/inference/pipeline/normalizers/__init__.py @@ -55,5 +55,6 @@ def create_normalizer_registry( "Amount": amount_normalizer, "InvoiceDate": date_normalizer, "InvoiceDueDate": date_normalizer, - "supplier_org_number": SupplierOrgNumberNormalizer(), + # Note: field_name is "supplier_organisation_number" (from CLASS_TO_FIELD mapping) + "supplier_organisation_number": SupplierOrgNumberNormalizer(), } diff --git a/packages/inference/inference/web/api/v1/admin/annotations.py b/packages/inference/inference/web/api/v1/admin/annotations.py index 592fedf..68d56a8 100644 --- a/packages/inference/inference/web/api/v1/admin/annotations.py +++ b/packages/inference/inference/web/api/v1/admin/annotations.py @@ -481,11 +481,22 @@ def create_annotation_router() -> APIRouter: detail="At least one field value is required", ) + # Get the actual file path from storage + # document.file_path is a relative storage path like "raw_pdfs/uuid.pdf" + storage = get_storage_helper() + filename = document.file_path.split("/")[-1] if "/" in document.file_path else document.file_path + file_path = storage.get_raw_pdf_local_path(filename) + if file_path is None: + raise HTTPException( + status_code=500, + detail=f"Cannot find PDF file: {document.file_path}", + ) + # Run auto-labeling service = get_auto_label_service() result = service.auto_label_document( document_id=document_id, - file_path=document.file_path, + file_path=str(file_path), field_values=request.field_values, doc_repo=doc_repo, ann_repo=ann_repo, diff --git a/packages/inference/inference/web/api/v1/admin/auth.py b/packages/inference/inference/web/api/v1/admin/auth.py index f1208fc..e9fd2e8 100644 --- a/packages/inference/inference/web/api/v1/admin/auth.py +++ b/packages/inference/inference/web/api/v1/admin/auth.py @@ -6,7 +6,7 @@ FastAPI endpoints for admin token management. import logging import secrets -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from fastapi import APIRouter @@ -41,10 +41,10 @@ def create_auth_router() -> APIRouter: # Generate secure token token = secrets.token_urlsafe(32) - # Calculate expiration + # Calculate expiration (use timezone-aware datetime) expires_at = None if request.expires_in_days: - expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days) + expires_at = datetime.now(timezone.utc) + timedelta(days=request.expires_in_days) # Create token in database tokens.create( diff --git a/packages/inference/inference/web/api/v1/admin/dashboard.py b/packages/inference/inference/web/api/v1/admin/dashboard.py new file mode 100644 index 0000000..0b68b38 --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/dashboard.py @@ -0,0 +1,135 @@ +""" +Dashboard API Routes + +FastAPI endpoints for dashboard statistics and activity. +""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, Query + +from inference.web.core.auth import ( + AdminTokenDep, + get_model_version_repository, + get_training_task_repository, + ModelVersionRepoDep, + TrainingTaskRepoDep, +) +from inference.web.schemas.admin import ( + DashboardStatsResponse, + ActiveModelResponse, + ActiveModelInfo, + RunningTrainingInfo, + RecentActivityResponse, + ActivityItem, +) +from inference.web.services.dashboard_service import ( + DashboardStatsService, + DashboardActivityService, +) + +logger = logging.getLogger(__name__) + + +def create_dashboard_router() -> APIRouter: + """Create dashboard API router.""" + router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"]) + + @router.get( + "/stats", + response_model=DashboardStatsResponse, + summary="Get dashboard statistics", + description="Returns document counts and annotation completeness metrics.", + ) + async def get_dashboard_stats( + admin_token: AdminTokenDep, + ) -> DashboardStatsResponse: + """Get dashboard statistics.""" + service = DashboardStatsService() + stats = service.get_stats() + + return DashboardStatsResponse( + total_documents=stats["total_documents"], + annotation_complete=stats["annotation_complete"], + annotation_incomplete=stats["annotation_incomplete"], + pending=stats["pending"], + completeness_rate=stats["completeness_rate"], + ) + + @router.get( + "/active-model", + response_model=ActiveModelResponse, + summary="Get active model info", + description="Returns current active model and running training status.", + ) + async def get_active_model( + admin_token: AdminTokenDep, + model_repo: ModelVersionRepoDep, + task_repo: TrainingTaskRepoDep, + ) -> ActiveModelResponse: + """Get active model and training status.""" + # Get active model + active_model = model_repo.get_active() + model_info = None + + if active_model: + model_info = ActiveModelInfo( + version_id=str(active_model.version_id), + version=active_model.version, + name=active_model.name, + metrics_mAP=active_model.metrics_mAP, + metrics_precision=active_model.metrics_precision, + metrics_recall=active_model.metrics_recall, + document_count=active_model.document_count, + activated_at=active_model.activated_at, + ) + + # Get running training task + running_task = task_repo.get_running() + training_info = None + + if running_task: + training_info = RunningTrainingInfo( + task_id=str(running_task.task_id), + name=running_task.name, + status=running_task.status, + started_at=running_task.started_at, + progress=running_task.progress or 0, + ) + + return ActiveModelResponse( + model=model_info, + running_training=training_info, + ) + + @router.get( + "/activity", + response_model=RecentActivityResponse, + summary="Get recent activity", + description="Returns recent system activities sorted by timestamp.", + ) + async def get_recent_activity( + admin_token: AdminTokenDep, + limit: Annotated[ + int, + Query(ge=1, le=50, description="Maximum number of activities"), + ] = 10, + ) -> RecentActivityResponse: + """Get recent system activity.""" + service = DashboardActivityService() + activities = service.get_recent_activities(limit=limit) + + return RecentActivityResponse( + activities=[ + ActivityItem( + type=act["type"], + description=act["description"], + timestamp=act["timestamp"], + metadata=act["metadata"], + ) + for act in activities + ] + ) + + return router diff --git a/packages/inference/inference/web/app.py b/packages/inference/inference/web/app.py index 01bb160..af83421 100644 --- a/packages/inference/inference/web/app.py +++ b/packages/inference/inference/web/app.py @@ -44,6 +44,7 @@ from inference.web.api.v1.admin import ( create_locks_router, create_training_router, ) +from inference.web.api.v1.admin.dashboard import create_dashboard_router from inference.web.core.scheduler import start_scheduler, stop_scheduler from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler @@ -115,13 +116,21 @@ def create_app(config: AppConfig | None = None) -> FastAPI: """Application lifespan manager.""" logger.info("Starting Invoice Inference API...") - # Initialize database tables + # Initialize async request database tables try: async_db.create_tables() logger.info("Async database tables ready") except Exception as e: logger.error(f"Failed to initialize async database: {e}") + # Initialize admin database tables (admin_tokens, admin_documents, training_tasks, etc.) + try: + from inference.data.database import create_db_and_tables + create_db_and_tables() + logger.info("Admin database tables ready") + except Exception as e: + logger.error(f"Failed to initialize admin database: {e}") + # Initialize inference service on startup try: inference_service.initialize() @@ -279,6 +288,10 @@ def create_app(config: AppConfig | None = None) -> FastAPI: augmentation_router = create_augmentation_router() app.include_router(augmentation_router, prefix="/api/v1/admin") + # Include dashboard routes + dashboard_router = create_dashboard_router() + app.include_router(dashboard_router, prefix="/api/v1") + # Include batch upload routes app.include_router(batch_upload_router) diff --git a/packages/inference/inference/web/schemas/admin/__init__.py b/packages/inference/inference/web/schemas/admin/__init__.py index ca4d999..b8c8228 100644 --- a/packages/inference/inference/web/schemas/admin/__init__.py +++ b/packages/inference/inference/web/schemas/admin/__init__.py @@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403 from .training import * # noqa: F401, F403 from .datasets import * # noqa: F401, F403 from .models import * # noqa: F401, F403 +from .dashboard import * # noqa: F401, F403 # Resolve forward references for DocumentDetailResponse from .documents import DocumentDetailResponse diff --git a/packages/inference/inference/web/schemas/admin/dashboard.py b/packages/inference/inference/web/schemas/admin/dashboard.py new file mode 100644 index 0000000..27532df --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/dashboard.py @@ -0,0 +1,92 @@ +""" +Dashboard API Schemas + +Pydantic models for dashboard statistics and activity endpoints. +""" + +from datetime import datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +# Activity type literals for type safety +ActivityType = Literal[ + "document_uploaded", + "annotation_modified", + "training_completed", + "training_failed", + "model_activated", +] + + +class DashboardStatsResponse(BaseModel): + """Response for dashboard statistics.""" + + total_documents: int = Field(..., description="Total number of documents") + annotation_complete: int = Field( + ..., description="Documents with complete annotations" + ) + annotation_incomplete: int = Field( + ..., description="Documents with incomplete annotations" + ) + pending: int = Field(..., description="Documents pending processing") + completeness_rate: float = Field( + ..., description="Annotation completeness percentage" + ) + + +class ActiveModelInfo(BaseModel): + """Active model information.""" + + version_id: str = Field(..., description="Model version UUID") + version: str = Field(..., description="Model version string") + name: str = Field(..., description="Model name") + metrics_mAP: float | None = Field(None, description="Mean Average Precision") + metrics_precision: float | None = Field(None, description="Precision score") + metrics_recall: float | None = Field(None, description="Recall score") + document_count: int = Field(0, description="Number of training documents") + activated_at: datetime | None = Field(None, description="Activation timestamp") + + +class RunningTrainingInfo(BaseModel): + """Running training task information.""" + + task_id: str = Field(..., description="Training task UUID") + name: str = Field(..., description="Training task name") + status: str = Field(..., description="Training status") + started_at: datetime | None = Field(None, description="Start timestamp") + progress: int = Field(0, description="Training progress percentage") + + +class ActiveModelResponse(BaseModel): + """Response for active model endpoint.""" + + model: ActiveModelInfo | None = Field( + None, description="Active model info, null if none" + ) + running_training: RunningTrainingInfo | None = Field( + None, description="Running training task, null if none" + ) + + +class ActivityItem(BaseModel): + """Single activity item.""" + + type: ActivityType = Field( + ..., + description="Activity type: document_uploaded, annotation_modified, training_completed, training_failed, model_activated", + ) + description: str = Field(..., description="Human-readable description") + timestamp: datetime = Field(..., description="Activity timestamp") + metadata: dict[str, Any] = Field( + default_factory=dict, description="Additional metadata" + ) + + +class RecentActivityResponse(BaseModel): + """Response for recent activity endpoint.""" + + activities: list[ActivityItem] = Field( + default_factory=list, description="List of recent activities" + ) diff --git a/packages/inference/inference/web/services/autolabel.py b/packages/inference/inference/web/services/autolabel.py index ebfbaff..25ebc27 100644 --- a/packages/inference/inference/web/services/autolabel.py +++ b/packages/inference/inference/web/services/autolabel.py @@ -291,7 +291,7 @@ class AutoLabelService: "bbox_y": bbox_y, "bbox_width": bbox_width, "bbox_height": bbox_height, - "text_value": best_match.matched_value, + "text_value": best_match.matched_text, "confidence": best_match.score, "source": "auto", }) diff --git a/packages/inference/inference/web/services/dashboard_service.py b/packages/inference/inference/web/services/dashboard_service.py new file mode 100644 index 0000000..44b2145 --- /dev/null +++ b/packages/inference/inference/web/services/dashboard_service.py @@ -0,0 +1,276 @@ +""" +Dashboard Service + +Business logic for dashboard statistics and activity aggregation. +""" + +import logging +from datetime import datetime, timezone +from typing import Any +from uuid import UUID + +from sqlalchemy import func, exists, and_, or_ +from sqlmodel import select + +from inference.data.database import get_session_context +from inference.data.admin_models import ( + AdminDocument, + AdminAnnotation, + AnnotationHistory, + TrainingTask, + ModelVersion, +) + +logger = logging.getLogger(__name__) + +# Field class IDs for completeness calculation +# Identifiers: invoice_number (0) or ocr_number (3) +IDENTIFIER_CLASS_IDS = {0, 3} +# Payment accounts: bankgiro (4) or plusgiro (5) +PAYMENT_CLASS_IDS = {4, 5} + + +def is_annotation_complete(annotations: list[dict[str, Any]]) -> bool: + """Check if a document's annotations are complete. + + A document is complete if it has: + - At least one identifier field (invoice_number OR ocr_number) + - At least one payment field (bankgiro OR plusgiro) + + Args: + annotations: List of annotation dicts with class_id + + Returns: + True if document has required fields + """ + class_ids = {ann.get("class_id") for ann in annotations} + + has_identifier = bool(class_ids & IDENTIFIER_CLASS_IDS) + has_payment = bool(class_ids & PAYMENT_CLASS_IDS) + + return has_identifier and has_payment + + +class DashboardStatsService: + """Service for computing dashboard statistics.""" + + def get_stats(self) -> dict[str, Any]: + """Get dashboard statistics. + + Returns: + Dict with total_documents, annotation_complete, annotation_incomplete, + pending, and completeness_rate + """ + with get_session_context() as session: + # Total documents + total = session.exec( + select(func.count()).select_from(AdminDocument) + ).one() + + # Pending documents (status in ['pending', 'auto_labeling']) + pending = session.exec( + select(func.count()) + .select_from(AdminDocument) + .where(AdminDocument.status.in_(["pending", "auto_labeling"])) + ).one() + + # Complete annotations: labeled + has identifier + has payment + complete = self._count_complete(session) + + # Incomplete: labeled but not complete + labeled_count = session.exec( + select(func.count()) + .select_from(AdminDocument) + .where(AdminDocument.status == "labeled") + ).one() + incomplete = labeled_count - complete + + # Calculate completeness rate + total_assessed = complete + incomplete + completeness_rate = ( + round(complete / total_assessed * 100, 2) + if total_assessed > 0 + else 0.0 + ) + + return { + "total_documents": total, + "annotation_complete": complete, + "annotation_incomplete": incomplete, + "pending": pending, + "completeness_rate": completeness_rate, + } + + def _count_complete(self, session) -> int: + """Count documents with complete annotations. + + A document is complete if it: + 1. Has status = 'labeled' + 2. Has at least one identifier annotation (class_id 0 or 3) + 3. Has at least one payment annotation (class_id 4 or 5) + """ + # Subquery for documents with identifier + has_identifier = exists( + select(1) + .select_from(AdminAnnotation) + .where( + and_( + AdminAnnotation.document_id == AdminDocument.document_id, + AdminAnnotation.class_id.in_(IDENTIFIER_CLASS_IDS), + ) + ) + ) + + # Subquery for documents with payment + has_payment = exists( + select(1) + .select_from(AdminAnnotation) + .where( + and_( + AdminAnnotation.document_id == AdminDocument.document_id, + AdminAnnotation.class_id.in_(PAYMENT_CLASS_IDS), + ) + ) + ) + + count = session.exec( + select(func.count()) + .select_from(AdminDocument) + .where( + and_( + AdminDocument.status == "labeled", + has_identifier, + has_payment, + ) + ) + ).one() + + return count + + +class DashboardActivityService: + """Service for aggregating recent activities.""" + + def get_recent_activities(self, limit: int = 10) -> list[dict[str, Any]]: + """Get recent system activities. + + Aggregates from: + - Document uploads + - Annotation modifications + - Training completions/failures + - Model activations + + Args: + limit: Maximum number of activities to return + + Returns: + List of activity dicts sorted by timestamp DESC + """ + activities = [] + + with get_session_context() as session: + # Document uploads (recent 10) + uploads = session.exec( + select(AdminDocument) + .order_by(AdminDocument.created_at.desc()) + .limit(limit) + ).all() + + for doc in uploads: + activities.append({ + "type": "document_uploaded", + "description": f"Uploaded {doc.filename}", + "timestamp": doc.created_at, + "metadata": { + "document_id": str(doc.document_id), + "filename": doc.filename, + }, + }) + + # Annotation modifications (from history) + modifications = session.exec( + select(AnnotationHistory) + .where(AnnotationHistory.action == "override") + .order_by(AnnotationHistory.created_at.desc()) + .limit(limit) + ).all() + + for mod in modifications: + # Get document filename + doc = session.get(AdminDocument, mod.document_id) + filename = doc.filename if doc else "Unknown" + field_name = "" + if mod.new_value and isinstance(mod.new_value, dict): + field_name = mod.new_value.get("class_name", "") + + activities.append({ + "type": "annotation_modified", + "description": f"Modified {filename} {field_name}".strip(), + "timestamp": mod.created_at, + "metadata": { + "annotation_id": str(mod.annotation_id), + "document_id": str(mod.document_id), + "field_name": field_name, + }, + }) + + # Training completions and failures + training_tasks = session.exec( + select(TrainingTask) + .where(TrainingTask.status.in_(["completed", "failed"])) + .order_by(TrainingTask.updated_at.desc()) + .limit(limit) + ).all() + + for task in training_tasks: + if task.updated_at is None: + continue + if task.status == "completed": + # Use metrics_mAP field directly + mAP = task.metrics_mAP or 0.0 + activities.append({ + "type": "training_completed", + "description": f"Training complete: {task.name}, mAP {mAP:.1%}", + "timestamp": task.updated_at, + "metadata": { + "task_id": str(task.task_id), + "task_name": task.name, + "mAP": mAP, + }, + }) + else: + activities.append({ + "type": "training_failed", + "description": f"Training failed: {task.name}", + "timestamp": task.updated_at, + "metadata": { + "task_id": str(task.task_id), + "task_name": task.name, + "error": task.error_message or "", + }, + }) + + # Model activations + model_versions = session.exec( + select(ModelVersion) + .where(ModelVersion.activated_at.is_not(None)) + .order_by(ModelVersion.activated_at.desc()) + .limit(limit) + ).all() + + for model in model_versions: + if model.activated_at is None: + continue + activities.append({ + "type": "model_activated", + "description": f"Activated model {model.version}", + "timestamp": model.activated_at, + "metadata": { + "version_id": str(model.version_id), + "version": model.version, + }, + }) + + # Sort all activities by timestamp DESC and return top N + activities.sort(key=lambda x: x["timestamp"], reverse=True) + return activities[:limit] diff --git a/pyproject.toml b/pyproject.toml index fe13e50..04613fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "black>=23.0.0", "ruff>=0.1.0", "mypy>=1.0.0", + "testcontainers[postgres]>=4.0.0", ] gpu = [ "paddlepaddle-gpu>=2.5.0", diff --git a/run_migration.py b/run_migration.py deleted file mode 100644 index 35cc7f8..0000000 --- a/run_migration.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Run database migration for training_status fields.""" -import psycopg2 -import os - -# Read password from .env file -password = "" -try: - with open(".env") as f: - for line in f: - if line.startswith("DB_PASSWORD="): - password = line.strip().split("=", 1)[1].strip('"').strip("'") - break -except Exception as e: - print(f"Error reading .env: {e}") - -print(f"Password found: {bool(password)}") - -conn = psycopg2.connect( - host="192.168.68.31", - port=5432, - database="docmaster", - user="docmaster", - password=password -) -conn.autocommit = True -cur = conn.cursor() - -# Add training_status column -try: - cur.execute("ALTER TABLE training_datasets ADD COLUMN training_status VARCHAR(20) DEFAULT NULL") - print("Added training_status column") -except Exception as e: - print(f"training_status: {e}") - -# Add active_training_task_id column -try: - cur.execute("ALTER TABLE training_datasets ADD COLUMN active_training_task_id UUID DEFAULT NULL") - print("Added active_training_task_id column") -except Exception as e: - print(f"active_training_task_id: {e}") - -# Create indexes -try: - cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status)") - print("Created training_status index") -except Exception as e: - print(f"index training_status: {e}") - -try: - cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id)") - print("Created active_training_task_id index") -except Exception as e: - print(f"index active_training_task_id: {e}") - -# Update existing datasets that have been used in completed training tasks to trained status -try: - cur.execute(""" - UPDATE training_datasets d - SET status = 'trained' - WHERE d.status = 'ready' - AND EXISTS ( - SELECT 1 FROM training_tasks t - WHERE t.dataset_id = d.dataset_id - AND t.status = 'completed' - ) - """) - print(f"Updated {cur.rowcount} datasets to trained status") -except Exception as e: - print(f"update status: {e}") - -cur.close() -conn.close() -print("Migration complete!") diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..fe002bb --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for invoice-master-poc-v2.""" diff --git a/tests/integration/api/__init__.py b/tests/integration/api/__init__.py new file mode 100644 index 0000000..a537cdd --- /dev/null +++ b/tests/integration/api/__init__.py @@ -0,0 +1 @@ +"""API integration tests.""" diff --git a/tests/integration/api/test_api_integration.py b/tests/integration/api/test_api_integration.py new file mode 100644 index 0000000..9fff079 --- /dev/null +++ b/tests/integration/api/test_api_integration.py @@ -0,0 +1,389 @@ +""" +API Integration Tests + +Tests FastAPI endpoints with mocked services. +These tests verify the API layer works correctly with the service layer. +""" + +import io +import tempfile +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +@dataclass +class MockServiceResult: + """Mock result from inference service.""" + + document_id: str = "test-doc-123" + success: bool = True + document_type: str = "invoice" + fields: dict[str, str] = field(default_factory=lambda: { + "InvoiceNumber": "INV-2024-001", + "Amount": "1500.00", + "InvoiceDate": "2024-01-15", + "OCR": "12345678901234", + "Bankgiro": "1234-5678", + }) + confidence: dict[str, float] = field(default_factory=lambda: { + "InvoiceNumber": 0.95, + "Amount": 0.92, + "InvoiceDate": 0.88, + "OCR": 0.95, + "Bankgiro": 0.90, + }) + detections: list[dict[str, Any]] = field(default_factory=list) + processing_time_ms: float = 150.5 + visualization_path: Path | None = None + errors: list[str] = field(default_factory=list) + + +@pytest.fixture +def temp_storage_dir(): + """Create temporary storage directories.""" + with tempfile.TemporaryDirectory() as tmpdir: + base = Path(tmpdir) + uploads_dir = base / "uploads" / "inference" + results_dir = base / "results" + uploads_dir.mkdir(parents=True, exist_ok=True) + results_dir.mkdir(parents=True, exist_ok=True) + yield { + "base": base, + "uploads": uploads_dir, + "results": results_dir, + } + + +@pytest.fixture +def mock_inference_service(): + """Create a mock inference service.""" + service = MagicMock() + service.is_initialized = True + service.gpu_available = False + + # Create a realistic mock result + mock_result = MockServiceResult() + + service.process_pdf.return_value = mock_result + service.process_image.return_value = mock_result + service.initialize.return_value = None + + return service + + +@pytest.fixture +def mock_storage_config(temp_storage_dir): + """Create mock storage configuration.""" + from inference.web.config import StorageConfig + + return StorageConfig( + upload_dir=temp_storage_dir["uploads"], + result_dir=temp_storage_dir["results"], + max_file_size_mb=50, + ) + + +@pytest.fixture +def mock_storage_helper(temp_storage_dir): + """Create a mock storage helper.""" + helper = MagicMock() + helper.get_uploads_base_path.return_value = temp_storage_dir["uploads"] + helper.get_result_local_path.return_value = None + helper.result_exists.return_value = False + return helper + + +@pytest.fixture +def test_app(mock_inference_service, mock_storage_config, mock_storage_helper): + """Create a test FastAPI application with mocked storage.""" + from inference.web.api.v1.public.inference import create_inference_router + + app = FastAPI() + + # Patch get_storage_helper to return our mock + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + inference_router = create_inference_router(mock_inference_service, mock_storage_config) + app.include_router(inference_router) + + return app + + +@pytest.fixture +def client(test_app, mock_storage_helper): + """Create a test client with storage helper patched.""" + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + yield TestClient(test_app) + + +class TestHealthEndpoint: + """Tests for health check endpoint.""" + + def test_health_check(self, client, mock_inference_service): + """Test health check returns status.""" + response = client.get("/api/v1/health") + + assert response.status_code == 200 + data = response.json() + assert "status" in data + assert "model_loaded" in data + + +class TestInferenceEndpoint: + """Tests for inference endpoint.""" + + def test_infer_pdf(self, client, mock_inference_service, mock_storage_helper, temp_storage_dir): + """Test PDF inference endpoint.""" + # Create a minimal PDF content + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + assert response.status_code == 200 + data = response.json() + assert "result" in data + assert data["result"]["success"] is True + assert "InvoiceNumber" in data["result"]["fields"] + + def test_infer_image(self, client, mock_inference_service, mock_storage_helper): + """Test image inference endpoint.""" + # Create minimal PNG header + png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.png", io.BytesIO(png_header), "image/png")}, + ) + + assert response.status_code == 200 + data = response.json() + assert "result" in data + + def test_infer_invalid_file_type(self, client, mock_storage_helper): + """Test rejection of invalid file types.""" + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")}, + ) + + assert response.status_code == 400 + + def test_infer_no_file(self, client, mock_storage_helper): + """Test rejection when no file provided.""" + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post("/api/v1/infer") + + assert response.status_code == 422 # Validation error + + def test_infer_result_structure(self, client, mock_inference_service, mock_storage_helper): + """Test that result has expected structure.""" + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + data = response.json() + result = data["result"] + + # Check required fields + assert "document_id" in result + assert "success" in result + assert "fields" in result + assert "confidence" in result + assert "processing_time_ms" in result + + +class TestInferenceResultFormat: + """Tests for inference result formatting.""" + + def test_result_fields_mapped_correctly(self, client, mock_inference_service, mock_storage_helper): + """Test that fields are mapped to API response format.""" + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + data = response.json() + fields = data["result"]["fields"] + + assert fields["InvoiceNumber"] == "INV-2024-001" + assert fields["Amount"] == "1500.00" + assert fields["InvoiceDate"] == "2024-01-15" + + def test_confidence_values_included(self, client, mock_inference_service, mock_storage_helper): + """Test that confidence values are included.""" + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + data = response.json() + confidence = data["result"]["confidence"] + + assert "InvoiceNumber" in confidence + assert confidence["InvoiceNumber"] == 0.95 + + +class TestErrorHandling: + """Tests for error handling in API.""" + + def test_service_error_handling(self, client, mock_inference_service, mock_storage_helper): + """Test handling of service errors.""" + mock_inference_service.process_pdf.side_effect = Exception("Processing failed") + + pdf_content = b"%PDF-1.4\n%test\n" + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + # Should return error response + assert response.status_code >= 400 + + def test_empty_file_handling(self, client, mock_storage_helper): + """Test handling of empty files.""" + # Empty file still has valid content type + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(b""), "application/pdf")}, + ) + + # Empty file may be processed or rejected depending on implementation + # Just verify we get a response + assert response.status_code in [200, 400, 422, 500] + + +class TestResponseFormat: + """Tests for API response format consistency.""" + + def test_success_response_format(self, client, mock_inference_service, mock_storage_helper): + """Test successful response format.""" + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + + data = response.json() + assert isinstance(data, dict) + assert "result" in data + + def test_json_serialization(self, client, mock_inference_service, mock_storage_helper): + """Test that all result fields are JSON serializable.""" + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + # If this doesn't raise, JSON is valid + data = response.json() + assert data is not None + + +class TestDocumentIdGeneration: + """Tests for document ID handling.""" + + def test_document_id_generated(self, client, mock_inference_service, mock_storage_helper): + """Test that document ID is generated.""" + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("test.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + data = response.json() + assert "document_id" in data["result"] + assert data["result"]["document_id"] is not None + + def test_document_id_from_filename(self, client, mock_inference_service, mock_storage_helper): + """Test document ID derived from filename.""" + pdf_content = b"%PDF-1.4\n%test\n" + + with patch( + "inference.web.api.v1.public.inference.get_storage_helper", + return_value=mock_storage_helper, + ): + response = client.post( + "/api/v1/infer", + files={"file": ("my_invoice_123.pdf", io.BytesIO(pdf_content), "application/pdf")}, + ) + + data = response.json() + # Document ID should be set (either from filename or generated) + assert data["result"]["document_id"] is not None diff --git a/tests/integration/api/test_dashboard_api_integration.py b/tests/integration/api/test_dashboard_api_integration.py new file mode 100644 index 0000000..ec81202 --- /dev/null +++ b/tests/integration/api/test_dashboard_api_integration.py @@ -0,0 +1,400 @@ +""" +Dashboard API Integration Tests + +Tests Dashboard API endpoints with real database operations via TestClient. +""" + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from inference.data.admin_models import ( + AdminAnnotation, + AdminDocument, + AdminToken, + AnnotationHistory, + ModelVersion, + TrainingDataset, + TrainingTask, +) +from inference.web.api.v1.admin.dashboard import create_dashboard_router +from inference.web.core.auth import get_admin_token_dep + + +def create_test_app(override_token_dep): + """Create a FastAPI test application with dashboard router.""" + app = FastAPI() + router = create_dashboard_router() + app.include_router(router) + + # Override auth dependency + app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep + + return app + + +class TestDashboardStatsEndpoint: + """Tests for GET /admin/dashboard/stats endpoint.""" + + def test_stats_empty_database(self, patched_session, admin_token): + """Test stats endpoint with empty database.""" + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/stats") + + assert response.status_code == 200 + data = response.json() + assert data["total_documents"] == 0 + assert data["annotation_complete"] == 0 + assert data["annotation_incomplete"] == 0 + assert data["pending"] == 0 + assert data["completeness_rate"] == 0.0 + + def test_stats_with_pending_documents(self, patched_session, admin_token): + """Test stats with pending documents.""" + session = patched_session + + # Create pending documents + for i in range(3): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"pending_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/pending_{i}.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/stats") + + assert response.status_code == 200 + data = response.json() + assert data["total_documents"] == 3 + assert data["pending"] == 3 + + def test_stats_with_complete_annotations(self, patched_session, admin_token): + """Test stats with complete annotations.""" + session = patched_session + + # Create labeled document with complete annotations + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename="complete.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/uploads/complete.pdf", + page_count=1, + status="labeled", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + # Add identifier and payment annotations + session.add(AdminAnnotation( + annotation_id=uuid4(), + document_id=doc.document_id, + page_number=1, + class_id=0, # invoice_number + class_name="invoice_number", + x_center=0.5, y_center=0.1, width=0.2, height=0.05, + bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + )) + session.add(AdminAnnotation( + annotation_id=uuid4(), + document_id=doc.document_id, + page_number=1, + class_id=4, # bankgiro + class_name="bankgiro", + x_center=0.5, y_center=0.2, width=0.2, height=0.05, + bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + )) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/stats") + + assert response.status_code == 200 + data = response.json() + assert data["annotation_complete"] == 1 + assert data["completeness_rate"] == 100.0 + + +class TestActiveModelEndpoint: + """Tests for GET /admin/dashboard/active-model endpoint.""" + + def test_active_model_none(self, patched_session, admin_token): + """Test active-model endpoint with no active model.""" + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/active-model") + + assert response.status_code == 200 + data = response.json() + assert data["model"] is None + assert data["running_training"] is None + + def test_active_model_with_model(self, patched_session, admin_token, sample_dataset): + """Test active-model endpoint with active model.""" + session = patched_session + + # Create training task + task = TrainingTask( + task_id=uuid4(), + admin_token=admin_token.token, + name="Test Task", + status="completed", + task_type="train", + dataset_id=sample_dataset.dataset_id, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(task) + session.commit() + + # Create active model + model = ModelVersion( + version_id=uuid4(), + version="1.0.0", + name="Test Model", + model_path="/models/test.pt", + status="active", + is_active=True, + task_id=task.task_id, + dataset_id=sample_dataset.dataset_id, + metrics_mAP=0.90, + metrics_precision=0.88, + metrics_recall=0.85, + document_count=100, + file_size=50000000, + activated_at=datetime.now(timezone.utc), + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(model) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/active-model") + + assert response.status_code == 200 + data = response.json() + assert data["model"] is not None + assert data["model"]["version"] == "1.0.0" + assert data["model"]["name"] == "Test Model" + assert data["model"]["metrics_mAP"] == 0.90 + + def test_active_model_with_running_training(self, patched_session, admin_token, sample_dataset): + """Test active-model endpoint with running training.""" + session = patched_session + + # Create running training task + task = TrainingTask( + task_id=uuid4(), + admin_token=admin_token.token, + name="Running Task", + status="running", + task_type="train", + dataset_id=sample_dataset.dataset_id, + started_at=datetime.now(timezone.utc), + progress=50, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(task) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/active-model") + + assert response.status_code == 200 + data = response.json() + assert data["running_training"] is not None + assert data["running_training"]["name"] == "Running Task" + assert data["running_training"]["status"] == "running" + assert data["running_training"]["progress"] == 50 + + +class TestRecentActivityEndpoint: + """Tests for GET /admin/dashboard/activity endpoint.""" + + def test_activity_empty(self, patched_session, admin_token): + """Test activity endpoint with no activities.""" + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/activity") + + assert response.status_code == 200 + data = response.json() + assert data["activities"] == [] + + def test_activity_with_uploads(self, patched_session, admin_token): + """Test activity includes document uploads.""" + session = patched_session + + # Create documents + for i in range(3): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"activity_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/activity_{i}.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/activity") + + assert response.status_code == 200 + data = response.json() + upload_activities = [a for a in data["activities"] if a["type"] == "document_uploaded"] + assert len(upload_activities) == 3 + + def test_activity_limit_parameter(self, patched_session, admin_token): + """Test activity limit parameter.""" + session = patched_session + + # Create many documents + for i in range(15): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"limit_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/limit_{i}.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/activity?limit=5") + + assert response.status_code == 200 + data = response.json() + assert len(data["activities"]) <= 5 + + def test_activity_invalid_limit(self, patched_session, admin_token): + """Test activity with invalid limit parameter.""" + app = create_test_app(admin_token.token) + client = TestClient(app) + + # Limit too high + response = client.get("/admin/dashboard/activity?limit=100") + assert response.status_code == 422 + + # Limit too low + response = client.get("/admin/dashboard/activity?limit=0") + assert response.status_code == 422 + + def test_activity_with_training_completion(self, patched_session, admin_token, sample_dataset): + """Test activity includes training completions.""" + session = patched_session + + # Create completed training task + task = TrainingTask( + task_id=uuid4(), + admin_token=admin_token.token, + name="Completed Task", + status="completed", + task_type="train", + dataset_id=sample_dataset.dataset_id, + metrics_mAP=0.95, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(task) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/activity") + + assert response.status_code == 200 + data = response.json() + training_activities = [a for a in data["activities"] if a["type"] == "training_completed"] + assert len(training_activities) >= 1 + + def test_activity_sorted_by_timestamp(self, patched_session, admin_token): + """Test activities are sorted by timestamp descending.""" + session = patched_session + + # Create documents + for i in range(5): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"sorted_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/sorted_{i}.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + app = create_test_app(admin_token.token) + client = TestClient(app) + + response = client.get("/admin/dashboard/activity") + + assert response.status_code == 200 + data = response.json() + timestamps = [a["timestamp"] for a in data["activities"]] + assert timestamps == sorted(timestamps, reverse=True) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..6df8c55 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,465 @@ +""" +Integration Test Fixtures + +Provides shared fixtures for integration tests using PostgreSQL. + +IMPORTANT: Integration tests MUST use Docker testcontainers for database isolation. +This ensures tests never touch the real production/development database. + +Supported modes: +1. Docker testcontainers (default): Automatically starts a PostgreSQL container +2. TEST_DB_URL environment variable: Use a dedicated test database (NOT production!) + +To use an external test database, set: + TEST_DB_URL=postgresql://user:password@host:port/test_dbname +""" + +import os +import tempfile +from contextlib import contextmanager, ExitStack +from datetime import datetime, timezone +from pathlib import Path +from typing import Generator +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlmodel import Session, SQLModel, create_engine + +from inference.data.admin_models import ( + AdminAnnotation, + AdminDocument, + AdminToken, + AnnotationHistory, + BatchUpload, + BatchUploadFile, + DatasetDocument, + ModelVersion, + TrainingDataset, + TrainingDocumentLink, + TrainingLog, + TrainingTask, +) + + +# ============================================================================= +# Database Fixtures +# ============================================================================= + + +def _is_docker_available() -> bool: + """Check if Docker is available.""" + try: + import docker + client = docker.from_env() + client.ping() + return True + except Exception: + return False + + +def _get_test_db_url() -> str | None: + """Get test database URL from environment.""" + return os.environ.get("TEST_DB_URL") + + +@pytest.fixture(scope="session") +def test_engine(): + """Create a SQLAlchemy engine for testing. + + Uses one of: + 1. TEST_DB_URL environment variable (dedicated test database) + 2. Docker testcontainers (if Docker is available) + + IMPORTANT: Will NOT fall back to production database. If Docker is not + available and TEST_DB_URL is not set, tests will fail with a clear error. + + The engine is shared across all tests in a session for efficiency. + """ + # Try to get URL from environment first + connection_url = _get_test_db_url() + + if connection_url: + # Use external test database from environment + # Warn if it looks like a production database + if "docmaster" in connection_url and "_test" not in connection_url: + import warnings + warnings.warn( + "TEST_DB_URL appears to point to a production database. " + "Please use a dedicated test database (e.g., docmaster_test).", + UserWarning, + ) + elif _is_docker_available(): + # Use testcontainers - this is the recommended approach + from testcontainers.postgres import PostgresContainer + postgres = PostgresContainer("postgres:15-alpine") + postgres.start() + + connection_url = postgres.get_connection_url() + if "psycopg2" in connection_url: + connection_url = connection_url.replace("postgresql+psycopg2://", "postgresql://") + + # Store container for cleanup + test_engine._postgres_container = postgres + else: + # No Docker and no TEST_DB_URL - fail with clear instructions + pytest.fail( + "Integration tests require Docker or a TEST_DB_URL environment variable.\n\n" + "Option 1 (Recommended): Install Docker Desktop and ensure it's running.\n" + " - Windows: https://docs.docker.com/desktop/install/windows-install/\n" + " - The testcontainers library will automatically create a PostgreSQL container.\n\n" + "Option 2: Set TEST_DB_URL to a dedicated test database:\n" + " - export TEST_DB_URL=postgresql://user:password@host:port/test_dbname\n" + " - NEVER use your production database for tests!\n\n" + "Integration tests will NOT fall back to the production database." + ) + + engine = create_engine( + connection_url, + echo=False, + pool_pre_ping=True, + ) + + # Create all tables + SQLModel.metadata.create_all(engine) + + yield engine + + # Cleanup + SQLModel.metadata.drop_all(engine) + engine.dispose() + + # Stop container if we started one + if hasattr(test_engine, "_postgres_container"): + test_engine._postgres_container.stop() + + +@pytest.fixture(scope="function") +def db_session(test_engine) -> Generator[Session, None, None]: + """Provide a database session for each test function. + + Each test gets a fresh session that rolls back after the test, + ensuring test isolation. + """ + connection = test_engine.connect() + transaction = connection.begin() + session = Session(bind=connection) + + yield session + + # Rollback and cleanup + session.close() + transaction.rollback() + connection.close() + + +@pytest.fixture(scope="function") +def patched_session(db_session): + """Patch get_session_context to use the test session. + + This allows repository classes to use the test database session + instead of creating their own connections. + + We need to patch in multiple locations because each repository module + imports get_session_context directly. + """ + + @contextmanager + def mock_session_context() -> Generator[Session, None, None]: + yield db_session + + # All modules that import get_session_context + patch_targets = [ + "inference.data.database.get_session_context", + "inference.data.repositories.document_repository.get_session_context", + "inference.data.repositories.annotation_repository.get_session_context", + "inference.data.repositories.dataset_repository.get_session_context", + "inference.data.repositories.training_task_repository.get_session_context", + "inference.data.repositories.model_version_repository.get_session_context", + "inference.data.repositories.batch_upload_repository.get_session_context", + "inference.data.repositories.token_repository.get_session_context", + "inference.web.services.dashboard_service.get_session_context", + ] + + with ExitStack() as stack: + for target in patch_targets: + try: + stack.enter_context(patch(target, mock_session_context)) + except (ModuleNotFoundError, AttributeError): + # Skip if module doesn't exist or doesn't have the attribute + pass + yield db_session + + +# ============================================================================= +# Test Data Fixtures +# ============================================================================= + + +@pytest.fixture +def admin_token(db_session) -> AdminToken: + """Create a test admin token.""" + token = AdminToken( + token="test-admin-token-12345", + name="Test Admin", + is_active=True, + created_at=datetime.now(timezone.utc), + ) + db_session.add(token) + db_session.commit() + db_session.refresh(token) + return token + + +@pytest.fixture +def sample_document(db_session, admin_token) -> AdminDocument: + """Create a sample document for testing.""" + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename="test_invoice.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/uploads/test_invoice.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(doc) + db_session.commit() + db_session.refresh(doc) + return doc + + +@pytest.fixture +def sample_annotation(db_session, sample_document) -> AdminAnnotation: + """Create a sample annotation for testing.""" + annotation = AdminAnnotation( + annotation_id=uuid4(), + document_id=sample_document.document_id, + page_number=1, + class_id=0, + class_name="invoice_number", + x_center=0.5, + y_center=0.3, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=240, + bbox_width=160, + bbox_height=40, + text_value="INV-2024-001", + confidence=0.95, + source="auto", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(annotation) + db_session.commit() + db_session.refresh(annotation) + return annotation + + +@pytest.fixture +def sample_dataset(db_session) -> TrainingDataset: + """Create a sample training dataset for testing.""" + dataset = TrainingDataset( + dataset_id=uuid4(), + name="Test Dataset", + description="Dataset for integration testing", + status="building", + train_ratio=0.8, + val_ratio=0.1, + seed=42, + total_documents=0, + total_images=0, + total_annotations=0, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(dataset) + db_session.commit() + db_session.refresh(dataset) + return dataset + + +@pytest.fixture +def sample_training_task(db_session, admin_token, sample_dataset) -> TrainingTask: + """Create a sample training task for testing.""" + task = TrainingTask( + task_id=uuid4(), + admin_token=admin_token.token, + name="Test Training Task", + description="Training task for integration testing", + status="pending", + task_type="train", + dataset_id=sample_dataset.dataset_id, + config={"epochs": 10, "batch_size": 16}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + return task + + +@pytest.fixture +def sample_model_version(db_session, sample_training_task, sample_dataset) -> ModelVersion: + """Create a sample model version for testing.""" + version = ModelVersion( + version_id=uuid4(), + version="1.0.0", + name="Test Model v1", + description="Model version for integration testing", + model_path="/models/test_model.pt", + status="inactive", + is_active=False, + task_id=sample_training_task.task_id, + dataset_id=sample_dataset.dataset_id, + metrics_mAP=0.85, + metrics_precision=0.88, + metrics_recall=0.82, + document_count=100, + file_size=50000000, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(version) + db_session.commit() + db_session.refresh(version) + return version + + +@pytest.fixture +def sample_batch_upload(db_session, admin_token) -> BatchUpload: + """Create a sample batch upload for testing.""" + batch = BatchUpload( + batch_id=uuid4(), + admin_token=admin_token.token, + filename="test_batch.zip", + file_size=10240, + upload_source="api", + status="processing", + total_files=5, + processed_files=0, + successful_files=0, + failed_files=0, + created_at=datetime.now(timezone.utc), + ) + db_session.add(batch) + db_session.commit() + db_session.refresh(batch) + return batch + + +# ============================================================================= +# Multiple Documents Fixture +# ============================================================================= + + +@pytest.fixture +def multiple_documents(db_session, admin_token) -> list[AdminDocument]: + """Create multiple documents for pagination/filtering tests.""" + documents = [] + statuses = ["pending", "pending", "labeled", "labeled", "exported"] + categories = ["invoice", "invoice", "invoice", "letter", "invoice"] + + for i, (status, category) in enumerate(zip(statuses, categories)): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"test_doc_{i}.pdf", + file_size=1024 + i * 100, + content_type="application/pdf", + file_path=f"/uploads/test_doc_{i}.pdf", + page_count=1, + status=status, + upload_source="ui", + category=category, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(doc) + documents.append(doc) + + db_session.commit() + for doc in documents: + db_session.refresh(doc) + + return documents + + +# ============================================================================= +# Temporary File Fixtures +# ============================================================================= + + +@pytest.fixture +def temp_upload_dir() -> Generator[Path, None, None]: + """Create a temporary directory for file uploads.""" + with tempfile.TemporaryDirectory() as tmpdir: + upload_dir = Path(tmpdir) / "uploads" + upload_dir.mkdir(parents=True, exist_ok=True) + yield upload_dir + + +@pytest.fixture +def temp_model_dir() -> Generator[Path, None, None]: + """Create a temporary directory for model files.""" + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) / "models" + model_dir.mkdir(parents=True, exist_ok=True) + yield model_dir + + +@pytest.fixture +def temp_dataset_dir() -> Generator[Path, None, None]: + """Create a temporary directory for dataset files.""" + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) / "datasets" + dataset_dir.mkdir(parents=True, exist_ok=True) + yield dataset_dir + + +# ============================================================================= +# Sample PDF Fixture +# ============================================================================= + + +@pytest.fixture +def sample_pdf_bytes() -> bytes: + """Return minimal valid PDF bytes for testing.""" + # Minimal valid PDF structure + return b"""%PDF-1.4 +1 0 obj +<< /Type /Catalog /Pages 2 0 R >> +endobj +2 0 obj +<< /Type /Pages /Kids [3 0 R] /Count 1 >> +endobj +3 0 obj +<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] >> +endobj +xref +0 4 +0000000000 65535 f +0000000009 00000 n +0000000058 00000 n +0000000115 00000 n +trailer +<< /Size 4 /Root 1 0 R >> +startxref +196 +%%EOF""" + + +@pytest.fixture +def sample_pdf_file(temp_upload_dir, sample_pdf_bytes) -> Path: + """Create a sample PDF file for testing.""" + pdf_path = temp_upload_dir / "test_invoice.pdf" + pdf_path.write_bytes(sample_pdf_bytes) + return pdf_path diff --git a/tests/integration/pipeline/__init__.py b/tests/integration/pipeline/__init__.py new file mode 100644 index 0000000..fe2e66e --- /dev/null +++ b/tests/integration/pipeline/__init__.py @@ -0,0 +1 @@ +"""Pipeline integration tests.""" diff --git a/tests/integration/pipeline/test_pipeline_integration.py b/tests/integration/pipeline/test_pipeline_integration.py new file mode 100644 index 0000000..0de7252 --- /dev/null +++ b/tests/integration/pipeline/test_pipeline_integration.py @@ -0,0 +1,456 @@ +""" +Inference Pipeline Integration Tests + +Tests the complete pipeline from input to output. +Note: These tests use mocks for YOLO and OCR to avoid requiring actual models, +but test the integration of pipeline components. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +import numpy as np + +from inference.pipeline.pipeline import ( + InferencePipeline, + InferenceResult, + CrossValidationResult, +) +from inference.pipeline.yolo_detector import Detection +from inference.pipeline.field_extractor import ExtractedField + + +@pytest.fixture +def mock_detection(): + """Create a mock detection.""" + return Detection( + class_id=0, + class_name="invoice_number", + confidence=0.95, + bbox=(100, 50, 200, 30), + page_no=0, + ) + + +@pytest.fixture +def mock_extracted_field(): + """Create a mock extracted field.""" + return ExtractedField( + field_name="InvoiceNumber", + raw_text="INV-2024-001", + normalized_value="INV-2024-001", + confidence=0.95, + bbox=(100, 50, 200, 30), + page_no=0, + is_valid=True, + ) + + +class TestInferenceResultConstruction: + """Tests for InferenceResult construction and methods.""" + + def test_default_result(self): + """Test default InferenceResult values.""" + result = InferenceResult() + + assert result.document_id is None + assert result.success is False + assert result.fields == {} + assert result.confidence == {} + assert result.raw_detections == [] + assert result.extracted_fields == [] + assert result.errors == [] + assert result.fallback_used is False + assert result.cross_validation is None + + def test_result_to_json(self): + """Test JSON serialization of result.""" + result = InferenceResult( + document_id="test-doc", + success=True, + fields={ + "InvoiceNumber": "INV-001", + "Amount": "1500.00", + }, + confidence={ + "InvoiceNumber": 0.95, + "Amount": 0.92, + }, + bboxes={ + "InvoiceNumber": (100, 50, 200, 30), + }, + ) + + json_data = result.to_json() + + assert json_data["DocumentId"] == "test-doc" + assert json_data["success"] is True + assert json_data["InvoiceNumber"] == "INV-001" + assert json_data["Amount"] == "1500.00" + assert json_data["confidence"]["InvoiceNumber"] == 0.95 + assert "bboxes" in json_data + + def test_result_get_field(self): + """Test getting field value and confidence.""" + result = InferenceResult( + fields={"InvoiceNumber": "INV-001"}, + confidence={"InvoiceNumber": 0.95}, + ) + + value, conf = result.get_field("InvoiceNumber") + assert value == "INV-001" + assert conf == 0.95 + + value, conf = result.get_field("Amount") + assert value is None + assert conf == 0.0 + + +class TestCrossValidation: + """Tests for cross-validation logic.""" + + def test_cross_validation_default(self): + """Test default CrossValidationResult values.""" + cv = CrossValidationResult() + + assert cv.is_valid is False + assert cv.ocr_match is None + assert cv.amount_match is None + assert cv.bankgiro_match is None + assert cv.plusgiro_match is None + assert cv.payment_line_ocr is None + assert cv.payment_line_amount is None + assert cv.details == [] + + def test_cross_validation_with_matches(self): + """Test CrossValidationResult with matches.""" + cv = CrossValidationResult( + is_valid=True, + ocr_match=True, + amount_match=True, + bankgiro_match=True, + payment_line_ocr="12345678901234", + payment_line_amount="1500.00", + payment_line_account="1234-5678", + payment_line_account_type="bankgiro", + details=["OCR match", "Amount match", "Bankgiro match"], + ) + + assert cv.is_valid is True + assert cv.ocr_match is True + assert cv.amount_match is True + assert len(cv.details) == 3 + + +class TestPipelineMergeFields: + """Tests for field merging logic.""" + + def test_merge_selects_highest_confidence(self): + """Test that merge selects highest confidence for duplicate fields.""" + # Create mock pipeline with minimal mocking + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + pipeline.payment_line_parser = MagicMock() + pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False) + + result = InferenceResult() + result.extracted_fields = [ + ExtractedField( + field_name="InvoiceNumber", + raw_text="INV-001", + normalized_value="INV-001", + confidence=0.85, + detection_confidence=0.90, + ocr_confidence=0.85, + bbox=(100, 50, 200, 30), + page_no=0, + is_valid=True, + ), + ExtractedField( + field_name="InvoiceNumber", + raw_text="INV-001", + normalized_value="INV-001", + confidence=0.95, # Higher confidence + detection_confidence=0.95, + ocr_confidence=0.95, + bbox=(105, 52, 198, 28), + page_no=0, + is_valid=True, + ), + ] + + pipeline._merge_fields(result) + + assert result.fields["InvoiceNumber"] == "INV-001" + assert result.confidence["InvoiceNumber"] == 0.95 + + def test_merge_skips_invalid_fields(self): + """Test that merge skips invalid extracted fields.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + pipeline.payment_line_parser = MagicMock() + pipeline.payment_line_parser.parse.return_value = MagicMock(is_valid=False) + + result = InferenceResult() + result.extracted_fields = [ + ExtractedField( + field_name="InvoiceNumber", + raw_text="", + normalized_value=None, + confidence=0.95, + detection_confidence=0.95, + ocr_confidence=0.95, + bbox=(100, 50, 200, 30), + page_no=0, + is_valid=False, # Invalid + ), + ExtractedField( + field_name="Amount", + raw_text="1500.00", + normalized_value="1500.00", + confidence=0.92, + detection_confidence=0.92, + ocr_confidence=0.92, + bbox=(200, 100, 100, 25), + page_no=0, + is_valid=True, + ), + ] + + pipeline._merge_fields(result) + + assert "InvoiceNumber" not in result.fields + assert result.fields["Amount"] == "1500.00" + + +class TestPaymentLineValidation: + """Tests for payment line cross-validation.""" + + def test_payment_line_overrides_ocr(self): + """Test that payment line OCR overrides detected OCR.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + # Mock payment line parser + mock_parsed = MagicMock() + mock_parsed.is_valid = True + mock_parsed.ocr_number = "12345678901234" + mock_parsed.amount = "1500.00" + mock_parsed.account_number = "12345678" + + pipeline.payment_line_parser = MagicMock() + pipeline.payment_line_parser.parse.return_value = mock_parsed + + result = InferenceResult( + fields={ + "payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#", + "OCR": "99999999999999", # Different OCR + }, + confidence={"OCR": 0.85}, + ) + + pipeline._cross_validate_payment_line(result) + + # Payment line OCR should override + assert result.fields["OCR"] == "12345678901234" + assert result.confidence["OCR"] == 0.95 + + def test_payment_line_overrides_amount(self): + """Test that payment line amount overrides detected amount.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + mock_parsed = MagicMock() + mock_parsed.is_valid = True + mock_parsed.ocr_number = None + mock_parsed.amount = "2500.50" + mock_parsed.account_number = None + + pipeline.payment_line_parser = MagicMock() + pipeline.payment_line_parser.parse.return_value = mock_parsed + + result = InferenceResult( + fields={ + "payment_line": "# ... # 2500 50 5 > ...", + "Amount": "2500.00", # Slightly different + }, + confidence={"Amount": 0.80}, + ) + + pipeline._cross_validate_payment_line(result) + + assert result.fields["Amount"] == "2500.50" + assert result.confidence["Amount"] == 0.95 + + def test_cross_validation_records_matches(self): + """Test that cross-validation records match status.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + mock_parsed = MagicMock() + mock_parsed.is_valid = True + mock_parsed.ocr_number = "12345678901234" + mock_parsed.amount = "1500.00" + mock_parsed.account_number = "12345678" + + pipeline.payment_line_parser = MagicMock() + pipeline.payment_line_parser.parse.return_value = mock_parsed + + result = InferenceResult( + fields={ + "payment_line": "# 12345678901234 # 1500 00 5 > 12345678#41#", + "OCR": "12345678901234", # Matching + "Amount": "1500.00", # Matching + "Bankgiro": "1234-5678", # Matching + }, + confidence={}, + ) + + pipeline._cross_validate_payment_line(result) + + assert result.cross_validation is not None + assert result.cross_validation.ocr_match is True + assert result.cross_validation.amount_match is True + assert result.cross_validation.is_valid is True + + +class TestFallbackLogic: + """Tests for fallback detection logic.""" + + def test_needs_fallback_when_key_fields_missing(self): + """Test fallback is triggered when key fields missing.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + # Only one key field present + result = InferenceResult(fields={"Amount": "1500.00"}) + + assert pipeline._needs_fallback(result) is True + + def test_no_fallback_when_fields_present(self): + """Test no fallback when key fields present.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + # All key fields present + result = InferenceResult( + fields={ + "Amount": "1500.00", + "InvoiceNumber": "INV-001", + "OCR": "12345678901234", + } + ) + + assert pipeline._needs_fallback(result) is False + + +class TestPatternExtraction: + """Tests for fallback pattern extraction.""" + + def test_extract_amount_pattern(self): + """Test amount extraction with regex.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + text = "Att betala: 1 500,00 SEK" + result = InferenceResult() + + pipeline._extract_with_patterns(text, result) + + assert "Amount" in result.fields + assert result.confidence["Amount"] == 0.5 + + def test_extract_bankgiro_pattern(self): + """Test bankgiro extraction with regex.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + text = "Bankgiro: 1234-5678" + result = InferenceResult() + + pipeline._extract_with_patterns(text, result) + + assert "Bankgiro" in result.fields + assert result.fields["Bankgiro"] == "1234-5678" + + def test_extract_ocr_pattern(self): + """Test OCR extraction with regex.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + text = "OCR: 12345678901234567890" + result = InferenceResult() + + pipeline._extract_with_patterns(text, result) + + assert "OCR" in result.fields + assert result.fields["OCR"] == "12345678901234567890" + + def test_does_not_override_existing_fields(self): + """Test pattern extraction doesn't override existing fields.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + text = "Fakturanr: 999" + result = InferenceResult(fields={"InvoiceNumber": "INV-001"}) + + pipeline._extract_with_patterns(text, result) + + # Should keep existing value + assert result.fields["InvoiceNumber"] == "INV-001" + + +class TestAmountNormalization: + """Tests for amount normalization.""" + + def test_normalize_swedish_format(self): + """Test normalizing Swedish amount format.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + # Swedish format: space as thousands separator, comma as decimal + assert pipeline._normalize_amount_for_compare("1 500,00") == 1500.00 + # Standard format: dot as decimal + assert pipeline._normalize_amount_for_compare("1500.00") == 1500.00 + # Swedish format with comma as decimal + assert pipeline._normalize_amount_for_compare("1500,00") == 1500.00 + + def test_normalize_invalid_amount(self): + """Test normalizing invalid amount returns None.""" + with patch.object(InferencePipeline, "__init__", lambda x, *args, **kwargs: None): + pipeline = InferencePipeline.__new__(InferencePipeline) + + assert pipeline._normalize_amount_for_compare("invalid") is None + assert pipeline._normalize_amount_for_compare("") is None + + +class TestResultSerialization: + """Tests for result serialization with cross-validation.""" + + def test_to_json_with_cross_validation(self): + """Test JSON serialization includes cross-validation.""" + cv = CrossValidationResult( + is_valid=True, + ocr_match=True, + amount_match=True, + payment_line_ocr="12345678901234", + payment_line_amount="1500.00", + details=["OCR match", "Amount match"], + ) + + result = InferenceResult( + document_id="test-doc", + success=True, + fields={"InvoiceNumber": "INV-001"}, + cross_validation=cv, + ) + + json_data = result.to_json() + + assert "cross_validation" in json_data + assert json_data["cross_validation"]["is_valid"] is True + assert json_data["cross_validation"]["ocr_match"] is True + assert json_data["cross_validation"]["payment_line_ocr"] == "12345678901234" diff --git a/tests/integration/repositories/__init__.py b/tests/integration/repositories/__init__.py new file mode 100644 index 0000000..97dcb2d --- /dev/null +++ b/tests/integration/repositories/__init__.py @@ -0,0 +1 @@ +"""Repository integration tests.""" diff --git a/tests/integration/repositories/test_annotation_repo_integration.py b/tests/integration/repositories/test_annotation_repo_integration.py new file mode 100644 index 0000000..c38024f --- /dev/null +++ b/tests/integration/repositories/test_annotation_repo_integration.py @@ -0,0 +1,464 @@ +""" +Annotation Repository Integration Tests + +Tests AnnotationRepository with real database operations. +""" + +from uuid import uuid4 + +import pytest + +from inference.data.repositories.annotation_repository import AnnotationRepository + + +class TestAnnotationRepositoryCreate: + """Tests for annotation creation.""" + + def test_create_annotation(self, patched_session, sample_document): + """Test creating a single annotation.""" + repo = AnnotationRepository() + + ann_id = repo.create( + document_id=str(sample_document.document_id), + page_number=1, + class_id=0, + class_name="invoice_number", + x_center=0.5, + y_center=0.3, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=240, + bbox_width=160, + bbox_height=40, + text_value="INV-2024-001", + confidence=0.95, + source="auto", + ) + + assert ann_id is not None + + ann = repo.get(ann_id) + assert ann is not None + assert ann.class_name == "invoice_number" + assert ann.text_value == "INV-2024-001" + assert ann.confidence == 0.95 + assert ann.source == "auto" + + def test_create_batch_annotations(self, patched_session, sample_document): + """Test batch creation of annotations.""" + repo = AnnotationRepository() + + annotations_data = [ + { + "document_id": str(sample_document.document_id), + "page_number": 1, + "class_id": 0, + "class_name": "invoice_number", + "x_center": 0.5, + "y_center": 0.1, + "width": 0.2, + "height": 0.05, + "bbox_x": 400, + "bbox_y": 80, + "bbox_width": 160, + "bbox_height": 40, + "text_value": "INV-001", + "confidence": 0.95, + }, + { + "document_id": str(sample_document.document_id), + "page_number": 1, + "class_id": 1, + "class_name": "invoice_date", + "x_center": 0.5, + "y_center": 0.2, + "width": 0.15, + "height": 0.04, + "bbox_x": 400, + "bbox_y": 160, + "bbox_width": 120, + "bbox_height": 32, + "text_value": "2024-01-15", + "confidence": 0.92, + }, + { + "document_id": str(sample_document.document_id), + "page_number": 1, + "class_id": 6, + "class_name": "amount", + "x_center": 0.7, + "y_center": 0.8, + "width": 0.1, + "height": 0.04, + "bbox_x": 560, + "bbox_y": 640, + "bbox_width": 80, + "bbox_height": 32, + "text_value": "1500.00", + "confidence": 0.98, + }, + ] + + ids = repo.create_batch(annotations_data) + + assert len(ids) == 3 + + # Verify all annotations exist + for ann_id in ids: + ann = repo.get(ann_id) + assert ann is not None + + +class TestAnnotationRepositoryRead: + """Tests for annotation retrieval.""" + + def test_get_nonexistent_annotation(self, patched_session): + """Test getting an annotation that doesn't exist.""" + repo = AnnotationRepository() + + ann = repo.get(str(uuid4())) + assert ann is None + + def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation): + """Test getting all annotations for a document.""" + repo = AnnotationRepository() + + # Add another annotation + repo.create( + document_id=str(sample_document.document_id), + page_number=1, + class_id=1, + class_name="invoice_date", + x_center=0.5, + y_center=0.4, + width=0.15, + height=0.04, + bbox_x=400, + bbox_y=320, + bbox_width=120, + bbox_height=32, + text_value="2024-01-15", + ) + + annotations = repo.get_for_document(str(sample_document.document_id)) + + assert len(annotations) == 2 + # Should be ordered by class_id + assert annotations[0].class_id == 0 + assert annotations[1].class_id == 1 + + def test_get_annotations_for_specific_page(self, patched_session, sample_document): + """Test getting annotations for a specific page.""" + repo = AnnotationRepository() + + # Create annotations on different pages + repo.create( + document_id=str(sample_document.document_id), + page_number=1, + class_id=0, + class_name="invoice_number", + x_center=0.5, + y_center=0.1, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=80, + bbox_width=160, + bbox_height=40, + ) + repo.create( + document_id=str(sample_document.document_id), + page_number=2, + class_id=6, + class_name="amount", + x_center=0.7, + y_center=0.8, + width=0.1, + height=0.04, + bbox_x=560, + bbox_y=640, + bbox_width=80, + bbox_height=32, + ) + + page1_annotations = repo.get_for_document( + str(sample_document.document_id), + page_number=1, + ) + page2_annotations = repo.get_for_document( + str(sample_document.document_id), + page_number=2, + ) + + assert len(page1_annotations) == 1 + assert len(page2_annotations) == 1 + assert page1_annotations[0].page_number == 1 + assert page2_annotations[0].page_number == 2 + + +class TestAnnotationRepositoryUpdate: + """Tests for annotation updates.""" + + def test_update_annotation_bbox(self, patched_session, sample_annotation): + """Test updating annotation bounding box.""" + repo = AnnotationRepository() + + result = repo.update( + str(sample_annotation.annotation_id), + x_center=0.6, + y_center=0.4, + width=0.25, + height=0.06, + bbox_x=480, + bbox_y=320, + bbox_width=200, + bbox_height=48, + ) + + assert result is True + + ann = repo.get(str(sample_annotation.annotation_id)) + assert ann is not None + assert ann.x_center == 0.6 + assert ann.y_center == 0.4 + assert ann.bbox_x == 480 + assert ann.bbox_width == 200 + + def test_update_annotation_text(self, patched_session, sample_annotation): + """Test updating annotation text value.""" + repo = AnnotationRepository() + + result = repo.update( + str(sample_annotation.annotation_id), + text_value="INV-2024-002", + ) + + assert result is True + + ann = repo.get(str(sample_annotation.annotation_id)) + assert ann is not None + assert ann.text_value == "INV-2024-002" + + def test_update_annotation_class(self, patched_session, sample_annotation): + """Test updating annotation class.""" + repo = AnnotationRepository() + + result = repo.update( + str(sample_annotation.annotation_id), + class_id=1, + class_name="invoice_date", + ) + + assert result is True + + ann = repo.get(str(sample_annotation.annotation_id)) + assert ann is not None + assert ann.class_id == 1 + assert ann.class_name == "invoice_date" + + def test_update_nonexistent_annotation(self, patched_session): + """Test updating annotation that doesn't exist.""" + repo = AnnotationRepository() + + result = repo.update( + str(uuid4()), + text_value="new value", + ) + + assert result is False + + +class TestAnnotationRepositoryDelete: + """Tests for annotation deletion.""" + + def test_delete_annotation(self, patched_session, sample_annotation): + """Test deleting a single annotation.""" + repo = AnnotationRepository() + + result = repo.delete(str(sample_annotation.annotation_id)) + assert result is True + + ann = repo.get(str(sample_annotation.annotation_id)) + assert ann is None + + def test_delete_nonexistent_annotation(self, patched_session): + """Test deleting annotation that doesn't exist.""" + repo = AnnotationRepository() + + result = repo.delete(str(uuid4())) + assert result is False + + def test_delete_annotations_for_document(self, patched_session, sample_document): + """Test deleting all annotations for a document.""" + repo = AnnotationRepository() + + # Create multiple annotations + for i in range(3): + repo.create( + document_id=str(sample_document.document_id), + page_number=1, + class_id=i, + class_name=f"field_{i}", + x_center=0.5, + y_center=0.1 + i * 0.2, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=80 + i * 160, + bbox_width=160, + bbox_height=40, + ) + + # Delete all + count = repo.delete_for_document(str(sample_document.document_id)) + + assert count == 3 + + annotations = repo.get_for_document(str(sample_document.document_id)) + assert len(annotations) == 0 + + def test_delete_annotations_by_source(self, patched_session, sample_document): + """Test deleting annotations by source type.""" + repo = AnnotationRepository() + + # Create auto and manual annotations + repo.create( + document_id=str(sample_document.document_id), + page_number=1, + class_id=0, + class_name="invoice_number", + x_center=0.5, + y_center=0.1, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=80, + bbox_width=160, + bbox_height=40, + source="auto", + ) + repo.create( + document_id=str(sample_document.document_id), + page_number=1, + class_id=1, + class_name="invoice_date", + x_center=0.5, + y_center=0.2, + width=0.15, + height=0.04, + bbox_x=400, + bbox_y=160, + bbox_width=120, + bbox_height=32, + source="manual", + ) + + # Delete only auto annotations + count = repo.delete_for_document(str(sample_document.document_id), source="auto") + + assert count == 1 + + remaining = repo.get_for_document(str(sample_document.document_id)) + assert len(remaining) == 1 + assert remaining[0].source == "manual" + + +class TestAnnotationVerification: + """Tests for annotation verification.""" + + def test_verify_annotation(self, patched_session, admin_token, sample_annotation): + """Test marking annotation as verified.""" + repo = AnnotationRepository() + + ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token) + + assert ann is not None + assert ann.is_verified is True + assert ann.verified_by == admin_token.token + assert ann.verified_at is not None + + +class TestAnnotationOverride: + """Tests for annotation override functionality.""" + + def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation): + """Test overriding an auto-generated annotation.""" + repo = AnnotationRepository() + + # Override the annotation + ann = repo.override( + str(sample_annotation.annotation_id), + admin_token.token, + change_reason="Correcting OCR error", + text_value="INV-2024-CORRECTED", + x_center=0.55, + ) + + assert ann is not None + assert ann.text_value == "INV-2024-CORRECTED" + assert ann.x_center == 0.55 + assert ann.source == "manual" # Changed from auto to manual + assert ann.override_source == "auto" + + +class TestAnnotationHistory: + """Tests for annotation history tracking.""" + + def test_create_history_record(self, patched_session, sample_annotation): + """Test creating annotation history record.""" + repo = AnnotationRepository() + + history = repo.create_history( + annotation_id=sample_annotation.annotation_id, + document_id=sample_annotation.document_id, + action="created", + new_value={"text_value": "INV-001"}, + changed_by="test-user", + ) + + assert history is not None + assert history.action == "created" + assert history.changed_by == "test-user" + + def test_get_annotation_history(self, patched_session, sample_annotation): + """Test getting history for an annotation.""" + repo = AnnotationRepository() + + # Create history records + repo.create_history( + annotation_id=sample_annotation.annotation_id, + document_id=sample_annotation.document_id, + action="created", + new_value={"text_value": "INV-001"}, + ) + repo.create_history( + annotation_id=sample_annotation.annotation_id, + document_id=sample_annotation.document_id, + action="updated", + previous_value={"text_value": "INV-001"}, + new_value={"text_value": "INV-002"}, + ) + + history = repo.get_history(sample_annotation.annotation_id) + + assert len(history) == 2 + # Should be ordered by created_at desc + assert history[0].action == "updated" + assert history[1].action == "created" + + def test_get_document_history(self, patched_session, sample_document, sample_annotation): + """Test getting all annotation history for a document.""" + repo = AnnotationRepository() + + repo.create_history( + annotation_id=sample_annotation.annotation_id, + document_id=sample_document.document_id, + action="created", + new_value={"class_name": "invoice_number"}, + ) + + history = repo.get_document_history(sample_document.document_id) + + assert len(history) >= 1 + assert all(h.document_id == sample_document.document_id for h in history) diff --git a/tests/integration/repositories/test_batch_upload_repo_integration.py b/tests/integration/repositories/test_batch_upload_repo_integration.py new file mode 100644 index 0000000..5ece694 --- /dev/null +++ b/tests/integration/repositories/test_batch_upload_repo_integration.py @@ -0,0 +1,355 @@ +""" +Batch Upload Repository Integration Tests + +Tests BatchUploadRepository with real database operations. +""" + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest + +from inference.data.repositories.batch_upload_repository import BatchUploadRepository + + +class TestBatchUploadCreate: + """Tests for batch upload creation.""" + + def test_create_batch_upload(self, patched_session, admin_token): + """Test creating a batch upload.""" + repo = BatchUploadRepository() + + batch = repo.create( + admin_token=admin_token.token, + filename="test_batch.zip", + file_size=10240, + upload_source="api", + ) + + assert batch is not None + assert batch.batch_id is not None + assert batch.filename == "test_batch.zip" + assert batch.file_size == 10240 + assert batch.upload_source == "api" + assert batch.status == "processing" + assert batch.total_files == 0 + assert batch.processed_files == 0 + + def test_create_batch_upload_default_source(self, patched_session, admin_token): + """Test creating batch upload with default source.""" + repo = BatchUploadRepository() + + batch = repo.create( + admin_token=admin_token.token, + filename="ui_batch.zip", + file_size=5120, + ) + + assert batch.upload_source == "ui" + + +class TestBatchUploadRead: + """Tests for batch upload retrieval.""" + + def test_get_batch_upload(self, patched_session, sample_batch_upload): + """Test getting a batch upload by ID.""" + repo = BatchUploadRepository() + + batch = repo.get(sample_batch_upload.batch_id) + + assert batch is not None + assert batch.batch_id == sample_batch_upload.batch_id + assert batch.filename == sample_batch_upload.filename + + def test_get_nonexistent_batch_upload(self, patched_session): + """Test getting a batch upload that doesn't exist.""" + repo = BatchUploadRepository() + + batch = repo.get(uuid4()) + assert batch is None + + def test_get_paginated_batch_uploads(self, patched_session, admin_token): + """Test paginated batch upload listing.""" + repo = BatchUploadRepository() + + # Create multiple batches + for i in range(5): + repo.create( + admin_token=admin_token.token, + filename=f"batch_{i}.zip", + file_size=1024 * (i + 1), + ) + + batches, total = repo.get_paginated(limit=3, offset=0) + + assert total == 5 + assert len(batches) == 3 + + def test_get_paginated_with_offset(self, patched_session, admin_token): + """Test pagination offset.""" + repo = BatchUploadRepository() + + for i in range(5): + repo.create( + admin_token=admin_token.token, + filename=f"batch_{i}.zip", + file_size=1024, + ) + + page1, _ = repo.get_paginated(limit=2, offset=0) + page2, _ = repo.get_paginated(limit=2, offset=2) + + ids_page1 = {b.batch_id for b in page1} + ids_page2 = {b.batch_id for b in page2} + + assert len(ids_page1 & ids_page2) == 0 + + +class TestBatchUploadUpdate: + """Tests for batch upload updates.""" + + def test_update_batch_status(self, patched_session, sample_batch_upload): + """Test updating batch upload status.""" + repo = BatchUploadRepository() + + repo.update( + sample_batch_upload.batch_id, + status="completed", + total_files=10, + processed_files=10, + successful_files=8, + failed_files=2, + ) + + # Need to commit to see changes + patched_session.commit() + + batch = repo.get(sample_batch_upload.batch_id) + assert batch.status == "completed" + assert batch.total_files == 10 + assert batch.successful_files == 8 + assert batch.failed_files == 2 + + def test_update_batch_with_error(self, patched_session, sample_batch_upload): + """Test updating batch upload with error message.""" + repo = BatchUploadRepository() + + repo.update( + sample_batch_upload.batch_id, + status="failed", + error_message="ZIP extraction failed", + ) + + patched_session.commit() + + batch = repo.get(sample_batch_upload.batch_id) + assert batch.status == "failed" + assert batch.error_message == "ZIP extraction failed" + + def test_update_batch_csv_info(self, patched_session, sample_batch_upload): + """Test updating batch with CSV information.""" + repo = BatchUploadRepository() + + repo.update( + sample_batch_upload.batch_id, + csv_filename="manifest.csv", + csv_row_count=100, + ) + + patched_session.commit() + + batch = repo.get(sample_batch_upload.batch_id) + assert batch.csv_filename == "manifest.csv" + assert batch.csv_row_count == 100 + + +class TestBatchUploadFiles: + """Tests for batch upload file management.""" + + def test_create_batch_file(self, patched_session, sample_batch_upload): + """Test creating a batch upload file record.""" + repo = BatchUploadRepository() + + file_record = repo.create_file( + batch_id=sample_batch_upload.batch_id, + filename="invoice_001.pdf", + status="pending", + ) + + assert file_record is not None + assert file_record.file_id is not None + assert file_record.filename == "invoice_001.pdf" + assert file_record.batch_id == sample_batch_upload.batch_id + assert file_record.status == "pending" + + def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document): + """Test creating batch file linked to a document.""" + repo = BatchUploadRepository() + + file_record = repo.create_file( + batch_id=sample_batch_upload.batch_id, + filename="invoice_linked.pdf", + document_id=sample_document.document_id, + status="completed", + annotation_count=5, + ) + + assert file_record.document_id == sample_document.document_id + assert file_record.status == "completed" + assert file_record.annotation_count == 5 + + def test_get_batch_files(self, patched_session, sample_batch_upload): + """Test getting all files for a batch.""" + repo = BatchUploadRepository() + + # Create multiple files + for i in range(3): + repo.create_file( + batch_id=sample_batch_upload.batch_id, + filename=f"file_{i}.pdf", + ) + + files = repo.get_files(sample_batch_upload.batch_id) + + assert len(files) == 3 + assert all(f.batch_id == sample_batch_upload.batch_id for f in files) + + def test_get_batch_files_empty(self, patched_session, sample_batch_upload): + """Test getting files for batch with no files.""" + repo = BatchUploadRepository() + + files = repo.get_files(sample_batch_upload.batch_id) + + assert files == [] + + def test_update_batch_file_status(self, patched_session, sample_batch_upload): + """Test updating batch file status.""" + repo = BatchUploadRepository() + + file_record = repo.create_file( + batch_id=sample_batch_upload.batch_id, + filename="test.pdf", + ) + + repo.update_file( + file_record.file_id, + status="completed", + annotation_count=10, + ) + + patched_session.commit() + + files = repo.get_files(sample_batch_upload.batch_id) + updated_file = files[0] + assert updated_file.status == "completed" + assert updated_file.annotation_count == 10 + + def test_update_batch_file_with_error(self, patched_session, sample_batch_upload): + """Test updating batch file with error.""" + repo = BatchUploadRepository() + + file_record = repo.create_file( + batch_id=sample_batch_upload.batch_id, + filename="corrupt.pdf", + ) + + repo.update_file( + file_record.file_id, + status="failed", + error_message="Invalid PDF format", + ) + + patched_session.commit() + + files = repo.get_files(sample_batch_upload.batch_id) + updated_file = files[0] + assert updated_file.status == "failed" + assert updated_file.error_message == "Invalid PDF format" + + def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload): + """Test updating batch file with CSV row data.""" + repo = BatchUploadRepository() + + file_record = repo.create_file( + batch_id=sample_batch_upload.batch_id, + filename="invoice_with_csv.pdf", + ) + + csv_data = { + "invoice_number": "INV-001", + "amount": "1500.00", + "supplier": "Test Corp", + } + + repo.update_file( + file_record.file_id, + csv_row_data=csv_data, + ) + + patched_session.commit() + + files = repo.get_files(sample_batch_upload.batch_id) + updated_file = files[0] + assert updated_file.csv_row_data == csv_data + + +class TestBatchUploadWorkflow: + """Tests for complete batch upload workflows.""" + + def test_complete_batch_workflow(self, patched_session, admin_token): + """Test complete batch upload workflow.""" + repo = BatchUploadRepository() + + # 1. Create batch + batch = repo.create( + admin_token=admin_token.token, + filename="full_workflow.zip", + file_size=50000, + ) + + # 2. Update with file count + repo.update(batch.batch_id, total_files=3) + patched_session.commit() + + # 3. Create file records + file_ids = [] + for i in range(3): + file_record = repo.create_file( + batch_id=batch.batch_id, + filename=f"doc_{i}.pdf", + ) + file_ids.append(file_record.file_id) + + # 4. Process files one by one + for i, file_id in enumerate(file_ids): + status = "completed" if i < 2 else "failed" + repo.update_file( + file_id, + status=status, + annotation_count=5 if status == "completed" else 0, + ) + + # 5. Update batch progress + repo.update( + batch.batch_id, + processed_files=3, + successful_files=2, + failed_files=1, + status="partial", + ) + patched_session.commit() + + # Verify final state + final_batch = repo.get(batch.batch_id) + assert final_batch.status == "partial" + assert final_batch.total_files == 3 + assert final_batch.processed_files == 3 + assert final_batch.successful_files == 2 + assert final_batch.failed_files == 1 + + files = repo.get_files(batch.batch_id) + assert len(files) == 3 + completed = [f for f in files if f.status == "completed"] + failed = [f for f in files if f.status == "failed"] + assert len(completed) == 2 + assert len(failed) == 1 diff --git a/tests/integration/repositories/test_dataset_repo_integration.py b/tests/integration/repositories/test_dataset_repo_integration.py new file mode 100644 index 0000000..dafddf4 --- /dev/null +++ b/tests/integration/repositories/test_dataset_repo_integration.py @@ -0,0 +1,321 @@ +""" +Dataset Repository Integration Tests + +Tests DatasetRepository with real database operations. +""" + +from uuid import uuid4 + +import pytest + +from inference.data.repositories.dataset_repository import DatasetRepository + + +class TestDatasetRepositoryCreate: + """Tests for dataset creation.""" + + def test_create_dataset(self, patched_session): + """Test creating a training dataset.""" + repo = DatasetRepository() + + dataset = repo.create( + name="Test Dataset", + description="Dataset for integration testing", + train_ratio=0.8, + val_ratio=0.1, + seed=42, + ) + + assert dataset is not None + assert dataset.name == "Test Dataset" + assert dataset.description == "Dataset for integration testing" + assert dataset.train_ratio == 0.8 + assert dataset.val_ratio == 0.1 + assert dataset.seed == 42 + assert dataset.status == "building" + + def test_create_dataset_with_defaults(self, patched_session): + """Test creating dataset with default values.""" + repo = DatasetRepository() + + dataset = repo.create(name="Minimal Dataset") + + assert dataset is not None + assert dataset.train_ratio == 0.8 + assert dataset.val_ratio == 0.1 + assert dataset.seed == 42 + + +class TestDatasetRepositoryRead: + """Tests for dataset retrieval.""" + + def test_get_dataset_by_id(self, patched_session, sample_dataset): + """Test getting dataset by ID.""" + repo = DatasetRepository() + + dataset = repo.get(str(sample_dataset.dataset_id)) + + assert dataset is not None + assert dataset.dataset_id == sample_dataset.dataset_id + assert dataset.name == sample_dataset.name + + def test_get_nonexistent_dataset(self, patched_session): + """Test getting dataset that doesn't exist.""" + repo = DatasetRepository() + + dataset = repo.get(str(uuid4())) + assert dataset is None + + def test_get_paginated_datasets(self, patched_session): + """Test paginated dataset listing.""" + repo = DatasetRepository() + + # Create multiple datasets + for i in range(5): + repo.create(name=f"Dataset {i}") + + datasets, total = repo.get_paginated(limit=2, offset=0) + + assert total == 5 + assert len(datasets) == 2 + + def test_get_paginated_with_status_filter(self, patched_session): + """Test filtering datasets by status.""" + repo = DatasetRepository() + + # Create datasets with different statuses + d1 = repo.create(name="Building Dataset") + repo.update_status(str(d1.dataset_id), "ready") + + d2 = repo.create(name="Another Building Dataset") + # stays as "building" + + datasets, total = repo.get_paginated(status="ready") + + assert total == 1 + assert datasets[0].status == "ready" + + +class TestDatasetRepositoryUpdate: + """Tests for dataset updates.""" + + def test_update_status(self, patched_session, sample_dataset): + """Test updating dataset status.""" + repo = DatasetRepository() + + repo.update_status( + str(sample_dataset.dataset_id), + status="ready", + total_documents=100, + total_images=150, + total_annotations=500, + ) + + dataset = repo.get(str(sample_dataset.dataset_id)) + assert dataset is not None + assert dataset.status == "ready" + assert dataset.total_documents == 100 + assert dataset.total_images == 150 + assert dataset.total_annotations == 500 + + def test_update_status_with_error(self, patched_session, sample_dataset): + """Test updating dataset status with error message.""" + repo = DatasetRepository() + + repo.update_status( + str(sample_dataset.dataset_id), + status="failed", + error_message="Failed to build dataset: insufficient documents", + ) + + dataset = repo.get(str(sample_dataset.dataset_id)) + assert dataset is not None + assert dataset.status == "failed" + assert "insufficient documents" in dataset.error_message + + def test_update_status_with_path(self, patched_session, sample_dataset): + """Test updating dataset path.""" + repo = DatasetRepository() + + repo.update_status( + str(sample_dataset.dataset_id), + status="ready", + dataset_path="/datasets/test_dataset_2024", + ) + + dataset = repo.get(str(sample_dataset.dataset_id)) + assert dataset is not None + assert dataset.dataset_path == "/datasets/test_dataset_2024" + + def test_update_training_status(self, patched_session, sample_dataset, sample_training_task): + """Test updating dataset training status.""" + repo = DatasetRepository() + + repo.update_training_status( + str(sample_dataset.dataset_id), + training_status="running", + active_training_task_id=str(sample_training_task.task_id), + ) + + dataset = repo.get(str(sample_dataset.dataset_id)) + assert dataset is not None + assert dataset.training_status == "running" + assert dataset.active_training_task_id == sample_training_task.task_id + + def test_update_training_status_completed(self, patched_session, sample_dataset): + """Test updating training status to completed updates main status.""" + repo = DatasetRepository() + + # First set to ready + repo.update_status(str(sample_dataset.dataset_id), status="ready") + + # Then complete training + repo.update_training_status( + str(sample_dataset.dataset_id), + training_status="completed", + update_main_status=True, + ) + + dataset = repo.get(str(sample_dataset.dataset_id)) + assert dataset is not None + assert dataset.training_status == "completed" + assert dataset.status == "trained" + + +class TestDatasetDocuments: + """Tests for dataset document management.""" + + def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents): + """Test adding documents to a dataset.""" + repo = DatasetRepository() + + documents_data = [ + { + "document_id": str(multiple_documents[0].document_id), + "split": "train", + "page_count": 1, + "annotation_count": 5, + }, + { + "document_id": str(multiple_documents[1].document_id), + "split": "train", + "page_count": 2, + "annotation_count": 8, + }, + { + "document_id": str(multiple_documents[2].document_id), + "split": "val", + "page_count": 1, + "annotation_count": 3, + }, + ] + + repo.add_documents(str(sample_dataset.dataset_id), documents_data) + + # Verify documents were added + docs = repo.get_documents(str(sample_dataset.dataset_id)) + assert len(docs) == 3 + + train_docs = [d for d in docs if d.split == "train"] + val_docs = [d for d in docs if d.split == "val"] + + assert len(train_docs) == 2 + assert len(val_docs) == 1 + + def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document): + """Test getting documents from a dataset.""" + repo = DatasetRepository() + + repo.add_documents( + str(sample_dataset.dataset_id), + [ + { + "document_id": str(sample_document.document_id), + "split": "train", + "page_count": 1, + "annotation_count": 5, + } + ], + ) + + docs = repo.get_documents(str(sample_dataset.dataset_id)) + + assert len(docs) == 1 + assert docs[0].document_id == sample_document.document_id + assert docs[0].split == "train" + assert docs[0].page_count == 1 + assert docs[0].annotation_count == 5 + + +class TestDatasetRepositoryDelete: + """Tests for dataset deletion.""" + + def test_delete_dataset(self, patched_session, sample_dataset): + """Test deleting a dataset.""" + repo = DatasetRepository() + + result = repo.delete(str(sample_dataset.dataset_id)) + assert result is True + + dataset = repo.get(str(sample_dataset.dataset_id)) + assert dataset is None + + def test_delete_nonexistent_dataset(self, patched_session): + """Test deleting dataset that doesn't exist.""" + repo = DatasetRepository() + + result = repo.delete(str(uuid4())) + assert result is False + + def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document): + """Test deleting dataset also removes document links.""" + repo = DatasetRepository() + + # Add document to dataset + repo.add_documents( + str(sample_dataset.dataset_id), + [ + { + "document_id": str(sample_document.document_id), + "split": "train", + "page_count": 1, + "annotation_count": 5, + } + ], + ) + + # Delete dataset + repo.delete(str(sample_dataset.dataset_id)) + + # Document links should be gone + docs = repo.get_documents(str(sample_dataset.dataset_id)) + assert len(docs) == 0 + + +class TestActiveTrainingTasks: + """Tests for active training task queries.""" + + def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task): + """Test getting active training tasks for datasets.""" + repo = DatasetRepository() + + # Update task to running + from inference.data.repositories.training_task_repository import TrainingTaskRepository + + task_repo = TrainingTaskRepository() + task_repo.update_status(str(sample_training_task.task_id), "running") + + result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)]) + + assert str(sample_dataset.dataset_id) in result + assert result[str(sample_dataset.dataset_id)]["status"] == "running" + + def test_get_active_training_tasks_empty(self, patched_session, sample_dataset): + """Test getting active training tasks returns empty when no tasks exist.""" + repo = DatasetRepository() + + result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)]) + + # No training task exists for this dataset, so result should be empty + assert str(sample_dataset.dataset_id) not in result + assert result == {} diff --git a/tests/integration/repositories/test_document_repo_integration.py b/tests/integration/repositories/test_document_repo_integration.py new file mode 100644 index 0000000..a8f812a --- /dev/null +++ b/tests/integration/repositories/test_document_repo_integration.py @@ -0,0 +1,350 @@ +""" +Document Repository Integration Tests + +Tests DocumentRepository with real database operations. +""" + +from datetime import datetime, timezone, timedelta +from uuid import uuid4 + +import pytest +from sqlmodel import select + +from inference.data.admin_models import AdminAnnotation, AdminDocument +from inference.data.repositories.document_repository import DocumentRepository + + +def ensure_utc(dt: datetime | None) -> datetime | None: + """Ensure datetime is timezone-aware (UTC). + + PostgreSQL may return offset-naive datetimes. This helper + converts them to UTC for proper comparison. + """ + if dt is None: + return None + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt + + +class TestDocumentRepositoryCreate: + """Tests for document creation.""" + + def test_create_document(self, patched_session): + """Test creating a document and retrieving it.""" + repo = DocumentRepository() + + doc_id = repo.create( + filename="test_invoice.pdf", + file_size=2048, + content_type="application/pdf", + file_path="/uploads/test_invoice.pdf", + page_count=2, + upload_source="api", + category="invoice", + ) + + assert doc_id is not None + + doc = repo.get(doc_id) + assert doc is not None + assert doc.filename == "test_invoice.pdf" + assert doc.file_size == 2048 + assert doc.page_count == 2 + assert doc.upload_source == "api" + assert doc.category == "invoice" + assert doc.status == "pending" + + def test_create_document_with_csv_values(self, patched_session): + """Test creating document with CSV field values.""" + repo = DocumentRepository() + + csv_values = { + "invoice_number": "INV-001", + "amount": "1500.00", + "supplier_name": "Test Supplier AB", + } + + doc_id = repo.create( + filename="invoice_with_csv.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/uploads/invoice_with_csv.pdf", + csv_field_values=csv_values, + ) + + doc = repo.get(doc_id) + assert doc is not None + assert doc.csv_field_values == csv_values + + def test_create_document_with_group_key(self, patched_session): + """Test creating document with group key.""" + repo = DocumentRepository() + + doc_id = repo.create( + filename="grouped_doc.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/uploads/grouped_doc.pdf", + group_key="batch-2024-01", + ) + + doc = repo.get(doc_id) + assert doc is not None + assert doc.group_key == "batch-2024-01" + + +class TestDocumentRepositoryRead: + """Tests for document retrieval.""" + + def test_get_nonexistent_document(self, patched_session): + """Test getting a document that doesn't exist.""" + repo = DocumentRepository() + + doc = repo.get(str(uuid4())) + assert doc is None + + def test_get_paginated_documents(self, patched_session, multiple_documents): + """Test paginated document listing.""" + repo = DocumentRepository() + + docs, total = repo.get_paginated(limit=2, offset=0) + + assert total == 5 + assert len(docs) == 2 + + def test_get_paginated_with_status_filter(self, patched_session, multiple_documents): + """Test filtering documents by status.""" + repo = DocumentRepository() + + docs, total = repo.get_paginated(status="labeled") + + assert total == 2 + for doc in docs: + assert doc.status == "labeled" + + def test_get_paginated_with_category_filter(self, patched_session, multiple_documents): + """Test filtering documents by category.""" + repo = DocumentRepository() + + docs, total = repo.get_paginated(category="letter") + + assert total == 1 + assert docs[0].category == "letter" + + def test_get_paginated_with_offset(self, patched_session, multiple_documents): + """Test pagination offset.""" + repo = DocumentRepository() + + docs_page1, _ = repo.get_paginated(limit=2, offset=0) + docs_page2, _ = repo.get_paginated(limit=2, offset=2) + + doc_ids_page1 = {str(d.document_id) for d in docs_page1} + doc_ids_page2 = {str(d.document_id) for d in docs_page2} + + assert len(doc_ids_page1 & doc_ids_page2) == 0 + + def test_get_by_ids(self, patched_session, multiple_documents): + """Test getting multiple documents by IDs.""" + repo = DocumentRepository() + + ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)] + docs = repo.get_by_ids(ids_to_fetch) + + assert len(docs) == 2 + fetched_ids = {str(d.document_id) for d in docs} + assert fetched_ids == set(ids_to_fetch) + + +class TestDocumentRepositoryUpdate: + """Tests for document updates.""" + + def test_update_status(self, patched_session, sample_document): + """Test updating document status.""" + repo = DocumentRepository() + + repo.update_status( + str(sample_document.document_id), + status="labeled", + auto_label_status="completed", + ) + + doc = repo.get(str(sample_document.document_id)) + assert doc is not None + assert doc.status == "labeled" + assert doc.auto_label_status == "completed" + + def test_update_status_with_error(self, patched_session, sample_document): + """Test updating document status with error message.""" + repo = DocumentRepository() + + repo.update_status( + str(sample_document.document_id), + status="pending", + auto_label_status="failed", + auto_label_error="OCR extraction failed", + ) + + doc = repo.get(str(sample_document.document_id)) + assert doc is not None + assert doc.auto_label_status == "failed" + assert doc.auto_label_error == "OCR extraction failed" + + def test_update_file_path(self, patched_session, sample_document): + """Test updating document file path.""" + repo = DocumentRepository() + + new_path = "/archive/2024/test_invoice.pdf" + repo.update_file_path(str(sample_document.document_id), new_path) + + doc = repo.get(str(sample_document.document_id)) + assert doc is not None + assert doc.file_path == new_path + + def test_update_group_key(self, patched_session, sample_document): + """Test updating document group key.""" + repo = DocumentRepository() + + result = repo.update_group_key(str(sample_document.document_id), "new-group-key") + assert result is True + + doc = repo.get(str(sample_document.document_id)) + assert doc is not None + assert doc.group_key == "new-group-key" + + def test_update_category(self, patched_session, sample_document): + """Test updating document category.""" + repo = DocumentRepository() + + doc = repo.update_category(str(sample_document.document_id), "letter") + + assert doc is not None + assert doc.category == "letter" + + +class TestDocumentRepositoryDelete: + """Tests for document deletion.""" + + def test_delete_document(self, patched_session, sample_document): + """Test deleting a document.""" + repo = DocumentRepository() + + result = repo.delete(str(sample_document.document_id)) + assert result is True + + doc = repo.get(str(sample_document.document_id)) + assert doc is None + + def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation): + """Test deleting document also deletes its annotations.""" + repo = DocumentRepository() + + result = repo.delete(str(sample_document.document_id)) + assert result is True + + # Verify annotation is also deleted + from inference.data.repositories.annotation_repository import AnnotationRepository + + ann_repo = AnnotationRepository() + annotations = ann_repo.get_for_document(str(sample_document.document_id)) + assert len(annotations) == 0 + + def test_delete_nonexistent_document(self, patched_session): + """Test deleting a document that doesn't exist.""" + repo = DocumentRepository() + + result = repo.delete(str(uuid4())) + assert result is False + + +class TestDocumentRepositoryQueries: + """Tests for complex document queries.""" + + def test_count_by_status(self, patched_session, multiple_documents): + """Test counting documents by status.""" + repo = DocumentRepository() + + counts = repo.count_by_status() + + assert counts.get("pending") == 2 + assert counts.get("labeled") == 2 + assert counts.get("exported") == 1 + + def test_get_categories(self, patched_session, multiple_documents): + """Test getting unique categories.""" + repo = DocumentRepository() + + categories = repo.get_categories() + + assert "invoice" in categories + assert "letter" in categories + + def test_get_labeled_for_export(self, patched_session, multiple_documents): + """Test getting labeled documents for export.""" + repo = DocumentRepository() + + docs = repo.get_labeled_for_export() + + assert len(docs) == 2 + for doc in docs: + assert doc.status == "labeled" + + +class TestDocumentAnnotationLocking: + """Tests for annotation locking mechanism.""" + + def test_acquire_annotation_lock(self, patched_session, sample_document): + """Test acquiring annotation lock.""" + repo = DocumentRepository() + + doc = repo.acquire_annotation_lock( + str(sample_document.document_id), + duration_seconds=300, + ) + + assert doc is not None + assert doc.annotation_lock_until is not None + lock_until = ensure_utc(doc.annotation_lock_until) + assert lock_until > datetime.now(timezone.utc) + + def test_acquire_lock_when_already_locked(self, patched_session, sample_document): + """Test acquiring lock fails when already locked.""" + repo = DocumentRepository() + + # First lock + repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300) + + # Second lock attempt should fail + result = repo.acquire_annotation_lock(str(sample_document.document_id)) + assert result is None + + def test_release_annotation_lock(self, patched_session, sample_document): + """Test releasing annotation lock.""" + repo = DocumentRepository() + + repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300) + doc = repo.release_annotation_lock(str(sample_document.document_id)) + + assert doc is not None + assert doc.annotation_lock_until is None + + def test_extend_annotation_lock(self, patched_session, sample_document): + """Test extending annotation lock.""" + repo = DocumentRepository() + + # Acquire initial lock + initial_doc = repo.acquire_annotation_lock( + str(sample_document.document_id), + duration_seconds=300, + ) + initial_expiry = ensure_utc(initial_doc.annotation_lock_until) + + # Extend lock + extended_doc = repo.extend_annotation_lock( + str(sample_document.document_id), + additional_seconds=300, + ) + + assert extended_doc is not None + extended_expiry = ensure_utc(extended_doc.annotation_lock_until) + assert extended_expiry > initial_expiry diff --git a/tests/integration/repositories/test_model_version_repo_integration.py b/tests/integration/repositories/test_model_version_repo_integration.py new file mode 100644 index 0000000..6353000 --- /dev/null +++ b/tests/integration/repositories/test_model_version_repo_integration.py @@ -0,0 +1,310 @@ +""" +Model Version Repository Integration Tests + +Tests ModelVersionRepository with real database operations. +""" + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest + +from inference.data.repositories.model_version_repository import ModelVersionRepository + + +class TestModelVersionCreate: + """Tests for model version creation.""" + + def test_create_model_version(self, patched_session): + """Test creating a model version.""" + repo = ModelVersionRepository() + + model = repo.create( + version="1.0.0", + name="Invoice Extractor v1", + model_path="/models/invoice_v1.pt", + description="Initial production model", + metrics_mAP=0.92, + metrics_precision=0.89, + metrics_recall=0.85, + document_count=1000, + file_size=50000000, + ) + + assert model is not None + assert model.version == "1.0.0" + assert model.name == "Invoice Extractor v1" + assert model.model_path == "/models/invoice_v1.pt" + assert model.metrics_mAP == 0.92 + assert model.is_active is False + assert model.status == "inactive" + + def test_create_model_version_with_training_info( + self, patched_session, sample_training_task, sample_dataset + ): + """Test creating model version linked to training task and dataset.""" + repo = ModelVersionRepository() + + model = repo.create( + version="1.1.0", + name="Invoice Extractor v1.1", + model_path="/models/invoice_v1.1.pt", + task_id=sample_training_task.task_id, + dataset_id=sample_dataset.dataset_id, + training_config={"epochs": 100, "batch_size": 16}, + trained_at=datetime.now(timezone.utc), + ) + + assert model is not None + assert model.task_id == sample_training_task.task_id + assert model.dataset_id == sample_dataset.dataset_id + assert model.training_config["epochs"] == 100 + + +class TestModelVersionRead: + """Tests for model version retrieval.""" + + def test_get_model_version_by_id(self, patched_session, sample_model_version): + """Test getting model version by ID.""" + repo = ModelVersionRepository() + + model = repo.get(str(sample_model_version.version_id)) + + assert model is not None + assert model.version_id == sample_model_version.version_id + + def test_get_nonexistent_model_version(self, patched_session): + """Test getting model version that doesn't exist.""" + repo = ModelVersionRepository() + + model = repo.get(str(uuid4())) + assert model is None + + def test_get_paginated_model_versions(self, patched_session): + """Test paginated model version listing.""" + repo = ModelVersionRepository() + + # Create multiple versions + for i in range(5): + repo.create( + version=f"1.{i}.0", + name=f"Model v1.{i}", + model_path=f"/models/model_v1.{i}.pt", + ) + + models, total = repo.get_paginated(limit=2, offset=0) + + assert total == 5 + assert len(models) == 2 + + def test_get_paginated_with_status_filter(self, patched_session): + """Test filtering model versions by status.""" + repo = ModelVersionRepository() + + # Create active and inactive models + m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt") + repo.activate(str(m1.version_id)) + + repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt") + + active_models, active_total = repo.get_paginated(status="active") + inactive_models, inactive_total = repo.get_paginated(status="inactive") + + assert active_total == 1 + assert inactive_total == 1 + + +class TestModelVersionActivation: + """Tests for model version activation.""" + + def test_activate_model_version(self, patched_session, sample_model_version): + """Test activating a model version.""" + repo = ModelVersionRepository() + + model = repo.activate(str(sample_model_version.version_id)) + + assert model is not None + assert model.is_active is True + assert model.status == "active" + assert model.activated_at is not None + + def test_activate_deactivates_others(self, patched_session): + """Test that activating one version deactivates others.""" + repo = ModelVersionRepository() + + # Create and activate first model + m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt") + repo.activate(str(m1.version_id)) + + # Create and activate second model + m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt") + repo.activate(str(m2.version_id)) + + # Check first model is now inactive + m1_after = repo.get(str(m1.version_id)) + assert m1_after.is_active is False + assert m1_after.status == "inactive" + + # Check second model is active + m2_after = repo.get(str(m2.version_id)) + assert m2_after.is_active is True + + def test_get_active_model(self, patched_session, sample_model_version): + """Test getting the currently active model.""" + repo = ModelVersionRepository() + + # Initially no active model + active = repo.get_active() + assert active is None + + # Activate model + repo.activate(str(sample_model_version.version_id)) + + # Now should return active model + active = repo.get_active() + assert active is not None + assert active.version_id == sample_model_version.version_id + + def test_deactivate_model_version(self, patched_session, sample_model_version): + """Test deactivating a model version.""" + repo = ModelVersionRepository() + + # First activate + repo.activate(str(sample_model_version.version_id)) + + # Then deactivate + model = repo.deactivate(str(sample_model_version.version_id)) + + assert model is not None + assert model.is_active is False + assert model.status == "inactive" + + +class TestModelVersionUpdate: + """Tests for model version updates.""" + + def test_update_model_metadata(self, patched_session, sample_model_version): + """Test updating model version metadata.""" + repo = ModelVersionRepository() + + model = repo.update( + str(sample_model_version.version_id), + name="Updated Model Name", + description="Updated description", + ) + + assert model is not None + assert model.name == "Updated Model Name" + assert model.description == "Updated description" + + def test_update_model_status(self, patched_session, sample_model_version): + """Test updating model version status.""" + repo = ModelVersionRepository() + + model = repo.update(str(sample_model_version.version_id), status="deprecated") + + assert model is not None + assert model.status == "deprecated" + + def test_update_nonexistent_model(self, patched_session): + """Test updating model that doesn't exist.""" + repo = ModelVersionRepository() + + model = repo.update(str(uuid4()), name="New Name") + assert model is None + + +class TestModelVersionArchive: + """Tests for model version archiving.""" + + def test_archive_model_version(self, patched_session, sample_model_version): + """Test archiving an inactive model version.""" + repo = ModelVersionRepository() + + model = repo.archive(str(sample_model_version.version_id)) + + assert model is not None + assert model.status == "archived" + + def test_cannot_archive_active_model(self, patched_session, sample_model_version): + """Test that active model cannot be archived.""" + repo = ModelVersionRepository() + + # Activate the model + repo.activate(str(sample_model_version.version_id)) + + # Try to archive + model = repo.archive(str(sample_model_version.version_id)) + + assert model is None + + # Verify model is still active + current = repo.get(str(sample_model_version.version_id)) + assert current.status == "active" + + +class TestModelVersionDelete: + """Tests for model version deletion.""" + + def test_delete_inactive_model(self, patched_session, sample_model_version): + """Test deleting an inactive model version.""" + repo = ModelVersionRepository() + + result = repo.delete(str(sample_model_version.version_id)) + + assert result is True + + model = repo.get(str(sample_model_version.version_id)) + assert model is None + + def test_cannot_delete_active_model(self, patched_session, sample_model_version): + """Test that active model cannot be deleted.""" + repo = ModelVersionRepository() + + # Activate the model + repo.activate(str(sample_model_version.version_id)) + + # Try to delete + result = repo.delete(str(sample_model_version.version_id)) + + assert result is False + + # Verify model still exists + model = repo.get(str(sample_model_version.version_id)) + assert model is not None + + def test_delete_nonexistent_model(self, patched_session): + """Test deleting model that doesn't exist.""" + repo = ModelVersionRepository() + + result = repo.delete(str(uuid4())) + assert result is False + + +class TestOnlyOneActiveModel: + """Tests to verify only one model can be active at a time.""" + + def test_single_active_model_constraint(self, patched_session): + """Test that only one model can be active at any time.""" + repo = ModelVersionRepository() + + # Create multiple models + models = [] + for i in range(3): + m = repo.create( + version=f"1.{i}.0", + name=f"Model {i}", + model_path=f"/models/model_{i}.pt", + ) + models.append(m) + + # Activate each model in sequence + for model in models: + repo.activate(str(model.version_id)) + + # Count active models + all_models, _ = repo.get_paginated(status="active") + assert len(all_models) == 1 + + # Verify it's the last one activated + assert all_models[0].version_id == models[-1].version_id diff --git a/tests/integration/repositories/test_token_repo_integration.py b/tests/integration/repositories/test_token_repo_integration.py new file mode 100644 index 0000000..e747c60 --- /dev/null +++ b/tests/integration/repositories/test_token_repo_integration.py @@ -0,0 +1,274 @@ +""" +Token Repository Integration Tests + +Tests TokenRepository with real database operations. +""" + +from datetime import datetime, timezone, timedelta + +import pytest + +from inference.data.repositories.token_repository import TokenRepository + + +class TestTokenCreate: + """Tests for token creation.""" + + def test_create_new_token(self, patched_session): + """Test creating a new admin token.""" + repo = TokenRepository() + + repo.create( + token="new-test-token-abc123", + name="New Test Admin", + ) + + token = repo.get("new-test-token-abc123") + assert token is not None + assert token.token == "new-test-token-abc123" + assert token.name == "New Test Admin" + assert token.is_active is True + assert token.expires_at is None + + def test_create_token_with_expiration(self, patched_session): + """Test creating token with expiration date.""" + repo = TokenRepository() + expiry = datetime.now(timezone.utc) + timedelta(days=30) + + repo.create( + token="expiring-token-xyz789", + name="Expiring Token", + expires_at=expiry, + ) + + token = repo.get("expiring-token-xyz789") + assert token is not None + assert token.expires_at is not None + + def test_create_updates_existing_token(self, patched_session, admin_token): + """Test creating with existing token updates it.""" + repo = TokenRepository() + new_expiry = datetime.now(timezone.utc) + timedelta(days=60) + + repo.create( + token=admin_token.token, + name="Updated Admin Name", + expires_at=new_expiry, + ) + + token = repo.get(admin_token.token) + assert token is not None + assert token.name == "Updated Admin Name" + assert token.is_active is True + + +class TestTokenValidation: + """Tests for token validation.""" + + def test_is_valid_active_token(self, patched_session, admin_token): + """Test that active token is valid.""" + repo = TokenRepository() + + result = repo.is_valid(admin_token.token) + + assert result is True + + def test_is_valid_nonexistent_token(self, patched_session): + """Test that nonexistent token is invalid.""" + repo = TokenRepository() + + result = repo.is_valid("nonexistent-token-12345") + + assert result is False + + def test_is_valid_deactivated_token(self, patched_session, admin_token): + """Test that deactivated token is invalid.""" + repo = TokenRepository() + + repo.deactivate(admin_token.token) + result = repo.is_valid(admin_token.token) + + assert result is False + + def test_is_valid_expired_token(self, patched_session): + """Test that expired token is invalid.""" + repo = TokenRepository() + past_expiry = datetime.now(timezone.utc) - timedelta(days=1) + + repo.create( + token="expired-token-test", + name="Expired Token", + expires_at=past_expiry, + ) + + result = repo.is_valid("expired-token-test") + + assert result is False + + def test_is_valid_not_yet_expired_token(self, patched_session): + """Test that not-yet-expired token is valid.""" + repo = TokenRepository() + future_expiry = datetime.now(timezone.utc) + timedelta(days=7) + + repo.create( + token="valid-expiring-token", + name="Valid Expiring Token", + expires_at=future_expiry, + ) + + result = repo.is_valid("valid-expiring-token") + + assert result is True + + +class TestTokenGet: + """Tests for token retrieval.""" + + def test_get_existing_token(self, patched_session, admin_token): + """Test getting an existing token.""" + repo = TokenRepository() + + token = repo.get(admin_token.token) + + assert token is not None + assert token.token == admin_token.token + assert token.name == admin_token.name + + def test_get_nonexistent_token(self, patched_session): + """Test getting a token that doesn't exist.""" + repo = TokenRepository() + + token = repo.get("nonexistent-token-xyz") + + assert token is None + + +class TestTokenDeactivate: + """Tests for token deactivation.""" + + def test_deactivate_existing_token(self, patched_session, admin_token): + """Test deactivating an existing token.""" + repo = TokenRepository() + + result = repo.deactivate(admin_token.token) + + assert result is True + token = repo.get(admin_token.token) + assert token is not None + assert token.is_active is False + + def test_deactivate_nonexistent_token(self, patched_session): + """Test deactivating a token that doesn't exist.""" + repo = TokenRepository() + + result = repo.deactivate("nonexistent-token-abc") + + assert result is False + + def test_reactivate_deactivated_token(self, patched_session, admin_token): + """Test reactivating a deactivated token via create.""" + repo = TokenRepository() + + # Deactivate first + repo.deactivate(admin_token.token) + assert repo.is_valid(admin_token.token) is False + + # Reactivate via create + repo.create( + token=admin_token.token, + name="Reactivated Admin", + ) + + assert repo.is_valid(admin_token.token) is True + + +class TestTokenUsageTracking: + """Tests for token usage tracking.""" + + def test_update_usage(self, patched_session, admin_token): + """Test updating token last used timestamp.""" + repo = TokenRepository() + + # Initially last_used_at might be None + initial_token = repo.get(admin_token.token) + initial_last_used = initial_token.last_used_at + + repo.update_usage(admin_token.token) + + updated_token = repo.get(admin_token.token) + assert updated_token.last_used_at is not None + if initial_last_used: + assert updated_token.last_used_at >= initial_last_used + + def test_update_usage_nonexistent_token(self, patched_session): + """Test updating usage for nonexistent token does nothing.""" + repo = TokenRepository() + + # Should not raise, just does nothing + repo.update_usage("nonexistent-token-usage") + + token = repo.get("nonexistent-token-usage") + assert token is None + + +class TestTokenWorkflow: + """Tests for complete token workflows.""" + + def test_full_token_lifecycle(self, patched_session): + """Test complete token lifecycle: create, validate, use, deactivate.""" + repo = TokenRepository() + token_str = "lifecycle-test-token" + + # 1. Create token + repo.create(token=token_str, name="Lifecycle Token") + assert repo.is_valid(token_str) is True + + # 2. Use token + repo.update_usage(token_str) + token = repo.get(token_str) + assert token.last_used_at is not None + + # 3. Update token info + new_expiry = datetime.now(timezone.utc) + timedelta(days=90) + repo.create( + token=token_str, + name="Updated Lifecycle Token", + expires_at=new_expiry, + ) + token = repo.get(token_str) + assert token.name == "Updated Lifecycle Token" + + # 4. Deactivate token + result = repo.deactivate(token_str) + assert result is True + assert repo.is_valid(token_str) is False + + # 5. Reactivate token + repo.create(token=token_str, name="Reactivated Token") + assert repo.is_valid(token_str) is True + + def test_multiple_tokens(self, patched_session): + """Test managing multiple tokens.""" + repo = TokenRepository() + + # Create multiple tokens + tokens = [ + ("token-a", "Admin A"), + ("token-b", "Admin B"), + ("token-c", "Admin C"), + ] + + for token_str, name in tokens: + repo.create(token=token_str, name=name) + + # Verify all are valid + for token_str, _ in tokens: + assert repo.is_valid(token_str) is True + + # Deactivate one + repo.deactivate("token-b") + + # Verify states + assert repo.is_valid("token-a") is True + assert repo.is_valid("token-b") is False + assert repo.is_valid("token-c") is True diff --git a/tests/integration/repositories/test_training_task_repo_integration.py b/tests/integration/repositories/test_training_task_repo_integration.py new file mode 100644 index 0000000..e3a6d19 --- /dev/null +++ b/tests/integration/repositories/test_training_task_repo_integration.py @@ -0,0 +1,364 @@ +""" +Training Task Repository Integration Tests + +Tests TrainingTaskRepository with real database operations. +""" + +from datetime import datetime, timezone, timedelta +from uuid import uuid4 + +import pytest + +from inference.data.repositories.training_task_repository import TrainingTaskRepository + + +class TestTrainingTaskCreate: + """Tests for training task creation.""" + + def test_create_training_task(self, patched_session, admin_token): + """Test creating a training task.""" + repo = TrainingTaskRepository() + + task_id = repo.create( + admin_token=admin_token.token, + name="Test Training Task", + task_type="train", + description="Integration test training task", + config={"epochs": 100, "batch_size": 16}, + ) + + assert task_id is not None + + task = repo.get(task_id) + assert task is not None + assert task.name == "Test Training Task" + assert task.task_type == "train" + assert task.status == "pending" + assert task.config["epochs"] == 100 + + def test_create_scheduled_task(self, patched_session, admin_token): + """Test creating a scheduled training task.""" + repo = TrainingTaskRepository() + + scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1) + + task_id = repo.create( + admin_token=admin_token.token, + name="Scheduled Task", + scheduled_at=scheduled_time, + ) + + task = repo.get(task_id) + assert task is not None + assert task.status == "scheduled" + assert task.scheduled_at is not None + + def test_create_recurring_task(self, patched_session, admin_token): + """Test creating a recurring training task.""" + repo = TrainingTaskRepository() + + task_id = repo.create( + admin_token=admin_token.token, + name="Recurring Task", + cron_expression="0 2 * * *", + is_recurring=True, + ) + + task = repo.get(task_id) + assert task is not None + assert task.is_recurring is True + assert task.cron_expression == "0 2 * * *" + + def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset): + """Test creating task linked to a dataset.""" + repo = TrainingTaskRepository() + + task_id = repo.create( + admin_token=admin_token.token, + name="Dataset Training Task", + dataset_id=str(sample_dataset.dataset_id), + ) + + task = repo.get(task_id) + assert task is not None + assert task.dataset_id == sample_dataset.dataset_id + + +class TestTrainingTaskRead: + """Tests for training task retrieval.""" + + def test_get_task_by_id(self, patched_session, sample_training_task): + """Test getting task by ID.""" + repo = TrainingTaskRepository() + + task = repo.get(str(sample_training_task.task_id)) + + assert task is not None + assert task.task_id == sample_training_task.task_id + + def test_get_nonexistent_task(self, patched_session): + """Test getting task that doesn't exist.""" + repo = TrainingTaskRepository() + + task = repo.get(str(uuid4())) + assert task is None + + def test_get_paginated_tasks(self, patched_session, admin_token): + """Test paginated task listing.""" + repo = TrainingTaskRepository() + + # Create multiple tasks + for i in range(5): + repo.create(admin_token=admin_token.token, name=f"Task {i}") + + tasks, total = repo.get_paginated(limit=2, offset=0) + + assert total == 5 + assert len(tasks) == 2 + + def test_get_paginated_with_status_filter(self, patched_session, admin_token): + """Test filtering tasks by status.""" + repo = TrainingTaskRepository() + + # Create tasks with different statuses + task_id = repo.create(admin_token=admin_token.token, name="Running Task") + repo.update_status(task_id, "running") + + repo.create(admin_token=admin_token.token, name="Pending Task") + + tasks, total = repo.get_paginated(status="running") + + assert total == 1 + assert tasks[0].status == "running" + + def test_get_pending_tasks(self, patched_session, admin_token): + """Test getting pending tasks ready to run.""" + repo = TrainingTaskRepository() + + # Create pending task + repo.create(admin_token=admin_token.token, name="Ready Task") + + # Create scheduled task in the past (should be included) + past_time = datetime.now(timezone.utc) - timedelta(hours=1) + repo.create( + admin_token=admin_token.token, + name="Past Scheduled Task", + scheduled_at=past_time, + ) + + # Create scheduled task in the future (should not be included) + future_time = datetime.now(timezone.utc) + timedelta(hours=1) + repo.create( + admin_token=admin_token.token, + name="Future Scheduled Task", + scheduled_at=future_time, + ) + + pending = repo.get_pending() + + # Should include pending and past scheduled, not future scheduled + assert len(pending) >= 2 + names = [t.name for t in pending] + assert "Ready Task" in names + assert "Past Scheduled Task" in names + + def test_get_running_task(self, patched_session, admin_token): + """Test getting currently running task.""" + repo = TrainingTaskRepository() + + task_id = repo.create(admin_token=admin_token.token, name="Running Task") + repo.update_status(task_id, "running") + + running = repo.get_running() + + assert running is not None + assert running.status == "running" + + def test_get_running_task_none(self, patched_session, admin_token): + """Test getting running task when none is running.""" + repo = TrainingTaskRepository() + + repo.create(admin_token=admin_token.token, name="Pending Task") + + running = repo.get_running() + assert running is None + + +class TestTrainingTaskUpdate: + """Tests for training task updates.""" + + def test_update_status_to_running(self, patched_session, sample_training_task): + """Test updating task status to running.""" + repo = TrainingTaskRepository() + + repo.update_status(str(sample_training_task.task_id), "running") + + task = repo.get(str(sample_training_task.task_id)) + assert task is not None + assert task.status == "running" + assert task.started_at is not None + + def test_update_status_to_completed(self, patched_session, sample_training_task): + """Test updating task status to completed.""" + repo = TrainingTaskRepository() + + metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85} + + repo.update_status( + str(sample_training_task.task_id), + "completed", + result_metrics=metrics, + model_path="/models/trained_model.pt", + ) + + task = repo.get(str(sample_training_task.task_id)) + assert task is not None + assert task.status == "completed" + assert task.completed_at is not None + assert task.result_metrics["mAP"] == 0.92 + assert task.model_path == "/models/trained_model.pt" + + def test_update_status_to_failed(self, patched_session, sample_training_task): + """Test updating task status to failed with error message.""" + repo = TrainingTaskRepository() + + repo.update_status( + str(sample_training_task.task_id), + "failed", + error_message="CUDA out of memory", + ) + + task = repo.get(str(sample_training_task.task_id)) + assert task is not None + assert task.status == "failed" + assert task.completed_at is not None + assert "CUDA out of memory" in task.error_message + + def test_cancel_pending_task(self, patched_session, sample_training_task): + """Test cancelling a pending task.""" + repo = TrainingTaskRepository() + + result = repo.cancel(str(sample_training_task.task_id)) + + assert result is True + + task = repo.get(str(sample_training_task.task_id)) + assert task is not None + assert task.status == "cancelled" + + def test_cannot_cancel_running_task(self, patched_session, sample_training_task): + """Test that running task cannot be cancelled.""" + repo = TrainingTaskRepository() + + repo.update_status(str(sample_training_task.task_id), "running") + + result = repo.cancel(str(sample_training_task.task_id)) + + assert result is False + + task = repo.get(str(sample_training_task.task_id)) + assert task.status == "running" + + +class TestTrainingLogs: + """Tests for training log management.""" + + def test_add_log_entry(self, patched_session, sample_training_task): + """Test adding a training log entry.""" + repo = TrainingTaskRepository() + + repo.add_log( + str(sample_training_task.task_id), + level="INFO", + message="Starting training...", + details={"epoch": 1, "batch": 0}, + ) + + logs = repo.get_logs(str(sample_training_task.task_id)) + assert len(logs) == 1 + assert logs[0].level == "INFO" + assert logs[0].message == "Starting training..." + + def test_add_multiple_log_entries(self, patched_session, sample_training_task): + """Test adding multiple log entries.""" + repo = TrainingTaskRepository() + + for i in range(5): + repo.add_log( + str(sample_training_task.task_id), + level="INFO", + message=f"Epoch {i} completed", + details={"epoch": i, "loss": 0.5 - i * 0.1}, + ) + + logs = repo.get_logs(str(sample_training_task.task_id)) + assert len(logs) == 5 + + def test_get_logs_pagination(self, patched_session, sample_training_task): + """Test paginated log retrieval.""" + repo = TrainingTaskRepository() + + for i in range(10): + repo.add_log( + str(sample_training_task.task_id), + level="INFO", + message=f"Log entry {i}", + ) + + logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0) + assert len(logs) == 5 + + logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5) + assert len(logs_page2) == 5 + + +class TestDocumentLinks: + """Tests for training document link management.""" + + def test_create_document_link(self, patched_session, sample_training_task, sample_document): + """Test creating a document link.""" + repo = TrainingTaskRepository() + + link = repo.create_document_link( + task_id=sample_training_task.task_id, + document_id=sample_document.document_id, + annotation_snapshot={"count": 5, "verified": 3}, + ) + + assert link is not None + assert link.task_id == sample_training_task.task_id + assert link.document_id == sample_document.document_id + assert link.annotation_snapshot["count"] == 5 + + def test_get_document_links(self, patched_session, sample_training_task, multiple_documents): + """Test getting all document links for a task.""" + repo = TrainingTaskRepository() + + for doc in multiple_documents[:3]: + repo.create_document_link( + task_id=sample_training_task.task_id, + document_id=doc.document_id, + ) + + links = repo.get_document_links(sample_training_task.task_id) + assert len(links) == 3 + + def test_get_document_training_tasks(self, patched_session, admin_token, sample_document): + """Test getting training tasks that used a document.""" + repo = TrainingTaskRepository() + + # Create multiple tasks using the same document + task1_id = repo.create(admin_token=admin_token.token, name="Task 1") + task2_id = repo.create(admin_token=admin_token.token, name="Task 2") + + repo.create_document_link( + task_id=repo.get(task1_id).task_id, + document_id=sample_document.document_id, + ) + repo.create_document_link( + task_id=repo.get(task2_id).task_id, + document_id=sample_document.document_id, + ) + + links = repo.get_document_training_tasks(sample_document.document_id) + assert len(links) == 2 diff --git a/tests/integration/services/__init__.py b/tests/integration/services/__init__.py new file mode 100644 index 0000000..ef52312 --- /dev/null +++ b/tests/integration/services/__init__.py @@ -0,0 +1 @@ +"""Service integration tests.""" diff --git a/tests/integration/services/test_dashboard_service_integration.py b/tests/integration/services/test_dashboard_service_integration.py new file mode 100644 index 0000000..f930d90 --- /dev/null +++ b/tests/integration/services/test_dashboard_service_integration.py @@ -0,0 +1,497 @@ +""" +Dashboard Service Integration Tests + +Tests DashboardStatsService and DashboardActivityService with real database operations. +""" + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest + +from inference.data.admin_models import ( + AdminAnnotation, + AdminDocument, + AnnotationHistory, + ModelVersion, + TrainingDataset, + TrainingTask, +) +from inference.web.services.dashboard_service import ( + DashboardStatsService, + DashboardActivityService, + is_annotation_complete, + IDENTIFIER_CLASS_IDS, + PAYMENT_CLASS_IDS, +) + + +class TestIsAnnotationComplete: + """Tests for is_annotation_complete function.""" + + def test_complete_with_invoice_number_and_bankgiro(self): + """Test complete with invoice_number (0) and bankgiro (4).""" + annotations = [ + {"class_id": 0}, # invoice_number + {"class_id": 4}, # bankgiro + ] + assert is_annotation_complete(annotations) is True + + def test_complete_with_ocr_number_and_plusgiro(self): + """Test complete with ocr_number (3) and plusgiro (5).""" + annotations = [ + {"class_id": 3}, # ocr_number + {"class_id": 5}, # plusgiro + ] + assert is_annotation_complete(annotations) is True + + def test_incomplete_missing_identifier(self): + """Test incomplete when missing identifier.""" + annotations = [ + {"class_id": 4}, # bankgiro only + ] + assert is_annotation_complete(annotations) is False + + def test_incomplete_missing_payment(self): + """Test incomplete when missing payment.""" + annotations = [ + {"class_id": 0}, # invoice_number only + ] + assert is_annotation_complete(annotations) is False + + def test_incomplete_empty_annotations(self): + """Test incomplete with empty annotations.""" + assert is_annotation_complete([]) is False + + def test_complete_with_multiple_fields(self): + """Test complete with multiple fields.""" + annotations = [ + {"class_id": 0}, # invoice_number + {"class_id": 1}, # invoice_date + {"class_id": 3}, # ocr_number + {"class_id": 4}, # bankgiro + {"class_id": 5}, # plusgiro + {"class_id": 6}, # amount + ] + assert is_annotation_complete(annotations) is True + + +class TestDashboardStatsService: + """Tests for DashboardStatsService.""" + + def test_get_stats_empty_database(self, patched_session): + """Test stats with empty database.""" + service = DashboardStatsService() + + stats = service.get_stats() + + assert stats["total_documents"] == 0 + assert stats["annotation_complete"] == 0 + assert stats["annotation_incomplete"] == 0 + assert stats["pending"] == 0 + assert stats["completeness_rate"] == 0.0 + + def test_get_stats_with_documents(self, patched_session, admin_token): + """Test stats with various document states.""" + service = DashboardStatsService() + session = patched_session + + # Create documents with different statuses + docs = [] + for i, status in enumerate(["pending", "auto_labeling", "labeled", "labeled", "exported"]): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"doc_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/doc_{i}.pdf", + page_count=1, + status=status, + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + docs.append(doc) + session.commit() + + stats = service.get_stats() + + assert stats["total_documents"] == 5 + assert stats["pending"] == 2 # pending + auto_labeling + + def test_get_stats_complete_annotations(self, patched_session, admin_token): + """Test completeness calculation with proper annotations.""" + service = DashboardStatsService() + session = patched_session + + # Create a labeled document with complete annotations + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename="complete_doc.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/uploads/complete_doc.pdf", + page_count=1, + status="labeled", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + # Add identifier annotation (invoice_number = class_id 0) + ann1 = AdminAnnotation( + annotation_id=uuid4(), + document_id=doc.document_id, + page_number=1, + class_id=0, + class_name="invoice_number", + x_center=0.5, + y_center=0.1, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=80, + bbox_width=160, + bbox_height=40, + text_value="INV-001", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(ann1) + + # Add payment annotation (bankgiro = class_id 4) + ann2 = AdminAnnotation( + annotation_id=uuid4(), + document_id=doc.document_id, + page_number=1, + class_id=4, + class_name="bankgiro", + x_center=0.5, + y_center=0.2, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=160, + bbox_width=160, + bbox_height=40, + text_value="123-4567", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(ann2) + session.commit() + + stats = service.get_stats() + + assert stats["annotation_complete"] == 1 + assert stats["annotation_incomplete"] == 0 + assert stats["completeness_rate"] == 100.0 + + def test_get_stats_incomplete_annotations(self, patched_session, admin_token): + """Test completeness with incomplete annotations.""" + service = DashboardStatsService() + session = patched_session + + # Create a labeled document missing payment annotation + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename="incomplete_doc.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/uploads/incomplete_doc.pdf", + page_count=1, + status="labeled", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + # Add only identifier annotation (missing payment) + ann = AdminAnnotation( + annotation_id=uuid4(), + document_id=doc.document_id, + page_number=1, + class_id=0, + class_name="invoice_number", + x_center=0.5, + y_center=0.1, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=80, + bbox_width=160, + bbox_height=40, + text_value="INV-001", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(ann) + session.commit() + + stats = service.get_stats() + + assert stats["annotation_complete"] == 0 + assert stats["annotation_incomplete"] == 1 + assert stats["completeness_rate"] == 0.0 + + def test_get_stats_mixed_completeness(self, patched_session, admin_token): + """Test stats with mix of complete and incomplete documents.""" + service = DashboardStatsService() + session = patched_session + + # Create 2 labeled documents + docs = [] + for i in range(2): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"mixed_doc_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/mixed_doc_{i}.pdf", + page_count=1, + status="labeled", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + docs.append(doc) + session.commit() + + # First document: complete (has identifier + payment) + session.add(AdminAnnotation( + annotation_id=uuid4(), + document_id=docs[0].document_id, + page_number=1, + class_id=0, # invoice_number + class_name="invoice_number", + x_center=0.5, y_center=0.1, width=0.2, height=0.05, + bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + )) + session.add(AdminAnnotation( + annotation_id=uuid4(), + document_id=docs[0].document_id, + page_number=1, + class_id=4, # bankgiro + class_name="bankgiro", + x_center=0.5, y_center=0.2, width=0.2, height=0.05, + bbox_x=400, bbox_y=160, bbox_width=160, bbox_height=40, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + )) + + # Second document: incomplete (missing payment) + session.add(AdminAnnotation( + annotation_id=uuid4(), + document_id=docs[1].document_id, + page_number=1, + class_id=0, # invoice_number only + class_name="invoice_number", + x_center=0.5, y_center=0.1, width=0.2, height=0.05, + bbox_x=400, bbox_y=80, bbox_width=160, bbox_height=40, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + )) + session.commit() + + stats = service.get_stats() + + assert stats["annotation_complete"] == 1 + assert stats["annotation_incomplete"] == 1 + assert stats["completeness_rate"] == 50.0 + + +class TestDashboardActivityService: + """Tests for DashboardActivityService.""" + + def test_get_recent_activities_empty(self, patched_session): + """Test activities with empty database.""" + service = DashboardActivityService() + + activities = service.get_recent_activities() + + assert activities == [] + + def test_get_recent_activities_document_uploads(self, patched_session, admin_token): + """Test activities include document uploads.""" + service = DashboardActivityService() + session = patched_session + + # Create documents + for i in range(3): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"activity_doc_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/activity_doc_{i}.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + activities = service.get_recent_activities() + + upload_activities = [a for a in activities if a["type"] == "document_uploaded"] + assert len(upload_activities) == 3 + + def test_get_recent_activities_annotation_overrides(self, patched_session, sample_document, sample_annotation): + """Test activities include annotation overrides.""" + service = DashboardActivityService() + session = patched_session + + # Create annotation history with override + history = AnnotationHistory( + history_id=uuid4(), + annotation_id=sample_annotation.annotation_id, + document_id=sample_document.document_id, + action="override", + previous_value={"text_value": "OLD-001"}, + new_value={"text_value": "NEW-001", "class_name": "invoice_number"}, + changed_by="test-admin", + created_at=datetime.now(timezone.utc), + ) + session.add(history) + session.commit() + + activities = service.get_recent_activities() + + override_activities = [a for a in activities if a["type"] == "annotation_modified"] + assert len(override_activities) >= 1 + + def test_get_recent_activities_training_completed(self, patched_session, sample_training_task): + """Test activities include training completions.""" + service = DashboardActivityService() + session = patched_session + + # Update training task to completed + sample_training_task.status = "completed" + sample_training_task.metrics_mAP = 0.85 + sample_training_task.updated_at = datetime.now(timezone.utc) + session.add(sample_training_task) + session.commit() + + activities = service.get_recent_activities() + + training_activities = [a for a in activities if a["type"] == "training_completed"] + assert len(training_activities) >= 1 + assert "mAP" in training_activities[0]["metadata"] + + def test_get_recent_activities_training_failed(self, patched_session, sample_training_task): + """Test activities include training failures.""" + service = DashboardActivityService() + session = patched_session + + # Update training task to failed + sample_training_task.status = "failed" + sample_training_task.error_message = "CUDA out of memory" + sample_training_task.updated_at = datetime.now(timezone.utc) + session.add(sample_training_task) + session.commit() + + activities = service.get_recent_activities() + + failed_activities = [a for a in activities if a["type"] == "training_failed"] + assert len(failed_activities) >= 1 + assert failed_activities[0]["metadata"]["error"] == "CUDA out of memory" + + def test_get_recent_activities_model_activated(self, patched_session, sample_model_version): + """Test activities include model activations.""" + service = DashboardActivityService() + session = patched_session + + # Activate model + sample_model_version.is_active = True + sample_model_version.activated_at = datetime.now(timezone.utc) + session.add(sample_model_version) + session.commit() + + activities = service.get_recent_activities() + + activation_activities = [a for a in activities if a["type"] == "model_activated"] + assert len(activation_activities) >= 1 + assert activation_activities[0]["metadata"]["version"] == sample_model_version.version + + def test_get_recent_activities_limit(self, patched_session, admin_token): + """Test activity limit parameter.""" + service = DashboardActivityService() + session = patched_session + + # Create many documents + for i in range(20): + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename=f"limit_doc_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/limit_doc_{i}.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + session.commit() + + activities = service.get_recent_activities(limit=5) + + assert len(activities) <= 5 + + def test_get_recent_activities_sorted_by_timestamp(self, patched_session, admin_token, sample_training_task): + """Test activities are sorted by timestamp descending.""" + service = DashboardActivityService() + session = patched_session + + # Create document + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename="sorted_doc.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/uploads/sorted_doc.pdf", + page_count=1, + status="pending", + upload_source="ui", + category="invoice", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(doc) + + # Complete training task + sample_training_task.status = "completed" + sample_training_task.metrics_mAP = 0.90 + sample_training_task.updated_at = datetime.now(timezone.utc) + session.add(sample_training_task) + session.commit() + + activities = service.get_recent_activities() + + # Verify sorted by timestamp DESC + timestamps = [a["timestamp"] for a in activities] + assert timestamps == sorted(timestamps, reverse=True) diff --git a/tests/integration/services/test_dataset_builder_integration.py b/tests/integration/services/test_dataset_builder_integration.py new file mode 100644 index 0000000..633db37 --- /dev/null +++ b/tests/integration/services/test_dataset_builder_integration.py @@ -0,0 +1,453 @@ +""" +Dataset Builder Service Integration Tests + +Tests DatasetBuilder with real file operations and repository interactions. +""" + +import shutil +from datetime import datetime, timezone +from pathlib import Path +from uuid import uuid4 + +import pytest +import yaml + +from inference.data.admin_models import AdminAnnotation, AdminDocument +from inference.data.repositories.annotation_repository import AnnotationRepository +from inference.data.repositories.dataset_repository import DatasetRepository +from inference.data.repositories.document_repository import DocumentRepository +from inference.web.services.dataset_builder import DatasetBuilder + + +@pytest.fixture +def dataset_builder(patched_session, temp_dataset_dir): + """Create a DatasetBuilder with real repositories.""" + return DatasetBuilder( + datasets_repo=DatasetRepository(), + documents_repo=DocumentRepository(), + annotations_repo=AnnotationRepository(), + base_dir=temp_dataset_dir, + ) + + +@pytest.fixture +def admin_images_dir(temp_upload_dir): + """Create a directory for admin images.""" + images_dir = temp_upload_dir / "admin_images" + images_dir.mkdir(parents=True, exist_ok=True) + return images_dir + + +@pytest.fixture +def documents_with_annotations(patched_session, db_session, admin_token, admin_images_dir): + """Create documents with annotations and corresponding image files.""" + documents = [] + doc_repo = DocumentRepository() + ann_repo = AnnotationRepository() + + for i in range(5): + # Create document + doc_id = doc_repo.create( + filename=f"invoice_{i}.pdf", + file_size=1024, + content_type="application/pdf", + file_path=f"/uploads/invoice_{i}.pdf", + page_count=2, + category="invoice", + group_key=f"group_{i % 2}", # Two groups + ) + + # Create image files for each page + doc_dir = admin_images_dir / doc_id + doc_dir.mkdir(parents=True, exist_ok=True) + + for page in range(1, 3): + image_path = doc_dir / f"page_{page}.png" + # Create a minimal fake PNG + image_path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + + # Create annotations + for j in range(3): + ann_repo.create( + document_id=doc_id, + page_number=1, + class_id=j, + class_name=f"field_{j}", + x_center=0.5, + y_center=0.1 + j * 0.2, + width=0.2, + height=0.05, + bbox_x=400, + bbox_y=80 + j * 160, + bbox_width=160, + bbox_height=40, + text_value=f"value_{j}", + confidence=0.95, + source="auto", + ) + + doc = doc_repo.get(doc_id) + documents.append(doc) + + return documents + + +class TestDatasetBuilderBasic: + """Tests for basic dataset building operations.""" + + def test_build_dataset_creates_directory_structure( + self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session + ): + """Test that building creates proper directory structure.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Test Dataset") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + dataset_dir = temp_dataset_dir / str(dataset.dataset_id) + + # Check directory structure + assert (dataset_dir / "images" / "train").exists() + assert (dataset_dir / "images" / "val").exists() + assert (dataset_dir / "images" / "test").exists() + assert (dataset_dir / "labels" / "train").exists() + assert (dataset_dir / "labels" / "val").exists() + assert (dataset_dir / "labels" / "test").exists() + + def test_build_dataset_copies_images( + self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session + ): + """Test that images are copied to dataset directory.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Image Copy Test") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + result = dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + dataset_dir = temp_dataset_dir / str(dataset.dataset_id) + + # Count total images across all splits + total_images = 0 + for split in ["train", "val", "test"]: + images = list((dataset_dir / "images" / split).glob("*.png")) + total_images += len(images) + + # 5 docs * 2 pages = 10 images + assert total_images == 10 + assert result["total_images"] == 10 + + def test_build_dataset_generates_labels( + self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session + ): + """Test that YOLO label files are generated.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Label Generation Test") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + dataset_dir = temp_dataset_dir / str(dataset.dataset_id) + + # Count total label files + total_labels = 0 + for split in ["train", "val", "test"]: + labels = list((dataset_dir / "labels" / split).glob("*.txt")) + total_labels += len(labels) + + # Same count as images + assert total_labels == 10 + + def test_build_dataset_generates_data_yaml( + self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session + ): + """Test that data.yaml is generated correctly.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="YAML Generation Test") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + dataset_dir = temp_dataset_dir / str(dataset.dataset_id) + yaml_path = dataset_dir / "data.yaml" + + assert yaml_path.exists() + + with open(yaml_path) as f: + data = yaml.safe_load(f) + + assert data["train"] == "images/train" + assert data["val"] == "images/val" + assert data["test"] == "images/test" + assert "nc" in data + assert "names" in data + + +class TestDatasetBuilderSplits: + """Tests for train/val/test split assignment.""" + + def test_split_ratio_respected( + self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session + ): + """Test that split ratios are approximately respected.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Split Ratio Test") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.6, + val_ratio=0.2, + seed=42, + admin_images_dir=admin_images_dir, + ) + + # Check document assignments in database + dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id)) + + splits = {"train": 0, "val": 0, "test": 0} + for doc in dataset_docs: + splits[doc.split] += 1 + + # With 5 docs and ratios 0.6/0.2/0.2, expect ~3/1/1 + # Due to rounding and group constraints, allow some variation + assert splits["train"] >= 2 + assert splits["val"] >= 1 or splits["test"] >= 1 + + def test_same_seed_same_split( + self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session + ): + """Test that same seed produces same split.""" + dataset_repo = DatasetRepository() + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + # Build first dataset + dataset1 = dataset_repo.create(name="Seed Test 1") + dataset_builder.build_dataset( + dataset_id=str(dataset1.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=12345, + admin_images_dir=admin_images_dir, + ) + + # Build second dataset with same seed + dataset2 = dataset_repo.create(name="Seed Test 2") + dataset_builder.build_dataset( + dataset_id=str(dataset2.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=12345, + admin_images_dir=admin_images_dir, + ) + + # Compare splits + docs1 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset1.dataset_id))} + docs2 = {str(d.document_id): d.split for d in dataset_repo.get_documents(str(dataset2.dataset_id))} + + assert docs1 == docs2 + + +class TestDatasetBuilderDatabase: + """Tests for database interactions.""" + + def test_updates_dataset_status( + self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session + ): + """Test that dataset status is updated after build.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Status Update Test") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + updated = dataset_repo.get(str(dataset.dataset_id)) + + assert updated.status == "ready" + assert updated.total_documents == 5 + assert updated.total_images == 10 + assert updated.total_annotations > 0 + assert updated.dataset_path is not None + + def test_records_document_assignments( + self, dataset_builder, documents_with_annotations, admin_images_dir, patched_session + ): + """Test that document assignments are recorded in database.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Assignment Recording Test") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + dataset_docs = dataset_repo.get_documents(str(dataset.dataset_id)) + + assert len(dataset_docs) == 5 + + for doc in dataset_docs: + assert doc.split in ["train", "val", "test"] + assert doc.page_count > 0 + + +class TestDatasetBuilderErrors: + """Tests for error handling.""" + + def test_fails_with_no_documents(self, dataset_builder, admin_images_dir, patched_session): + """Test that building fails with empty document list.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Empty Docs Test") + + with pytest.raises(ValueError, match="No valid documents"): + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + def test_fails_with_invalid_doc_ids(self, dataset_builder, admin_images_dir, patched_session): + """Test that building fails with nonexistent document IDs.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Invalid IDs Test") + + fake_ids = [str(uuid4()) for _ in range(3)] + + with pytest.raises(ValueError, match="No valid documents"): + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=fake_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + def test_updates_status_on_failure(self, dataset_builder, admin_images_dir, patched_session): + """Test that dataset status is set to failed on error.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Failure Status Test") + + try: + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + except ValueError: + pass + + updated = dataset_repo.get(str(dataset.dataset_id)) + assert updated.status == "failed" + assert updated.error_message is not None + + +class TestLabelFileFormat: + """Tests for YOLO label file format.""" + + def test_label_file_format( + self, dataset_builder, documents_with_annotations, admin_images_dir, temp_dataset_dir, patched_session + ): + """Test that label files are in correct YOLO format.""" + dataset_repo = DatasetRepository() + dataset = dataset_repo.create(name="Label Format Test") + + doc_ids = [str(d.document_id) for d in documents_with_annotations] + + dataset_builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=doc_ids, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=admin_images_dir, + ) + + dataset_dir = temp_dataset_dir / str(dataset.dataset_id) + + # Find a label file with content + label_files = [] + for split in ["train", "val", "test"]: + label_files.extend(list((dataset_dir / "labels" / split).glob("*.txt"))) + + # Check at least one label file has correct format + found_valid_label = False + for label_file in label_files: + content = label_file.read_text().strip() + if content: + lines = content.split("\n") + for line in lines: + parts = line.split() + assert len(parts) == 5, f"Expected 5 parts, got {len(parts)}: {line}" + + class_id = int(parts[0]) + x_center = float(parts[1]) + y_center = float(parts[2]) + width = float(parts[3]) + height = float(parts[4]) + + assert 0 <= class_id < 10 + assert 0 <= x_center <= 1 + assert 0 <= y_center <= 1 + assert 0 <= width <= 1 + assert 0 <= height <= 1 + + found_valid_label = True + break + + assert found_valid_label, "No valid label files found" diff --git a/tests/integration/services/test_document_service_integration.py b/tests/integration/services/test_document_service_integration.py new file mode 100644 index 0000000..ce48dc3 --- /dev/null +++ b/tests/integration/services/test_document_service_integration.py @@ -0,0 +1,283 @@ +""" +Document Service Integration Tests + +Tests DocumentService with real storage operations. +""" + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from inference.web.services.document_service import DocumentService, DocumentResult + + +class MockStorageBackend: + """Simple in-memory storage backend for testing.""" + + def __init__(self): + self._files: dict[str, bytes] = {} + + def upload_bytes(self, content: bytes, remote_path: str, overwrite: bool = False) -> None: + if not overwrite and remote_path in self._files: + raise FileExistsError(f"File already exists: {remote_path}") + self._files[remote_path] = content + + def download_bytes(self, remote_path: str) -> bytes: + if remote_path not in self._files: + raise FileNotFoundError(f"File not found: {remote_path}") + return self._files[remote_path] + + def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str: + return f"https://storage.example.com/{remote_path}?expires={expires_in_seconds}" + + def exists(self, remote_path: str) -> bool: + return remote_path in self._files + + def delete(self, remote_path: str) -> bool: + if remote_path in self._files: + del self._files[remote_path] + return True + return False + + def list_files(self, prefix: str) -> list[str]: + return [path for path in self._files.keys() if path.startswith(prefix)] + + +@pytest.fixture +def mock_storage(): + """Create a mock storage backend.""" + return MockStorageBackend() + + +@pytest.fixture +def document_service(mock_storage): + """Create a DocumentService with mock storage.""" + return DocumentService(storage_backend=mock_storage) + + +class TestDocumentUpload: + """Tests for document upload operations.""" + + def test_upload_document(self, document_service): + """Test uploading a document.""" + content = b"%PDF-1.4 test content" + filename = "test_invoice.pdf" + + result = document_service.upload_document(content, filename) + + assert result is not None + assert result.id is not None + assert result.filename == filename + assert result.file_path.startswith("documents/") + assert result.file_path.endswith(".pdf") + + def test_upload_document_with_custom_id(self, document_service): + """Test uploading with custom document ID.""" + content = b"%PDF-1.4 test content" + filename = "invoice.pdf" + custom_id = "custom-doc-12345" + + result = document_service.upload_document( + content, filename, document_id=custom_id + ) + + assert result.id == custom_id + assert custom_id in result.file_path + + def test_upload_preserves_extension(self, document_service): + """Test that file extension is preserved.""" + cases = [ + ("document.pdf", ".pdf"), + ("image.PNG", ".png"), + ("file.JPEG", ".jpeg"), + ("noextension", ""), + ] + + for filename, expected_ext in cases: + result = document_service.upload_document(b"content", filename) + if expected_ext: + assert result.file_path.endswith(expected_ext) + + def test_upload_document_overwrite(self, document_service, mock_storage): + """Test that upload overwrites existing file.""" + content1 = b"original content" + content2 = b"new content" + doc_id = "overwrite-test" + + document_service.upload_document(content1, "doc.pdf", document_id=doc_id) + document_service.upload_document(content2, "doc.pdf", document_id=doc_id) + + # Should have new content + remote_path = f"documents/{doc_id}.pdf" + stored_content = mock_storage.download_bytes(remote_path) + assert stored_content == content2 + + +class TestDocumentDownload: + """Tests for document download operations.""" + + def test_download_document(self, document_service, mock_storage): + """Test downloading a document.""" + content = b"test document content" + remote_path = "documents/test-doc.pdf" + mock_storage.upload_bytes(content, remote_path) + + downloaded = document_service.download_document(remote_path) + + assert downloaded == content + + def test_download_nonexistent_document(self, document_service): + """Test downloading document that doesn't exist.""" + with pytest.raises(FileNotFoundError): + document_service.download_document("documents/nonexistent.pdf") + + +class TestDocumentUrl: + """Tests for document URL generation.""" + + def test_get_document_url(self, document_service, mock_storage): + """Test getting presigned URL for document.""" + remote_path = "documents/test-doc.pdf" + mock_storage.upload_bytes(b"content", remote_path) + + url = document_service.get_document_url(remote_path, expires_in_seconds=7200) + + assert url.startswith("https://") + assert remote_path in url + assert "7200" in url + + def test_get_document_url_default_expiry(self, document_service): + """Test default URL expiry.""" + url = document_service.get_document_url("documents/doc.pdf") + + assert "3600" in url + + +class TestDocumentExists: + """Tests for document existence check.""" + + def test_document_exists(self, document_service, mock_storage): + """Test checking if document exists.""" + remote_path = "documents/existing.pdf" + mock_storage.upload_bytes(b"content", remote_path) + + assert document_service.document_exists(remote_path) is True + + def test_document_not_exists(self, document_service): + """Test checking if nonexistent document exists.""" + assert document_service.document_exists("documents/nonexistent.pdf") is False + + +class TestDocumentDelete: + """Tests for document deletion.""" + + def test_delete_document(self, document_service, mock_storage): + """Test deleting a document.""" + remote_path = "documents/to-delete.pdf" + mock_storage.upload_bytes(b"content", remote_path) + + result = document_service.delete_document_files(remote_path) + + assert result is True + assert document_service.document_exists(remote_path) is False + + def test_delete_nonexistent_document(self, document_service): + """Test deleting document that doesn't exist.""" + result = document_service.delete_document_files("documents/nonexistent.pdf") + + assert result is False + + +class TestPageImages: + """Tests for page image operations.""" + + def test_save_page_image(self, document_service, mock_storage): + """Test saving a page image.""" + doc_id = "test-doc-123" + page_num = 1 + image_content = b"\x89PNG\r\n\x1a\n fake png" + + remote_path = document_service.save_page_image(doc_id, page_num, image_content) + + assert remote_path == f"images/{doc_id}/page_{page_num}.png" + assert mock_storage.exists(remote_path) + + def test_save_multiple_page_images(self, document_service, mock_storage): + """Test saving images for multiple pages.""" + doc_id = "multi-page-doc" + + for page_num in range(1, 4): + content = f"page {page_num} content".encode() + document_service.save_page_image(doc_id, page_num, content) + + images = document_service.list_document_images(doc_id) + assert len(images) == 3 + + def test_get_page_image(self, document_service, mock_storage): + """Test downloading a page image.""" + doc_id = "test-doc" + page_num = 2 + image_content = b"image data" + + document_service.save_page_image(doc_id, page_num, image_content) + downloaded = document_service.get_page_image(doc_id, page_num) + + assert downloaded == image_content + + def test_get_page_image_url(self, document_service): + """Test getting URL for page image.""" + doc_id = "test-doc" + page_num = 1 + + url = document_service.get_page_image_url(doc_id, page_num) + + assert f"images/{doc_id}/page_{page_num}.png" in url + + def test_list_document_images(self, document_service, mock_storage): + """Test listing all images for a document.""" + doc_id = "list-test-doc" + + for i in range(5): + document_service.save_page_image(doc_id, i + 1, f"page {i}".encode()) + + images = document_service.list_document_images(doc_id) + + assert len(images) == 5 + + def test_delete_document_images(self, document_service, mock_storage): + """Test deleting all images for a document.""" + doc_id = "delete-images-doc" + + for i in range(3): + document_service.save_page_image(doc_id, i + 1, b"content") + + deleted_count = document_service.delete_document_images(doc_id) + + assert deleted_count == 3 + assert len(document_service.list_document_images(doc_id)) == 0 + + +class TestRoundTrip: + """Tests for complete upload-download cycles.""" + + def test_document_round_trip(self, document_service): + """Test uploading and downloading document.""" + original_content = b"%PDF-1.4 complete document content here" + filename = "roundtrip.pdf" + + result = document_service.upload_document(original_content, filename) + downloaded = document_service.download_document(result.file_path) + + assert downloaded == original_content + + def test_image_round_trip(self, document_service): + """Test saving and retrieving page image.""" + doc_id = "roundtrip-doc" + page_num = 1 + original_image = b"\x89PNG fake image data" + + document_service.save_page_image(doc_id, page_num, original_image) + retrieved = document_service.get_page_image(doc_id, page_num) + + assert retrieved == original_image diff --git a/tests/integration/test_database_setup.py b/tests/integration/test_database_setup.py new file mode 100644 index 0000000..e0cca35 --- /dev/null +++ b/tests/integration/test_database_setup.py @@ -0,0 +1,258 @@ +""" +Database Setup Integration Tests + +Tests for database connection, session management, and basic operations. +""" + +import pytest +from sqlmodel import Session, select + +from inference.data.admin_models import AdminDocument, AdminToken + + +class TestDatabaseConnection: + """Tests for database engine and connection.""" + + def test_engine_connection(self, test_engine): + """Verify database engine can establish connection.""" + with test_engine.connect() as conn: + result = conn.execute(select(1)) + assert result.scalar() == 1 + + def test_tables_created(self, test_engine): + """Verify all expected tables are created.""" + from sqlmodel import SQLModel + + table_names = SQLModel.metadata.tables.keys() + + expected_tables = [ + "admin_tokens", + "admin_documents", + "admin_annotations", + "training_tasks", + "training_logs", + "batch_uploads", + "batch_upload_files", + "training_datasets", + "dataset_documents", + "training_document_links", + "model_versions", + ] + + for table in expected_tables: + assert table in table_names, f"Table '{table}' not found" + + +class TestSessionManagement: + """Tests for database session context manager.""" + + def test_session_commit(self, db_session): + """Verify session commits changes successfully.""" + token = AdminToken( + token="commit-test-token", + name="Commit Test", + is_active=True, + ) + db_session.add(token) + db_session.commit() + + result = db_session.exec( + select(AdminToken).where(AdminToken.token == "commit-test-token") + ).first() + + assert result is not None + assert result.name == "Commit Test" + + def test_session_rollback_on_error(self, test_engine): + """Verify session rollback on exception.""" + session = Session(test_engine) + + try: + token = AdminToken( + token="rollback-test-token", + name="Rollback Test", + is_active=True, + ) + session.add(token) + session.commit() + + # Try to insert duplicate (should fail) + duplicate = AdminToken( + token="rollback-test-token", # Same primary key + name="Duplicate", + is_active=True, + ) + session.add(duplicate) + session.commit() + except Exception: + session.rollback() + finally: + session.close() + + # Verify original record exists + with Session(test_engine) as verify_session: + result = verify_session.exec( + select(AdminToken).where(AdminToken.token == "rollback-test-token") + ).first() + assert result is not None + assert result.name == "Rollback Test" + + def test_session_isolation(self, test_engine): + """Verify sessions are isolated from each other.""" + session1 = Session(test_engine) + session2 = Session(test_engine) + + try: + # Insert in session1, don't commit + token = AdminToken( + token="isolation-test-token", + name="Isolation Test", + is_active=True, + ) + session1.add(token) + session1.flush() + + # Session2 should not see uncommitted data (with proper isolation) + # Note: SQLite in-memory may have different isolation behavior + session1.commit() + + result = session2.exec( + select(AdminToken).where(AdminToken.token == "isolation-test-token") + ).first() + + # After commit, session2 should see the data + assert result is not None + + finally: + session1.close() + session2.close() + + +class TestBasicCRUDOperations: + """Tests for basic CRUD operations on database.""" + + def test_create_and_read_token(self, db_session): + """Test creating and reading admin token.""" + token = AdminToken( + token="crud-test-token", + name="CRUD Test", + is_active=True, + ) + db_session.add(token) + db_session.commit() + + result = db_session.get(AdminToken, "crud-test-token") + + assert result is not None + assert result.name == "CRUD Test" + assert result.is_active is True + + def test_update_entity(self, db_session, admin_token): + """Test updating an entity.""" + admin_token.name = "Updated Name" + db_session.add(admin_token) + db_session.commit() + + result = db_session.get(AdminToken, admin_token.token) + + assert result is not None + assert result.name == "Updated Name" + + def test_delete_entity(self, db_session): + """Test deleting an entity.""" + token = AdminToken( + token="delete-test-token", + name="Delete Test", + is_active=True, + ) + db_session.add(token) + db_session.commit() + + db_session.delete(token) + db_session.commit() + + result = db_session.get(AdminToken, "delete-test-token") + assert result is None + + def test_foreign_key_constraint(self, db_session, admin_token): + """Test foreign key constraints are enforced.""" + from uuid import uuid4 + + doc = AdminDocument( + document_id=uuid4(), + admin_token=admin_token.token, + filename="fk_test.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/test/fk_test.pdf", + page_count=1, + status="pending", + ) + db_session.add(doc) + db_session.commit() + + # Document should reference valid token + result = db_session.get(AdminDocument, doc.document_id) + assert result is not None + assert result.admin_token == admin_token.token + + +class TestQueryOperations: + """Tests for various query operations.""" + + def test_select_with_filter(self, db_session, multiple_documents): + """Test SELECT with WHERE clause.""" + results = db_session.exec( + select(AdminDocument).where(AdminDocument.status == "labeled") + ).all() + + assert len(results) == 2 + for doc in results: + assert doc.status == "labeled" + + def test_select_with_order(self, db_session, multiple_documents): + """Test SELECT with ORDER BY clause.""" + results = db_session.exec( + select(AdminDocument).order_by(AdminDocument.file_size.desc()) + ).all() + + file_sizes = [doc.file_size for doc in results] + assert file_sizes == sorted(file_sizes, reverse=True) + + def test_select_with_limit_offset(self, db_session, multiple_documents): + """Test SELECT with LIMIT and OFFSET.""" + results = db_session.exec( + select(AdminDocument) + .order_by(AdminDocument.filename) + .offset(2) + .limit(2) + ).all() + + assert len(results) == 2 + + def test_count_query(self, db_session, multiple_documents): + """Test COUNT aggregation.""" + from sqlalchemy import func + + count = db_session.exec( + select(func.count()).select_from(AdminDocument) + ).one() + + assert count == 5 + + def test_group_by_query(self, db_session, multiple_documents): + """Test GROUP BY aggregation.""" + from sqlalchemy import func + + results = db_session.exec( + select( + AdminDocument.status, + func.count(AdminDocument.document_id).label("count"), + ).group_by(AdminDocument.status) + ).all() + + status_counts = {row[0]: row[1] for row in results} + + assert status_counts.get("pending") == 2 + assert status_counts.get("labeled") == 2 + assert status_counts.get("exported") == 1 diff --git a/tests/web/test_admin_annotations.py b/tests/web/test_admin_annotations.py index 0140b03..642c61c 100644 --- a/tests/web/test_admin_annotations.py +++ b/tests/web/test_admin_annotations.py @@ -196,3 +196,121 @@ class TestAnnotationModel: assert 0 <= ann.y_center <= 1 assert 0 <= ann.width <= 1 assert 0 <= ann.height <= 1 + + +class TestAutoLabelFilePathResolution: + """Tests for auto-label file path resolution. + + The auto-label endpoint needs to resolve the storage path (e.g., "raw_pdfs/uuid.pdf") + to an actual filesystem path via the storage helper. + """ + + def test_extracts_filename_from_storage_path(self): + """Test that filename is extracted from storage path correctly.""" + # Storage paths are like "raw_pdfs/uuid.pdf" + storage_path = "raw_pdfs/550e8400-e29b-41d4-a716-446655440000.pdf" + + # The annotation endpoint extracts filename + filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path + + assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf" + + def test_handles_path_without_prefix(self): + """Test that paths without prefix are handled.""" + storage_path = "550e8400-e29b-41d4-a716-446655440000.pdf" + + filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path + + assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf" + + def test_storage_helper_resolves_path(self): + """Test that storage helper can resolve the path.""" + from pathlib import Path + from unittest.mock import MagicMock, patch + + # Mock storage helper + mock_storage = MagicMock() + mock_path = Path("/storage/raw_pdfs/test.pdf") + mock_storage.get_raw_pdf_local_path.return_value = mock_path + + with patch( + "inference.web.services.storage_helpers.get_storage_helper", + return_value=mock_storage, + ): + from inference.web.services.storage_helpers import get_storage_helper + + storage = get_storage_helper() + result = storage.get_raw_pdf_local_path("test.pdf") + + assert result == mock_path + mock_storage.get_raw_pdf_local_path.assert_called_once_with("test.pdf") + + def test_auto_label_request_validation(self): + """Test AutoLabelRequest validates field_values.""" + # Valid request + request = AutoLabelRequest( + field_values={"InvoiceNumber": "12345"}, + replace_existing=False, + ) + assert request.field_values == {"InvoiceNumber": "12345"} + + # Empty field_values should be valid at schema level + # (endpoint validates non-empty) + request_empty = AutoLabelRequest( + field_values={}, + replace_existing=False, + ) + assert request_empty.field_values == {} + + +class TestMatchClassAttributes: + """Tests for Match class attributes used in auto-labeling. + + The autolabel service uses Match objects from FieldMatcher. + Verifies the correct attribute names are used. + """ + + def test_match_has_matched_text_attribute(self): + """Test that Match class has matched_text attribute (not matched_value).""" + from shared.matcher.models import Match + + # Create a Match object + match = Match( + field="invoice_number", + value="12345", + bbox=(100, 100, 200, 150), + page_no=0, + score=0.95, + matched_text="INV-12345", + context_keywords=["faktura", "nummer"], + ) + + # Verify matched_text exists (this is what autolabel.py should use) + assert hasattr(match, "matched_text") + assert match.matched_text == "INV-12345" + + # Verify matched_value does NOT exist + # This was the bug - autolabel.py was using matched_value instead of matched_text + assert not hasattr(match, "matched_value") + + def test_match_attributes_for_annotation_creation(self): + """Test that Match has all attributes needed for annotation creation.""" + from shared.matcher.models import Match + + match = Match( + field="amount", + value="1000.00", + bbox=(50, 200, 150, 230), + page_no=0, + score=0.88, + matched_text="1 000,00", + context_keywords=["att betala", "summa"], + ) + + # These are all the attributes used in autolabel._create_annotations_from_matches + assert hasattr(match, "bbox") + assert hasattr(match, "matched_text") # NOT matched_value + assert hasattr(match, "score") + + # Verify bbox format + assert len(match.bbox) == 4 # (x0, y0, x1, y1) diff --git a/tests/web/test_admin_auth.py b/tests/web/test_admin_auth.py index c2f6d92..72403a4 100644 --- a/tests/web/test_admin_auth.py +++ b/tests/web/test_admin_auth.py @@ -3,7 +3,7 @@ Tests for Admin Authentication. """ import pytest -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock, patch from fastapi import HTTPException @@ -132,6 +132,47 @@ class TestTokenRepository: with patch.object(repo, "_now", return_value=datetime.utcnow()): assert repo.is_valid("test-token") is False + def test_is_valid_expired_token_timezone_aware(self): + """Test expired token with timezone-aware datetime. + + This verifies the fix for comparing timezone-aware and naive datetimes. + The auth API now creates tokens with timezone-aware expiration dates. + """ + with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx: + mock_session = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_session + + # Create token with timezone-aware expiration (as auth API now does) + mock_token = AdminToken( + token="test-token", + name="Test", + is_active=True, + expires_at=datetime.now(timezone.utc) - timedelta(days=1), + ) + mock_session.get.return_value = mock_token + + repo = TokenRepository() + # _now() returns timezone-aware datetime, should compare correctly + assert repo.is_valid("test-token") is False + + def test_is_valid_not_expired_token_timezone_aware(self): + """Test non-expired token with timezone-aware datetime.""" + with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx: + mock_session = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_session + + # Create token with timezone-aware expiration in the future + mock_token = AdminToken( + token="test-token", + name="Test", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ) + mock_session.get.return_value = mock_token + + repo = TokenRepository() + assert repo.is_valid("test-token") is True + def test_is_valid_token_not_found(self): """Test token not found.""" with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx: diff --git a/tests/web/test_dashboard_api.py b/tests/web/test_dashboard_api.py new file mode 100644 index 0000000..7cea5de --- /dev/null +++ b/tests/web/test_dashboard_api.py @@ -0,0 +1,317 @@ +""" +Tests for Dashboard API Endpoints and Services. + +Tests are split into: +1. Unit tests for business logic (is_annotation_complete, etc.) +2. Service tests with mocked database +3. Integration tests via TestClient (requires DB) +""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + + +# Test data constants +TEST_DOC_UUID_1 = "550e8400-e29b-41d4-a716-446655440001" +TEST_MODEL_UUID = "660e8400-e29b-41d4-a716-446655440001" +TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440001" + + +class TestAnnotationCompletenessLogic: + """Unit tests for annotation completeness calculation logic. + + These tests verify the core business logic: + - Complete: has (invoice_number OR ocr_number) AND (bankgiro OR plusgiro) + - Incomplete: labeled but missing required fields + """ + + def test_document_with_invoice_number_and_bankgiro_is_complete(self): + """Document with invoice_number + bankgiro should be complete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 0, "class_name": "invoice_number"}, + {"class_id": 4, "class_name": "bankgiro"}, + ] + + assert is_annotation_complete(annotations) is True + + def test_document_with_ocr_number_and_plusgiro_is_complete(self): + """Document with ocr_number + plusgiro should be complete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 3, "class_name": "ocr_number"}, + {"class_id": 5, "class_name": "plusgiro"}, + ] + + assert is_annotation_complete(annotations) is True + + def test_document_with_invoice_number_and_plusgiro_is_complete(self): + """Document with invoice_number + plusgiro should be complete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 0, "class_name": "invoice_number"}, + {"class_id": 5, "class_name": "plusgiro"}, + ] + + assert is_annotation_complete(annotations) is True + + def test_document_with_ocr_number_and_bankgiro_is_complete(self): + """Document with ocr_number + bankgiro should be complete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 3, "class_name": "ocr_number"}, + {"class_id": 4, "class_name": "bankgiro"}, + ] + + assert is_annotation_complete(annotations) is True + + def test_document_with_only_identifier_is_incomplete(self): + """Document with only identifier field should be incomplete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 0, "class_name": "invoice_number"}, + ] + + assert is_annotation_complete(annotations) is False + + def test_document_with_only_payment_is_incomplete(self): + """Document with only payment field should be incomplete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 4, "class_name": "bankgiro"}, + ] + + assert is_annotation_complete(annotations) is False + + def test_document_with_no_annotations_is_incomplete(self): + """Document with no annotations should be incomplete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + assert is_annotation_complete([]) is False + + def test_document_with_other_fields_only_is_incomplete(self): + """Document with only non-essential fields should be incomplete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 1, "class_name": "invoice_date"}, + {"class_id": 6, "class_name": "amount"}, + ] + + assert is_annotation_complete(annotations) is False + + def test_document_with_all_fields_is_complete(self): + """Document with all fields should be complete.""" + from inference.web.services.dashboard_service import is_annotation_complete + + annotations = [ + {"class_id": 0, "class_name": "invoice_number"}, + {"class_id": 1, "class_name": "invoice_date"}, + {"class_id": 4, "class_name": "bankgiro"}, + {"class_id": 6, "class_name": "amount"}, + ] + + assert is_annotation_complete(annotations) is True + + +class TestDashboardStatsService: + """Tests for DashboardStatsService with mocked database.""" + + @pytest.fixture + def mock_session(self): + """Create a mock database session.""" + session = MagicMock() + session.exec.return_value.one.return_value = 0 + return session + + def test_completeness_rate_calculation(self): + """Test completeness rate is calculated correctly.""" + # Direct calculation test + complete = 25 + incomplete = 8 + total_assessed = complete + incomplete + expected_rate = round(complete / total_assessed * 100, 2) + + assert expected_rate == pytest.approx(75.76, rel=0.01) + + def test_completeness_rate_zero_documents(self): + """Test completeness rate is 0 when no documents.""" + complete = 0 + incomplete = 0 + total_assessed = complete + incomplete + + completeness_rate = ( + round(complete / total_assessed * 100, 2) + if total_assessed > 0 + else 0.0 + ) + + assert completeness_rate == 0.0 + + +class TestDashboardActivityService: + """Tests for DashboardActivityService activity aggregation.""" + + def test_activity_types(self): + """Test all activity types are defined.""" + expected_types = [ + "document_uploaded", + "annotation_modified", + "training_completed", + "training_failed", + "model_activated", + ] + + for activity_type in expected_types: + assert activity_type in expected_types + + +class TestDashboardSchemas: + """Tests for Dashboard API schemas.""" + + def test_dashboard_stats_response_schema(self): + """Test DashboardStatsResponse schema validation.""" + from inference.web.schemas.admin import DashboardStatsResponse + + response = DashboardStatsResponse( + total_documents=38, + annotation_complete=25, + annotation_incomplete=8, + pending=5, + completeness_rate=75.76, + ) + + assert response.total_documents == 38 + assert response.annotation_complete == 25 + assert response.annotation_incomplete == 8 + assert response.pending == 5 + assert response.completeness_rate == 75.76 + + def test_active_model_response_schema(self): + """Test ActiveModelResponse schema with null model.""" + from inference.web.schemas.admin import ActiveModelResponse + + response = ActiveModelResponse( + model=None, + running_training=None, + ) + + assert response.model is None + assert response.running_training is None + + def test_active_model_info_schema(self): + """Test ActiveModelInfo schema validation.""" + from inference.web.schemas.admin import ActiveModelInfo + + model = ActiveModelInfo( + version_id=TEST_MODEL_UUID, + version="1.2.0", + name="Invoice Model", + metrics_mAP=0.951, + metrics_precision=0.94, + metrics_recall=0.92, + document_count=500, + activated_at=datetime(2024, 1, 20, 15, 0, 0, tzinfo=timezone.utc), + ) + + assert model.version == "1.2.0" + assert model.name == "Invoice Model" + assert model.metrics_mAP == 0.951 + + def test_running_training_info_schema(self): + """Test RunningTrainingInfo schema validation.""" + from inference.web.schemas.admin import RunningTrainingInfo + + task = RunningTrainingInfo( + task_id=TEST_TASK_UUID, + name="Run-2024-02", + status="running", + started_at=datetime(2024, 1, 25, 10, 0, 0, tzinfo=timezone.utc), + progress=45, + ) + + assert task.name == "Run-2024-02" + assert task.status == "running" + assert task.progress == 45 + + def test_activity_item_schema(self): + """Test ActivityItem schema validation.""" + from inference.web.schemas.admin import ActivityItem + + activity = ActivityItem( + type="model_activated", + description="Activated model v1.2.0", + timestamp=datetime(2024, 1, 25, 12, 0, 0, tzinfo=timezone.utc), + metadata={"version_id": TEST_MODEL_UUID, "version": "1.2.0"}, + ) + + assert activity.type == "model_activated" + assert activity.description == "Activated model v1.2.0" + assert activity.metadata["version"] == "1.2.0" + + def test_recent_activity_response_schema(self): + """Test RecentActivityResponse schema with empty activities.""" + from inference.web.schemas.admin import RecentActivityResponse + + response = RecentActivityResponse(activities=[]) + + assert response.activities == [] + + +class TestDashboardRouterCreation: + """Tests for dashboard router creation.""" + + def test_creates_router_with_expected_endpoints(self): + """Test router is created with expected endpoint paths.""" + from inference.web.api.v1.admin.dashboard import create_dashboard_router + + router = create_dashboard_router() + + paths = [route.path for route in router.routes] + + assert any("/stats" in p for p in paths) + assert any("/active-model" in p for p in paths) + assert any("/activity" in p for p in paths) + + def test_router_has_correct_prefix(self): + """Test router has /admin/dashboard prefix.""" + from inference.web.api.v1.admin.dashboard import create_dashboard_router + + router = create_dashboard_router() + + assert router.prefix == "/admin/dashboard" + + def test_router_has_dashboard_tag(self): + """Test router uses Dashboard tag.""" + from inference.web.api.v1.admin.dashboard import create_dashboard_router + + router = create_dashboard_router() + + assert "Dashboard" in router.tags + + +class TestFieldClassIds: + """Tests for field class ID constants.""" + + def test_identifier_class_ids(self): + """Test identifier field class IDs.""" + from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS + + # invoice_number = 0, ocr_number = 3 + assert 0 in IDENTIFIER_CLASS_IDS + assert 3 in IDENTIFIER_CLASS_IDS + + def test_payment_class_ids(self): + """Test payment field class IDs.""" + from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS + + # bankgiro = 4, plusgiro = 5 + assert 4 in PAYMENT_CLASS_IDS + assert 5 in PAYMENT_CLASS_IDS diff --git a/update_test_imports.py b/update_test_imports.py deleted file mode 100644 index 7c8d7ed..0000000 --- a/update_test_imports.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python3 -"""Update test imports to use new structure.""" - -import re -from pathlib import Path - -# Import mapping: old -> new -IMPORT_MAPPINGS = { - # Admin routes - r'from src\.web\.admin_routes import': 'from src.web.api.v1.admin.documents import', - r'from src\.web\.admin_annotation_routes import': 'from src.web.api.v1.admin.annotations import', - r'from src\.web\.admin_training_routes import': 'from src.web.api.v1.admin.training import', - - # Auth and core - r'from src\.web\.admin_auth import': 'from src.web.core.auth import', - r'from src\.web\.admin_autolabel import': 'from src.web.services.autolabel import', - r'from src\.web\.admin_scheduler import': 'from src.web.core.scheduler import', - - # Schemas - r'from src\.web\.admin_schemas import': 'from src.web.schemas.admin import', - r'from src\.web\.schemas import': 'from src.web.schemas.inference import', - - # Services - r'from src\.web\.services import': 'from src.web.services.inference import', - r'from src\.web\.async_service import': 'from src.web.services.async_processing import', - r'from src\.web\.batch_upload_service import': 'from src.web.services.batch_upload import', - - # Workers - r'from src\.web\.async_queue import': 'from src.web.workers.async_queue import', - r'from src\.web\.batch_queue import': 'from src.web.workers.batch_queue import', - - # Routes - r'from src\.web\.routes import': 'from src.web.api.v1.routes import', - r'from src\.web\.async_routes import': 'from src.web.api.v1.async_api.routes import', - r'from src\.web\.batch_upload_routes import': 'from src.web.api.v1.batch.routes import', -} - -def update_file(file_path: Path) -> bool: - """Update imports in a single file.""" - content = file_path.read_text(encoding='utf-8') - original_content = content - - for old_pattern, new_import in IMPORT_MAPPINGS.items(): - content = re.sub(old_pattern, new_import, content) - - if content != original_content: - file_path.write_text(content, encoding='utf-8') - return True - return False - -def main(): - """Update all test files.""" - test_dir = Path('tests/web') - updated_files = [] - - for test_file in test_dir.glob('test_*.py'): - if update_file(test_file): - updated_files.append(test_file.name) - - if updated_files: - print(f"✓ Updated {len(updated_files)} test files:") - for filename in sorted(updated_files): - print(f" - {filename}") - else: - print("No files needed updating") - -if __name__ == '__main__': - main()