This commit is contained in:
Yaojia Wang
2026-01-27 00:47:10 +01:00
parent e83a0cae36
commit 58bf75db68
141 changed files with 24814 additions and 3884 deletions

View File

@@ -7,7 +7,8 @@
"Edit(*)",
"Glob(*)",
"Grep(*)",
"Task(*)"
"Task(*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_batch_upload_routes.py::TestBatchUploadRoutes::test_upload_batch_async_mode_default -v -s 2>&1 | head -100\")"
]
}
}

View File

@@ -81,7 +81,13 @@
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
"Bash(wsl bash -c:*)"
"Bash(wsl bash -c:*)",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -120\")",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -80\")",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")"
],
"deny": [],
"ask": [],

BIN
.coverage Normal file

Binary file not shown.

View File

@@ -76,6 +76,38 @@
| 8 | payment_line | 支付行 (机器可读格式) |
| 9 | customer_number | 客户编号 |
## DPI 配置
**重要**: 系统所有组件统一使用 **150 DPI**,确保训练和推理的一致性。
DPI每英寸点数设置必须在训练和推理时保持一致否则会导致
- 检测框尺寸失配
- mAP显著下降可能从93.5%降到60-70%
- 字段漏检或误检
### 配置位置
| 组件 | 配置文件 | 配置项 |
|------|---------|--------|
| **全局常量** | `src/config.py` | `DEFAULT_DPI = 150` |
| **Web推理** | `src/web/config.py` | `ModelConfig.dpi` (导入自 `src.config`) |
| **CLI推理** | `src/cli/infer.py` | `--dpi` 默认值 = `DEFAULT_DPI` |
| **自动标注** | `src/config.py` | `AUTOLABEL['dpi'] = DEFAULT_DPI` |
| **PDF转图** | `src/web/api/v1/admin/documents.py` | 使用 `DEFAULT_DPI` |
### 使用示例
```bash
# 训练使用默认150 DPI
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
# 推理默认150 DPI与训练一致
python -m src.cli.infer -m runs/train/invoice_fields/weights/best.pt -i invoice.pdf
# 手动指定DPI仅当需要与非默认训练DPI的模型配合时
python -m src.cli.infer -m custom_model.pt -i invoice.pdf --dpi 150
```
## 安装
```bash
@@ -490,7 +522,7 @@ Options:
--input, -i 输入 PDF/图像
--output, -o 输出 JSON 路径
--confidence 置信度阈值 (默认: 0.5)
--dpi 渲染 DPI (默认: 300)
--dpi 渲染 DPI (默认: 150, 必须与训练DPI一致)
--gpu 使用 GPU
```

96
create_shims.sh Normal file
View File

@@ -0,0 +1,96 @@
#!/bin/bash
# Create backward compatibility shims for all migrated files
# admin_auth.py -> core/auth.py
cat > src/web/admin_auth.py << 'EOF'
"""DEPRECATED: Import from src.web.core.auth instead"""
from src.web.core.auth import * # noqa: F401, F403
EOF
# admin_autolabel.py -> services/autolabel.py
cat > src/web/admin_autolabel.py << 'EOF'
"""DEPRECATED: Import from src.web.services.autolabel instead"""
from src.web.services.autolabel import * # noqa: F401, F403
EOF
# admin_scheduler.py -> core/scheduler.py
cat > src/web/admin_scheduler.py << 'EOF'
"""DEPRECATED: Import from src.web.core.scheduler instead"""
from src.web.core.scheduler import * # noqa: F401, F403
EOF
# admin_schemas.py -> schemas/admin.py
cat > src/web/admin_schemas.py << 'EOF'
"""DEPRECATED: Import from src.web.schemas.admin instead"""
from src.web.schemas.admin import * # noqa: F401, F403
EOF
# schemas.py -> schemas/inference.py + schemas/common.py
cat > src/web/schemas.py << 'EOF'
"""DEPRECATED: Import from src.web.schemas.inference or src.web.schemas.common instead"""
from src.web.schemas.inference import * # noqa: F401, F403
from src.web.schemas.common import * # noqa: F401, F403
EOF
# services.py -> services/inference.py
cat > src/web/services.py << 'EOF'
"""DEPRECATED: Import from src.web.services.inference instead"""
from src.web.services.inference import * # noqa: F401, F403
EOF
# async_queue.py -> workers/async_queue.py
cat > src/web/async_queue.py << 'EOF'
"""DEPRECATED: Import from src.web.workers.async_queue instead"""
from src.web.workers.async_queue import * # noqa: F401, F403
EOF
# async_service.py -> services/async_processing.py
cat > src/web/async_service.py << 'EOF'
"""DEPRECATED: Import from src.web.services.async_processing instead"""
from src.web.services.async_processing import * # noqa: F401, F403
EOF
# batch_queue.py -> workers/batch_queue.py
cat > src/web/batch_queue.py << 'EOF'
"""DEPRECATED: Import from src.web.workers.batch_queue instead"""
from src.web.workers.batch_queue import * # noqa: F401, F403
EOF
# batch_upload_service.py -> services/batch_upload.py
cat > src/web/batch_upload_service.py << 'EOF'
"""DEPRECATED: Import from src.web.services.batch_upload instead"""
from src.web.services.batch_upload import * # noqa: F401, F403
EOF
# batch_upload_routes.py -> api/v1/batch/routes.py
cat > src/web/batch_upload_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.batch.routes instead"""
from src.web.api.v1.batch.routes import * # noqa: F401, F403
EOF
# admin_routes.py -> api/v1/admin/documents.py
cat > src/web/admin_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.admin.documents instead"""
from src.web.api.v1.admin.documents import * # noqa: F401, F403
EOF
# admin_annotation_routes.py -> api/v1/admin/annotations.py
cat > src/web/admin_annotation_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.admin.annotations instead"""
from src.web.api.v1.admin.annotations import * # noqa: F401, F403
EOF
# admin_training_routes.py -> api/v1/admin/training.py
cat > src/web/admin_training_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.admin.training instead"""
from src.web.api.v1.admin.training import * # noqa: F401, F403
EOF
# routes.py -> api/v1/routes.py
cat > src/web/routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.routes instead"""
from src.web.api.v1.routes import * # noqa: F401, F403
EOF
echo "✓ Created backward compatibility shims for all migrated files"

View File

@@ -1,405 +0,0 @@
# Invoice Master POC v2 - 代码审查报告
**审查日期**: 2026-01-22
**代码库规模**: 67 个 Python 源文件,约 22,434 行代码
**测试覆盖率**: ~40-50%
---
## 执行摘要
### 总体评估:**良好B+**
**优势**
- ✅ 清晰的模块化架构,职责分离良好
- ✅ 使用了合适的数据类和类型提示
- ✅ 针对瑞典发票的全面规范化逻辑
- ✅ 空间索引优化O(1) token 查找)
- ✅ 完善的降级机制YOLO 失败时的 OCR fallback
- ✅ 设计良好的 Web API 和 UI
**主要问题**
- ❌ 支付行解析代码重复3+ 处)
- ❌ 长函数(`_normalize_customer_number` 127 行)
- ❌ 配置安全问题(明文数据库密码)
- ❌ 异常处理不一致(到处都是通用 Exception
- ❌ 缺少集成测试
- ❌ 魔法数字散布各处0.5, 0.95, 300 等)
---
## 1. 架构分析
### 1.1 模块结构
```
src/
├── inference/ # 推理管道核心
│ ├── pipeline.py (517 行) ⚠️
│ ├── field_extractor.py (1,347 行) 🔴 太长
│ └── yolo_detector.py
├── web/ # FastAPI Web 服务
│ ├── app.py (765 行) ⚠️ HTML 内联
│ ├── routes.py (184 行)
│ └── services.py (286 行)
├── ocr/ # OCR 提取
│ ├── paddle_ocr.py
│ └── machine_code_parser.py (919 行) 🔴 太长
├── matcher/ # 字段匹配
│ └── field_matcher.py (875 行) ⚠️
├── utils/ # 共享工具
│ ├── validators.py
│ ├── text_cleaner.py
│ ├── fuzzy_matcher.py
│ ├── ocr_corrections.py
│ └── format_variants.py (610 行)
├── processing/ # 批处理
├── data/ # 数据管理
└── cli/ # 命令行工具
```
### 1.2 推理流程
```
PDF/Image 输入
渲染为图片 (pdf/renderer.py)
YOLO 检测 (yolo_detector.py) - 检测字段区域
字段提取 (field_extractor.py)
├→ OCR 文本提取 (ocr/paddle_ocr.py)
├→ 规范化 & 验证
└→ 置信度计算
交叉验证 (pipeline.py)
├→ 解析 payment_line 格式
├→ 从 payment_line 提取 OCR/Amount/Account
└→ 与检测字段验证payment_line 值优先
降级 OCR如果关键字段缺失
├→ 全页 OCR
└→ 正则提取
InferenceResult 输出
```
---
## 2. 代码质量问题
### 2.1 长函数(>50 行)🔴
| 函数 | 文件 | 行数 | 复杂度 | 问题 |
|------|------|------|--------|------|
| `_normalize_customer_number()` | field_extractor.py | **127** | 极高 | 4 层模式匹配7+ 正则,复杂评分 |
| `_cross_validate_payment_line()` | pipeline.py | **127** | 极高 | 核心验证逻辑8+ 条件分支 |
| `_normalize_bankgiro()` | field_extractor.py | 62 | 高 | Luhn 验证 + 多种降级 |
| `_normalize_plusgiro()` | field_extractor.py | 63 | 高 | 类似 bankgiro |
| `_normalize_payment_line()` | field_extractor.py | 74 | 高 | 4 种正则模式 |
| `_normalize_amount()` | field_extractor.py | 78 | 高 | 多策略降级 |
**示例问题** - `_normalize_customer_number()` (第 776-902 行):
```python
def _normalize_customer_number(self, text: str):
# 127 行函数,包含:
# - 4 个嵌套的 if/for 循环
# - 7 种不同的正则模式
# - 5 个评分机制
# - 处理有标签和无标签格式
```
**建议**: 拆分为:
- `_find_customer_code_patterns()`
- `_find_labeled_customer_code()`
- `_score_customer_candidates()`
### 2.2 代码重复 🔴
**支付行解析3+ 处重复实现)**:
1. `_parse_machine_readable_payment_line()` (pipeline.py:217-252)
2. `MachineCodeParser.parse()` (machine_code_parser.py:919 行)
3. `_normalize_payment_line()` (field_extractor.py:632-705)
所有三处都实现类似的正则模式:
```
格式: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
```
**Bankgiro/Plusgiro 验证(重复)**:
- `validators.py`: `is_valid_bankgiro()`, `format_bankgiro()`
- `field_extractor.py`: `_normalize_bankgiro()`, `_normalize_plusgiro()`, `_luhn_checksum()`
- `normalizer.py`: `normalize_bankgiro()`, `normalize_plusgiro()`
- `field_matcher.py`: 类似匹配逻辑
**建议**: 创建统一模块:
```python
# src/common/payment_line_parser.py
class PaymentLineParser:
def parse(text: str) -> PaymentLineResult
# src/common/giro_validator.py
class GiroValidator:
def validate_and_format(value: str, giro_type: str) -> str
```
### 2.3 错误处理不一致 ⚠️
**通用异常捕获31 处)**:
```python
except Exception as e: # 代码库中 31 处
result.errors.append(str(e))
```
**问题**:
- 没有捕获特定错误类型
- 通用错误消息丢失上下文
- 第 142-147 行 (routes.py): 捕获所有异常,返回 500 状态
**当前写法** (routes.py:142-147):
```python
try:
service_result = inference_service.process_pdf(...)
except Exception as e: # 太宽泛
logger.error(f"Error processing document: {e}")
raise HTTPException(status_code=500, detail=str(e))
```
**改进建议**:
```python
except FileNotFoundError:
raise HTTPException(status_code=400, detail="PDF 文件未找到")
except PyMuPDFError:
raise HTTPException(status_code=400, detail="无效的 PDF 格式")
except OCRError:
raise HTTPException(status_code=503, detail="OCR 服务不可用")
```
### 2.4 配置安全问题 🔴
**config.py 第 24-30 行** - 明文凭据:
```python
DATABASE = {
'host': '192.168.68.31', # 硬编码 IP
'user': 'docmaster', # 硬编码用户名
'password': 'nY6LYK5d', # 🔴 明文密码!
'database': 'invoice_master'
}
```
**建议**:
```python
DATABASE = {
'host': os.getenv('DB_HOST', 'localhost'),
'user': os.getenv('DB_USER', 'docmaster'),
'password': os.getenv('DB_PASSWORD'), # 从环境变量读取
'database': os.getenv('DB_NAME', 'invoice_master')
}
```
### 2.5 魔法数字 ⚠️
| 值 | 位置 | 用途 | 问题 |
|---|------|------|------|
| 0.5 | 多处 | 置信度阈值 | 不可按字段配置 |
| 0.95 | pipeline.py | payment_line 置信度 | 无说明 |
| 300 | 多处 | DPI | 硬编码 |
| 0.1 | field_extractor.py | BBox 填充 | 应为配置 |
| 72 | 多处 | PDF 基础 DPI | 公式中的魔法数字 |
| 50 | field_extractor.py | 客户编号评分加分 | 无说明 |
**建议**: 提取到配置:
```python
INFERENCE_CONFIG = {
'confidence_threshold': 0.5,
'payment_line_confidence': 0.95,
'dpi': 300,
'bbox_padding': 0.1,
}
```
### 2.6 命名不一致 ⚠️
**字段名称不一致**:
- YOLO 类名: `invoice_number`, `ocr_number`, `supplier_org_number`
- 字段名: `InvoiceNumber`, `OCR`, `supplier_org_number`
- CSV 列名: 可能又不同
- 数据库字段名: 另一种变体
映射维护在多处:
- `yolo_detector.py` (90-100 行): `CLASS_TO_FIELD`
- 多个其他位置
---
## 3. 测试分析
### 3.1 测试覆盖率
**测试文件**: 13 个
- ✅ 覆盖良好: field_matcher, normalizer, payment_line_parser
- ⚠️ 中等覆盖: field_extractor, pipeline
- ❌ 覆盖不足: web 层, CLI, 批处理
**估算覆盖率**: 40-50%
### 3.2 缺失的测试用例 🔴
**关键缺失**:
1. 交叉验证逻辑 - 最复杂部分,测试很少
2. payment_line 解析变体 - 多种实现,边界情况不清楚
3. OCR 错误纠正 - 不同策略的复杂逻辑
4. Web API 端点 - 没有请求/响应测试
5. 批处理 - 多 worker 协调未测试
6. 降级 OCR 机制 - YOLO 检测失败时
---
## 4. 架构风险
### 🔴 关键风险
1. **配置安全** - config.py 中明文数据库凭据24-30 行)
2. **错误恢复** - 宽泛的异常处理掩盖真实问题
3. **可测试性** - 硬编码依赖阻止单元测试
### 🟡 高风险
1. **代码可维护性** - 支付行解析重复
2. **可扩展性** - 没有长时间推理的异步处理
3. **扩展性** - 添加新字段类型会很困难
### 🟢 中等风险
1. **性能** - 懒加载有帮助,但 ORM 查询未优化
2. **文档** - 大部分足够但可以更好
---
## 5. 优先级矩阵
| 优先级 | 行动 | 工作量 | 影响 |
|--------|------|--------|------|
| 🔴 关键 | 修复配置安全(环境变量) | 1 小时 | 高 |
| 🔴 关键 | 添加集成测试 | 2-3 天 | 高 |
| 🔴 关键 | 文档化错误处理策略 | 4 小时 | 中 |
| 🟡 高 | 统一 payment_line 解析 | 1-2 天 | 高 |
| 🟡 高 | 提取规范化到子模块 | 2-3 天 | 中 |
| 🟡 高 | 添加依赖注入 | 2-3 天 | 中 |
| 🟡 高 | 拆分长函数 | 2-3 天 | 低 |
| 🟢 中 | 提高测试覆盖率到 70%+ | 3-5 天 | 高 |
| 🟢 中 | 提取魔法数字 | 4 小时 | 低 |
| 🟢 中 | 标准化命名约定 | 1-2 天 | 中 |
---
## 6. 具体文件建议
### 高优先级(代码质量)
| 文件 | 问题 | 建议 |
|------|------|------|
| `field_extractor.py` | 1,347 行6 个长规范化方法 | 拆分为 `normalizers/` 子模块 |
| `pipeline.py` | 127 行 `_cross_validate_payment_line()` | 提取到单独的 `CrossValidator` 类 |
| `field_matcher.py` | 875 行;复杂匹配逻辑 | 拆分为 `matching/` 子模块 |
| `config.py` | 硬编码凭据(第 29 行) | 使用环境变量 |
| `machine_code_parser.py` | 919 行payment_line 解析 | 与 pipeline 解析合并 |
### 中优先级(重构)
| 文件 | 问题 | 建议 |
|------|------|------|
| `app.py` | 765 行HTML 内联在 Python 中 | 提取到 `templates/` 目录 |
| `autolabel.py` | 753 行;批处理逻辑 | 提取 worker 函数到模块 |
| `format_variants.py` | 610 行;变体生成 | 考虑策略模式 |
---
## 7. 建议行动
### 第 1 阶段关键修复1 周)
1. **配置安全** (1 小时)
- 移除 config.py 中的明文密码
- 添加环境变量支持
- 更新 README 说明配置
2. **错误处理标准化** (1 天)
- 定义自定义异常类
- 替换通用 Exception 捕获
- 添加错误代码常量
3. **添加关键集成测试** (2 天)
- 端到端推理测试
- payment_line 交叉验证测试
- API 端点测试
### 第 2 阶段重构2-3 周)
4. **统一 payment_line 解析** (2 天)
- 创建 `src/common/payment_line_parser.py`
- 合并 3 处重复实现
- 迁移所有调用方
5. **拆分 field_extractor.py** (3 天)
- 创建 `src/inference/normalizers/` 子模块
- 每个字段类型一个文件
- 提取共享验证逻辑
6. **拆分长函数** (2 天)
- `_normalize_customer_number()` → 3 个函数
- `_cross_validate_payment_line()` → CrossValidator 类
### 第 3 阶段改进1-2 周)
7. **提高测试覆盖率** (5 天)
- 目标70%+ 覆盖率
- 专注于验证逻辑
- 添加边界情况测试
8. **配置管理改进** (1 天)
- 提取所有魔法数字
- 创建配置文件YAML
- 添加配置验证
9. **文档改进** (2 天)
- 添加架构图
- 文档化所有私有方法
- 创建贡献指南
---
## 附录 A度量指标
### 代码复杂度
| 类别 | 计数 | 平均行数 |
|------|------|----------|
| 源文件 | 67 | 334 |
| 长文件 (>500 行) | 12 | 875 |
| 长函数 (>50 行) | 23 | 89 |
| 测试文件 | 13 | 298 |
### 依赖关系
| 类型 | 计数 |
|------|------|
| 外部依赖 | ~25 |
| 内部模块 | 10 |
| 循环依赖 | 0 ✅ |
### 代码风格
| 指标 | 覆盖率 |
|------|--------|
| 类型提示 | 80% |
| Docstrings (公开) | 80% |
| Docstrings (私有) | 40% |
| 测试覆盖率 | 45% |
---
**生成日期**: 2026-01-22
**审查者**: Claude Code
**版本**: v2.0

View File

@@ -1,96 +0,0 @@
# Field Extractor 分析报告
## 概述
field_extractor.py (1183行) 最初被识别为可优化文件,尝试使用 `src/normalize` 模块进行重构,但经过分析和测试后发现 **不应该重构**
## 重构尝试
### 初始计划
将 field_extractor.py 中的重复 normalize 方法删除,统一使用 `src/normalize/normalize_field()` 接口。
### 实施步骤
1. ✅ 备份原文件 (`field_extractor_old.py`)
2. ✅ 修改 `_normalize_and_validate` 使用统一 normalizer
3. ✅ 删除重复的 normalize 方法 (~400行)
4. ❌ 运行测试 - **28个失败**
5. ✅ 添加 wrapper 方法委托给 normalizer
6. ❌ 再次测试 - **12个失败**
7. ✅ 还原原文件
8. ✅ 测试通过 - **全部45个测试通过**
## 关键发现
### 两个模块的不同用途
| 模块 | 用途 | 输入 | 输出 | 示例 |
|------|------|------|------|------|
| **src/normalize/** | **变体生成** 用于匹配 | 已提取的字段值 | 多个匹配变体列表 | `"INV-12345"``["INV-12345", "12345"]` |
| **field_extractor** | **值提取** 从OCR文本 | 包含字段的原始OCR文本 | 提取的单个字段值 | `"Fakturanummer: A3861"``"A3861"` |
### 为什么不能统一?
1. **src/normalize/** 的设计目的:
- 接收已经提取的字段值
- 生成多个标准化变体用于fuzzy matching
- 例如 BankgiroNormalizer:
```python
normalize("782-1713") → ["7821713", "782-1713"] # 生成变体
```
2. **field_extractor** 的 normalize 方法:
- 接收包含字段的原始OCR文本可能包含标签、其他文本等
- **提取**特定模式的字段值
- 例如 `_normalize_bankgiro`:
```python
_normalize_bankgiro("Bankgiro: 782-1713") → ("782-1713", True, None) # 从文本提取
```
3. **关键区别**:
- Normalizer: 变体生成器 (for matching)
- Field Extractor: 模式提取器 (for parsing)
### 测试失败示例
使用 normalizer 替代 field extractor 方法后的失败:
```python
# InvoiceNumber 测试
Input: "Fakturanummer: A3861"
期望: "A3861"
实际: "Fakturanummer: A3861" # 没有提取,只是清理
# Bankgiro 测试
Input: "Bankgiro: 782-1713"
期望: "782-1713"
实际: "7821713" # 返回了不带破折号的变体,而不是提取格式化值
```
## 结论
**field_extractor.py 不应该使用 src/normalize 模块重构**,因为:
1.**职责不同**: 提取 vs 变体生成
2.**输入不同**: 包含标签的原始OCR文本 vs 已提取的字段值
3.**输出不同**: 单个提取值 vs 多个匹配变体
4.**现有代码运行良好**: 所有45个测试通过
5.**提取逻辑有价值**: 包含复杂的模式匹配规则(例如区分 Bankgiro/Plusgiro 格式)
## 建议
1. **保留 field_extractor.py 原样**: 不进行重构
2. **文档化两个模块的差异**: 确保团队理解各自用途
3. **关注其他优化目标**: machine_code_parser.py (919行)
## 学习点
重构前应该:
1. 理解模块的**真实用途**,而不只是看代码相似度
2. 运行完整测试套件验证假设
3. 评估是否真的存在重复,还是表面相似但用途不同
---
**状态**: ✅ 分析完成,决定不重构
**测试**: ✅ 45/45 通过
**文件**: 保持 1183行 原样

View File

@@ -1,238 +0,0 @@
# Machine Code Parser 分析报告
## 文件概况
- **文件**: `src/ocr/machine_code_parser.py`
- **总行数**: 919 行
- **代码行**: 607 行 (66%)
- **方法数**: 14 个
- **正则表达式使用**: 47 次
## 代码结构
### 类结构
```
MachineCodeResult (数据类)
├── to_dict()
└── get_region_bbox()
MachineCodeParser (主解析器)
├── __init__()
├── parse() - 主入口
├── _find_tokens_with_values()
├── _find_machine_code_line_tokens()
├── _parse_standard_payment_line_with_tokens()
├── _parse_standard_payment_line() - 142行 ⚠️
├── _extract_ocr() - 50行
├── _extract_bankgiro() - 58行
├── _extract_plusgiro() - 30行
├── _extract_amount() - 68行
├── _calculate_confidence()
└── cross_validate()
```
## 发现的问题
### 1. ⚠️ `_parse_standard_payment_line` 方法过长 (142行)
**位置**: 442-582 行
**问题**:
- 包含嵌套函数 `normalize_account_spaces``format_account`
- 多个正则匹配分支
- 逻辑复杂,难以测试和维护
**建议**:
可以拆分为独立方法:
- `_normalize_account_spaces(line)`
- `_format_account(account_digits, context)`
- `_match_primary_pattern(line)`
- `_match_fallback_patterns(line)`
### 2. 🔁 4个 `_extract_*` 方法有重复模式
所有 extract 方法都遵循相同模式:
```python
def _extract_XXX(self, tokens):
candidates = []
for token in tokens:
text = token.text.strip()
matches = self.XXX_PATTERN.findall(text)
for match in matches:
# 验证逻辑
# 上下文检测
candidates.append((normalized, context_score, token))
if not candidates:
return None
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
return candidates[0][0]
```
**重复的逻辑**:
- Token 迭代
- 模式匹配
- 候选收集
- 上下文评分
- 排序和选择最佳匹配
**建议**:
可以提取基础提取器类或通用方法来减少重复。
### 3. ✅ 上下文检测重复
上下文检测代码在多个地方重复:
```python
# _extract_bankgiro 中
context_text = ' '.join(t.text.lower() for t in tokens)
is_bankgiro_context = (
'bankgiro' in context_text or
'bg:' in context_text or
'bg ' in context_text
)
# _extract_plusgiro 中
context_text = ' '.join(t.text.lower() for t in tokens)
is_plusgiro_context = (
'plusgiro' in context_text or
'postgiro' in context_text or
'pg:' in context_text or
'pg ' in context_text
)
# _parse_standard_payment_line 中
context = (context_line or raw_line).lower()
is_plusgiro_context = (
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
and 'bankgiro' not in context
)
```
**建议**:
提取为独立方法:
- `_detect_account_context(tokens) -> dict[str, bool]`
## 重构建议
### 方案 A: 轻度重构(推荐)✅
**目标**: 提取重复的上下文检测逻辑,不改变主要结构
**步骤**:
1. 提取 `_detect_account_context(tokens)` 方法
2. 提取 `_normalize_account_spaces(line)` 为独立方法
3. 提取 `_format_account(digits, context)` 为独立方法
**影响**:
- 减少 ~50-80 行重复代码
- 提高可测试性
- 低风险,易于验证
**预期结果**: 919 行 → ~850 行 (↓7%)
### 方案 B: 中度重构
**目标**: 创建通用的字段提取框架
**步骤**:
1. 创建 `_generic_extract(pattern, normalizer, context_checker)`
2. 重构所有 `_extract_*` 方法使用通用框架
3. 拆分 `_parse_standard_payment_line` 为多个小方法
**影响**:
- 减少 ~150-200 行代码
- 显著提高可维护性
- 中等风险,需要全面测试
**预期结果**: 919 行 → ~720 行 (↓22%)
### 方案 C: 深度重构(不推荐)
**目标**: 完全重新设计为策略模式
**风险**:
- 高风险,可能引入 bugs
- 需要大量测试
- 可能破坏现有集成
## 推荐方案
### ✅ 采用方案 A轻度重构
**理由**:
1. **代码已经工作良好**: 没有明显的 bug 或性能问题
2. **低风险**: 只提取重复逻辑,不改变核心算法
3. **性价比高**: 小改动带来明显的代码质量提升
4. **易于验证**: 现有测试应该能覆盖
### 重构步骤
```python
# 1. 提取上下文检测
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
"""检测上下文中的账户类型关键词"""
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
# 2. 提取空格标准化
def _normalize_account_spaces(self, line: str) -> str:
"""移除账户号码中的空格"""
# (现有 line 460-481 的代码)
# 3. 提取账户格式化
def _format_account(
self,
account_digits: str,
is_plusgiro_context: bool
) -> tuple[str, str]:
"""格式化账户并确定类型"""
# (现有 line 485-523 的代码)
```
## 对比field_extractor vs machine_code_parser
| 特征 | field_extractor | machine_code_parser |
|------|-----------------|---------------------|
| 用途 | 值提取 | 机器码解析 |
| 重复代码 | ~400行normalize方法 | ~80行上下文检测 |
| 重构价值 | ❌ 不同用途,不应统一 | ✅ 可提取共享逻辑 |
| 风险 | 高(会破坏功能) | 低(只是代码组织) |
## 决策
### ✅ 建议重构 machine_code_parser.py
**与 field_extractor 的不同**:
- field_extractor: 重复的方法有**不同的用途**(提取 vs 变体生成)
- machine_code_parser: 重复的代码有**相同的用途**(都是上下文检测)
**预期收益**:
- 减少 ~70 行重复代码
- 提高可测试性(可以单独测试上下文检测)
- 更清晰的代码组织
- **低风险**,易于验证
## 下一步
1. ✅ 备份原文件
2. ✅ 提取 `_detect_account_context` 方法
3. ✅ 提取 `_normalize_account_spaces` 方法
4. ✅ 提取 `_format_account` 方法
5. ✅ 更新所有调用点
6. ✅ 运行测试验证
7. ✅ 检查代码覆盖率
---
**状态**: 📋 分析完成,建议轻度重构
**风险评估**: 🟢 低风险
**预期收益**: 919行 → ~850行 (↓7%)

View File

@@ -1,519 +0,0 @@
# Performance Optimization Guide
This document provides performance optimization recommendations for the Invoice Field Extraction system.
## Table of Contents
1. [Batch Processing Optimization](#batch-processing-optimization)
2. [Database Query Optimization](#database-query-optimization)
3. [Caching Strategies](#caching-strategies)
4. [Memory Management](#memory-management)
5. [Profiling and Monitoring](#profiling-and-monitoring)
---
## Batch Processing Optimization
### Current State
The system processes invoices one at a time. For large batches, this can be inefficient.
### Recommendations
#### 1. Database Batch Operations
**Current**: Individual inserts for each document
```python
# Inefficient
for doc in documents:
db.insert_document(doc) # Individual DB call
```
**Optimized**: Use `execute_values` for batch inserts
```python
# Efficient - already implemented in db.py line 519
from psycopg2.extras import execute_values
execute_values(cursor, """
INSERT INTO documents (...)
VALUES %s
""", document_values)
```
**Impact**: 10-50x faster for batches of 100+ documents
#### 2. PDF Processing Batching
**Recommendation**: Process PDFs in parallel using multiprocessing
```python
from multiprocessing import Pool
def process_batch(pdf_paths, batch_size=10):
"""Process PDFs in parallel batches."""
with Pool(processes=batch_size) as pool:
results = pool.map(pipeline.process_pdf, pdf_paths)
return results
```
**Considerations**:
- GPU models should use a shared process pool (already exists: `src/processing/gpu_pool.py`)
- CPU-intensive tasks can use separate process pool (`src/processing/cpu_pool.py`)
- Current dual pool coordinator (`dual_pool_coordinator.py`) already supports this pattern
**Status**: ✅ Already implemented in `src/processing/` modules
#### 3. Image Caching for Multi-Page PDFs
**Current**: Each page rendered independently
```python
# Current pattern in field_extractor.py
for page_num in range(total_pages):
image = render_pdf_page(pdf_path, page_num, dpi=300)
```
**Optimized**: Pre-render all pages if processing multiple fields per page
```python
# Batch render
images = {
page_num: render_pdf_page(pdf_path, page_num, dpi=300)
for page_num in page_numbers_needed
}
# Reuse images
for detection in detections:
image = images[detection.page_no]
extract_field(detection, image)
```
**Impact**: Reduces redundant PDF rendering by 50-90% for multi-field invoices
---
## Database Query Optimization
### Current Performance
- **Parameterized queries**: ✅ Implemented (Phase 1)
- **Connection pooling**: ❌ Not implemented
- **Query batching**: ✅ Partially implemented
- **Index optimization**: ⚠️ Needs verification
### Recommendations
#### 1. Connection Pooling
**Current**: New connection for each operation
```python
def connect(self):
"""Create new database connection."""
return psycopg2.connect(**self.config)
```
**Optimized**: Use connection pooling
```python
from psycopg2 import pool
class DocumentDatabase:
def __init__(self, config):
self.pool = pool.SimpleConnectionPool(
minconn=1,
maxconn=10,
**config
)
def connect(self):
return self.pool.getconn()
def close(self, conn):
self.pool.putconn(conn)
```
**Impact**:
- Reduces connection overhead by 80-95%
- Especially important for high-frequency operations
#### 2. Index Recommendations
**Check current indexes**:
```sql
-- Verify indexes exist on frequently queried columns
SELECT tablename, indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'public';
```
**Recommended indexes**:
```sql
-- If not already present
CREATE INDEX IF NOT EXISTS idx_documents_success
ON documents(success);
CREATE INDEX IF NOT EXISTS idx_documents_timestamp
ON documents(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id
ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_matched
ON field_results(matched);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name
ON field_results(field_name);
```
**Impact**:
- 10-100x faster queries for filtered/sorted results
- Critical for `get_failed_matches()` and `get_all_documents_summary()`
#### 3. Query Batching
**Status**: ✅ Already implemented for field results (line 519)
**Verify batching is used**:
```python
# Good pattern in db.py
execute_values(cursor, "INSERT INTO field_results (...) VALUES %s", field_values)
```
**Additional opportunity**: Batch `SELECT` queries
```python
# Current
docs = [get_document(doc_id) for doc_id in doc_ids] # N queries
# Optimized
docs = get_documents_batch(doc_ids) # 1 query with IN clause
```
**Status**: ✅ Already implemented (`get_documents_batch` exists in db.py)
---
## Caching Strategies
### 1. Model Loading Cache
**Current**: Models loaded per-instance
**Recommendation**: Singleton pattern for YOLO model
```python
class YOLODetectorSingleton:
_instance = None
_model = None
@classmethod
def get_instance(cls, model_path):
if cls._instance is None:
cls._instance = YOLODetector(model_path)
return cls._instance
```
**Impact**: Reduces memory usage by 90% when processing multiple documents
### 2. Parser Instance Caching
**Current**: ✅ Already optimal
```python
# Good pattern in field_extractor.py
def __init__(self):
self.payment_line_parser = PaymentLineParser() # Reused
self.customer_number_parser = CustomerNumberParser() # Reused
```
**Status**: No changes needed
### 3. OCR Result Caching
**Recommendation**: Cache OCR results for identical regions
```python
from functools import lru_cache
@lru_cache(maxsize=1000)
def ocr_region_cached(image_hash, bbox):
"""Cache OCR results by image hash + bbox."""
return paddle_ocr.ocr_region(image, bbox)
```
**Impact**: 50-80% speedup when re-processing similar documents
**Note**: Requires implementing image hashing (e.g., `hashlib.md5(image.tobytes())`)
---
## Memory Management
### Current Issues
**Potential memory leaks**:
1. Large images kept in memory after processing
2. OCR results accumulated without cleanup
3. Model outputs not explicitly cleared
### Recommendations
#### 1. Explicit Image Cleanup
```python
import gc
def process_pdf(pdf_path):
try:
image = render_pdf(pdf_path)
result = extract_fields(image)
return result
finally:
del image # Explicit cleanup
gc.collect() # Force garbage collection
```
#### 2. Generator Pattern for Large Batches
**Current**: Load all documents into memory
```python
docs = [process_pdf(path) for path in pdf_paths] # All in memory
```
**Optimized**: Use generator for streaming processing
```python
def process_batch_streaming(pdf_paths):
"""Process documents one at a time, yielding results."""
for path in pdf_paths:
result = process_pdf(path)
yield result
# Result can be saved to DB immediately
# Previous result is garbage collected
```
**Impact**: Constant memory usage regardless of batch size
#### 3. Context Managers for Resources
```python
class InferencePipeline:
def __enter__(self):
self.detector.load_model()
return self
def __exit__(self, *args):
self.detector.unload_model()
self.extractor.cleanup()
# Usage
with InferencePipeline(...) as pipeline:
results = pipeline.process_pdf(path)
# Automatic cleanup
```
---
## Profiling and Monitoring
### Recommended Profiling Tools
#### 1. cProfile for CPU Profiling
```python
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
# Your code here
pipeline.process_pdf(pdf_path)
profiler.disable()
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20) # Top 20 slowest functions
```
#### 2. memory_profiler for Memory Analysis
```bash
pip install memory_profiler
python -m memory_profiler your_script.py
```
Or decorator-based:
```python
from memory_profiler import profile
@profile
def process_large_batch(pdf_paths):
# Memory usage tracked line-by-line
results = [process_pdf(path) for path in pdf_paths]
return results
```
#### 3. py-spy for Production Profiling
```bash
pip install py-spy
# Profile running process
py-spy top --pid 12345
# Generate flamegraph
py-spy record -o profile.svg -- python your_script.py
```
**Advantage**: No code changes needed, minimal overhead
### Key Metrics to Monitor
1. **Processing Time per Document**
- Target: <10 seconds for single-page invoice
- Current: ~2-5 seconds (estimated)
2. **Memory Usage**
- Target: <2GB for batch of 100 documents
- Monitor: Peak memory usage
3. **Database Query Time**
- Target: <100ms per query (with indexes)
- Monitor: Slow query log
4. **OCR Accuracy vs Speed Trade-off**
- Current: PaddleOCR with GPU (~200ms per region)
- Alternative: Tesseract (~500ms, slightly more accurate)
### Logging Performance Metrics
**Add to pipeline.py**:
```python
import time
import logging
logger = logging.getLogger(__name__)
def process_pdf(self, pdf_path):
start = time.time()
# Processing...
result = self._process_internal(pdf_path)
elapsed = time.time() - start
logger.info(f"Processed {pdf_path} in {elapsed:.2f}s")
# Log to database for analysis
self.db.log_performance({
'document_id': result.document_id,
'processing_time': elapsed,
'field_count': len(result.fields)
})
return result
```
---
## Performance Optimization Priorities
### High Priority (Implement First)
1. **Database parameterized queries** - Already done (Phase 1)
2. **Database connection pooling** - Not implemented
3. **Index optimization** - Needs verification
### Medium Priority
4. **Batch PDF rendering** - Optimization possible
5. **Parser instance reuse** - Already done (Phase 2)
6. **Model caching** - Could improve
### Low Priority (Nice to Have)
7. **OCR result caching** - Complex implementation
8. **Generator patterns** - Refactoring needed
9. **Advanced profiling** - For production optimization
---
## Benchmarking Script
```python
"""
Benchmark script for invoice processing performance.
"""
import time
from pathlib import Path
from src.inference.pipeline import InferencePipeline
def benchmark_single_document(pdf_path, iterations=10):
"""Benchmark single document processing."""
pipeline = InferencePipeline(
model_path="path/to/model.pt",
use_gpu=True
)
times = []
for i in range(iterations):
start = time.time()
result = pipeline.process_pdf(pdf_path)
elapsed = time.time() - start
times.append(elapsed)
print(f"Iteration {i+1}: {elapsed:.2f}s")
avg_time = sum(times) / len(times)
print(f"\nAverage: {avg_time:.2f}s")
print(f"Min: {min(times):.2f}s")
print(f"Max: {max(times):.2f}s")
def benchmark_batch(pdf_paths, batch_size=10):
"""Benchmark batch processing."""
from multiprocessing import Pool
pipeline = InferencePipeline(
model_path="path/to/model.pt",
use_gpu=True
)
start = time.time()
with Pool(processes=batch_size) as pool:
results = pool.map(pipeline.process_pdf, pdf_paths)
elapsed = time.time() - start
avg_per_doc = elapsed / len(pdf_paths)
print(f"Total time: {elapsed:.2f}s")
print(f"Documents: {len(pdf_paths)}")
print(f"Average per document: {avg_per_doc:.2f}s")
print(f"Throughput: {len(pdf_paths)/elapsed:.2f} docs/sec")
if __name__ == "__main__":
# Single document benchmark
benchmark_single_document("test.pdf")
# Batch benchmark
pdf_paths = list(Path("data/test_pdfs").glob("*.pdf"))
benchmark_batch(pdf_paths[:100])
```
---
## Summary
**Implemented (Phase 1-2)**:
- Parameterized queries (SQL injection fix)
- Parser instance reuse (Phase 2 refactoring)
- Batch insert operations (execute_values)
- Dual pool processing (CPU/GPU separation)
**Quick Wins (Low effort, high impact)**:
- Database connection pooling (2-4 hours)
- Index verification and optimization (1-2 hours)
- Batch PDF rendering (4-6 hours)
**Long-term Improvements**:
- OCR result caching with hashing
- Generator patterns for streaming
- Advanced profiling and monitoring
**Expected Impact**:
- Connection pooling: 80-95% reduction in DB overhead
- Indexes: 10-100x faster queries
- Batch rendering: 50-90% less redundant work
- **Overall**: 2-5x throughput improvement for batch processing

File diff suppressed because it is too large Load Diff

View File

@@ -1,170 +0,0 @@
# 代码重构总结报告
## 📊 整体成果
### 测试状态
-**688/688 测试全部通过** (100%)
-**代码覆盖率**: 34% → 37% (+3%)
-**0 个失败**, 0 个错误
### 测试覆盖率改进
-**machine_code_parser**: 25% → 65% (+40%)
-**新增测试**: 55个633 → 688
---
## 🎯 已完成的重构
### 1. ✅ Matcher 模块化 (876行 → 205行, ↓76%)
**文件**:
**重构内容**:
- 将单一876行文件拆分为 **11个模块**
- 提取 **5种独立的匹配策略**
- 创建专门的数据模型、工具函数和上下文处理模块
**新模块结构**:
**测试结果**:
- ✅ 77个 matcher 测试全部通过
- ✅ 完整的README文档
- ✅ 策略模式,易于扩展
**收益**:
- 📉 代码量减少 76%
- 📈 可维护性显著提高
- ✨ 每个策略独立测试
- 🔧 易于添加新策略
---
### 2. ✅ Machine Code Parser 轻度重构 + 测试覆盖 (919行 → 929行)
**文件**: src/ocr/machine_code_parser.py
**重构内容**:
- 提取 **3个共享辅助方法**,消除重复代码
- 优化上下文检测逻辑
- 简化账号格式化方法
**测试改进**:
-**新增55个测试**24 → 79个
-**覆盖率**: 25% → 65% (+40%)
- ✅ 所有688个项目测试通过
**新增测试覆盖**:
- **第一轮** (22个测试):
- `_detect_account_context()` - 8个测试上下文检测
- `_normalize_account_spaces()` - 5个测试空格规范化
- `_format_account()` - 4个测试账号格式化
- `parse()` - 5个测试主入口方法
- **第二轮** (33个测试):
- `_extract_ocr()` - 8个测试OCR 提取)
- `_extract_bankgiro()` - 9个测试Bankgiro 提取)
- `_extract_plusgiro()` - 8个测试Plusgiro 提取)
- `_extract_amount()` - 8个测试金额提取
**收益**:
- 🔄 消除80行重复代码
- 📈 可测试性提高(可独立测试辅助方法)
- 📖 代码可读性提升
- ✅ 覆盖率从25%提升到65% (+40%)
- 🎯 低风险,高回报
---
### 3. ✅ Field Extractor 分析 (决定不重构)
**文件**: (1183行)
**分析结果**: ❌ **不应重构**
**关键洞察**:
- 表面相似的代码可能有**完全不同的用途**
- field_extractor: **解析/提取** 字段值
- src/normalize: **标准化/生成变体** 用于匹配
- 两者职责不同,不应统一
**文档**:
---
## 📈 重构统计
### 代码行数变化
| 文件 | 重构前 | 重构后 | 变化 | 百分比 |
|------|--------|--------|------|--------|
| **matcher/field_matcher.py** | 876行 | 205行 | -671 | ↓76% |
| **matcher/* (新增10个模块)** | 0行 | 466行 | +466 | 新增 |
| **matcher 总计** | 876行 | 671行 | -205 | ↓23% |
| **ocr/machine_code_parser.py** | 919行 | 929行 | +10 | +1% |
| **总净减少** | - | - | **-195行** | **↓11%** |
### 测试覆盖
| 模块 | 测试数 | 通过率 | 覆盖率 | 状态 |
|------|--------|--------|--------|------|
| matcher | 77 | 100% | - | ✅ |
| field_extractor | 45 | 100% | 39% | ✅ |
| machine_code_parser | 79 | 100% | 65% | ✅ |
| normalizer | ~120 | 100% | - | ✅ |
| 其他模块 | ~367 | 100% | - | ✅ |
| **总计** | **688** | **100%** | **37%** | ✅ |
---
## 🎓 重构经验总结
### 成功经验
1. **✅ 先测试后重构**
- 所有重构都有完整测试覆盖
- 每次改动后立即验证测试
- 100%测试通过率保证质量
2. **✅ 识别真正的重复**
- 不是所有相似代码都是重复
- field_extractor vs normalizer: 表面相似但用途不同
- machine_code_parser: 真正的代码重复
3. **✅ 渐进式重构**
- matcher: 大规模模块化 (策略模式)
- machine_code_parser: 轻度重构 (提取共享方法)
- field_extractor: 分析后决定不重构
### 关键决策
#### ✅ 应该重构的情况
- **matcher**: 单一文件过长 (876行),包含多种策略
- **machine_code_parser**: 多处相同用途的重复代码
#### ❌ 不应重构的情况
- **field_extractor**: 相似代码有不同用途
### 教训
**不要盲目追求DRY原则**
> 相似代码不一定是重复。要理解代码的**真实用途**。
---
## ✅ 总结
**关键成果**:
- 📉 净减少 195 行代码
- 📈 代码覆盖率 +3% (34% → 37%)
- ✅ 测试数量 +55 (633 → 688)
- 🎯 machine_code_parser 覆盖率 +40% (25% → 65%)
- ✨ 模块化程度显著提高
- 🎯 可维护性大幅提升
**重要教训**:
> 相似的代码不一定是重复的代码。理解代码的真实用途,才能做出正确的重构决策。
**下一步建议**:
1. 继续提升 machine_code_parser 覆盖率到 80%+ (目前 65%)
2. 为其他低覆盖模块添加测试field_extractor 39%, pipeline 19%
3. 完善边界条件和异常情况的测试

View File

@@ -1,258 +0,0 @@
# 测试覆盖率改进报告
## 📊 改进概览
### 整体统计
-**测试总数**: 633 → 688 (+55个测试, +8.7%)
-**通过率**: 100% (688/688)
-**整体覆盖率**: 34% → 37% (+3%)
### machine_code_parser.py 专项改进
-**测试数**: 24 → 79 (+55个测试, +229%)
-**覆盖率**: 25% → 65% (+40%)
-**未覆盖行**: 273 → 129 (减少144行)
---
## 🎯 新增测试详情
### 第一轮改进 (22个测试)
#### 1. TestDetectAccountContext (8个测试)
测试新增的 `_detect_account_context()` 辅助方法。
**测试用例**:
1. `test_bankgiro_keyword` - 检测 'bankgiro' 关键词
2. `test_bg_keyword` - 检测 'bg:' 缩写
3. `test_plusgiro_keyword` - 检测 'plusgiro' 关键词
4. `test_postgiro_keyword` - 检测 'postgiro' 别名
5. `test_pg_keyword` - 检测 'pg:' 缩写
6. `test_both_contexts` - 同时存在两种关键词
7. `test_no_context` - 无账号关键词
8. `test_case_insensitive` - 大小写不敏感检测
**覆盖的代码路径**:
```python
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
```
---
### 2. TestNormalizeAccountSpacesMethod (5个测试)
测试新增的 `_normalize_account_spaces()` 辅助方法。
**测试用例**:
1. `test_removes_spaces_after_arrow` - 移除 > 后的空格
2. `test_multiple_consecutive_spaces` - 处理多个连续空格
3. `test_no_arrow_returns_unchanged` - 无 > 标记时返回原值
4. `test_spaces_before_arrow_preserved` - 保留 > 前的空格
5. `test_empty_string` - 空字符串处理
**覆盖的代码路径**:
```python
def _normalize_account_spaces(self, line: str) -> str:
if '>' not in line:
return line
parts = line.split('>', 1)
after_arrow = parts[1]
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
while re.search(r'(\d)\s+(\d)', normalized):
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
return parts[0] + '>' + normalized
```
---
### 3. TestFormatAccount (4个测试)
测试新增的 `_format_account()` 辅助方法。
**测试用例**:
1. `test_plusgiro_context_forces_plusgiro` - Plusgiro 上下文强制格式化为 Plusgiro
2. `test_valid_bankgiro_7_digits` - 7位有效 Bankgiro 格式化
3. `test_valid_bankgiro_8_digits` - 8位有效 Bankgiro 格式化
4. `test_defaults_to_bankgiro_when_ambiguous` - 模糊情况默认 Bankgiro
**覆盖的代码路径**:
```python
def _format_account(self, account_digits: str, is_plusgiro_context: bool) -> tuple[str, str]:
if is_plusgiro_context:
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
return formatted, 'plusgiro'
# Luhn 验证逻辑
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
# 决策逻辑
if pg_valid and not bg_valid:
return pg_formatted, 'plusgiro'
elif bg_valid and not pg_valid:
return bg_formatted, 'bankgiro'
else:
return bg_formatted, 'bankgiro'
```
---
### 4. TestParseMethod (5个测试)
测试主入口 `parse()` 方法。
**测试用例**:
1. `test_parse_empty_tokens` - 空 token 列表处理
2. `test_parse_finds_payment_line_in_bottom_region` - 在页面底部35%区域查找付款行
3. `test_parse_ignores_top_region` - 忽略页面顶部区域
4. `test_parse_with_context_keywords` - 检测上下文关键词
5. `test_parse_stores_source_tokens` - 存储源 token
**覆盖的代码路径**:
- Token 过滤(底部区域检测)
- 上下文关键词检测
- 付款行查找和解析
- 结果对象构建
---
### 第二轮改进 (33个测试)
#### 5. TestExtractOCR (8个测试)
测试 `_extract_ocr()` 方法 - OCR 参考号码提取。
**测试用例**:
1. `test_extract_valid_ocr_10_digits` - 提取10位 OCR 号码
2. `test_extract_valid_ocr_15_digits` - 提取15位 OCR 号码
3. `test_extract_ocr_with_hash_markers` - 带 # 标记的 OCR
4. `test_extract_longest_ocr_when_multiple` - 多个候选时选最长
5. `test_extract_ocr_ignores_short_numbers` - 忽略短于10位的数字
6. `test_extract_ocr_ignores_long_numbers` - 忽略长于25位的数字
7. `test_extract_ocr_excludes_bankgiro_variants` - 排除 Bankgiro 变体
8. `test_extract_ocr_empty_tokens` - 空 token 处理
#### 6. TestExtractBankgiro (9个测试)
测试 `_extract_bankgiro()` 方法 - Bankgiro 账号提取。
**测试用例**:
1. `test_extract_bankgiro_7_digits_with_dash` - 带破折号的7位 Bankgiro
2. `test_extract_bankgiro_7_digits_without_dash` - 无破折号的7位 Bankgiro
3. `test_extract_bankgiro_8_digits_with_dash` - 带破折号的8位 Bankgiro
4. `test_extract_bankgiro_8_digits_without_dash` - 无破折号的8位 Bankgiro
5. `test_extract_bankgiro_with_spaces` - 带空格的 Bankgiro
6. `test_extract_bankgiro_handles_plusgiro_format` - 处理 Plusgiro 格式
7. `test_extract_bankgiro_with_context` - 带上下文关键词
8. `test_extract_bankgiro_ignores_plusgiro_context` - 忽略 Plusgiro 上下文
9. `test_extract_bankgiro_empty_tokens` - 空 token 处理
#### 7. TestExtractPlusgiro (8个测试)
测试 `_extract_plusgiro()` 方法 - Plusgiro 账号提取。
**测试用例**:
1. `test_extract_plusgiro_7_digits_with_dash` - 带破折号的7位 Plusgiro
2. `test_extract_plusgiro_7_digits_without_dash` - 无破折号的7位 Plusgiro
3. `test_extract_plusgiro_8_digits` - 8位 Plusgiro
4. `test_extract_plusgiro_with_spaces` - 带空格的 Plusgiro
5. `test_extract_plusgiro_with_context` - 带上下文关键词
6. `test_extract_plusgiro_ignores_too_short` - 忽略少于7位
7. `test_extract_plusgiro_ignores_too_long` - 忽略多于8位
8. `test_extract_plusgiro_empty_tokens` - 空 token 处理
#### 8. TestExtractAmount (8个测试)
测试 `_extract_amount()` 方法 - 金额提取。
**测试用例**:
1. `test_extract_amount_with_comma_decimal` - 逗号小数分隔符
2. `test_extract_amount_with_dot_decimal` - 点号小数分隔符
3. `test_extract_amount_integer` - 整数金额
4. `test_extract_amount_with_thousand_separator` - 千位分隔符
5. `test_extract_amount_large_number` - 大额金额
6. `test_extract_amount_ignores_too_large` - 忽略过大金额
7. `test_extract_amount_ignores_zero` - 忽略零或负数
8. `test_extract_amount_empty_tokens` - 空 token 处理
---
## 📈 覆盖率分析
### 已覆盖的方法
`_detect_account_context()` - **100%** (第一轮新增)
`_normalize_account_spaces()` - **100%** (第一轮新增)
`_format_account()` - **95%** (第一轮新增)
`parse()` - **70%** (第一轮改进)
`_parse_standard_payment_line()` - **95%** (已有测试)
`_extract_ocr()` - **85%** (第二轮新增)
`_extract_bankgiro()` - **90%** (第二轮新增)
`_extract_plusgiro()` - **90%** (第二轮新增)
`_extract_amount()` - **80%** (第二轮新增)
### 仍需改进的方法 (未覆盖/部分覆盖)
⚠️ `_calculate_confidence()` - **0%** (未测试)
⚠️ `cross_validate()` - **0%** (未测试)
⚠️ `get_region_bbox()` - **0%** (未测试)
⚠️ `_find_tokens_with_values()` - **部分覆盖**
⚠️ `_find_machine_code_line_tokens()` - **部分覆盖**
### 未覆盖的代码行129行
主要集中在:
1. **验证方法** (lines 805-824): `_calculate_confidence`, `cross_validate`
2. **辅助方法** (lines 80-92, 336-369, 377-407): Token 查找、bbox 计算、日志记录
3. **边界条件** (lines 648-653, 690, 699, 759-760等): 某些提取方法的边界情况
---
## 🎯 改进建议
### ✅ 已完成目标
- ✅ 覆盖率从 25% 提升到 65% (+40%)
- ✅ 测试数量从 24 增加到 79 (+55个)
- ✅ 提取方法全部测试_extract_ocr, _extract_bankgiro, _extract_plusgiro, _extract_amount
### 下一步目标(覆盖率 65% → 80%+
1. **添加验证方法测试** - 为 `_calculate_confidence`, `cross_validate` 添加测试
2. **添加辅助方法测试** - 为 token 查找和 bbox 计算方法添加测试
3. **完善边界条件** - 增加边界情况和异常处理的测试
4. **集成测试** - 添加端到端的集成测试,使用真实 PDF token 数据
---
## ✅ 已完成的改进
### 重构收益
- ✅ 提取的3个辅助方法现在可以独立测试
- ✅ 测试粒度更细,更容易定位问题
- ✅ 代码可读性提高,测试用例清晰易懂
### 质量保证
- ✅ 所有655个测试100%通过
- ✅ 无回归问题
- ✅ 新增测试覆盖了之前未测试的重构代码
---
## 📚 测试编写经验
### 成功经验
1. **使用 fixture 创建测试数据** - `_create_token()` 辅助方法简化了 token 创建
2. **按方法组织测试类** - 每个方法一个测试类,结构清晰
3. **测试用例命名清晰** - `test_<what>_<condition>` 格式,一目了然
4. **覆盖关键路径** - 优先测试常见场景和边界条件
### 遇到的问题
1. **Token 初始化参数** - 忘记了 `page_no` 参数,导致初始测试失败
- 解决:修复 `_create_token()` 辅助方法,添加 `page_no=0`
---
**报告日期**: 2026-01-24
**状态**: ✅ 完成
**下一步**: 继续提升覆盖率到 60%+

View File

@@ -1,619 +0,0 @@
# 多池处理架构设计文档
## 1. 研究总结
### 1.1 当前问题分析
我们之前实现的双池模式存在稳定性问题,主要原因:
| 问题 | 原因 | 解决方案 |
|------|------|----------|
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
### 1.2 推荐架构方案
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**
```
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Main Process │ │ CPU Workers │ │ GPU Worker │
│ │ │ (4 processes) │ │ (1 process) │
│ ┌───────────┐ │ │ │ │ │
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
│ └───────────┘ │ │ │ │ │
│ ▲ │ │ │ │ │ │ │
│ │ │ │ ▼ │ │ ▼ │
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
│ │ Collector │ │ │ │ │ │
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
│ │ │
│ ▼ │
│ ┌───────────┐ │
│ │ Database │ │
│ │ Batch │ │
│ │ Writer │ │
│ └───────────┘ │
└─────────────────┘
```
---
## 2. 核心设计原则
### 2.1 CUDA 兼容性
```python
# 关键:使用 spawn 启动方式
import multiprocessing as mp
ctx = mp.get_context("spawn")
# GPU worker 初始化时设置设备
def init_gpu_worker(gpu_id: int = 0):
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
global _ocr
from paddleocr import PaddleOCR
_ocr = PaddleOCR(use_gpu=True, ...)
```
### 2.2 Worker 初始化模式
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
```python
# 全局变量保存模型
_ocr = None
def init_worker(use_gpu: bool, gpu_id: int = 0):
global _ocr
if use_gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from paddleocr import PaddleOCR
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
# 创建 Pool 时使用 initializer
pool = ProcessPoolExecutor(
max_workers=1,
initializer=init_worker,
initargs=(True, 0), # use_gpu=True, gpu_id=0
mp_context=mp.get_context("spawn")
)
```
### 2.3 队列模式 vs as_completed
| 方式 | 优点 | 缺点 | 适用场景 |
|------|------|------|----------|
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
---
## 3. 详细开发计划
### 阶段 1重构基础架构 (2-3天)
#### 1.1 创建 WorkerPool 抽象类
```python
# src/processing/worker_pool.py
from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, Future
from dataclasses import dataclass
from typing import List, Any, Optional, Callable
import multiprocessing as mp
@dataclass
class TaskResult:
"""任务结果容器"""
task_id: str
success: bool
data: Any
error: Optional[str] = None
processing_time: float = 0.0
class WorkerPool(ABC):
"""Worker Pool 抽象基类"""
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
self.max_workers = max_workers
self.use_gpu = use_gpu
self.gpu_id = gpu_id
self._executor: Optional[ProcessPoolExecutor] = None
@abstractmethod
def get_initializer(self) -> Callable:
"""返回 worker 初始化函数"""
pass
@abstractmethod
def get_init_args(self) -> tuple:
"""返回初始化参数"""
pass
def start(self):
"""启动 worker pool"""
ctx = mp.get_context("spawn")
self._executor = ProcessPoolExecutor(
max_workers=self.max_workers,
mp_context=ctx,
initializer=self.get_initializer(),
initargs=self.get_init_args()
)
def submit(self, fn: Callable, *args, **kwargs) -> Future:
"""提交任务"""
if not self._executor:
raise RuntimeError("Pool not started")
return self._executor.submit(fn, *args, **kwargs)
def shutdown(self, wait: bool = True):
"""关闭 pool"""
if self._executor:
self._executor.shutdown(wait=wait)
self._executor = None
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.shutdown()
```
#### 1.2 实现 CPU 和 GPU Worker Pool
```python
# src/processing/cpu_pool.py
class CPUWorkerPool(WorkerPool):
"""CPU-only worker pool for text PDF processing"""
def __init__(self, max_workers: int = 4):
super().__init__(max_workers=max_workers, use_gpu=False)
def get_initializer(self) -> Callable:
return init_cpu_worker
def get_init_args(self) -> tuple:
return ()
# src/processing/gpu_pool.py
class GPUWorkerPool(WorkerPool):
"""GPU worker pool for OCR processing"""
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
def get_initializer(self) -> Callable:
return init_gpu_worker
def get_init_args(self) -> tuple:
return (self.gpu_id,)
```
---
### 阶段 2实现双池协调器 (2-3天)
#### 2.1 任务分发器
```python
# src/processing/task_dispatcher.py
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Tuple
class TaskType(Enum):
CPU = auto() # Text PDF
GPU = auto() # Scanned PDF
@dataclass
class Task:
id: str
task_type: TaskType
data: Any
class TaskDispatcher:
"""根据 PDF 类型分发任务到不同的 pool"""
def classify_task(self, doc_info: dict) -> TaskType:
"""判断文档是否需要 OCR"""
# 基于 PDF 特征判断
if self._is_scanned_pdf(doc_info):
return TaskType.GPU
return TaskType.CPU
def _is_scanned_pdf(self, doc_info: dict) -> bool:
"""检测是否为扫描件"""
# 1. 检查是否有可提取文本
# 2. 检查图片比例
# 3. 检查文本密度
pass
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
"""将任务分为 CPU 和 GPU 两组"""
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
return cpu_tasks, gpu_tasks
```
#### 2.2 双池协调器
```python
# src/processing/dual_pool_coordinator.py
from concurrent.futures import as_completed
from typing import List, Iterator
import logging
logger = logging.getLogger(__name__)
class DualPoolCoordinator:
"""协调 CPU 和 GPU 两个 worker pool"""
def __init__(
self,
cpu_workers: int = 4,
gpu_workers: int = 1,
gpu_id: int = 0
):
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
self.dispatcher = TaskDispatcher()
def __enter__(self):
self.cpu_pool.start()
self.gpu_pool.start()
return self
def __exit__(self, *args):
self.cpu_pool.shutdown()
self.gpu_pool.shutdown()
def process_batch(
self,
documents: List[dict],
cpu_task_fn: Callable,
gpu_task_fn: Callable,
on_result: Optional[Callable[[TaskResult], None]] = None,
on_error: Optional[Callable[[str, Exception], None]] = None
) -> List[TaskResult]:
"""
处理一批文档,自动分发到 CPU 或 GPU pool
Args:
documents: 待处理文档列表
cpu_task_fn: CPU 任务处理函数
gpu_task_fn: GPU 任务处理函数
on_result: 结果回调(可选)
on_error: 错误回调(可选)
Returns:
所有任务结果列表
"""
# 分类任务
tasks = [
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
for doc in documents
]
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
# 提交任务到各自的 pool
cpu_futures = {
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
for t in cpu_tasks
}
gpu_futures = {
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
for t in gpu_tasks
}
# 收集结果
results = []
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
for future in as_completed(all_futures):
task_id = cpu_futures.get(future) or gpu_futures.get(future)
pool_type = "CPU" if future in cpu_futures else "GPU"
try:
data = future.result(timeout=300) # 5分钟超时
result = TaskResult(task_id=task_id, success=True, data=data)
if on_result:
on_result(result)
except Exception as e:
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
if on_error:
on_error(task_id, e)
results.append(result)
return results
```
---
### 阶段 3集成到 autolabel (1-2天)
#### 3.1 修改 autolabel.py
```python
# src/cli/autolabel.py
def run_autolabel_dual_pool(args):
"""使用双池模式运行自动标注"""
from src.processing.dual_pool_coordinator import DualPoolCoordinator
# 初始化数据库批处理
db_batch = []
db_batch_size = 100
def on_result(result: TaskResult):
"""处理成功结果"""
nonlocal db_batch
db_batch.append(result.data)
if len(db_batch) >= db_batch_size:
save_documents_batch(db_batch)
db_batch.clear()
def on_error(task_id: str, error: Exception):
"""处理错误"""
logger.error(f"Task {task_id} failed: {error}")
# 创建双池协调器
with DualPoolCoordinator(
cpu_workers=args.cpu_workers or 4,
gpu_workers=args.gpu_workers or 1,
gpu_id=0
) as coordinator:
# 处理所有 CSV
for csv_file in csv_files:
documents = load_documents_from_csv(csv_file)
results = coordinator.process_batch(
documents=documents,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=on_result,
on_error=on_error
)
logger.info(f"CSV {csv_file}: {len(results)} processed")
# 保存剩余批次
if db_batch:
save_documents_batch(db_batch)
```
---
### 阶段 4测试与验证 (1-2天)
#### 4.1 单元测试
```python
# tests/unit/test_dual_pool.py
import pytest
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
class TestDualPoolCoordinator:
def test_cpu_only_batch(self):
"""测试纯 CPU 任务批处理"""
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
results = coord.process_batch(docs, cpu_fn, gpu_fn)
assert len(results) == 10
assert all(r.success for r in results)
def test_mixed_batch(self):
"""测试混合任务批处理"""
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
docs = [
{"id": "text_1", "type": "text"},
{"id": "scan_1", "type": "scanned"},
{"id": "text_2", "type": "text"},
]
results = coord.process_batch(docs, cpu_fn, gpu_fn)
assert len(results) == 3
def test_timeout_handling(self):
"""测试超时处理"""
pass
def test_error_recovery(self):
"""测试错误恢复"""
pass
```
#### 4.2 集成测试
```python
# tests/integration/test_autolabel_dual_pool.py
def test_autolabel_with_dual_pool():
"""端到端测试双池模式"""
# 使用少量测试数据
result = subprocess.run([
"python", "-m", "src.cli.autolabel",
"--cpu-workers", "2",
"--gpu-workers", "1",
"--limit", "50"
], capture_output=True)
assert result.returncode == 0
# 验证数据库记录
```
---
## 4. 关键技术点
### 4.1 避免死锁的策略
```python
# 1. 使用 timeout
try:
result = future.result(timeout=300)
except TimeoutError:
logger.warning(f"Task timed out")
# 2. 使用哨兵值
SENTINEL = object()
queue.put(SENTINEL) # 发送结束信号
# 3. 检查进程状态
if not worker.is_alive():
logger.error("Worker died unexpectedly")
break
# 4. 先清空队列再 join
while not queue.empty():
results.append(queue.get_nowait())
worker.join(timeout=5.0)
```
### 4.2 PaddleOCR 特殊处理
```python
# PaddleOCR 必须在 worker 进程中初始化
def init_paddle_worker(gpu_id: int):
global _ocr
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# 延迟导入,确保 CUDA 环境变量生效
from paddleocr import PaddleOCR
_ocr = PaddleOCR(
use_angle_cls=True,
lang='en',
use_gpu=True,
show_log=False,
# 重要:设置 GPU 内存比例
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
)
```
### 4.3 资源监控
```python
import psutil
import GPUtil
def get_resource_usage():
"""获取系统资源使用情况"""
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
gpu_info = []
for gpu in GPUtil.getGPUs():
gpu_info.append({
"id": gpu.id,
"memory_used": gpu.memoryUsed,
"memory_total": gpu.memoryTotal,
"utilization": gpu.load * 100
})
return {
"cpu_percent": cpu_percent,
"memory_percent": memory.percent,
"gpu": gpu_info
}
```
---
## 5. 风险评估与应对
| 风险 | 可能性 | 影响 | 应对策略 |
|------|--------|------|----------|
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1设置 gpu_mem 参数 |
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
| 任务分类错误 | 中 | 中 | 添加回退机制CPU 失败后尝试 GPU |
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
---
## 6. 备选方案
如果上述方案仍存在问题,可以考虑:
### 6.1 使用 Ray
```python
import ray
ray.init()
@ray.remote(num_cpus=1)
def cpu_task(data):
return process_text_pdf(data)
@ray.remote(num_gpus=1)
def gpu_task(data):
return process_scanned_pdf(data)
# 自动资源调度
futures = [cpu_task.remote(d) for d in cpu_docs]
futures += [gpu_task.remote(d) for d in gpu_docs]
results = ray.get(futures)
```
### 6.2 单池 + 动态 GPU 调度
保持单池模式,但在每个任务内部动态决定是否使用 GPU
```python
def process_document(doc_data):
if is_scanned_pdf(doc_data):
# 使用 GPU (需要全局锁或信号量控制并发)
with gpu_semaphore:
return process_with_ocr(doc_data)
else:
return process_text_only(doc_data)
```
---
## 7. 时间线总结
| 阶段 | 任务 | 预计工作量 |
|------|------|------------|
| 阶段 1 | 基础架构重构 | 2-3 天 |
| 阶段 2 | 双池协调器实现 | 2-3 天 |
| 阶段 3 | 集成到 autolabel | 1-2 天 |
| 阶段 4 | 测试与验证 | 1-2 天 |
| **总计** | | **6-10 天** |
---
## 8. 参考资料
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)

1223
docs/product-plan-v2.md Normal file

File diff suppressed because it is too large Load Diff

302
docs/ux-design-prompt-v2.md Normal file
View File

@@ -0,0 +1,302 @@
# Document Annotation Tool UX Design Spec v2
## Theme: Warm Graphite (Modern Enterprise)
---
## 1. Design Principles (Updated)
1. **Clarity** High contrast, but never pure black-on-white
2. **Warm Neutrality** Slightly warm grays reduce visual fatigue
3. **Focus** Content-first layouts with restrained accents
4. **Consistency** Reusable patterns, predictable behavior
5. **Professional Trust** Calm, serious, enterprise-ready
6. **Longevity** No trendy colors that age quickly
---
## 2. Color Palette (Warm Graphite)
### Core Colors
| Usage | Color Name | Hex |
|------|-----------|-----|
| Primary Text | Soft Black | #121212 |
| Secondary Text | Charcoal Gray | #2A2A2A |
| Muted Text | Warm Gray | #6B6B6B |
| Disabled Text | Light Warm Gray | #9A9A9A |
### Backgrounds
| Usage | Color | Hex |
|-----|------|-----|
| App Background | Paper White | #FAFAF8 |
| Card / Panel | White | #FFFFFF |
| Hover Surface | Subtle Warm Gray | #F1F0ED |
| Selected Row | Very Light Warm Gray | #ECEAE6 |
### Borders & Dividers
| Usage | Color | Hex |
|------|------|-----|
| Default Border | Warm Light Gray | #E6E4E1 |
| Strong Divider | Neutral Gray | #D8D6D2 |
### Semantic States (Muted & Professional)
| State | Color | Hex |
|------|-------|-----|
| Success | Olive Gray | #3E4A3A |
| Error | Brick Gray | #4A3A3A |
| Warning | Sand Gray | #4A4A3A |
| Info | Graphite Gray | #3A3A3A |
> Accent colors are **never saturated** and are used only for status, progress, or selection.
---
## 3. Typography
- **Font Family**: Inter / SF Pro / system-ui
- **Headings**:
- Weight: 600700
- Color: #121212
- Letter spacing: -0.01em
- **Body Text**:
- Weight: 400
- Color: #2A2A2A
- **Captions / Meta**:
- Weight: 400
- Color: #6B6B6B
- **Monospace (IDs / Values)**:
- JetBrains Mono / SF Mono
- Color: #2A2A2A
---
## 4. Global Layout
### Top Navigation Bar
- Height: 56px
- Background: #FAFAF8
- Bottom Border: 1px solid #E6E4E1
- Logo: Text or icon in #121212
**Navigation Items**
- Default: #6B6B6B
- Hover: #2A2A2A
- Active:
- Text: #121212
- Bottom indicator: 2px solid #3A3A3A (rounded ends)
**Avatar**
- Circle background: #ECEAE6
- Text: #2A2A2A
---
## 5. Page: Documents (Dashboard)
### Page Header
- Title: "Documents" (#121212)
- Actions:
- Primary button: Dark graphite outline
- Secondary button: Subtle border only
### Filters Bar
- Background: #FFFFFF
- Border: 1px solid #E6E4E1
- Inputs:
- Background: #FFFFFF
- Hover: #F1F0ED
- Focus ring: 1px #3A3A3A
### Document Table
- Table background: #FFFFFF
- Header text: #6B6B6B
- Row hover: #F1F0ED
- Row selected:
- Background: #ECEAE6
- Left indicator: 3px solid #3A3A3A
### Status Badges
- Pending:
- BG: #FFFFFF
- Border: #D8D6D2
- Text: #2A2A2A
- Labeled:
- BG: #2A2A2A
- Text: #FFFFFF
- Exported:
- BG: #ECEAE6
- Text: #2A2A2A
- Icon: ✓
### Auto-label States
- Running:
- Progress bar: #3A3A3A on #ECEAE6
- Completed:
- Text: #3E4A3A
- Failed:
- BG: #F1EDED
- Text: #4A3A3A
---
## 6. Upload Modals (Single & Batch)
### Modal Container
- Background: #FFFFFF
- Border radius: 8px
- Shadow: 0 1px 3px rgba(0,0,0,0.08)
### Drop Zone
- Background: #FAFAF8
- Border: 1px dashed #D8D6D2
- Hover: #F1F0ED
- Icon: Graphite gray
### Form Fields
- Input BG: #FFFFFF
- Border: #D8D6D2
- Focus: 1px solid #3A3A3A
Primary Action Button:
- Text: #FFFFFF
- BG: #2A2A2A
- Hover: #121212
---
## 7. Document Detail View
### Canvas Area
- Background: #FFFFFF
- Annotation styles:
- Manual: Solid border #2A2A2A
- Auto: Dashed border #6B6B6B
- Selected: 2px border #3A3A3A + resize handles
### Right Info Panel
- Card background: #FFFFFF
- Section headers: #121212
- Meta text: #6B6B6B
### Annotation Table
- Same table styles as Documents
- Inline edit:
- Input background: #FAFAF8
- Save button: Graphite
### Locked State (Auto-label Running)
- Banner BG: #FAFAF8
- Border-left: 3px solid #4A4A3A
- Progress bar: Graphite
---
## 8. Training Page
### Document Selector
- Selected rows use same highlight rules
- Verified state:
- Full: Olive gray check
- Partial: Sand gray warning
### Configuration Panel
- Card layout
- Inputs aligned to grid
- Schedule option visually muted until enabled
Primary CTA:
- Start Training button in dark graphite
---
## 9. Models & Training History
### Training Job List
- Job cards use #FFFFFF background
- Running job:
- Progress bar: #3A3A3A
- Completed job:
- Metrics bars in graphite
### Model Detail Panel
- Sectioned cards
- Metric bars:
- Track: #ECEAE6
- Fill: #3A3A3A
Actions:
- Primary: Download Model
- Secondary: View Logs / Use as Base
---
## 10. Micro-interactions (Refined)
| Element | Interaction | Animation |
|------|------------|-----------|
| Button hover | BG lightens | 150ms ease-out |
| Button press | Scale 0.98 | 100ms |
| Row hover | BG fade | 120ms |
| Modal open | Fade + scale 0.96 → 1 | 200ms |
| Progress fill | Smooth | ease-out |
| Annotation select | Border + handles | 120ms |
---
## 11. Tailwind Theme (Updated)
```js
colors: {
text: {
primary: '#121212',
secondary: '#2A2A2A',
muted: '#6B6B6B',
disabled: '#9A9A9A',
},
bg: {
app: '#FAFAF8',
card: '#FFFFFF',
hover: '#F1F0ED',
selected: '#ECEAE6',
},
border: '#E6E4E1',
accent: '#3A3A3A',
success: '#3E4A3A',
error: '#4A3A3A',
warning: '#4A4A3A',
}
```
---
## 12. Final Notes
- Pure black (#000000) should **never** be used as large surfaces
- Accent color usage should stay under **10% of UI area**
- Warm grays are intentional and must not be "corrected" to blue-grays
This theme is designed to scale from internal tool → polished SaaS without redesign.

View File

@@ -0,0 +1,273 @@
# Web Directory Refactoring - Complete ✅
**Date**: 2026-01-25
**Status**: ✅ Completed
**Tests**: 188 passing (0 failures)
**Coverage**: 23% (maintained)
---
## Final Directory Structure
```
src/web/
├── api/
│ ├── __init__.py
│ └── v1/
│ ├── __init__.py
│ ├── routes.py # Public inference API
│ ├── admin/
│ │ ├── __init__.py
│ │ ├── documents.py # Document management (was admin_routes.py)
│ │ ├── annotations.py # Annotation routes (was admin_annotation_routes.py)
│ │ └── training.py # Training routes (was admin_training_routes.py)
│ ├── async_api/
│ │ ├── __init__.py
│ │ └── routes.py # Async processing API (was async_routes.py)
│ └── batch/
│ ├── __init__.py
│ └── routes.py # Batch upload API (was batch_upload_routes.py)
├── schemas/
│ ├── __init__.py
│ ├── common.py # Shared models (ErrorResponse)
│ ├── admin.py # Admin schemas (was admin_schemas.py)
│ └── inference.py # Inference + async schemas (was schemas.py)
├── services/
│ ├── __init__.py
│ ├── inference.py # Inference service (was services.py)
│ ├── autolabel.py # Auto-label service (was admin_autolabel.py)
│ ├── async_processing.py # Async processing (was async_service.py)
│ └── batch_upload.py # Batch upload service (was batch_upload_service.py)
├── core/
│ ├── __init__.py
│ ├── auth.py # Authentication (was admin_auth.py)
│ ├── rate_limiter.py # Rate limiting (unchanged)
│ └── scheduler.py # Task scheduler (was admin_scheduler.py)
├── workers/
│ ├── __init__.py
│ ├── async_queue.py # Async task queue (was async_queue.py)
│ └── batch_queue.py # Batch task queue (was batch_queue.py)
├── __init__.py # Main exports
├── app.py # FastAPI app (imports updated)
├── config.py # Configuration (unchanged)
└── dependencies.py # Global dependencies (unchanged)
```
---
## Changes Summary
### Files Moved and Renamed
| Old Location | New Location | Change Type |
|-------------|--------------|-------------|
| `admin_routes.py` | `api/v1/admin/documents.py` | Moved + Renamed |
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Moved + Renamed |
| `admin_training_routes.py` | `api/v1/admin/training.py` | Moved + Renamed |
| `admin_auth.py` | `core/auth.py` | Moved |
| `admin_autolabel.py` | `services/autolabel.py` | Moved |
| `admin_scheduler.py` | `core/scheduler.py` | Moved |
| `admin_schemas.py` | `schemas/admin.py` | Moved |
| `routes.py` | `api/v1/routes.py` | Moved |
| `schemas.py` | `schemas/inference.py` | Moved |
| `services.py` | `services/inference.py` | Moved |
| `async_routes.py` | `api/v1/async_api/routes.py` | Moved |
| `async_queue.py` | `workers/async_queue.py` | Moved |
| `async_service.py` | `services/async_processing.py` | Moved + Renamed |
| `batch_queue.py` | `workers/batch_queue.py` | Moved |
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Moved |
| `batch_upload_service.py` | `services/batch_upload.py` | Moved |
**Total**: 16 files reorganized
### Files Updated
**Source Files** (imports updated):
- `app.py` - Updated all imports to new structure
- `api/v1/admin/documents.py` - Updated schema/auth imports
- `api/v1/admin/annotations.py` - Updated schema/service imports
- `api/v1/admin/training.py` - Updated schema/auth imports
- `api/v1/routes.py` - Updated schema imports
- `api/v1/async_api/routes.py` - Updated schema imports
- `api/v1/batch/routes.py` - Updated service/worker imports
- `services/async_processing.py` - Updated worker/core imports
**Test Files** (all 15 updated):
- `test_admin_annotations.py`
- `test_admin_auth.py`
- `test_admin_routes.py`
- `test_admin_routes_enhanced.py`
- `test_admin_training.py`
- `test_annotation_locks.py`
- `test_annotation_phase5.py`
- `test_async_queue.py`
- `test_async_routes.py`
- `test_async_service.py`
- `test_autolabel_with_locks.py`
- `test_batch_queue.py`
- `test_batch_upload_routes.py`
- `test_batch_upload_service.py`
- `test_training_phase4.py`
- `conftest.py`
---
## Import Examples
### Old Import Style (Before Refactoring)
```python
from src.web.admin_routes import create_admin_router
from src.web.admin_schemas import DocumentItem
from src.web.admin_auth import validate_admin_token
from src.web.async_routes import create_async_router
from src.web.schemas import ErrorResponse
```
### New Import Style (After Refactoring)
```python
# Admin API
from src.web.api.v1.admin.documents import create_admin_router
from src.web.api.v1.admin import create_admin_router # Shorter alternative
# Schemas
from src.web.schemas.admin import DocumentItem
from src.web.schemas.common import ErrorResponse
# Core components
from src.web.core.auth import validate_admin_token
# Async API
from src.web.api.v1.async_api.routes import create_async_router
```
---
## Benefits Achieved
### 1. **Clear Separation of Concerns**
- **API Routes**: All in `api/v1/` by version and feature
- **Data Models**: All in `schemas/` by domain
- **Business Logic**: All in `services/`
- **Core Components**: Reusable utilities in `core/`
- **Background Jobs**: Task queues in `workers/`
### 2. **Better Scalability**
- Easy to add API v2 without touching v1
- Clear namespace for each module
- Reduced file sizes (no 800+ line files)
- Follows single responsibility principle
### 3. **Improved Maintainability**
- Find files by function, not by prefix
- Each module has one clear purpose
- Easier to onboard new developers
- Better IDE navigation
### 4. **Standards Compliance**
- Follows FastAPI best practices
- Matches Django/Flask project structures
- Standard Python package organization
- Industry-standard naming conventions
---
## Testing Results
**Before Refactoring**:
- 188 tests passing
- 23% code coverage
- Flat directory structure
**After Refactoring**:
- ✅ 188 tests passing (0 failures)
- ✅ 23% code coverage (maintained)
- ✅ Clean hierarchical structure
- ✅ All imports updated
- ✅ No backward compatibility shims needed
---
## Migration Statistics
| Metric | Count |
|--------|-------|
| Files moved | 16 |
| Directories created | 9 |
| Files updated (source) | 8 |
| Files updated (tests) | 16 |
| Import statements updated | ~150 |
| Lines of code changed | ~200 |
| Tests broken | 0 |
| Coverage lost | 0% |
---
## Code Diff Summary
```diff
Before:
src/web/
├── admin_routes.py (645 lines)
├── admin_annotation_routes.py (504 lines)
├── admin_training_routes.py (565 lines)
├── admin_auth.py (22 lines)
├── admin_schemas.py (262 lines)
... (15 more files at root level)
After:
src/web/
├── api/v1/
│ ├── admin/ (3 route files)
│ ├── async_api/ (1 route file)
│ └── batch/ (1 route file)
├── schemas/ (3 schema files)
├── services/ (4 service files)
├── core/ (3 core files)
└── workers/ (2 worker files)
```
---
## Next Steps (Optional)
### Phase 2: Documentation
- [ ] Update API documentation with new import paths
- [ ] Create migration guide for external developers
- [ ] Update CLAUDE.md with new structure
### Phase 3: Further Optimization
- [ ] Split large files (>400 lines) if needed
- [ ] Extract common utilities
- [ ] Add typing stubs
### Phase 4: Deprecation (Future)
- [ ] Add deprecation warnings if creating compatibility layer
- [ ] Remove old imports after grace period
- [ ] Update all documentation
---
## Rollback Instructions
If needed, rollback is simple:
```bash
git revert <commit-hash>
```
All changes are in version control, making rollback safe and easy.
---
## Conclusion
**Refactoring completed successfully**
**Zero breaking changes**
**All tests passing**
**Industry-standard structure achieved**
The web directory is now organized following Python and FastAPI best practices, making it easier to scale, maintain, and extend.

View File

@@ -0,0 +1,186 @@
# Web Directory Refactoring Plan
## Current Structure Issues
1. **Flat structure**: All files in one directory (20 Python files)
2. **Naming inconsistency**: Mix of `admin_*`, `async_*`, `batch_*` prefixes
3. **Mixed concerns**: Routes, schemas, services, and workers in same directory
4. **Poor scalability**: Hard to navigate and maintain as project grows
## Proposed Structure (Best Practices)
```
src/web/
├── __init__.py # Main exports
├── app.py # FastAPI app factory
├── config.py # App configuration
├── dependencies.py # Global dependencies
├── api/ # API Routes Layer
│ ├── __init__.py
│ └── v1/ # API version 1
│ ├── __init__.py
│ ├── routes.py # Public API routes (inference)
│ ├── admin/ # Admin API routes
│ │ ├── __init__.py
│ │ ├── documents.py # admin_routes.py → documents.py
│ │ ├── annotations.py # admin_annotation_routes.py → annotations.py
│ │ ├── training.py # admin_training_routes.py → training.py
│ │ └── auth.py # admin_auth.py → auth.py (routes only)
│ ├── async_api/ # Async processing API
│ │ ├── __init__.py
│ │ └── routes.py # async_routes.py → routes.py
│ └── batch/ # Batch upload API
│ ├── __init__.py
│ └── routes.py # batch_upload_routes.py → routes.py
├── schemas/ # Pydantic Models
│ ├── __init__.py
│ ├── common.py # Shared schemas (ErrorResponse, etc.)
│ ├── inference.py # schemas.py → inference.py
│ ├── admin.py # admin_schemas.py → admin.py
│ ├── async_api.py # New: async API schemas
│ └── batch.py # New: batch upload schemas
├── services/ # Business Logic Layer
│ ├── __init__.py
│ ├── inference.py # services.py → inference.py
│ ├── autolabel.py # admin_autolabel.py → autolabel.py
│ ├── async_processing.py # async_service.py → async_processing.py
│ └── batch_upload.py # batch_upload_service.py → batch_upload.py
├── core/ # Core Components
│ ├── __init__.py
│ ├── auth.py # admin_auth.py → auth.py (logic only)
│ ├── rate_limiter.py # rate_limiter.py → rate_limiter.py
│ └── scheduler.py # admin_scheduler.py → scheduler.py
└── workers/ # Background Task Queues
├── __init__.py
├── async_queue.py # async_queue.py → async_queue.py
└── batch_queue.py # batch_queue.py → batch_queue.py
```
## File Mapping
### Current → New Location
| Current File | New Location | Purpose |
|--------------|--------------|---------|
| `admin_routes.py` | `api/v1/admin/documents.py` | Document management routes |
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Annotation routes |
| `admin_training_routes.py` | `api/v1/admin/training.py` | Training routes |
| `admin_auth.py` | Split: `api/v1/admin/auth.py` + `core/auth.py` | Auth routes + logic |
| `admin_schemas.py` | `schemas/admin.py` | Admin Pydantic models |
| `admin_autolabel.py` | `services/autolabel.py` | Auto-label service |
| `admin_scheduler.py` | `core/scheduler.py` | Training scheduler |
| `routes.py` | `api/v1/routes.py` | Public inference API |
| `schemas.py` | `schemas/inference.py` | Inference models |
| `services.py` | `services/inference.py` | Inference service |
| `async_routes.py` | `api/v1/async_api/routes.py` | Async API routes |
| `async_service.py` | `services/async_processing.py` | Async processing service |
| `async_queue.py` | `workers/async_queue.py` | Async task queue |
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Batch upload routes |
| `batch_upload_service.py` | `services/batch_upload.py` | Batch upload service |
| `batch_queue.py` | `workers/batch_queue.py` | Batch task queue |
| `rate_limiter.py` | `core/rate_limiter.py` | Rate limiting logic |
| `config.py` | `config.py` | Keep as-is |
| `dependencies.py` | `dependencies.py` | Keep as-is |
| `app.py` | `app.py` | Keep as-is (update imports) |
## Benefits
### 1. Clear Separation of Concerns
- **Routes**: API endpoint definitions
- **Schemas**: Data validation models
- **Services**: Business logic
- **Core**: Reusable components
- **Workers**: Background processing
### 2. Better Scalability
- Easy to add new API versions (`v2/`)
- Clear namespace for each domain
- Reduced file size (no 800+ line files)
### 3. Improved Maintainability
- Find files by function, not by prefix
- Each module has single responsibility
- Easier to write focused tests
### 4. Standard Python Patterns
- Package-based organization
- Follows FastAPI best practices
- Similar to Django/Flask structures
## Implementation Steps
### Phase 1: Create New Structure (No Breaking Changes)
1. Create new directories: `api/`, `schemas/`, `services/`, `core/`, `workers/`
2. Copy files to new locations (don't delete originals yet)
3. Update imports in new files
4. Add `__init__.py` with proper exports
### Phase 2: Update Tests
5. Update test imports to use new structure
6. Run tests to verify nothing breaks
7. Fix any import issues
### Phase 3: Update Main App
8. Update `app.py` to import from new locations
9. Run full test suite
10. Verify all endpoints work
### Phase 4: Cleanup
11. Delete old files
12. Update documentation
13. Final test run
## Migration Priority
**High Priority** (Most used):
- Routes and schemas (user-facing APIs)
- Services (core business logic)
**Medium Priority**:
- Core components (auth, rate limiter)
- Workers (background tasks)
**Low Priority**:
- Config and dependencies (already well-located)
## Backwards Compatibility
During migration, maintain backwards compatibility:
```python
# src/web/__init__.py
# Old imports still work
from src.web.api.v1.admin.documents import router as admin_router
from src.web.schemas.admin import AdminDocument
# Keep old names for compatibility (temporary)
admin_routes = admin_router # Deprecated alias
```
## Testing Strategy
1. **Unit Tests**: Test each module independently
2. **Integration Tests**: Test API endpoints still work
3. **Import Tests**: Verify all old imports still work
4. **Coverage**: Maintain current 23% coverage minimum
## Rollback Plan
If issues arise:
1. Keep old files until fully migrated
2. Git allows easy revert
3. Tests catch breaking changes early
---
## Next Steps
Would you like me to:
1. **Start Phase 1**: Create new directory structure and move files?
2. **Create migration script**: Automate the file moves and import updates?
3. **Focus on specific area**: Start with admin API or async API first?

View File

@@ -0,0 +1,218 @@
# Web Directory Refactoring - Current Status
## ✅ Completed Steps
### 1. Directory Structure Created
```
src/web/
├── api/
│ ├── v1/
│ │ ├── admin/ (documents.py, annotations.py, training.py)
│ │ ├── async_api/ (routes.py)
│ │ ├── batch/ (routes.py)
│ │ └── routes.py (public inference API)
├── schemas/
│ ├── admin.py (admin schemas)
│ ├── inference.py (inference + async schemas)
│ └── common.py (ErrorResponse)
├── services/
│ ├── autolabel.py
│ ├── async_processing.py
│ ├── batch_upload.py
│ └── inference.py
├── core/
│ ├── auth.py
│ ├── rate_limiter.py
│ └── scheduler.py
└── workers/
├── async_queue.py
└── batch_queue.py
```
### 2. Files Copied and Imports Updated
#### Admin API (✅ Complete)
- [x] `admin_routes.py``api/v1/admin/documents.py` (imports updated)
- [x] `admin_annotation_routes.py``api/v1/admin/annotations.py` (imports updated)
- [x] `admin_training_routes.py``api/v1/admin/training.py` (imports updated)
- [x] `api/v1/admin/__init__.py` created with exports
#### Public & Async API (✅ Complete)
- [x] `routes.py``api/v1/routes.py` (imports updated)
- [x] `async_routes.py``api/v1/async_api/routes.py` (imports updated)
- [x] `batch_upload_routes.py``api/v1/batch/routes.py` (copied, imports pending)
#### Schemas (✅ Complete)
- [x] `admin_schemas.py``schemas/admin.py`
- [x] `schemas.py``schemas/inference.py`
- [x] `schemas/common.py` created
- [x] `schemas/__init__.py` created with exports
#### Services (✅ Complete)
- [x] `admin_autolabel.py``services/autolabel.py`
- [x] `async_service.py``services/async_processing.py`
- [x] `batch_upload_service.py``services/batch_upload.py`
- [x] `services.py``services/inference.py`
- [x] `services/__init__.py` created
#### Core Components (✅ Complete)
- [x] `admin_auth.py``core/auth.py`
- [x] `rate_limiter.py``core/rate_limiter.py`
- [x] `admin_scheduler.py``core/scheduler.py`
- [x] `core/__init__.py` created
#### Workers (✅ Complete)
- [x] `async_queue.py``workers/async_queue.py`
- [x] `batch_queue.py``workers/batch_queue.py`
- [x] `workers/__init__.py` created
#### Main App (✅ Complete)
- [x] `app.py` imports updated to use new structure
---
## ⏳ Remaining Work
### 1. Update Remaining File Imports (HIGH PRIORITY)
Files that need import updates:
- [ ] `api/v1/batch/routes.py` - update to use new schema/service imports
- [ ] `services/autolabel.py` - may need import updates if it references old paths
- [ ] `services/async_processing.py` - check for old import references
- [ ] `services/batch_upload.py` - check for old import references
- [ ] `services/inference.py` - check for old import references
### 2. Update ALL Test Files (CRITICAL)
Test files need to import from new locations. Pattern:
**Old:**
```python
from src.web.admin_routes import create_admin_router
from src.web.admin_schemas import DocumentItem
from src.web.admin_auth import validate_admin_token
```
**New:**
```python
from src.web.api.v1.admin import create_admin_router
from src.web.schemas.admin import DocumentItem
from src.web.core.auth import validate_admin_token
```
Test files to update:
- [ ] `tests/web/test_admin_annotations.py`
- [ ] `tests/web/test_admin_auth.py`
- [ ] `tests/web/test_admin_routes.py`
- [ ] `tests/web/test_admin_routes_enhanced.py`
- [ ] `tests/web/test_admin_training.py`
- [ ] `tests/web/test_annotation_locks.py`
- [ ] `tests/web/test_annotation_phase5.py`
- [ ] `tests/web/test_async_queue.py`
- [ ] `tests/web/test_async_routes.py`
- [ ] `tests/web/test_async_service.py`
- [ ] `tests/web/test_autolabel_with_locks.py`
- [ ] `tests/web/test_batch_queue.py`
- [ ] `tests/web/test_batch_upload_routes.py`
- [ ] `tests/web/test_batch_upload_service.py`
- [ ] `tests/web/test_rate_limiter.py`
- [ ] `tests/web/test_training_phase4.py`
### 3. Create Backward Compatibility Layer (OPTIONAL)
Keep old imports working temporarily:
```python
# src/web/admin_routes.py (temporary compatibility shim)
\"\"\"
DEPRECATED: Use src.web.api.v1.admin.documents instead.
This file will be removed in next version.
\"\"\"
import warnings
from src.web.api.v1.admin.documents import *
warnings.warn(
"Importing from src.web.admin_routes is deprecated. "
"Use src.web.api.v1.admin.documents instead.",
DeprecationWarning,
stacklevel=2
)
```
### 4. Verify and Test
1. Run tests:
```bash
pytest tests/web/ -v
```
2. Check for any import errors:
```bash
python -c "from src.web.app import create_app; create_app()"
```
3. Start server and test endpoints:
```bash
python run_server.py
```
### 5. Clean Up Old Files (ONLY AFTER TESTS PASS)
Old files to remove:
- `src/web/admin_*.py` (7 files)
- `src/web/async_*.py` (3 files)
- `src/web/batch_*.py` (3 files)
- `src/web/routes.py`
- `src/web/services.py`
- `src/web/schemas.py`
- `src/web/rate_limiter.py`
Keep these files (don't remove):
- `src/web/__init__.py`
- `src/web/app.py`
- `src/web/config.py`
- `src/web/dependencies.py`
---
## 🎯 Next Immediate Steps
1. **Update batch/routes.py imports** - Quick fix for remaining API route
2. **Update test file imports** - Critical for verification
3. **Run test suite** - Verify nothing broke
4. **Fix any import errors** - Address failures
5. **Remove old files** - Clean up after tests pass
---
## 📊 Migration Impact Summary
| Category | Files Moved | Imports Updated | Status |
|----------|-------------|-----------------|--------|
| API Routes | 7 | 5/7 | 🟡 In Progress |
| Schemas | 3 | 3/3 | ✅ Complete |
| Services | 4 | 0/4 | ⚠️ Pending |
| Core | 3 | 3/3 | ✅ Complete |
| Workers | 2 | 2/2 | ✅ Complete |
| Tests | 0 | 0/16 | ❌ Not Started |
**Overall Progress: 65%**
---
## 🚀 Benefits After Migration
1. **Better Organization**: Clear separation by function
2. **Easier Navigation**: Find files by purpose, not prefix
3. **Scalability**: Easy to add new API versions
4. **Standard Structure**: Follows FastAPI best practices
5. **Maintainability**: Each module has single responsibility
---
## 📝 Notes
- All original files are still in place (no data loss risk)
- New structure is operational but needs import updates
- Backward compatibility can be added if needed
- Tests will validate the migration success

5
frontend/.env.example Normal file
View File

@@ -0,0 +1,5 @@
# Backend API URL
VITE_API_URL=http://localhost:8000
# WebSocket URL (for future real-time updates)
VITE_WS_URL=ws://localhost:8000/ws

24
frontend/.gitignore vendored Normal file
View File

@@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

20
frontend/README.md Normal file
View File

@@ -0,0 +1,20 @@
<div align="center">
<img width="1200" height="475" alt="GHBanner" src="https://github.com/user-attachments/assets/0aa67016-6eaf-458a-adb2-6e31a0763ed6" />
</div>
# Run and deploy your AI Studio app
This contains everything you need to run your app locally.
View your app in AI Studio: https://ai.studio/apps/drive/13hqd80ft4g_LngMYB8LLJxx2XU8C_eI4
## Run Locally
**Prerequisites:** Node.js
1. Install dependencies:
`npm install`
2. Set the `GEMINI_API_KEY` in [.env.local](.env.local) to your Gemini API key
3. Run the app:
`npm run dev`

View File

@@ -0,0 +1,240 @@
# Frontend Refactoring Plan
## Current Structure Issues
1. **Flat component organization** - All components in one directory
2. **Mock data only** - No real API integration
3. **No state management** - Props drilling everywhere
4. **CDN dependencies** - Should use npm packages
5. **Manual routing** - Using useState instead of react-router
6. **No TypeScript integration with backend** - Types don't match API schemas
## Recommended Structure
```
frontend/
├── public/
│ └── favicon.ico
├── src/
│ ├── api/ # API Layer
│ │ ├── client.ts # Axios instance + interceptors
│ │ ├── types.ts # API request/response types
│ │ └── endpoints/
│ │ ├── documents.ts # GET /api/v1/admin/documents
│ │ ├── annotations.ts # GET/POST /api/v1/admin/documents/{id}/annotations
│ │ ├── training.ts # GET/POST /api/v1/admin/training/*
│ │ ├── inference.ts # POST /api/v1/infer
│ │ └── async.ts # POST /api/v1/async/submit
│ │
│ ├── components/
│ │ ├── common/ # Reusable components
│ │ │ ├── Badge.tsx
│ │ │ ├── Button.tsx
│ │ │ ├── Input.tsx
│ │ │ ├── Modal.tsx
│ │ │ ├── Table.tsx
│ │ │ ├── ProgressBar.tsx
│ │ │ └── StatusBadge.tsx
│ │ │
│ │ ├── layout/ # Layout components
│ │ │ ├── TopNav.tsx
│ │ │ ├── Sidebar.tsx
│ │ │ └── PageHeader.tsx
│ │ │
│ │ ├── documents/ # Document-specific components
│ │ │ ├── DocumentTable.tsx
│ │ │ ├── DocumentFilters.tsx
│ │ │ ├── DocumentRow.tsx
│ │ │ ├── UploadModal.tsx
│ │ │ └── BatchUploadModal.tsx
│ │ │
│ │ ├── annotations/ # Annotation components
│ │ │ ├── AnnotationCanvas.tsx
│ │ │ ├── AnnotationBox.tsx
│ │ │ ├── AnnotationTable.tsx
│ │ │ ├── FieldEditor.tsx
│ │ │ └── VerificationPanel.tsx
│ │ │
│ │ └── training/ # Training components
│ │ ├── DocumentSelector.tsx
│ │ ├── TrainingConfig.tsx
│ │ ├── TrainingJobList.tsx
│ │ ├── ModelCard.tsx
│ │ └── MetricsChart.tsx
│ │
│ ├── pages/ # Page-level components
│ │ ├── DocumentsPage.tsx # Was Dashboard.tsx
│ │ ├── DocumentDetailPage.tsx # Was DocumentDetail.tsx
│ │ ├── TrainingPage.tsx # Was Training.tsx
│ │ ├── ModelsPage.tsx # Was Models.tsx
│ │ └── InferencePage.tsx # New: Test inference
│ │
│ ├── hooks/ # Custom React Hooks
│ │ ├── useDocuments.ts # Document CRUD + listing
│ │ ├── useAnnotations.ts # Annotation management
│ │ ├── useTraining.ts # Training jobs
│ │ ├── usePolling.ts # Auto-refresh for async jobs
│ │ └── useDebounce.ts # Debounce search inputs
│ │
│ ├── store/ # State Management (Zustand)
│ │ ├── documentsStore.ts
│ │ ├── annotationsStore.ts
│ │ ├── trainingStore.ts
│ │ └── uiStore.ts
│ │
│ ├── types/ # TypeScript Types
│ │ ├── index.ts
│ │ ├── document.ts
│ │ ├── annotation.ts
│ │ ├── training.ts
│ │ └── api.ts
│ │
│ ├── utils/ # Utility Functions
│ │ ├── formatters.ts # Date, currency, etc.
│ │ ├── validators.ts # Form validation
│ │ └── constants.ts # Field definitions, statuses
│ │
│ ├── styles/
│ │ └── index.css # Tailwind entry
│ │
│ ├── App.tsx
│ ├── main.tsx
│ └── router.tsx # React Router config
├── .env.example
├── package.json
├── tsconfig.json
├── vite.config.ts
├── tailwind.config.js
├── postcss.config.js
└── index.html
```
## Migration Steps
### Phase 1: Setup Infrastructure
- [ ] Install dependencies (axios, react-router, zustand, @tanstack/react-query)
- [ ] Setup local Tailwind (remove CDN)
- [ ] Create API client with interceptors
- [ ] Add environment variables (.env.local with VITE_API_URL)
### Phase 2: Create API Layer
- [ ] Create `src/api/client.ts` with axios instance
- [ ] Create `src/api/endpoints/documents.ts` matching backend API
- [ ] Create `src/api/endpoints/annotations.ts`
- [ ] Create `src/api/endpoints/training.ts`
- [ ] Add types matching backend schemas
### Phase 3: Reorganize Components
- [ ] Move existing components to new structure
- [ ] Split large components (Dashboard > DocumentTable + DocumentFilters + DocumentRow)
- [ ] Extract reusable components (Badge, Button already done)
- [ ] Create layout components (TopNav, Sidebar)
### Phase 4: Add Routing
- [ ] Install react-router-dom
- [ ] Create router.tsx with routes
- [ ] Update App.tsx to use RouterProvider
- [ ] Add navigation links
### Phase 5: State Management
- [ ] Create custom hooks (useDocuments, useAnnotations)
- [ ] Use @tanstack/react-query for server state
- [ ] Add Zustand stores for UI state
- [ ] Replace mock data with API calls
### Phase 6: Backend Integration
- [ ] Update CORS settings in backend
- [ ] Test all API endpoints
- [ ] Add error handling
- [ ] Add loading states
## Dependencies to Add
```json
{
"dependencies": {
"react-router-dom": "^6.22.0",
"axios": "^1.6.7",
"zustand": "^4.5.0",
"@tanstack/react-query": "^5.20.0",
"date-fns": "^3.3.0",
"clsx": "^2.1.0"
},
"devDependencies": {
"tailwindcss": "^3.4.1",
"autoprefixer": "^10.4.17",
"postcss": "^8.4.35"
}
}
```
## Configuration Files to Create
### tailwind.config.js
```javascript
export default {
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
theme: {
extend: {
colors: {
warm: {
bg: '#FAFAF8',
card: '#FFFFFF',
hover: '#F1F0ED',
selected: '#ECEAE6',
border: '#E6E4E1',
divider: '#D8D6D2',
text: {
primary: '#121212',
secondary: '#2A2A2A',
muted: '#6B6B6B',
disabled: '#9A9A9A',
},
state: {
success: '#3E4A3A',
error: '#4A3A3A',
warning: '#4A4A3A',
info: '#3A3A3A',
}
}
}
}
}
}
```
### .env.example
```bash
VITE_API_URL=http://localhost:8000
VITE_WS_URL=ws://localhost:8000/ws
```
## Type Generation from Backend
Consider generating TypeScript types from Python Pydantic schemas:
- Option 1: Use `datamodel-code-generator` to convert schemas
- Option 2: Manually maintain types in `src/types/api.ts`
- Option 3: Use OpenAPI spec + openapi-typescript-codegen
## Testing Strategy
- Unit tests: Vitest for components
- Integration tests: React Testing Library
- E2E tests: Playwright (matching backend)
## Performance Considerations
- Code splitting by route
- Lazy load heavy components (AnnotationCanvas)
- Optimize re-renders with React.memo
- Use virtual scrolling for large tables
- Image lazy loading for document previews
## Accessibility
- Proper ARIA labels
- Keyboard navigation
- Focus management
- Color contrast compliance (already done with Warm Graphite theme)

256
frontend/SETUP.md Normal file
View File

@@ -0,0 +1,256 @@
# Frontend Setup Guide
## Quick Start
### 1. Install Dependencies
```bash
cd frontend
npm install
```
### 2. Configure Environment
Copy `.env.example` to `.env.local` and update if needed:
```bash
cp .env.example .env.local
```
Default configuration:
```
VITE_API_URL=http://localhost:8000
VITE_WS_URL=ws://localhost:8000/ws
```
### 3. Start Backend API
Make sure the backend is running first:
```bash
# From project root
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python run_server.py"
```
Backend will be available at: http://localhost:8000
### 4. Start Frontend Dev Server
```bash
cd frontend
npm run dev
```
Frontend will be available at: http://localhost:3000
## Project Structure
```
frontend/
├── src/
│ ├── api/ # API client layer
│ │ ├── client.ts # Axios instance with interceptors
│ │ ├── types.ts # API type definitions
│ │ └── endpoints/
│ │ ├── documents.ts # Document API calls
│ │ ├── annotations.ts # Annotation API calls
│ │ └── training.ts # Training API calls
│ │
│ ├── components/ # React components
│ │ └── Dashboard.tsx # Updated with real API integration
│ │
│ ├── hooks/ # Custom React Hooks
│ │ ├── useDocuments.ts
│ │ ├── useDocumentDetail.ts
│ │ ├── useAnnotations.ts
│ │ └── useTraining.ts
│ │
│ ├── styles/
│ │ └── index.css # Tailwind CSS entry
│ │
│ ├── App.tsx
│ └── main.tsx # App entry point with QueryClient
├── components/ # Legacy components (to be migrated)
│ ├── Badge.tsx
│ ├── Button.tsx
│ ├── Layout.tsx
│ ├── DocumentDetail.tsx
│ ├── Training.tsx
│ ├── Models.tsx
│ └── UploadModal.tsx
├── tailwind.config.js # Tailwind configuration
├── postcss.config.js
├── vite.config.ts
├── package.json
└── index.html
```
## Key Technologies
- **React 19** - UI framework
- **TypeScript** - Type safety
- **Vite** - Build tool
- **Tailwind CSS** - Styling (Warm Graphite theme)
- **Axios** - HTTP client
- **@tanstack/react-query** - Server state management
- **lucide-react** - Icon library
## API Integration
### Authentication
The app stores admin token in localStorage:
```typescript
localStorage.setItem('admin_token', 'your-token')
```
All API requests automatically include the `X-Admin-Token` header.
### Available Hooks
#### useDocuments
```typescript
const {
documents,
total,
isLoading,
uploadDocument,
deleteDocument,
triggerAutoLabel,
} = useDocuments({ status: 'labeled', limit: 20 })
```
#### useDocumentDetail
```typescript
const { document, annotations, isLoading } = useDocumentDetail(documentId)
```
#### useAnnotations
```typescript
const {
createAnnotation,
updateAnnotation,
deleteAnnotation,
verifyAnnotation,
overrideAnnotation,
} = useAnnotations(documentId)
```
#### useTraining
```typescript
const {
models,
isLoadingModels,
startTraining,
downloadModel,
} = useTraining()
```
## Features Implemented
### Phase 1 (Completed)
- ✅ API client with axios interceptors
- ✅ Type-safe API endpoints
- ✅ React Query for server state
- ✅ Custom hooks for all APIs
- ✅ Dashboard with real data
- ✅ Local Tailwind CSS
- ✅ Environment configuration
- ✅ CORS configured in backend
### Phase 2 (TODO)
- [ ] Update DocumentDetail to use useDocumentDetail
- [ ] Update Training page to use useTraining hooks
- [ ] Update Models page with real data
- [ ] Add UploadModal integration with API
- [ ] Add react-router for proper routing
- [ ] Add error boundary
- [ ] Add loading states
- [ ] Add toast notifications
### Phase 3 (TODO)
- [ ] Annotation canvas with real data
- [ ] Batch upload functionality
- [ ] Auto-label progress polling
- [ ] Training job monitoring
- [ ] Model download functionality
- [ ] Search and filtering
- [ ] Pagination
## Development Tips
### Hot Module Replacement
Vite supports HMR. Changes will reflect immediately without page reload.
### API Debugging
Check browser console for API requests:
- Network tab shows all requests/responses
- Axios interceptors log errors automatically
### Type Safety
TypeScript types in `src/api/types.ts` match backend Pydantic schemas.
To regenerate types from backend:
```bash
# TODO: Add type generation script
```
### Backend API Documentation
Visit http://localhost:8000/docs for interactive API documentation (Swagger UI).
## Troubleshooting
### CORS Errors
If you see CORS errors:
1. Check backend is running at http://localhost:8000
2. Verify CORS settings in `src/web/app.py`
3. Check `.env.local` has correct `VITE_API_URL`
### Module Not Found
If imports fail:
```bash
rm -rf node_modules package-lock.json
npm install
```
### Types Not Matching
If API responses don't match types:
1. Check backend version is up-to-date
2. Verify types in `src/api/types.ts`
3. Check API response in Network tab
## Next Steps
1. Run `npm install` to install dependencies
2. Start backend server
3. Run `npm run dev` to start frontend
4. Open http://localhost:3000
5. Create an admin token via backend API
6. Store token in localStorage via browser console:
```javascript
localStorage.setItem('admin_token', 'your-token-here')
```
7. Refresh page to see authenticated API calls
## Production Build
```bash
npm run build
npm run preview # Preview production build
```
Build output will be in `dist/` directory.

15
frontend/index.html Normal file
View File

@@ -0,0 +1,15 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Graphite Annotator - Invoice Field Extraction</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

5
frontend/metadata.json Normal file
View File

@@ -0,0 +1,5 @@
{
"name": "Graphite Annotator",
"description": "A professional, warm graphite themed document annotation and training tool for enterprise use cases.",
"requestFramePermissions": []
}

3510
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

32
frontend/package.json Normal file
View File

@@ -0,0 +1,32 @@
{
"name": "graphite-annotator",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview"
},
"dependencies": {
"react": "^19.2.3",
"react-dom": "^19.2.3",
"lucide-react": "^0.563.0",
"recharts": "^3.7.0",
"axios": "^1.6.7",
"react-router-dom": "^6.22.0",
"zustand": "^4.5.0",
"@tanstack/react-query": "^5.20.0",
"date-fns": "^3.3.0",
"clsx": "^2.1.0"
},
"devDependencies": {
"@types/node": "^22.14.0",
"@vitejs/plugin-react": "^5.0.0",
"typescript": "~5.8.2",
"vite": "^6.2.0",
"tailwindcss": "^3.4.1",
"autoprefixer": "^10.4.17",
"postcss": "^8.4.35"
}
}

View File

@@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

73
frontend/src/App.tsx Normal file
View File

@@ -0,0 +1,73 @@
import React, { useState, useEffect } from 'react'
import { Layout } from './components/Layout'
import { DashboardOverview } from './components/DashboardOverview'
import { Dashboard } from './components/Dashboard'
import { DocumentDetail } from './components/DocumentDetail'
import { Training } from './components/Training'
import { Models } from './components/Models'
import { Login } from './components/Login'
import { InferenceDemo } from './components/InferenceDemo'
const App: React.FC = () => {
const [currentView, setCurrentView] = useState('dashboard')
const [selectedDocId, setSelectedDocId] = useState<string | null>(null)
const [isAuthenticated, setIsAuthenticated] = useState(false)
useEffect(() => {
const token = localStorage.getItem('admin_token')
setIsAuthenticated(!!token)
}, [])
const handleNavigate = (view: string, docId?: string) => {
setCurrentView(view)
if (docId) {
setSelectedDocId(docId)
}
}
const handleLogin = (token: string) => {
setIsAuthenticated(true)
}
const handleLogout = () => {
localStorage.removeItem('admin_token')
setIsAuthenticated(false)
setCurrentView('documents')
}
if (!isAuthenticated) {
return <Login onLogin={handleLogin} />
}
const renderContent = () => {
switch (currentView) {
case 'dashboard':
return <DashboardOverview onNavigate={handleNavigate} />
case 'documents':
return <Dashboard onNavigate={handleNavigate} />
case 'detail':
return (
<DocumentDetail
docId={selectedDocId || '1'}
onBack={() => setCurrentView('documents')}
/>
)
case 'demo':
return <InferenceDemo />
case 'training':
return <Training />
case 'models':
return <Models />
default:
return <DashboardOverview onNavigate={handleNavigate} />
}
}
return (
<Layout activeView={currentView} onNavigate={handleNavigate} onLogout={handleLogout}>
{renderContent()}
</Layout>
)
}
export default App

View File

@@ -0,0 +1,41 @@
import axios, { AxiosInstance, AxiosError } from 'axios'
const apiClient: AxiosInstance = axios.create({
baseURL: import.meta.env.VITE_API_URL || 'http://localhost:8000',
headers: {
'Content-Type': 'application/json',
},
timeout: 30000,
})
apiClient.interceptors.request.use(
(config) => {
const token = localStorage.getItem('admin_token')
if (token) {
config.headers['X-Admin-Token'] = token
}
return config
},
(error) => {
return Promise.reject(error)
}
)
apiClient.interceptors.response.use(
(response) => response,
(error: AxiosError) => {
if (error.response?.status === 401) {
console.warn('Authentication required. Please set admin_token in localStorage.')
// Don't redirect to avoid infinite loop
// User should manually set: localStorage.setItem('admin_token', 'your-token')
}
if (error.response?.status === 429) {
console.error('Rate limit exceeded')
}
return Promise.reject(error)
}
)
export default apiClient

View File

@@ -0,0 +1,66 @@
import apiClient from '../client'
import type {
AnnotationItem,
CreateAnnotationRequest,
AnnotationOverrideRequest,
} from '../types'
export const annotationsApi = {
list: async (documentId: string): Promise<AnnotationItem[]> => {
const { data } = await apiClient.get(
`/api/v1/admin/documents/${documentId}/annotations`
)
return data.annotations
},
create: async (
documentId: string,
annotation: CreateAnnotationRequest
): Promise<AnnotationItem> => {
const { data } = await apiClient.post(
`/api/v1/admin/documents/${documentId}/annotations`,
annotation
)
return data
},
update: async (
documentId: string,
annotationId: string,
updates: Partial<CreateAnnotationRequest>
): Promise<AnnotationItem> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`,
updates
)
return data
},
delete: async (documentId: string, annotationId: string): Promise<void> => {
await apiClient.delete(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`
)
},
verify: async (
documentId: string,
annotationId: string
): Promise<{ annotation_id: string; is_verified: boolean; message: string }> => {
const { data } = await apiClient.post(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/verify`
)
return data
},
override: async (
documentId: string,
annotationId: string,
overrideData: AnnotationOverrideRequest
): Promise<{ annotation_id: string; source: string; message: string }> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/override`,
overrideData
)
return data
},
}

View File

@@ -0,0 +1,80 @@
import apiClient from '../client'
import type {
DocumentListResponse,
DocumentDetailResponse,
DocumentItem,
UploadDocumentResponse,
} from '../types'
export const documentsApi = {
list: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<DocumentListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/documents', { params })
return data
},
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
return data
},
upload: async (file: File): Promise<UploadDocumentResponse> => {
const formData = new FormData()
formData.append('file', file)
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
})
return data
},
batchUpload: async (
files: File[],
csvFile?: File
): Promise<{ batch_id: string; message: string; documents_created: number }> => {
const formData = new FormData()
files.forEach((file) => {
formData.append('files', file)
})
if (csvFile) {
formData.append('csv_file', csvFile)
}
const { data } = await apiClient.post('/api/v1/admin/batch/upload', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
})
return data
},
delete: async (documentId: string): Promise<void> => {
await apiClient.delete(`/api/v1/admin/documents/${documentId}`)
},
updateStatus: async (
documentId: string,
status: string
): Promise<DocumentItem> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/status`,
null,
{ params: { status } }
)
return data
},
triggerAutoLabel: async (documentId: string): Promise<{ message: string }> => {
const { data } = await apiClient.post(
`/api/v1/admin/documents/${documentId}/auto-label`
)
return data
},
}

View File

@@ -0,0 +1,4 @@
export { documentsApi } from './documents'
export { annotationsApi } from './annotations'
export { trainingApi } from './training'
export { inferenceApi } from './inference'

View File

@@ -0,0 +1,16 @@
import apiClient from '../client'
import type { InferenceResponse } from '../types'
export const inferenceApi = {
processDocument: async (file: File): Promise<InferenceResponse> => {
const formData = new FormData()
formData.append('file', file)
const { data } = await apiClient.post('/api/v1/infer', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
})
return data
},
}

View File

@@ -0,0 +1,74 @@
import apiClient from '../client'
import type { TrainingModelsResponse, DocumentListResponse } from '../types'
export const trainingApi = {
getDocumentsForTraining: async (params?: {
has_annotations?: boolean
min_annotation_count?: number
exclude_used_in_training?: boolean
limit?: number
offset?: number
}): Promise<DocumentListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/documents', {
params,
})
return data
},
getModels: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<TrainingModelsResponse> => {
const { data} = await apiClient.get('/api/v1/admin/training/models', {
params,
})
return data
},
getTaskDetail: async (taskId: string) => {
const { data } = await apiClient.get(`/api/v1/admin/training/tasks/${taskId}`)
return data
},
startTraining: async (config: {
name: string
description?: string
document_ids: string[]
epochs?: number
batch_size?: number
model_base?: string
}) => {
// Convert frontend config to backend TrainingTaskCreate format
const taskRequest = {
name: config.name,
task_type: 'yolo',
description: config.description,
config: {
document_ids: config.document_ids,
epochs: config.epochs,
batch_size: config.batch_size,
base_model: config.model_base,
},
}
const { data } = await apiClient.post('/api/v1/admin/training/tasks', taskRequest)
return data
},
cancelTask: async (taskId: string) => {
const { data } = await apiClient.post(
`/api/v1/admin/training/tasks/${taskId}/cancel`
)
return data
},
downloadModel: async (taskId: string): Promise<Blob> => {
const { data } = await apiClient.get(
`/api/v1/admin/training/models/${taskId}/download`,
{
responseType: 'blob',
}
)
return data
},
}

173
frontend/src/api/types.ts Normal file
View File

@@ -0,0 +1,173 @@
export interface DocumentItem {
document_id: string
filename: string
file_size: number
content_type: string
page_count: number
status: 'pending' | 'labeled' | 'verified' | 'exported'
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
auto_label_error: string | null
upload_source: string
created_at: string
updated_at: string
annotation_count?: number
annotation_sources?: {
manual: number
auto: number
verified: number
}
}
export interface DocumentListResponse {
documents: DocumentItem[]
total: number
limit: number
offset: number
}
export interface AnnotationItem {
annotation_id: string
page_number: number
class_id: number
class_name: string
bbox: {
x: number
y: number
width: number
height: number
}
normalized_bbox: {
x_center: number
y_center: number
width: number
height: number
}
text_value: string | null
confidence: number | null
source: 'manual' | 'auto'
created_at: string
}
export interface DocumentDetailResponse {
document_id: string
filename: string
file_size: number
content_type: string
page_count: number
status: 'pending' | 'labeled' | 'verified' | 'exported'
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
auto_label_error: string | null
upload_source: string
batch_id: string | null
csv_field_values: Record<string, string> | null
can_annotate: boolean
annotation_lock_until: string | null
annotations: AnnotationItem[]
image_urls: string[]
training_history: Array<{
task_id: string
name: string
trained_at: string
model_metrics: {
mAP: number | null
precision: number | null
recall: number | null
} | null
}>
created_at: string
updated_at: string
}
export interface TrainingTask {
task_id: string
admin_token: string
name: string
description: string | null
status: 'pending' | 'running' | 'completed' | 'failed'
task_type: string
config: Record<string, unknown>
started_at: string | null
completed_at: string | null
error_message: string | null
result_metrics: Record<string, unknown>
model_path: string | null
document_count: number
metrics_mAP: number | null
metrics_precision: number | null
metrics_recall: number | null
created_at: string
updated_at: string
}
export interface TrainingModelsResponse {
models: TrainingTask[]
total: number
limit: number
offset: number
}
export interface ErrorResponse {
detail: string
}
export interface UploadDocumentResponse {
document_id: string
filename: string
status: string
message: string
}
export interface CreateAnnotationRequest {
page_number: number
class_id: number
bbox: {
x: number
y: number
width: number
height: number
}
text_value?: string
}
export interface AnnotationOverrideRequest {
text_value?: string
bbox?: {
x: number
y: number
width: number
height: number
}
class_id?: number
class_name?: string
reason?: string
}
export interface CrossValidationResult {
is_valid: boolean
payment_line_ocr: string | null
payment_line_amount: string | null
payment_line_account: string | null
payment_line_account_type: 'bankgiro' | 'plusgiro' | null
ocr_match: boolean | null
amount_match: boolean | null
bankgiro_match: boolean | null
plusgiro_match: boolean | null
details: string[]
}
export interface InferenceResult {
document_id: string
document_type: string
success: boolean
fields: Record<string, string>
confidence: Record<string, number>
cross_validation: CrossValidationResult | null
processing_time_ms: number
visualization_url: string | null
errors: string[]
fallback_used: boolean
}
export interface InferenceResponse {
result: InferenceResult
}

View File

@@ -0,0 +1,39 @@
import React from 'react';
import { DocumentStatus } from '../types';
import { Check } from 'lucide-react';
interface BadgeProps {
status: DocumentStatus | 'Exported';
}
export const Badge: React.FC<BadgeProps> = ({ status }) => {
if (status === 'Exported') {
return (
<span className="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-full text-xs font-medium bg-warm-selected text-warm-text-secondary">
<Check size={12} strokeWidth={3} />
Exported
</span>
);
}
const styles = {
[DocumentStatus.PENDING]: "bg-white border border-warm-divider text-warm-text-secondary",
[DocumentStatus.LABELED]: "bg-warm-text-secondary text-white border border-transparent",
[DocumentStatus.VERIFIED]: "bg-warm-state-success/10 text-warm-state-success border border-warm-state-success/20",
[DocumentStatus.PARTIAL]: "bg-warm-state-warning/10 text-warm-state-warning border border-warm-state-warning/20",
};
const icons = {
[DocumentStatus.VERIFIED]: <Check size={12} className="mr-1" />,
[DocumentStatus.PARTIAL]: <span className="mr-1 text-[10px] font-bold">!</span>,
[DocumentStatus.PENDING]: null,
[DocumentStatus.LABELED]: null,
}
return (
<span className={`inline-flex items-center px-3 py-1 rounded-full text-xs font-medium border ${styles[status]}`}>
{icons[status]}
{status}
</span>
);
};

View File

@@ -0,0 +1,38 @@
import React from 'react';
interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement> {
variant?: 'primary' | 'secondary' | 'outline' | 'text';
size?: 'sm' | 'md' | 'lg';
}
export const Button: React.FC<ButtonProps> = ({
variant = 'primary',
size = 'md',
className = '',
children,
...props
}) => {
const baseStyles = "inline-flex items-center justify-center rounded-md font-medium transition-all duration-150 ease-out active:scale-98 disabled:opacity-50 disabled:pointer-events-none";
const variants = {
primary: "bg-warm-text-secondary text-white hover:bg-warm-text-primary shadow-sm",
secondary: "bg-white border border-warm-divider text-warm-text-secondary hover:bg-warm-hover",
outline: "bg-transparent border border-warm-text-secondary text-warm-text-secondary hover:bg-warm-hover",
text: "text-warm-text-muted hover:text-warm-text-primary hover:bg-warm-hover",
};
const sizes = {
sm: "h-8 px-3 text-xs",
md: "h-10 px-4 text-sm",
lg: "h-12 px-6 text-base",
};
return (
<button
className={`${baseStyles} ${variants[variant]} ${sizes[size]} ${className}`}
{...props}
>
{children}
</button>
);
};

View File

@@ -0,0 +1,266 @@
import React, { useState } from 'react'
import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
import { Badge } from './Badge'
import { Button } from './Button'
import { UploadModal } from './UploadModal'
import { useDocuments } from '../hooks/useDocuments'
import type { DocumentItem } from '../api/types'
interface DashboardProps {
onNavigate: (view: string, docId?: string) => void
}
const getStatusForBadge = (status: string): string => {
const statusMap: Record<string, string> = {
pending: 'Pending',
labeled: 'Labeled',
verified: 'Verified',
exported: 'Exported',
}
return statusMap[status] || status
}
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
if (doc.auto_label_status === 'running') {
return 45
}
if (doc.auto_label_status === 'completed') {
return 100
}
return undefined
}
export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
const [isUploadOpen, setIsUploadOpen] = useState(false)
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
const [statusFilter, setStatusFilter] = useState<string>('')
const [limit] = useState(20)
const [offset] = useState(0)
const { documents, total, isLoading, error, refetch } = useDocuments({
status: statusFilter || undefined,
limit,
offset,
})
const toggleSelection = (id: string) => {
const newSet = new Set(selectedDocs)
if (newSet.has(id)) {
newSet.delete(id)
} else {
newSet.add(id)
}
setSelectedDocs(newSet)
}
if (error) {
return (
<div className="p-8 max-w-7xl mx-auto">
<div className="bg-red-50 border border-red-200 text-red-800 p-4 rounded-lg">
Error loading documents. Please check your connection to the backend API.
<button
onClick={() => refetch()}
className="ml-4 underline hover:no-underline"
>
Retry
</button>
</div>
</div>
)
}
return (
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
<div className="flex items-center justify-between mb-8">
<div>
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
Documents
</h1>
<p className="text-sm text-warm-text-muted mt-1">
{isLoading ? 'Loading...' : `${total} documents total`}
</p>
</div>
<div className="flex gap-3">
<Button variant="secondary" disabled={selectedDocs.size === 0}>
Export Selection ({selectedDocs.size})
</Button>
<Button onClick={() => setIsUploadOpen(true)}>Upload Documents</Button>
</div>
</div>
<div className="bg-warm-card border border-warm-border rounded-lg p-4 mb-6 shadow-sm flex flex-wrap gap-4 items-center">
<div className="relative flex-1 min-w-[200px]">
<Search
className="absolute left-3 top-1/2 -translate-y-1/2 text-warm-text-muted"
size={16}
/>
<input
type="text"
placeholder="Search documents..."
className="w-full pl-9 pr-4 h-10 rounded-md border border-warm-border bg-white focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow text-sm"
/>
</div>
<div className="flex gap-3">
<div className="relative">
<select
value={statusFilter}
onChange={(e) => setStatusFilter(e.target.value)}
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
>
<option value="">All Statuses</option>
<option value="pending">Pending</option>
<option value="labeled">Labeled</option>
<option value="verified">Verified</option>
<option value="exported">Exported</option>
</select>
<ChevronDown
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
size={14}
/>
</div>
</div>
</div>
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<table className="w-full text-left border-collapse">
<thead>
<tr className="border-b border-warm-border bg-white">
<th className="py-3 pl-6 pr-4 w-12">
<input
type="checkbox"
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary"
/>
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Document Name
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Date
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Status
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Annotations
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
Auto-label
</th>
<th className="py-3 px-4 w-12"></th>
</tr>
</thead>
<tbody>
{isLoading ? (
<tr>
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
Loading documents...
</td>
</tr>
) : documents.length === 0 ? (
<tr>
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
No documents found. Upload your first document to get started.
</td>
</tr>
) : (
documents.map((doc) => {
const isSelected = selectedDocs.has(doc.document_id)
const progress = getAutoLabelProgress(doc)
return (
<tr
key={doc.document_id}
onClick={() => onNavigate('detail', doc.document_id)}
className={`
group transition-colors duration-150 cursor-pointer border-b border-warm-border last:border-0
${isSelected ? 'bg-warm-selected' : 'hover:bg-warm-hover bg-white'}
`}
>
<td
className="py-4 pl-6 pr-4 relative"
onClick={(e) => {
e.stopPropagation()
toggleSelection(doc.document_id)
}}
>
{isSelected && (
<div className="absolute left-0 top-0 bottom-0 w-[3px] bg-warm-state-info" />
)}
<input
type="checkbox"
checked={isSelected}
readOnly
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary cursor-pointer"
/>
</td>
<td className="py-4 px-4">
<div className="flex items-center gap-3">
<div className="p-2 bg-warm-bg rounded border border-warm-border text-warm-text-muted">
<FileText size={16} />
</div>
<span className="font-medium text-warm-text-secondary">
{doc.filename}
</span>
</div>
</td>
<td className="py-4 px-4 text-sm text-warm-text-secondary font-mono">
{new Date(doc.created_at).toLocaleDateString()}
</td>
<td className="py-4 px-4">
<Badge status={getStatusForBadge(doc.status)} />
</td>
<td className="py-4 px-4 text-sm text-warm-text-secondary">
{doc.annotation_count || 0} annotations
</td>
<td className="py-4 px-4">
{doc.auto_label_status === 'running' && progress && (
<div className="w-full">
<div className="flex justify-between text-xs mb-1">
<span className="text-warm-text-secondary font-medium">
Running
</span>
<span className="text-warm-text-muted">{progress}%</span>
</div>
<div className="h-1.5 w-full bg-warm-selected rounded-full overflow-hidden">
<div
className="h-full bg-warm-state-info transition-all duration-500 ease-out"
style={{ width: `${progress}%` }}
/>
</div>
</div>
)}
{doc.auto_label_status === 'completed' && (
<span className="text-sm font-medium text-warm-state-success">
Completed
</span>
)}
{doc.auto_label_status === 'failed' && (
<span className="text-sm font-medium text-warm-state-error">
Failed
</span>
)}
</td>
<td className="py-4 px-4 text-right">
<button className="text-warm-text-muted hover:text-warm-text-secondary p-1 rounded hover:bg-black/5 transition-colors">
<MoreHorizontal size={18} />
</button>
</td>
</tr>
)
})
)}
</tbody>
</table>
</div>
<UploadModal
isOpen={isUploadOpen}
onClose={() => {
setIsUploadOpen(false)
refetch()
}}
/>
</div>
)
}

View File

@@ -0,0 +1,148 @@
import React from 'react'
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
import { Button } from './Button'
import { useDocuments } from '../hooks/useDocuments'
import { useTraining } from '../hooks/useTraining'
interface DashboardOverviewProps {
onNavigate: (view: string) => void
}
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
const { models, isLoadingModels } = useTraining()
const stats = [
{
label: 'Total Documents',
value: docsLoading ? '...' : totalDocs.toString(),
icon: FileText,
color: 'text-warm-text-primary',
bgColor: 'bg-warm-bg',
},
{
label: 'Labeled',
value: '0',
icon: CheckCircle,
color: 'text-warm-state-success',
bgColor: 'bg-green-50',
},
{
label: 'Pending',
value: '0',
icon: Clock,
color: 'text-warm-state-warning',
bgColor: 'bg-yellow-50',
},
{
label: 'Training Models',
value: isLoadingModels ? '...' : models.length.toString(),
icon: TrendingUp,
color: 'text-warm-state-info',
bgColor: 'bg-blue-50',
},
]
return (
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
{/* Header */}
<div className="mb-8">
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
Dashboard
</h1>
<p className="text-sm text-warm-text-muted mt-1">
Overview of your document annotation system
</p>
</div>
{/* Stats Grid */}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
{stats.map((stat) => (
<div
key={stat.label}
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
>
<div className="flex items-center justify-between mb-4">
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
<stat.icon className={stat.color} size={24} />
</div>
</div>
<p className="text-2xl font-bold text-warm-text-primary mb-1">
{stat.value}
</p>
<p className="text-sm text-warm-text-muted">{stat.label}</p>
</div>
))}
</div>
{/* Quick Actions */}
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
Quick Actions
</h2>
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<Button onClick={() => onNavigate('documents')} className="justify-start">
<FileText size={18} className="mr-2" />
Manage Documents
</Button>
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
<Activity size={18} className="mr-2" />
Start Training
</Button>
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
<TrendingUp size={18} className="mr-2" />
View Models
</Button>
</div>
</div>
{/* Recent Activity */}
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<div className="p-6 border-b border-warm-border">
<h2 className="text-lg font-semibold text-warm-text-primary">
Recent Activity
</h2>
</div>
<div className="p-6">
<div className="text-center py-8 text-warm-text-muted">
<Activity size={48} className="mx-auto mb-3 opacity-20" />
<p className="text-sm">No recent activity</p>
<p className="text-xs mt-1">
Start by uploading documents or creating training jobs
</p>
</div>
</div>
</div>
{/* System Status */}
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
System Status
</h2>
<div className="space-y-3">
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">Backend API</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Online
</span>
</div>
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">Database</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Connected
</span>
</div>
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">GPU</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Available
</span>
</div>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,504 @@
import React, { useState, useRef, useEffect } from 'react'
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
import { Button } from './Button'
import { useDocumentDetail } from '../hooks/useDocumentDetail'
import { useAnnotations } from '../hooks/useAnnotations'
import { documentsApi } from '../api/endpoints/documents'
import type { AnnotationItem } from '../api/types'
interface DocumentDetailProps {
docId: string
onBack: () => void
}
// Field class mapping from backend
const FIELD_CLASSES: Record<number, string> = {
0: 'invoice_number',
1: 'invoice_date',
2: 'invoice_due_date',
3: 'ocr_number',
4: 'bankgiro',
5: 'plusgiro',
6: 'amount',
7: 'supplier_organisation_number',
8: 'payment_line',
9: 'customer_number',
}
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
const { document, annotations, isLoading } = useDocumentDetail(docId)
const {
createAnnotation,
updateAnnotation,
deleteAnnotation,
isCreating,
isDeleting,
} = useAnnotations(docId)
const [selectedId, setSelectedId] = useState<string | null>(null)
const [zoom, setZoom] = useState(100)
const [isDrawing, setIsDrawing] = useState(false)
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
const [selectedClassId, setSelectedClassId] = useState<number>(0)
const [currentPage, setCurrentPage] = useState(1)
const [imageSize, setImageSize] = useState<{ width: number; height: number } | null>(null)
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null)
const canvasRef = useRef<HTMLDivElement>(null)
const imageRef = useRef<HTMLImageElement>(null)
const [isMarkingComplete, setIsMarkingComplete] = useState(false)
const selectedAnnotation = annotations?.find((a) => a.annotation_id === selectedId)
// Handle mark as complete
const handleMarkComplete = async () => {
if (!annotations || annotations.length === 0) {
alert('Please add at least one annotation before marking as complete.')
return
}
if (!confirm('Mark this document as labeled? This will save annotations to the database.')) {
return
}
setIsMarkingComplete(true)
try {
const result = await documentsApi.updateStatus(docId, 'labeled')
alert(`Document marked as labeled. ${(result as any).fields_saved || annotations.length} annotations saved.`)
onBack() // Return to document list
} catch (error) {
console.error('Failed to mark document as complete:', error)
alert('Failed to mark document as complete. Please try again.')
} finally {
setIsMarkingComplete(false)
}
}
// Load image via fetch with authentication header
useEffect(() => {
let objectUrl: string | null = null
const loadImage = async () => {
if (!docId) return
const token = localStorage.getItem('admin_token')
const imageUrl = `${import.meta.env.VITE_API_URL || 'http://localhost:8000'}/api/v1/admin/documents/${docId}/images/${currentPage}`
try {
const response = await fetch(imageUrl, {
headers: {
'X-Admin-Token': token || '',
},
})
if (!response.ok) {
throw new Error(`Failed to load image: ${response.status}`)
}
const blob = await response.blob()
objectUrl = URL.createObjectURL(blob)
setImageBlobUrl(objectUrl)
} catch (error) {
console.error('Failed to load image:', error)
}
}
loadImage()
// Cleanup: revoke object URL when component unmounts or page changes
return () => {
if (objectUrl) {
URL.revokeObjectURL(objectUrl)
}
}
}, [currentPage, docId])
// Load image size
useEffect(() => {
if (imageRef.current && imageRef.current.complete) {
setImageSize({
width: imageRef.current.naturalWidth,
height: imageRef.current.naturalHeight,
})
}
}, [imageBlobUrl])
const handleImageLoad = () => {
if (imageRef.current) {
setImageSize({
width: imageRef.current.naturalWidth,
height: imageRef.current.naturalHeight,
})
}
}
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
if (!canvasRef.current || !imageSize) return
const rect = canvasRef.current.getBoundingClientRect()
const x = (e.clientX - rect.left) / (zoom / 100)
const y = (e.clientY - rect.top) / (zoom / 100)
setIsDrawing(true)
setDrawStart({ x, y })
setDrawEnd({ x, y })
}
const handleMouseMove = (e: React.MouseEvent<HTMLDivElement>) => {
if (!isDrawing || !canvasRef.current || !imageSize) return
const rect = canvasRef.current.getBoundingClientRect()
const x = (e.clientX - rect.left) / (zoom / 100)
const y = (e.clientY - rect.top) / (zoom / 100)
setDrawEnd({ x, y })
}
const handleMouseUp = () => {
if (!isDrawing || !drawStart || !drawEnd || !imageSize) {
setIsDrawing(false)
return
}
const bbox_x = Math.min(drawStart.x, drawEnd.x)
const bbox_y = Math.min(drawStart.y, drawEnd.y)
const bbox_width = Math.abs(drawEnd.x - drawStart.x)
const bbox_height = Math.abs(drawEnd.y - drawStart.y)
// Only create if box is large enough (min 10x10 pixels)
if (bbox_width > 10 && bbox_height > 10) {
createAnnotation({
page_number: currentPage,
class_id: selectedClassId,
bbox: {
x: Math.round(bbox_x),
y: Math.round(bbox_y),
width: Math.round(bbox_width),
height: Math.round(bbox_height),
},
})
}
setIsDrawing(false)
setDrawStart(null)
setDrawEnd(null)
}
const handleDeleteAnnotation = (annotationId: string) => {
if (confirm('Are you sure you want to delete this annotation?')) {
deleteAnnotation(annotationId)
setSelectedId(null)
}
}
if (isLoading || !document) {
return (
<div className="flex h-screen items-center justify-center">
<div className="text-warm-text-muted">Loading...</div>
</div>
)
}
// Get current page annotations
const pageAnnotations = annotations?.filter((a) => a.page_number === currentPage) || []
return (
<div className="flex h-[calc(100vh-56px)] overflow-hidden">
{/* Main Canvas Area */}
<div className="flex-1 bg-warm-bg flex flex-col relative">
{/* Toolbar */}
<div className="h-14 border-b border-warm-border bg-white flex items-center justify-between px-4 z-10">
<div className="flex items-center gap-4">
<button
onClick={onBack}
className="p-2 hover:bg-warm-hover rounded-md text-warm-text-secondary transition-colors"
>
<ChevronLeft size={20} />
</button>
<div>
<h2 className="text-sm font-semibold text-warm-text-primary">{document.filename}</h2>
<p className="text-xs text-warm-text-muted">
Page {currentPage} of {document.page_count}
</p>
</div>
<div className="h-6 w-px bg-warm-divider mx-2" />
<div className="flex items-center gap-2">
<button
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
onClick={() => setZoom((z) => Math.max(50, z - 10))}
>
<ZoomOut size={16} />
</button>
<span className="text-xs font-mono w-12 text-center text-warm-text-secondary">
{zoom}%
</span>
<button
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
onClick={() => setZoom((z) => Math.min(200, z + 10))}
>
<ZoomIn size={16} />
</button>
</div>
</div>
<div className="flex gap-2">
<Button variant="secondary" size="sm">
Auto-label
</Button>
<Button
variant="primary"
size="sm"
onClick={handleMarkComplete}
disabled={isMarkingComplete || document.status === 'labeled'}
>
<CheckCircle size={16} className="mr-1" />
{isMarkingComplete ? 'Saving...' : document.status === 'labeled' ? 'Labeled' : 'Mark Complete'}
</Button>
{document.page_count > 1 && (
<div className="flex gap-1">
<Button
variant="secondary"
size="sm"
onClick={() => setCurrentPage((p) => Math.max(1, p - 1))}
disabled={currentPage === 1}
>
Prev
</Button>
<Button
variant="secondary"
size="sm"
onClick={() => setCurrentPage((p) => Math.min(document.page_count, p + 1))}
disabled={currentPage === document.page_count}
>
Next
</Button>
</div>
)}
</div>
</div>
{/* Canvas Scroll Area */}
<div className="flex-1 overflow-auto p-8 flex justify-center bg-warm-bg">
<div
ref={canvasRef}
className="bg-white shadow-lg relative transition-transform duration-200 ease-out origin-top"
style={{
width: imageSize?.width || 800,
height: imageSize?.height || 1132,
transform: `scale(${zoom / 100})`,
marginBottom: '100px',
cursor: isDrawing ? 'crosshair' : 'default',
}}
onMouseDown={handleMouseDown}
onMouseMove={handleMouseMove}
onMouseUp={handleMouseUp}
onClick={() => setSelectedId(null)}
>
{/* Document Image */}
{imageBlobUrl ? (
<img
ref={imageRef}
src={imageBlobUrl}
alt={`Page ${currentPage}`}
className="w-full h-full object-contain select-none pointer-events-none"
onLoad={handleImageLoad}
/>
) : (
<div className="flex items-center justify-center h-full">
<div className="text-warm-text-muted">Loading image...</div>
</div>
)}
{/* Annotation Overlays */}
{pageAnnotations.map((ann) => {
const isSelected = selectedId === ann.annotation_id
return (
<div
key={ann.annotation_id}
onClick={(e) => {
e.stopPropagation()
setSelectedId(ann.annotation_id)
}}
className={`
absolute group cursor-pointer transition-all duration-100
${
ann.source === 'auto'
? 'border border-dashed border-warm-text-muted bg-transparent'
: 'border-2 border-warm-text-secondary bg-warm-text-secondary/5'
}
${
isSelected
? 'border-2 border-warm-state-info ring-4 ring-warm-state-info/10 z-20'
: 'hover:bg-warm-state-info/5 z-10'
}
`}
style={{
left: ann.bbox.x,
top: ann.bbox.y,
width: ann.bbox.width,
height: ann.bbox.height,
}}
>
{/* Label Tag */}
<div
className={`
absolute -top-6 left-0 text-[10px] uppercase font-bold px-1.5 py-0.5 rounded-sm tracking-wide shadow-sm whitespace-nowrap
${
isSelected
? 'bg-warm-state-info text-white'
: 'bg-white text-warm-text-secondary border border-warm-border'
}
`}
>
{ann.class_name}
</div>
{/* Resize Handles (Visual only) */}
{isSelected && (
<>
<div className="absolute -top-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
<div className="absolute -top-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
<div className="absolute -bottom-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
<div className="absolute -bottom-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
</>
)}
</div>
)
})}
{/* Drawing Box Preview */}
{isDrawing && drawStart && drawEnd && (
<div
className="absolute border-2 border-warm-state-info bg-warm-state-info/10 z-30 pointer-events-none"
style={{
left: Math.min(drawStart.x, drawEnd.x),
top: Math.min(drawStart.y, drawEnd.y),
width: Math.abs(drawEnd.x - drawStart.x),
height: Math.abs(drawEnd.y - drawStart.y),
}}
/>
)}
</div>
</div>
</div>
{/* Right Sidebar */}
<div className="w-80 bg-white border-l border-warm-border flex flex-col shadow-[-4px_0_15px_-3px_rgba(0,0,0,0.03)] z-20">
{/* Field Selector */}
<div className="p-4 border-b border-warm-border">
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Draw Annotation</h3>
<div className="space-y-2">
<label className="block text-xs text-warm-text-muted mb-1">Select Field Type</label>
<select
value={selectedClassId}
onChange={(e) => setSelectedClassId(Number(e.target.value))}
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
>
{Object.entries(FIELD_CLASSES).map(([id, name]) => (
<option key={id} value={id}>
{name.replace(/_/g, ' ')}
</option>
))}
</select>
<p className="text-xs text-warm-text-muted mt-2">
Click and drag on the document to create a bounding box
</p>
</div>
</div>
{/* Document Info Card */}
<div className="p-4 border-b border-warm-border">
<div className="bg-white rounded-lg border border-warm-border p-4 shadow-sm">
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Document Info</h3>
<div className="space-y-2">
<div className="flex justify-between text-xs">
<span className="text-warm-text-muted">Status</span>
<span className="text-warm-text-secondary font-medium capitalize">
{document.status}
</span>
</div>
<div className="flex justify-between text-xs">
<span className="text-warm-text-muted">Size</span>
<span className="text-warm-text-secondary font-medium">
{(document.file_size / 1024 / 1024).toFixed(2)} MB
</span>
</div>
<div className="flex justify-between text-xs">
<span className="text-warm-text-muted">Uploaded</span>
<span className="text-warm-text-secondary font-medium">
{new Date(document.created_at).toLocaleDateString()}
</span>
</div>
</div>
</div>
</div>
{/* Annotations List */}
<div className="flex-1 overflow-y-auto p-4">
<div className="flex items-center justify-between mb-4">
<h3 className="text-sm font-semibold text-warm-text-primary">Annotations</h3>
<span className="text-xs text-warm-text-muted">{pageAnnotations.length} items</span>
</div>
{pageAnnotations.length === 0 ? (
<div className="text-center py-8 text-warm-text-muted">
<Tag size={48} className="mx-auto mb-3 opacity-20" />
<p className="text-sm">No annotations yet</p>
<p className="text-xs mt-1">Draw on the document to add annotations</p>
</div>
) : (
<div className="space-y-3">
{pageAnnotations.map((ann) => (
<div
key={ann.annotation_id}
onClick={() => setSelectedId(ann.annotation_id)}
className={`
group p-3 rounded-md border transition-all duration-150 cursor-pointer
${
selectedId === ann.annotation_id
? 'bg-warm-bg border-warm-state-info shadow-sm'
: 'bg-white border-warm-border hover:border-warm-text-muted'
}
`}
>
<div className="flex justify-between items-start mb-1">
<span className="text-xs font-bold text-warm-text-secondary uppercase tracking-wider">
{ann.class_name.replace(/_/g, ' ')}
</span>
{selectedId === ann.annotation_id && (
<div className="flex gap-1">
<button
onClick={() => handleDeleteAnnotation(ann.annotation_id)}
className="text-warm-text-muted hover:text-warm-state-error"
disabled={isDeleting}
>
<Trash2 size={12} />
</button>
</div>
)}
</div>
<p className="text-sm text-warm-text-muted font-mono truncate">
{ann.text_value || '(no text)'}
</p>
<div className="flex items-center gap-2 mt-2">
<span
className={`text-[10px] px-1.5 py-0.5 rounded ${
ann.source === 'auto'
? 'bg-blue-50 text-blue-700'
: 'bg-green-50 text-green-700'
}`}
>
{ann.source}
</span>
{ann.confidence && (
<span className="text-[10px] text-warm-text-muted">
{(ann.confidence * 100).toFixed(0)}%
</span>
)}
</div>
</div>
))}
</div>
)}
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,466 @@
import React, { useState, useRef } from 'react'
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock } from 'lucide-react'
import { Button } from './Button'
import { inferenceApi } from '../api/endpoints'
import type { InferenceResult } from '../api/types'
export const InferenceDemo: React.FC = () => {
const [isDragging, setIsDragging] = useState(false)
const [selectedFile, setSelectedFile] = useState<File | null>(null)
const [isProcessing, setIsProcessing] = useState(false)
const [result, setResult] = useState<InferenceResult | null>(null)
const [error, setError] = useState<string | null>(null)
const fileInputRef = useRef<HTMLInputElement>(null)
const handleFileSelect = (file: File | null) => {
if (!file) return
const validTypes = ['application/pdf', 'image/png', 'image/jpeg', 'image/jpg']
if (!validTypes.includes(file.type)) {
setError('Please upload a PDF, PNG, or JPG file')
return
}
if (file.size > 50 * 1024 * 1024) {
setError('File size must be less than 50MB')
return
}
setSelectedFile(file)
setResult(null)
setError(null)
}
const handleDrop = (e: React.DragEvent) => {
e.preventDefault()
setIsDragging(false)
if (e.dataTransfer.files.length > 0) {
handleFileSelect(e.dataTransfer.files[0])
}
}
const handleBrowseClick = () => {
fileInputRef.current?.click()
}
const handleProcess = async () => {
if (!selectedFile) return
setIsProcessing(true)
setError(null)
try {
const response = await inferenceApi.processDocument(selectedFile)
console.log('API Response:', response)
console.log('Visualization URL:', response.result?.visualization_url)
setResult(response.result)
} catch (err) {
setError(err instanceof Error ? err.message : 'Processing failed')
} finally {
setIsProcessing(false)
}
}
const handleReset = () => {
setSelectedFile(null)
setResult(null)
setError(null)
}
const formatFieldName = (field: string): string => {
const fieldNames: Record<string, string> = {
InvoiceNumber: 'Invoice Number',
InvoiceDate: 'Invoice Date',
InvoiceDueDate: 'Due Date',
OCR: 'OCR Number',
Amount: 'Amount',
Bankgiro: 'Bankgiro',
Plusgiro: 'Plusgiro',
supplier_org_number: 'Supplier Org Number',
customer_number: 'Customer Number',
payment_line: 'Payment Line',
}
return fieldNames[field] || field
}
return (
<div className="max-w-7xl mx-auto px-4 py-6 space-y-6">
{/* Header */}
<div className="text-center">
<h2 className="text-3xl font-bold text-warm-text-primary mb-2">
Invoice Extraction Demo
</h2>
<p className="text-warm-text-muted">
Upload a Swedish invoice to see our AI-powered field extraction in action
</p>
</div>
{/* Upload Area */}
{!result && (
<div className="max-w-2xl mx-auto">
<div className="bg-warm-card rounded-xl border border-warm-border p-8 shadow-sm">
<div
className={`
relative h-72 rounded-xl border-2 border-dashed transition-all duration-200
${isDragging
? 'border-warm-text-secondary bg-warm-selected scale-[1.02]'
: 'border-warm-divider bg-warm-bg hover:bg-warm-hover hover:border-warm-text-secondary/50'
}
${isProcessing ? 'opacity-60 pointer-events-none' : 'cursor-pointer'}
`}
onDragOver={(e) => {
e.preventDefault()
setIsDragging(true)
}}
onDragLeave={() => setIsDragging(false)}
onDrop={handleDrop}
onClick={handleBrowseClick}
>
<div className="absolute inset-0 flex flex-col items-center justify-center gap-6">
{isProcessing ? (
<>
<Loader2 size={56} className="text-warm-text-secondary animate-spin" />
<div className="text-center">
<p className="text-lg font-semibold text-warm-text-primary mb-1">
Processing invoice...
</p>
<p className="text-sm text-warm-text-muted">
This may take a few moments
</p>
</div>
</>
) : selectedFile ? (
<>
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
<FileText size={40} className="text-warm-text-secondary" />
</div>
<div className="text-center px-4">
<p className="text-lg font-semibold text-warm-text-primary mb-1">
{selectedFile.name}
</p>
<p className="text-sm text-warm-text-muted">
{(selectedFile.size / 1024 / 1024).toFixed(2)} MB
</p>
</div>
</>
) : (
<>
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
<UploadCloud size={40} className="text-warm-text-secondary" />
</div>
<div className="text-center px-4">
<p className="text-lg font-semibold text-warm-text-primary mb-2">
Drag & drop invoice here
</p>
<p className="text-sm text-warm-text-muted mb-3">
or{' '}
<span className="text-warm-text-secondary font-medium">
browse files
</span>
</p>
<p className="text-xs text-warm-text-muted">
Supports PDF, PNG, JPG (up to 50MB)
</p>
</div>
</>
)}
</div>
</div>
<input
ref={fileInputRef}
type="file"
accept=".pdf,image/*"
className="hidden"
onChange={(e) => handleFileSelect(e.target.files?.[0] || null)}
/>
{error && (
<div className="mt-5 p-4 bg-red-50 border border-red-200 rounded-lg flex items-start gap-3">
<AlertCircle size={18} className="text-red-600 flex-shrink-0 mt-0.5" />
<span className="text-sm text-red-800 font-medium">{error}</span>
</div>
)}
{selectedFile && !isProcessing && (
<div className="mt-6 flex gap-3 justify-end">
<Button variant="secondary" onClick={handleReset}>
Cancel
</Button>
<Button onClick={handleProcess}>Process Invoice</Button>
</div>
)}
</div>
</div>
)}
{/* Results */}
{result && (
<div className="space-y-6">
{/* Status Header */}
<div className="bg-warm-card rounded-xl border border-warm-border shadow-sm overflow-hidden">
<div className="p-6 flex items-center justify-between border-b border-warm-divider">
<div className="flex items-center gap-4">
{result.success ? (
<div className="p-3 bg-green-100 rounded-xl">
<CheckCircle2 size={28} className="text-green-600" />
</div>
) : (
<div className="p-3 bg-yellow-100 rounded-xl">
<AlertCircle size={28} className="text-yellow-600" />
</div>
)}
<div>
<h3 className="text-xl font-bold text-warm-text-primary">
{result.success ? 'Extraction Complete' : 'Partial Results'}
</h3>
<p className="text-sm text-warm-text-muted mt-0.5">
Document ID: <span className="font-mono">{result.document_id}</span>
</p>
</div>
</div>
<Button variant="secondary" onClick={handleReset}>
Process Another
</Button>
</div>
<div className="px-6 py-4 bg-warm-bg/50 flex items-center gap-6 text-sm">
<div className="flex items-center gap-2 text-warm-text-secondary">
<Clock size={16} />
<span className="font-medium">
{result.processing_time_ms.toFixed(0)}ms
</span>
</div>
{result.fallback_used && (
<span className="px-3 py-1.5 bg-warm-selected rounded-md text-warm-text-secondary font-medium text-xs">
Fallback OCR Used
</span>
)}
</div>
</div>
{/* Main Content Grid */}
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
{/* Left Column: Extracted Fields */}
<div className="lg:col-span-2 space-y-6">
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
Extracted Fields
</h3>
<div className="flex flex-wrap gap-4">
{Object.entries(result.fields).map(([field, value]) => {
const confidence = result.confidence[field]
return (
<div
key={field}
className="p-4 bg-warm-bg/70 rounded-lg border border-warm-divider hover:border-warm-text-secondary/30 transition-colors w-[calc(50%-0.5rem)]"
>
<div className="text-xs font-semibold text-warm-text-muted uppercase tracking-wide mb-2">
{formatFieldName(field)}
</div>
<div className="text-sm font-bold text-warm-text-primary mb-2 min-h-[1.5rem]">
{value || <span className="text-warm-text-muted italic">N/A</span>}
</div>
{confidence && (
<div className="flex items-center gap-1.5 text-xs font-medium text-warm-text-secondary">
<CheckCircle2 size={13} />
<span>{(confidence * 100).toFixed(1)}%</span>
</div>
)}
</div>
)
})}
</div>
</div>
{/* Visualization */}
{result.visualization_url && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
Detection Visualization
</h3>
<div className="bg-warm-bg rounded-lg overflow-hidden border border-warm-divider">
<img
src={`${import.meta.env.VITE_API_URL || 'http://localhost:8000'}${result.visualization_url}`}
alt="Detection visualization"
className="w-full h-auto"
/>
</div>
</div>
)}
</div>
{/* Right Column: Cross-Validation & Errors */}
<div className="space-y-6">
{/* Cross-Validation */}
{result.cross_validation && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
Payment Line Validation
</h3>
<div
className={`
p-4 rounded-lg mb-4 flex items-center gap-3
${result.cross_validation.is_valid
? 'bg-green-50 border border-green-200'
: 'bg-yellow-50 border border-yellow-200'
}
`}
>
{result.cross_validation.is_valid ? (
<>
<CheckCircle2 size={22} className="text-green-600 flex-shrink-0" />
<span className="font-bold text-green-800">All Fields Match</span>
</>
) : (
<>
<AlertCircle size={22} className="text-yellow-600 flex-shrink-0" />
<span className="font-bold text-yellow-800">Mismatch Detected</span>
</>
)}
</div>
<div className="space-y-2.5">
{result.cross_validation.payment_line_ocr && (
<div
className={`
p-3 rounded-lg border transition-colors
${result.cross_validation.ocr_match === true
? 'bg-green-50 border-green-200'
: result.cross_validation.ocr_match === false
? 'bg-red-50 border-red-200'
: 'bg-warm-bg border-warm-divider'
}
`}
>
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="text-xs font-semibold text-warm-text-muted mb-1">
OCR NUMBER
</div>
<div className="text-sm font-bold text-warm-text-primary font-mono">
{result.cross_validation.payment_line_ocr}
</div>
</div>
{result.cross_validation.ocr_match === true && (
<CheckCircle2 size={16} className="text-green-600" />
)}
{result.cross_validation.ocr_match === false && (
<AlertCircle size={16} className="text-red-600" />
)}
</div>
</div>
)}
{result.cross_validation.payment_line_amount && (
<div
className={`
p-3 rounded-lg border transition-colors
${result.cross_validation.amount_match === true
? 'bg-green-50 border-green-200'
: result.cross_validation.amount_match === false
? 'bg-red-50 border-red-200'
: 'bg-warm-bg border-warm-divider'
}
`}
>
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="text-xs font-semibold text-warm-text-muted mb-1">
AMOUNT
</div>
<div className="text-sm font-bold text-warm-text-primary font-mono">
{result.cross_validation.payment_line_amount}
</div>
</div>
{result.cross_validation.amount_match === true && (
<CheckCircle2 size={16} className="text-green-600" />
)}
{result.cross_validation.amount_match === false && (
<AlertCircle size={16} className="text-red-600" />
)}
</div>
</div>
)}
{result.cross_validation.payment_line_account && (
<div
className={`
p-3 rounded-lg border transition-colors
${(result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === true
? 'bg-green-50 border-green-200'
: (result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === false
? 'bg-red-50 border-red-200'
: 'bg-warm-bg border-warm-divider'
}
`}
>
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="text-xs font-semibold text-warm-text-muted mb-1">
{result.cross_validation.payment_line_account_type === 'bankgiro'
? 'BANKGIRO'
: 'PLUSGIRO'}
</div>
<div className="text-sm font-bold text-warm-text-primary font-mono">
{result.cross_validation.payment_line_account}
</div>
</div>
{(result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === true && (
<CheckCircle2 size={16} className="text-green-600" />
)}
{(result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === false && (
<AlertCircle size={16} className="text-red-600" />
)}
</div>
</div>
)}
</div>
{result.cross_validation.details.length > 0 && (
<div className="mt-4 p-3 bg-warm-bg/70 rounded-lg text-xs text-warm-text-secondary leading-relaxed border border-warm-divider">
{result.cross_validation.details[result.cross_validation.details.length - 1]}
</div>
)}
</div>
)}
{/* Errors */}
{result.errors.length > 0 && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
<span className="w-1 h-5 bg-red-500 rounded-full"></span>
Issues
</h3>
<div className="space-y-2.5">
{result.errors.map((err, idx) => (
<div
key={idx}
className="p-3 bg-yellow-50 border border-yellow-200 rounded-lg flex items-start gap-3"
>
<AlertCircle size={16} className="text-yellow-600 flex-shrink-0 mt-0.5" />
<span className="text-xs text-yellow-800 leading-relaxed">{err}</span>
</div>
))}
</div>
</div>
)}
</div>
</div>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,102 @@
import React, { useState } from 'react';
import { Box, LayoutTemplate, Users, BookOpen, LogOut, Sparkles } from 'lucide-react';
interface LayoutProps {
children: React.ReactNode;
activeView: string;
onNavigate: (view: string) => void;
onLogout?: () => void;
}
export const Layout: React.FC<LayoutProps> = ({ children, activeView, onNavigate, onLogout }) => {
const [showDropdown, setShowDropdown] = useState(false);
const navItems = [
{ id: 'dashboard', label: 'Dashboard', icon: LayoutTemplate },
{ id: 'demo', label: 'Demo', icon: Sparkles },
{ id: 'training', label: 'Training', icon: Box }, // Mapped to Compliants visually in prompt, using logical name
{ id: 'documents', label: 'Documents', icon: BookOpen },
{ id: 'models', label: 'Models', icon: Users }, // Contacts in prompt, mapped to models for this use case
];
return (
<div className="min-h-screen bg-warm-bg font-sans text-warm-text-primary flex flex-col">
{/* Top Navigation */}
<nav className="h-14 bg-warm-bg border-b border-warm-border px-6 flex items-center justify-between shrink-0 sticky top-0 z-40">
<div className="flex items-center gap-8">
{/* Logo */}
<div className="flex items-center gap-2">
<div className="w-8 h-8 bg-warm-text-primary rounded-full flex items-center justify-center text-white">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="3" strokeLinecap="round" strokeLinejoin="round">
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/>
</svg>
</div>
</div>
{/* Nav Links */}
<div className="flex h-14">
{navItems.map(item => {
const isActive = activeView === item.id || (activeView === 'detail' && item.id === 'documents');
return (
<button
key={item.id}
onClick={() => onNavigate(item.id)}
className={`
relative px-4 h-full flex items-center text-sm font-medium transition-colors
${isActive ? 'text-warm-text-primary' : 'text-warm-text-muted hover:text-warm-text-secondary'}
`}
>
{item.label}
{isActive && (
<div className="absolute bottom-0 left-0 right-0 h-0.5 bg-warm-text-secondary rounded-t-full mx-2" />
)}
</button>
);
})}
</div>
</div>
{/* User Profile */}
<div className="flex items-center gap-3 pl-6 border-l border-warm-border h-6 relative">
<button
onClick={() => setShowDropdown(!showDropdown)}
className="w-8 h-8 rounded-full bg-warm-selected flex items-center justify-center text-xs font-semibold text-warm-text-secondary border border-warm-divider hover:bg-warm-hover transition-colors"
>
AD
</button>
{showDropdown && (
<>
<div
className="fixed inset-0 z-10"
onClick={() => setShowDropdown(false)}
/>
<div className="absolute right-0 top-10 w-48 bg-warm-card border border-warm-border rounded-lg shadow-modal z-20">
<div className="p-3 border-b border-warm-border">
<p className="text-sm font-medium text-warm-text-primary">Admin User</p>
<p className="text-xs text-warm-text-muted mt-0.5">Authenticated</p>
</div>
{onLogout && (
<button
onClick={() => {
setShowDropdown(false)
onLogout()
}}
className="w-full px-3 py-2 text-left text-sm text-warm-text-secondary hover:bg-warm-hover transition-colors flex items-center gap-2"
>
<LogOut size={14} />
Sign Out
</button>
)}
</div>
</>
)}
</div>
</nav>
{/* Main Content */}
<main className="flex-1 overflow-auto">
{children}
</main>
</div>
);
};

View File

@@ -0,0 +1,188 @@
import React, { useState } from 'react'
import { Button } from './Button'
interface LoginProps {
onLogin: (token: string) => void
}
export const Login: React.FC<LoginProps> = ({ onLogin }) => {
const [token, setToken] = useState('')
const [name, setName] = useState('')
const [description, setDescription] = useState('')
const [isCreating, setIsCreating] = useState(false)
const [error, setError] = useState('')
const [createdToken, setCreatedToken] = useState('')
const handleLoginWithToken = () => {
if (!token.trim()) {
setError('Please enter a token')
return
}
localStorage.setItem('admin_token', token.trim())
onLogin(token.trim())
}
const handleCreateToken = async () => {
if (!name.trim()) {
setError('Please enter a token name')
return
}
setIsCreating(true)
setError('')
try {
const response = await fetch('http://localhost:8000/api/v1/admin/auth/token', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
name: name.trim(),
description: description.trim() || undefined,
}),
})
if (!response.ok) {
throw new Error('Failed to create token')
}
const data = await response.json()
setCreatedToken(data.token)
setToken(data.token)
setError('')
} catch (err) {
setError('Failed to create token. Please check your connection.')
console.error(err)
} finally {
setIsCreating(false)
}
}
const handleUseCreatedToken = () => {
if (createdToken) {
localStorage.setItem('admin_token', createdToken)
onLogin(createdToken)
}
}
return (
<div className="min-h-screen bg-warm-bg flex items-center justify-center p-4">
<div className="bg-warm-card border border-warm-border rounded-lg shadow-modal p-8 max-w-md w-full">
<h1 className="text-2xl font-bold text-warm-text-primary mb-2">
Admin Authentication
</h1>
<p className="text-sm text-warm-text-muted mb-6">
Sign in with an admin token to access the document management system
</p>
{error && (
<div className="mb-4 p-3 bg-red-50 border border-red-200 text-red-800 rounded text-sm">
{error}
</div>
)}
{createdToken && (
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded">
<p className="text-sm font-medium text-green-800 mb-2">Token created successfully!</p>
<div className="bg-white border border-green-300 rounded p-2 mb-3">
<code className="text-xs font-mono text-warm-text-primary break-all">
{createdToken}
</code>
</div>
<p className="text-xs text-green-700 mb-3">
Save this token securely. You won't be able to see it again.
</p>
<Button onClick={handleUseCreatedToken} className="w-full">
Use This Token
</Button>
</div>
)}
<div className="space-y-6">
{/* Login with existing token */}
<div>
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
Sign in with existing token
</h2>
<div className="space-y-3">
<div>
<label className="block text-sm text-warm-text-secondary mb-1">
Admin Token
</label>
<input
type="text"
value={token}
onChange={(e) => setToken(e.target.value)}
placeholder="Enter your admin token"
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info font-mono"
onKeyDown={(e) => e.key === 'Enter' && handleLoginWithToken()}
/>
</div>
<Button onClick={handleLoginWithToken} className="w-full">
Sign In
</Button>
</div>
</div>
<div className="relative">
<div className="absolute inset-0 flex items-center">
<div className="w-full border-t border-warm-border"></div>
</div>
<div className="relative flex justify-center text-xs">
<span className="px-2 bg-warm-card text-warm-text-muted">OR</span>
</div>
</div>
{/* Create new token */}
<div>
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
Create new admin token
</h2>
<div className="space-y-3">
<div>
<label className="block text-sm text-warm-text-secondary mb-1">
Token Name <span className="text-red-500">*</span>
</label>
<input
type="text"
value={name}
onChange={(e) => setName(e.target.value)}
placeholder="e.g., my-laptop"
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<div>
<label className="block text-sm text-warm-text-secondary mb-1">
Description (optional)
</label>
<input
type="text"
value={description}
onChange={(e) => setDescription(e.target.value)}
placeholder="e.g., Personal laptop access"
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<Button
onClick={handleCreateToken}
variant="secondary"
disabled={isCreating}
className="w-full"
>
{isCreating ? 'Creating...' : 'Create Token'}
</Button>
</div>
</div>
</div>
<div className="mt-6 pt-4 border-t border-warm-border">
<p className="text-xs text-warm-text-muted">
Admin tokens are used to authenticate with the document management API.
Keep your tokens secure and never share them.
</p>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,134 @@
import React from 'react';
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
import { Button } from './Button';
const CHART_DATA = [
{ name: 'Model A', value: 75 },
{ name: 'Model B', value: 82 },
{ name: 'Model C', value: 95 },
{ name: 'Model D', value: 68 },
];
const METRICS_DATA = [
{ name: 'Precision', value: 88 },
{ name: 'Recall', value: 76 },
{ name: 'F1 Score', value: 91 },
{ name: 'Accuracy', value: 82 },
];
const JOBS = [
{ id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
{ id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
{ id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
];
export const Models: React.FC = () => {
return (
<div className="p-8 max-w-7xl mx-auto flex gap-8">
{/* Left: Job History */}
<div className="flex-1">
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Recent Training Jobs</h3>
<div className="space-y-4">
{JOBS.map(job => (
<div key={job.id} className="bg-warm-card border border-warm-border rounded-lg p-5 shadow-sm hover:border-warm-divider transition-colors">
<div className="flex justify-between items-start mb-2">
<div>
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">{job.name}</h4>
<p className="text-sm text-warm-text-muted">Started {job.date}</p>
</div>
<span className={`px-3 py-1 rounded-full text-xs font-medium ${job.status === 'Running' ? 'bg-warm-selected text-warm-text-secondary' : 'bg-warm-selected text-warm-state-success'}`}>
{job.status}
</span>
</div>
{job.status === 'Running' ? (
<div className="mt-4">
<div className="h-2 w-full bg-warm-selected rounded-full overflow-hidden">
<div className="h-full bg-warm-text-secondary w-[65%] rounded-full"></div>
</div>
</div>
) : (
<div className="mt-4 flex gap-8">
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Success</span>
<span className="text-lg font-mono text-warm-text-secondary">{job.success}</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Performance</span>
<span className="text-lg font-mono text-warm-text-secondary">{job.metrics}%</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Completed</span>
<span className="text-lg font-mono text-warm-text-secondary">100%</span>
</div>
</div>
)}
</div>
))}
</div>
</div>
{/* Right: Model Detail */}
<div className="w-[400px]">
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
<div className="flex justify-between items-center mb-6">
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
<span className="text-sm font-medium text-warm-state-success">Completed</span>
</div>
<div className="mb-8">
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
<p className="font-medium text-warm-text-primary">Invoices Q4 v2.1</p>
</div>
<div className="space-y-8">
{/* Chart 1 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Bar Rate Metrics</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={CHART_DATA}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" hide />
<YAxis hide domain={[0, 100]} />
<Tooltip
cursor={{fill: '#F1F0ED'}}
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
/>
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
</ResponsiveContainer>
</div>
</div>
{/* Chart 2 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Entity Extraction Accuracy</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={METRICS_DATA}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
<YAxis hide domain={[0, 100]} />
<Tooltip cursor={{fill: '#F1F0ED'}} />
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
</ResponsiveContainer>
</div>
</div>
</div>
<div className="mt-8 space-y-3">
<Button className="w-full">Download Model</Button>
<div className="flex gap-3">
<Button variant="secondary" className="flex-1">View Logs</Button>
<Button variant="secondary" className="flex-1">Use as Base</Button>
</div>
</div>
</div>
</div>
</div>
);
};

View File

@@ -0,0 +1,113 @@
import React, { useState } from 'react';
import { Check, AlertCircle } from 'lucide-react';
import { Button } from './Button';
import { DocumentStatus } from '../types';
export const Training: React.FC = () => {
const [split, setSplit] = useState(80);
const docs = [
{ id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
{ id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
{ id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
{ id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
{ id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
];
return (
<div className="p-8 max-w-7xl mx-auto h-[calc(100vh-56px)] flex gap-8">
{/* Document Selection List */}
<div className="flex-1 flex flex-col">
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Document Selection</h2>
<div className="flex-1 bg-warm-card border border-warm-border rounded-lg overflow-hidden flex flex-col shadow-sm">
<div className="overflow-auto flex-1">
<table className="w-full text-left">
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
<tr>
<th className="py-3 pl-6 pr-4 w-12"><input type="checkbox" className="rounded border-warm-divider"/></th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document name</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Date</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
</tr>
</thead>
<tbody>
{docs.map(doc => (
<tr key={doc.id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 pl-6 pr-4"><input type="checkbox" defaultChecked className="rounded border-warm-divider accent-warm-state-info"/></td>
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{doc.name}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.date}</td>
<td className="py-3 px-4">
{doc.status === DocumentStatus.VERIFIED ? (
<div className="flex items-center text-warm-state-success text-sm font-medium">
<div className="w-5 h-5 rounded-full bg-warm-state-success flex items-center justify-center text-white mr-2">
<Check size={12} strokeWidth={3}/>
</div>
Verified
</div>
) : (
<div className="flex items-center text-warm-text-muted text-sm">
<div className="w-5 h-5 rounded-full bg-[#BDBBB5] flex items-center justify-center text-white mr-2">
<span className="font-bold text-[10px]">!</span>
</div>
Partial
</div>
)}
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
</div>
{/* Configuration Panel */}
<div className="w-96">
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
<h3 className="text-lg font-semibold text-warm-text-primary mb-6">Training Configuration</h3>
<div className="space-y-6">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Model Name</label>
<input
type="text"
placeholder="e.g. Invoices Q4"
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Base Model</label>
<select className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none">
<option>LayoutLMv3 (Standard)</option>
<option>Donut (Beta)</option>
</select>
</div>
<div>
<div className="flex justify-between mb-2">
<label className="block text-sm font-medium text-warm-text-secondary">Train/Test Split</label>
<span className="text-xs font-mono text-warm-text-muted">{split}% / {100-split}%</span>
</div>
<input
type="range"
min="50"
max="95"
value={split}
onChange={(e) => setSplit(parseInt(e.target.value))}
className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
/>
</div>
<div className="pt-4 border-t border-warm-border">
<Button className="w-full h-12">Start Training</Button>
</div>
</div>
</div>
</div>
</div>
);
};

View File

@@ -0,0 +1,210 @@
import React, { useState, useRef } from 'react'
import { X, UploadCloud, File, CheckCircle, AlertCircle } from 'lucide-react'
import { Button } from './Button'
import { useDocuments } from '../hooks/useDocuments'
interface UploadModalProps {
isOpen: boolean
onClose: () => void
}
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
const [isDragging, setIsDragging] = useState(false)
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
const [errorMessage, setErrorMessage] = useState('')
const fileInputRef = useRef<HTMLInputElement>(null)
const { uploadDocument, isUploading } = useDocuments({})
if (!isOpen) return null
const handleFileSelect = (files: FileList | null) => {
if (!files) return
const pdfFiles = Array.from(files).filter(file => {
const isPdf = file.type === 'application/pdf'
const isImage = file.type.startsWith('image/')
const isUnder25MB = file.size <= 25 * 1024 * 1024
return (isPdf || isImage) && isUnder25MB
})
setSelectedFiles(prev => [...prev, ...pdfFiles])
setUploadStatus('idle')
setErrorMessage('')
}
const handleDrop = (e: React.DragEvent) => {
e.preventDefault()
setIsDragging(false)
handleFileSelect(e.dataTransfer.files)
}
const handleBrowseClick = () => {
fileInputRef.current?.click()
}
const removeFile = (index: number) => {
setSelectedFiles(prev => prev.filter((_, i) => i !== index))
}
const handleUpload = async () => {
if (selectedFiles.length === 0) {
setErrorMessage('Please select at least one file')
return
}
setUploadStatus('uploading')
setErrorMessage('')
try {
// Upload files one by one
for (const file of selectedFiles) {
await new Promise<void>((resolve, reject) => {
uploadDocument(file, {
onSuccess: () => resolve(),
onError: (error: Error) => reject(error),
})
})
}
setUploadStatus('success')
setTimeout(() => {
onClose()
setSelectedFiles([])
setUploadStatus('idle')
}, 1500)
} catch (error) {
setUploadStatus('error')
setErrorMessage(error instanceof Error ? error.message : 'Upload failed')
}
}
const handleClose = () => {
if (uploadStatus === 'uploading') {
return // Prevent closing during upload
}
setSelectedFiles([])
setUploadStatus('idle')
setErrorMessage('')
onClose()
}
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/20 backdrop-blur-sm transition-opacity duration-200">
<div
className="w-full max-w-lg bg-warm-card rounded-lg shadow-modal border border-warm-border transform transition-all duration-200 scale-100 p-6"
onClick={(e) => e.stopPropagation()}
>
<div className="flex items-center justify-between mb-6">
<h3 className="text-xl font-semibold text-warm-text-primary">Upload Documents</h3>
<button
onClick={handleClose}
className="text-warm-text-muted hover:text-warm-text-primary transition-colors disabled:opacity-50"
disabled={uploadStatus === 'uploading'}
>
<X size={20} />
</button>
</div>
{/* Drop Zone */}
<div
className={`
w-full h-48 rounded-lg border-2 border-dashed flex flex-col items-center justify-center gap-3 transition-colors duration-150 mb-6 cursor-pointer
${isDragging ? 'border-warm-text-secondary bg-warm-selected' : 'border-warm-divider bg-warm-bg hover:bg-warm-hover'}
${uploadStatus === 'uploading' ? 'opacity-50 pointer-events-none' : ''}
`}
onDragOver={(e) => { e.preventDefault(); setIsDragging(true); }}
onDragLeave={() => setIsDragging(false)}
onDrop={handleDrop}
onClick={handleBrowseClick}
>
<div className="p-3 bg-white rounded-full shadow-sm">
<UploadCloud size={24} className="text-warm-text-secondary" />
</div>
<div className="text-center">
<p className="text-sm font-medium text-warm-text-primary">
Drag & drop files here or <span className="underline decoration-1 underline-offset-2 hover:text-warm-state-info">Browse</span>
</p>
<p className="text-xs text-warm-text-muted mt-1">PDF, JPG, PNG up to 25MB</p>
</div>
</div>
<input
ref={fileInputRef}
type="file"
multiple
accept=".pdf,image/*"
className="hidden"
onChange={(e) => handleFileSelect(e.target.files)}
/>
{/* Selected Files */}
{selectedFiles.length > 0 && (
<div className="mb-6 max-h-40 overflow-y-auto">
<p className="text-sm font-medium text-warm-text-secondary mb-2">
Selected Files ({selectedFiles.length})
</p>
<div className="space-y-2">
{selectedFiles.map((file, index) => (
<div
key={index}
className="flex items-center justify-between p-2 bg-warm-bg rounded border border-warm-border"
>
<div className="flex items-center gap-2 flex-1 min-w-0">
<File size={16} className="text-warm-text-muted flex-shrink-0" />
<span className="text-sm text-warm-text-secondary truncate">
{file.name}
</span>
<span className="text-xs text-warm-text-muted flex-shrink-0">
({(file.size / 1024 / 1024).toFixed(2)} MB)
</span>
</div>
<button
onClick={() => removeFile(index)}
className="text-warm-text-muted hover:text-warm-state-error ml-2 flex-shrink-0"
disabled={uploadStatus === 'uploading'}
>
<X size={16} />
</button>
</div>
))}
</div>
</div>
)}
{/* Status Messages */}
{uploadStatus === 'success' && (
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
<CheckCircle size={16} className="text-green-600" />
<span className="text-sm text-green-800">Upload successful!</span>
</div>
)}
{uploadStatus === 'error' && errorMessage && (
<div className="mb-4 p-3 bg-red-50 border border-red-200 rounded flex items-center gap-2">
<AlertCircle size={16} className="text-red-600" />
<span className="text-sm text-red-800">{errorMessage}</span>
</div>
)}
{/* Actions */}
<div className="mt-8 flex justify-end gap-3">
<Button
variant="secondary"
onClick={handleClose}
disabled={uploadStatus === 'uploading'}
>
Cancel
</Button>
<Button
onClick={handleUpload}
disabled={selectedFiles.length === 0 || uploadStatus === 'uploading'}
>
{uploadStatus === 'uploading' ? 'Uploading...' : `Upload ${selectedFiles.length > 0 ? `(${selectedFiles.length})` : ''}`}
</Button>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,4 @@
export { useDocuments } from './useDocuments'
export { useDocumentDetail } from './useDocumentDetail'
export { useAnnotations } from './useAnnotations'
export { useTraining, useTrainingDocuments } from './useTraining'

View File

@@ -0,0 +1,70 @@
import { useMutation, useQueryClient } from '@tanstack/react-query'
import { annotationsApi } from '../api/endpoints'
import type { CreateAnnotationRequest, AnnotationOverrideRequest } from '../api/types'
export const useAnnotations = (documentId: string) => {
const queryClient = useQueryClient()
const createMutation = useMutation({
mutationFn: (annotation: CreateAnnotationRequest) =>
annotationsApi.create(documentId, annotation),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const updateMutation = useMutation({
mutationFn: ({
annotationId,
updates,
}: {
annotationId: string
updates: Partial<CreateAnnotationRequest>
}) => annotationsApi.update(documentId, annotationId, updates),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const deleteMutation = useMutation({
mutationFn: (annotationId: string) =>
annotationsApi.delete(documentId, annotationId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const verifyMutation = useMutation({
mutationFn: (annotationId: string) =>
annotationsApi.verify(documentId, annotationId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const overrideMutation = useMutation({
mutationFn: ({
annotationId,
overrideData,
}: {
annotationId: string
overrideData: AnnotationOverrideRequest
}) => annotationsApi.override(documentId, annotationId, overrideData),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
return {
createAnnotation: createMutation.mutate,
isCreating: createMutation.isPending,
updateAnnotation: updateMutation.mutate,
isUpdating: updateMutation.isPending,
deleteAnnotation: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
verifyAnnotation: verifyMutation.mutate,
isVerifying: verifyMutation.isPending,
overrideAnnotation: overrideMutation.mutate,
isOverriding: overrideMutation.isPending,
}
}

View File

@@ -0,0 +1,25 @@
import { useQuery } from '@tanstack/react-query'
import { documentsApi } from '../api/endpoints'
import type { DocumentDetailResponse } from '../api/types'
export const useDocumentDetail = (documentId: string | null) => {
const { data, isLoading, error, refetch } = useQuery<DocumentDetailResponse>({
queryKey: ['document', documentId],
queryFn: () => {
if (!documentId) {
throw new Error('Document ID is required')
}
return documentsApi.getDetail(documentId)
},
enabled: !!documentId,
staleTime: 10000,
})
return {
document: data || null,
annotations: data?.annotations || [],
isLoading,
error,
refetch,
}
}

View File

@@ -0,0 +1,78 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { documentsApi } from '../api/endpoints'
import type { DocumentListResponse, UploadDocumentResponse } from '../api/types'
interface UseDocumentsParams {
status?: string
limit?: number
offset?: number
}
export const useDocuments = (params: UseDocumentsParams = {}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<DocumentListResponse>({
queryKey: ['documents', params],
queryFn: () => documentsApi.list(params),
staleTime: 30000,
})
const uploadMutation = useMutation({
mutationFn: (file: File) => documentsApi.upload(file),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const batchUploadMutation = useMutation({
mutationFn: ({ files, csvFile }: { files: File[]; csvFile?: File }) =>
documentsApi.batchUpload(files, csvFile),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const deleteMutation = useMutation({
mutationFn: (documentId: string) => documentsApi.delete(documentId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const updateStatusMutation = useMutation({
mutationFn: ({ documentId, status }: { documentId: string; status: string }) =>
documentsApi.updateStatus(documentId, status),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const triggerAutoLabelMutation = useMutation({
mutationFn: (documentId: string) => documentsApi.triggerAutoLabel(documentId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
return {
documents: data?.documents || [],
total: data?.total || 0,
limit: data?.limit || params.limit || 20,
offset: data?.offset || params.offset || 0,
isLoading,
error,
refetch,
uploadDocument: uploadMutation.mutate,
uploadDocumentAsync: uploadMutation.mutateAsync,
isUploading: uploadMutation.isPending,
batchUpload: batchUploadMutation.mutate,
batchUploadAsync: batchUploadMutation.mutateAsync,
isBatchUploading: batchUploadMutation.isPending,
deleteDocument: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
updateStatus: updateStatusMutation.mutate,
isUpdatingStatus: updateStatusMutation.isPending,
triggerAutoLabel: triggerAutoLabelMutation.mutate,
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
}
}

View File

@@ -0,0 +1,83 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { trainingApi } from '../api/endpoints'
import type { TrainingModelsResponse } from '../api/types'
export const useTraining = () => {
const queryClient = useQueryClient()
const { data: modelsData, isLoading: isLoadingModels } =
useQuery<TrainingModelsResponse>({
queryKey: ['training', 'models'],
queryFn: () => trainingApi.getModels(),
staleTime: 30000,
})
const startTrainingMutation = useMutation({
mutationFn: (config: {
name: string
description?: string
document_ids: string[]
epochs?: number
batch_size?: number
model_base?: string
}) => trainingApi.startTraining(config),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
},
})
const cancelTaskMutation = useMutation({
mutationFn: (taskId: string) => trainingApi.cancelTask(taskId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
},
})
const downloadModelMutation = useMutation({
mutationFn: (taskId: string) => trainingApi.downloadModel(taskId),
onSuccess: (blob, taskId) => {
const url = window.URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `model-${taskId}.pt`
document.body.appendChild(a)
a.click()
window.URL.revokeObjectURL(url)
document.body.removeChild(a)
},
})
return {
models: modelsData?.models || [],
total: modelsData?.total || 0,
isLoadingModels,
startTraining: startTrainingMutation.mutate,
startTrainingAsync: startTrainingMutation.mutateAsync,
isStartingTraining: startTrainingMutation.isPending,
cancelTask: cancelTaskMutation.mutate,
isCancelling: cancelTaskMutation.isPending,
downloadModel: downloadModelMutation.mutate,
isDownloading: downloadModelMutation.isPending,
}
}
export const useTrainingDocuments = (params?: {
has_annotations?: boolean
min_annotation_count?: number
exclude_used_in_training?: boolean
limit?: number
offset?: number
}) => {
const { data, isLoading, error } = useQuery({
queryKey: ['training', 'documents', params],
queryFn: () => trainingApi.getDocumentsForTraining(params),
staleTime: 30000,
})
return {
documents: data?.documents || [],
total: data?.total || 0,
isLoading,
error,
}
}

23
frontend/src/main.tsx Normal file
View File

@@ -0,0 +1,23 @@
import React from 'react'
import ReactDOM from 'react-dom/client'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import App from './App'
import './styles/index.css'
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: 1,
refetchOnWindowFocus: false,
staleTime: 30000,
},
},
})
ReactDOM.createRoot(document.getElementById('root')!).render(
<React.StrictMode>
<QueryClientProvider client={queryClient}>
<App />
</QueryClientProvider>
</React.StrictMode>
)

View File

@@ -0,0 +1,26 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
@layer base {
body {
@apply bg-warm-bg text-warm-text-primary;
}
/* Custom scrollbar */
::-webkit-scrollbar {
@apply w-2 h-2;
}
::-webkit-scrollbar-track {
@apply bg-transparent;
}
::-webkit-scrollbar-thumb {
@apply bg-warm-divider rounded;
}
::-webkit-scrollbar-thumb:hover {
@apply bg-warm-text-disabled;
}
}

View File

@@ -0,0 +1,48 @@
// Legacy types for backward compatibility with old components
// These will be gradually replaced with API types
export enum DocumentStatus {
PENDING = 'Pending',
LABELED = 'Labeled',
VERIFIED = 'Verified',
PARTIAL = 'Partial'
}
export interface Document {
id: string
name: string
date: string
status: DocumentStatus
exported: boolean
autoLabelProgress?: number
autoLabelStatus?: 'Running' | 'Completed' | 'Failed'
}
export interface Annotation {
id: string
text: string
label: string
x: number
y: number
width: number
height: number
isAuto?: boolean
}
export interface TrainingJob {
id: string
name: string
startDate: string
status: 'Running' | 'Completed' | 'Failed'
progress: number
metrics?: {
accuracy: number
precision: number
recall: number
}
}
export interface ModelMetric {
name: string
value: number
}

View File

@@ -0,0 +1,47 @@
export default {
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
theme: {
extend: {
fontFamily: {
sans: ['Inter', 'SF Pro', 'system-ui', 'sans-serif'],
mono: ['JetBrains Mono', 'SF Mono', 'monospace'],
},
colors: {
warm: {
bg: '#FAFAF8',
card: '#FFFFFF',
hover: '#F1F0ED',
selected: '#ECEAE6',
border: '#E6E4E1',
divider: '#D8D6D2',
text: {
primary: '#121212',
secondary: '#2A2A2A',
muted: '#6B6B6B',
disabled: '#9A9A9A',
},
state: {
success: '#3E4A3A',
error: '#4A3A3A',
warning: '#4A4A3A',
info: '#3A3A3A',
}
}
},
boxShadow: {
'card': '0 1px 3px rgba(0,0,0,0.08)',
'modal': '0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)',
},
animation: {
'fade-in': 'fadeIn 0.3s ease-out',
},
keyframes: {
fadeIn: {
'0%': { opacity: '0', transform: 'translateY(10px)' },
'100%': { opacity: '1', transform: 'translateY(0)' },
}
}
}
},
plugins: [],
}

29
frontend/tsconfig.json Normal file
View File

@@ -0,0 +1,29 @@
{
"compilerOptions": {
"target": "ES2022",
"experimentalDecorators": true,
"useDefineForClassFields": false,
"module": "ESNext",
"lib": [
"ES2022",
"DOM",
"DOM.Iterable"
],
"skipLibCheck": true,
"types": [
"node"
],
"moduleResolution": "bundler",
"isolatedModules": true,
"moduleDetection": "force",
"allowJs": true,
"jsx": "react-jsx",
"paths": {
"@/*": [
"./*"
]
},
"allowImportingTsExtensions": true,
"noEmit": true
}
}

16
frontend/vite.config.ts Normal file
View File

@@ -0,0 +1,16 @@
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
export default defineConfig({
server: {
port: 3000,
host: '0.0.0.0',
proxy: {
'/api': {
target: 'http://localhost:8000',
changeOrigin: true,
},
},
},
plugins: [react()],
});

View File

@@ -21,3 +21,7 @@ pyyaml>=6.0 # YAML config files
# Utilities
tqdm>=4.65.0 # Progress bars
python-dotenv>=1.0.0 # Environment variable management
# Database
psycopg2-binary>=2.9.0 # PostgreSQL driver
sqlmodel>=0.0.22 # SQLModel ORM (SQLAlchemy + Pydantic)

View File

@@ -16,7 +16,7 @@ from pathlib import Path
from typing import Optional
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string
from src.config import get_db_connection_string
from ..normalize import normalize_field
from ..matcher import FieldMatcher

View File

@@ -12,7 +12,7 @@ from collections import defaultdict
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string
from src.config import get_db_connection_string
def load_reports_from_db() -> dict:

View File

@@ -34,7 +34,7 @@ if sys.platform == 'win32':
multiprocessing.set_start_method('spawn', force=True)
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS, AUTOLABEL
from src.config import get_db_connection_string, PATHS, AUTOLABEL
# Global OCR engine for worker processes (initialized once per worker)
_worker_ocr_engine = None

View File

@@ -16,7 +16,7 @@ from psycopg2.extras import execute_values
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS
from src.config import get_db_connection_string, PATHS
def create_tables(conn):

View File

@@ -10,6 +10,9 @@ import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.config import DEFAULT_DPI
def main():
parser = argparse.ArgumentParser(
@@ -38,8 +41,8 @@ def main():
parser.add_argument(
'--dpi',
type=int,
default=150,
help='DPI for PDF rendering (default: 150, must match training)'
default=DEFAULT_DPI,
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
)
parser.add_argument(
'--no-fallback',

View File

@@ -17,6 +17,7 @@ from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.config import DEFAULT_DPI
from src.data.db import DocumentDB
from src.data.csv_loader import CSVLoader
from src.normalize.normalizer import normalize_field
@@ -144,7 +145,7 @@ def process_single_document(args):
ocr_engine = OCREngine()
for page_no in range(pdf_doc.page_count):
# Render page to image
img = pdf_doc.render_page(page_no, dpi=150)
img = pdf_doc.render_page(page_no, dpi=DEFAULT_DPI)
if img is None:
continue

View File

@@ -15,6 +15,8 @@ from pathlib import Path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.config import DEFAULT_DPI
def setup_logging(debug: bool = False) -> None:
"""Configure logging."""
@@ -65,8 +67,8 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--dpi",
type=int,
default=150,
help="DPI for PDF rendering (must match training DPI)",
default=DEFAULT_DPI,
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
)
parser.add_argument(

View File

@@ -11,7 +11,7 @@ import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import PATHS
from src.config import DEFAULT_DPI, PATHS
def main():
@@ -103,8 +103,8 @@ def main():
parser.add_argument(
'--dpi',
type=int,
default=150,
help='DPI used for rendering (default: 150, must match autolabel rendering)'
default=DEFAULT_DPI,
help=f'DPI used for rendering (default: {DEFAULT_DPI}, must match autolabel rendering)'
)
parser.add_argument(
'--export-only',

View File

@@ -8,9 +8,13 @@ from pathlib import Path
from dotenv import load_dotenv
# Load environment variables from .env file
env_path = Path(__file__).parent / '.env'
# .env is at project root, config.py is in src/
env_path = Path(__file__).parent.parent / '.env'
load_dotenv(dotenv_path=env_path)
# Global DPI setting - must match training DPI for optimal model performance
DEFAULT_DPI = 150
def _is_wsl() -> bool:
"""Check if running inside WSL (Windows Subsystem for Linux)."""
@@ -69,7 +73,7 @@ else:
# Auto-labeling Configuration
AUTOLABEL = {
'workers': 2,
'dpi': 150,
'dpi': DEFAULT_DPI,
'min_confidence': 0.5,
'train_ratio': 0.8,
'val_ratio': 0.1,

1156
src/data/admin_db.py Normal file

File diff suppressed because it is too large Load Diff

339
src/data/admin_models.py Normal file
View File

@@ -0,0 +1,339 @@
"""
Admin API SQLModel Database Models
Defines the database schema for admin document management, annotations, and training tasks.
Includes batch upload support, training document links, and annotation history.
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
# =============================================================================
# CSV to Field Class Mapping
# =============================================================================
CSV_TO_CLASS_MAPPING: dict[str, int] = {
"InvoiceNumber": 0, # invoice_number
"InvoiceDate": 1, # invoice_date
"InvoiceDueDate": 2, # invoice_due_date
"OCR": 3, # ocr_number
"Bankgiro": 4, # bankgiro
"Plusgiro": 5, # plusgiro
"Amount": 6, # amount
"supplier_organisation_number": 7, # supplier_organisation_number
# 8: payment_line (derived from OCR/Bankgiro/Amount)
"customer_number": 9, # customer_number
}
# =============================================================================
# Core Models
# =============================================================================
class AdminToken(SQLModel, table=True):
"""Admin authentication token."""
__tablename__ = "admin_tokens"
token: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
expires_at: datetime | None = Field(default=None)
class AdminDocument(SQLModel, table=True):
"""Document uploaded for labeling/annotation."""
__tablename__ = "admin_documents"
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
file_path: str = Field(max_length=512) # Path to stored file
page_count: int = Field(default=1)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, auto_labeling, labeled, exported
auto_label_status: str | None = Field(default=None, max_length=20)
# Auto-label status: running, completed, failed
auto_label_error: str | None = Field(default=None)
# v2: Upload source tracking
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
batch_id: UUID | None = Field(default=None, index=True)
# Link to batch upload (if uploaded via ZIP)
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None)
# When auto-label was queued
annotation_lock_until: datetime | None = Field(default=None)
# Lock for manual annotation while auto-label runs
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class AdminAnnotation(SQLModel, table=True):
"""Annotation for a document (bounding box + label)."""
__tablename__ = "admin_annotations"
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
page_number: int = Field(default=1) # 1-indexed
class_id: int # 0-9 for invoice fields
class_name: str = Field(max_length=50) # e.g., "invoice_number"
# Bounding box (normalized 0-1 coordinates)
x_center: float
y_center: float
width: float
height: float
# Original pixel coordinates (for display)
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
# OCR extracted text (if available)
text_value: str | None = Field(default=None)
confidence: float | None = Field(default=None)
# Source: manual, auto, imported
source: str = Field(default="manual", max_length=20, index=True)
# v2: Verification fields
is_verified: bool = Field(default=False, index=True)
verified_at: datetime | None = Field(default=None)
verified_by: str | None = Field(default=None, max_length=255)
# v2: Override tracking
override_source: str | None = Field(default=None, max_length=20)
# If this annotation overrides another: 'auto' or 'imported'
original_annotation_id: UUID | None = Field(default=None)
# Reference to the annotation this overrides
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingTask(SQLModel, table=True):
"""Training/fine-tuning task."""
__tablename__ = "training_tasks"
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, scheduled, running, completed, failed, cancelled
task_type: str = Field(default="train", max_length=20)
# Task type: train, finetune
# Training configuration
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Schedule settings
scheduled_at: datetime | None = Field(default=None)
cron_expression: str | None = Field(default=None, max_length=50)
is_recurring: bool = Field(default=False)
# Execution details
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
error_message: str | None = Field(default=None)
# Result metrics
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
model_path: str | None = Field(default=None, max_length=512)
# v2: Document count and extracted metrics
document_count: int = Field(default=0)
# Count of documents used in training
metrics_mAP: float | None = Field(default=None, index=True)
metrics_precision: float | None = Field(default=None)
metrics_recall: float | None = Field(default=None)
# Extracted metrics for easy querying
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingLog(SQLModel, table=True):
"""Training log entry."""
__tablename__ = "training_logs"
log_id: int | None = Field(default=None, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
level: str = Field(max_length=20) # INFO, WARNING, ERROR
message: str
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# =============================================================================
# Batch Upload Models (v2)
# =============================================================================
class BatchUpload(SQLModel, table=True):
"""Batch upload of multiple documents via ZIP file."""
__tablename__ = "batch_uploads"
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
filename: str = Field(max_length=255) # ZIP filename
file_size: int
upload_source: str = Field(default="ui", max_length=20)
# Upload source: ui, api
status: str = Field(default="processing", max_length=20, index=True)
# Status: processing, completed, partial, failed
total_files: int = Field(default=0)
processed_files: int = Field(default=0)
# Number of files processed so far
successful_files: int = Field(default=0)
failed_files: int = Field(default=0)
csv_filename: str | None = Field(default=None, max_length=255)
# CSV file used for auto-labeling
csv_row_count: int | None = Field(default=None)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
completed_at: datetime | None = Field(default=None)
class BatchUploadFile(SQLModel, table=True):
"""Individual file within a batch upload."""
__tablename__ = "batch_upload_files"
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
filename: str = Field(max_length=255) # PDF filename within ZIP
document_id: UUID | None = Field(default=None)
# Link to created AdminDocument (if successful)
status: str = Field(default="pending", max_length=20, index=True)
# Status: pending, processing, completed, failed, skipped
error_message: str | None = Field(default=None)
annotation_count: int = Field(default=0)
# Number of annotations created for this file
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# CSV row data for this file (if available)
created_at: datetime = Field(default_factory=datetime.utcnow)
processed_at: datetime | None = Field(default=None)
# =============================================================================
# Training Document Link (v2)
# =============================================================================
class TrainingDocumentLink(SQLModel, table=True):
"""Junction table linking training tasks to documents."""
__tablename__ = "training_document_links"
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Snapshot of annotations at training time (includes count, verified count, etc.)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================
class AnnotationHistory(SQLModel, table=True):
"""History of annotation changes (for override tracking)."""
__tablename__ = "annotation_history"
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
# Change action: created, updated, deleted, override
action: str = Field(max_length=20, index=True)
# Previous value (for updates/deletes)
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# New value (for creates/updates)
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Change metadata
changed_by: str | None = Field(default=None, max_length=255)
# User/token who made the change
change_reason: str | None = Field(default=None)
# Optional reason for change
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Field class mapping (same as src/cli/train.py)
FIELD_CLASSES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_organisation_number",
8: "payment_line",
9: "customer_number",
}
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
# Read-only models for API responses
class AdminDocumentRead(SQLModel):
"""Admin document response model."""
document_id: UUID
filename: str
file_size: int
content_type: str
page_count: int
status: str
auto_label_status: str | None
auto_label_error: str | None
created_at: datetime
updated_at: datetime
class AdminAnnotationRead(SQLModel):
"""Admin annotation response model."""
annotation_id: UUID
document_id: UUID
page_number: int
class_id: int
class_name: str
x_center: float
y_center: float
width: float
height: float
bbox_x: int
bbox_y: int
bbox_width: int
bbox_height: int
text_value: str | None
confidence: float | None
source: str
created_at: datetime
class TrainingTaskRead(SQLModel):
"""Training task response model."""
task_id: UUID
name: str
description: str | None
status: str
task_type: str
config: dict[str, Any] | None
scheduled_at: datetime | None
is_recurring: bool
started_at: datetime | None
completed_at: datetime | None
error_message: str | None
result_metrics: dict[str, Any] | None
model_path: str | None
created_at: datetime

View File

@@ -0,0 +1,374 @@
"""
Async Request Database Operations
Database interface for async invoice processing requests using SQLModel.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy import func, text
from sqlmodel import Session, select
from src.data.database import get_session_context, create_db_and_tables, close_engine
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent
logger = logging.getLogger(__name__)
# Legacy dataclasses for backward compatibility
from dataclasses import dataclass
@dataclass(frozen=True)
class ApiKeyConfig:
"""API key configuration and limits (legacy compatibility)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestDB:
"""Database interface for async processing requests using SQLModel."""
def __init__(self, connection_string: str | None = None) -> None:
# connection_string is kept for backward compatibility but ignored
# SQLModel uses the global engine from database.py
self._initialized = False
def connect(self):
"""Legacy method - returns self for compatibility."""
return self
def close(self) -> None:
"""Close database connections."""
close_engine()
def __enter__(self) -> "AsyncRequestDB":
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass # Sessions are managed per-operation
def create_tables(self) -> None:
"""Create async processing tables if they don't exist."""
create_db_and_tables()
self._initialized = True
# ==========================================================================
# API Key Operations
# ==========================================================================
def is_valid_api_key(self, api_key: str) -> bool:
"""Check if API key exists and is active."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
return result is not None and result.is_active is True
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
"""Get API key configuration and limits."""
with get_session_context() as session:
result = session.get(ApiKey, api_key)
if result is None:
return None
return ApiKeyConfig(
api_key=result.api_key,
name=result.name,
is_active=result.is_active,
requests_per_minute=result.requests_per_minute,
max_concurrent_jobs=result.max_concurrent_jobs,
max_file_size_mb=result.max_file_size_mb,
)
def create_api_key(
self,
api_key: str,
name: str,
requests_per_minute: int = 10,
max_concurrent_jobs: int = 3,
max_file_size_mb: int = 50,
) -> None:
"""Create a new API key."""
with get_session_context() as session:
existing = session.get(ApiKey, api_key)
if existing:
existing.name = name
existing.requests_per_minute = requests_per_minute
existing.max_concurrent_jobs = max_concurrent_jobs
existing.max_file_size_mb = max_file_size_mb
session.add(existing)
else:
new_key = ApiKey(
api_key=api_key,
name=name,
requests_per_minute=requests_per_minute,
max_concurrent_jobs=max_concurrent_jobs,
max_file_size_mb=max_file_size_mb,
)
session.add(new_key)
def update_api_key_usage(self, api_key: str) -> None:
"""Update API key last used timestamp and increment total requests."""
with get_session_context() as session:
key = session.get(ApiKey, api_key)
if key:
key.last_used_at = datetime.utcnow()
key.total_requests += 1
session.add(key)
# ==========================================================================
# Async Request Operations
# ==========================================================================
def create_request(
self,
api_key: str,
filename: str,
file_size: int,
content_type: str,
expires_at: datetime,
request_id: str | None = None,
) -> str:
"""Create a new async request."""
with get_session_context() as session:
request = AsyncRequest(
api_key=api_key,
filename=filename,
file_size=file_size,
content_type=content_type,
expires_at=expires_at,
)
if request_id:
request.request_id = UUID(request_id)
session.add(request)
session.flush() # To get the generated ID
return str(request.request_id)
def get_request(self, request_id: str) -> AsyncRequest | None:
"""Get a single async request by ID."""
with get_session_context() as session:
result = session.get(AsyncRequest, UUID(request_id))
if result:
# Detach from session for use outside context
session.expunge(result)
return result
def get_request_by_api_key(
self,
request_id: str,
api_key: str,
) -> AsyncRequest | None:
"""Get a request only if it belongs to the given API key."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.request_id == UUID(request_id),
AsyncRequest.api_key == api_key,
)
result = session.exec(statement).first()
if result:
session.expunge(result)
return result
def update_status(
self,
request_id: str,
status: str,
error_message: str | None = None,
increment_retry: bool = False,
) -> None:
"""Update request status."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = status
if status == "processing":
request.started_at = datetime.utcnow()
if error_message is not None:
request.error_message = error_message
if increment_retry:
request.retry_count += 1
session.add(request)
def complete_request(
self,
request_id: str,
document_id: str,
result: dict[str, Any],
processing_time_ms: float,
visualization_path: str | None = None,
) -> None:
"""Mark request as completed with result."""
with get_session_context() as session:
request = session.get(AsyncRequest, UUID(request_id))
if request:
request.status = "completed"
request.document_id = document_id
request.result = result
request.processing_time_ms = processing_time_ms
request.visualization_path = visualization_path
request.completed_at = datetime.utcnow()
session.add(request)
def get_requests_by_api_key(
self,
api_key: str,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[AsyncRequest], int]:
"""Get paginated requests for an API key."""
with get_session_context() as session:
# Count query
count_stmt = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
count_stmt = count_stmt.where(AsyncRequest.status == status)
total = session.exec(count_stmt).one()
# Fetch query
statement = select(AsyncRequest).where(
AsyncRequest.api_key == api_key
)
if status:
statement = statement.where(AsyncRequest.status == status)
statement = statement.order_by(AsyncRequest.created_at.desc())
statement = statement.offset(offset).limit(limit)
results = session.exec(statement).all()
# Detach results from session
for r in results:
session.expunge(r)
return list(results), total
def count_active_jobs(self, api_key: str) -> int:
"""Count active (pending + processing) jobs for an API key."""
with get_session_context() as session:
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.api_key == api_key,
AsyncRequest.status.in_(["pending", "processing"]),
)
return session.exec(statement).one()
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
"""Get pending requests ordered by creation time."""
with get_session_context() as session:
statement = select(AsyncRequest).where(
AsyncRequest.status == "pending"
).order_by(AsyncRequest.created_at).limit(limit)
results = session.exec(statement).all()
for r in results:
session.expunge(r)
return list(results)
def get_queue_position(self, request_id: str) -> int | None:
"""Get position of a request in the pending queue."""
with get_session_context() as session:
# Get the request's created_at
request = session.get(AsyncRequest, UUID(request_id))
if not request:
return None
# Count pending requests created before this one
statement = select(func.count()).select_from(AsyncRequest).where(
AsyncRequest.status == "pending",
AsyncRequest.created_at < request.created_at,
)
count = session.exec(statement).one()
return count + 1 # 1-based position
# ==========================================================================
# Rate Limit Operations
# ==========================================================================
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
"""Record a rate limit event."""
with get_session_context() as session:
event = RateLimitEvent(
api_key=api_key,
event_type=event_type,
)
session.add(event)
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
"""Count requests in the last N seconds."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
statement = select(func.count()).select_from(RateLimitEvent).where(
RateLimitEvent.api_key == api_key,
RateLimitEvent.event_type == "request",
RateLimitEvent.created_at > cutoff,
)
return session.exec(statement).one()
# ==========================================================================
# Cleanup Operations
# ==========================================================================
def delete_expired_requests(self) -> int:
"""Delete requests that have expired. Returns count of deleted rows."""
with get_session_context() as session:
now = datetime.utcnow()
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
expired = session.exec(statement).all()
count = len(expired)
for request in expired:
session.delete(request)
logger.info(f"Deleted {count} expired async requests")
return count
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
"""Delete rate limit events older than N hours."""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(hours=hours)
statement = select(RateLimitEvent).where(
RateLimitEvent.created_at < cutoff
)
old_events = session.exec(statement).all()
count = len(old_events)
for event in old_events:
session.delete(event)
return count
def reset_stale_processing_requests(
self,
stale_minutes: int = 10,
max_retries: int = 3,
) -> int:
"""
Reset requests stuck in 'processing' status.
Requests that have been processing for more than stale_minutes
are considered stale. They are either reset to 'pending' (if under
max_retries) or set to 'failed'.
"""
with get_session_context() as session:
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
reset_count = 0
# Find stale processing requests
statement = select(AsyncRequest).where(
AsyncRequest.status == "processing",
AsyncRequest.started_at < cutoff,
)
stale_requests = session.exec(statement).all()
for request in stale_requests:
if request.retry_count < max_retries:
request.status = "pending"
request.started_at = None
else:
request.status = "failed"
request.error_message = "Processing timeout after max retries"
session.add(request)
reset_count += 1
if reset_count > 0:
logger.warning(f"Reset {reset_count} stale processing requests")
return reset_count

103
src/data/database.py Normal file
View File

@@ -0,0 +1,103 @@
"""
Database Engine and Session Management
Provides SQLModel database engine and session handling.
"""
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.config import get_db_connection_string
logger = logging.getLogger(__name__)
# Global engine instance
_engine = None
def get_engine():
"""Get or create the database engine."""
global _engine
if _engine is None:
connection_string = get_db_connection_string()
# Convert psycopg2 format to SQLAlchemy format
if connection_string.startswith("postgresql://"):
# Already in correct format
pass
elif "host=" in connection_string:
# Convert DSN format to URL format
parts = dict(item.split("=") for item in connection_string.split())
connection_string = (
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
f"/{parts.get('dbname', 'docmaster')}"
)
_engine = create_engine(
connection_string,
echo=False, # Set to True for SQL debugging
pool_pre_ping=True, # Verify connections before use
pool_size=5,
max_overflow=10,
)
return _engine
def create_db_and_tables() -> None:
"""Create all database tables."""
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
from src.data.admin_models import ( # noqa: F401
AdminToken,
AdminDocument,
AdminAnnotation,
TrainingTask,
TrainingLog,
)
engine = get_engine()
SQLModel.metadata.create_all(engine)
logger.info("Database tables created/verified")
def get_session() -> Session:
"""Get a new database session."""
engine = get_engine()
return Session(engine)
@contextmanager
def get_session_context() -> Generator[Session, None, None]:
"""Context manager for database sessions with auto-commit/rollback."""
session = get_session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def close_engine() -> None:
"""Close the database engine and release connections."""
global _engine
if _engine is not None:
_engine.dispose()
_engine = None
logger.info("Database engine closed")
def execute_raw_sql(sql: str) -> None:
"""Execute raw SQL (for migrations)."""
engine = get_engine()
with engine.connect() as conn:
conn.execute(text(sql))
conn.commit()

View File

@@ -10,7 +10,7 @@ import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string
from src.config import get_db_connection_string
class DocumentDB:

View File

@@ -0,0 +1,83 @@
-- Async Invoice Processing Tables
-- Migration: 001_async_tables.sql
-- Created: 2024-01-15
-- API Keys table for authentication and rate limiting
CREATE TABLE IF NOT EXISTS api_keys (
api_key TEXT PRIMARY KEY,
name TEXT NOT NULL,
is_active BOOLEAN DEFAULT true,
-- Rate limits
requests_per_minute INTEGER DEFAULT 10,
max_concurrent_jobs INTEGER DEFAULT 3,
max_file_size_mb INTEGER DEFAULT 50,
-- Usage tracking
total_requests INTEGER DEFAULT 0,
total_processed INTEGER DEFAULT 0,
-- Timestamps
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_used_at TIMESTAMPTZ
);
-- Async processing requests table
CREATE TABLE IF NOT EXISTS async_requests (
request_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE,
status TEXT NOT NULL DEFAULT 'pending',
filename TEXT NOT NULL,
file_size INTEGER NOT NULL,
content_type TEXT NOT NULL,
-- Processing metadata
document_id TEXT,
error_message TEXT,
retry_count INTEGER DEFAULT 0,
-- Timestamps
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
started_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
expires_at TIMESTAMPTZ NOT NULL,
-- Result storage (JSONB for flexibility)
result JSONB,
-- Processing time
processing_time_ms REAL,
-- Visualization path
visualization_path TEXT,
CONSTRAINT valid_status CHECK (status IN ('pending', 'processing', 'completed', 'failed'))
);
-- Indexes for async_requests
CREATE INDEX IF NOT EXISTS idx_async_requests_api_key ON async_requests(api_key);
CREATE INDEX IF NOT EXISTS idx_async_requests_status ON async_requests(status);
CREATE INDEX IF NOT EXISTS idx_async_requests_created_at ON async_requests(created_at);
CREATE INDEX IF NOT EXISTS idx_async_requests_expires_at ON async_requests(expires_at);
CREATE INDEX IF NOT EXISTS idx_async_requests_api_key_status ON async_requests(api_key, status);
-- Rate limit tracking table
CREATE TABLE IF NOT EXISTS rate_limit_events (
id SERIAL PRIMARY KEY,
api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE,
event_type TEXT NOT NULL, -- 'request', 'complete', 'fail'
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Index for rate limiting queries (recent events only)
CREATE INDEX IF NOT EXISTS idx_rate_limit_events_api_key_time
ON rate_limit_events(api_key, created_at DESC);
-- Cleanup old rate limit events index
CREATE INDEX IF NOT EXISTS idx_rate_limit_events_cleanup
ON rate_limit_events(created_at);
-- Insert default API key for development/testing
INSERT INTO api_keys (api_key, name, requests_per_minute, max_concurrent_jobs)
VALUES ('dev-api-key-12345', 'Development Key', 100, 10)
ON CONFLICT (api_key) DO NOTHING;

View File

@@ -0,0 +1,5 @@
-- Migration: Make admin_token nullable in admin_documents table
-- This allows documents uploaded via public API to not require an admin token
ALTER TABLE admin_documents
ALTER COLUMN admin_token DROP NOT NULL;

95
src/data/models.py Normal file
View File

@@ -0,0 +1,95 @@
"""
SQLModel Database Models
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON
class ApiKey(SQLModel, table=True):
"""API key configuration and limits."""
__tablename__ = "api_keys"
api_key: str = Field(primary_key=True, max_length=255)
name: str = Field(max_length=255)
is_active: bool = Field(default=True)
requests_per_minute: int = Field(default=10)
max_concurrent_jobs: int = Field(default=3)
max_file_size_mb: int = Field(default=50)
total_requests: int = Field(default=0)
total_processed: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: datetime | None = Field(default=None)
class AsyncRequest(SQLModel, table=True):
"""Async request record."""
__tablename__ = "async_requests"
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
status: str = Field(default="pending", max_length=20, index=True)
filename: str = Field(max_length=255)
file_size: int
content_type: str = Field(max_length=100)
document_id: str | None = Field(default=None, max_length=100)
error_message: str | None = Field(default=None)
retry_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
expires_at: datetime = Field(index=True)
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
processing_time_ms: float | None = Field(default=None)
visualization_path: str | None = Field(default=None, max_length=255)
class RateLimitEvent(SQLModel, table=True):
"""Rate limit event record."""
__tablename__ = "rate_limit_events"
id: int | None = Field(default=None, primary_key=True)
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
event_type: str = Field(max_length=50)
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Read-only models for responses (without table=True)
class ApiKeyRead(SQLModel):
"""API key response model (read-only)."""
api_key: str
name: str
is_active: bool
requests_per_minute: int
max_concurrent_jobs: int
max_file_size_mb: int
class AsyncRequestRead(SQLModel):
"""Async request response model (read-only)."""
request_id: UUID
api_key: str
status: str
filename: str
file_size: int
content_type: str
document_id: str | None
error_message: str | None
retry_count: int
created_at: datetime
started_at: datetime | None
completed_at: datetime | None
expires_at: datetime
result: dict[str, Any] | None
processing_time_ms: float | None
visualization_path: str | None

View File

@@ -12,6 +12,8 @@ import warnings
from pathlib import Path
from typing import Any, Dict, Optional
from src.config import DEFAULT_DPI
# Global OCR instance (initialized once per GPU worker process)
_ocr_engine: Optional[Any] = None
@@ -94,7 +96,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
row_dict = task_data["row_dict"]
pdf_path = Path(task_data["pdf_path"])
output_dir = Path(task_data["output_dir"])
dpi = task_data.get("dpi", 150)
dpi = task_data.get("dpi", DEFAULT_DPI)
min_confidence = task_data.get("min_confidence", 0.5)
start_time = time.time()
@@ -212,7 +214,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
row_dict = task_data["row_dict"]
pdf_path = Path(task_data["pdf_path"])
output_dir = Path(task_data["output_dir"])
dpi = task_data.get("dpi", 150)
dpi = task_data.get("dpi", DEFAULT_DPI)
min_confidence = task_data.get("min_confidence", 0.5)
start_time = time.time()

View File

@@ -16,6 +16,8 @@ from datetime import datetime
import psycopg2
from psycopg2.extras import execute_values
from src.config import DEFAULT_DPI
@dataclass
class LLMExtractionResult:
@@ -265,7 +267,7 @@ Return ONLY the JSON object, no other text."""
self,
pdf_path: Path,
page_no: int = 0,
dpi: int = 150,
dpi: int = DEFAULT_DPI,
max_size_mb: float = 18.0
) -> bytes:
"""

View File

@@ -0,0 +1,8 @@
"""
Backward compatibility shim for admin_routes.py
DEPRECATED: Import from src.web.api.v1.admin.documents instead.
"""
from src.web.api.v1.admin.documents import *
__all__ = ["create_admin_router"]

0
src/web/api/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,19 @@
"""
Admin API v1
Document management, annotations, and training endpoints.
"""
from src.web.api.v1.admin.annotations import create_annotation_router
from src.web.api.v1.admin.auth import create_auth_router
from src.web.api.v1.admin.documents import create_documents_router
from src.web.api.v1.admin.locks import create_locks_router
from src.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",
"create_training_router",
]

View File

@@ -0,0 +1,644 @@
"""
Admin Annotation API Routes
FastAPI endpoints for annotation management.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse
from src.data.admin_db import AdminDB
from src.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
from src.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.services.autolabel import get_auto_label_service
from src.web.schemas.admin import (
AnnotationCreate,
AnnotationItem,
AnnotationListResponse,
AnnotationOverrideRequest,
AnnotationOverrideResponse,
AnnotationResponse,
AnnotationSource,
AnnotationUpdate,
AnnotationVerifyRequest,
AnnotationVerifyResponse,
AutoLabelRequest,
AutoLabelResponse,
BoundingBox,
)
from src.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
# Image storage directory
ADMIN_IMAGES_DIR = Path("data/admin_images")
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_annotation_router() -> APIRouter:
"""Create annotation API router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
# =========================================================================
# Image Endpoints
# =========================================================================
@router.get(
"/{document_id}/images/{page_number}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Get page image",
description="Get the image for a specific page.",
)
async def get_page_image(
document_id: str,
page_number: int,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> FileResponse:
"""Get page image."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate page number
if page_number < 1 or page_number > document.page_count:
raise HTTPException(
status_code=404,
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
)
# Find image file
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
if not image_path.exists():
raise HTTPException(
status_code=404,
detail=f"Image for page {page_number} not found",
)
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=f"{document.filename}_page_{page_number}.png",
)
# =========================================================================
# Annotation Endpoints
# =========================================================================
@router.get(
"/{document_id}/annotations",
response_model=AnnotationListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="List annotations",
description="Get all annotations for a document.",
)
async def list_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
page_number: Annotated[
int | None,
Query(ge=1, description="Filter by page number"),
] = None,
) -> AnnotationListResponse:
"""List annotations for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id, page_number)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
return AnnotationListResponse(
document_id=document_id,
page_count=document.page_count,
total_annotations=len(annotations),
annotations=annotations,
)
@router.post(
"/{document_id}/annotations",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Create annotation",
description="Create a new annotation for a document.",
)
async def create_annotation(
document_id: str,
request: AnnotationCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationResponse:
"""Create a new annotation."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate page number
if request.page_number > document.page_count:
raise HTTPException(
status_code=400,
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
)
# Get image dimensions for normalization
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
if not image_path.exists():
raise HTTPException(
status_code=400,
detail=f"Image for page {request.page_number} not available",
)
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
# Calculate normalized coordinates
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
width = request.bbox.width / image_width
height = request.bbox.height / image_height
# Get class name
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
# Create annotation
annotation_id = db.create_annotation(
document_id=document_id,
page_number=request.page_number,
class_id=request.class_id,
class_name=class_name,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
bbox_x=request.bbox.x,
bbox_y=request.bbox.y,
bbox_width=request.bbox.width,
bbox_height=request.bbox.height,
text_value=request.text_value,
source="manual",
)
# Keep status as pending - user must click "Mark Complete" to finalize
# This allows user to add multiple annotations before saving to PostgreSQL
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation created successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}",
response_model=AnnotationResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Update annotation",
description="Update an existing annotation.",
)
async def update_annotation(
document_id: str,
annotation_id: str,
request: AnnotationUpdate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationResponse:
"""Update an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Prepare update data
update_kwargs = {}
if request.class_id is not None:
update_kwargs["class_id"] = request.class_id
update_kwargs["class_name"] = FIELD_CLASSES.get(
request.class_id, f"class_{request.class_id}"
)
if request.text_value is not None:
update_kwargs["text_value"] = request.text_value
if request.bbox is not None:
# Get image dimensions
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
# Calculate normalized coordinates
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
update_kwargs["width"] = request.bbox.width / image_width
update_kwargs["height"] = request.bbox.height / image_height
update_kwargs["bbox_x"] = request.bbox.x
update_kwargs["bbox_y"] = request.bbox.y
update_kwargs["bbox_width"] = request.bbox.width
update_kwargs["bbox_height"] = request.bbox.height
# Update annotation
if update_kwargs:
success = db.update_annotation(annotation_id, **update_kwargs)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to update annotation",
)
return AnnotationResponse(
annotation_id=annotation_id,
message="Annotation updated successfully",
)
@router.delete(
"/{document_id}/annotations/{annotation_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"},
},
summary="Delete annotation",
description="Delete an annotation.",
)
async def delete_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get existing annotation
annotation = db.get_annotation(annotation_id)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Verify annotation belongs to document
if str(annotation.document_id) != document_id:
raise HTTPException(
status_code=404,
detail="Annotation does not belong to this document",
)
# Delete annotation
db.delete_annotation(annotation_id)
return {
"status": "deleted",
"annotation_id": annotation_id,
"message": "Annotation deleted successfully",
}
# =========================================================================
# Auto-Labeling Endpoints
# =========================================================================
@router.post(
"/{document_id}/auto-label",
response_model=AutoLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Trigger auto-labeling",
description="Trigger auto-labeling for a document using field values.",
)
async def trigger_auto_label(
document_id: str,
request: AutoLabelRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AutoLabelResponse:
"""Trigger auto-labeling for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Validate field values
if not request.field_values:
raise HTTPException(
status_code=400,
detail="At least one field value is required",
)
# Run auto-labeling
service = get_auto_label_service()
result = service.auto_label_document(
document_id=document_id,
file_path=document.file_path,
field_values=request.field_values,
db=db,
replace_existing=request.replace_existing,
)
if result["status"] == "failed":
raise HTTPException(
status_code=500,
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
)
return AutoLabelResponse(
document_id=document_id,
status=result["status"],
annotations_created=result["annotations_created"],
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
)
@router.delete(
"/{document_id}/annotations",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete all annotations",
description="Delete all annotations for a document (optionally filter by source).",
)
async def delete_all_annotations(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
source: Annotated[
str | None,
Query(description="Filter by source (manual, auto, imported)"),
] = None,
) -> dict:
"""Delete all annotations for a document."""
_validate_uuid(document_id, "document_id")
# Validate source
if source and source not in ("manual", "auto", "imported"):
raise HTTPException(
status_code=400,
detail=f"Invalid source: {source}",
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Delete annotations
deleted_count = db.delete_annotations_for_document(document_id, source)
# Update document status if all annotations deleted
remaining = db.get_annotations_for_document(document_id)
if not remaining:
db.update_document_status(document_id, "pending")
return {
"status": "deleted",
"document_id": document_id,
"deleted_count": deleted_count,
"message": f"Deleted {deleted_count} annotations",
}
# =========================================================================
# Phase 5: Annotation Enhancement
# =========================================================================
@router.post(
"/{document_id}/annotations/{annotation_id}/verify",
response_model=AnnotationVerifyResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Verify annotation",
description="Mark an annotation as verified by a human reviewer.",
)
async def verify_annotation(
document_id: str,
annotation_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
) -> AnnotationVerifyResponse:
"""Verify an annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Verify the annotation
annotation = db.verify_annotation(annotation_id, admin_token)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
return AnnotationVerifyResponse(
annotation_id=annotation_id,
is_verified=annotation.is_verified,
verified_at=annotation.verified_at,
verified_by=annotation.verified_by,
message="Annotation verified successfully",
)
@router.patch(
"/{document_id}/annotations/{annotation_id}/override",
response_model=AnnotationOverrideResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Annotation not found"},
},
summary="Override annotation",
description="Override an auto-generated annotation with manual corrections.",
)
async def override_annotation(
document_id: str,
annotation_id: str,
request: AnnotationOverrideRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> AnnotationOverrideResponse:
"""Override an auto-generated annotation."""
_validate_uuid(document_id, "document_id")
_validate_uuid(annotation_id, "annotation_id")
# Verify ownership of document
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Build updates dict from request
updates = {}
if request.text_value is not None:
updates["text_value"] = request.text_value
if request.class_id is not None:
updates["class_id"] = request.class_id
# Update class_name if class_id changed
if request.class_id in FIELD_CLASSES:
updates["class_name"] = FIELD_CLASSES[request.class_id]
if request.class_name is not None:
updates["class_name"] = request.class_name
if request.bbox:
# Update bbox fields
if "x" in request.bbox:
updates["bbox_x"] = request.bbox["x"]
if "y" in request.bbox:
updates["bbox_y"] = request.bbox["y"]
if "width" in request.bbox:
updates["bbox_width"] = request.bbox["width"]
if "height" in request.bbox:
updates["bbox_height"] = request.bbox["height"]
if not updates:
raise HTTPException(
status_code=400,
detail="No updates provided. Specify at least one field to update.",
)
# Override the annotation
annotation = db.override_annotation(
annotation_id=annotation_id,
admin_token=admin_token,
change_reason=request.reason,
**updates,
)
if annotation is None:
raise HTTPException(
status_code=404,
detail="Annotation not found",
)
# Get history to return history_id
history_records = db.get_annotation_history(UUID(annotation_id))
latest_history = history_records[0] if history_records else None
return AnnotationOverrideResponse(
annotation_id=annotation_id,
source=annotation.source,
override_source=annotation.override_source,
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
message="Annotation overridden successfully",
history_id=str(latest_history.history_id) if latest_history else "",
)
return router

View File

@@ -0,0 +1,82 @@
"""
Admin Auth Routes
FastAPI endpoints for admin token management.
"""
import logging
import secrets
from datetime import datetime, timedelta
from fastapi import APIRouter
from src.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.schemas.admin import (
AdminTokenCreate,
AdminTokenResponse,
)
from src.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def create_auth_router() -> APIRouter:
"""Create admin auth router."""
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
@router.post(
"/token",
response_model=AdminTokenResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
},
summary="Create admin token",
description="Create a new admin authentication token.",
)
async def create_token(
request: AdminTokenCreate,
db: AdminDBDep,
) -> AdminTokenResponse:
"""Create a new admin token."""
# Generate secure token
token = secrets.token_urlsafe(32)
# Calculate expiration
expires_at = None
if request.expires_in_days:
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
# Create token in database
db.create_admin_token(
token=token,
name=request.name,
expires_at=expires_at,
)
return AdminTokenResponse(
token=token,
name=request.name,
expires_at=expires_at,
message="Admin token created successfully",
)
@router.delete(
"/token",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Revoke admin token",
description="Revoke the current admin token.",
)
async def revoke_token(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Revoke the current admin token."""
db.deactivate_admin_token(admin_token)
return {
"status": "revoked",
"message": "Admin token has been revoked",
}
return router

View File

@@ -0,0 +1,551 @@
"""
Admin Document Routes
FastAPI endpoints for admin document management.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from src.web.config import DEFAULT_DPI, StorageConfig
from src.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.schemas.admin import (
AnnotationItem,
AnnotationSource,
AutoLabelStatus,
BoundingBox,
DocumentDetailResponse,
DocumentItem,
DocumentListResponse,
DocumentStatus,
DocumentStatsResponse,
DocumentUploadResponse,
ModelMetrics,
TrainingHistoryItem,
)
from src.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
# Render at configured DPI for consistency with training
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
pdf_doc.close()
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
"""Create admin documents router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
# Directories are created by StorageConfig.__post_init__
allowed_extensions = storage_config.allowed_extensions
@router.post(
"",
response_model=DocumentUploadResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Upload document",
description="Upload a PDF or image document for labeling.",
)
async def upload_document(
admin_token: AdminTokenDep,
db: AdminDBDep,
file: UploadFile = File(..., description="PDF or image file"),
auto_label: Annotated[
bool,
Query(description="Trigger auto-labeling after upload"),
] = True,
) -> DocumentUploadResponse:
"""Upload a document for labeling."""
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Get page count (for PDF)
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record (token only used for auth, not stored)
document_id = db.create_document(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(status_code=500, detail="Failed to save file")
# Update file path in database
from src.data.database import get_session_context
from src.data.admin_models import AdminDocument
with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id))
if doc:
doc.file_path = str(file_path)
session.add(doc)
# Convert PDF to images for annotation
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling if requested
auto_label_started = False
if auto_label:
# Auto-labeling will be triggered by a background task
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="running",
)
auto_label_started = True
return DocumentUploadResponse(
document_id=document_id,
filename=file.filename,
file_size=len(content),
page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
auto_label_started=auto_label_started,
message="Document uploaded successfully",
)
@router.get(
"",
response_model=DocumentListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List documents",
description="List all documents for the current admin.",
)
async def list_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
upload_source: Annotated[
str | None,
Query(description="Filter by upload source (ui or api)"),
] = None,
has_annotations: Annotated[
bool | None,
Query(description="Filter by annotation presence"),
] = None,
auto_label_status: Annotated[
str | None,
Query(description="Filter by auto-label status"),
] = None,
batch_id: Annotated[
str | None,
Query(description="Filter by batch ID"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> DocumentListResponse:
"""List documents."""
# Validate status
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Validate upload_source
if upload_source and upload_source not in ("ui", "api"):
raise HTTPException(
status_code=400,
detail=f"Invalid upload_source: {upload_source}",
)
# Validate auto_label_status
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid auto_label_status: {auto_label_status}",
)
documents, total = db.get_documents_by_token(
admin_token=admin_token,
status=status,
upload_source=upload_source,
has_annotations=has_annotations,
auto_label_status=auto_label_status,
batch_id=batch_id,
limit=limit,
offset=offset,
)
# Get annotation counts and build items
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
# Determine if document can be annotated (not locked)
can_annotate = True
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
from datetime import datetime, timezone
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
items.append(
DocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
file_size=doc.file_size,
page_count=doc.page_count,
status=DocumentStatus(doc.status),
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
annotation_count=len(annotations),
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
can_annotate=can_annotate,
created_at=doc.created_at,
updated_at=doc.updated_at,
)
)
return DocumentListResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/stats",
response_model=DocumentStatsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get document statistics",
description="Get document count by status.",
)
async def get_document_stats(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentStatsResponse:
"""Get document statistics."""
counts = db.count_documents_by_status(admin_token)
return DocumentStatsResponse(
total=sum(counts.values()),
pending=counts.get("pending", 0),
auto_labeling=counts.get("auto_labeling", 0),
labeled=counts.get("labeled", 0),
exported=counts.get("exported", 0),
)
@router.get(
"/{document_id}",
response_model=DocumentDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Get document detail",
description="Get document details with annotations.",
)
async def get_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentDetailResponse:
"""Get document details."""
_validate_uuid(document_id, "document_id")
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Get annotations
raw_annotations = db.get_annotations_for_document(document_id)
annotations = [
AnnotationItem(
annotation_id=str(ann.annotation_id),
page_number=ann.page_number,
class_id=ann.class_id,
class_name=ann.class_name,
bbox=BoundingBox(
x=ann.bbox_x,
y=ann.bbox_y,
width=ann.bbox_width,
height=ann.bbox_height,
),
normalized_bbox={
"x_center": ann.x_center,
"y_center": ann.y_center,
"width": ann.width,
"height": ann.height,
},
text_value=ann.text_value,
confidence=ann.confidence,
source=AnnotationSource(ann.source),
created_at=ann.created_at,
)
for ann in raw_annotations
]
# Generate image URLs
image_urls = []
for page in range(1, document.page_count + 1):
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
# Determine if document can be annotated (not locked)
can_annotate = True
annotation_lock_until = None
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
from datetime import datetime, timezone
annotation_lock_until = document.annotation_lock_until
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
# Get CSV field values if available
csv_field_values = None
if hasattr(document, 'csv_field_values') and document.csv_field_values:
csv_field_values = document.csv_field_values
# Get training history (Phase 5)
training_history = []
training_links = db.get_document_training_tasks(document.document_id)
for link in training_links:
# Get task details
task = db.get_training_task(str(link.task_id))
if task:
# Build metrics
metrics = None
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
training_history.append(
TrainingHistoryItem(
task_id=str(link.task_id),
name=task.name,
trained_at=link.created_at,
model_metrics=metrics,
)
)
return DocumentDetailResponse(
document_id=str(document.document_id),
filename=document.filename,
file_size=document.file_size,
content_type=document.content_type,
page_count=document.page_count,
status=DocumentStatus(document.status),
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
auto_label_error=document.auto_label_error,
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
csv_field_values=csv_field_values,
can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until,
annotations=annotations,
image_urls=image_urls,
training_history=training_history,
created_at=document.created_at,
updated_at=document.updated_at,
)
@router.delete(
"/{document_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Delete document",
description="Delete a document and its annotations.",
)
async def delete_document(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Delete file
file_path = Path(document.file_path)
if file_path.exists():
file_path.unlink()
# Delete images
images_dir = ADMIN_IMAGES_DIR / document_id
if images_dir.exists():
import shutil
shutil.rmtree(images_dir)
# Delete from database
db.delete_document(document_id)
return {
"status": "deleted",
"document_id": document_id,
"message": "Document deleted successfully",
}
@router.patch(
"/{document_id}/status",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document status",
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
)
async def update_document_status(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str,
Query(description="New status"),
],
) -> dict:
"""Update document status.
When status is set to 'labeled', the annotations are automatically
saved to PostgreSQL documents/field_results tables for consistency
with CLI auto-label workflow.
"""
_validate_uuid(document_id, "document_id")
# Validate status
if status not in ("pending", "labeled", "exported"):
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
)
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# If marking as labeled, save annotations to PostgreSQL DocumentDB
db_save_result = None
if status == "labeled":
from src.web.services.db_autolabel import save_manual_annotations_to_document_db
# Get all annotations for this document
annotations = db.get_annotations_for_document(document_id)
if annotations:
db_save_result = save_manual_annotations_to_document_db(
document=document,
annotations=annotations,
db=db,
)
db.update_document_status(document_id, status)
response = {
"status": "updated",
"document_id": document_id,
"new_status": status,
"message": "Document status updated",
}
# Include PostgreSQL save result if applicable
if db_save_result:
response["document_db_saved"] = db_save_result.get("success", False)
response["fields_saved"] = db_save_result.get("fields_saved", 0)
return response
return router

View File

@@ -0,0 +1,184 @@
"""
Admin Document Lock Routes
FastAPI endpoints for annotation lock management.
"""
import logging
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from src.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.schemas.admin import (
AnnotationLockRequest,
AnnotationLockResponse,
)
from src.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_locks_router() -> APIRouter:
"""Create annotation locks router."""
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
@router.post(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Document already locked"},
},
summary="Acquire annotation lock",
description="Acquire a lock on a document to prevent concurrent annotation edits.",
)
async def acquire_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Acquire annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to acquire lock
updated_doc = db.acquire_annotation_lock(
document_id=document_id,
admin_token=admin_token,
duration_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Document is already locked. Please try again later.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock acquired for {request.duration_seconds} seconds",
)
@router.delete(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Release annotation lock",
description="Release the annotation lock on a document.",
)
async def release_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
force: Annotated[
bool,
Query(description="Force release (admin override)"),
] = False,
) -> AnnotationLockResponse:
"""Release annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Release lock
updated_doc = db.release_annotation_lock(
document_id=document_id,
admin_token=admin_token,
force=force,
)
if updated_doc is None:
raise HTTPException(
status_code=404,
detail="Failed to release lock",
)
return AnnotationLockResponse(
document_id=document_id,
locked=False,
lock_expires_at=None,
message="Lock released successfully",
)
@router.patch(
"/{document_id}/lock",
response_model=AnnotationLockResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
},
summary="Extend annotation lock",
description="Extend an existing annotation lock.",
)
async def extend_lock(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: AnnotationLockRequest = AnnotationLockRequest(),
) -> AnnotationLockResponse:
"""Extend annotation lock for a document."""
_validate_uuid(document_id, "document_id")
# Verify ownership
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Attempt to extend lock
updated_doc = db.extend_annotation_lock(
document_id=document_id,
admin_token=admin_token,
additional_seconds=request.duration_seconds,
)
if updated_doc is None:
raise HTTPException(
status_code=409,
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
)
return AnnotationLockResponse(
document_id=document_id,
locked=True,
lock_expires_at=updated_doc.annotation_lock_until,
message=f"Lock extended by {request.duration_seconds} seconds",
)
return router

View File

@@ -0,0 +1,622 @@
"""
Admin Training API Routes
FastAPI endpoints for training task management and scheduling.
"""
import logging
from datetime import datetime
from typing import Annotated, Any
from uuid import UUID
from fastapi import APIRouter, HTTPException, Query
from src.data.admin_db import AdminDB
from src.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.schemas.admin import (
ExportRequest,
ExportResponse,
ModelMetrics,
TrainingConfig,
TrainingDocumentItem,
TrainingDocumentsResponse,
TrainingHistoryItem,
TrainingLogItem,
TrainingLogsResponse,
TrainingModelItem,
TrainingModelsResponse,
TrainingStatus,
TrainingTaskCreate,
TrainingTaskDetailResponse,
TrainingTaskItem,
TrainingTaskListResponse,
TrainingTaskResponse,
TrainingType,
)
from src.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)
def create_training_router() -> APIRouter:
"""Create training API router."""
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
# =========================================================================
# Training Task Endpoints
# =========================================================================
@router.post(
"/tasks",
response_model=TrainingTaskResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Create training task",
description="Create a new training task.",
)
async def create_training_task(
request: TrainingTaskCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a new training task."""
# Convert config to dict
config_dict = request.config.model_dump() if request.config else {}
# Create task
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type=request.task_type.value,
description=request.description,
config=config_dict,
scheduled_at=request.scheduled_at,
cron_expression=request.cron_expression,
is_recurring=bool(request.cron_expression),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
message="Training task created successfully",
)
@router.get(
"/tasks",
response_model=TrainingTaskListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List training tasks",
description="List all training tasks.",
)
async def list_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingTaskListResponse:
"""List training tasks."""
# Validate status
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
if status and status not in valid_statuses:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
)
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status,
limit=limit,
offset=offset,
)
items = [
TrainingTaskItem(
task_id=str(task.task_id),
name=task.name,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
scheduled_at=task.scheduled_at,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
created_at=task.created_at,
)
for task in tasks
]
return TrainingTaskListResponse(
total=total,
limit=limit,
offset=offset,
tasks=items,
)
@router.get(
"/tasks/{task_id}",
response_model=TrainingTaskDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training task detail",
description="Get training task details.",
)
async def get_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskDetailResponse:
"""Get training task details."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
return TrainingTaskDetailResponse(
task_id=str(task.task_id),
name=task.name,
description=task.description,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
config=task.config,
scheduled_at=task.scheduled_at,
cron_expression=task.cron_expression,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
error_message=task.error_message,
result_metrics=task.result_metrics,
model_path=task.model_path,
created_at=task.created_at,
)
@router.post(
"/tasks/{task_id}/cancel",
response_model=TrainingTaskResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
},
summary="Cancel training task",
description="Cancel a pending or scheduled training task.",
)
async def cancel_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Cancel a training task."""
_validate_uuid(task_id, "task_id")
# Verify ownership
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
# Check if can be cancelled
if task.status not in ("pending", "scheduled"):
raise HTTPException(
status_code=409,
detail=f"Cannot cancel task with status: {task.status}",
)
# Cancel task
success = db.cancel_training_task(task_id)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to cancel training task",
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.CANCELLED,
message="Training task cancelled successfully",
)
@router.get(
"/tasks/{task_id}/logs",
response_model=TrainingLogsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training logs",
description="Get training task logs.",
)
async def get_training_logs(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
limit: Annotated[
int,
Query(ge=1, le=500, description="Maximum logs to return"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingLogsResponse:
"""Get training logs."""
_validate_uuid(task_id, "task_id")
# Verify ownership
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
# Get logs
logs = db.get_training_logs(task_id, limit, offset)
items = [
TrainingLogItem(
level=log.level,
message=log.message,
details=log.details,
created_at=log.created_at,
)
for log in logs
]
return TrainingLogsResponse(
task_id=task_id,
logs=items,
)
# =========================================================================
# Phase 4: Training Data Management
# =========================================================================
@router.get(
"/documents",
response_model=TrainingDocumentsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get documents for training",
description="Get labeled documents available for training with filtering options.",
)
async def get_training_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
has_annotations: Annotated[
bool,
Query(description="Only include documents with annotations"),
] = True,
min_annotation_count: Annotated[
int | None,
Query(ge=1, description="Minimum annotation count"),
] = None,
exclude_used_in_training: Annotated[
bool,
Query(description="Exclude documents already used in training"),
] = False,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingDocumentsResponse:
"""Get documents available for training."""
# Get documents
documents, total = db.get_documents_for_training(
admin_token=admin_token,
status="labeled",
has_annotations=has_annotations,
min_annotation_count=min_annotation_count,
exclude_used_in_training=exclude_used_in_training,
limit=limit,
offset=offset,
)
# Build response items with annotation details and training history
items = []
for doc in documents:
# Get annotations for this document
annotations = db.get_annotations_for_document(str(doc.document_id))
# Count annotations by source
sources = {"manual": 0, "auto": 0}
for ann in annotations:
if ann.source in sources:
sources[ann.source] += 1
# Get training history
training_links = db.get_document_training_tasks(doc.document_id)
used_in_training = [str(link.task_id) for link in training_links]
items.append(
TrainingDocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
annotation_count=len(annotations),
annotation_sources=sources,
used_in_training=used_in_training,
last_modified=doc.updated_at,
)
)
return TrainingDocumentsResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/models/{task_id}/download",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Model not found"},
},
summary="Download trained model",
description="Download trained model weights file.",
)
async def download_model(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
):
"""Download trained model."""
from fastapi.responses import FileResponse
from pathlib import Path
_validate_uuid(task_id, "task_id")
# Verify ownership
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
# Check if model exists
if not task.model_path:
raise HTTPException(
status_code=404,
detail="Model file not available for this task",
)
model_path = Path(task.model_path)
if not model_path.exists():
raise HTTPException(
status_code=404,
detail="Model file not found on disk",
)
return FileResponse(
path=str(model_path),
media_type="application/octet-stream",
filename=f"{task.name}_model.pt",
)
@router.get(
"/models",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get trained models",
description="Get list of trained models with metrics and download links.",
)
async def get_training_models(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (completed, failed, etc.)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingModelsResponse:
"""Get list of trained models."""
# Get training tasks
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status if status else "completed",
limit=limit,
offset=offset,
)
# Build response items
items = []
for task in tasks:
# Build metrics
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
# Build download URL if model exists
download_url = None
if task.model_path and task.status == "completed":
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
items.append(
TrainingModelItem(
task_id=str(task.task_id),
name=task.name,
status=TrainingStatus(task.status),
document_count=task.document_count,
created_at=task.created_at,
completed_at=task.completed_at,
metrics=metrics,
model_path=task.model_path,
download_url=download_url,
)
)
return TrainingModelsResponse(
total=total,
limit=limit,
offset=offset,
models=items,
)
# =========================================================================
# Export Endpoints
# =========================================================================
@router.post(
"/export",
response_model=ExportResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Export annotations",
description="Export annotations in YOLO format for training.",
)
async def export_annotations(
request: ExportRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ExportResponse:
"""Export annotations for training."""
from pathlib import Path
import shutil
# Validate format
if request.format not in ("yolo", "coco", "voc"):
raise HTTPException(
status_code=400,
detail=f"Unsupported export format: {request.format}",
)
# Get labeled documents
documents = db.get_labeled_documents_for_export(admin_token)
if not documents:
raise HTTPException(
status_code=400,
detail="No labeled documents available for export",
)
# Create export directory
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
export_dir.mkdir(parents=True, exist_ok=True)
# YOLO format directories
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
# Calculate train/val split
total_docs = len(documents)
train_count = int(total_docs * request.split_ratio)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
# Export documents
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
# Get annotations
annotations = db.get_annotations_for_document(str(doc.document_id))
if not annotations:
continue
# Export each page
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
if not page_annotations and not request.include_images:
continue
# Copy image
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
continue
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
total_images += 1
# Write YOLO label file
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
# YOLO format: class_id x_center y_center width height
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
# Create data.yaml
from src.data.admin_models import FIELD_CLASSES
yaml_content = f"""# Auto-generated YOLO dataset config
path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
(export_dir / "data.yaml").write_text(yaml_content)
return ExportResponse(
status="completed",
export_path=str(export_dir),
total_images=total_images,
total_annotations=total_annotations,
train_count=len(train_docs),
val_count=len(val_docs),
message=f"Exported {total_images} images with {total_annotations} annotations",
)
return router

View File

View File

@@ -0,0 +1,236 @@
"""
Batch Upload API Routes
Endpoints for batch uploading documents via ZIP files with CSV metadata.
"""
import io
import logging
import zipfile
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
from fastapi.responses import JSONResponse
from src.data.admin_db import AdminDB
from src.web.core.auth import validate_admin_token, get_admin_db
from src.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
from src.web.workers.batch_queue import BatchTask, get_batch_queue
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
@router.post("/upload")
async def upload_batch(
file: UploadFile = File(...),
upload_source: str = Form(default="ui"),
async_mode: bool = Form(default=True),
auto_label: bool = Form(default=True),
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
) -> dict:
"""Upload a batch of documents via ZIP file.
The ZIP file can contain:
- Multiple PDF files
- Optional CSV file with field values for auto-labeling
CSV format:
- Required column: DocumentId (matches PDF filename without extension)
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
Args:
file: ZIP file upload
upload_source: Upload source (ui or api)
admin_token: Admin authentication token
admin_db: Admin database interface
Returns:
Batch upload result with batch_id and status
"""
if not file.filename.lower().endswith('.zip'):
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
# Check compressed size
if file.size and file.size > MAX_COMPRESSED_SIZE:
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
raise HTTPException(
status_code=400,
detail=f"File size exceeds {max_mb:.0f}MB limit"
)
try:
# Read file content
zip_content = await file.read()
# Additional security validation before processing
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
# Quick validation of ZIP structure
test_zip.testzip()
except zipfile.BadZipFile:
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
if async_mode:
# Async mode: Queue task and return immediately
from uuid import uuid4
batch_id = uuid4()
# Create batch task for background processing
task = BatchTask(
batch_id=batch_id,
admin_token=admin_token,
zip_content=zip_content,
zip_filename=file.filename,
upload_source=upload_source,
auto_label=auto_label,
created_at=datetime.utcnow(),
)
# Submit to queue
queue = get_batch_queue()
if not queue.submit(task):
raise HTTPException(
status_code=503,
detail="Processing queue is full. Please try again later."
)
logger.info(
f"Batch upload queued: batch_id={batch_id}, "
f"filename={file.filename}, async_mode=True"
)
# Return 202 Accepted with batch_id and status URL
return JSONResponse(
status_code=202,
content={
"status": "accepted",
"batch_id": str(batch_id),
"message": "Batch upload queued for processing",
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
"queue_depth": queue.get_queue_depth(),
}
)
else:
# Sync mode: Process immediately and return results
service = BatchUploadService(admin_db)
result = service.process_zip_upload(
admin_token=admin_token,
zip_filename=file.filename,
zip_content=zip_content,
upload_source=upload_source,
)
logger.info(
f"Batch upload completed: batch_id={result.get('batch_id')}, "
f"status={result.get('status')}, files={result.get('successful_files')}"
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing batch upload: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail="Failed to process batch upload. Please contact support."
)
@router.get("/status/{batch_id}")
async def get_batch_status(
batch_id: str,
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
) -> dict:
"""Get batch upload status and file processing details.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
admin_db: Admin database interface
Returns:
Batch status with file processing details
"""
# Validate UUID format
try:
batch_uuid = UUID(batch_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid batch ID format")
# Check batch exists and verify ownership
batch = admin_db.get_batch_upload(batch_uuid)
if not batch:
raise HTTPException(status_code=404, detail="Batch not found")
# CRITICAL: Verify ownership
if batch.admin_token != admin_token:
raise HTTPException(
status_code=403,
detail="You do not have access to this batch"
)
# Now safe to return details
service = BatchUploadService(admin_db)
result = service.get_batch_status(batch_id)
return result
@router.get("/list")
async def list_batch_uploads(
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
limit: int = 50,
offset: int = 0,
) -> dict:
"""List batch uploads for the current admin token.
Args:
admin_token: Admin authentication token
admin_db: Admin database interface
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of batch uploads
"""
# Validate pagination parameters
if limit < 1 or limit > 100:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative")
# Get batch uploads filtered by admin token
batches, total = admin_db.get_batch_uploads_by_token(
admin_token=admin_token,
limit=limit,
offset=offset,
)
return {
"batches": [
{
"batch_id": str(b.batch_id),
"filename": b.filename,
"status": b.status,
"total_files": b.total_files,
"successful_files": b.successful_files,
"failed_files": b.failed_files,
"created_at": b.created_at.isoformat() if b.created_at else None,
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
}
for b in batches
],
"total": total,
"limit": limit,
"offset": offset,
}

View File

@@ -0,0 +1,16 @@
"""
Public API v1
Customer-facing endpoints for inference, async processing, and labeling.
"""
from src.web.api.v1.public.inference import create_inference_router
from src.web.api.v1.public.async_api import create_async_router, set_async_service
from src.web.api.v1.public.labeling import create_labeling_router
__all__ = [
"create_inference_router",
"create_async_router",
"set_async_service",
"create_labeling_router",
]

View File

@@ -0,0 +1,372 @@
"""
Async API Routes
FastAPI endpoints for async invoice processing.
"""
import logging
from pathlib import Path
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from src.web.dependencies import (
ApiKeyDep,
AsyncDBDep,
PollRateLimitDep,
SubmitRateLimitDep,
)
from src.web.schemas.inference import (
AsyncRequestItem,
AsyncRequestsListResponse,
AsyncResultResponse,
AsyncStatus,
AsyncStatusResponse,
AsyncSubmitResponse,
DetectionResult,
InferenceResult,
)
from src.web.schemas.common import ErrorResponse
def _validate_request_id(request_id: str) -> None:
"""Validate that request_id is a valid UUID format."""
try:
UUID(request_id)
except ValueError:
raise HTTPException(
status_code=400,
detail="Invalid request ID format. Must be a valid UUID.",
)
logger = logging.getLogger(__name__)
# Global reference to async processing service (set during app startup)
_async_service = None
def set_async_service(service) -> None:
"""Set the async processing service instance."""
global _async_service
_async_service = service
def get_async_service():
"""Get the async processing service instance."""
if _async_service is None:
raise RuntimeError("AsyncProcessingService not initialized")
return _async_service
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
"""Create async API router."""
router = APIRouter(prefix="/async", tags=["Async Processing"])
@router.post(
"/submit",
response_model=AsyncSubmitResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
401: {"model": ErrorResponse, "description": "Invalid API key"},
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
503: {"model": ErrorResponse, "description": "Queue full"},
},
summary="Submit PDF for async processing",
description="Submit a PDF or image file for asynchronous processing. "
"Returns a request_id that can be used to poll for results.",
)
async def submit_document(
api_key: SubmitRateLimitDep,
file: UploadFile = File(..., description="PDF or image file to process"),
) -> AsyncSubmitResponse:
"""Submit a document for async processing."""
# Validate filename
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required")
# Validate file extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. "
f"Allowed: {', '.join(allowed_extensions)}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Check file size (get from config via service)
service = get_async_service()
max_size = service._async_config.max_file_size_mb * 1024 * 1024
if len(content) > max_size:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: "
f"{service._async_config.max_file_size_mb}MB",
)
# Submit request
result = service.submit_request(
api_key=api_key,
file_content=content,
filename=file.filename,
content_type=file.content_type or "application/octet-stream",
)
if not result.success:
if "queue" in (result.error or "").lower():
raise HTTPException(status_code=503, detail=result.error)
raise HTTPException(status_code=500, detail=result.error)
return AsyncSubmitResponse(
status="accepted",
message="Request submitted for processing",
request_id=result.request_id,
estimated_wait_seconds=result.estimated_wait_seconds,
poll_url=f"/api/v1/async/status/{result.request_id}",
)
@router.get(
"/status/{request_id}",
response_model=AsyncStatusResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get request status",
description="Get the current processing status of an async request.",
)
async def get_status(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncStatusResponse:
"""Get the status of an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Get queue position for pending requests
position = None
if request.status == "pending":
position = db.get_queue_position(request_id)
# Build result URL for completed requests
result_url = None
if request.status == "completed":
result_url = f"/api/v1/async/result/{request_id}"
return AsyncStatusResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
filename=request.filename,
created_at=request.created_at,
started_at=request.started_at,
completed_at=request.completed_at,
position_in_queue=position,
error_message=request.error_message,
result_url=result_url,
)
@router.get(
"/result/{request_id}",
response_model=AsyncResultResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Request not completed"},
429: {"model": ErrorResponse, "description": "Polling too frequently"},
},
summary="Get extraction results",
description="Get the extraction results for a completed async request.",
)
async def get_result(
request_id: str,
api_key: PollRateLimitDep,
db: AsyncDBDep,
) -> AsyncResultResponse:
"""Get the results of a completed async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database (validates API key ownership)
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Check if completed or failed
if request.status not in ("completed", "failed"):
raise HTTPException(
status_code=409,
detail=f"Request not yet completed. Current status: {request.status}",
)
# Build inference result from stored data
inference_result = None
if request.result:
# Convert detections to DetectionResult objects
detections = []
for d in request.result.get("detections", []):
detections.append(DetectionResult(
field=d.get("field", ""),
confidence=d.get("confidence", 0.0),
bbox=d.get("bbox", [0, 0, 0, 0]),
))
inference_result = InferenceResult(
document_id=request.result.get("document_id", str(request.request_id)[:8]),
success=request.result.get("success", False),
document_type=request.result.get("document_type", "invoice"),
fields=request.result.get("fields", {}),
confidence=request.result.get("confidence", {}),
detections=detections,
processing_time_ms=request.processing_time_ms or 0.0,
errors=request.result.get("errors", []),
)
# Build visualization URL
viz_url = None
if request.visualization_path:
viz_url = f"/api/v1/results/{request.visualization_path}"
return AsyncResultResponse(
request_id=str(request.request_id),
status=AsyncStatus(request.status),
processing_time_ms=request.processing_time_ms or 0.0,
result=inference_result,
visualization_url=viz_url,
)
@router.get(
"/requests",
response_model=AsyncRequestsListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
},
summary="List requests",
description="List all async requests for the authenticated API key.",
)
async def list_requests(
api_key: ApiKeyDep,
db: AsyncDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (pending, processing, completed, failed)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Maximum number of results"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Pagination offset"),
] = 0,
) -> AsyncRequestsListResponse:
"""List all requests for the authenticated API key."""
# Validate status filter
if status and status not in ("pending", "processing", "completed", "failed"):
raise HTTPException(
status_code=400,
detail=f"Invalid status filter: {status}. "
"Must be one of: pending, processing, completed, failed",
)
# Get requests from database
requests, total = db.get_requests_by_api_key(
api_key=api_key,
status=status,
limit=limit,
offset=offset,
)
# Convert to response items
items = [
AsyncRequestItem(
request_id=str(r.request_id),
status=AsyncStatus(r.status),
filename=r.filename,
file_size=r.file_size,
created_at=r.created_at,
completed_at=r.completed_at,
)
for r in requests
]
return AsyncRequestsListResponse(
total=total,
limit=limit,
offset=offset,
requests=items,
)
@router.delete(
"/requests/{request_id}",
responses={
401: {"model": ErrorResponse, "description": "Invalid API key"},
404: {"model": ErrorResponse, "description": "Request not found"},
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
},
summary="Cancel/delete request",
description="Cancel a pending request or delete a completed/failed request.",
)
async def delete_request(
request_id: str,
api_key: ApiKeyDep,
db: AsyncDBDep,
) -> dict:
"""Delete or cancel an async request."""
# Validate UUID format
_validate_request_id(request_id)
# Get request from database
request = db.get_request_by_api_key(request_id, api_key)
if request is None:
raise HTTPException(
status_code=404,
detail="Request not found or does not belong to this API key",
)
# Cannot delete processing requests
if request.status == "processing":
raise HTTPException(
status_code=409,
detail="Cannot delete a request that is currently processing",
)
# Delete from database (will cascade delete related records)
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute(
"DELETE FROM async_requests WHERE request_id = %s",
(request_id,),
)
conn.commit()
return {
"status": "deleted",
"request_id": request_id,
"message": "Request deleted successfully",
}
return router

View File

@@ -1,5 +1,5 @@
"""
API Routes
Inference API Routes
FastAPI route definitions for the inference API.
"""
@@ -15,23 +15,22 @@ from typing import TYPE_CHECKING
from fastapi import APIRouter, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse
from .schemas import (
BatchInferenceResponse,
from src.web.schemas.inference import (
DetectionResult,
ErrorResponse,
HealthResponse,
InferenceResponse,
InferenceResult,
)
from src.web.schemas.common import ErrorResponse
if TYPE_CHECKING:
from .services import InferenceService
from .config import StorageConfig
from src.web.services import InferenceService
from src.web.config import StorageConfig
logger = logging.getLogger(__name__)
def create_api_router(
def create_inference_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:

View File

@@ -0,0 +1,203 @@
"""
Labeling API Routes
FastAPI endpoints for pre-labeling documents with expected field values.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from src.data.admin_db import AdminDB
from src.web.schemas.labeling import PreLabelResponse
from src.web.schemas.common import ErrorResponse
if TYPE_CHECKING:
from src.web.services import InferenceService
from src.web.config import StorageConfig
logger = logging.getLogger(__name__)
# Storage directory for pre-label uploads (legacy, now uses storage_config)
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
) -> None:
"""Convert PDF pages to images for annotation."""
import fitz
doc_images_dir = images_dir / document_id
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count):
page = pdf_doc[page_num]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png"
pix.save(str(image_path))
pdf_doc.close()
def get_admin_db() -> AdminDB:
"""Get admin database instance."""
return AdminDB()
def create_labeling_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with labeling endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["labeling"])
# Ensure upload directory exists
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@router.post(
"/pre-label",
response_model=PreLabelResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
summary="Pre-label document with expected values",
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
)
async def pre_label(
file: UploadFile = File(..., description="PDF or image file to process"),
field_values: str = Form(
...,
description="JSON object with expected field values. "
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
),
db: AdminDB = Depends(get_admin_db),
) -> PreLabelResponse:
"""
Upload a document with expected field values for pre-labeling.
Returns document_id which can be used to retrieve results later.
Example field_values JSON:
```json
{
"InvoiceNumber": "12345",
"Amount": "1500.00",
"Bankgiro": "123-4567",
"OCR": "1234567890"
}
```
"""
# Parse field_values JSON
try:
expected_values = json.loads(field_values)
if not isinstance(expected_values, dict):
raise ValueError("field_values must be a JSON object")
except json.JSONDecodeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid JSON in field_values: {e}",
)
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Read file content
try:
content = await file.read()
except Exception as e:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to read file",
)
# Get page count for PDF
page_count = 1
if file_ext == ".pdf":
try:
import fitz
pdf_doc = fitz.open(stream=content, filetype="pdf")
page_count = len(pdf_doc)
pdf_doc.close()
except Exception as e:
logger.warning(f"Failed to get PDF page count: {e}")
# Create document record with field_values
document_id = db.create_document(
filename=file.filename,
file_size=len(content),
content_type=file.content_type or "application/octet-stream",
file_path="", # Will update after saving
page_count=page_count,
upload_source="api",
csv_field_values=expected_values,
)
# Save file to admin uploads
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
try:
file_path.write_bytes(content)
except Exception as e:
logger.error(f"Failed to save file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save file",
)
# Update file path in database
db.update_document_file_path(document_id, str(file_path))
# Convert PDF to images for annotation UI
if file_ext == ".pdf":
try:
_convert_pdf_to_images(
document_id, content, page_count,
storage_config.admin_images_dir, storage_config.dpi
)
except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}")
# Trigger auto-labeling
db.update_document_status(
document_id=document_id,
status="auto_labeling",
auto_label_status="pending",
)
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
return PreLabelResponse(document_id=document_id)
return router

View File

@@ -17,8 +17,39 @@ from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config
from .routes import create_api_router
from .services import InferenceService
from src.web.services import InferenceService
# Public API imports
from src.web.api.v1.public import (
create_inference_router,
create_async_router,
set_async_service,
create_labeling_router,
)
# Async processing imports
from src.data.async_request_db import AsyncRequestDB
from src.web.workers.async_queue import AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService
from src.web.dependencies import init_dependencies
from src.web.core.rate_limiter import RateLimiter
# Admin API imports
from src.web.api.v1.admin import (
create_annotation_router,
create_auth_router,
create_documents_router,
create_locks_router,
create_training_router,
)
from src.web.core.scheduler import start_scheduler, stop_scheduler
from src.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
# Batch upload imports
from src.web.api.v1.batch.routes import router as batch_upload_router
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from src.web.services.batch_upload import BatchUploadService
from src.data.admin_db import AdminDB
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
@@ -44,11 +75,38 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
storage_config=config.storage,
)
# Create async processing components
async_db = AsyncRequestDB()
rate_limiter = RateLimiter(async_db)
task_queue = AsyncTaskQueue(
max_size=config.async_processing.queue_max_size,
worker_count=config.async_processing.worker_count,
)
async_service = AsyncProcessingService(
inference_service=inference_service,
db=async_db,
queue=task_queue,
rate_limiter=rate_limiter,
async_config=config.async_processing,
storage_config=config.storage,
)
# Initialize dependencies for FastAPI
init_dependencies(async_db, rate_limiter)
set_async_service(async_service)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan manager."""
logger.info("Starting Invoice Inference API...")
# Initialize database tables
try:
async_db.create_tables()
logger.info("Async database tables ready")
except Exception as e:
logger.error(f"Failed to initialize async database: {e}")
# Initialize inference service on startup
try:
inference_service.initialize()
@@ -57,10 +115,75 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
logger.error(f"Failed to initialize inference service: {e}")
# Continue anyway - service will retry on first request
# Start async processing service
try:
async_service.start()
logger.info("Async processing service started")
except Exception as e:
logger.error(f"Failed to start async processing: {e}")
# Start batch upload queue
try:
admin_db = AdminDB()
batch_service = BatchUploadService(admin_db)
init_batch_queue(batch_service)
logger.info("Batch upload queue started")
except Exception as e:
logger.error(f"Failed to start batch upload queue: {e}")
# Start training scheduler
try:
start_scheduler()
logger.info("Training scheduler started")
except Exception as e:
logger.error(f"Failed to start training scheduler: {e}")
# Start auto-label scheduler
try:
start_autolabel_scheduler()
logger.info("AutoLabel scheduler started")
except Exception as e:
logger.error(f"Failed to start autolabel scheduler: {e}")
yield
logger.info("Shutting down Invoice Inference API...")
# Stop auto-label scheduler
try:
stop_autolabel_scheduler()
logger.info("AutoLabel scheduler stopped")
except Exception as e:
logger.error(f"Error stopping autolabel scheduler: {e}")
# Stop training scheduler
try:
stop_scheduler()
logger.info("Training scheduler stopped")
except Exception as e:
logger.error(f"Error stopping training scheduler: {e}")
# Stop batch upload queue
try:
shutdown_batch_queue()
logger.info("Batch upload queue stopped")
except Exception as e:
logger.error(f"Error stopping batch upload queue: {e}")
# Stop async processing service
try:
async_service.stop(timeout=30.0)
logger.info("Async processing service stopped")
except Exception as e:
logger.error(f"Error stopping async service: {e}")
# Close database connection
try:
async_db.close()
logger.info("Database connection closed")
except Exception as e:
logger.error(f"Error closing database: {e}")
# Create FastAPI app
app = FastAPI(
title="Invoice Field Extraction API",
@@ -106,9 +229,34 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
name="results",
)
# Include API routes
api_router = create_api_router(inference_service, config.storage)
app.include_router(api_router)
# Include public API routes
inference_router = create_inference_router(inference_service, config.storage)
app.include_router(inference_router)
async_router = create_async_router(config.storage.allowed_extensions)
app.include_router(async_router, prefix="/api/v1")
labeling_router = create_labeling_router(inference_service, config.storage)
app.include_router(labeling_router)
# Include admin API routes
auth_router = create_auth_router()
app.include_router(auth_router, prefix="/api/v1")
documents_router = create_documents_router(config.storage)
app.include_router(documents_router, prefix="/api/v1")
locks_router = create_locks_router()
app.include_router(locks_router, prefix="/api/v1")
annotation_router = create_annotation_router()
app.include_router(annotation_router, prefix="/api/v1")
training_router = create_training_router()
app.include_router(training_router, prefix="/api/v1")
# Include batch upload routes
app.include_router(batch_upload_router)
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)

View File

@@ -8,6 +8,8 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from src.config import DEFAULT_DPI, PATHS
@dataclass(frozen=True)
class ModelConfig:
@@ -16,7 +18,7 @@ class ModelConfig:
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
confidence_threshold: float = 0.5
use_gpu: bool = True
dpi: int = 150
dpi: int = DEFAULT_DPI
@dataclass(frozen=True)
@@ -32,19 +34,59 @@ class ServerConfig:
@dataclass(frozen=True)
class StorageConfig:
"""File storage configuration."""
"""File storage configuration.
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
and avoids storing duplicate files.
"""
upload_dir: Path = Path("uploads")
result_dir: Path = Path("results")
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
admin_images_dir: Path = Path("data/admin_images")
max_file_size_mb: int = 50
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
dpi: int = DEFAULT_DPI
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
object.__setattr__(self, "result_dir", Path(self.result_dir))
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
self.upload_dir.mkdir(parents=True, exist_ok=True)
self.result_dir.mkdir(parents=True, exist_ok=True)
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
@dataclass(frozen=True)
class AsyncConfig:
"""Async processing configuration."""
# Queue settings
queue_max_size: int = 100
worker_count: int = 1
task_timeout_seconds: int = 300
# Rate limiting defaults
default_requests_per_minute: int = 10
default_max_concurrent_jobs: int = 3
default_min_poll_interval_ms: int = 1000
# Storage
result_retention_days: int = 7
temp_upload_dir: Path = Path("uploads/async")
max_file_size_mb: int = 50
# Cleanup
cleanup_interval_hours: int = 1
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
@dataclass
@@ -54,6 +96,7 @@ class AppConfig:
model: ModelConfig = field(default_factory=ModelConfig)
server: ServerConfig = field(default_factory=ServerConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
@@ -62,6 +105,7 @@ class AppConfig:
model=ModelConfig(**config_dict.get("model", {})),
server=ServerConfig(**config_dict.get("server", {})),
storage=StorageConfig(**config_dict.get("storage", {})),
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
)

28
src/web/core/__init__.py Normal file
View File

@@ -0,0 +1,28 @@
"""
Core Components
Reusable core functionality: authentication, rate limiting, scheduling.
"""
from src.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
from src.web.core.rate_limiter import RateLimiter
from src.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
from src.web.core.autolabel_scheduler import (
start_autolabel_scheduler,
stop_autolabel_scheduler,
get_autolabel_scheduler,
)
__all__ = [
"validate_admin_token",
"get_admin_db",
"AdminTokenDep",
"AdminDBDep",
"RateLimiter",
"start_scheduler",
"stop_scheduler",
"get_training_scheduler",
"start_autolabel_scheduler",
"stop_autolabel_scheduler",
"get_autolabel_scheduler",
]

60
src/web/core/auth.py Normal file
View File

@@ -0,0 +1,60 @@
"""
Admin Authentication
FastAPI dependencies for admin token authentication.
"""
import logging
from typing import Annotated
from fastapi import Depends, Header, HTTPException
from src.data.admin_db import AdminDB
from src.data.database import get_session_context
logger = logging.getLogger(__name__)
# Global AdminDB instance
_admin_db: AdminDB | None = None
def get_admin_db() -> AdminDB:
"""Get the AdminDB instance."""
global _admin_db
if _admin_db is None:
_admin_db = AdminDB()
return _admin_db
def reset_admin_db() -> None:
"""Reset the AdminDB instance (for testing)."""
global _admin_db
_admin_db = None
async def validate_admin_token(
x_admin_token: Annotated[str | None, Header()] = None,
admin_db: AdminDB = Depends(get_admin_db),
) -> str:
"""Validate admin token from header."""
if not x_admin_token:
raise HTTPException(
status_code=401,
detail="Admin token required. Provide X-Admin-Token header.",
)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(
status_code=401,
detail="Invalid or expired admin token.",
)
# Update last used timestamp
admin_db.update_admin_token_usage(x_admin_token)
return x_admin_token
# Type alias for dependency injection
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]

View File

@@ -0,0 +1,153 @@
"""
Auto-Label Scheduler
Background scheduler for processing documents pending auto-labeling.
"""
import logging
import threading
from pathlib import Path
from src.data.admin_db import AdminDB
from src.web.services.db_autolabel import (
get_pending_autolabel_documents,
process_document_autolabel,
)
logger = logging.getLogger(__name__)
class AutoLabelScheduler:
"""Scheduler for auto-labeling tasks."""
def __init__(
self,
check_interval_seconds: int = 10,
batch_size: int = 5,
output_dir: Path | None = None,
):
"""
Initialize auto-label scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
batch_size: Number of documents to process per batch
output_dir: Output directory for temporary files
"""
self._check_interval = check_interval_seconds
self._batch_size = batch_size
self._output_dir = output_dir or Path("data/autolabel_output")
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("AutoLabel scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("AutoLabel scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("AutoLabel scheduler stopped")
@property
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._process_pending_documents()
except Exception as e:
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _process_pending_documents(self) -> None:
"""Check and process pending auto-label documents."""
try:
documents = get_pending_autolabel_documents(
self._db, limit=self._batch_size
)
if not documents:
return
logger.info(f"Processing {len(documents)} pending autolabel documents")
for doc in documents:
if self._stop_event.is_set():
break
try:
result = process_document_autolabel(
document=doc,
db=self._db,
output_dir=self._output_dir,
)
if result.get("success"):
logger.info(
f"AutoLabel completed for document {doc.document_id}"
)
else:
logger.warning(
f"AutoLabel failed for document {doc.document_id}: "
f"{result.get('error', 'Unknown error')}"
)
except Exception as e:
logger.error(
f"Error processing document {doc.document_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
# Global scheduler instance
_autolabel_scheduler: AutoLabelScheduler | None = None
def get_autolabel_scheduler() -> AutoLabelScheduler:
"""Get the auto-label scheduler instance."""
global _autolabel_scheduler
if _autolabel_scheduler is None:
_autolabel_scheduler = AutoLabelScheduler()
return _autolabel_scheduler
def start_autolabel_scheduler() -> None:
"""Start the global auto-label scheduler."""
scheduler = get_autolabel_scheduler()
scheduler.start()
def stop_autolabel_scheduler() -> None:
"""Stop the global auto-label scheduler."""
global _autolabel_scheduler
if _autolabel_scheduler:
_autolabel_scheduler.stop()
_autolabel_scheduler = None

View File

@@ -0,0 +1,211 @@
"""
Rate Limiter Implementation
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
"""
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from src.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RateLimitConfig:
"""Rate limit configuration for an API key."""
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000 # Minimum time between status polls
@dataclass
class RateLimitStatus:
"""Current rate limit status."""
allowed: bool
remaining_requests: int
reset_at: datetime
retry_after_seconds: int | None = None
reason: str | None = None
class RateLimiter:
"""
Thread-safe rate limiter with sliding window algorithm.
Tracks:
- Requests per minute (sliding window)
- Concurrent active jobs
- Poll frequency per request_id
"""
def __init__(self, db: "AsyncRequestDB") -> None:
self._db = db
self._lock = Lock()
# In-memory tracking for fast checks
self._request_windows: dict[str, list[float]] = defaultdict(list)
# (api_key, request_id) -> last_poll timestamp
self._poll_timestamps: dict[tuple[str, str], float] = {}
# Cache for API key configs (TTL 60 seconds)
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
self._config_cache_ttl = 60.0
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
"""Check if API key can submit a new request."""
config = self._get_config(api_key)
with self._lock:
now = time.time()
window_start = now - 60 # 1 minute window
# Clean old entries
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
current_count = len(self._request_windows[api_key])
if current_count >= config.requests_per_minute:
oldest = min(self._request_windows[api_key])
retry_after = int(oldest + 60 - now) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
retry_after_seconds=max(1, retry_after),
reason="Rate limit exceeded: too many requests per minute",
)
# Check concurrent jobs (query database) - inside lock for thread safety
active_jobs = self._db.count_active_jobs(api_key)
if active_jobs >= config.max_concurrent_jobs:
return RateLimitStatus(
allowed=False,
remaining_requests=config.requests_per_minute - current_count,
reset_at=datetime.utcnow() + timedelta(seconds=30),
retry_after_seconds=30,
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
)
return RateLimitStatus(
allowed=True,
remaining_requests=config.requests_per_minute - current_count - 1,
reset_at=datetime.utcnow() + timedelta(seconds=60),
)
def record_request(self, api_key: str) -> None:
"""Record a successful request submission."""
with self._lock:
self._request_windows[api_key].append(time.time())
# Also record in database for persistence
try:
self._db.record_rate_limit_event(api_key, "request")
except Exception as e:
logger.warning(f"Failed to record rate limit event: {e}")
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
"""Check if polling is allowed (prevent abuse)."""
config = self._get_config(api_key)
key = (api_key, request_id)
with self._lock:
now = time.time()
last_poll = self._poll_timestamps.get(key, 0)
elapsed_ms = (now - last_poll) * 1000
if elapsed_ms < config.min_poll_interval_ms:
# Suggest exponential backoff
wait_ms = min(
config.min_poll_interval_ms * 2,
5000, # Max 5 seconds
)
retry_after = int(wait_ms / 1000) + 1
return RateLimitStatus(
allowed=False,
remaining_requests=0,
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
retry_after_seconds=retry_after,
reason="Polling too frequently. Please wait before retrying.",
)
# Update poll timestamp
self._poll_timestamps[key] = now
return RateLimitStatus(
allowed=True,
remaining_requests=999, # No limit on poll count, just frequency
reset_at=datetime.utcnow(),
)
def _get_config(self, api_key: str) -> RateLimitConfig:
"""Get rate limit config for API key with caching."""
now = time.time()
# Check cache
if api_key in self._config_cache:
cached_config, cached_at = self._config_cache[api_key]
if now - cached_at < self._config_cache_ttl:
return cached_config
# Query database
db_config = self._db.get_api_key_config(api_key)
if db_config:
config = RateLimitConfig(
requests_per_minute=db_config.requests_per_minute,
max_concurrent_jobs=db_config.max_concurrent_jobs,
)
else:
config = RateLimitConfig() # Default limits
# Cache result
self._config_cache[api_key] = (config, now)
return config
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
"""Clean up old poll timestamps to prevent memory leak."""
with self._lock:
now = time.time()
cutoff = now - max_age_seconds
old_keys = [
k for k, v in self._poll_timestamps.items()
if v < cutoff
]
for key in old_keys:
del self._poll_timestamps[key]
return len(old_keys)
def cleanup_request_windows(self) -> None:
"""Clean up expired entries from request windows."""
with self._lock:
now = time.time()
window_start = now - 60
for api_key in list(self._request_windows.keys()):
self._request_windows[api_key] = [
ts for ts in self._request_windows[api_key]
if ts > window_start
]
# Remove empty entries
if not self._request_windows[api_key]:
del self._request_windows[api_key]
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
"""Generate rate limit headers for HTTP response."""
headers = {
"X-RateLimit-Remaining": str(status.remaining_requests),
"X-RateLimit-Reset": status.reset_at.isoformat(),
}
if status.retry_after_seconds:
headers["Retry-After"] = str(status.retry_after_seconds)
return headers

329
src/web/core/scheduler.py Normal file
View File

@@ -0,0 +1,329 @@
"""
Admin Training Scheduler
Background scheduler for training tasks using APScheduler.
"""
import logging
import threading
from datetime import datetime
from pathlib import Path
from typing import Any
from src.data.admin_db import AdminDB
logger = logging.getLogger(__name__)
class TrainingScheduler:
"""Scheduler for training tasks."""
def __init__(
self,
check_interval_seconds: int = 60,
):
"""
Initialize training scheduler.
Args:
check_interval_seconds: Interval to check for pending tasks
"""
self._check_interval = check_interval_seconds
self._running = False
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._db = AdminDB()
def start(self) -> None:
"""Start the scheduler."""
if self._running:
logger.warning("Training scheduler already running")
return
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
logger.info("Training scheduler started")
def stop(self) -> None:
"""Stop the scheduler."""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("Training scheduler stopped")
def _run_loop(self) -> None:
"""Main scheduler loop."""
while self._running:
try:
self._check_pending_tasks()
except Exception as e:
logger.error(f"Error in scheduler loop: {e}")
# Wait for next check interval
self._stop_event.wait(timeout=self._check_interval)
def _check_pending_tasks(self) -> None:
"""Check and execute pending training tasks."""
try:
tasks = self._db.get_pending_training_tasks()
for task in tasks:
task_id = str(task.task_id)
# Check if scheduled time has passed
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
continue
logger.info(f"Starting training task: {task_id}")
try:
self._execute_task(task_id, task.config or {})
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.update_training_task_status(
task_id=task_id,
status="failed",
error_message=str(e),
)
except Exception as e:
logger.error(f"Error checking pending tasks: {e}")
def _execute_task(self, task_id: str, config: dict[str, Any]) -> None:
"""Execute a training task."""
# Update status to running
self._db.update_training_task_status(task_id, "running")
self._db.add_training_log(task_id, "INFO", "Training task started")
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
image_size = config.get("image_size", 640)
learning_rate = config.get("learning_rate", 0.01)
device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields")
# Export annotations for training
export_result = self._export_training_data(task_id)
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=model_name,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
image_size=image_size,
learning_rate=learning_rate,
device=device,
project_name=project_name,
)
# Update task with results
self._db.update_training_task_status(
task_id=task_id,
status="completed",
result_metrics=result.get("metrics"),
model_path=result.get("model_path"),
)
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
raise
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
from pathlib import Path
import shutil
from src.data.admin_models import FIELD_CLASSES
# Get all labeled documents
documents = self._db.get_labeled_documents_for_export()
if not documents:
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
return None
# Create export directory
export_dir = Path("data/training") / task_id
export_dir.mkdir(parents=True, exist_ok=True)
# YOLO format directories
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
# 80/20 train/val split
total_docs = len(documents)
train_count = int(total_docs * 0.8)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
# Export documents
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = self._db.get_annotations_for_document(str(doc.document_id))
if not annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
# Copy image
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
continue
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
total_images += 1
# Write YOLO label
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
# Create data.yaml
yaml_path = export_dir / "data.yaml"
yaml_content = f"""path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
yaml_path.write_text(yaml_content)
return {
"data_yaml": str(yaml_path),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _run_yolo_training(
self,
task_id: str,
model_name: str,
data_yaml: str,
epochs: int,
batch_size: int,
image_size: int,
learning_rate: float,
device: str,
project_name: str,
) -> dict[str, Any]:
"""Run YOLO training."""
try:
from ultralytics import YOLO
# Log training start
self._db.add_training_log(
task_id, "INFO",
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
)
# Load model
model = YOLO(model_name)
# Train
results = model.train(
data=data_yaml,
epochs=epochs,
batch=batch_size,
imgsz=image_size,
lr0=learning_rate,
device=device,
project=f"runs/train/{project_name}",
name=f"task_{task_id[:8]}",
exist_ok=True,
verbose=True,
)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
self._db.add_training_log(
task_id, "INFO",
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
)
return {
"model_path": str(best_model) if best_model.exists() else None,
"metrics": metrics,
}
except ImportError:
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
raise ValueError("Ultralytics (YOLO) not installed")
except Exception as e:
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
raise
# Global scheduler instance
_scheduler: TrainingScheduler | None = None
def get_training_scheduler() -> TrainingScheduler:
"""Get the training scheduler instance."""
global _scheduler
if _scheduler is None:
_scheduler = TrainingScheduler()
return _scheduler
def start_scheduler() -> None:
"""Start the global training scheduler."""
scheduler = get_training_scheduler()
scheduler.start()
def stop_scheduler() -> None:
"""Stop the global training scheduler."""
global _scheduler
if _scheduler:
_scheduler.stop()
_scheduler = None

Some files were not shown because too many files have changed in this diff Show More