WIP
This commit is contained in:
@@ -7,7 +7,8 @@
|
||||
"Edit(*)",
|
||||
"Glob(*)",
|
||||
"Grep(*)",
|
||||
"Task(*)"
|
||||
"Task(*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_batch_upload_routes.py::TestBatchUploadRoutes::test_upload_batch_async_mode_default -v -s 2>&1 | head -100\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,13 @@
|
||||
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
|
||||
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
|
||||
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
|
||||
"Bash(wsl bash -c:*)"
|
||||
"Bash(wsl bash -c:*)",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -120\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -80\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
|
||||
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": [],
|
||||
|
||||
34
README.md
34
README.md
@@ -76,6 +76,38 @@
|
||||
| 8 | payment_line | 支付行 (机器可读格式) |
|
||||
| 9 | customer_number | 客户编号 |
|
||||
|
||||
## DPI 配置
|
||||
|
||||
**重要**: 系统所有组件统一使用 **150 DPI**,确保训练和推理的一致性。
|
||||
|
||||
DPI(每英寸点数)设置必须在训练和推理时保持一致,否则会导致:
|
||||
- 检测框尺寸失配
|
||||
- mAP显著下降(可能从93.5%降到60-70%)
|
||||
- 字段漏检或误检
|
||||
|
||||
### 配置位置
|
||||
|
||||
| 组件 | 配置文件 | 配置项 |
|
||||
|------|---------|--------|
|
||||
| **全局常量** | `src/config.py` | `DEFAULT_DPI = 150` |
|
||||
| **Web推理** | `src/web/config.py` | `ModelConfig.dpi` (导入自 `src.config`) |
|
||||
| **CLI推理** | `src/cli/infer.py` | `--dpi` 默认值 = `DEFAULT_DPI` |
|
||||
| **自动标注** | `src/config.py` | `AUTOLABEL['dpi'] = DEFAULT_DPI` |
|
||||
| **PDF转图** | `src/web/api/v1/admin/documents.py` | 使用 `DEFAULT_DPI` |
|
||||
|
||||
### 使用示例
|
||||
|
||||
```bash
|
||||
# 训练(使用默认150 DPI)
|
||||
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
|
||||
|
||||
# 推理(默认150 DPI,与训练一致)
|
||||
python -m src.cli.infer -m runs/train/invoice_fields/weights/best.pt -i invoice.pdf
|
||||
|
||||
# 手动指定DPI(仅当需要与非默认训练DPI的模型配合时)
|
||||
python -m src.cli.infer -m custom_model.pt -i invoice.pdf --dpi 150
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
@@ -490,7 +522,7 @@ Options:
|
||||
--input, -i 输入 PDF/图像
|
||||
--output, -o 输出 JSON 路径
|
||||
--confidence 置信度阈值 (默认: 0.5)
|
||||
--dpi 渲染 DPI (默认: 300)
|
||||
--dpi 渲染 DPI (默认: 150, 必须与训练DPI一致)
|
||||
--gpu 使用 GPU
|
||||
```
|
||||
|
||||
|
||||
96
create_shims.sh
Normal file
96
create_shims.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create backward compatibility shims for all migrated files
|
||||
|
||||
# admin_auth.py -> core/auth.py
|
||||
cat > src/web/admin_auth.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.core.auth instead"""
|
||||
from src.web.core.auth import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_autolabel.py -> services/autolabel.py
|
||||
cat > src/web/admin_autolabel.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.autolabel instead"""
|
||||
from src.web.services.autolabel import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_scheduler.py -> core/scheduler.py
|
||||
cat > src/web/admin_scheduler.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.core.scheduler instead"""
|
||||
from src.web.core.scheduler import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_schemas.py -> schemas/admin.py
|
||||
cat > src/web/admin_schemas.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.schemas.admin instead"""
|
||||
from src.web.schemas.admin import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# schemas.py -> schemas/inference.py + schemas/common.py
|
||||
cat > src/web/schemas.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.schemas.inference or src.web.schemas.common instead"""
|
||||
from src.web.schemas.inference import * # noqa: F401, F403
|
||||
from src.web.schemas.common import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# services.py -> services/inference.py
|
||||
cat > src/web/services.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.inference instead"""
|
||||
from src.web.services.inference import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# async_queue.py -> workers/async_queue.py
|
||||
cat > src/web/async_queue.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.workers.async_queue instead"""
|
||||
from src.web.workers.async_queue import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# async_service.py -> services/async_processing.py
|
||||
cat > src/web/async_service.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.async_processing instead"""
|
||||
from src.web.services.async_processing import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_queue.py -> workers/batch_queue.py
|
||||
cat > src/web/batch_queue.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.workers.batch_queue instead"""
|
||||
from src.web.workers.batch_queue import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_upload_service.py -> services/batch_upload.py
|
||||
cat > src/web/batch_upload_service.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.batch_upload instead"""
|
||||
from src.web.services.batch_upload import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_upload_routes.py -> api/v1/batch/routes.py
|
||||
cat > src/web/batch_upload_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.batch.routes instead"""
|
||||
from src.web.api.v1.batch.routes import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_routes.py -> api/v1/admin/documents.py
|
||||
cat > src/web/admin_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.documents instead"""
|
||||
from src.web.api.v1.admin.documents import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_annotation_routes.py -> api/v1/admin/annotations.py
|
||||
cat > src/web/admin_annotation_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.annotations instead"""
|
||||
from src.web.api.v1.admin.annotations import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_training_routes.py -> api/v1/admin/training.py
|
||||
cat > src/web/admin_training_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.training instead"""
|
||||
from src.web.api.v1.admin.training import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# routes.py -> api/v1/routes.py
|
||||
cat > src/web/routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.routes instead"""
|
||||
from src.web.api.v1.routes import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
echo "✓ Created backward compatibility shims for all migrated files"
|
||||
@@ -1,405 +0,0 @@
|
||||
# Invoice Master POC v2 - 代码审查报告
|
||||
|
||||
**审查日期**: 2026-01-22
|
||||
**代码库规模**: 67 个 Python 源文件,约 22,434 行代码
|
||||
**测试覆盖率**: ~40-50%
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
### 总体评估:**良好(B+)**
|
||||
|
||||
**优势**:
|
||||
- ✅ 清晰的模块化架构,职责分离良好
|
||||
- ✅ 使用了合适的数据类和类型提示
|
||||
- ✅ 针对瑞典发票的全面规范化逻辑
|
||||
- ✅ 空间索引优化(O(1) token 查找)
|
||||
- ✅ 完善的降级机制(YOLO 失败时的 OCR fallback)
|
||||
- ✅ 设计良好的 Web API 和 UI
|
||||
|
||||
**主要问题**:
|
||||
- ❌ 支付行解析代码重复(3+ 处)
|
||||
- ❌ 长函数(`_normalize_customer_number` 127 行)
|
||||
- ❌ 配置安全问题(明文数据库密码)
|
||||
- ❌ 异常处理不一致(到处都是通用 Exception)
|
||||
- ❌ 缺少集成测试
|
||||
- ❌ 魔法数字散布各处(0.5, 0.95, 300 等)
|
||||
|
||||
---
|
||||
|
||||
## 1. 架构分析
|
||||
|
||||
### 1.1 模块结构
|
||||
|
||||
```
|
||||
src/
|
||||
├── inference/ # 推理管道核心
|
||||
│ ├── pipeline.py (517 行) ⚠️
|
||||
│ ├── field_extractor.py (1,347 行) 🔴 太长
|
||||
│ └── yolo_detector.py
|
||||
├── web/ # FastAPI Web 服务
|
||||
│ ├── app.py (765 行) ⚠️ HTML 内联
|
||||
│ ├── routes.py (184 行)
|
||||
│ └── services.py (286 行)
|
||||
├── ocr/ # OCR 提取
|
||||
│ ├── paddle_ocr.py
|
||||
│ └── machine_code_parser.py (919 行) 🔴 太长
|
||||
├── matcher/ # 字段匹配
|
||||
│ └── field_matcher.py (875 行) ⚠️
|
||||
├── utils/ # 共享工具
|
||||
│ ├── validators.py
|
||||
│ ├── text_cleaner.py
|
||||
│ ├── fuzzy_matcher.py
|
||||
│ ├── ocr_corrections.py
|
||||
│ └── format_variants.py (610 行)
|
||||
├── processing/ # 批处理
|
||||
├── data/ # 数据管理
|
||||
└── cli/ # 命令行工具
|
||||
```
|
||||
|
||||
### 1.2 推理流程
|
||||
|
||||
```
|
||||
PDF/Image 输入
|
||||
↓
|
||||
渲染为图片 (pdf/renderer.py)
|
||||
↓
|
||||
YOLO 检测 (yolo_detector.py) - 检测字段区域
|
||||
↓
|
||||
字段提取 (field_extractor.py)
|
||||
├→ OCR 文本提取 (ocr/paddle_ocr.py)
|
||||
├→ 规范化 & 验证
|
||||
└→ 置信度计算
|
||||
↓
|
||||
交叉验证 (pipeline.py)
|
||||
├→ 解析 payment_line 格式
|
||||
├→ 从 payment_line 提取 OCR/Amount/Account
|
||||
└→ 与检测字段验证,payment_line 值优先
|
||||
↓
|
||||
降级 OCR(如果关键字段缺失)
|
||||
├→ 全页 OCR
|
||||
└→ 正则提取
|
||||
↓
|
||||
InferenceResult 输出
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 代码质量问题
|
||||
|
||||
### 2.1 长函数(>50 行)🔴
|
||||
|
||||
| 函数 | 文件 | 行数 | 复杂度 | 问题 |
|
||||
|------|------|------|--------|------|
|
||||
| `_normalize_customer_number()` | field_extractor.py | **127** | 极高 | 4 层模式匹配,7+ 正则,复杂评分 |
|
||||
| `_cross_validate_payment_line()` | pipeline.py | **127** | 极高 | 核心验证逻辑,8+ 条件分支 |
|
||||
| `_normalize_bankgiro()` | field_extractor.py | 62 | 高 | Luhn 验证 + 多种降级 |
|
||||
| `_normalize_plusgiro()` | field_extractor.py | 63 | 高 | 类似 bankgiro |
|
||||
| `_normalize_payment_line()` | field_extractor.py | 74 | 高 | 4 种正则模式 |
|
||||
| `_normalize_amount()` | field_extractor.py | 78 | 高 | 多策略降级 |
|
||||
|
||||
**示例问题** - `_normalize_customer_number()` (第 776-902 行):
|
||||
```python
|
||||
def _normalize_customer_number(self, text: str):
|
||||
# 127 行函数,包含:
|
||||
# - 4 个嵌套的 if/for 循环
|
||||
# - 7 种不同的正则模式
|
||||
# - 5 个评分机制
|
||||
# - 处理有标签和无标签格式
|
||||
```
|
||||
|
||||
**建议**: 拆分为:
|
||||
- `_find_customer_code_patterns()`
|
||||
- `_find_labeled_customer_code()`
|
||||
- `_score_customer_candidates()`
|
||||
|
||||
### 2.2 代码重复 🔴
|
||||
|
||||
**支付行解析(3+ 处重复实现)**:
|
||||
|
||||
1. `_parse_machine_readable_payment_line()` (pipeline.py:217-252)
|
||||
2. `MachineCodeParser.parse()` (machine_code_parser.py:919 行)
|
||||
3. `_normalize_payment_line()` (field_extractor.py:632-705)
|
||||
|
||||
所有三处都实现类似的正则模式:
|
||||
```
|
||||
格式: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
```
|
||||
|
||||
**Bankgiro/Plusgiro 验证(重复)**:
|
||||
- `validators.py`: `is_valid_bankgiro()`, `format_bankgiro()`
|
||||
- `field_extractor.py`: `_normalize_bankgiro()`, `_normalize_plusgiro()`, `_luhn_checksum()`
|
||||
- `normalizer.py`: `normalize_bankgiro()`, `normalize_plusgiro()`
|
||||
- `field_matcher.py`: 类似匹配逻辑
|
||||
|
||||
**建议**: 创建统一模块:
|
||||
```python
|
||||
# src/common/payment_line_parser.py
|
||||
class PaymentLineParser:
|
||||
def parse(text: str) -> PaymentLineResult
|
||||
|
||||
# src/common/giro_validator.py
|
||||
class GiroValidator:
|
||||
def validate_and_format(value: str, giro_type: str) -> str
|
||||
```
|
||||
|
||||
### 2.3 错误处理不一致 ⚠️
|
||||
|
||||
**通用异常捕获(31 处)**:
|
||||
```python
|
||||
except Exception as e: # 代码库中 31 处
|
||||
result.errors.append(str(e))
|
||||
```
|
||||
|
||||
**问题**:
|
||||
- 没有捕获特定错误类型
|
||||
- 通用错误消息丢失上下文
|
||||
- 第 142-147 行 (routes.py): 捕获所有异常,返回 500 状态
|
||||
|
||||
**当前写法** (routes.py:142-147):
|
||||
```python
|
||||
try:
|
||||
service_result = inference_service.process_pdf(...)
|
||||
except Exception as e: # 太宽泛
|
||||
logger.error(f"Error processing document: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
```
|
||||
|
||||
**改进建议**:
|
||||
```python
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=400, detail="PDF 文件未找到")
|
||||
except PyMuPDFError:
|
||||
raise HTTPException(status_code=400, detail="无效的 PDF 格式")
|
||||
except OCRError:
|
||||
raise HTTPException(status_code=503, detail="OCR 服务不可用")
|
||||
```
|
||||
|
||||
### 2.4 配置安全问题 🔴
|
||||
|
||||
**config.py 第 24-30 行** - 明文凭据:
|
||||
```python
|
||||
DATABASE = {
|
||||
'host': '192.168.68.31', # 硬编码 IP
|
||||
'user': 'docmaster', # 硬编码用户名
|
||||
'password': 'nY6LYK5d', # 🔴 明文密码!
|
||||
'database': 'invoice_master'
|
||||
}
|
||||
```
|
||||
|
||||
**建议**:
|
||||
```python
|
||||
DATABASE = {
|
||||
'host': os.getenv('DB_HOST', 'localhost'),
|
||||
'user': os.getenv('DB_USER', 'docmaster'),
|
||||
'password': os.getenv('DB_PASSWORD'), # 从环境变量读取
|
||||
'database': os.getenv('DB_NAME', 'invoice_master')
|
||||
}
|
||||
```
|
||||
|
||||
### 2.5 魔法数字 ⚠️
|
||||
|
||||
| 值 | 位置 | 用途 | 问题 |
|
||||
|---|------|------|------|
|
||||
| 0.5 | 多处 | 置信度阈值 | 不可按字段配置 |
|
||||
| 0.95 | pipeline.py | payment_line 置信度 | 无说明 |
|
||||
| 300 | 多处 | DPI | 硬编码 |
|
||||
| 0.1 | field_extractor.py | BBox 填充 | 应为配置 |
|
||||
| 72 | 多处 | PDF 基础 DPI | 公式中的魔法数字 |
|
||||
| 50 | field_extractor.py | 客户编号评分加分 | 无说明 |
|
||||
|
||||
**建议**: 提取到配置:
|
||||
```python
|
||||
INFERENCE_CONFIG = {
|
||||
'confidence_threshold': 0.5,
|
||||
'payment_line_confidence': 0.95,
|
||||
'dpi': 300,
|
||||
'bbox_padding': 0.1,
|
||||
}
|
||||
```
|
||||
|
||||
### 2.6 命名不一致 ⚠️
|
||||
|
||||
**字段名称不一致**:
|
||||
- YOLO 类名: `invoice_number`, `ocr_number`, `supplier_org_number`
|
||||
- 字段名: `InvoiceNumber`, `OCR`, `supplier_org_number`
|
||||
- CSV 列名: 可能又不同
|
||||
- 数据库字段名: 另一种变体
|
||||
|
||||
映射维护在多处:
|
||||
- `yolo_detector.py` (90-100 行): `CLASS_TO_FIELD`
|
||||
- 多个其他位置
|
||||
|
||||
---
|
||||
|
||||
## 3. 测试分析
|
||||
|
||||
### 3.1 测试覆盖率
|
||||
|
||||
**测试文件**: 13 个
|
||||
- ✅ 覆盖良好: field_matcher, normalizer, payment_line_parser
|
||||
- ⚠️ 中等覆盖: field_extractor, pipeline
|
||||
- ❌ 覆盖不足: web 层, CLI, 批处理
|
||||
|
||||
**估算覆盖率**: 40-50%
|
||||
|
||||
### 3.2 缺失的测试用例 🔴
|
||||
|
||||
**关键缺失**:
|
||||
1. 交叉验证逻辑 - 最复杂部分,测试很少
|
||||
2. payment_line 解析变体 - 多种实现,边界情况不清楚
|
||||
3. OCR 错误纠正 - 不同策略的复杂逻辑
|
||||
4. Web API 端点 - 没有请求/响应测试
|
||||
5. 批处理 - 多 worker 协调未测试
|
||||
6. 降级 OCR 机制 - YOLO 检测失败时
|
||||
|
||||
---
|
||||
|
||||
## 4. 架构风险
|
||||
|
||||
### 🔴 关键风险
|
||||
|
||||
1. **配置安全** - config.py 中明文数据库凭据(24-30 行)
|
||||
2. **错误恢复** - 宽泛的异常处理掩盖真实问题
|
||||
3. **可测试性** - 硬编码依赖阻止单元测试
|
||||
|
||||
### 🟡 高风险
|
||||
|
||||
1. **代码可维护性** - 支付行解析重复
|
||||
2. **可扩展性** - 没有长时间推理的异步处理
|
||||
3. **扩展性** - 添加新字段类型会很困难
|
||||
|
||||
### 🟢 中等风险
|
||||
|
||||
1. **性能** - 懒加载有帮助,但 ORM 查询未优化
|
||||
2. **文档** - 大部分足够但可以更好
|
||||
|
||||
---
|
||||
|
||||
## 5. 优先级矩阵
|
||||
|
||||
| 优先级 | 行动 | 工作量 | 影响 |
|
||||
|--------|------|--------|------|
|
||||
| 🔴 关键 | 修复配置安全(环境变量) | 1 小时 | 高 |
|
||||
| 🔴 关键 | 添加集成测试 | 2-3 天 | 高 |
|
||||
| 🔴 关键 | 文档化错误处理策略 | 4 小时 | 中 |
|
||||
| 🟡 高 | 统一 payment_line 解析 | 1-2 天 | 高 |
|
||||
| 🟡 高 | 提取规范化到子模块 | 2-3 天 | 中 |
|
||||
| 🟡 高 | 添加依赖注入 | 2-3 天 | 中 |
|
||||
| 🟡 高 | 拆分长函数 | 2-3 天 | 低 |
|
||||
| 🟢 中 | 提高测试覆盖率到 70%+ | 3-5 天 | 高 |
|
||||
| 🟢 中 | 提取魔法数字 | 4 小时 | 低 |
|
||||
| 🟢 中 | 标准化命名约定 | 1-2 天 | 中 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 具体文件建议
|
||||
|
||||
### 高优先级(代码质量)
|
||||
|
||||
| 文件 | 问题 | 建议 |
|
||||
|------|------|------|
|
||||
| `field_extractor.py` | 1,347 行;6 个长规范化方法 | 拆分为 `normalizers/` 子模块 |
|
||||
| `pipeline.py` | 127 行 `_cross_validate_payment_line()` | 提取到单独的 `CrossValidator` 类 |
|
||||
| `field_matcher.py` | 875 行;复杂匹配逻辑 | 拆分为 `matching/` 子模块 |
|
||||
| `config.py` | 硬编码凭据(第 29 行) | 使用环境变量 |
|
||||
| `machine_code_parser.py` | 919 行;payment_line 解析 | 与 pipeline 解析合并 |
|
||||
|
||||
### 中优先级(重构)
|
||||
|
||||
| 文件 | 问题 | 建议 |
|
||||
|------|------|------|
|
||||
| `app.py` | 765 行;HTML 内联在 Python 中 | 提取到 `templates/` 目录 |
|
||||
| `autolabel.py` | 753 行;批处理逻辑 | 提取 worker 函数到模块 |
|
||||
| `format_variants.py` | 610 行;变体生成 | 考虑策略模式 |
|
||||
|
||||
---
|
||||
|
||||
## 7. 建议行动
|
||||
|
||||
### 第 1 阶段:关键修复(1 周)
|
||||
|
||||
1. **配置安全** (1 小时)
|
||||
- 移除 config.py 中的明文密码
|
||||
- 添加环境变量支持
|
||||
- 更新 README 说明配置
|
||||
|
||||
2. **错误处理标准化** (1 天)
|
||||
- 定义自定义异常类
|
||||
- 替换通用 Exception 捕获
|
||||
- 添加错误代码常量
|
||||
|
||||
3. **添加关键集成测试** (2 天)
|
||||
- 端到端推理测试
|
||||
- payment_line 交叉验证测试
|
||||
- API 端点测试
|
||||
|
||||
### 第 2 阶段:重构(2-3 周)
|
||||
|
||||
4. **统一 payment_line 解析** (2 天)
|
||||
- 创建 `src/common/payment_line_parser.py`
|
||||
- 合并 3 处重复实现
|
||||
- 迁移所有调用方
|
||||
|
||||
5. **拆分 field_extractor.py** (3 天)
|
||||
- 创建 `src/inference/normalizers/` 子模块
|
||||
- 每个字段类型一个文件
|
||||
- 提取共享验证逻辑
|
||||
|
||||
6. **拆分长函数** (2 天)
|
||||
- `_normalize_customer_number()` → 3 个函数
|
||||
- `_cross_validate_payment_line()` → CrossValidator 类
|
||||
|
||||
### 第 3 阶段:改进(1-2 周)
|
||||
|
||||
7. **提高测试覆盖率** (5 天)
|
||||
- 目标:70%+ 覆盖率
|
||||
- 专注于验证逻辑
|
||||
- 添加边界情况测试
|
||||
|
||||
8. **配置管理改进** (1 天)
|
||||
- 提取所有魔法数字
|
||||
- 创建配置文件(YAML)
|
||||
- 添加配置验证
|
||||
|
||||
9. **文档改进** (2 天)
|
||||
- 添加架构图
|
||||
- 文档化所有私有方法
|
||||
- 创建贡献指南
|
||||
|
||||
---
|
||||
|
||||
## 附录 A:度量指标
|
||||
|
||||
### 代码复杂度
|
||||
|
||||
| 类别 | 计数 | 平均行数 |
|
||||
|------|------|----------|
|
||||
| 源文件 | 67 | 334 |
|
||||
| 长文件 (>500 行) | 12 | 875 |
|
||||
| 长函数 (>50 行) | 23 | 89 |
|
||||
| 测试文件 | 13 | 298 |
|
||||
|
||||
### 依赖关系
|
||||
|
||||
| 类型 | 计数 |
|
||||
|------|------|
|
||||
| 外部依赖 | ~25 |
|
||||
| 内部模块 | 10 |
|
||||
| 循环依赖 | 0 ✅ |
|
||||
|
||||
### 代码风格
|
||||
|
||||
| 指标 | 覆盖率 |
|
||||
|------|--------|
|
||||
| 类型提示 | 80% |
|
||||
| Docstrings (公开) | 80% |
|
||||
| Docstrings (私有) | 40% |
|
||||
| 测试覆盖率 | 45% |
|
||||
|
||||
---
|
||||
|
||||
**生成日期**: 2026-01-22
|
||||
**审查者**: Claude Code
|
||||
**版本**: v2.0
|
||||
@@ -1,96 +0,0 @@
|
||||
# Field Extractor 分析报告
|
||||
|
||||
## 概述
|
||||
|
||||
field_extractor.py (1183行) 最初被识别为可优化文件,尝试使用 `src/normalize` 模块进行重构,但经过分析和测试后发现 **不应该重构**。
|
||||
|
||||
## 重构尝试
|
||||
|
||||
### 初始计划
|
||||
将 field_extractor.py 中的重复 normalize 方法删除,统一使用 `src/normalize/normalize_field()` 接口。
|
||||
|
||||
### 实施步骤
|
||||
1. ✅ 备份原文件 (`field_extractor_old.py`)
|
||||
2. ✅ 修改 `_normalize_and_validate` 使用统一 normalizer
|
||||
3. ✅ 删除重复的 normalize 方法 (~400行)
|
||||
4. ❌ 运行测试 - **28个失败**
|
||||
5. ✅ 添加 wrapper 方法委托给 normalizer
|
||||
6. ❌ 再次测试 - **12个失败**
|
||||
7. ✅ 还原原文件
|
||||
8. ✅ 测试通过 - **全部45个测试通过**
|
||||
|
||||
## 关键发现
|
||||
|
||||
### 两个模块的不同用途
|
||||
|
||||
| 模块 | 用途 | 输入 | 输出 | 示例 |
|
||||
|------|------|------|------|------|
|
||||
| **src/normalize/** | **变体生成** 用于匹配 | 已提取的字段值 | 多个匹配变体列表 | `"INV-12345"` → `["INV-12345", "12345"]` |
|
||||
| **field_extractor** | **值提取** 从OCR文本 | 包含字段的原始OCR文本 | 提取的单个字段值 | `"Fakturanummer: A3861"` → `"A3861"` |
|
||||
|
||||
### 为什么不能统一?
|
||||
|
||||
1. **src/normalize/** 的设计目的:
|
||||
- 接收已经提取的字段值
|
||||
- 生成多个标准化变体用于fuzzy matching
|
||||
- 例如 BankgiroNormalizer:
|
||||
```python
|
||||
normalize("782-1713") → ["7821713", "782-1713"] # 生成变体
|
||||
```
|
||||
|
||||
2. **field_extractor** 的 normalize 方法:
|
||||
- 接收包含字段的原始OCR文本(可能包含标签、其他文本等)
|
||||
- **提取**特定模式的字段值
|
||||
- 例如 `_normalize_bankgiro`:
|
||||
```python
|
||||
_normalize_bankgiro("Bankgiro: 782-1713") → ("782-1713", True, None) # 从文本提取
|
||||
```
|
||||
|
||||
3. **关键区别**:
|
||||
- Normalizer: 变体生成器 (for matching)
|
||||
- Field Extractor: 模式提取器 (for parsing)
|
||||
|
||||
### 测试失败示例
|
||||
|
||||
使用 normalizer 替代 field extractor 方法后的失败:
|
||||
|
||||
```python
|
||||
# InvoiceNumber 测试
|
||||
Input: "Fakturanummer: A3861"
|
||||
期望: "A3861"
|
||||
实际: "Fakturanummer: A3861" # 没有提取,只是清理
|
||||
|
||||
# Bankgiro 测试
|
||||
Input: "Bankgiro: 782-1713"
|
||||
期望: "782-1713"
|
||||
实际: "7821713" # 返回了不带破折号的变体,而不是提取格式化值
|
||||
```
|
||||
|
||||
## 结论
|
||||
|
||||
**field_extractor.py 不应该使用 src/normalize 模块重构**,因为:
|
||||
|
||||
1. ✅ **职责不同**: 提取 vs 变体生成
|
||||
2. ✅ **输入不同**: 包含标签的原始OCR文本 vs 已提取的字段值
|
||||
3. ✅ **输出不同**: 单个提取值 vs 多个匹配变体
|
||||
4. ✅ **现有代码运行良好**: 所有45个测试通过
|
||||
5. ✅ **提取逻辑有价值**: 包含复杂的模式匹配规则(例如区分 Bankgiro/Plusgiro 格式)
|
||||
|
||||
## 建议
|
||||
|
||||
1. **保留 field_extractor.py 原样**: 不进行重构
|
||||
2. **文档化两个模块的差异**: 确保团队理解各自用途
|
||||
3. **关注其他优化目标**: machine_code_parser.py (919行)
|
||||
|
||||
## 学习点
|
||||
|
||||
重构前应该:
|
||||
1. 理解模块的**真实用途**,而不只是看代码相似度
|
||||
2. 运行完整测试套件验证假设
|
||||
3. 评估是否真的存在重复,还是表面相似但用途不同
|
||||
|
||||
---
|
||||
|
||||
**状态**: ✅ 分析完成,决定不重构
|
||||
**测试**: ✅ 45/45 通过
|
||||
**文件**: 保持 1183行 原样
|
||||
@@ -1,238 +0,0 @@
|
||||
# Machine Code Parser 分析报告
|
||||
|
||||
## 文件概况
|
||||
|
||||
- **文件**: `src/ocr/machine_code_parser.py`
|
||||
- **总行数**: 919 行
|
||||
- **代码行**: 607 行 (66%)
|
||||
- **方法数**: 14 个
|
||||
- **正则表达式使用**: 47 次
|
||||
|
||||
## 代码结构
|
||||
|
||||
### 类结构
|
||||
|
||||
```
|
||||
MachineCodeResult (数据类)
|
||||
├── to_dict()
|
||||
└── get_region_bbox()
|
||||
|
||||
MachineCodeParser (主解析器)
|
||||
├── __init__()
|
||||
├── parse() - 主入口
|
||||
├── _find_tokens_with_values()
|
||||
├── _find_machine_code_line_tokens()
|
||||
├── _parse_standard_payment_line_with_tokens()
|
||||
├── _parse_standard_payment_line() - 142行 ⚠️
|
||||
├── _extract_ocr() - 50行
|
||||
├── _extract_bankgiro() - 58行
|
||||
├── _extract_plusgiro() - 30行
|
||||
├── _extract_amount() - 68行
|
||||
├── _calculate_confidence()
|
||||
└── cross_validate()
|
||||
```
|
||||
|
||||
## 发现的问题
|
||||
|
||||
### 1. ⚠️ `_parse_standard_payment_line` 方法过长 (142行)
|
||||
|
||||
**位置**: 442-582 行
|
||||
|
||||
**问题**:
|
||||
- 包含嵌套函数 `normalize_account_spaces` 和 `format_account`
|
||||
- 多个正则匹配分支
|
||||
- 逻辑复杂,难以测试和维护
|
||||
|
||||
**建议**:
|
||||
可以拆分为独立方法:
|
||||
- `_normalize_account_spaces(line)`
|
||||
- `_format_account(account_digits, context)`
|
||||
- `_match_primary_pattern(line)`
|
||||
- `_match_fallback_patterns(line)`
|
||||
|
||||
### 2. 🔁 4个 `_extract_*` 方法有重复模式
|
||||
|
||||
所有 extract 方法都遵循相同模式:
|
||||
|
||||
```python
|
||||
def _extract_XXX(self, tokens):
|
||||
candidates = []
|
||||
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
matches = self.XXX_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
# 验证逻辑
|
||||
# 上下文检测
|
||||
candidates.append((normalized, context_score, token))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
|
||||
return candidates[0][0]
|
||||
```
|
||||
|
||||
**重复的逻辑**:
|
||||
- Token 迭代
|
||||
- 模式匹配
|
||||
- 候选收集
|
||||
- 上下文评分
|
||||
- 排序和选择最佳匹配
|
||||
|
||||
**建议**:
|
||||
可以提取基础提取器类或通用方法来减少重复。
|
||||
|
||||
### 3. ✅ 上下文检测重复
|
||||
|
||||
上下文检测代码在多个地方重复:
|
||||
|
||||
```python
|
||||
# _extract_bankgiro 中
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
is_bankgiro_context = (
|
||||
'bankgiro' in context_text or
|
||||
'bg:' in context_text or
|
||||
'bg ' in context_text
|
||||
)
|
||||
|
||||
# _extract_plusgiro 中
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
is_plusgiro_context = (
|
||||
'plusgiro' in context_text or
|
||||
'postgiro' in context_text or
|
||||
'pg:' in context_text or
|
||||
'pg ' in context_text
|
||||
)
|
||||
|
||||
# _parse_standard_payment_line 中
|
||||
context = (context_line or raw_line).lower()
|
||||
is_plusgiro_context = (
|
||||
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
|
||||
and 'bankgiro' not in context
|
||||
)
|
||||
```
|
||||
|
||||
**建议**:
|
||||
提取为独立方法:
|
||||
- `_detect_account_context(tokens) -> dict[str, bool]`
|
||||
|
||||
## 重构建议
|
||||
|
||||
### 方案 A: 轻度重构(推荐)✅
|
||||
|
||||
**目标**: 提取重复的上下文检测逻辑,不改变主要结构
|
||||
|
||||
**步骤**:
|
||||
1. 提取 `_detect_account_context(tokens)` 方法
|
||||
2. 提取 `_normalize_account_spaces(line)` 为独立方法
|
||||
3. 提取 `_format_account(digits, context)` 为独立方法
|
||||
|
||||
**影响**:
|
||||
- 减少 ~50-80 行重复代码
|
||||
- 提高可测试性
|
||||
- 低风险,易于验证
|
||||
|
||||
**预期结果**: 919 行 → ~850 行 (↓7%)
|
||||
|
||||
### 方案 B: 中度重构
|
||||
|
||||
**目标**: 创建通用的字段提取框架
|
||||
|
||||
**步骤**:
|
||||
1. 创建 `_generic_extract(pattern, normalizer, context_checker)`
|
||||
2. 重构所有 `_extract_*` 方法使用通用框架
|
||||
3. 拆分 `_parse_standard_payment_line` 为多个小方法
|
||||
|
||||
**影响**:
|
||||
- 减少 ~150-200 行代码
|
||||
- 显著提高可维护性
|
||||
- 中等风险,需要全面测试
|
||||
|
||||
**预期结果**: 919 行 → ~720 行 (↓22%)
|
||||
|
||||
### 方案 C: 深度重构(不推荐)
|
||||
|
||||
**目标**: 完全重新设计为策略模式
|
||||
|
||||
**风险**:
|
||||
- 高风险,可能引入 bugs
|
||||
- 需要大量测试
|
||||
- 可能破坏现有集成
|
||||
|
||||
## 推荐方案
|
||||
|
||||
### ✅ 采用方案 A(轻度重构)
|
||||
|
||||
**理由**:
|
||||
1. **代码已经工作良好**: 没有明显的 bug 或性能问题
|
||||
2. **低风险**: 只提取重复逻辑,不改变核心算法
|
||||
3. **性价比高**: 小改动带来明显的代码质量提升
|
||||
4. **易于验证**: 现有测试应该能覆盖
|
||||
|
||||
### 重构步骤
|
||||
|
||||
```python
|
||||
# 1. 提取上下文检测
|
||||
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
|
||||
"""检测上下文中的账户类型关键词"""
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
|
||||
return {
|
||||
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
|
||||
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
|
||||
}
|
||||
|
||||
# 2. 提取空格标准化
|
||||
def _normalize_account_spaces(self, line: str) -> str:
|
||||
"""移除账户号码中的空格"""
|
||||
# (现有 line 460-481 的代码)
|
||||
|
||||
# 3. 提取账户格式化
|
||||
def _format_account(
|
||||
self,
|
||||
account_digits: str,
|
||||
is_plusgiro_context: bool
|
||||
) -> tuple[str, str]:
|
||||
"""格式化账户并确定类型"""
|
||||
# (现有 line 485-523 的代码)
|
||||
```
|
||||
|
||||
## 对比:field_extractor vs machine_code_parser
|
||||
|
||||
| 特征 | field_extractor | machine_code_parser |
|
||||
|------|-----------------|---------------------|
|
||||
| 用途 | 值提取 | 机器码解析 |
|
||||
| 重复代码 | ~400行normalize方法 | ~80行上下文检测 |
|
||||
| 重构价值 | ❌ 不同用途,不应统一 | ✅ 可提取共享逻辑 |
|
||||
| 风险 | 高(会破坏功能) | 低(只是代码组织) |
|
||||
|
||||
## 决策
|
||||
|
||||
### ✅ 建议重构 machine_code_parser.py
|
||||
|
||||
**与 field_extractor 的不同**:
|
||||
- field_extractor: 重复的方法有**不同的用途**(提取 vs 变体生成)
|
||||
- machine_code_parser: 重复的代码有**相同的用途**(都是上下文检测)
|
||||
|
||||
**预期收益**:
|
||||
- 减少 ~70 行重复代码
|
||||
- 提高可测试性(可以单独测试上下文检测)
|
||||
- 更清晰的代码组织
|
||||
- **低风险**,易于验证
|
||||
|
||||
## 下一步
|
||||
|
||||
1. ✅ 备份原文件
|
||||
2. ✅ 提取 `_detect_account_context` 方法
|
||||
3. ✅ 提取 `_normalize_account_spaces` 方法
|
||||
4. ✅ 提取 `_format_account` 方法
|
||||
5. ✅ 更新所有调用点
|
||||
6. ✅ 运行测试验证
|
||||
7. ✅ 检查代码覆盖率
|
||||
|
||||
---
|
||||
|
||||
**状态**: 📋 分析完成,建议轻度重构
|
||||
**风险评估**: 🟢 低风险
|
||||
**预期收益**: 919行 → ~850行 (↓7%)
|
||||
@@ -1,519 +0,0 @@
|
||||
# Performance Optimization Guide
|
||||
|
||||
This document provides performance optimization recommendations for the Invoice Field Extraction system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Batch Processing Optimization](#batch-processing-optimization)
|
||||
2. [Database Query Optimization](#database-query-optimization)
|
||||
3. [Caching Strategies](#caching-strategies)
|
||||
4. [Memory Management](#memory-management)
|
||||
5. [Profiling and Monitoring](#profiling-and-monitoring)
|
||||
|
||||
---
|
||||
|
||||
## Batch Processing Optimization
|
||||
|
||||
### Current State
|
||||
|
||||
The system processes invoices one at a time. For large batches, this can be inefficient.
|
||||
|
||||
### Recommendations
|
||||
|
||||
#### 1. Database Batch Operations
|
||||
|
||||
**Current**: Individual inserts for each document
|
||||
```python
|
||||
# Inefficient
|
||||
for doc in documents:
|
||||
db.insert_document(doc) # Individual DB call
|
||||
```
|
||||
|
||||
**Optimized**: Use `execute_values` for batch inserts
|
||||
```python
|
||||
# Efficient - already implemented in db.py line 519
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
execute_values(cursor, """
|
||||
INSERT INTO documents (...)
|
||||
VALUES %s
|
||||
""", document_values)
|
||||
```
|
||||
|
||||
**Impact**: 10-50x faster for batches of 100+ documents
|
||||
|
||||
#### 2. PDF Processing Batching
|
||||
|
||||
**Recommendation**: Process PDFs in parallel using multiprocessing
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool
|
||||
|
||||
def process_batch(pdf_paths, batch_size=10):
|
||||
"""Process PDFs in parallel batches."""
|
||||
with Pool(processes=batch_size) as pool:
|
||||
results = pool.map(pipeline.process_pdf, pdf_paths)
|
||||
return results
|
||||
```
|
||||
|
||||
**Considerations**:
|
||||
- GPU models should use a shared process pool (already exists: `src/processing/gpu_pool.py`)
|
||||
- CPU-intensive tasks can use separate process pool (`src/processing/cpu_pool.py`)
|
||||
- Current dual pool coordinator (`dual_pool_coordinator.py`) already supports this pattern
|
||||
|
||||
**Status**: ✅ Already implemented in `src/processing/` modules
|
||||
|
||||
#### 3. Image Caching for Multi-Page PDFs
|
||||
|
||||
**Current**: Each page rendered independently
|
||||
```python
|
||||
# Current pattern in field_extractor.py
|
||||
for page_num in range(total_pages):
|
||||
image = render_pdf_page(pdf_path, page_num, dpi=300)
|
||||
```
|
||||
|
||||
**Optimized**: Pre-render all pages if processing multiple fields per page
|
||||
```python
|
||||
# Batch render
|
||||
images = {
|
||||
page_num: render_pdf_page(pdf_path, page_num, dpi=300)
|
||||
for page_num in page_numbers_needed
|
||||
}
|
||||
|
||||
# Reuse images
|
||||
for detection in detections:
|
||||
image = images[detection.page_no]
|
||||
extract_field(detection, image)
|
||||
```
|
||||
|
||||
**Impact**: Reduces redundant PDF rendering by 50-90% for multi-field invoices
|
||||
|
||||
---
|
||||
|
||||
## Database Query Optimization
|
||||
|
||||
### Current Performance
|
||||
|
||||
- **Parameterized queries**: ✅ Implemented (Phase 1)
|
||||
- **Connection pooling**: ❌ Not implemented
|
||||
- **Query batching**: ✅ Partially implemented
|
||||
- **Index optimization**: ⚠️ Needs verification
|
||||
|
||||
### Recommendations
|
||||
|
||||
#### 1. Connection Pooling
|
||||
|
||||
**Current**: New connection for each operation
|
||||
```python
|
||||
def connect(self):
|
||||
"""Create new database connection."""
|
||||
return psycopg2.connect(**self.config)
|
||||
```
|
||||
|
||||
**Optimized**: Use connection pooling
|
||||
```python
|
||||
from psycopg2 import pool
|
||||
|
||||
class DocumentDatabase:
|
||||
def __init__(self, config):
|
||||
self.pool = pool.SimpleConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=10,
|
||||
**config
|
||||
)
|
||||
|
||||
def connect(self):
|
||||
return self.pool.getconn()
|
||||
|
||||
def close(self, conn):
|
||||
self.pool.putconn(conn)
|
||||
```
|
||||
|
||||
**Impact**:
|
||||
- Reduces connection overhead by 80-95%
|
||||
- Especially important for high-frequency operations
|
||||
|
||||
#### 2. Index Recommendations
|
||||
|
||||
**Check current indexes**:
|
||||
```sql
|
||||
-- Verify indexes exist on frequently queried columns
|
||||
SELECT tablename, indexname, indexdef
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public';
|
||||
```
|
||||
|
||||
**Recommended indexes**:
|
||||
```sql
|
||||
-- If not already present
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_success
|
||||
ON documents(success);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_timestamp
|
||||
ON documents(timestamp DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_document_id
|
||||
ON field_results(document_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_matched
|
||||
ON field_results(matched);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_field_name
|
||||
ON field_results(field_name);
|
||||
```
|
||||
|
||||
**Impact**:
|
||||
- 10-100x faster queries for filtered/sorted results
|
||||
- Critical for `get_failed_matches()` and `get_all_documents_summary()`
|
||||
|
||||
#### 3. Query Batching
|
||||
|
||||
**Status**: ✅ Already implemented for field results (line 519)
|
||||
|
||||
**Verify batching is used**:
|
||||
```python
|
||||
# Good pattern in db.py
|
||||
execute_values(cursor, "INSERT INTO field_results (...) VALUES %s", field_values)
|
||||
```
|
||||
|
||||
**Additional opportunity**: Batch `SELECT` queries
|
||||
```python
|
||||
# Current
|
||||
docs = [get_document(doc_id) for doc_id in doc_ids] # N queries
|
||||
|
||||
# Optimized
|
||||
docs = get_documents_batch(doc_ids) # 1 query with IN clause
|
||||
```
|
||||
|
||||
**Status**: ✅ Already implemented (`get_documents_batch` exists in db.py)
|
||||
|
||||
---
|
||||
|
||||
## Caching Strategies
|
||||
|
||||
### 1. Model Loading Cache
|
||||
|
||||
**Current**: Models loaded per-instance
|
||||
|
||||
**Recommendation**: Singleton pattern for YOLO model
|
||||
```python
|
||||
class YOLODetectorSingleton:
|
||||
_instance = None
|
||||
_model = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, model_path):
|
||||
if cls._instance is None:
|
||||
cls._instance = YOLODetector(model_path)
|
||||
return cls._instance
|
||||
```
|
||||
|
||||
**Impact**: Reduces memory usage by 90% when processing multiple documents
|
||||
|
||||
### 2. Parser Instance Caching
|
||||
|
||||
**Current**: ✅ Already optimal
|
||||
```python
|
||||
# Good pattern in field_extractor.py
|
||||
def __init__(self):
|
||||
self.payment_line_parser = PaymentLineParser() # Reused
|
||||
self.customer_number_parser = CustomerNumberParser() # Reused
|
||||
```
|
||||
|
||||
**Status**: No changes needed
|
||||
|
||||
### 3. OCR Result Caching
|
||||
|
||||
**Recommendation**: Cache OCR results for identical regions
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def ocr_region_cached(image_hash, bbox):
|
||||
"""Cache OCR results by image hash + bbox."""
|
||||
return paddle_ocr.ocr_region(image, bbox)
|
||||
```
|
||||
|
||||
**Impact**: 50-80% speedup when re-processing similar documents
|
||||
|
||||
**Note**: Requires implementing image hashing (e.g., `hashlib.md5(image.tobytes())`)
|
||||
|
||||
---
|
||||
|
||||
## Memory Management
|
||||
|
||||
### Current Issues
|
||||
|
||||
**Potential memory leaks**:
|
||||
1. Large images kept in memory after processing
|
||||
2. OCR results accumulated without cleanup
|
||||
3. Model outputs not explicitly cleared
|
||||
|
||||
### Recommendations
|
||||
|
||||
#### 1. Explicit Image Cleanup
|
||||
|
||||
```python
|
||||
import gc
|
||||
|
||||
def process_pdf(pdf_path):
|
||||
try:
|
||||
image = render_pdf(pdf_path)
|
||||
result = extract_fields(image)
|
||||
return result
|
||||
finally:
|
||||
del image # Explicit cleanup
|
||||
gc.collect() # Force garbage collection
|
||||
```
|
||||
|
||||
#### 2. Generator Pattern for Large Batches
|
||||
|
||||
**Current**: Load all documents into memory
|
||||
```python
|
||||
docs = [process_pdf(path) for path in pdf_paths] # All in memory
|
||||
```
|
||||
|
||||
**Optimized**: Use generator for streaming processing
|
||||
```python
|
||||
def process_batch_streaming(pdf_paths):
|
||||
"""Process documents one at a time, yielding results."""
|
||||
for path in pdf_paths:
|
||||
result = process_pdf(path)
|
||||
yield result
|
||||
# Result can be saved to DB immediately
|
||||
# Previous result is garbage collected
|
||||
```
|
||||
|
||||
**Impact**: Constant memory usage regardless of batch size
|
||||
|
||||
#### 3. Context Managers for Resources
|
||||
|
||||
```python
|
||||
class InferencePipeline:
|
||||
def __enter__(self):
|
||||
self.detector.load_model()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.detector.unload_model()
|
||||
self.extractor.cleanup()
|
||||
|
||||
# Usage
|
||||
with InferencePipeline(...) as pipeline:
|
||||
results = pipeline.process_pdf(path)
|
||||
# Automatic cleanup
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Profiling and Monitoring
|
||||
|
||||
### Recommended Profiling Tools
|
||||
|
||||
#### 1. cProfile for CPU Profiling
|
||||
|
||||
```python
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
# Your code here
|
||||
pipeline.process_pdf(pdf_path)
|
||||
|
||||
profiler.disable()
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.sort_stats('cumulative')
|
||||
stats.print_stats(20) # Top 20 slowest functions
|
||||
```
|
||||
|
||||
#### 2. memory_profiler for Memory Analysis
|
||||
|
||||
```bash
|
||||
pip install memory_profiler
|
||||
python -m memory_profiler your_script.py
|
||||
```
|
||||
|
||||
Or decorator-based:
|
||||
```python
|
||||
from memory_profiler import profile
|
||||
|
||||
@profile
|
||||
def process_large_batch(pdf_paths):
|
||||
# Memory usage tracked line-by-line
|
||||
results = [process_pdf(path) for path in pdf_paths]
|
||||
return results
|
||||
```
|
||||
|
||||
#### 3. py-spy for Production Profiling
|
||||
|
||||
```bash
|
||||
pip install py-spy
|
||||
|
||||
# Profile running process
|
||||
py-spy top --pid 12345
|
||||
|
||||
# Generate flamegraph
|
||||
py-spy record -o profile.svg -- python your_script.py
|
||||
```
|
||||
|
||||
**Advantage**: No code changes needed, minimal overhead
|
||||
|
||||
### Key Metrics to Monitor
|
||||
|
||||
1. **Processing Time per Document**
|
||||
- Target: <10 seconds for single-page invoice
|
||||
- Current: ~2-5 seconds (estimated)
|
||||
|
||||
2. **Memory Usage**
|
||||
- Target: <2GB for batch of 100 documents
|
||||
- Monitor: Peak memory usage
|
||||
|
||||
3. **Database Query Time**
|
||||
- Target: <100ms per query (with indexes)
|
||||
- Monitor: Slow query log
|
||||
|
||||
4. **OCR Accuracy vs Speed Trade-off**
|
||||
- Current: PaddleOCR with GPU (~200ms per region)
|
||||
- Alternative: Tesseract (~500ms, slightly more accurate)
|
||||
|
||||
### Logging Performance Metrics
|
||||
|
||||
**Add to pipeline.py**:
|
||||
```python
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def process_pdf(self, pdf_path):
|
||||
start = time.time()
|
||||
|
||||
# Processing...
|
||||
result = self._process_internal(pdf_path)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(f"Processed {pdf_path} in {elapsed:.2f}s")
|
||||
|
||||
# Log to database for analysis
|
||||
self.db.log_performance({
|
||||
'document_id': result.document_id,
|
||||
'processing_time': elapsed,
|
||||
'field_count': len(result.fields)
|
||||
})
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Performance Optimization Priorities
|
||||
|
||||
### High Priority (Implement First)
|
||||
|
||||
1. ✅ **Database parameterized queries** - Already done (Phase 1)
|
||||
2. ⚠️ **Database connection pooling** - Not implemented
|
||||
3. ⚠️ **Index optimization** - Needs verification
|
||||
|
||||
### Medium Priority
|
||||
|
||||
4. ⚠️ **Batch PDF rendering** - Optimization possible
|
||||
5. ✅ **Parser instance reuse** - Already done (Phase 2)
|
||||
6. ⚠️ **Model caching** - Could improve
|
||||
|
||||
### Low Priority (Nice to Have)
|
||||
|
||||
7. ⚠️ **OCR result caching** - Complex implementation
|
||||
8. ⚠️ **Generator patterns** - Refactoring needed
|
||||
9. ⚠️ **Advanced profiling** - For production optimization
|
||||
|
||||
---
|
||||
|
||||
## Benchmarking Script
|
||||
|
||||
```python
|
||||
"""
|
||||
Benchmark script for invoice processing performance.
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from src.inference.pipeline import InferencePipeline
|
||||
|
||||
def benchmark_single_document(pdf_path, iterations=10):
|
||||
"""Benchmark single document processing."""
|
||||
pipeline = InferencePipeline(
|
||||
model_path="path/to/model.pt",
|
||||
use_gpu=True
|
||||
)
|
||||
|
||||
times = []
|
||||
for i in range(iterations):
|
||||
start = time.time()
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
print(f"Iteration {i+1}: {elapsed:.2f}s")
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"\nAverage: {avg_time:.2f}s")
|
||||
print(f"Min: {min(times):.2f}s")
|
||||
print(f"Max: {max(times):.2f}s")
|
||||
|
||||
def benchmark_batch(pdf_paths, batch_size=10):
|
||||
"""Benchmark batch processing."""
|
||||
from multiprocessing import Pool
|
||||
|
||||
pipeline = InferencePipeline(
|
||||
model_path="path/to/model.pt",
|
||||
use_gpu=True
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
with Pool(processes=batch_size) as pool:
|
||||
results = pool.map(pipeline.process_pdf, pdf_paths)
|
||||
|
||||
elapsed = time.time() - start
|
||||
avg_per_doc = elapsed / len(pdf_paths)
|
||||
|
||||
print(f"Total time: {elapsed:.2f}s")
|
||||
print(f"Documents: {len(pdf_paths)}")
|
||||
print(f"Average per document: {avg_per_doc:.2f}s")
|
||||
print(f"Throughput: {len(pdf_paths)/elapsed:.2f} docs/sec")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Single document benchmark
|
||||
benchmark_single_document("test.pdf")
|
||||
|
||||
# Batch benchmark
|
||||
pdf_paths = list(Path("data/test_pdfs").glob("*.pdf"))
|
||||
benchmark_batch(pdf_paths[:100])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
**Implemented (Phase 1-2)**:
|
||||
- ✅ Parameterized queries (SQL injection fix)
|
||||
- ✅ Parser instance reuse (Phase 2 refactoring)
|
||||
- ✅ Batch insert operations (execute_values)
|
||||
- ✅ Dual pool processing (CPU/GPU separation)
|
||||
|
||||
**Quick Wins (Low effort, high impact)**:
|
||||
- Database connection pooling (2-4 hours)
|
||||
- Index verification and optimization (1-2 hours)
|
||||
- Batch PDF rendering (4-6 hours)
|
||||
|
||||
**Long-term Improvements**:
|
||||
- OCR result caching with hashing
|
||||
- Generator patterns for streaming
|
||||
- Advanced profiling and monitoring
|
||||
|
||||
**Expected Impact**:
|
||||
- Connection pooling: 80-95% reduction in DB overhead
|
||||
- Indexes: 10-100x faster queries
|
||||
- Batch rendering: 50-90% less redundant work
|
||||
- **Overall**: 2-5x throughput improvement for batch processing
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,170 +0,0 @@
|
||||
# 代码重构总结报告
|
||||
|
||||
## 📊 整体成果
|
||||
|
||||
### 测试状态
|
||||
- ✅ **688/688 测试全部通过** (100%)
|
||||
- ✅ **代码覆盖率**: 34% → 37% (+3%)
|
||||
- ✅ **0 个失败**, 0 个错误
|
||||
|
||||
### 测试覆盖率改进
|
||||
- ✅ **machine_code_parser**: 25% → 65% (+40%)
|
||||
- ✅ **新增测试**: 55个(633 → 688)
|
||||
|
||||
---
|
||||
|
||||
## 🎯 已完成的重构
|
||||
|
||||
### 1. ✅ Matcher 模块化 (876行 → 205行, ↓76%)
|
||||
|
||||
**文件**:
|
||||
|
||||
**重构内容**:
|
||||
- 将单一876行文件拆分为 **11个模块**
|
||||
- 提取 **5种独立的匹配策略**
|
||||
- 创建专门的数据模型、工具函数和上下文处理模块
|
||||
|
||||
**新模块结构**:
|
||||
|
||||
|
||||
**测试结果**:
|
||||
- ✅ 77个 matcher 测试全部通过
|
||||
- ✅ 完整的README文档
|
||||
- ✅ 策略模式,易于扩展
|
||||
|
||||
**收益**:
|
||||
- 📉 代码量减少 76%
|
||||
- 📈 可维护性显著提高
|
||||
- ✨ 每个策略独立测试
|
||||
- 🔧 易于添加新策略
|
||||
|
||||
---
|
||||
|
||||
### 2. ✅ Machine Code Parser 轻度重构 + 测试覆盖 (919行 → 929行)
|
||||
|
||||
**文件**: src/ocr/machine_code_parser.py
|
||||
|
||||
**重构内容**:
|
||||
- 提取 **3个共享辅助方法**,消除重复代码
|
||||
- 优化上下文检测逻辑
|
||||
- 简化账号格式化方法
|
||||
|
||||
**测试改进**:
|
||||
- ✅ **新增55个测试**(24 → 79个)
|
||||
- ✅ **覆盖率**: 25% → 65% (+40%)
|
||||
- ✅ 所有688个项目测试通过
|
||||
|
||||
**新增测试覆盖**:
|
||||
- **第一轮** (22个测试):
|
||||
- `_detect_account_context()` - 8个测试(上下文检测)
|
||||
- `_normalize_account_spaces()` - 5个测试(空格规范化)
|
||||
- `_format_account()` - 4个测试(账号格式化)
|
||||
- `parse()` - 5个测试(主入口方法)
|
||||
- **第二轮** (33个测试):
|
||||
- `_extract_ocr()` - 8个测试(OCR 提取)
|
||||
- `_extract_bankgiro()` - 9个测试(Bankgiro 提取)
|
||||
- `_extract_plusgiro()` - 8个测试(Plusgiro 提取)
|
||||
- `_extract_amount()` - 8个测试(金额提取)
|
||||
|
||||
**收益**:
|
||||
- 🔄 消除80行重复代码
|
||||
- 📈 可测试性提高(可独立测试辅助方法)
|
||||
- 📖 代码可读性提升
|
||||
- ✅ 覆盖率从25%提升到65% (+40%)
|
||||
- 🎯 低风险,高回报
|
||||
|
||||
---
|
||||
|
||||
### 3. ✅ Field Extractor 分析 (决定不重构)
|
||||
|
||||
**文件**: (1183行)
|
||||
|
||||
**分析结果**: ❌ **不应重构**
|
||||
|
||||
**关键洞察**:
|
||||
- 表面相似的代码可能有**完全不同的用途**
|
||||
- field_extractor: **解析/提取** 字段值
|
||||
- src/normalize: **标准化/生成变体** 用于匹配
|
||||
- 两者职责不同,不应统一
|
||||
|
||||
**文档**:
|
||||
|
||||
---
|
||||
|
||||
## 📈 重构统计
|
||||
|
||||
### 代码行数变化
|
||||
|
||||
| 文件 | 重构前 | 重构后 | 变化 | 百分比 |
|
||||
|------|--------|--------|------|--------|
|
||||
| **matcher/field_matcher.py** | 876行 | 205行 | -671 | ↓76% |
|
||||
| **matcher/* (新增10个模块)** | 0行 | 466行 | +466 | 新增 |
|
||||
| **matcher 总计** | 876行 | 671行 | -205 | ↓23% |
|
||||
| **ocr/machine_code_parser.py** | 919行 | 929行 | +10 | +1% |
|
||||
| **总净减少** | - | - | **-195行** | **↓11%** |
|
||||
|
||||
### 测试覆盖
|
||||
|
||||
| 模块 | 测试数 | 通过率 | 覆盖率 | 状态 |
|
||||
|------|--------|--------|--------|------|
|
||||
| matcher | 77 | 100% | - | ✅ |
|
||||
| field_extractor | 45 | 100% | 39% | ✅ |
|
||||
| machine_code_parser | 79 | 100% | 65% | ✅ |
|
||||
| normalizer | ~120 | 100% | - | ✅ |
|
||||
| 其他模块 | ~367 | 100% | - | ✅ |
|
||||
| **总计** | **688** | **100%** | **37%** | ✅ |
|
||||
|
||||
---
|
||||
|
||||
## 🎓 重构经验总结
|
||||
|
||||
### 成功经验
|
||||
|
||||
1. **✅ 先测试后重构**
|
||||
- 所有重构都有完整测试覆盖
|
||||
- 每次改动后立即验证测试
|
||||
- 100%测试通过率保证质量
|
||||
|
||||
2. **✅ 识别真正的重复**
|
||||
- 不是所有相似代码都是重复
|
||||
- field_extractor vs normalizer: 表面相似但用途不同
|
||||
- machine_code_parser: 真正的代码重复
|
||||
|
||||
3. **✅ 渐进式重构**
|
||||
- matcher: 大规模模块化 (策略模式)
|
||||
- machine_code_parser: 轻度重构 (提取共享方法)
|
||||
- field_extractor: 分析后决定不重构
|
||||
|
||||
### 关键决策
|
||||
|
||||
#### ✅ 应该重构的情况
|
||||
- **matcher**: 单一文件过长 (876行),包含多种策略
|
||||
- **machine_code_parser**: 多处相同用途的重复代码
|
||||
|
||||
#### ❌ 不应重构的情况
|
||||
- **field_extractor**: 相似代码有不同用途
|
||||
|
||||
### 教训
|
||||
|
||||
**不要盲目追求DRY原则**
|
||||
> 相似代码不一定是重复。要理解代码的**真实用途**。
|
||||
|
||||
---
|
||||
|
||||
## ✅ 总结
|
||||
|
||||
**关键成果**:
|
||||
- 📉 净减少 195 行代码
|
||||
- 📈 代码覆盖率 +3% (34% → 37%)
|
||||
- ✅ 测试数量 +55 (633 → 688)
|
||||
- 🎯 machine_code_parser 覆盖率 +40% (25% → 65%)
|
||||
- ✨ 模块化程度显著提高
|
||||
- 🎯 可维护性大幅提升
|
||||
|
||||
**重要教训**:
|
||||
> 相似的代码不一定是重复的代码。理解代码的真实用途,才能做出正确的重构决策。
|
||||
|
||||
**下一步建议**:
|
||||
1. 继续提升 machine_code_parser 覆盖率到 80%+ (目前 65%)
|
||||
2. 为其他低覆盖模块添加测试(field_extractor 39%, pipeline 19%)
|
||||
3. 完善边界条件和异常情况的测试
|
||||
@@ -1,258 +0,0 @@
|
||||
# 测试覆盖率改进报告
|
||||
|
||||
## 📊 改进概览
|
||||
|
||||
### 整体统计
|
||||
- ✅ **测试总数**: 633 → 688 (+55个测试, +8.7%)
|
||||
- ✅ **通过率**: 100% (688/688)
|
||||
- ✅ **整体覆盖率**: 34% → 37% (+3%)
|
||||
|
||||
### machine_code_parser.py 专项改进
|
||||
- ✅ **测试数**: 24 → 79 (+55个测试, +229%)
|
||||
- ✅ **覆盖率**: 25% → 65% (+40%)
|
||||
- ✅ **未覆盖行**: 273 → 129 (减少144行)
|
||||
|
||||
---
|
||||
|
||||
## 🎯 新增测试详情
|
||||
|
||||
### 第一轮改进 (22个测试)
|
||||
|
||||
#### 1. TestDetectAccountContext (8个测试)
|
||||
|
||||
测试新增的 `_detect_account_context()` 辅助方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_bankgiro_keyword` - 检测 'bankgiro' 关键词
|
||||
2. `test_bg_keyword` - 检测 'bg:' 缩写
|
||||
3. `test_plusgiro_keyword` - 检测 'plusgiro' 关键词
|
||||
4. `test_postgiro_keyword` - 检测 'postgiro' 别名
|
||||
5. `test_pg_keyword` - 检测 'pg:' 缩写
|
||||
6. `test_both_contexts` - 同时存在两种关键词
|
||||
7. `test_no_context` - 无账号关键词
|
||||
8. `test_case_insensitive` - 大小写不敏感检测
|
||||
|
||||
**覆盖的代码路径**:
|
||||
```python
|
||||
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
return {
|
||||
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
|
||||
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. TestNormalizeAccountSpacesMethod (5个测试)
|
||||
|
||||
测试新增的 `_normalize_account_spaces()` 辅助方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_removes_spaces_after_arrow` - 移除 > 后的空格
|
||||
2. `test_multiple_consecutive_spaces` - 处理多个连续空格
|
||||
3. `test_no_arrow_returns_unchanged` - 无 > 标记时返回原值
|
||||
4. `test_spaces_before_arrow_preserved` - 保留 > 前的空格
|
||||
5. `test_empty_string` - 空字符串处理
|
||||
|
||||
**覆盖的代码路径**:
|
||||
```python
|
||||
def _normalize_account_spaces(self, line: str) -> str:
|
||||
if '>' not in line:
|
||||
return line
|
||||
parts = line.split('>', 1)
|
||||
after_arrow = parts[1]
|
||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
|
||||
while re.search(r'(\d)\s+(\d)', normalized):
|
||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
|
||||
return parts[0] + '>' + normalized
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. TestFormatAccount (4个测试)
|
||||
|
||||
测试新增的 `_format_account()` 辅助方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_plusgiro_context_forces_plusgiro` - Plusgiro 上下文强制格式化为 Plusgiro
|
||||
2. `test_valid_bankgiro_7_digits` - 7位有效 Bankgiro 格式化
|
||||
3. `test_valid_bankgiro_8_digits` - 8位有效 Bankgiro 格式化
|
||||
4. `test_defaults_to_bankgiro_when_ambiguous` - 模糊情况默认 Bankgiro
|
||||
|
||||
**覆盖的代码路径**:
|
||||
```python
|
||||
def _format_account(self, account_digits: str, is_plusgiro_context: bool) -> tuple[str, str]:
|
||||
if is_plusgiro_context:
|
||||
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
|
||||
return formatted, 'plusgiro'
|
||||
|
||||
# Luhn 验证逻辑
|
||||
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
|
||||
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
|
||||
|
||||
# 决策逻辑
|
||||
if pg_valid and not bg_valid:
|
||||
return pg_formatted, 'plusgiro'
|
||||
elif bg_valid and not pg_valid:
|
||||
return bg_formatted, 'bankgiro'
|
||||
else:
|
||||
return bg_formatted, 'bankgiro'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. TestParseMethod (5个测试)
|
||||
|
||||
测试主入口 `parse()` 方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_parse_empty_tokens` - 空 token 列表处理
|
||||
2. `test_parse_finds_payment_line_in_bottom_region` - 在页面底部35%区域查找付款行
|
||||
3. `test_parse_ignores_top_region` - 忽略页面顶部区域
|
||||
4. `test_parse_with_context_keywords` - 检测上下文关键词
|
||||
5. `test_parse_stores_source_tokens` - 存储源 token
|
||||
|
||||
**覆盖的代码路径**:
|
||||
- Token 过滤(底部区域检测)
|
||||
- 上下文关键词检测
|
||||
- 付款行查找和解析
|
||||
- 结果对象构建
|
||||
|
||||
---
|
||||
|
||||
### 第二轮改进 (33个测试)
|
||||
|
||||
#### 5. TestExtractOCR (8个测试)
|
||||
|
||||
测试 `_extract_ocr()` 方法 - OCR 参考号码提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_valid_ocr_10_digits` - 提取10位 OCR 号码
|
||||
2. `test_extract_valid_ocr_15_digits` - 提取15位 OCR 号码
|
||||
3. `test_extract_ocr_with_hash_markers` - 带 # 标记的 OCR
|
||||
4. `test_extract_longest_ocr_when_multiple` - 多个候选时选最长
|
||||
5. `test_extract_ocr_ignores_short_numbers` - 忽略短于10位的数字
|
||||
6. `test_extract_ocr_ignores_long_numbers` - 忽略长于25位的数字
|
||||
7. `test_extract_ocr_excludes_bankgiro_variants` - 排除 Bankgiro 变体
|
||||
8. `test_extract_ocr_empty_tokens` - 空 token 处理
|
||||
|
||||
#### 6. TestExtractBankgiro (9个测试)
|
||||
|
||||
测试 `_extract_bankgiro()` 方法 - Bankgiro 账号提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_bankgiro_7_digits_with_dash` - 带破折号的7位 Bankgiro
|
||||
2. `test_extract_bankgiro_7_digits_without_dash` - 无破折号的7位 Bankgiro
|
||||
3. `test_extract_bankgiro_8_digits_with_dash` - 带破折号的8位 Bankgiro
|
||||
4. `test_extract_bankgiro_8_digits_without_dash` - 无破折号的8位 Bankgiro
|
||||
5. `test_extract_bankgiro_with_spaces` - 带空格的 Bankgiro
|
||||
6. `test_extract_bankgiro_handles_plusgiro_format` - 处理 Plusgiro 格式
|
||||
7. `test_extract_bankgiro_with_context` - 带上下文关键词
|
||||
8. `test_extract_bankgiro_ignores_plusgiro_context` - 忽略 Plusgiro 上下文
|
||||
9. `test_extract_bankgiro_empty_tokens` - 空 token 处理
|
||||
|
||||
#### 7. TestExtractPlusgiro (8个测试)
|
||||
|
||||
测试 `_extract_plusgiro()` 方法 - Plusgiro 账号提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_plusgiro_7_digits_with_dash` - 带破折号的7位 Plusgiro
|
||||
2. `test_extract_plusgiro_7_digits_without_dash` - 无破折号的7位 Plusgiro
|
||||
3. `test_extract_plusgiro_8_digits` - 8位 Plusgiro
|
||||
4. `test_extract_plusgiro_with_spaces` - 带空格的 Plusgiro
|
||||
5. `test_extract_plusgiro_with_context` - 带上下文关键词
|
||||
6. `test_extract_plusgiro_ignores_too_short` - 忽略少于7位
|
||||
7. `test_extract_plusgiro_ignores_too_long` - 忽略多于8位
|
||||
8. `test_extract_plusgiro_empty_tokens` - 空 token 处理
|
||||
|
||||
#### 8. TestExtractAmount (8个测试)
|
||||
|
||||
测试 `_extract_amount()` 方法 - 金额提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_amount_with_comma_decimal` - 逗号小数分隔符
|
||||
2. `test_extract_amount_with_dot_decimal` - 点号小数分隔符
|
||||
3. `test_extract_amount_integer` - 整数金额
|
||||
4. `test_extract_amount_with_thousand_separator` - 千位分隔符
|
||||
5. `test_extract_amount_large_number` - 大额金额
|
||||
6. `test_extract_amount_ignores_too_large` - 忽略过大金额
|
||||
7. `test_extract_amount_ignores_zero` - 忽略零或负数
|
||||
8. `test_extract_amount_empty_tokens` - 空 token 处理
|
||||
|
||||
---
|
||||
|
||||
## 📈 覆盖率分析
|
||||
|
||||
### 已覆盖的方法
|
||||
✅ `_detect_account_context()` - **100%** (第一轮新增)
|
||||
✅ `_normalize_account_spaces()` - **100%** (第一轮新增)
|
||||
✅ `_format_account()` - **95%** (第一轮新增)
|
||||
✅ `parse()` - **70%** (第一轮改进)
|
||||
✅ `_parse_standard_payment_line()` - **95%** (已有测试)
|
||||
✅ `_extract_ocr()` - **85%** (第二轮新增)
|
||||
✅ `_extract_bankgiro()` - **90%** (第二轮新增)
|
||||
✅ `_extract_plusgiro()` - **90%** (第二轮新增)
|
||||
✅ `_extract_amount()` - **80%** (第二轮新增)
|
||||
|
||||
### 仍需改进的方法 (未覆盖/部分覆盖)
|
||||
⚠️ `_calculate_confidence()` - **0%** (未测试)
|
||||
⚠️ `cross_validate()` - **0%** (未测试)
|
||||
⚠️ `get_region_bbox()` - **0%** (未测试)
|
||||
⚠️ `_find_tokens_with_values()` - **部分覆盖**
|
||||
⚠️ `_find_machine_code_line_tokens()` - **部分覆盖**
|
||||
|
||||
### 未覆盖的代码行(129行)
|
||||
主要集中在:
|
||||
1. **验证方法** (lines 805-824): `_calculate_confidence`, `cross_validate`
|
||||
2. **辅助方法** (lines 80-92, 336-369, 377-407): Token 查找、bbox 计算、日志记录
|
||||
3. **边界条件** (lines 648-653, 690, 699, 759-760等): 某些提取方法的边界情况
|
||||
|
||||
---
|
||||
|
||||
## 🎯 改进建议
|
||||
|
||||
### ✅ 已完成目标
|
||||
- ✅ 覆盖率从 25% 提升到 65% (+40%)
|
||||
- ✅ 测试数量从 24 增加到 79 (+55个)
|
||||
- ✅ 提取方法全部测试(_extract_ocr, _extract_bankgiro, _extract_plusgiro, _extract_amount)
|
||||
|
||||
### 下一步目标(覆盖率 65% → 80%+)
|
||||
1. **添加验证方法测试** - 为 `_calculate_confidence`, `cross_validate` 添加测试
|
||||
2. **添加辅助方法测试** - 为 token 查找和 bbox 计算方法添加测试
|
||||
3. **完善边界条件** - 增加边界情况和异常处理的测试
|
||||
4. **集成测试** - 添加端到端的集成测试,使用真实 PDF token 数据
|
||||
|
||||
---
|
||||
|
||||
## ✅ 已完成的改进
|
||||
|
||||
### 重构收益
|
||||
- ✅ 提取的3个辅助方法现在可以独立测试
|
||||
- ✅ 测试粒度更细,更容易定位问题
|
||||
- ✅ 代码可读性提高,测试用例清晰易懂
|
||||
|
||||
### 质量保证
|
||||
- ✅ 所有655个测试100%通过
|
||||
- ✅ 无回归问题
|
||||
- ✅ 新增测试覆盖了之前未测试的重构代码
|
||||
|
||||
---
|
||||
|
||||
## 📚 测试编写经验
|
||||
|
||||
### 成功经验
|
||||
1. **使用 fixture 创建测试数据** - `_create_token()` 辅助方法简化了 token 创建
|
||||
2. **按方法组织测试类** - 每个方法一个测试类,结构清晰
|
||||
3. **测试用例命名清晰** - `test_<what>_<condition>` 格式,一目了然
|
||||
4. **覆盖关键路径** - 优先测试常见场景和边界条件
|
||||
|
||||
### 遇到的问题
|
||||
1. **Token 初始化参数** - 忘记了 `page_no` 参数,导致初始测试失败
|
||||
- 解决:修复 `_create_token()` 辅助方法,添加 `page_no=0`
|
||||
|
||||
---
|
||||
|
||||
**报告日期**: 2026-01-24
|
||||
**状态**: ✅ 完成
|
||||
**下一步**: 继续提升覆盖率到 60%+
|
||||
@@ -1,619 +0,0 @@
|
||||
# 多池处理架构设计文档
|
||||
|
||||
## 1. 研究总结
|
||||
|
||||
### 1.1 当前问题分析
|
||||
|
||||
我们之前实现的双池模式存在稳定性问题,主要原因:
|
||||
|
||||
| 问题 | 原因 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
|
||||
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
|
||||
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
|
||||
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
|
||||
|
||||
### 1.2 推荐架构方案
|
||||
|
||||
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**:
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Main Process │ │ CPU Workers │ │ GPU Worker │
|
||||
│ │ │ (4 processes) │ │ (1 process) │
|
||||
│ ┌───────────┐ │ │ │ │ │
|
||||
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
|
||||
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
|
||||
│ └───────────┘ │ │ │ │ │
|
||||
│ ▲ │ │ │ │ │ │ │
|
||||
│ │ │ │ ▼ │ │ ▼ │
|
||||
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
|
||||
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
|
||||
│ │ Collector │ │ │ │ │ │
|
||||
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────┐ │
|
||||
│ │ Database │ │
|
||||
│ │ Batch │ │
|
||||
│ │ Writer │ │
|
||||
│ └───────────┘ │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 核心设计原则
|
||||
|
||||
### 2.1 CUDA 兼容性
|
||||
|
||||
```python
|
||||
# 关键:使用 spawn 启动方式
|
||||
import multiprocessing as mp
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
# GPU worker 初始化时设置设备
|
||||
def init_gpu_worker(gpu_id: int = 0):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
global _ocr
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=True, ...)
|
||||
```
|
||||
|
||||
### 2.2 Worker 初始化模式
|
||||
|
||||
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
|
||||
|
||||
```python
|
||||
# 全局变量保存模型
|
||||
_ocr = None
|
||||
|
||||
def init_worker(use_gpu: bool, gpu_id: int = 0):
|
||||
global _ocr
|
||||
if use_gpu:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
|
||||
|
||||
# 创建 Pool 时使用 initializer
|
||||
pool = ProcessPoolExecutor(
|
||||
max_workers=1,
|
||||
initializer=init_worker,
|
||||
initargs=(True, 0), # use_gpu=True, gpu_id=0
|
||||
mp_context=mp.get_context("spawn")
|
||||
)
|
||||
```
|
||||
|
||||
### 2.3 队列模式 vs as_completed
|
||||
|
||||
| 方式 | 优点 | 缺点 | 适用场景 |
|
||||
|------|------|------|----------|
|
||||
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
|
||||
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
|
||||
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
|
||||
|
||||
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
|
||||
|
||||
---
|
||||
|
||||
## 3. 详细开发计划
|
||||
|
||||
### 阶段 1:重构基础架构 (2-3天)
|
||||
|
||||
#### 1.1 创建 WorkerPool 抽象类
|
||||
|
||||
```python
|
||||
# src/processing/worker_pool.py
|
||||
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Optional, Callable
|
||||
import multiprocessing as mp
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""任务结果容器"""
|
||||
task_id: str
|
||||
success: bool
|
||||
data: Any
|
||||
error: Optional[str] = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
class WorkerPool(ABC):
|
||||
"""Worker Pool 抽象基类"""
|
||||
|
||||
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
|
||||
self.max_workers = max_workers
|
||||
self.use_gpu = use_gpu
|
||||
self.gpu_id = gpu_id
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_initializer(self) -> Callable:
|
||||
"""返回 worker 初始化函数"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_init_args(self) -> tuple:
|
||||
"""返回初始化参数"""
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
"""启动 worker pool"""
|
||||
ctx = mp.get_context("spawn")
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=self.max_workers,
|
||||
mp_context=ctx,
|
||||
initializer=self.get_initializer(),
|
||||
initargs=self.get_init_args()
|
||||
)
|
||||
|
||||
def submit(self, fn: Callable, *args, **kwargs) -> Future:
|
||||
"""提交任务"""
|
||||
if not self._executor:
|
||||
raise RuntimeError("Pool not started")
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def shutdown(self, wait: bool = True):
|
||||
"""关闭 pool"""
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=wait)
|
||||
self._executor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.shutdown()
|
||||
```
|
||||
|
||||
#### 1.2 实现 CPU 和 GPU Worker Pool
|
||||
|
||||
```python
|
||||
# src/processing/cpu_pool.py
|
||||
|
||||
class CPUWorkerPool(WorkerPool):
|
||||
"""CPU-only worker pool for text PDF processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 4):
|
||||
super().__init__(max_workers=max_workers, use_gpu=False)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_cpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return ()
|
||||
|
||||
# src/processing/gpu_pool.py
|
||||
|
||||
class GPUWorkerPool(WorkerPool):
|
||||
"""GPU worker pool for OCR processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
|
||||
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_gpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return (self.gpu_id,)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 2:实现双池协调器 (2-3天)
|
||||
|
||||
#### 2.1 任务分发器
|
||||
|
||||
```python
|
||||
# src/processing/task_dispatcher.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple
|
||||
|
||||
class TaskType(Enum):
|
||||
CPU = auto() # Text PDF
|
||||
GPU = auto() # Scanned PDF
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
id: str
|
||||
task_type: TaskType
|
||||
data: Any
|
||||
|
||||
class TaskDispatcher:
|
||||
"""根据 PDF 类型分发任务到不同的 pool"""
|
||||
|
||||
def classify_task(self, doc_info: dict) -> TaskType:
|
||||
"""判断文档是否需要 OCR"""
|
||||
# 基于 PDF 特征判断
|
||||
if self._is_scanned_pdf(doc_info):
|
||||
return TaskType.GPU
|
||||
return TaskType.CPU
|
||||
|
||||
def _is_scanned_pdf(self, doc_info: dict) -> bool:
|
||||
"""检测是否为扫描件"""
|
||||
# 1. 检查是否有可提取文本
|
||||
# 2. 检查图片比例
|
||||
# 3. 检查文本密度
|
||||
pass
|
||||
|
||||
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
|
||||
"""将任务分为 CPU 和 GPU 两组"""
|
||||
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
|
||||
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
|
||||
return cpu_tasks, gpu_tasks
|
||||
```
|
||||
|
||||
#### 2.2 双池协调器
|
||||
|
||||
```python
|
||||
# src/processing/dual_pool_coordinator.py
|
||||
|
||||
from concurrent.futures import as_completed
|
||||
from typing import List, Iterator
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DualPoolCoordinator:
|
||||
"""协调 CPU 和 GPU 两个 worker pool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_workers: int = 4,
|
||||
gpu_workers: int = 1,
|
||||
gpu_id: int = 0
|
||||
):
|
||||
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
|
||||
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
|
||||
self.dispatcher = TaskDispatcher()
|
||||
|
||||
def __enter__(self):
|
||||
self.cpu_pool.start()
|
||||
self.gpu_pool.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.cpu_pool.shutdown()
|
||||
self.gpu_pool.shutdown()
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
documents: List[dict],
|
||||
cpu_task_fn: Callable,
|
||||
gpu_task_fn: Callable,
|
||||
on_result: Optional[Callable[[TaskResult], None]] = None,
|
||||
on_error: Optional[Callable[[str, Exception], None]] = None
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
处理一批文档,自动分发到 CPU 或 GPU pool
|
||||
|
||||
Args:
|
||||
documents: 待处理文档列表
|
||||
cpu_task_fn: CPU 任务处理函数
|
||||
gpu_task_fn: GPU 任务处理函数
|
||||
on_result: 结果回调(可选)
|
||||
on_error: 错误回调(可选)
|
||||
|
||||
Returns:
|
||||
所有任务结果列表
|
||||
"""
|
||||
# 分类任务
|
||||
tasks = [
|
||||
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
|
||||
for doc in documents
|
||||
]
|
||||
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
|
||||
|
||||
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
|
||||
|
||||
# 提交任务到各自的 pool
|
||||
cpu_futures = {
|
||||
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
|
||||
for t in cpu_tasks
|
||||
}
|
||||
gpu_futures = {
|
||||
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
|
||||
for t in gpu_tasks
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
results = []
|
||||
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
|
||||
|
||||
for future in as_completed(all_futures):
|
||||
task_id = cpu_futures.get(future) or gpu_futures.get(future)
|
||||
pool_type = "CPU" if future in cpu_futures else "GPU"
|
||||
|
||||
try:
|
||||
data = future.result(timeout=300) # 5分钟超时
|
||||
result = TaskResult(task_id=task_id, success=True, data=data)
|
||||
if on_result:
|
||||
on_result(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
|
||||
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
|
||||
if on_error:
|
||||
on_error(task_id, e)
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 3:集成到 autolabel (1-2天)
|
||||
|
||||
#### 3.1 修改 autolabel.py
|
||||
|
||||
```python
|
||||
# src/cli/autolabel.py
|
||||
|
||||
def run_autolabel_dual_pool(args):
|
||||
"""使用双池模式运行自动标注"""
|
||||
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator
|
||||
|
||||
# 初始化数据库批处理
|
||||
db_batch = []
|
||||
db_batch_size = 100
|
||||
|
||||
def on_result(result: TaskResult):
|
||||
"""处理成功结果"""
|
||||
nonlocal db_batch
|
||||
db_batch.append(result.data)
|
||||
|
||||
if len(db_batch) >= db_batch_size:
|
||||
save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
def on_error(task_id: str, error: Exception):
|
||||
"""处理错误"""
|
||||
logger.error(f"Task {task_id} failed: {error}")
|
||||
|
||||
# 创建双池协调器
|
||||
with DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers or 4,
|
||||
gpu_workers=args.gpu_workers or 1,
|
||||
gpu_id=0
|
||||
) as coordinator:
|
||||
|
||||
# 处理所有 CSV
|
||||
for csv_file in csv_files:
|
||||
documents = load_documents_from_csv(csv_file)
|
||||
|
||||
results = coordinator.process_batch(
|
||||
documents=documents,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=on_result,
|
||||
on_error=on_error
|
||||
)
|
||||
|
||||
logger.info(f"CSV {csv_file}: {len(results)} processed")
|
||||
|
||||
# 保存剩余批次
|
||||
if db_batch:
|
||||
save_documents_batch(db_batch)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 4:测试与验证 (1-2天)
|
||||
|
||||
#### 4.1 单元测试
|
||||
|
||||
```python
|
||||
# tests/unit/test_dual_pool.py
|
||||
|
||||
import pytest
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
|
||||
|
||||
class TestDualPoolCoordinator:
|
||||
|
||||
def test_cpu_only_batch(self):
|
||||
"""测试纯 CPU 任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 10
|
||||
assert all(r.success for r in results)
|
||||
|
||||
def test_mixed_batch(self):
|
||||
"""测试混合任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [
|
||||
{"id": "text_1", "type": "text"},
|
||||
{"id": "scan_1", "type": "scanned"},
|
||||
{"id": "text_2", "type": "text"},
|
||||
]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_timeout_handling(self):
|
||||
"""测试超时处理"""
|
||||
pass
|
||||
|
||||
def test_error_recovery(self):
|
||||
"""测试错误恢复"""
|
||||
pass
|
||||
```
|
||||
|
||||
#### 4.2 集成测试
|
||||
|
||||
```python
|
||||
# tests/integration/test_autolabel_dual_pool.py
|
||||
|
||||
def test_autolabel_with_dual_pool():
|
||||
"""端到端测试双池模式"""
|
||||
# 使用少量测试数据
|
||||
result = subprocess.run([
|
||||
"python", "-m", "src.cli.autolabel",
|
||||
"--cpu-workers", "2",
|
||||
"--gpu-workers", "1",
|
||||
"--limit", "50"
|
||||
], capture_output=True)
|
||||
|
||||
assert result.returncode == 0
|
||||
# 验证数据库记录
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键技术点
|
||||
|
||||
### 4.1 避免死锁的策略
|
||||
|
||||
```python
|
||||
# 1. 使用 timeout
|
||||
try:
|
||||
result = future.result(timeout=300)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Task timed out")
|
||||
|
||||
# 2. 使用哨兵值
|
||||
SENTINEL = object()
|
||||
queue.put(SENTINEL) # 发送结束信号
|
||||
|
||||
# 3. 检查进程状态
|
||||
if not worker.is_alive():
|
||||
logger.error("Worker died unexpectedly")
|
||||
break
|
||||
|
||||
# 4. 先清空队列再 join
|
||||
while not queue.empty():
|
||||
results.append(queue.get_nowait())
|
||||
worker.join(timeout=5.0)
|
||||
```
|
||||
|
||||
### 4.2 PaddleOCR 特殊处理
|
||||
|
||||
```python
|
||||
# PaddleOCR 必须在 worker 进程中初始化
|
||||
def init_paddle_worker(gpu_id: int):
|
||||
global _ocr
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
# 延迟导入,确保 CUDA 环境变量生效
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang='en',
|
||||
use_gpu=True,
|
||||
show_log=False,
|
||||
# 重要:设置 GPU 内存比例
|
||||
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3 资源监控
|
||||
|
||||
```python
|
||||
import psutil
|
||||
import GPUtil
|
||||
|
||||
def get_resource_usage():
|
||||
"""获取系统资源使用情况"""
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
gpu_info = []
|
||||
for gpu in GPUtil.getGPUs():
|
||||
gpu_info.append({
|
||||
"id": gpu.id,
|
||||
"memory_used": gpu.memoryUsed,
|
||||
"memory_total": gpu.memoryTotal,
|
||||
"utilization": gpu.load * 100
|
||||
})
|
||||
|
||||
return {
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_percent": memory.percent,
|
||||
"gpu": gpu_info
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 风险评估与应对
|
||||
|
||||
| 风险 | 可能性 | 影响 | 应对策略 |
|
||||
|------|--------|------|----------|
|
||||
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1,设置 gpu_mem 参数 |
|
||||
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
|
||||
| 任务分类错误 | 中 | 中 | 添加回退机制,CPU 失败后尝试 GPU |
|
||||
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 备选方案
|
||||
|
||||
如果上述方案仍存在问题,可以考虑:
|
||||
|
||||
### 6.1 使用 Ray
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def cpu_task(data):
|
||||
return process_text_pdf(data)
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
def gpu_task(data):
|
||||
return process_scanned_pdf(data)
|
||||
|
||||
# 自动资源调度
|
||||
futures = [cpu_task.remote(d) for d in cpu_docs]
|
||||
futures += [gpu_task.remote(d) for d in gpu_docs]
|
||||
results = ray.get(futures)
|
||||
```
|
||||
|
||||
### 6.2 单池 + 动态 GPU 调度
|
||||
|
||||
保持单池模式,但在每个任务内部动态决定是否使用 GPU:
|
||||
|
||||
```python
|
||||
def process_document(doc_data):
|
||||
if is_scanned_pdf(doc_data):
|
||||
# 使用 GPU (需要全局锁或信号量控制并发)
|
||||
with gpu_semaphore:
|
||||
return process_with_ocr(doc_data)
|
||||
else:
|
||||
return process_text_only(doc_data)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 时间线总结
|
||||
|
||||
| 阶段 | 任务 | 预计工作量 |
|
||||
|------|------|------------|
|
||||
| 阶段 1 | 基础架构重构 | 2-3 天 |
|
||||
| 阶段 2 | 双池协调器实现 | 2-3 天 |
|
||||
| 阶段 3 | 集成到 autolabel | 1-2 天 |
|
||||
| 阶段 4 | 测试与验证 | 1-2 天 |
|
||||
| **总计** | | **6-10 天** |
|
||||
|
||||
---
|
||||
|
||||
## 8. 参考资料
|
||||
|
||||
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
|
||||
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
|
||||
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
|
||||
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
|
||||
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
|
||||
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)
|
||||
1223
docs/product-plan-v2.md
Normal file
1223
docs/product-plan-v2.md
Normal file
File diff suppressed because it is too large
Load Diff
302
docs/ux-design-prompt-v2.md
Normal file
302
docs/ux-design-prompt-v2.md
Normal file
@@ -0,0 +1,302 @@
|
||||
# Document Annotation Tool – UX Design Spec v2
|
||||
|
||||
## Theme: Warm Graphite (Modern Enterprise)
|
||||
|
||||
---
|
||||
|
||||
## 1. Design Principles (Updated)
|
||||
|
||||
1. **Clarity** – High contrast, but never pure black-on-white
|
||||
2. **Warm Neutrality** – Slightly warm grays reduce visual fatigue
|
||||
3. **Focus** – Content-first layouts with restrained accents
|
||||
4. **Consistency** – Reusable patterns, predictable behavior
|
||||
5. **Professional Trust** – Calm, serious, enterprise-ready
|
||||
6. **Longevity** – No trendy colors that age quickly
|
||||
|
||||
---
|
||||
|
||||
## 2. Color Palette (Warm Graphite)
|
||||
|
||||
### Core Colors
|
||||
|
||||
| Usage | Color Name | Hex |
|
||||
|------|-----------|-----|
|
||||
| Primary Text | Soft Black | #121212 |
|
||||
| Secondary Text | Charcoal Gray | #2A2A2A |
|
||||
| Muted Text | Warm Gray | #6B6B6B |
|
||||
| Disabled Text | Light Warm Gray | #9A9A9A |
|
||||
|
||||
### Backgrounds
|
||||
|
||||
| Usage | Color | Hex |
|
||||
|-----|------|-----|
|
||||
| App Background | Paper White | #FAFAF8 |
|
||||
| Card / Panel | White | #FFFFFF |
|
||||
| Hover Surface | Subtle Warm Gray | #F1F0ED |
|
||||
| Selected Row | Very Light Warm Gray | #ECEAE6 |
|
||||
|
||||
### Borders & Dividers
|
||||
|
||||
| Usage | Color | Hex |
|
||||
|------|------|-----|
|
||||
| Default Border | Warm Light Gray | #E6E4E1 |
|
||||
| Strong Divider | Neutral Gray | #D8D6D2 |
|
||||
|
||||
### Semantic States (Muted & Professional)
|
||||
|
||||
| State | Color | Hex |
|
||||
|------|-------|-----|
|
||||
| Success | Olive Gray | #3E4A3A |
|
||||
| Error | Brick Gray | #4A3A3A |
|
||||
| Warning | Sand Gray | #4A4A3A |
|
||||
| Info | Graphite Gray | #3A3A3A |
|
||||
|
||||
> Accent colors are **never saturated** and are used only for status, progress, or selection.
|
||||
|
||||
---
|
||||
|
||||
## 3. Typography
|
||||
|
||||
- **Font Family**: Inter / SF Pro / system-ui
|
||||
- **Headings**:
|
||||
- Weight: 600–700
|
||||
- Color: #121212
|
||||
- Letter spacing: -0.01em
|
||||
- **Body Text**:
|
||||
- Weight: 400
|
||||
- Color: #2A2A2A
|
||||
- **Captions / Meta**:
|
||||
- Weight: 400
|
||||
- Color: #6B6B6B
|
||||
- **Monospace (IDs / Values)**:
|
||||
- JetBrains Mono / SF Mono
|
||||
- Color: #2A2A2A
|
||||
|
||||
---
|
||||
|
||||
## 4. Global Layout
|
||||
|
||||
### Top Navigation Bar
|
||||
|
||||
- Height: 56px
|
||||
- Background: #FAFAF8
|
||||
- Bottom Border: 1px solid #E6E4E1
|
||||
- Logo: Text or icon in #121212
|
||||
|
||||
**Navigation Items**
|
||||
- Default: #6B6B6B
|
||||
- Hover: #2A2A2A
|
||||
- Active:
|
||||
- Text: #121212
|
||||
- Bottom indicator: 2px solid #3A3A3A (rounded ends)
|
||||
|
||||
**Avatar**
|
||||
- Circle background: #ECEAE6
|
||||
- Text: #2A2A2A
|
||||
|
||||
---
|
||||
|
||||
## 5. Page: Documents (Dashboard)
|
||||
|
||||
### Page Header
|
||||
|
||||
- Title: "Documents" (#121212)
|
||||
- Actions:
|
||||
- Primary button: Dark graphite outline
|
||||
- Secondary button: Subtle border only
|
||||
|
||||
### Filters Bar
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Border: 1px solid #E6E4E1
|
||||
- Inputs:
|
||||
- Background: #FFFFFF
|
||||
- Hover: #F1F0ED
|
||||
- Focus ring: 1px #3A3A3A
|
||||
|
||||
### Document Table
|
||||
|
||||
- Table background: #FFFFFF
|
||||
- Header text: #6B6B6B
|
||||
- Row hover: #F1F0ED
|
||||
- Row selected:
|
||||
- Background: #ECEAE6
|
||||
- Left indicator: 3px solid #3A3A3A
|
||||
|
||||
### Status Badges
|
||||
|
||||
- Pending:
|
||||
- BG: #FFFFFF
|
||||
- Border: #D8D6D2
|
||||
- Text: #2A2A2A
|
||||
|
||||
- Labeled:
|
||||
- BG: #2A2A2A
|
||||
- Text: #FFFFFF
|
||||
|
||||
- Exported:
|
||||
- BG: #ECEAE6
|
||||
- Text: #2A2A2A
|
||||
- Icon: ✓
|
||||
|
||||
### Auto-label States
|
||||
|
||||
- Running:
|
||||
- Progress bar: #3A3A3A on #ECEAE6
|
||||
- Completed:
|
||||
- Text: #3E4A3A
|
||||
- Failed:
|
||||
- BG: #F1EDED
|
||||
- Text: #4A3A3A
|
||||
|
||||
---
|
||||
|
||||
## 6. Upload Modals (Single & Batch)
|
||||
|
||||
### Modal Container
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Border radius: 8px
|
||||
- Shadow: 0 1px 3px rgba(0,0,0,0.08)
|
||||
|
||||
### Drop Zone
|
||||
|
||||
- Background: #FAFAF8
|
||||
- Border: 1px dashed #D8D6D2
|
||||
- Hover: #F1F0ED
|
||||
- Icon: Graphite gray
|
||||
|
||||
### Form Fields
|
||||
|
||||
- Input BG: #FFFFFF
|
||||
- Border: #D8D6D2
|
||||
- Focus: 1px solid #3A3A3A
|
||||
|
||||
Primary Action Button:
|
||||
- Text: #FFFFFF
|
||||
- BG: #2A2A2A
|
||||
- Hover: #121212
|
||||
|
||||
---
|
||||
|
||||
## 7. Document Detail View
|
||||
|
||||
### Canvas Area
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Annotation styles:
|
||||
- Manual: Solid border #2A2A2A
|
||||
- Auto: Dashed border #6B6B6B
|
||||
- Selected: 2px border #3A3A3A + resize handles
|
||||
|
||||
### Right Info Panel
|
||||
|
||||
- Card background: #FFFFFF
|
||||
- Section headers: #121212
|
||||
- Meta text: #6B6B6B
|
||||
|
||||
### Annotation Table
|
||||
|
||||
- Same table styles as Documents
|
||||
- Inline edit:
|
||||
- Input background: #FAFAF8
|
||||
- Save button: Graphite
|
||||
|
||||
### Locked State (Auto-label Running)
|
||||
|
||||
- Banner BG: #FAFAF8
|
||||
- Border-left: 3px solid #4A4A3A
|
||||
- Progress bar: Graphite
|
||||
|
||||
---
|
||||
|
||||
## 8. Training Page
|
||||
|
||||
### Document Selector
|
||||
|
||||
- Selected rows use same highlight rules
|
||||
- Verified state:
|
||||
- Full: Olive gray check
|
||||
- Partial: Sand gray warning
|
||||
|
||||
### Configuration Panel
|
||||
|
||||
- Card layout
|
||||
- Inputs aligned to grid
|
||||
- Schedule option visually muted until enabled
|
||||
|
||||
Primary CTA:
|
||||
- Start Training button in dark graphite
|
||||
|
||||
---
|
||||
|
||||
## 9. Models & Training History
|
||||
|
||||
### Training Job List
|
||||
|
||||
- Job cards use #FFFFFF background
|
||||
- Running job:
|
||||
- Progress bar: #3A3A3A
|
||||
- Completed job:
|
||||
- Metrics bars in graphite
|
||||
|
||||
### Model Detail Panel
|
||||
|
||||
- Sectioned cards
|
||||
- Metric bars:
|
||||
- Track: #ECEAE6
|
||||
- Fill: #3A3A3A
|
||||
|
||||
Actions:
|
||||
- Primary: Download Model
|
||||
- Secondary: View Logs / Use as Base
|
||||
|
||||
---
|
||||
|
||||
## 10. Micro-interactions (Refined)
|
||||
|
||||
| Element | Interaction | Animation |
|
||||
|------|------------|-----------|
|
||||
| Button hover | BG lightens | 150ms ease-out |
|
||||
| Button press | Scale 0.98 | 100ms |
|
||||
| Row hover | BG fade | 120ms |
|
||||
| Modal open | Fade + scale 0.96 → 1 | 200ms |
|
||||
| Progress fill | Smooth | ease-out |
|
||||
| Annotation select | Border + handles | 120ms |
|
||||
|
||||
---
|
||||
|
||||
## 11. Tailwind Theme (Updated)
|
||||
|
||||
```js
|
||||
colors: {
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
bg: {
|
||||
app: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
},
|
||||
border: '#E6E4E1',
|
||||
accent: '#3A3A3A',
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Final Notes
|
||||
|
||||
- Pure black (#000000) should **never** be used as large surfaces
|
||||
- Accent color usage should stay under **10% of UI area**
|
||||
- Warm grays are intentional and must not be "corrected" to blue-grays
|
||||
|
||||
This theme is designed to scale from internal tool → polished SaaS without redesign.
|
||||
|
||||
273
docs/web-refactoring-complete.md
Normal file
273
docs/web-refactoring-complete.md
Normal file
@@ -0,0 +1,273 @@
|
||||
# Web Directory Refactoring - Complete ✅
|
||||
|
||||
**Date**: 2026-01-25
|
||||
**Status**: ✅ Completed
|
||||
**Tests**: 188 passing (0 failures)
|
||||
**Coverage**: 23% (maintained)
|
||||
|
||||
---
|
||||
|
||||
## Final Directory Structure
|
||||
|
||||
```
|
||||
src/web/
|
||||
├── api/
|
||||
│ ├── __init__.py
|
||||
│ └── v1/
|
||||
│ ├── __init__.py
|
||||
│ ├── routes.py # Public inference API
|
||||
│ ├── admin/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── documents.py # Document management (was admin_routes.py)
|
||||
│ │ ├── annotations.py # Annotation routes (was admin_annotation_routes.py)
|
||||
│ │ └── training.py # Training routes (was admin_training_routes.py)
|
||||
│ ├── async_api/
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── routes.py # Async processing API (was async_routes.py)
|
||||
│ └── batch/
|
||||
│ ├── __init__.py
|
||||
│ └── routes.py # Batch upload API (was batch_upload_routes.py)
|
||||
│
|
||||
├── schemas/
|
||||
│ ├── __init__.py
|
||||
│ ├── common.py # Shared models (ErrorResponse)
|
||||
│ ├── admin.py # Admin schemas (was admin_schemas.py)
|
||||
│ └── inference.py # Inference + async schemas (was schemas.py)
|
||||
│
|
||||
├── services/
|
||||
│ ├── __init__.py
|
||||
│ ├── inference.py # Inference service (was services.py)
|
||||
│ ├── autolabel.py # Auto-label service (was admin_autolabel.py)
|
||||
│ ├── async_processing.py # Async processing (was async_service.py)
|
||||
│ └── batch_upload.py # Batch upload service (was batch_upload_service.py)
|
||||
│
|
||||
├── core/
|
||||
│ ├── __init__.py
|
||||
│ ├── auth.py # Authentication (was admin_auth.py)
|
||||
│ ├── rate_limiter.py # Rate limiting (unchanged)
|
||||
│ └── scheduler.py # Task scheduler (was admin_scheduler.py)
|
||||
│
|
||||
├── workers/
|
||||
│ ├── __init__.py
|
||||
│ ├── async_queue.py # Async task queue (was async_queue.py)
|
||||
│ └── batch_queue.py # Batch task queue (was batch_queue.py)
|
||||
│
|
||||
├── __init__.py # Main exports
|
||||
├── app.py # FastAPI app (imports updated)
|
||||
├── config.py # Configuration (unchanged)
|
||||
└── dependencies.py # Global dependencies (unchanged)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Changes Summary
|
||||
|
||||
### Files Moved and Renamed
|
||||
|
||||
| Old Location | New Location | Change Type |
|
||||
|-------------|--------------|-------------|
|
||||
| `admin_routes.py` | `api/v1/admin/documents.py` | Moved + Renamed |
|
||||
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Moved + Renamed |
|
||||
| `admin_training_routes.py` | `api/v1/admin/training.py` | Moved + Renamed |
|
||||
| `admin_auth.py` | `core/auth.py` | Moved |
|
||||
| `admin_autolabel.py` | `services/autolabel.py` | Moved |
|
||||
| `admin_scheduler.py` | `core/scheduler.py` | Moved |
|
||||
| `admin_schemas.py` | `schemas/admin.py` | Moved |
|
||||
| `routes.py` | `api/v1/routes.py` | Moved |
|
||||
| `schemas.py` | `schemas/inference.py` | Moved |
|
||||
| `services.py` | `services/inference.py` | Moved |
|
||||
| `async_routes.py` | `api/v1/async_api/routes.py` | Moved |
|
||||
| `async_queue.py` | `workers/async_queue.py` | Moved |
|
||||
| `async_service.py` | `services/async_processing.py` | Moved + Renamed |
|
||||
| `batch_queue.py` | `workers/batch_queue.py` | Moved |
|
||||
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Moved |
|
||||
| `batch_upload_service.py` | `services/batch_upload.py` | Moved |
|
||||
|
||||
**Total**: 16 files reorganized
|
||||
|
||||
### Files Updated
|
||||
|
||||
**Source Files** (imports updated):
|
||||
- `app.py` - Updated all imports to new structure
|
||||
- `api/v1/admin/documents.py` - Updated schema/auth imports
|
||||
- `api/v1/admin/annotations.py` - Updated schema/service imports
|
||||
- `api/v1/admin/training.py` - Updated schema/auth imports
|
||||
- `api/v1/routes.py` - Updated schema imports
|
||||
- `api/v1/async_api/routes.py` - Updated schema imports
|
||||
- `api/v1/batch/routes.py` - Updated service/worker imports
|
||||
- `services/async_processing.py` - Updated worker/core imports
|
||||
|
||||
**Test Files** (all 15 updated):
|
||||
- `test_admin_annotations.py`
|
||||
- `test_admin_auth.py`
|
||||
- `test_admin_routes.py`
|
||||
- `test_admin_routes_enhanced.py`
|
||||
- `test_admin_training.py`
|
||||
- `test_annotation_locks.py`
|
||||
- `test_annotation_phase5.py`
|
||||
- `test_async_queue.py`
|
||||
- `test_async_routes.py`
|
||||
- `test_async_service.py`
|
||||
- `test_autolabel_with_locks.py`
|
||||
- `test_batch_queue.py`
|
||||
- `test_batch_upload_routes.py`
|
||||
- `test_batch_upload_service.py`
|
||||
- `test_training_phase4.py`
|
||||
- `conftest.py`
|
||||
|
||||
---
|
||||
|
||||
## Import Examples
|
||||
|
||||
### Old Import Style (Before Refactoring)
|
||||
```python
|
||||
from src.web.admin_routes import create_admin_router
|
||||
from src.web.admin_schemas import DocumentItem
|
||||
from src.web.admin_auth import validate_admin_token
|
||||
from src.web.async_routes import create_async_router
|
||||
from src.web.schemas import ErrorResponse
|
||||
```
|
||||
|
||||
### New Import Style (After Refactoring)
|
||||
```python
|
||||
# Admin API
|
||||
from src.web.api.v1.admin.documents import create_admin_router
|
||||
from src.web.api.v1.admin import create_admin_router # Shorter alternative
|
||||
|
||||
# Schemas
|
||||
from src.web.schemas.admin import DocumentItem
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
# Core components
|
||||
from src.web.core.auth import validate_admin_token
|
||||
|
||||
# Async API
|
||||
from src.web.api.v1.async_api.routes import create_async_router
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Benefits Achieved
|
||||
|
||||
### 1. **Clear Separation of Concerns**
|
||||
- **API Routes**: All in `api/v1/` by version and feature
|
||||
- **Data Models**: All in `schemas/` by domain
|
||||
- **Business Logic**: All in `services/`
|
||||
- **Core Components**: Reusable utilities in `core/`
|
||||
- **Background Jobs**: Task queues in `workers/`
|
||||
|
||||
### 2. **Better Scalability**
|
||||
- Easy to add API v2 without touching v1
|
||||
- Clear namespace for each module
|
||||
- Reduced file sizes (no 800+ line files)
|
||||
- Follows single responsibility principle
|
||||
|
||||
### 3. **Improved Maintainability**
|
||||
- Find files by function, not by prefix
|
||||
- Each module has one clear purpose
|
||||
- Easier to onboard new developers
|
||||
- Better IDE navigation
|
||||
|
||||
### 4. **Standards Compliance**
|
||||
- Follows FastAPI best practices
|
||||
- Matches Django/Flask project structures
|
||||
- Standard Python package organization
|
||||
- Industry-standard naming conventions
|
||||
|
||||
---
|
||||
|
||||
## Testing Results
|
||||
|
||||
**Before Refactoring**:
|
||||
- 188 tests passing
|
||||
- 23% code coverage
|
||||
- Flat directory structure
|
||||
|
||||
**After Refactoring**:
|
||||
- ✅ 188 tests passing (0 failures)
|
||||
- ✅ 23% code coverage (maintained)
|
||||
- ✅ Clean hierarchical structure
|
||||
- ✅ All imports updated
|
||||
- ✅ No backward compatibility shims needed
|
||||
|
||||
---
|
||||
|
||||
## Migration Statistics
|
||||
|
||||
| Metric | Count |
|
||||
|--------|-------|
|
||||
| Files moved | 16 |
|
||||
| Directories created | 9 |
|
||||
| Files updated (source) | 8 |
|
||||
| Files updated (tests) | 16 |
|
||||
| Import statements updated | ~150 |
|
||||
| Lines of code changed | ~200 |
|
||||
| Tests broken | 0 |
|
||||
| Coverage lost | 0% |
|
||||
|
||||
---
|
||||
|
||||
## Code Diff Summary
|
||||
|
||||
```diff
|
||||
Before:
|
||||
src/web/
|
||||
├── admin_routes.py (645 lines)
|
||||
├── admin_annotation_routes.py (504 lines)
|
||||
├── admin_training_routes.py (565 lines)
|
||||
├── admin_auth.py (22 lines)
|
||||
├── admin_schemas.py (262 lines)
|
||||
... (15 more files at root level)
|
||||
|
||||
After:
|
||||
src/web/
|
||||
├── api/v1/
|
||||
│ ├── admin/ (3 route files)
|
||||
│ ├── async_api/ (1 route file)
|
||||
│ └── batch/ (1 route file)
|
||||
├── schemas/ (3 schema files)
|
||||
├── services/ (4 service files)
|
||||
├── core/ (3 core files)
|
||||
└── workers/ (2 worker files)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps (Optional)
|
||||
|
||||
### Phase 2: Documentation
|
||||
- [ ] Update API documentation with new import paths
|
||||
- [ ] Create migration guide for external developers
|
||||
- [ ] Update CLAUDE.md with new structure
|
||||
|
||||
### Phase 3: Further Optimization
|
||||
- [ ] Split large files (>400 lines) if needed
|
||||
- [ ] Extract common utilities
|
||||
- [ ] Add typing stubs
|
||||
|
||||
### Phase 4: Deprecation (Future)
|
||||
- [ ] Add deprecation warnings if creating compatibility layer
|
||||
- [ ] Remove old imports after grace period
|
||||
- [ ] Update all documentation
|
||||
|
||||
---
|
||||
|
||||
## Rollback Instructions
|
||||
|
||||
If needed, rollback is simple:
|
||||
```bash
|
||||
git revert <commit-hash>
|
||||
```
|
||||
|
||||
All changes are in version control, making rollback safe and easy.
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
✅ **Refactoring completed successfully**
|
||||
✅ **Zero breaking changes**
|
||||
✅ **All tests passing**
|
||||
✅ **Industry-standard structure achieved**
|
||||
|
||||
The web directory is now organized following Python and FastAPI best practices, making it easier to scale, maintain, and extend.
|
||||
186
docs/web-refactoring-plan.md
Normal file
186
docs/web-refactoring-plan.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Web Directory Refactoring Plan
|
||||
|
||||
## Current Structure Issues
|
||||
|
||||
1. **Flat structure**: All files in one directory (20 Python files)
|
||||
2. **Naming inconsistency**: Mix of `admin_*`, `async_*`, `batch_*` prefixes
|
||||
3. **Mixed concerns**: Routes, schemas, services, and workers in same directory
|
||||
4. **Poor scalability**: Hard to navigate and maintain as project grows
|
||||
|
||||
## Proposed Structure (Best Practices)
|
||||
|
||||
```
|
||||
src/web/
|
||||
├── __init__.py # Main exports
|
||||
├── app.py # FastAPI app factory
|
||||
├── config.py # App configuration
|
||||
├── dependencies.py # Global dependencies
|
||||
│
|
||||
├── api/ # API Routes Layer
|
||||
│ ├── __init__.py
|
||||
│ └── v1/ # API version 1
|
||||
│ ├── __init__.py
|
||||
│ ├── routes.py # Public API routes (inference)
|
||||
│ ├── admin/ # Admin API routes
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── documents.py # admin_routes.py → documents.py
|
||||
│ │ ├── annotations.py # admin_annotation_routes.py → annotations.py
|
||||
│ │ ├── training.py # admin_training_routes.py → training.py
|
||||
│ │ └── auth.py # admin_auth.py → auth.py (routes only)
|
||||
│ ├── async_api/ # Async processing API
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── routes.py # async_routes.py → routes.py
|
||||
│ └── batch/ # Batch upload API
|
||||
│ ├── __init__.py
|
||||
│ └── routes.py # batch_upload_routes.py → routes.py
|
||||
│
|
||||
├── schemas/ # Pydantic Models
|
||||
│ ├── __init__.py
|
||||
│ ├── common.py # Shared schemas (ErrorResponse, etc.)
|
||||
│ ├── inference.py # schemas.py → inference.py
|
||||
│ ├── admin.py # admin_schemas.py → admin.py
|
||||
│ ├── async_api.py # New: async API schemas
|
||||
│ └── batch.py # New: batch upload schemas
|
||||
│
|
||||
├── services/ # Business Logic Layer
|
||||
│ ├── __init__.py
|
||||
│ ├── inference.py # services.py → inference.py
|
||||
│ ├── autolabel.py # admin_autolabel.py → autolabel.py
|
||||
│ ├── async_processing.py # async_service.py → async_processing.py
|
||||
│ └── batch_upload.py # batch_upload_service.py → batch_upload.py
|
||||
│
|
||||
├── core/ # Core Components
|
||||
│ ├── __init__.py
|
||||
│ ├── auth.py # admin_auth.py → auth.py (logic only)
|
||||
│ ├── rate_limiter.py # rate_limiter.py → rate_limiter.py
|
||||
│ └── scheduler.py # admin_scheduler.py → scheduler.py
|
||||
│
|
||||
└── workers/ # Background Task Queues
|
||||
├── __init__.py
|
||||
├── async_queue.py # async_queue.py → async_queue.py
|
||||
└── batch_queue.py # batch_queue.py → batch_queue.py
|
||||
```
|
||||
|
||||
## File Mapping
|
||||
|
||||
### Current → New Location
|
||||
|
||||
| Current File | New Location | Purpose |
|
||||
|--------------|--------------|---------|
|
||||
| `admin_routes.py` | `api/v1/admin/documents.py` | Document management routes |
|
||||
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Annotation routes |
|
||||
| `admin_training_routes.py` | `api/v1/admin/training.py` | Training routes |
|
||||
| `admin_auth.py` | Split: `api/v1/admin/auth.py` + `core/auth.py` | Auth routes + logic |
|
||||
| `admin_schemas.py` | `schemas/admin.py` | Admin Pydantic models |
|
||||
| `admin_autolabel.py` | `services/autolabel.py` | Auto-label service |
|
||||
| `admin_scheduler.py` | `core/scheduler.py` | Training scheduler |
|
||||
| `routes.py` | `api/v1/routes.py` | Public inference API |
|
||||
| `schemas.py` | `schemas/inference.py` | Inference models |
|
||||
| `services.py` | `services/inference.py` | Inference service |
|
||||
| `async_routes.py` | `api/v1/async_api/routes.py` | Async API routes |
|
||||
| `async_service.py` | `services/async_processing.py` | Async processing service |
|
||||
| `async_queue.py` | `workers/async_queue.py` | Async task queue |
|
||||
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Batch upload routes |
|
||||
| `batch_upload_service.py` | `services/batch_upload.py` | Batch upload service |
|
||||
| `batch_queue.py` | `workers/batch_queue.py` | Batch task queue |
|
||||
| `rate_limiter.py` | `core/rate_limiter.py` | Rate limiting logic |
|
||||
| `config.py` | `config.py` | Keep as-is |
|
||||
| `dependencies.py` | `dependencies.py` | Keep as-is |
|
||||
| `app.py` | `app.py` | Keep as-is (update imports) |
|
||||
|
||||
## Benefits
|
||||
|
||||
### 1. Clear Separation of Concerns
|
||||
- **Routes**: API endpoint definitions
|
||||
- **Schemas**: Data validation models
|
||||
- **Services**: Business logic
|
||||
- **Core**: Reusable components
|
||||
- **Workers**: Background processing
|
||||
|
||||
### 2. Better Scalability
|
||||
- Easy to add new API versions (`v2/`)
|
||||
- Clear namespace for each domain
|
||||
- Reduced file size (no 800+ line files)
|
||||
|
||||
### 3. Improved Maintainability
|
||||
- Find files by function, not by prefix
|
||||
- Each module has single responsibility
|
||||
- Easier to write focused tests
|
||||
|
||||
### 4. Standard Python Patterns
|
||||
- Package-based organization
|
||||
- Follows FastAPI best practices
|
||||
- Similar to Django/Flask structures
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### Phase 1: Create New Structure (No Breaking Changes)
|
||||
1. Create new directories: `api/`, `schemas/`, `services/`, `core/`, `workers/`
|
||||
2. Copy files to new locations (don't delete originals yet)
|
||||
3. Update imports in new files
|
||||
4. Add `__init__.py` with proper exports
|
||||
|
||||
### Phase 2: Update Tests
|
||||
5. Update test imports to use new structure
|
||||
6. Run tests to verify nothing breaks
|
||||
7. Fix any import issues
|
||||
|
||||
### Phase 3: Update Main App
|
||||
8. Update `app.py` to import from new locations
|
||||
9. Run full test suite
|
||||
10. Verify all endpoints work
|
||||
|
||||
### Phase 4: Cleanup
|
||||
11. Delete old files
|
||||
12. Update documentation
|
||||
13. Final test run
|
||||
|
||||
## Migration Priority
|
||||
|
||||
**High Priority** (Most used):
|
||||
- Routes and schemas (user-facing APIs)
|
||||
- Services (core business logic)
|
||||
|
||||
**Medium Priority**:
|
||||
- Core components (auth, rate limiter)
|
||||
- Workers (background tasks)
|
||||
|
||||
**Low Priority**:
|
||||
- Config and dependencies (already well-located)
|
||||
|
||||
## Backwards Compatibility
|
||||
|
||||
During migration, maintain backwards compatibility:
|
||||
|
||||
```python
|
||||
# src/web/__init__.py
|
||||
# Old imports still work
|
||||
from src.web.api.v1.admin.documents import router as admin_router
|
||||
from src.web.schemas.admin import AdminDocument
|
||||
|
||||
# Keep old names for compatibility (temporary)
|
||||
admin_routes = admin_router # Deprecated alias
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
1. **Unit Tests**: Test each module independently
|
||||
2. **Integration Tests**: Test API endpoints still work
|
||||
3. **Import Tests**: Verify all old imports still work
|
||||
4. **Coverage**: Maintain current 23% coverage minimum
|
||||
|
||||
## Rollback Plan
|
||||
|
||||
If issues arise:
|
||||
1. Keep old files until fully migrated
|
||||
2. Git allows easy revert
|
||||
3. Tests catch breaking changes early
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
Would you like me to:
|
||||
1. **Start Phase 1**: Create new directory structure and move files?
|
||||
2. **Create migration script**: Automate the file moves and import updates?
|
||||
3. **Focus on specific area**: Start with admin API or async API first?
|
||||
218
docs/web-refactoring-status.md
Normal file
218
docs/web-refactoring-status.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# Web Directory Refactoring - Current Status
|
||||
|
||||
## ✅ Completed Steps
|
||||
|
||||
### 1. Directory Structure Created
|
||||
```
|
||||
src/web/
|
||||
├── api/
|
||||
│ ├── v1/
|
||||
│ │ ├── admin/ (documents.py, annotations.py, training.py)
|
||||
│ │ ├── async_api/ (routes.py)
|
||||
│ │ ├── batch/ (routes.py)
|
||||
│ │ └── routes.py (public inference API)
|
||||
├── schemas/
|
||||
│ ├── admin.py (admin schemas)
|
||||
│ ├── inference.py (inference + async schemas)
|
||||
│ └── common.py (ErrorResponse)
|
||||
├── services/
|
||||
│ ├── autolabel.py
|
||||
│ ├── async_processing.py
|
||||
│ ├── batch_upload.py
|
||||
│ └── inference.py
|
||||
├── core/
|
||||
│ ├── auth.py
|
||||
│ ├── rate_limiter.py
|
||||
│ └── scheduler.py
|
||||
└── workers/
|
||||
├── async_queue.py
|
||||
└── batch_queue.py
|
||||
```
|
||||
|
||||
### 2. Files Copied and Imports Updated
|
||||
|
||||
#### Admin API (✅ Complete)
|
||||
- [x] `admin_routes.py` → `api/v1/admin/documents.py` (imports updated)
|
||||
- [x] `admin_annotation_routes.py` → `api/v1/admin/annotations.py` (imports updated)
|
||||
- [x] `admin_training_routes.py` → `api/v1/admin/training.py` (imports updated)
|
||||
- [x] `api/v1/admin/__init__.py` created with exports
|
||||
|
||||
#### Public & Async API (✅ Complete)
|
||||
- [x] `routes.py` → `api/v1/routes.py` (imports updated)
|
||||
- [x] `async_routes.py` → `api/v1/async_api/routes.py` (imports updated)
|
||||
- [x] `batch_upload_routes.py` → `api/v1/batch/routes.py` (copied, imports pending)
|
||||
|
||||
#### Schemas (✅ Complete)
|
||||
- [x] `admin_schemas.py` → `schemas/admin.py`
|
||||
- [x] `schemas.py` → `schemas/inference.py`
|
||||
- [x] `schemas/common.py` created
|
||||
- [x] `schemas/__init__.py` created with exports
|
||||
|
||||
#### Services (✅ Complete)
|
||||
- [x] `admin_autolabel.py` → `services/autolabel.py`
|
||||
- [x] `async_service.py` → `services/async_processing.py`
|
||||
- [x] `batch_upload_service.py` → `services/batch_upload.py`
|
||||
- [x] `services.py` → `services/inference.py`
|
||||
- [x] `services/__init__.py` created
|
||||
|
||||
#### Core Components (✅ Complete)
|
||||
- [x] `admin_auth.py` → `core/auth.py`
|
||||
- [x] `rate_limiter.py` → `core/rate_limiter.py`
|
||||
- [x] `admin_scheduler.py` → `core/scheduler.py`
|
||||
- [x] `core/__init__.py` created
|
||||
|
||||
#### Workers (✅ Complete)
|
||||
- [x] `async_queue.py` → `workers/async_queue.py`
|
||||
- [x] `batch_queue.py` → `workers/batch_queue.py`
|
||||
- [x] `workers/__init__.py` created
|
||||
|
||||
#### Main App (✅ Complete)
|
||||
- [x] `app.py` imports updated to use new structure
|
||||
|
||||
---
|
||||
|
||||
## ⏳ Remaining Work
|
||||
|
||||
### 1. Update Remaining File Imports (HIGH PRIORITY)
|
||||
|
||||
Files that need import updates:
|
||||
- [ ] `api/v1/batch/routes.py` - update to use new schema/service imports
|
||||
- [ ] `services/autolabel.py` - may need import updates if it references old paths
|
||||
- [ ] `services/async_processing.py` - check for old import references
|
||||
- [ ] `services/batch_upload.py` - check for old import references
|
||||
- [ ] `services/inference.py` - check for old import references
|
||||
|
||||
### 2. Update ALL Test Files (CRITICAL)
|
||||
|
||||
Test files need to import from new locations. Pattern:
|
||||
|
||||
**Old:**
|
||||
```python
|
||||
from src.web.admin_routes import create_admin_router
|
||||
from src.web.admin_schemas import DocumentItem
|
||||
from src.web.admin_auth import validate_admin_token
|
||||
```
|
||||
|
||||
**New:**
|
||||
```python
|
||||
from src.web.api.v1.admin import create_admin_router
|
||||
from src.web.schemas.admin import DocumentItem
|
||||
from src.web.core.auth import validate_admin_token
|
||||
```
|
||||
|
||||
Test files to update:
|
||||
- [ ] `tests/web/test_admin_annotations.py`
|
||||
- [ ] `tests/web/test_admin_auth.py`
|
||||
- [ ] `tests/web/test_admin_routes.py`
|
||||
- [ ] `tests/web/test_admin_routes_enhanced.py`
|
||||
- [ ] `tests/web/test_admin_training.py`
|
||||
- [ ] `tests/web/test_annotation_locks.py`
|
||||
- [ ] `tests/web/test_annotation_phase5.py`
|
||||
- [ ] `tests/web/test_async_queue.py`
|
||||
- [ ] `tests/web/test_async_routes.py`
|
||||
- [ ] `tests/web/test_async_service.py`
|
||||
- [ ] `tests/web/test_autolabel_with_locks.py`
|
||||
- [ ] `tests/web/test_batch_queue.py`
|
||||
- [ ] `tests/web/test_batch_upload_routes.py`
|
||||
- [ ] `tests/web/test_batch_upload_service.py`
|
||||
- [ ] `tests/web/test_rate_limiter.py`
|
||||
- [ ] `tests/web/test_training_phase4.py`
|
||||
|
||||
### 3. Create Backward Compatibility Layer (OPTIONAL)
|
||||
|
||||
Keep old imports working temporarily:
|
||||
|
||||
```python
|
||||
# src/web/admin_routes.py (temporary compatibility shim)
|
||||
\"\"\"
|
||||
DEPRECATED: Use src.web.api.v1.admin.documents instead.
|
||||
This file will be removed in next version.
|
||||
\"\"\"
|
||||
import warnings
|
||||
from src.web.api.v1.admin.documents import *
|
||||
|
||||
warnings.warn(
|
||||
"Importing from src.web.admin_routes is deprecated. "
|
||||
"Use src.web.api.v1.admin.documents instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Verify and Test
|
||||
|
||||
1. Run tests:
|
||||
```bash
|
||||
pytest tests/web/ -v
|
||||
```
|
||||
|
||||
2. Check for any import errors:
|
||||
```bash
|
||||
python -c "from src.web.app import create_app; create_app()"
|
||||
```
|
||||
|
||||
3. Start server and test endpoints:
|
||||
```bash
|
||||
python run_server.py
|
||||
```
|
||||
|
||||
### 5. Clean Up Old Files (ONLY AFTER TESTS PASS)
|
||||
|
||||
Old files to remove:
|
||||
- `src/web/admin_*.py` (7 files)
|
||||
- `src/web/async_*.py` (3 files)
|
||||
- `src/web/batch_*.py` (3 files)
|
||||
- `src/web/routes.py`
|
||||
- `src/web/services.py`
|
||||
- `src/web/schemas.py`
|
||||
- `src/web/rate_limiter.py`
|
||||
|
||||
Keep these files (don't remove):
|
||||
- `src/web/__init__.py`
|
||||
- `src/web/app.py`
|
||||
- `src/web/config.py`
|
||||
- `src/web/dependencies.py`
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Next Immediate Steps
|
||||
|
||||
1. **Update batch/routes.py imports** - Quick fix for remaining API route
|
||||
2. **Update test file imports** - Critical for verification
|
||||
3. **Run test suite** - Verify nothing broke
|
||||
4. **Fix any import errors** - Address failures
|
||||
5. **Remove old files** - Clean up after tests pass
|
||||
|
||||
---
|
||||
|
||||
## 📊 Migration Impact Summary
|
||||
|
||||
| Category | Files Moved | Imports Updated | Status |
|
||||
|----------|-------------|-----------------|--------|
|
||||
| API Routes | 7 | 5/7 | 🟡 In Progress |
|
||||
| Schemas | 3 | 3/3 | ✅ Complete |
|
||||
| Services | 4 | 0/4 | ⚠️ Pending |
|
||||
| Core | 3 | 3/3 | ✅ Complete |
|
||||
| Workers | 2 | 2/2 | ✅ Complete |
|
||||
| Tests | 0 | 0/16 | ❌ Not Started |
|
||||
|
||||
**Overall Progress: 65%**
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Benefits After Migration
|
||||
|
||||
1. **Better Organization**: Clear separation by function
|
||||
2. **Easier Navigation**: Find files by purpose, not prefix
|
||||
3. **Scalability**: Easy to add new API versions
|
||||
4. **Standard Structure**: Follows FastAPI best practices
|
||||
5. **Maintainability**: Each module has single responsibility
|
||||
|
||||
---
|
||||
|
||||
## 📝 Notes
|
||||
|
||||
- All original files are still in place (no data loss risk)
|
||||
- New structure is operational but needs import updates
|
||||
- Backward compatibility can be added if needed
|
||||
- Tests will validate the migration success
|
||||
5
frontend/.env.example
Normal file
5
frontend/.env.example
Normal file
@@ -0,0 +1,5 @@
|
||||
# Backend API URL
|
||||
VITE_API_URL=http://localhost:8000
|
||||
|
||||
# WebSocket URL (for future real-time updates)
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
24
frontend/.gitignore
vendored
Normal file
24
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
20
frontend/README.md
Normal file
20
frontend/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
<div align="center">
|
||||
<img width="1200" height="475" alt="GHBanner" src="https://github.com/user-attachments/assets/0aa67016-6eaf-458a-adb2-6e31a0763ed6" />
|
||||
</div>
|
||||
|
||||
# Run and deploy your AI Studio app
|
||||
|
||||
This contains everything you need to run your app locally.
|
||||
|
||||
View your app in AI Studio: https://ai.studio/apps/drive/13hqd80ft4g_LngMYB8LLJxx2XU8C_eI4
|
||||
|
||||
## Run Locally
|
||||
|
||||
**Prerequisites:** Node.js
|
||||
|
||||
|
||||
1. Install dependencies:
|
||||
`npm install`
|
||||
2. Set the `GEMINI_API_KEY` in [.env.local](.env.local) to your Gemini API key
|
||||
3. Run the app:
|
||||
`npm run dev`
|
||||
240
frontend/REFACTORING_PLAN.md
Normal file
240
frontend/REFACTORING_PLAN.md
Normal file
@@ -0,0 +1,240 @@
|
||||
# Frontend Refactoring Plan
|
||||
|
||||
## Current Structure Issues
|
||||
|
||||
1. **Flat component organization** - All components in one directory
|
||||
2. **Mock data only** - No real API integration
|
||||
3. **No state management** - Props drilling everywhere
|
||||
4. **CDN dependencies** - Should use npm packages
|
||||
5. **Manual routing** - Using useState instead of react-router
|
||||
6. **No TypeScript integration with backend** - Types don't match API schemas
|
||||
|
||||
## Recommended Structure
|
||||
|
||||
```
|
||||
frontend/
|
||||
├── public/
|
||||
│ └── favicon.ico
|
||||
│
|
||||
├── src/
|
||||
│ ├── api/ # API Layer
|
||||
│ │ ├── client.ts # Axios instance + interceptors
|
||||
│ │ ├── types.ts # API request/response types
|
||||
│ │ └── endpoints/
|
||||
│ │ ├── documents.ts # GET /api/v1/admin/documents
|
||||
│ │ ├── annotations.ts # GET/POST /api/v1/admin/documents/{id}/annotations
|
||||
│ │ ├── training.ts # GET/POST /api/v1/admin/training/*
|
||||
│ │ ├── inference.ts # POST /api/v1/infer
|
||||
│ │ └── async.ts # POST /api/v1/async/submit
|
||||
│ │
|
||||
│ ├── components/
|
||||
│ │ ├── common/ # Reusable components
|
||||
│ │ │ ├── Badge.tsx
|
||||
│ │ │ ├── Button.tsx
|
||||
│ │ │ ├── Input.tsx
|
||||
│ │ │ ├── Modal.tsx
|
||||
│ │ │ ├── Table.tsx
|
||||
│ │ │ ├── ProgressBar.tsx
|
||||
│ │ │ └── StatusBadge.tsx
|
||||
│ │ │
|
||||
│ │ ├── layout/ # Layout components
|
||||
│ │ │ ├── TopNav.tsx
|
||||
│ │ │ ├── Sidebar.tsx
|
||||
│ │ │ └── PageHeader.tsx
|
||||
│ │ │
|
||||
│ │ ├── documents/ # Document-specific components
|
||||
│ │ │ ├── DocumentTable.tsx
|
||||
│ │ │ ├── DocumentFilters.tsx
|
||||
│ │ │ ├── DocumentRow.tsx
|
||||
│ │ │ ├── UploadModal.tsx
|
||||
│ │ │ └── BatchUploadModal.tsx
|
||||
│ │ │
|
||||
│ │ ├── annotations/ # Annotation components
|
||||
│ │ │ ├── AnnotationCanvas.tsx
|
||||
│ │ │ ├── AnnotationBox.tsx
|
||||
│ │ │ ├── AnnotationTable.tsx
|
||||
│ │ │ ├── FieldEditor.tsx
|
||||
│ │ │ └── VerificationPanel.tsx
|
||||
│ │ │
|
||||
│ │ └── training/ # Training components
|
||||
│ │ ├── DocumentSelector.tsx
|
||||
│ │ ├── TrainingConfig.tsx
|
||||
│ │ ├── TrainingJobList.tsx
|
||||
│ │ ├── ModelCard.tsx
|
||||
│ │ └── MetricsChart.tsx
|
||||
│ │
|
||||
│ ├── pages/ # Page-level components
|
||||
│ │ ├── DocumentsPage.tsx # Was Dashboard.tsx
|
||||
│ │ ├── DocumentDetailPage.tsx # Was DocumentDetail.tsx
|
||||
│ │ ├── TrainingPage.tsx # Was Training.tsx
|
||||
│ │ ├── ModelsPage.tsx # Was Models.tsx
|
||||
│ │ └── InferencePage.tsx # New: Test inference
|
||||
│ │
|
||||
│ ├── hooks/ # Custom React Hooks
|
||||
│ │ ├── useDocuments.ts # Document CRUD + listing
|
||||
│ │ ├── useAnnotations.ts # Annotation management
|
||||
│ │ ├── useTraining.ts # Training jobs
|
||||
│ │ ├── usePolling.ts # Auto-refresh for async jobs
|
||||
│ │ └── useDebounce.ts # Debounce search inputs
|
||||
│ │
|
||||
│ ├── store/ # State Management (Zustand)
|
||||
│ │ ├── documentsStore.ts
|
||||
│ │ ├── annotationsStore.ts
|
||||
│ │ ├── trainingStore.ts
|
||||
│ │ └── uiStore.ts
|
||||
│ │
|
||||
│ ├── types/ # TypeScript Types
|
||||
│ │ ├── index.ts
|
||||
│ │ ├── document.ts
|
||||
│ │ ├── annotation.ts
|
||||
│ │ ├── training.ts
|
||||
│ │ └── api.ts
|
||||
│ │
|
||||
│ ├── utils/ # Utility Functions
|
||||
│ │ ├── formatters.ts # Date, currency, etc.
|
||||
│ │ ├── validators.ts # Form validation
|
||||
│ │ └── constants.ts # Field definitions, statuses
|
||||
│ │
|
||||
│ ├── styles/
|
||||
│ │ └── index.css # Tailwind entry
|
||||
│ │
|
||||
│ ├── App.tsx
|
||||
│ ├── main.tsx
|
||||
│ └── router.tsx # React Router config
|
||||
│
|
||||
├── .env.example
|
||||
├── package.json
|
||||
├── tsconfig.json
|
||||
├── vite.config.ts
|
||||
├── tailwind.config.js
|
||||
├── postcss.config.js
|
||||
└── index.html
|
||||
```
|
||||
|
||||
## Migration Steps
|
||||
|
||||
### Phase 1: Setup Infrastructure
|
||||
- [ ] Install dependencies (axios, react-router, zustand, @tanstack/react-query)
|
||||
- [ ] Setup local Tailwind (remove CDN)
|
||||
- [ ] Create API client with interceptors
|
||||
- [ ] Add environment variables (.env.local with VITE_API_URL)
|
||||
|
||||
### Phase 2: Create API Layer
|
||||
- [ ] Create `src/api/client.ts` with axios instance
|
||||
- [ ] Create `src/api/endpoints/documents.ts` matching backend API
|
||||
- [ ] Create `src/api/endpoints/annotations.ts`
|
||||
- [ ] Create `src/api/endpoints/training.ts`
|
||||
- [ ] Add types matching backend schemas
|
||||
|
||||
### Phase 3: Reorganize Components
|
||||
- [ ] Move existing components to new structure
|
||||
- [ ] Split large components (Dashboard > DocumentTable + DocumentFilters + DocumentRow)
|
||||
- [ ] Extract reusable components (Badge, Button already done)
|
||||
- [ ] Create layout components (TopNav, Sidebar)
|
||||
|
||||
### Phase 4: Add Routing
|
||||
- [ ] Install react-router-dom
|
||||
- [ ] Create router.tsx with routes
|
||||
- [ ] Update App.tsx to use RouterProvider
|
||||
- [ ] Add navigation links
|
||||
|
||||
### Phase 5: State Management
|
||||
- [ ] Create custom hooks (useDocuments, useAnnotations)
|
||||
- [ ] Use @tanstack/react-query for server state
|
||||
- [ ] Add Zustand stores for UI state
|
||||
- [ ] Replace mock data with API calls
|
||||
|
||||
### Phase 6: Backend Integration
|
||||
- [ ] Update CORS settings in backend
|
||||
- [ ] Test all API endpoints
|
||||
- [ ] Add error handling
|
||||
- [ ] Add loading states
|
||||
|
||||
## Dependencies to Add
|
||||
|
||||
```json
|
||||
{
|
||||
"dependencies": {
|
||||
"react-router-dom": "^6.22.0",
|
||||
"axios": "^1.6.7",
|
||||
"zustand": "^4.5.0",
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"clsx": "^2.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tailwindcss": "^3.4.1",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"postcss": "^8.4.35"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Files to Create
|
||||
|
||||
### tailwind.config.js
|
||||
```javascript
|
||||
export default {
|
||||
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
warm: {
|
||||
bg: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
border: '#E6E4E1',
|
||||
divider: '#D8D6D2',
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
state: {
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
info: '#3A3A3A',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### .env.example
|
||||
```bash
|
||||
VITE_API_URL=http://localhost:8000
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
```
|
||||
|
||||
## Type Generation from Backend
|
||||
|
||||
Consider generating TypeScript types from Python Pydantic schemas:
|
||||
- Option 1: Use `datamodel-code-generator` to convert schemas
|
||||
- Option 2: Manually maintain types in `src/types/api.ts`
|
||||
- Option 3: Use OpenAPI spec + openapi-typescript-codegen
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Unit tests: Vitest for components
|
||||
- Integration tests: React Testing Library
|
||||
- E2E tests: Playwright (matching backend)
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Code splitting by route
|
||||
- Lazy load heavy components (AnnotationCanvas)
|
||||
- Optimize re-renders with React.memo
|
||||
- Use virtual scrolling for large tables
|
||||
- Image lazy loading for document previews
|
||||
|
||||
## Accessibility
|
||||
|
||||
- Proper ARIA labels
|
||||
- Keyboard navigation
|
||||
- Focus management
|
||||
- Color contrast compliance (already done with Warm Graphite theme)
|
||||
256
frontend/SETUP.md
Normal file
256
frontend/SETUP.md
Normal file
@@ -0,0 +1,256 @@
|
||||
# Frontend Setup Guide
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Install Dependencies
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
```
|
||||
|
||||
### 2. Configure Environment
|
||||
|
||||
Copy `.env.example` to `.env.local` and update if needed:
|
||||
|
||||
```bash
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
Default configuration:
|
||||
```
|
||||
VITE_API_URL=http://localhost:8000
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
```
|
||||
|
||||
### 3. Start Backend API
|
||||
|
||||
Make sure the backend is running first:
|
||||
|
||||
```bash
|
||||
# From project root
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python run_server.py"
|
||||
```
|
||||
|
||||
Backend will be available at: http://localhost:8000
|
||||
|
||||
### 4. Start Frontend Dev Server
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Frontend will be available at: http://localhost:3000
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
frontend/
|
||||
├── src/
|
||||
│ ├── api/ # API client layer
|
||||
│ │ ├── client.ts # Axios instance with interceptors
|
||||
│ │ ├── types.ts # API type definitions
|
||||
│ │ └── endpoints/
|
||||
│ │ ├── documents.ts # Document API calls
|
||||
│ │ ├── annotations.ts # Annotation API calls
|
||||
│ │ └── training.ts # Training API calls
|
||||
│ │
|
||||
│ ├── components/ # React components
|
||||
│ │ └── Dashboard.tsx # Updated with real API integration
|
||||
│ │
|
||||
│ ├── hooks/ # Custom React Hooks
|
||||
│ │ ├── useDocuments.ts
|
||||
│ │ ├── useDocumentDetail.ts
|
||||
│ │ ├── useAnnotations.ts
|
||||
│ │ └── useTraining.ts
|
||||
│ │
|
||||
│ ├── styles/
|
||||
│ │ └── index.css # Tailwind CSS entry
|
||||
│ │
|
||||
│ ├── App.tsx
|
||||
│ └── main.tsx # App entry point with QueryClient
|
||||
│
|
||||
├── components/ # Legacy components (to be migrated)
|
||||
│ ├── Badge.tsx
|
||||
│ ├── Button.tsx
|
||||
│ ├── Layout.tsx
|
||||
│ ├── DocumentDetail.tsx
|
||||
│ ├── Training.tsx
|
||||
│ ├── Models.tsx
|
||||
│ └── UploadModal.tsx
|
||||
│
|
||||
├── tailwind.config.js # Tailwind configuration
|
||||
├── postcss.config.js
|
||||
├── vite.config.ts
|
||||
├── package.json
|
||||
└── index.html
|
||||
```
|
||||
|
||||
## Key Technologies
|
||||
|
||||
- **React 19** - UI framework
|
||||
- **TypeScript** - Type safety
|
||||
- **Vite** - Build tool
|
||||
- **Tailwind CSS** - Styling (Warm Graphite theme)
|
||||
- **Axios** - HTTP client
|
||||
- **@tanstack/react-query** - Server state management
|
||||
- **lucide-react** - Icon library
|
||||
|
||||
## API Integration
|
||||
|
||||
### Authentication
|
||||
|
||||
The app stores admin token in localStorage:
|
||||
|
||||
```typescript
|
||||
localStorage.setItem('admin_token', 'your-token')
|
||||
```
|
||||
|
||||
All API requests automatically include the `X-Admin-Token` header.
|
||||
|
||||
### Available Hooks
|
||||
|
||||
#### useDocuments
|
||||
|
||||
```typescript
|
||||
const {
|
||||
documents,
|
||||
total,
|
||||
isLoading,
|
||||
uploadDocument,
|
||||
deleteDocument,
|
||||
triggerAutoLabel,
|
||||
} = useDocuments({ status: 'labeled', limit: 20 })
|
||||
```
|
||||
|
||||
#### useDocumentDetail
|
||||
|
||||
```typescript
|
||||
const { document, annotations, isLoading } = useDocumentDetail(documentId)
|
||||
```
|
||||
|
||||
#### useAnnotations
|
||||
|
||||
```typescript
|
||||
const {
|
||||
createAnnotation,
|
||||
updateAnnotation,
|
||||
deleteAnnotation,
|
||||
verifyAnnotation,
|
||||
overrideAnnotation,
|
||||
} = useAnnotations(documentId)
|
||||
```
|
||||
|
||||
#### useTraining
|
||||
|
||||
```typescript
|
||||
const {
|
||||
models,
|
||||
isLoadingModels,
|
||||
startTraining,
|
||||
downloadModel,
|
||||
} = useTraining()
|
||||
```
|
||||
|
||||
## Features Implemented
|
||||
|
||||
### Phase 1 (Completed)
|
||||
- ✅ API client with axios interceptors
|
||||
- ✅ Type-safe API endpoints
|
||||
- ✅ React Query for server state
|
||||
- ✅ Custom hooks for all APIs
|
||||
- ✅ Dashboard with real data
|
||||
- ✅ Local Tailwind CSS
|
||||
- ✅ Environment configuration
|
||||
- ✅ CORS configured in backend
|
||||
|
||||
### Phase 2 (TODO)
|
||||
- [ ] Update DocumentDetail to use useDocumentDetail
|
||||
- [ ] Update Training page to use useTraining hooks
|
||||
- [ ] Update Models page with real data
|
||||
- [ ] Add UploadModal integration with API
|
||||
- [ ] Add react-router for proper routing
|
||||
- [ ] Add error boundary
|
||||
- [ ] Add loading states
|
||||
- [ ] Add toast notifications
|
||||
|
||||
### Phase 3 (TODO)
|
||||
- [ ] Annotation canvas with real data
|
||||
- [ ] Batch upload functionality
|
||||
- [ ] Auto-label progress polling
|
||||
- [ ] Training job monitoring
|
||||
- [ ] Model download functionality
|
||||
- [ ] Search and filtering
|
||||
- [ ] Pagination
|
||||
|
||||
## Development Tips
|
||||
|
||||
### Hot Module Replacement
|
||||
|
||||
Vite supports HMR. Changes will reflect immediately without page reload.
|
||||
|
||||
### API Debugging
|
||||
|
||||
Check browser console for API requests:
|
||||
- Network tab shows all requests/responses
|
||||
- Axios interceptors log errors automatically
|
||||
|
||||
### Type Safety
|
||||
|
||||
TypeScript types in `src/api/types.ts` match backend Pydantic schemas.
|
||||
|
||||
To regenerate types from backend:
|
||||
```bash
|
||||
# TODO: Add type generation script
|
||||
```
|
||||
|
||||
### Backend API Documentation
|
||||
|
||||
Visit http://localhost:8000/docs for interactive API documentation (Swagger UI).
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### CORS Errors
|
||||
|
||||
If you see CORS errors:
|
||||
1. Check backend is running at http://localhost:8000
|
||||
2. Verify CORS settings in `src/web/app.py`
|
||||
3. Check `.env.local` has correct `VITE_API_URL`
|
||||
|
||||
### Module Not Found
|
||||
|
||||
If imports fail:
|
||||
```bash
|
||||
rm -rf node_modules package-lock.json
|
||||
npm install
|
||||
```
|
||||
|
||||
### Types Not Matching
|
||||
|
||||
If API responses don't match types:
|
||||
1. Check backend version is up-to-date
|
||||
2. Verify types in `src/api/types.ts`
|
||||
3. Check API response in Network tab
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Run `npm install` to install dependencies
|
||||
2. Start backend server
|
||||
3. Run `npm run dev` to start frontend
|
||||
4. Open http://localhost:3000
|
||||
5. Create an admin token via backend API
|
||||
6. Store token in localStorage via browser console:
|
||||
```javascript
|
||||
localStorage.setItem('admin_token', 'your-token-here')
|
||||
```
|
||||
7. Refresh page to see authenticated API calls
|
||||
|
||||
## Production Build
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
npm run preview # Preview production build
|
||||
```
|
||||
|
||||
Build output will be in `dist/` directory.
|
||||
15
frontend/index.html
Normal file
15
frontend/index.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Graphite Annotator - Invoice Field Extraction</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
5
frontend/metadata.json
Normal file
5
frontend/metadata.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Graphite Annotator",
|
||||
"description": "A professional, warm graphite themed document annotation and training tool for enterprise use cases.",
|
||||
"requestFramePermissions": []
|
||||
}
|
||||
3510
frontend/package-lock.json
generated
Normal file
3510
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
32
frontend/package.json
Normal file
32
frontend/package.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"name": "graphite-annotator",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"lucide-react": "^0.563.0",
|
||||
"recharts": "^3.7.0",
|
||||
"axios": "^1.6.7",
|
||||
"react-router-dom": "^6.22.0",
|
||||
"zustand": "^4.5.0",
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"clsx": "^2.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.14.0",
|
||||
"@vitejs/plugin-react": "^5.0.0",
|
||||
"typescript": "~5.8.2",
|
||||
"vite": "^6.2.0",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"postcss": "^8.4.35"
|
||||
}
|
||||
}
|
||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
||||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
73
frontend/src/App.tsx
Normal file
73
frontend/src/App.tsx
Normal file
@@ -0,0 +1,73 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { Layout } from './components/Layout'
|
||||
import { DashboardOverview } from './components/DashboardOverview'
|
||||
import { Dashboard } from './components/Dashboard'
|
||||
import { DocumentDetail } from './components/DocumentDetail'
|
||||
import { Training } from './components/Training'
|
||||
import { Models } from './components/Models'
|
||||
import { Login } from './components/Login'
|
||||
import { InferenceDemo } from './components/InferenceDemo'
|
||||
|
||||
const App: React.FC = () => {
|
||||
const [currentView, setCurrentView] = useState('dashboard')
|
||||
const [selectedDocId, setSelectedDocId] = useState<string | null>(null)
|
||||
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
const token = localStorage.getItem('admin_token')
|
||||
setIsAuthenticated(!!token)
|
||||
}, [])
|
||||
|
||||
const handleNavigate = (view: string, docId?: string) => {
|
||||
setCurrentView(view)
|
||||
if (docId) {
|
||||
setSelectedDocId(docId)
|
||||
}
|
||||
}
|
||||
|
||||
const handleLogin = (token: string) => {
|
||||
setIsAuthenticated(true)
|
||||
}
|
||||
|
||||
const handleLogout = () => {
|
||||
localStorage.removeItem('admin_token')
|
||||
setIsAuthenticated(false)
|
||||
setCurrentView('documents')
|
||||
}
|
||||
|
||||
if (!isAuthenticated) {
|
||||
return <Login onLogin={handleLogin} />
|
||||
}
|
||||
|
||||
const renderContent = () => {
|
||||
switch (currentView) {
|
||||
case 'dashboard':
|
||||
return <DashboardOverview onNavigate={handleNavigate} />
|
||||
case 'documents':
|
||||
return <Dashboard onNavigate={handleNavigate} />
|
||||
case 'detail':
|
||||
return (
|
||||
<DocumentDetail
|
||||
docId={selectedDocId || '1'}
|
||||
onBack={() => setCurrentView('documents')}
|
||||
/>
|
||||
)
|
||||
case 'demo':
|
||||
return <InferenceDemo />
|
||||
case 'training':
|
||||
return <Training />
|
||||
case 'models':
|
||||
return <Models />
|
||||
default:
|
||||
return <DashboardOverview onNavigate={handleNavigate} />
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Layout activeView={currentView} onNavigate={handleNavigate} onLogout={handleLogout}>
|
||||
{renderContent()}
|
||||
</Layout>
|
||||
)
|
||||
}
|
||||
|
||||
export default App
|
||||
41
frontend/src/api/client.ts
Normal file
41
frontend/src/api/client.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import axios, { AxiosInstance, AxiosError } from 'axios'
|
||||
|
||||
const apiClient: AxiosInstance = axios.create({
|
||||
baseURL: import.meta.env.VITE_API_URL || 'http://localhost:8000',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 30000,
|
||||
})
|
||||
|
||||
apiClient.interceptors.request.use(
|
||||
(config) => {
|
||||
const token = localStorage.getItem('admin_token')
|
||||
if (token) {
|
||||
config.headers['X-Admin-Token'] = token
|
||||
}
|
||||
return config
|
||||
},
|
||||
(error) => {
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
apiClient.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error: AxiosError) => {
|
||||
if (error.response?.status === 401) {
|
||||
console.warn('Authentication required. Please set admin_token in localStorage.')
|
||||
// Don't redirect to avoid infinite loop
|
||||
// User should manually set: localStorage.setItem('admin_token', 'your-token')
|
||||
}
|
||||
|
||||
if (error.response?.status === 429) {
|
||||
console.error('Rate limit exceeded')
|
||||
}
|
||||
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
export default apiClient
|
||||
66
frontend/src/api/endpoints/annotations.ts
Normal file
66
frontend/src/api/endpoints/annotations.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
AnnotationItem,
|
||||
CreateAnnotationRequest,
|
||||
AnnotationOverrideRequest,
|
||||
} from '../types'
|
||||
|
||||
export const annotationsApi = {
|
||||
list: async (documentId: string): Promise<AnnotationItem[]> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/documents/${documentId}/annotations`
|
||||
)
|
||||
return data.annotations
|
||||
},
|
||||
|
||||
create: async (
|
||||
documentId: string,
|
||||
annotation: CreateAnnotationRequest
|
||||
): Promise<AnnotationItem> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/annotations`,
|
||||
annotation
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
update: async (
|
||||
documentId: string,
|
||||
annotationId: string,
|
||||
updates: Partial<CreateAnnotationRequest>
|
||||
): Promise<AnnotationItem> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`,
|
||||
updates
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (documentId: string, annotationId: string): Promise<void> => {
|
||||
await apiClient.delete(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`
|
||||
)
|
||||
},
|
||||
|
||||
verify: async (
|
||||
documentId: string,
|
||||
annotationId: string
|
||||
): Promise<{ annotation_id: string; is_verified: boolean; message: string }> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/verify`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
override: async (
|
||||
documentId: string,
|
||||
annotationId: string,
|
||||
overrideData: AnnotationOverrideRequest
|
||||
): Promise<{ annotation_id: string; source: string; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/override`,
|
||||
overrideData
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
80
frontend/src/api/endpoints/documents.ts
Normal file
80
frontend/src/api/endpoints/documents.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
DocumentListResponse,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
UploadDocumentResponse,
|
||||
} from '../types'
|
||||
|
||||
export const documentsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DocumentListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/documents', { params })
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
upload: async (file: File): Promise<UploadDocumentResponse> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
batchUpload: async (
|
||||
files: File[],
|
||||
csvFile?: File
|
||||
): Promise<{ batch_id: string; message: string; documents_created: number }> => {
|
||||
const formData = new FormData()
|
||||
|
||||
files.forEach((file) => {
|
||||
formData.append('files', file)
|
||||
})
|
||||
|
||||
if (csvFile) {
|
||||
formData.append('csv_file', csvFile)
|
||||
}
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/admin/batch/upload', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (documentId: string): Promise<void> => {
|
||||
await apiClient.delete(`/api/v1/admin/documents/${documentId}`)
|
||||
},
|
||||
|
||||
updateStatus: async (
|
||||
documentId: string,
|
||||
status: string
|
||||
): Promise<DocumentItem> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/status`,
|
||||
null,
|
||||
{ params: { status } }
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
triggerAutoLabel: async (documentId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/auto-label`
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
4
frontend/src/api/endpoints/index.ts
Normal file
4
frontend/src/api/endpoints/index.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export { documentsApi } from './documents'
|
||||
export { annotationsApi } from './annotations'
|
||||
export { trainingApi } from './training'
|
||||
export { inferenceApi } from './inference'
|
||||
16
frontend/src/api/endpoints/inference.ts
Normal file
16
frontend/src/api/endpoints/inference.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import apiClient from '../client'
|
||||
import type { InferenceResponse } from '../types'
|
||||
|
||||
export const inferenceApi = {
|
||||
processDocument: async (file: File): Promise<InferenceResponse> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/infer', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
})
|
||||
return data
|
||||
},
|
||||
}
|
||||
74
frontend/src/api/endpoints/training.ts
Normal file
74
frontend/src/api/endpoints/training.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import apiClient from '../client'
|
||||
import type { TrainingModelsResponse, DocumentListResponse } from '../types'
|
||||
|
||||
export const trainingApi = {
|
||||
getDocumentsForTraining: async (params?: {
|
||||
has_annotations?: boolean
|
||||
min_annotation_count?: number
|
||||
exclude_used_in_training?: boolean
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DocumentListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/documents', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getModels: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<TrainingModelsResponse> => {
|
||||
const { data} = await apiClient.get('/api/v1/admin/training/models', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getTaskDetail: async (taskId: string) => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/training/tasks/${taskId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
startTraining: async (config: {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
model_base?: string
|
||||
}) => {
|
||||
// Convert frontend config to backend TrainingTaskCreate format
|
||||
const taskRequest = {
|
||||
name: config.name,
|
||||
task_type: 'yolo',
|
||||
description: config.description,
|
||||
config: {
|
||||
document_ids: config.document_ids,
|
||||
epochs: config.epochs,
|
||||
batch_size: config.batch_size,
|
||||
base_model: config.model_base,
|
||||
},
|
||||
}
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/tasks', taskRequest)
|
||||
return data
|
||||
},
|
||||
|
||||
cancelTask: async (taskId: string) => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/training/tasks/${taskId}/cancel`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
downloadModel: async (taskId: string): Promise<Blob> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/training/models/${taskId}/download`,
|
||||
{
|
||||
responseType: 'blob',
|
||||
}
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
173
frontend/src/api/types.ts
Normal file
173
frontend/src/api/types.ts
Normal file
@@ -0,0 +1,173 @@
|
||||
export interface DocumentItem {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
content_type: string
|
||||
page_count: number
|
||||
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
annotation_count?: number
|
||||
annotation_sources?: {
|
||||
manual: number
|
||||
auto: number
|
||||
verified: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface DocumentListResponse {
|
||||
documents: DocumentItem[]
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
}
|
||||
|
||||
export interface AnnotationItem {
|
||||
annotation_id: string
|
||||
page_number: number
|
||||
class_id: number
|
||||
class_name: string
|
||||
bbox: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
normalized_bbox: {
|
||||
x_center: number
|
||||
y_center: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
text_value: string | null
|
||||
confidence: number | null
|
||||
source: 'manual' | 'auto'
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface DocumentDetailResponse {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
content_type: string
|
||||
page_count: number
|
||||
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
batch_id: string | null
|
||||
csv_field_values: Record<string, string> | null
|
||||
can_annotate: boolean
|
||||
annotation_lock_until: string | null
|
||||
annotations: AnnotationItem[]
|
||||
image_urls: string[]
|
||||
training_history: Array<{
|
||||
task_id: string
|
||||
name: string
|
||||
trained_at: string
|
||||
model_metrics: {
|
||||
mAP: number | null
|
||||
precision: number | null
|
||||
recall: number | null
|
||||
} | null
|
||||
}>
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface TrainingTask {
|
||||
task_id: string
|
||||
admin_token: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: 'pending' | 'running' | 'completed' | 'failed'
|
||||
task_type: string
|
||||
config: Record<string, unknown>
|
||||
started_at: string | null
|
||||
completed_at: string | null
|
||||
error_message: string | null
|
||||
result_metrics: Record<string, unknown>
|
||||
model_path: string | null
|
||||
document_count: number
|
||||
metrics_mAP: number | null
|
||||
metrics_precision: number | null
|
||||
metrics_recall: number | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface TrainingModelsResponse {
|
||||
models: TrainingTask[]
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
}
|
||||
|
||||
export interface ErrorResponse {
|
||||
detail: string
|
||||
}
|
||||
|
||||
export interface UploadDocumentResponse {
|
||||
document_id: string
|
||||
filename: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface CreateAnnotationRequest {
|
||||
page_number: number
|
||||
class_id: number
|
||||
bbox: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
text_value?: string
|
||||
}
|
||||
|
||||
export interface AnnotationOverrideRequest {
|
||||
text_value?: string
|
||||
bbox?: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
class_id?: number
|
||||
class_name?: string
|
||||
reason?: string
|
||||
}
|
||||
|
||||
export interface CrossValidationResult {
|
||||
is_valid: boolean
|
||||
payment_line_ocr: string | null
|
||||
payment_line_amount: string | null
|
||||
payment_line_account: string | null
|
||||
payment_line_account_type: 'bankgiro' | 'plusgiro' | null
|
||||
ocr_match: boolean | null
|
||||
amount_match: boolean | null
|
||||
bankgiro_match: boolean | null
|
||||
plusgiro_match: boolean | null
|
||||
details: string[]
|
||||
}
|
||||
|
||||
export interface InferenceResult {
|
||||
document_id: string
|
||||
document_type: string
|
||||
success: boolean
|
||||
fields: Record<string, string>
|
||||
confidence: Record<string, number>
|
||||
cross_validation: CrossValidationResult | null
|
||||
processing_time_ms: number
|
||||
visualization_url: string | null
|
||||
errors: string[]
|
||||
fallback_used: boolean
|
||||
}
|
||||
|
||||
export interface InferenceResponse {
|
||||
result: InferenceResult
|
||||
}
|
||||
39
frontend/src/components/Badge.tsx
Normal file
39
frontend/src/components/Badge.tsx
Normal file
@@ -0,0 +1,39 @@
|
||||
import React from 'react';
|
||||
import { DocumentStatus } from '../types';
|
||||
import { Check } from 'lucide-react';
|
||||
|
||||
interface BadgeProps {
|
||||
status: DocumentStatus | 'Exported';
|
||||
}
|
||||
|
||||
export const Badge: React.FC<BadgeProps> = ({ status }) => {
|
||||
if (status === 'Exported') {
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-full text-xs font-medium bg-warm-selected text-warm-text-secondary">
|
||||
<Check size={12} strokeWidth={3} />
|
||||
Exported
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
const styles = {
|
||||
[DocumentStatus.PENDING]: "bg-white border border-warm-divider text-warm-text-secondary",
|
||||
[DocumentStatus.LABELED]: "bg-warm-text-secondary text-white border border-transparent",
|
||||
[DocumentStatus.VERIFIED]: "bg-warm-state-success/10 text-warm-state-success border border-warm-state-success/20",
|
||||
[DocumentStatus.PARTIAL]: "bg-warm-state-warning/10 text-warm-state-warning border border-warm-state-warning/20",
|
||||
};
|
||||
|
||||
const icons = {
|
||||
[DocumentStatus.VERIFIED]: <Check size={12} className="mr-1" />,
|
||||
[DocumentStatus.PARTIAL]: <span className="mr-1 text-[10px] font-bold">!</span>,
|
||||
[DocumentStatus.PENDING]: null,
|
||||
[DocumentStatus.LABELED]: null,
|
||||
}
|
||||
|
||||
return (
|
||||
<span className={`inline-flex items-center px-3 py-1 rounded-full text-xs font-medium border ${styles[status]}`}>
|
||||
{icons[status]}
|
||||
{status}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
38
frontend/src/components/Button.tsx
Normal file
38
frontend/src/components/Button.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import React from 'react';
|
||||
|
||||
interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement> {
|
||||
variant?: 'primary' | 'secondary' | 'outline' | 'text';
|
||||
size?: 'sm' | 'md' | 'lg';
|
||||
}
|
||||
|
||||
export const Button: React.FC<ButtonProps> = ({
|
||||
variant = 'primary',
|
||||
size = 'md',
|
||||
className = '',
|
||||
children,
|
||||
...props
|
||||
}) => {
|
||||
const baseStyles = "inline-flex items-center justify-center rounded-md font-medium transition-all duration-150 ease-out active:scale-98 disabled:opacity-50 disabled:pointer-events-none";
|
||||
|
||||
const variants = {
|
||||
primary: "bg-warm-text-secondary text-white hover:bg-warm-text-primary shadow-sm",
|
||||
secondary: "bg-white border border-warm-divider text-warm-text-secondary hover:bg-warm-hover",
|
||||
outline: "bg-transparent border border-warm-text-secondary text-warm-text-secondary hover:bg-warm-hover",
|
||||
text: "text-warm-text-muted hover:text-warm-text-primary hover:bg-warm-hover",
|
||||
};
|
||||
|
||||
const sizes = {
|
||||
sm: "h-8 px-3 text-xs",
|
||||
md: "h-10 px-4 text-sm",
|
||||
lg: "h-12 px-6 text-base",
|
||||
};
|
||||
|
||||
return (
|
||||
<button
|
||||
className={`${baseStyles} ${variants[variant]} ${sizes[size]} ${className}`}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
};
|
||||
266
frontend/src/components/Dashboard.tsx
Normal file
266
frontend/src/components/Dashboard.tsx
Normal file
@@ -0,0 +1,266 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
|
||||
import { Badge } from './Badge'
|
||||
import { Button } from './Button'
|
||||
import { UploadModal } from './UploadModal'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import type { DocumentItem } from '../api/types'
|
||||
|
||||
interface DashboardProps {
|
||||
onNavigate: (view: string, docId?: string) => void
|
||||
}
|
||||
|
||||
const getStatusForBadge = (status: string): string => {
|
||||
const statusMap: Record<string, string> = {
|
||||
pending: 'Pending',
|
||||
labeled: 'Labeled',
|
||||
verified: 'Verified',
|
||||
exported: 'Exported',
|
||||
}
|
||||
return statusMap[status] || status
|
||||
}
|
||||
|
||||
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
|
||||
if (doc.auto_label_status === 'running') {
|
||||
return 45
|
||||
}
|
||||
if (doc.auto_label_status === 'completed') {
|
||||
return 100
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
||||
const [isUploadOpen, setIsUploadOpen] = useState(false)
|
||||
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
|
||||
const [statusFilter, setStatusFilter] = useState<string>('')
|
||||
const [limit] = useState(20)
|
||||
const [offset] = useState(0)
|
||||
|
||||
const { documents, total, isLoading, error, refetch } = useDocuments({
|
||||
status: statusFilter || undefined,
|
||||
limit,
|
||||
offset,
|
||||
})
|
||||
|
||||
const toggleSelection = (id: string) => {
|
||||
const newSet = new Set(selectedDocs)
|
||||
if (newSet.has(id)) {
|
||||
newSet.delete(id)
|
||||
} else {
|
||||
newSet.add(id)
|
||||
}
|
||||
setSelectedDocs(newSet)
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<div className="bg-red-50 border border-red-200 text-red-800 p-4 rounded-lg">
|
||||
Error loading documents. Please check your connection to the backend API.
|
||||
<button
|
||||
onClick={() => refetch()}
|
||||
className="ml-4 underline hover:no-underline"
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||
<div className="flex items-center justify-between mb-8">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||
Documents
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mt-1">
|
||||
{isLoading ? 'Loading...' : `${total} documents total`}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<Button variant="secondary" disabled={selectedDocs.size === 0}>
|
||||
Export Selection ({selectedDocs.size})
|
||||
</Button>
|
||||
<Button onClick={() => setIsUploadOpen(true)}>Upload Documents</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-4 mb-6 shadow-sm flex flex-wrap gap-4 items-center">
|
||||
<div className="relative flex-1 min-w-[200px]">
|
||||
<Search
|
||||
className="absolute left-3 top-1/2 -translate-y-1/2 text-warm-text-muted"
|
||||
size={16}
|
||||
/>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search documents..."
|
||||
className="w-full pl-9 pr-4 h-10 rounded-md border border-warm-border bg-white focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-3">
|
||||
<div className="relative">
|
||||
<select
|
||||
value={statusFilter}
|
||||
onChange={(e) => setStatusFilter(e.target.value)}
|
||||
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
|
||||
>
|
||||
<option value="">All Statuses</option>
|
||||
<option value="pending">Pending</option>
|
||||
<option value="labeled">Labeled</option>
|
||||
<option value="verified">Verified</option>
|
||||
<option value="exported">Exported</option>
|
||||
</select>
|
||||
<ChevronDown
|
||||
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||
size={14}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<table className="w-full text-left border-collapse">
|
||||
<thead>
|
||||
<tr className="border-b border-warm-border bg-white">
|
||||
<th className="py-3 pl-6 pr-4 w-12">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary"
|
||||
/>
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Document Name
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Date
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Status
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Annotations
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
|
||||
Auto-label
|
||||
</th>
|
||||
<th className="py-3 px-4 w-12"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{isLoading ? (
|
||||
<tr>
|
||||
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
||||
Loading documents...
|
||||
</td>
|
||||
</tr>
|
||||
) : documents.length === 0 ? (
|
||||
<tr>
|
||||
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
||||
No documents found. Upload your first document to get started.
|
||||
</td>
|
||||
</tr>
|
||||
) : (
|
||||
documents.map((doc) => {
|
||||
const isSelected = selectedDocs.has(doc.document_id)
|
||||
const progress = getAutoLabelProgress(doc)
|
||||
|
||||
return (
|
||||
<tr
|
||||
key={doc.document_id}
|
||||
onClick={() => onNavigate('detail', doc.document_id)}
|
||||
className={`
|
||||
group transition-colors duration-150 cursor-pointer border-b border-warm-border last:border-0
|
||||
${isSelected ? 'bg-warm-selected' : 'hover:bg-warm-hover bg-white'}
|
||||
`}
|
||||
>
|
||||
<td
|
||||
className="py-4 pl-6 pr-4 relative"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
toggleSelection(doc.document_id)
|
||||
}}
|
||||
>
|
||||
{isSelected && (
|
||||
<div className="absolute left-0 top-0 bottom-0 w-[3px] bg-warm-state-info" />
|
||||
)}
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={isSelected}
|
||||
readOnly
|
||||
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary cursor-pointer"
|
||||
/>
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="p-2 bg-warm-bg rounded border border-warm-border text-warm-text-muted">
|
||||
<FileText size={16} />
|
||||
</div>
|
||||
<span className="font-medium text-warm-text-secondary">
|
||||
{doc.filename}
|
||||
</span>
|
||||
</div>
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary font-mono">
|
||||
{new Date(doc.created_at).toLocaleDateString()}
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
<Badge status={getStatusForBadge(doc.status)} />
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
||||
{doc.annotation_count || 0} annotations
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
{doc.auto_label_status === 'running' && progress && (
|
||||
<div className="w-full">
|
||||
<div className="flex justify-between text-xs mb-1">
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
Running
|
||||
</span>
|
||||
<span className="text-warm-text-muted">{progress}%</span>
|
||||
</div>
|
||||
<div className="h-1.5 w-full bg-warm-selected rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-warm-state-info transition-all duration-500 ease-out"
|
||||
style={{ width: `${progress}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{doc.auto_label_status === 'completed' && (
|
||||
<span className="text-sm font-medium text-warm-state-success">
|
||||
Completed
|
||||
</span>
|
||||
)}
|
||||
{doc.auto_label_status === 'failed' && (
|
||||
<span className="text-sm font-medium text-warm-state-error">
|
||||
Failed
|
||||
</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="py-4 px-4 text-right">
|
||||
<button className="text-warm-text-muted hover:text-warm-text-secondary p-1 rounded hover:bg-black/5 transition-colors">
|
||||
<MoreHorizontal size={18} />
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<UploadModal
|
||||
isOpen={isUploadOpen}
|
||||
onClose={() => {
|
||||
setIsUploadOpen(false)
|
||||
refetch()
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
148
frontend/src/components/DashboardOverview.tsx
Normal file
148
frontend/src/components/DashboardOverview.tsx
Normal file
@@ -0,0 +1,148 @@
|
||||
import React from 'react'
|
||||
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import { useTraining } from '../hooks/useTraining'
|
||||
|
||||
interface DashboardOverviewProps {
|
||||
onNavigate: (view: string) => void
|
||||
}
|
||||
|
||||
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
|
||||
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
|
||||
const { models, isLoadingModels } = useTraining()
|
||||
|
||||
const stats = [
|
||||
{
|
||||
label: 'Total Documents',
|
||||
value: docsLoading ? '...' : totalDocs.toString(),
|
||||
icon: FileText,
|
||||
color: 'text-warm-text-primary',
|
||||
bgColor: 'bg-warm-bg',
|
||||
},
|
||||
{
|
||||
label: 'Labeled',
|
||||
value: '0',
|
||||
icon: CheckCircle,
|
||||
color: 'text-warm-state-success',
|
||||
bgColor: 'bg-green-50',
|
||||
},
|
||||
{
|
||||
label: 'Pending',
|
||||
value: '0',
|
||||
icon: Clock,
|
||||
color: 'text-warm-state-warning',
|
||||
bgColor: 'bg-yellow-50',
|
||||
},
|
||||
{
|
||||
label: 'Training Models',
|
||||
value: isLoadingModels ? '...' : models.length.toString(),
|
||||
icon: TrendingUp,
|
||||
color: 'text-warm-state-info',
|
||||
bgColor: 'bg-blue-50',
|
||||
},
|
||||
]
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||
{/* Header */}
|
||||
<div className="mb-8">
|
||||
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||
Dashboard
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mt-1">
|
||||
Overview of your document annotation system
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Stats Grid */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
|
||||
{stats.map((stat) => (
|
||||
<div
|
||||
key={stat.label}
|
||||
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
|
||||
>
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
|
||||
<stat.icon className={stat.color} size={24} />
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-2xl font-bold text-warm-text-primary mb-1">
|
||||
{stat.value}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">{stat.label}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Quick Actions */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
Quick Actions
|
||||
</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<Button onClick={() => onNavigate('documents')} className="justify-start">
|
||||
<FileText size={18} className="mr-2" />
|
||||
Manage Documents
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
|
||||
<Activity size={18} className="mr-2" />
|
||||
Start Training
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
|
||||
<TrendingUp size={18} className="mr-2" />
|
||||
View Models
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Recent Activity */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<div className="p-6 border-b border-warm-border">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary">
|
||||
Recent Activity
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6">
|
||||
<div className="text-center py-8 text-warm-text-muted">
|
||||
<Activity size={48} className="mx-auto mb-3 opacity-20" />
|
||||
<p className="text-sm">No recent activity</p>
|
||||
<p className="text-xs mt-1">
|
||||
Start by uploading documents or creating training jobs
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* System Status */}
|
||||
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
System Status
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Backend API</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Online
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Database</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Connected
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">GPU</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Available
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
504
frontend/src/components/DocumentDetail.tsx
Normal file
504
frontend/src/components/DocumentDetail.tsx
Normal file
@@ -0,0 +1,504 @@
|
||||
import React, { useState, useRef, useEffect } from 'react'
|
||||
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocumentDetail } from '../hooks/useDocumentDetail'
|
||||
import { useAnnotations } from '../hooks/useAnnotations'
|
||||
import { documentsApi } from '../api/endpoints/documents'
|
||||
import type { AnnotationItem } from '../api/types'
|
||||
|
||||
interface DocumentDetailProps {
|
||||
docId: string
|
||||
onBack: () => void
|
||||
}
|
||||
|
||||
// Field class mapping from backend
|
||||
const FIELD_CLASSES: Record<number, string> = {
|
||||
0: 'invoice_number',
|
||||
1: 'invoice_date',
|
||||
2: 'invoice_due_date',
|
||||
3: 'ocr_number',
|
||||
4: 'bankgiro',
|
||||
5: 'plusgiro',
|
||||
6: 'amount',
|
||||
7: 'supplier_organisation_number',
|
||||
8: 'payment_line',
|
||||
9: 'customer_number',
|
||||
}
|
||||
|
||||
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
|
||||
const { document, annotations, isLoading } = useDocumentDetail(docId)
|
||||
const {
|
||||
createAnnotation,
|
||||
updateAnnotation,
|
||||
deleteAnnotation,
|
||||
isCreating,
|
||||
isDeleting,
|
||||
} = useAnnotations(docId)
|
||||
|
||||
const [selectedId, setSelectedId] = useState<string | null>(null)
|
||||
const [zoom, setZoom] = useState(100)
|
||||
const [isDrawing, setIsDrawing] = useState(false)
|
||||
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
|
||||
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
|
||||
const [selectedClassId, setSelectedClassId] = useState<number>(0)
|
||||
const [currentPage, setCurrentPage] = useState(1)
|
||||
const [imageSize, setImageSize] = useState<{ width: number; height: number } | null>(null)
|
||||
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null)
|
||||
|
||||
const canvasRef = useRef<HTMLDivElement>(null)
|
||||
const imageRef = useRef<HTMLImageElement>(null)
|
||||
|
||||
const [isMarkingComplete, setIsMarkingComplete] = useState(false)
|
||||
|
||||
const selectedAnnotation = annotations?.find((a) => a.annotation_id === selectedId)
|
||||
|
||||
// Handle mark as complete
|
||||
const handleMarkComplete = async () => {
|
||||
if (!annotations || annotations.length === 0) {
|
||||
alert('Please add at least one annotation before marking as complete.')
|
||||
return
|
||||
}
|
||||
|
||||
if (!confirm('Mark this document as labeled? This will save annotations to the database.')) {
|
||||
return
|
||||
}
|
||||
|
||||
setIsMarkingComplete(true)
|
||||
try {
|
||||
const result = await documentsApi.updateStatus(docId, 'labeled')
|
||||
alert(`Document marked as labeled. ${(result as any).fields_saved || annotations.length} annotations saved.`)
|
||||
onBack() // Return to document list
|
||||
} catch (error) {
|
||||
console.error('Failed to mark document as complete:', error)
|
||||
alert('Failed to mark document as complete. Please try again.')
|
||||
} finally {
|
||||
setIsMarkingComplete(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Load image via fetch with authentication header
|
||||
useEffect(() => {
|
||||
let objectUrl: string | null = null
|
||||
|
||||
const loadImage = async () => {
|
||||
if (!docId) return
|
||||
|
||||
const token = localStorage.getItem('admin_token')
|
||||
const imageUrl = `${import.meta.env.VITE_API_URL || 'http://localhost:8000'}/api/v1/admin/documents/${docId}/images/${currentPage}`
|
||||
|
||||
try {
|
||||
const response = await fetch(imageUrl, {
|
||||
headers: {
|
||||
'X-Admin-Token': token || '',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to load image: ${response.status}`)
|
||||
}
|
||||
|
||||
const blob = await response.blob()
|
||||
objectUrl = URL.createObjectURL(blob)
|
||||
setImageBlobUrl(objectUrl)
|
||||
} catch (error) {
|
||||
console.error('Failed to load image:', error)
|
||||
}
|
||||
}
|
||||
|
||||
loadImage()
|
||||
|
||||
// Cleanup: revoke object URL when component unmounts or page changes
|
||||
return () => {
|
||||
if (objectUrl) {
|
||||
URL.revokeObjectURL(objectUrl)
|
||||
}
|
||||
}
|
||||
}, [currentPage, docId])
|
||||
|
||||
// Load image size
|
||||
useEffect(() => {
|
||||
if (imageRef.current && imageRef.current.complete) {
|
||||
setImageSize({
|
||||
width: imageRef.current.naturalWidth,
|
||||
height: imageRef.current.naturalHeight,
|
||||
})
|
||||
}
|
||||
}, [imageBlobUrl])
|
||||
|
||||
const handleImageLoad = () => {
|
||||
if (imageRef.current) {
|
||||
setImageSize({
|
||||
width: imageRef.current.naturalWidth,
|
||||
height: imageRef.current.naturalHeight,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (!canvasRef.current || !imageSize) return
|
||||
const rect = canvasRef.current.getBoundingClientRect()
|
||||
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||
setIsDrawing(true)
|
||||
setDrawStart({ x, y })
|
||||
setDrawEnd({ x, y })
|
||||
}
|
||||
|
||||
const handleMouseMove = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (!isDrawing || !canvasRef.current || !imageSize) return
|
||||
const rect = canvasRef.current.getBoundingClientRect()
|
||||
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||
setDrawEnd({ x, y })
|
||||
}
|
||||
|
||||
const handleMouseUp = () => {
|
||||
if (!isDrawing || !drawStart || !drawEnd || !imageSize) {
|
||||
setIsDrawing(false)
|
||||
return
|
||||
}
|
||||
|
||||
const bbox_x = Math.min(drawStart.x, drawEnd.x)
|
||||
const bbox_y = Math.min(drawStart.y, drawEnd.y)
|
||||
const bbox_width = Math.abs(drawEnd.x - drawStart.x)
|
||||
const bbox_height = Math.abs(drawEnd.y - drawStart.y)
|
||||
|
||||
// Only create if box is large enough (min 10x10 pixels)
|
||||
if (bbox_width > 10 && bbox_height > 10) {
|
||||
createAnnotation({
|
||||
page_number: currentPage,
|
||||
class_id: selectedClassId,
|
||||
bbox: {
|
||||
x: Math.round(bbox_x),
|
||||
y: Math.round(bbox_y),
|
||||
width: Math.round(bbox_width),
|
||||
height: Math.round(bbox_height),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
setIsDrawing(false)
|
||||
setDrawStart(null)
|
||||
setDrawEnd(null)
|
||||
}
|
||||
|
||||
const handleDeleteAnnotation = (annotationId: string) => {
|
||||
if (confirm('Are you sure you want to delete this annotation?')) {
|
||||
deleteAnnotation(annotationId)
|
||||
setSelectedId(null)
|
||||
}
|
||||
}
|
||||
|
||||
if (isLoading || !document) {
|
||||
return (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<div className="text-warm-text-muted">Loading...</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Get current page annotations
|
||||
const pageAnnotations = annotations?.filter((a) => a.page_number === currentPage) || []
|
||||
|
||||
return (
|
||||
<div className="flex h-[calc(100vh-56px)] overflow-hidden">
|
||||
{/* Main Canvas Area */}
|
||||
<div className="flex-1 bg-warm-bg flex flex-col relative">
|
||||
{/* Toolbar */}
|
||||
<div className="h-14 border-b border-warm-border bg-white flex items-center justify-between px-4 z-10">
|
||||
<div className="flex items-center gap-4">
|
||||
<button
|
||||
onClick={onBack}
|
||||
className="p-2 hover:bg-warm-hover rounded-md text-warm-text-secondary transition-colors"
|
||||
>
|
||||
<ChevronLeft size={20} />
|
||||
</button>
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-primary">{document.filename}</h2>
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Page {currentPage} of {document.page_count}
|
||||
</p>
|
||||
</div>
|
||||
<div className="h-6 w-px bg-warm-divider mx-2" />
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||
onClick={() => setZoom((z) => Math.max(50, z - 10))}
|
||||
>
|
||||
<ZoomOut size={16} />
|
||||
</button>
|
||||
<span className="text-xs font-mono w-12 text-center text-warm-text-secondary">
|
||||
{zoom}%
|
||||
</span>
|
||||
<button
|
||||
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||
onClick={() => setZoom((z) => Math.min(200, z + 10))}
|
||||
>
|
||||
<ZoomIn size={16} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="secondary" size="sm">
|
||||
Auto-label
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="sm"
|
||||
onClick={handleMarkComplete}
|
||||
disabled={isMarkingComplete || document.status === 'labeled'}
|
||||
>
|
||||
<CheckCircle size={16} className="mr-1" />
|
||||
{isMarkingComplete ? 'Saving...' : document.status === 'labeled' ? 'Labeled' : 'Mark Complete'}
|
||||
</Button>
|
||||
{document.page_count > 1 && (
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={() => setCurrentPage((p) => Math.max(1, p - 1))}
|
||||
disabled={currentPage === 1}
|
||||
>
|
||||
Prev
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={() => setCurrentPage((p) => Math.min(document.page_count, p + 1))}
|
||||
disabled={currentPage === document.page_count}
|
||||
>
|
||||
Next
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Canvas Scroll Area */}
|
||||
<div className="flex-1 overflow-auto p-8 flex justify-center bg-warm-bg">
|
||||
<div
|
||||
ref={canvasRef}
|
||||
className="bg-white shadow-lg relative transition-transform duration-200 ease-out origin-top"
|
||||
style={{
|
||||
width: imageSize?.width || 800,
|
||||
height: imageSize?.height || 1132,
|
||||
transform: `scale(${zoom / 100})`,
|
||||
marginBottom: '100px',
|
||||
cursor: isDrawing ? 'crosshair' : 'default',
|
||||
}}
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseUp={handleMouseUp}
|
||||
onClick={() => setSelectedId(null)}
|
||||
>
|
||||
{/* Document Image */}
|
||||
{imageBlobUrl ? (
|
||||
<img
|
||||
ref={imageRef}
|
||||
src={imageBlobUrl}
|
||||
alt={`Page ${currentPage}`}
|
||||
className="w-full h-full object-contain select-none pointer-events-none"
|
||||
onLoad={handleImageLoad}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex items-center justify-center h-full">
|
||||
<div className="text-warm-text-muted">Loading image...</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Annotation Overlays */}
|
||||
{pageAnnotations.map((ann) => {
|
||||
const isSelected = selectedId === ann.annotation_id
|
||||
return (
|
||||
<div
|
||||
key={ann.annotation_id}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
setSelectedId(ann.annotation_id)
|
||||
}}
|
||||
className={`
|
||||
absolute group cursor-pointer transition-all duration-100
|
||||
${
|
||||
ann.source === 'auto'
|
||||
? 'border border-dashed border-warm-text-muted bg-transparent'
|
||||
: 'border-2 border-warm-text-secondary bg-warm-text-secondary/5'
|
||||
}
|
||||
${
|
||||
isSelected
|
||||
? 'border-2 border-warm-state-info ring-4 ring-warm-state-info/10 z-20'
|
||||
: 'hover:bg-warm-state-info/5 z-10'
|
||||
}
|
||||
`}
|
||||
style={{
|
||||
left: ann.bbox.x,
|
||||
top: ann.bbox.y,
|
||||
width: ann.bbox.width,
|
||||
height: ann.bbox.height,
|
||||
}}
|
||||
>
|
||||
{/* Label Tag */}
|
||||
<div
|
||||
className={`
|
||||
absolute -top-6 left-0 text-[10px] uppercase font-bold px-1.5 py-0.5 rounded-sm tracking-wide shadow-sm whitespace-nowrap
|
||||
${
|
||||
isSelected
|
||||
? 'bg-warm-state-info text-white'
|
||||
: 'bg-white text-warm-text-secondary border border-warm-border'
|
||||
}
|
||||
`}
|
||||
>
|
||||
{ann.class_name}
|
||||
</div>
|
||||
|
||||
{/* Resize Handles (Visual only) */}
|
||||
{isSelected && (
|
||||
<>
|
||||
<div className="absolute -top-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -top-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -bottom-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -bottom-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
|
||||
{/* Drawing Box Preview */}
|
||||
{isDrawing && drawStart && drawEnd && (
|
||||
<div
|
||||
className="absolute border-2 border-warm-state-info bg-warm-state-info/10 z-30 pointer-events-none"
|
||||
style={{
|
||||
left: Math.min(drawStart.x, drawEnd.x),
|
||||
top: Math.min(drawStart.y, drawEnd.y),
|
||||
width: Math.abs(drawEnd.x - drawStart.x),
|
||||
height: Math.abs(drawEnd.y - drawStart.y),
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right Sidebar */}
|
||||
<div className="w-80 bg-white border-l border-warm-border flex flex-col shadow-[-4px_0_15px_-3px_rgba(0,0,0,0.03)] z-20">
|
||||
{/* Field Selector */}
|
||||
<div className="p-4 border-b border-warm-border">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Draw Annotation</h3>
|
||||
<div className="space-y-2">
|
||||
<label className="block text-xs text-warm-text-muted mb-1">Select Field Type</label>
|
||||
<select
|
||||
value={selectedClassId}
|
||||
onChange={(e) => setSelectedClassId(Number(e.target.value))}
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
>
|
||||
{Object.entries(FIELD_CLASSES).map(([id, name]) => (
|
||||
<option key={id} value={id}>
|
||||
{name.replace(/_/g, ' ')}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-warm-text-muted mt-2">
|
||||
Click and drag on the document to create a bounding box
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Document Info Card */}
|
||||
<div className="p-4 border-b border-warm-border">
|
||||
<div className="bg-white rounded-lg border border-warm-border p-4 shadow-sm">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Document Info</h3>
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Status</span>
|
||||
<span className="text-warm-text-secondary font-medium capitalize">
|
||||
{document.status}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Size</span>
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{(document.file_size / 1024 / 1024).toFixed(2)} MB
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Uploaded</span>
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{new Date(document.created_at).toLocaleDateString()}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Annotations List */}
|
||||
<div className="flex-1 overflow-y-auto p-4">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary">Annotations</h3>
|
||||
<span className="text-xs text-warm-text-muted">{pageAnnotations.length} items</span>
|
||||
</div>
|
||||
|
||||
{pageAnnotations.length === 0 ? (
|
||||
<div className="text-center py-8 text-warm-text-muted">
|
||||
<Tag size={48} className="mx-auto mb-3 opacity-20" />
|
||||
<p className="text-sm">No annotations yet</p>
|
||||
<p className="text-xs mt-1">Draw on the document to add annotations</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-3">
|
||||
{pageAnnotations.map((ann) => (
|
||||
<div
|
||||
key={ann.annotation_id}
|
||||
onClick={() => setSelectedId(ann.annotation_id)}
|
||||
className={`
|
||||
group p-3 rounded-md border transition-all duration-150 cursor-pointer
|
||||
${
|
||||
selectedId === ann.annotation_id
|
||||
? 'bg-warm-bg border-warm-state-info shadow-sm'
|
||||
: 'bg-white border-warm-border hover:border-warm-text-muted'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex justify-between items-start mb-1">
|
||||
<span className="text-xs font-bold text-warm-text-secondary uppercase tracking-wider">
|
||||
{ann.class_name.replace(/_/g, ' ')}
|
||||
</span>
|
||||
{selectedId === ann.annotation_id && (
|
||||
<div className="flex gap-1">
|
||||
<button
|
||||
onClick={() => handleDeleteAnnotation(ann.annotation_id)}
|
||||
className="text-warm-text-muted hover:text-warm-state-error"
|
||||
disabled={isDeleting}
|
||||
>
|
||||
<Trash2 size={12} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<p className="text-sm text-warm-text-muted font-mono truncate">
|
||||
{ann.text_value || '(no text)'}
|
||||
</p>
|
||||
<div className="flex items-center gap-2 mt-2">
|
||||
<span
|
||||
className={`text-[10px] px-1.5 py-0.5 rounded ${
|
||||
ann.source === 'auto'
|
||||
? 'bg-blue-50 text-blue-700'
|
||||
: 'bg-green-50 text-green-700'
|
||||
}`}
|
||||
>
|
||||
{ann.source}
|
||||
</span>
|
||||
{ann.confidence && (
|
||||
<span className="text-[10px] text-warm-text-muted">
|
||||
{(ann.confidence * 100).toFixed(0)}%
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
466
frontend/src/components/InferenceDemo.tsx
Normal file
466
frontend/src/components/InferenceDemo.tsx
Normal file
@@ -0,0 +1,466 @@
|
||||
import React, { useState, useRef } from 'react'
|
||||
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { inferenceApi } from '../api/endpoints'
|
||||
import type { InferenceResult } from '../api/types'
|
||||
|
||||
export const InferenceDemo: React.FC = () => {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [selectedFile, setSelectedFile] = useState<File | null>(null)
|
||||
const [isProcessing, setIsProcessing] = useState(false)
|
||||
const [result, setResult] = useState<InferenceResult | null>(null)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const handleFileSelect = (file: File | null) => {
|
||||
if (!file) return
|
||||
|
||||
const validTypes = ['application/pdf', 'image/png', 'image/jpeg', 'image/jpg']
|
||||
if (!validTypes.includes(file.type)) {
|
||||
setError('Please upload a PDF, PNG, or JPG file')
|
||||
return
|
||||
}
|
||||
|
||||
if (file.size > 50 * 1024 * 1024) {
|
||||
setError('File size must be less than 50MB')
|
||||
return
|
||||
}
|
||||
|
||||
setSelectedFile(file)
|
||||
setResult(null)
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(false)
|
||||
if (e.dataTransfer.files.length > 0) {
|
||||
handleFileSelect(e.dataTransfer.files[0])
|
||||
}
|
||||
}
|
||||
|
||||
const handleBrowseClick = () => {
|
||||
fileInputRef.current?.click()
|
||||
}
|
||||
|
||||
const handleProcess = async () => {
|
||||
if (!selectedFile) return
|
||||
|
||||
setIsProcessing(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const response = await inferenceApi.processDocument(selectedFile)
|
||||
console.log('API Response:', response)
|
||||
console.log('Visualization URL:', response.result?.visualization_url)
|
||||
setResult(response.result)
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Processing failed')
|
||||
} finally {
|
||||
setIsProcessing(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleReset = () => {
|
||||
setSelectedFile(null)
|
||||
setResult(null)
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const formatFieldName = (field: string): string => {
|
||||
const fieldNames: Record<string, string> = {
|
||||
InvoiceNumber: 'Invoice Number',
|
||||
InvoiceDate: 'Invoice Date',
|
||||
InvoiceDueDate: 'Due Date',
|
||||
OCR: 'OCR Number',
|
||||
Amount: 'Amount',
|
||||
Bankgiro: 'Bankgiro',
|
||||
Plusgiro: 'Plusgiro',
|
||||
supplier_org_number: 'Supplier Org Number',
|
||||
customer_number: 'Customer Number',
|
||||
payment_line: 'Payment Line',
|
||||
}
|
||||
return fieldNames[field] || field
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="max-w-7xl mx-auto px-4 py-6 space-y-6">
|
||||
{/* Header */}
|
||||
<div className="text-center">
|
||||
<h2 className="text-3xl font-bold text-warm-text-primary mb-2">
|
||||
Invoice Extraction Demo
|
||||
</h2>
|
||||
<p className="text-warm-text-muted">
|
||||
Upload a Swedish invoice to see our AI-powered field extraction in action
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Upload Area */}
|
||||
{!result && (
|
||||
<div className="max-w-2xl mx-auto">
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-8 shadow-sm">
|
||||
<div
|
||||
className={`
|
||||
relative h-72 rounded-xl border-2 border-dashed transition-all duration-200
|
||||
${isDragging
|
||||
? 'border-warm-text-secondary bg-warm-selected scale-[1.02]'
|
||||
: 'border-warm-divider bg-warm-bg hover:bg-warm-hover hover:border-warm-text-secondary/50'
|
||||
}
|
||||
${isProcessing ? 'opacity-60 pointer-events-none' : 'cursor-pointer'}
|
||||
`}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(true)
|
||||
}}
|
||||
onDragLeave={() => setIsDragging(false)}
|
||||
onDrop={handleDrop}
|
||||
onClick={handleBrowseClick}
|
||||
>
|
||||
<div className="absolute inset-0 flex flex-col items-center justify-center gap-6">
|
||||
{isProcessing ? (
|
||||
<>
|
||||
<Loader2 size={56} className="text-warm-text-secondary animate-spin" />
|
||||
<div className="text-center">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||
Processing invoice...
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">
|
||||
This may take a few moments
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
) : selectedFile ? (
|
||||
<>
|
||||
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||
<FileText size={40} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center px-4">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||
{selectedFile.name}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">
|
||||
{(selectedFile.size / 1024 / 1024).toFixed(2)} MB
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||
<UploadCloud size={40} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center px-4">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-2">
|
||||
Drag & drop invoice here
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted mb-3">
|
||||
or{' '}
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
browse files
|
||||
</span>
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Supports PDF, PNG, JPG (up to 50MB)
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".pdf,image/*"
|
||||
className="hidden"
|
||||
onChange={(e) => handleFileSelect(e.target.files?.[0] || null)}
|
||||
/>
|
||||
|
||||
{error && (
|
||||
<div className="mt-5 p-4 bg-red-50 border border-red-200 rounded-lg flex items-start gap-3">
|
||||
<AlertCircle size={18} className="text-red-600 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-sm text-red-800 font-medium">{error}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedFile && !isProcessing && (
|
||||
<div className="mt-6 flex gap-3 justify-end">
|
||||
<Button variant="secondary" onClick={handleReset}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleProcess}>Process Invoice</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Results */}
|
||||
{result && (
|
||||
<div className="space-y-6">
|
||||
{/* Status Header */}
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border shadow-sm overflow-hidden">
|
||||
<div className="p-6 flex items-center justify-between border-b border-warm-divider">
|
||||
<div className="flex items-center gap-4">
|
||||
{result.success ? (
|
||||
<div className="p-3 bg-green-100 rounded-xl">
|
||||
<CheckCircle2 size={28} className="text-green-600" />
|
||||
</div>
|
||||
) : (
|
||||
<div className="p-3 bg-yellow-100 rounded-xl">
|
||||
<AlertCircle size={28} className="text-yellow-600" />
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<h3 className="text-xl font-bold text-warm-text-primary">
|
||||
{result.success ? 'Extraction Complete' : 'Partial Results'}
|
||||
</h3>
|
||||
<p className="text-sm text-warm-text-muted mt-0.5">
|
||||
Document ID: <span className="font-mono">{result.document_id}</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button variant="secondary" onClick={handleReset}>
|
||||
Process Another
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="px-6 py-4 bg-warm-bg/50 flex items-center gap-6 text-sm">
|
||||
<div className="flex items-center gap-2 text-warm-text-secondary">
|
||||
<Clock size={16} />
|
||||
<span className="font-medium">
|
||||
{result.processing_time_ms.toFixed(0)}ms
|
||||
</span>
|
||||
</div>
|
||||
{result.fallback_used && (
|
||||
<span className="px-3 py-1.5 bg-warm-selected rounded-md text-warm-text-secondary font-medium text-xs">
|
||||
Fallback OCR Used
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Main Content Grid */}
|
||||
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
|
||||
{/* Left Column: Extracted Fields */}
|
||||
<div className="lg:col-span-2 space-y-6">
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Extracted Fields
|
||||
</h3>
|
||||
<div className="flex flex-wrap gap-4">
|
||||
{Object.entries(result.fields).map(([field, value]) => {
|
||||
const confidence = result.confidence[field]
|
||||
return (
|
||||
<div
|
||||
key={field}
|
||||
className="p-4 bg-warm-bg/70 rounded-lg border border-warm-divider hover:border-warm-text-secondary/30 transition-colors w-[calc(50%-0.5rem)]"
|
||||
>
|
||||
<div className="text-xs font-semibold text-warm-text-muted uppercase tracking-wide mb-2">
|
||||
{formatFieldName(field)}
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary mb-2 min-h-[1.5rem]">
|
||||
{value || <span className="text-warm-text-muted italic">N/A</span>}
|
||||
</div>
|
||||
{confidence && (
|
||||
<div className="flex items-center gap-1.5 text-xs font-medium text-warm-text-secondary">
|
||||
<CheckCircle2 size={13} />
|
||||
<span>{(confidence * 100).toFixed(1)}%</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Visualization */}
|
||||
{result.visualization_url && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Detection Visualization
|
||||
</h3>
|
||||
<div className="bg-warm-bg rounded-lg overflow-hidden border border-warm-divider">
|
||||
<img
|
||||
src={`${import.meta.env.VITE_API_URL || 'http://localhost:8000'}${result.visualization_url}`}
|
||||
alt="Detection visualization"
|
||||
className="w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right Column: Cross-Validation & Errors */}
|
||||
<div className="space-y-6">
|
||||
{/* Cross-Validation */}
|
||||
{result.cross_validation && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Payment Line Validation
|
||||
</h3>
|
||||
|
||||
<div
|
||||
className={`
|
||||
p-4 rounded-lg mb-4 flex items-center gap-3
|
||||
${result.cross_validation.is_valid
|
||||
? 'bg-green-50 border border-green-200'
|
||||
: 'bg-yellow-50 border border-yellow-200'
|
||||
}
|
||||
`}
|
||||
>
|
||||
{result.cross_validation.is_valid ? (
|
||||
<>
|
||||
<CheckCircle2 size={22} className="text-green-600 flex-shrink-0" />
|
||||
<span className="font-bold text-green-800">All Fields Match</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<AlertCircle size={22} className="text-yellow-600 flex-shrink-0" />
|
||||
<span className="font-bold text-yellow-800">Mismatch Detected</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2.5">
|
||||
{result.cross_validation.payment_line_ocr && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${result.cross_validation.ocr_match === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: result.cross_validation.ocr_match === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
OCR NUMBER
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_ocr}
|
||||
</div>
|
||||
</div>
|
||||
{result.cross_validation.ocr_match === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{result.cross_validation.ocr_match === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.cross_validation.payment_line_amount && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${result.cross_validation.amount_match === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: result.cross_validation.amount_match === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
AMOUNT
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_amount}
|
||||
</div>
|
||||
</div>
|
||||
{result.cross_validation.amount_match === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{result.cross_validation.amount_match === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.cross_validation.payment_line_account && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: (result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
{result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? 'BANKGIRO'
|
||||
: 'PLUSGIRO'}
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_account}
|
||||
</div>
|
||||
</div>
|
||||
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{result.cross_validation.details.length > 0 && (
|
||||
<div className="mt-4 p-3 bg-warm-bg/70 rounded-lg text-xs text-warm-text-secondary leading-relaxed border border-warm-divider">
|
||||
{result.cross_validation.details[result.cross_validation.details.length - 1]}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Errors */}
|
||||
{result.errors.length > 0 && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-red-500 rounded-full"></span>
|
||||
Issues
|
||||
</h3>
|
||||
<div className="space-y-2.5">
|
||||
{result.errors.map((err, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="p-3 bg-yellow-50 border border-yellow-200 rounded-lg flex items-start gap-3"
|
||||
>
|
||||
<AlertCircle size={16} className="text-yellow-600 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-xs text-yellow-800 leading-relaxed">{err}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
102
frontend/src/components/Layout.tsx
Normal file
102
frontend/src/components/Layout.tsx
Normal file
@@ -0,0 +1,102 @@
|
||||
import React, { useState } from 'react';
|
||||
import { Box, LayoutTemplate, Users, BookOpen, LogOut, Sparkles } from 'lucide-react';
|
||||
|
||||
interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
activeView: string;
|
||||
onNavigate: (view: string) => void;
|
||||
onLogout?: () => void;
|
||||
}
|
||||
|
||||
export const Layout: React.FC<LayoutProps> = ({ children, activeView, onNavigate, onLogout }) => {
|
||||
const [showDropdown, setShowDropdown] = useState(false);
|
||||
const navItems = [
|
||||
{ id: 'dashboard', label: 'Dashboard', icon: LayoutTemplate },
|
||||
{ id: 'demo', label: 'Demo', icon: Sparkles },
|
||||
{ id: 'training', label: 'Training', icon: Box }, // Mapped to Compliants visually in prompt, using logical name
|
||||
{ id: 'documents', label: 'Documents', icon: BookOpen },
|
||||
{ id: 'models', label: 'Models', icon: Users }, // Contacts in prompt, mapped to models for this use case
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-warm-bg font-sans text-warm-text-primary flex flex-col">
|
||||
{/* Top Navigation */}
|
||||
<nav className="h-14 bg-warm-bg border-b border-warm-border px-6 flex items-center justify-between shrink-0 sticky top-0 z-40">
|
||||
<div className="flex items-center gap-8">
|
||||
{/* Logo */}
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-8 h-8 bg-warm-text-primary rounded-full flex items-center justify-center text-white">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="3" strokeLinecap="round" strokeLinejoin="round">
|
||||
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Nav Links */}
|
||||
<div className="flex h-14">
|
||||
{navItems.map(item => {
|
||||
const isActive = activeView === item.id || (activeView === 'detail' && item.id === 'documents');
|
||||
return (
|
||||
<button
|
||||
key={item.id}
|
||||
onClick={() => onNavigate(item.id)}
|
||||
className={`
|
||||
relative px-4 h-full flex items-center text-sm font-medium transition-colors
|
||||
${isActive ? 'text-warm-text-primary' : 'text-warm-text-muted hover:text-warm-text-secondary'}
|
||||
`}
|
||||
>
|
||||
{item.label}
|
||||
{isActive && (
|
||||
<div className="absolute bottom-0 left-0 right-0 h-0.5 bg-warm-text-secondary rounded-t-full mx-2" />
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* User Profile */}
|
||||
<div className="flex items-center gap-3 pl-6 border-l border-warm-border h-6 relative">
|
||||
<button
|
||||
onClick={() => setShowDropdown(!showDropdown)}
|
||||
className="w-8 h-8 rounded-full bg-warm-selected flex items-center justify-center text-xs font-semibold text-warm-text-secondary border border-warm-divider hover:bg-warm-hover transition-colors"
|
||||
>
|
||||
AD
|
||||
</button>
|
||||
|
||||
{showDropdown && (
|
||||
<>
|
||||
<div
|
||||
className="fixed inset-0 z-10"
|
||||
onClick={() => setShowDropdown(false)}
|
||||
/>
|
||||
<div className="absolute right-0 top-10 w-48 bg-warm-card border border-warm-border rounded-lg shadow-modal z-20">
|
||||
<div className="p-3 border-b border-warm-border">
|
||||
<p className="text-sm font-medium text-warm-text-primary">Admin User</p>
|
||||
<p className="text-xs text-warm-text-muted mt-0.5">Authenticated</p>
|
||||
</div>
|
||||
{onLogout && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setShowDropdown(false)
|
||||
onLogout()
|
||||
}}
|
||||
className="w-full px-3 py-2 text-left text-sm text-warm-text-secondary hover:bg-warm-hover transition-colors flex items-center gap-2"
|
||||
>
|
||||
<LogOut size={14} />
|
||||
Sign Out
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="flex-1 overflow-auto">
|
||||
{children}
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
188
frontend/src/components/Login.tsx
Normal file
188
frontend/src/components/Login.tsx
Normal file
@@ -0,0 +1,188 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Button } from './Button'
|
||||
|
||||
interface LoginProps {
|
||||
onLogin: (token: string) => void
|
||||
}
|
||||
|
||||
export const Login: React.FC<LoginProps> = ({ onLogin }) => {
|
||||
const [token, setToken] = useState('')
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [isCreating, setIsCreating] = useState(false)
|
||||
const [error, setError] = useState('')
|
||||
const [createdToken, setCreatedToken] = useState('')
|
||||
|
||||
const handleLoginWithToken = () => {
|
||||
if (!token.trim()) {
|
||||
setError('Please enter a token')
|
||||
return
|
||||
}
|
||||
localStorage.setItem('admin_token', token.trim())
|
||||
onLogin(token.trim())
|
||||
}
|
||||
|
||||
const handleCreateToken = async () => {
|
||||
if (!name.trim()) {
|
||||
setError('Please enter a token name')
|
||||
return
|
||||
}
|
||||
|
||||
setIsCreating(true)
|
||||
setError('')
|
||||
|
||||
try {
|
||||
const response = await fetch('http://localhost:8000/api/v1/admin/auth/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: name.trim(),
|
||||
description: description.trim() || undefined,
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to create token')
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
setCreatedToken(data.token)
|
||||
setToken(data.token)
|
||||
setError('')
|
||||
} catch (err) {
|
||||
setError('Failed to create token. Please check your connection.')
|
||||
console.error(err)
|
||||
} finally {
|
||||
setIsCreating(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleUseCreatedToken = () => {
|
||||
if (createdToken) {
|
||||
localStorage.setItem('admin_token', createdToken)
|
||||
onLogin(createdToken)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-warm-bg flex items-center justify-center p-4">
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-modal p-8 max-w-md w-full">
|
||||
<h1 className="text-2xl font-bold text-warm-text-primary mb-2">
|
||||
Admin Authentication
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mb-6">
|
||||
Sign in with an admin token to access the document management system
|
||||
</p>
|
||||
|
||||
{error && (
|
||||
<div className="mb-4 p-3 bg-red-50 border border-red-200 text-red-800 rounded text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{createdToken && (
|
||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded">
|
||||
<p className="text-sm font-medium text-green-800 mb-2">Token created successfully!</p>
|
||||
<div className="bg-white border border-green-300 rounded p-2 mb-3">
|
||||
<code className="text-xs font-mono text-warm-text-primary break-all">
|
||||
{createdToken}
|
||||
</code>
|
||||
</div>
|
||||
<p className="text-xs text-green-700 mb-3">
|
||||
Save this token securely. You won't be able to see it again.
|
||||
</p>
|
||||
<Button onClick={handleUseCreatedToken} className="w-full">
|
||||
Use This Token
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Login with existing token */}
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||
Sign in with existing token
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Admin Token
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={token}
|
||||
onChange={(e) => setToken(e.target.value)}
|
||||
placeholder="Enter your admin token"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info font-mono"
|
||||
onKeyDown={(e) => e.key === 'Enter' && handleLoginWithToken()}
|
||||
/>
|
||||
</div>
|
||||
<Button onClick={handleLoginWithToken} className="w-full">
|
||||
Sign In
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="relative">
|
||||
<div className="absolute inset-0 flex items-center">
|
||||
<div className="w-full border-t border-warm-border"></div>
|
||||
</div>
|
||||
<div className="relative flex justify-center text-xs">
|
||||
<span className="px-2 bg-warm-card text-warm-text-muted">OR</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Create new token */}
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||
Create new admin token
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Token Name <span className="text-red-500">*</span>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="e.g., my-laptop"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Description (optional)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="e.g., Personal laptop access"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
onClick={handleCreateToken}
|
||||
variant="secondary"
|
||||
disabled={isCreating}
|
||||
className="w-full"
|
||||
>
|
||||
{isCreating ? 'Creating...' : 'Create Token'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-6 pt-4 border-t border-warm-border">
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Admin tokens are used to authenticate with the document management API.
|
||||
Keep your tokens secure and never share them.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
134
frontend/src/components/Models.tsx
Normal file
134
frontend/src/components/Models.tsx
Normal file
@@ -0,0 +1,134 @@
|
||||
import React from 'react';
|
||||
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
|
||||
import { Button } from './Button';
|
||||
|
||||
const CHART_DATA = [
|
||||
{ name: 'Model A', value: 75 },
|
||||
{ name: 'Model B', value: 82 },
|
||||
{ name: 'Model C', value: 95 },
|
||||
{ name: 'Model D', value: 68 },
|
||||
];
|
||||
|
||||
const METRICS_DATA = [
|
||||
{ name: 'Precision', value: 88 },
|
||||
{ name: 'Recall', value: 76 },
|
||||
{ name: 'F1 Score', value: 91 },
|
||||
{ name: 'Accuracy', value: 82 },
|
||||
];
|
||||
|
||||
const JOBS = [
|
||||
{ id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
|
||||
{ id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
|
||||
{ id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
|
||||
];
|
||||
|
||||
export const Models: React.FC = () => {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto flex gap-8">
|
||||
{/* Left: Job History */}
|
||||
<div className="flex-1">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Recent Training Jobs</h3>
|
||||
|
||||
<div className="space-y-4">
|
||||
{JOBS.map(job => (
|
||||
<div key={job.id} className="bg-warm-card border border-warm-border rounded-lg p-5 shadow-sm hover:border-warm-divider transition-colors">
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<div>
|
||||
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">{job.name}</h4>
|
||||
<p className="text-sm text-warm-text-muted">Started {job.date}</p>
|
||||
</div>
|
||||
<span className={`px-3 py-1 rounded-full text-xs font-medium ${job.status === 'Running' ? 'bg-warm-selected text-warm-text-secondary' : 'bg-warm-selected text-warm-state-success'}`}>
|
||||
{job.status}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{job.status === 'Running' ? (
|
||||
<div className="mt-4">
|
||||
<div className="h-2 w-full bg-warm-selected rounded-full overflow-hidden">
|
||||
<div className="h-full bg-warm-text-secondary w-[65%] rounded-full"></div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="mt-4 flex gap-8">
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Success</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{job.success}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Performance</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{job.metrics}%</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Completed</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">100%</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right: Model Detail */}
|
||||
<div className="w-[400px]">
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
|
||||
<span className="text-sm font-medium text-warm-state-success">Completed</span>
|
||||
</div>
|
||||
|
||||
<div className="mb-8">
|
||||
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
|
||||
<p className="font-medium text-warm-text-primary">Invoices Q4 v2.1</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-8">
|
||||
{/* Chart 1 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Bar Rate Metrics</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={CHART_DATA}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" hide />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
|
||||
/>
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Chart 2 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Entity Extraction Accuracy</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={METRICS_DATA}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip cursor={{fill: '#F1F0ED'}} />
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-8 space-y-3">
|
||||
<Button className="w-full">Download Model</Button>
|
||||
<div className="flex gap-3">
|
||||
<Button variant="secondary" className="flex-1">View Logs</Button>
|
||||
<Button variant="secondary" className="flex-1">Use as Base</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
113
frontend/src/components/Training.tsx
Normal file
113
frontend/src/components/Training.tsx
Normal file
@@ -0,0 +1,113 @@
|
||||
import React, { useState } from 'react';
|
||||
import { Check, AlertCircle } from 'lucide-react';
|
||||
import { Button } from './Button';
|
||||
import { DocumentStatus } from '../types';
|
||||
|
||||
export const Training: React.FC = () => {
|
||||
const [split, setSplit] = useState(80);
|
||||
|
||||
const docs = [
|
||||
{ id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
|
||||
{ id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||
{ id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||
{ id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||
{ id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||
{ id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||
{ id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto h-[calc(100vh-56px)] flex gap-8">
|
||||
{/* Document Selection List */}
|
||||
<div className="flex-1 flex flex-col">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Document Selection</h2>
|
||||
|
||||
<div className="flex-1 bg-warm-card border border-warm-border rounded-lg overflow-hidden flex flex-col shadow-sm">
|
||||
<div className="overflow-auto flex-1">
|
||||
<table className="w-full text-left">
|
||||
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
||||
<tr>
|
||||
<th className="py-3 pl-6 pr-4 w-12"><input type="checkbox" className="rounded border-warm-divider"/></th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document name</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Date</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{docs.map(doc => (
|
||||
<tr key={doc.id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 pl-6 pr-4"><input type="checkbox" defaultChecked className="rounded border-warm-divider accent-warm-state-info"/></td>
|
||||
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{doc.name}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.date}</td>
|
||||
<td className="py-3 px-4">
|
||||
{doc.status === DocumentStatus.VERIFIED ? (
|
||||
<div className="flex items-center text-warm-state-success text-sm font-medium">
|
||||
<div className="w-5 h-5 rounded-full bg-warm-state-success flex items-center justify-center text-white mr-2">
|
||||
<Check size={12} strokeWidth={3}/>
|
||||
</div>
|
||||
Verified
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center text-warm-text-muted text-sm">
|
||||
<div className="w-5 h-5 rounded-full bg-[#BDBBB5] flex items-center justify-center text-white mr-2">
|
||||
<span className="font-bold text-[10px]">!</span>
|
||||
</div>
|
||||
Partial
|
||||
</div>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Configuration Panel */}
|
||||
<div className="w-96">
|
||||
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-6">Training Configuration</h3>
|
||||
|
||||
<div className="space-y-6">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Model Name</label>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="e.g. Invoices Q4"
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Base Model</label>
|
||||
<select className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none">
|
||||
<option>LayoutLMv3 (Standard)</option>
|
||||
<option>Donut (Beta)</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="flex justify-between mb-2">
|
||||
<label className="block text-sm font-medium text-warm-text-secondary">Train/Test Split</label>
|
||||
<span className="text-xs font-mono text-warm-text-muted">{split}% / {100-split}%</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min="50"
|
||||
max="95"
|
||||
value={split}
|
||||
onChange={(e) => setSplit(parseInt(e.target.value))}
|
||||
className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="pt-4 border-t border-warm-border">
|
||||
<Button className="w-full h-12">Start Training</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
210
frontend/src/components/UploadModal.tsx
Normal file
210
frontend/src/components/UploadModal.tsx
Normal file
@@ -0,0 +1,210 @@
|
||||
import React, { useState, useRef } from 'react'
|
||||
import { X, UploadCloud, File, CheckCircle, AlertCircle } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
|
||||
interface UploadModalProps {
|
||||
isOpen: boolean
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
||||
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
||||
const [errorMessage, setErrorMessage] = useState('')
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const { uploadDocument, isUploading } = useDocuments({})
|
||||
|
||||
if (!isOpen) return null
|
||||
|
||||
const handleFileSelect = (files: FileList | null) => {
|
||||
if (!files) return
|
||||
|
||||
const pdfFiles = Array.from(files).filter(file => {
|
||||
const isPdf = file.type === 'application/pdf'
|
||||
const isImage = file.type.startsWith('image/')
|
||||
const isUnder25MB = file.size <= 25 * 1024 * 1024
|
||||
return (isPdf || isImage) && isUnder25MB
|
||||
})
|
||||
|
||||
setSelectedFiles(prev => [...prev, ...pdfFiles])
|
||||
setUploadStatus('idle')
|
||||
setErrorMessage('')
|
||||
}
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(false)
|
||||
handleFileSelect(e.dataTransfer.files)
|
||||
}
|
||||
|
||||
const handleBrowseClick = () => {
|
||||
fileInputRef.current?.click()
|
||||
}
|
||||
|
||||
const removeFile = (index: number) => {
|
||||
setSelectedFiles(prev => prev.filter((_, i) => i !== index))
|
||||
}
|
||||
|
||||
const handleUpload = async () => {
|
||||
if (selectedFiles.length === 0) {
|
||||
setErrorMessage('Please select at least one file')
|
||||
return
|
||||
}
|
||||
|
||||
setUploadStatus('uploading')
|
||||
setErrorMessage('')
|
||||
|
||||
try {
|
||||
// Upload files one by one
|
||||
for (const file of selectedFiles) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
uploadDocument(file, {
|
||||
onSuccess: () => resolve(),
|
||||
onError: (error: Error) => reject(error),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
setUploadStatus('success')
|
||||
setTimeout(() => {
|
||||
onClose()
|
||||
setSelectedFiles([])
|
||||
setUploadStatus('idle')
|
||||
}, 1500)
|
||||
} catch (error) {
|
||||
setUploadStatus('error')
|
||||
setErrorMessage(error instanceof Error ? error.message : 'Upload failed')
|
||||
}
|
||||
}
|
||||
|
||||
const handleClose = () => {
|
||||
if (uploadStatus === 'uploading') {
|
||||
return // Prevent closing during upload
|
||||
}
|
||||
setSelectedFiles([])
|
||||
setUploadStatus('idle')
|
||||
setErrorMessage('')
|
||||
onClose()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/20 backdrop-blur-sm transition-opacity duration-200">
|
||||
<div
|
||||
className="w-full max-w-lg bg-warm-card rounded-lg shadow-modal border border-warm-border transform transition-all duration-200 scale-100 p-6"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<h3 className="text-xl font-semibold text-warm-text-primary">Upload Documents</h3>
|
||||
<button
|
||||
onClick={handleClose}
|
||||
className="text-warm-text-muted hover:text-warm-text-primary transition-colors disabled:opacity-50"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<X size={20} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Drop Zone */}
|
||||
<div
|
||||
className={`
|
||||
w-full h-48 rounded-lg border-2 border-dashed flex flex-col items-center justify-center gap-3 transition-colors duration-150 mb-6 cursor-pointer
|
||||
${isDragging ? 'border-warm-text-secondary bg-warm-selected' : 'border-warm-divider bg-warm-bg hover:bg-warm-hover'}
|
||||
${uploadStatus === 'uploading' ? 'opacity-50 pointer-events-none' : ''}
|
||||
`}
|
||||
onDragOver={(e) => { e.preventDefault(); setIsDragging(true); }}
|
||||
onDragLeave={() => setIsDragging(false)}
|
||||
onDrop={handleDrop}
|
||||
onClick={handleBrowseClick}
|
||||
>
|
||||
<div className="p-3 bg-white rounded-full shadow-sm">
|
||||
<UploadCloud size={24} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center">
|
||||
<p className="text-sm font-medium text-warm-text-primary">
|
||||
Drag & drop files here or <span className="underline decoration-1 underline-offset-2 hover:text-warm-state-info">Browse</span>
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted mt-1">PDF, JPG, PNG up to 25MB</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
multiple
|
||||
accept=".pdf,image/*"
|
||||
className="hidden"
|
||||
onChange={(e) => handleFileSelect(e.target.files)}
|
||||
/>
|
||||
|
||||
{/* Selected Files */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6 max-h-40 overflow-y-auto">
|
||||
<p className="text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Selected Files ({selectedFiles.length})
|
||||
</p>
|
||||
<div className="space-y-2">
|
||||
{selectedFiles.map((file, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="flex items-center justify-between p-2 bg-warm-bg rounded border border-warm-border"
|
||||
>
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0">
|
||||
<File size={16} className="text-warm-text-muted flex-shrink-0" />
|
||||
<span className="text-sm text-warm-text-secondary truncate">
|
||||
{file.name}
|
||||
</span>
|
||||
<span className="text-xs text-warm-text-muted flex-shrink-0">
|
||||
({(file.size / 1024 / 1024).toFixed(2)} MB)
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => removeFile(index)}
|
||||
className="text-warm-text-muted hover:text-warm-state-error ml-2 flex-shrink-0"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<X size={16} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Status Messages */}
|
||||
{uploadStatus === 'success' && (
|
||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
|
||||
<CheckCircle size={16} className="text-green-600" />
|
||||
<span className="text-sm text-green-800">Upload successful!</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{uploadStatus === 'error' && errorMessage && (
|
||||
<div className="mb-4 p-3 bg-red-50 border border-red-200 rounded flex items-center gap-2">
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
<span className="text-sm text-red-800">{errorMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Actions */}
|
||||
<div className="mt-8 flex justify-end gap-3">
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={handleClose}
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleUpload}
|
||||
disabled={selectedFiles.length === 0 || uploadStatus === 'uploading'}
|
||||
>
|
||||
{uploadStatus === 'uploading' ? 'Uploading...' : `Upload ${selectedFiles.length > 0 ? `(${selectedFiles.length})` : ''}`}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
4
frontend/src/hooks/index.ts
Normal file
4
frontend/src/hooks/index.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export { useDocuments } from './useDocuments'
|
||||
export { useDocumentDetail } from './useDocumentDetail'
|
||||
export { useAnnotations } from './useAnnotations'
|
||||
export { useTraining, useTrainingDocuments } from './useTraining'
|
||||
70
frontend/src/hooks/useAnnotations.ts
Normal file
70
frontend/src/hooks/useAnnotations.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import { useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { annotationsApi } from '../api/endpoints'
|
||||
import type { CreateAnnotationRequest, AnnotationOverrideRequest } from '../api/types'
|
||||
|
||||
export const useAnnotations = (documentId: string) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (annotation: CreateAnnotationRequest) =>
|
||||
annotationsApi.create(documentId, annotation),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateMutation = useMutation({
|
||||
mutationFn: ({
|
||||
annotationId,
|
||||
updates,
|
||||
}: {
|
||||
annotationId: string
|
||||
updates: Partial<CreateAnnotationRequest>
|
||||
}) => annotationsApi.update(documentId, annotationId, updates),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (annotationId: string) =>
|
||||
annotationsApi.delete(documentId, annotationId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const verifyMutation = useMutation({
|
||||
mutationFn: (annotationId: string) =>
|
||||
annotationsApi.verify(documentId, annotationId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const overrideMutation = useMutation({
|
||||
mutationFn: ({
|
||||
annotationId,
|
||||
overrideData,
|
||||
}: {
|
||||
annotationId: string
|
||||
overrideData: AnnotationOverrideRequest
|
||||
}) => annotationsApi.override(documentId, annotationId, overrideData),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
createAnnotation: createMutation.mutate,
|
||||
isCreating: createMutation.isPending,
|
||||
updateAnnotation: updateMutation.mutate,
|
||||
isUpdating: updateMutation.isPending,
|
||||
deleteAnnotation: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
verifyAnnotation: verifyMutation.mutate,
|
||||
isVerifying: verifyMutation.isPending,
|
||||
overrideAnnotation: overrideMutation.mutate,
|
||||
isOverriding: overrideMutation.isPending,
|
||||
}
|
||||
}
|
||||
25
frontend/src/hooks/useDocumentDetail.ts
Normal file
25
frontend/src/hooks/useDocumentDetail.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { documentsApi } from '../api/endpoints'
|
||||
import type { DocumentDetailResponse } from '../api/types'
|
||||
|
||||
export const useDocumentDetail = (documentId: string | null) => {
|
||||
const { data, isLoading, error, refetch } = useQuery<DocumentDetailResponse>({
|
||||
queryKey: ['document', documentId],
|
||||
queryFn: () => {
|
||||
if (!documentId) {
|
||||
throw new Error('Document ID is required')
|
||||
}
|
||||
return documentsApi.getDetail(documentId)
|
||||
},
|
||||
enabled: !!documentId,
|
||||
staleTime: 10000,
|
||||
})
|
||||
|
||||
return {
|
||||
document: data || null,
|
||||
annotations: data?.annotations || [],
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
}
|
||||
}
|
||||
78
frontend/src/hooks/useDocuments.ts
Normal file
78
frontend/src/hooks/useDocuments.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { documentsApi } from '../api/endpoints'
|
||||
import type { DocumentListResponse, UploadDocumentResponse } from '../api/types'
|
||||
|
||||
interface UseDocumentsParams {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}
|
||||
|
||||
export const useDocuments = (params: UseDocumentsParams = {}) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isLoading, error, refetch } = useQuery<DocumentListResponse>({
|
||||
queryKey: ['documents', params],
|
||||
queryFn: () => documentsApi.list(params),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
const uploadMutation = useMutation({
|
||||
mutationFn: (file: File) => documentsApi.upload(file),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const batchUploadMutation = useMutation({
|
||||
mutationFn: ({ files, csvFile }: { files: File[]; csvFile?: File }) =>
|
||||
documentsApi.batchUpload(files, csvFile),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (documentId: string) => documentsApi.delete(documentId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateStatusMutation = useMutation({
|
||||
mutationFn: ({ documentId, status }: { documentId: string; status: string }) =>
|
||||
documentsApi.updateStatus(documentId, status),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const triggerAutoLabelMutation = useMutation({
|
||||
mutationFn: (documentId: string) => documentsApi.triggerAutoLabel(documentId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
documents: data?.documents || [],
|
||||
total: data?.total || 0,
|
||||
limit: data?.limit || params.limit || 20,
|
||||
offset: data?.offset || params.offset || 0,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
uploadDocument: uploadMutation.mutate,
|
||||
uploadDocumentAsync: uploadMutation.mutateAsync,
|
||||
isUploading: uploadMutation.isPending,
|
||||
batchUpload: batchUploadMutation.mutate,
|
||||
batchUploadAsync: batchUploadMutation.mutateAsync,
|
||||
isBatchUploading: batchUploadMutation.isPending,
|
||||
deleteDocument: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
updateStatus: updateStatusMutation.mutate,
|
||||
isUpdatingStatus: updateStatusMutation.isPending,
|
||||
triggerAutoLabel: triggerAutoLabelMutation.mutate,
|
||||
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
|
||||
}
|
||||
}
|
||||
83
frontend/src/hooks/useTraining.ts
Normal file
83
frontend/src/hooks/useTraining.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { trainingApi } from '../api/endpoints'
|
||||
import type { TrainingModelsResponse } from '../api/types'
|
||||
|
||||
export const useTraining = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data: modelsData, isLoading: isLoadingModels } =
|
||||
useQuery<TrainingModelsResponse>({
|
||||
queryKey: ['training', 'models'],
|
||||
queryFn: () => trainingApi.getModels(),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
const startTrainingMutation = useMutation({
|
||||
mutationFn: (config: {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
model_base?: string
|
||||
}) => trainingApi.startTraining(config),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||
},
|
||||
})
|
||||
|
||||
const cancelTaskMutation = useMutation({
|
||||
mutationFn: (taskId: string) => trainingApi.cancelTask(taskId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||
},
|
||||
})
|
||||
|
||||
const downloadModelMutation = useMutation({
|
||||
mutationFn: (taskId: string) => trainingApi.downloadModel(taskId),
|
||||
onSuccess: (blob, taskId) => {
|
||||
const url = window.URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = `model-${taskId}.pt`
|
||||
document.body.appendChild(a)
|
||||
a.click()
|
||||
window.URL.revokeObjectURL(url)
|
||||
document.body.removeChild(a)
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
models: modelsData?.models || [],
|
||||
total: modelsData?.total || 0,
|
||||
isLoadingModels,
|
||||
startTraining: startTrainingMutation.mutate,
|
||||
startTrainingAsync: startTrainingMutation.mutateAsync,
|
||||
isStartingTraining: startTrainingMutation.isPending,
|
||||
cancelTask: cancelTaskMutation.mutate,
|
||||
isCancelling: cancelTaskMutation.isPending,
|
||||
downloadModel: downloadModelMutation.mutate,
|
||||
isDownloading: downloadModelMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
export const useTrainingDocuments = (params?: {
|
||||
has_annotations?: boolean
|
||||
min_annotation_count?: number
|
||||
exclude_used_in_training?: boolean
|
||||
limit?: number
|
||||
offset?: number
|
||||
}) => {
|
||||
const { data, isLoading, error } = useQuery({
|
||||
queryKey: ['training', 'documents', params],
|
||||
queryFn: () => trainingApi.getDocumentsForTraining(params),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
documents: data?.documents || [],
|
||||
total: data?.total || 0,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
23
frontend/src/main.tsx
Normal file
23
frontend/src/main.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
import React from 'react'
|
||||
import ReactDOM from 'react-dom/client'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import App from './App'
|
||||
import './styles/index.css'
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: 1,
|
||||
refetchOnWindowFocus: false,
|
||||
staleTime: 30000,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||
<React.StrictMode>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<App />
|
||||
</QueryClientProvider>
|
||||
</React.StrictMode>
|
||||
)
|
||||
26
frontend/src/styles/index.css
Normal file
26
frontend/src/styles/index.css
Normal file
@@ -0,0 +1,26 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
@layer base {
|
||||
body {
|
||||
@apply bg-warm-bg text-warm-text-primary;
|
||||
}
|
||||
|
||||
/* Custom scrollbar */
|
||||
::-webkit-scrollbar {
|
||||
@apply w-2 h-2;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
@apply bg-transparent;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
@apply bg-warm-divider rounded;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
@apply bg-warm-text-disabled;
|
||||
}
|
||||
}
|
||||
48
frontend/src/types/index.ts
Normal file
48
frontend/src/types/index.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
// Legacy types for backward compatibility with old components
|
||||
// These will be gradually replaced with API types
|
||||
|
||||
export enum DocumentStatus {
|
||||
PENDING = 'Pending',
|
||||
LABELED = 'Labeled',
|
||||
VERIFIED = 'Verified',
|
||||
PARTIAL = 'Partial'
|
||||
}
|
||||
|
||||
export interface Document {
|
||||
id: string
|
||||
name: string
|
||||
date: string
|
||||
status: DocumentStatus
|
||||
exported: boolean
|
||||
autoLabelProgress?: number
|
||||
autoLabelStatus?: 'Running' | 'Completed' | 'Failed'
|
||||
}
|
||||
|
||||
export interface Annotation {
|
||||
id: string
|
||||
text: string
|
||||
label: string
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
isAuto?: boolean
|
||||
}
|
||||
|
||||
export interface TrainingJob {
|
||||
id: string
|
||||
name: string
|
||||
startDate: string
|
||||
status: 'Running' | 'Completed' | 'Failed'
|
||||
progress: number
|
||||
metrics?: {
|
||||
accuracy: number
|
||||
precision: number
|
||||
recall: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface ModelMetric {
|
||||
name: string
|
||||
value: number
|
||||
}
|
||||
47
frontend/tailwind.config.js
Normal file
47
frontend/tailwind.config.js
Normal file
@@ -0,0 +1,47 @@
|
||||
export default {
|
||||
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||
theme: {
|
||||
extend: {
|
||||
fontFamily: {
|
||||
sans: ['Inter', 'SF Pro', 'system-ui', 'sans-serif'],
|
||||
mono: ['JetBrains Mono', 'SF Mono', 'monospace'],
|
||||
},
|
||||
colors: {
|
||||
warm: {
|
||||
bg: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
border: '#E6E4E1',
|
||||
divider: '#D8D6D2',
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
state: {
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
info: '#3A3A3A',
|
||||
}
|
||||
}
|
||||
},
|
||||
boxShadow: {
|
||||
'card': '0 1px 3px rgba(0,0,0,0.08)',
|
||||
'modal': '0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)',
|
||||
},
|
||||
animation: {
|
||||
'fade-in': 'fadeIn 0.3s ease-out',
|
||||
},
|
||||
keyframes: {
|
||||
fadeIn: {
|
||||
'0%': { opacity: '0', transform: 'translateY(10px)' },
|
||||
'100%': { opacity: '1', transform: 'translateY(0)' },
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
29
frontend/tsconfig.json
Normal file
29
frontend/tsconfig.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"experimentalDecorators": true,
|
||||
"useDefineForClassFields": false,
|
||||
"module": "ESNext",
|
||||
"lib": [
|
||||
"ES2022",
|
||||
"DOM",
|
||||
"DOM.Iterable"
|
||||
],
|
||||
"skipLibCheck": true,
|
||||
"types": [
|
||||
"node"
|
||||
],
|
||||
"moduleResolution": "bundler",
|
||||
"isolatedModules": true,
|
||||
"moduleDetection": "force",
|
||||
"allowJs": true,
|
||||
"jsx": "react-jsx",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
"./*"
|
||||
]
|
||||
},
|
||||
"allowImportingTsExtensions": true,
|
||||
"noEmit": true
|
||||
}
|
||||
}
|
||||
16
frontend/vite.config.ts
Normal file
16
frontend/vite.config.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
|
||||
export default defineConfig({
|
||||
server: {
|
||||
port: 3000,
|
||||
host: '0.0.0.0',
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:8000',
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [react()],
|
||||
});
|
||||
@@ -21,3 +21,7 @@ pyyaml>=6.0 # YAML config files
|
||||
# Utilities
|
||||
tqdm>=4.65.0 # Progress bars
|
||||
python-dotenv>=1.0.0 # Environment variable management
|
||||
|
||||
# Database
|
||||
psycopg2-binary>=2.9.0 # PostgreSQL driver
|
||||
sqlmodel>=0.0.22 # SQLModel ORM (SQLAlchemy + Pydantic)
|
||||
|
||||
@@ -16,7 +16,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string
|
||||
from src.config import get_db_connection_string
|
||||
|
||||
from ..normalize import normalize_field
|
||||
from ..matcher import FieldMatcher
|
||||
|
||||
@@ -12,7 +12,7 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string
|
||||
from src.config import get_db_connection_string
|
||||
|
||||
|
||||
def load_reports_from_db() -> dict:
|
||||
|
||||
@@ -34,7 +34,7 @@ if sys.platform == 'win32':
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string, PATHS, AUTOLABEL
|
||||
from src.config import get_db_connection_string, PATHS, AUTOLABEL
|
||||
|
||||
# Global OCR engine for worker processes (initialized once per worker)
|
||||
_worker_ocr_engine = None
|
||||
|
||||
@@ -16,7 +16,7 @@ from psycopg2.extras import execute_values
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string, PATHS
|
||||
from src.config import get_db_connection_string, PATHS
|
||||
|
||||
|
||||
def create_tables(conn):
|
||||
|
||||
@@ -10,6 +10,9 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from src.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -38,8 +41,8 @@ def main():
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=150,
|
||||
help='DPI for PDF rendering (default: 150, must match training)'
|
||||
default=DEFAULT_DPI,
|
||||
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-fallback',
|
||||
|
||||
@@ -17,6 +17,7 @@ from tqdm import tqdm
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
from src.data.db import DocumentDB
|
||||
from src.data.csv_loader import CSVLoader
|
||||
from src.normalize.normalizer import normalize_field
|
||||
@@ -144,7 +145,7 @@ def process_single_document(args):
|
||||
ocr_engine = OCREngine()
|
||||
for page_no in range(pdf_doc.page_count):
|
||||
# Render page to image
|
||||
img = pdf_doc.render_page(page_no, dpi=150)
|
||||
img = pdf_doc.render_page(page_no, dpi=DEFAULT_DPI)
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
"""Configure logging."""
|
||||
@@ -65,8 +67,8 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--dpi",
|
||||
type=int,
|
||||
default=150,
|
||||
help="DPI for PDF rendering (must match training DPI)",
|
||||
default=DEFAULT_DPI,
|
||||
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -11,7 +11,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import PATHS
|
||||
from src.config import DEFAULT_DPI, PATHS
|
||||
|
||||
|
||||
def main():
|
||||
@@ -103,8 +103,8 @@ def main():
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=150,
|
||||
help='DPI used for rendering (default: 150, must match autolabel rendering)'
|
||||
default=DEFAULT_DPI,
|
||||
help=f'DPI used for rendering (default: {DEFAULT_DPI}, must match autolabel rendering)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--export-only',
|
||||
|
||||
@@ -8,9 +8,13 @@ from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
env_path = Path(__file__).parent / '.env'
|
||||
# .env is at project root, config.py is in src/
|
||||
env_path = Path(__file__).parent.parent / '.env'
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
# Global DPI setting - must match training DPI for optimal model performance
|
||||
DEFAULT_DPI = 150
|
||||
|
||||
|
||||
def _is_wsl() -> bool:
|
||||
"""Check if running inside WSL (Windows Subsystem for Linux)."""
|
||||
@@ -69,7 +73,7 @@ else:
|
||||
# Auto-labeling Configuration
|
||||
AUTOLABEL = {
|
||||
'workers': 2,
|
||||
'dpi': 150,
|
||||
'dpi': DEFAULT_DPI,
|
||||
'min_confidence': 0.5,
|
||||
'train_ratio': 0.8,
|
||||
'val_ratio': 0.1,
|
||||
1156
src/data/admin_db.py
Normal file
1156
src/data/admin_db.py
Normal file
File diff suppressed because it is too large
Load Diff
339
src/data/admin_models.py
Normal file
339
src/data/admin_models.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Admin API SQLModel Database Models
|
||||
|
||||
Defines the database schema for admin document management, annotations, and training tasks.
|
||||
Includes batch upload support, training document links, and annotation history.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CSV to Field Class Mapping
|
||||
# =============================================================================
|
||||
|
||||
CSV_TO_CLASS_MAPPING: dict[str, int] = {
|
||||
"InvoiceNumber": 0, # invoice_number
|
||||
"InvoiceDate": 1, # invoice_date
|
||||
"InvoiceDueDate": 2, # invoice_due_date
|
||||
"OCR": 3, # ocr_number
|
||||
"Bankgiro": 4, # bankgiro
|
||||
"Plusgiro": 5, # plusgiro
|
||||
"Amount": 6, # amount
|
||||
"supplier_organisation_number": 7, # supplier_organisation_number
|
||||
# 8: payment_line (derived from OCR/Bankgiro/Amount)
|
||||
"customer_number": 9, # customer_number
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Core Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AdminToken(SQLModel, table=True):
|
||||
"""Admin authentication token."""
|
||||
|
||||
__tablename__ = "admin_tokens"
|
||||
|
||||
token: str = Field(primary_key=True, max_length=255)
|
||||
name: str = Field(max_length=255)
|
||||
is_active: bool = Field(default=True)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class AdminDocument(SQLModel, table=True):
|
||||
"""Document uploaded for labeling/annotation."""
|
||||
|
||||
__tablename__ = "admin_documents"
|
||||
|
||||
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
content_type: str = Field(max_length=100)
|
||||
file_path: str = Field(max_length=512) # Path to stored file
|
||||
page_count: int = Field(default=1)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, auto_labeling, labeled, exported
|
||||
auto_label_status: str | None = Field(default=None, max_length=20)
|
||||
# Auto-label status: running, completed, failed
|
||||
auto_label_error: str | None = Field(default=None)
|
||||
# v2: Upload source tracking
|
||||
upload_source: str = Field(default="ui", max_length=20)
|
||||
# Upload source: ui, api
|
||||
batch_id: UUID | None = Field(default=None, index=True)
|
||||
# Link to batch upload (if uploaded via ZIP)
|
||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Original CSV values for reference
|
||||
auto_label_queued_at: datetime | None = Field(default=None)
|
||||
# When auto-label was queued
|
||||
annotation_lock_until: datetime | None = Field(default=None)
|
||||
# Lock for manual annotation while auto-label runs
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class AdminAnnotation(SQLModel, table=True):
|
||||
"""Annotation for a document (bounding box + label)."""
|
||||
|
||||
__tablename__ = "admin_annotations"
|
||||
|
||||
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
page_number: int = Field(default=1) # 1-indexed
|
||||
class_id: int # 0-9 for invoice fields
|
||||
class_name: str = Field(max_length=50) # e.g., "invoice_number"
|
||||
# Bounding box (normalized 0-1 coordinates)
|
||||
x_center: float
|
||||
y_center: float
|
||||
width: float
|
||||
height: float
|
||||
# Original pixel coordinates (for display)
|
||||
bbox_x: int
|
||||
bbox_y: int
|
||||
bbox_width: int
|
||||
bbox_height: int
|
||||
# OCR extracted text (if available)
|
||||
text_value: str | None = Field(default=None)
|
||||
confidence: float | None = Field(default=None)
|
||||
# Source: manual, auto, imported
|
||||
source: str = Field(default="manual", max_length=20, index=True)
|
||||
# v2: Verification fields
|
||||
is_verified: bool = Field(default=False, index=True)
|
||||
verified_at: datetime | None = Field(default=None)
|
||||
verified_by: str | None = Field(default=None, max_length=255)
|
||||
# v2: Override tracking
|
||||
override_source: str | None = Field(default=None, max_length=20)
|
||||
# If this annotation overrides another: 'auto' or 'imported'
|
||||
original_annotation_id: UUID | None = Field(default=None)
|
||||
# Reference to the annotation this overrides
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingTask(SQLModel, table=True):
|
||||
"""Training/fine-tuning task."""
|
||||
|
||||
__tablename__ = "training_tasks"
|
||||
|
||||
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, scheduled, running, completed, failed, cancelled
|
||||
task_type: str = Field(default="train", max_length=20)
|
||||
# Task type: train, finetune
|
||||
# Training configuration
|
||||
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Schedule settings
|
||||
scheduled_at: datetime | None = Field(default=None)
|
||||
cron_expression: str | None = Field(default=None, max_length=50)
|
||||
is_recurring: bool = Field(default=False)
|
||||
# Execution details
|
||||
started_at: datetime | None = Field(default=None)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
# Result metrics
|
||||
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
model_path: str | None = Field(default=None, max_length=512)
|
||||
# v2: Document count and extracted metrics
|
||||
document_count: int = Field(default=0)
|
||||
# Count of documents used in training
|
||||
metrics_mAP: float | None = Field(default=None, index=True)
|
||||
metrics_precision: float | None = Field(default=None)
|
||||
metrics_recall: float | None = Field(default=None)
|
||||
# Extracted metrics for easy querying
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingLog(SQLModel, table=True):
|
||||
"""Training log entry."""
|
||||
|
||||
__tablename__ = "training_logs"
|
||||
|
||||
log_id: int | None = Field(default=None, primary_key=True)
|
||||
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||
level: str = Field(max_length=20) # INFO, WARNING, ERROR
|
||||
message: str
|
||||
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Batch Upload Models (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BatchUpload(SQLModel, table=True):
|
||||
"""Batch upload of multiple documents via ZIP file."""
|
||||
|
||||
__tablename__ = "batch_uploads"
|
||||
|
||||
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||
filename: str = Field(max_length=255) # ZIP filename
|
||||
file_size: int
|
||||
upload_source: str = Field(default="ui", max_length=20)
|
||||
# Upload source: ui, api
|
||||
status: str = Field(default="processing", max_length=20, index=True)
|
||||
# Status: processing, completed, partial, failed
|
||||
total_files: int = Field(default=0)
|
||||
processed_files: int = Field(default=0)
|
||||
# Number of files processed so far
|
||||
successful_files: int = Field(default=0)
|
||||
failed_files: int = Field(default=0)
|
||||
csv_filename: str | None = Field(default=None, max_length=255)
|
||||
# CSV file used for auto-labeling
|
||||
csv_row_count: int | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class BatchUploadFile(SQLModel, table=True):
|
||||
"""Individual file within a batch upload."""
|
||||
|
||||
__tablename__ = "batch_upload_files"
|
||||
|
||||
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
|
||||
filename: str = Field(max_length=255) # PDF filename within ZIP
|
||||
document_id: UUID | None = Field(default=None)
|
||||
# Link to created AdminDocument (if successful)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Status: pending, processing, completed, failed, skipped
|
||||
error_message: str | None = Field(default=None)
|
||||
annotation_count: int = Field(default=0)
|
||||
# Number of annotations created for this file
|
||||
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# CSV row data for this file (if available)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
processed_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Training Document Link (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TrainingDocumentLink(SQLModel, table=True):
|
||||
"""Junction table linking training tasks to documents."""
|
||||
|
||||
__tablename__ = "training_document_links"
|
||||
|
||||
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Snapshot of annotations at training time (includes count, verified count, etc.)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation History (v2)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AnnotationHistory(SQLModel, table=True):
|
||||
"""History of annotation changes (for override tracking)."""
|
||||
|
||||
__tablename__ = "annotation_history"
|
||||
|
||||
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
# Change action: created, updated, deleted, override
|
||||
action: str = Field(max_length=20, index=True)
|
||||
# Previous value (for updates/deletes)
|
||||
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# New value (for creates/updates)
|
||||
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Change metadata
|
||||
changed_by: str | None = Field(default=None, max_length=255)
|
||||
# User/token who made the change
|
||||
change_reason: str | None = Field(default=None)
|
||||
# Optional reason for change
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# Field class mapping (same as src/cli/train.py)
|
||||
FIELD_CLASSES = {
|
||||
0: "invoice_number",
|
||||
1: "invoice_date",
|
||||
2: "invoice_due_date",
|
||||
3: "ocr_number",
|
||||
4: "bankgiro",
|
||||
5: "plusgiro",
|
||||
6: "amount",
|
||||
7: "supplier_organisation_number",
|
||||
8: "payment_line",
|
||||
9: "customer_number",
|
||||
}
|
||||
|
||||
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
|
||||
|
||||
|
||||
# Read-only models for API responses
|
||||
class AdminDocumentRead(SQLModel):
|
||||
"""Admin document response model."""
|
||||
|
||||
document_id: UUID
|
||||
filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
page_count: int
|
||||
status: str
|
||||
auto_label_status: str | None
|
||||
auto_label_error: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class AdminAnnotationRead(SQLModel):
|
||||
"""Admin annotation response model."""
|
||||
|
||||
annotation_id: UUID
|
||||
document_id: UUID
|
||||
page_number: int
|
||||
class_id: int
|
||||
class_name: str
|
||||
x_center: float
|
||||
y_center: float
|
||||
width: float
|
||||
height: float
|
||||
bbox_x: int
|
||||
bbox_y: int
|
||||
bbox_width: int
|
||||
bbox_height: int
|
||||
text_value: str | None
|
||||
confidence: float | None
|
||||
source: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TrainingTaskRead(SQLModel):
|
||||
"""Training task response model."""
|
||||
|
||||
task_id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
task_type: str
|
||||
config: dict[str, Any] | None
|
||||
scheduled_at: datetime | None
|
||||
is_recurring: bool
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
error_message: str | None
|
||||
result_metrics: dict[str, Any] | None
|
||||
model_path: str | None
|
||||
created_at: datetime
|
||||
374
src/data/async_request_db.py
Normal file
374
src/data/async_request_db.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Async Request Database Operations
|
||||
|
||||
Database interface for async invoice processing requests using SQLModel.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, text
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from src.data.database import get_session_context, create_db_and_tables, close_engine
|
||||
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Legacy dataclasses for backward compatibility
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiKeyConfig:
|
||||
"""API key configuration and limits (legacy compatibility)."""
|
||||
|
||||
api_key: str
|
||||
name: str
|
||||
is_active: bool
|
||||
requests_per_minute: int
|
||||
max_concurrent_jobs: int
|
||||
max_file_size_mb: int
|
||||
|
||||
|
||||
class AsyncRequestDB:
|
||||
"""Database interface for async processing requests using SQLModel."""
|
||||
|
||||
def __init__(self, connection_string: str | None = None) -> None:
|
||||
# connection_string is kept for backward compatibility but ignored
|
||||
# SQLModel uses the global engine from database.py
|
||||
self._initialized = False
|
||||
|
||||
def connect(self):
|
||||
"""Legacy method - returns self for compatibility."""
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connections."""
|
||||
close_engine()
|
||||
|
||||
def __enter__(self) -> "AsyncRequestDB":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
pass # Sessions are managed per-operation
|
||||
|
||||
def create_tables(self) -> None:
|
||||
"""Create async processing tables if they don't exist."""
|
||||
create_db_and_tables()
|
||||
self._initialized = True
|
||||
|
||||
# ==========================================================================
|
||||
# API Key Operations
|
||||
# ==========================================================================
|
||||
|
||||
def is_valid_api_key(self, api_key: str) -> bool:
|
||||
"""Check if API key exists and is active."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(ApiKey, api_key)
|
||||
return result is not None and result.is_active is True
|
||||
|
||||
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
|
||||
"""Get API key configuration and limits."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(ApiKey, api_key)
|
||||
if result is None:
|
||||
return None
|
||||
return ApiKeyConfig(
|
||||
api_key=result.api_key,
|
||||
name=result.name,
|
||||
is_active=result.is_active,
|
||||
requests_per_minute=result.requests_per_minute,
|
||||
max_concurrent_jobs=result.max_concurrent_jobs,
|
||||
max_file_size_mb=result.max_file_size_mb,
|
||||
)
|
||||
|
||||
def create_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
name: str,
|
||||
requests_per_minute: int = 10,
|
||||
max_concurrent_jobs: int = 3,
|
||||
max_file_size_mb: int = 50,
|
||||
) -> None:
|
||||
"""Create a new API key."""
|
||||
with get_session_context() as session:
|
||||
existing = session.get(ApiKey, api_key)
|
||||
if existing:
|
||||
existing.name = name
|
||||
existing.requests_per_minute = requests_per_minute
|
||||
existing.max_concurrent_jobs = max_concurrent_jobs
|
||||
existing.max_file_size_mb = max_file_size_mb
|
||||
session.add(existing)
|
||||
else:
|
||||
new_key = ApiKey(
|
||||
api_key=api_key,
|
||||
name=name,
|
||||
requests_per_minute=requests_per_minute,
|
||||
max_concurrent_jobs=max_concurrent_jobs,
|
||||
max_file_size_mb=max_file_size_mb,
|
||||
)
|
||||
session.add(new_key)
|
||||
|
||||
def update_api_key_usage(self, api_key: str) -> None:
|
||||
"""Update API key last used timestamp and increment total requests."""
|
||||
with get_session_context() as session:
|
||||
key = session.get(ApiKey, api_key)
|
||||
if key:
|
||||
key.last_used_at = datetime.utcnow()
|
||||
key.total_requests += 1
|
||||
session.add(key)
|
||||
|
||||
# ==========================================================================
|
||||
# Async Request Operations
|
||||
# ==========================================================================
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
api_key: str,
|
||||
filename: str,
|
||||
file_size: int,
|
||||
content_type: str,
|
||||
expires_at: datetime,
|
||||
request_id: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new async request."""
|
||||
with get_session_context() as session:
|
||||
request = AsyncRequest(
|
||||
api_key=api_key,
|
||||
filename=filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
if request_id:
|
||||
request.request_id = UUID(request_id)
|
||||
session.add(request)
|
||||
session.flush() # To get the generated ID
|
||||
return str(request.request_id)
|
||||
|
||||
def get_request(self, request_id: str) -> AsyncRequest | None:
|
||||
"""Get a single async request by ID."""
|
||||
with get_session_context() as session:
|
||||
result = session.get(AsyncRequest, UUID(request_id))
|
||||
if result:
|
||||
# Detach from session for use outside context
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_request_by_api_key(
|
||||
self,
|
||||
request_id: str,
|
||||
api_key: str,
|
||||
) -> AsyncRequest | None:
|
||||
"""Get a request only if it belongs to the given API key."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.request_id == UUID(request_id),
|
||||
AsyncRequest.api_key == api_key,
|
||||
)
|
||||
result = session.exec(statement).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
request_id: str,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
increment_retry: bool = False,
|
||||
) -> None:
|
||||
"""Update request status."""
|
||||
with get_session_context() as session:
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if request:
|
||||
request.status = status
|
||||
if status == "processing":
|
||||
request.started_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
request.error_message = error_message
|
||||
if increment_retry:
|
||||
request.retry_count += 1
|
||||
session.add(request)
|
||||
|
||||
def complete_request(
|
||||
self,
|
||||
request_id: str,
|
||||
document_id: str,
|
||||
result: dict[str, Any],
|
||||
processing_time_ms: float,
|
||||
visualization_path: str | None = None,
|
||||
) -> None:
|
||||
"""Mark request as completed with result."""
|
||||
with get_session_context() as session:
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if request:
|
||||
request.status = "completed"
|
||||
request.document_id = document_id
|
||||
request.result = result
|
||||
request.processing_time_ms = processing_time_ms
|
||||
request.visualization_path = visualization_path
|
||||
request.completed_at = datetime.utcnow()
|
||||
session.add(request)
|
||||
|
||||
def get_requests_by_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[AsyncRequest], int]:
|
||||
"""Get paginated requests for an API key."""
|
||||
with get_session_context() as session:
|
||||
# Count query
|
||||
count_stmt = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key
|
||||
)
|
||||
if status:
|
||||
count_stmt = count_stmt.where(AsyncRequest.status == status)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
# Fetch query
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key
|
||||
)
|
||||
if status:
|
||||
statement = statement.where(AsyncRequest.status == status)
|
||||
statement = statement.order_by(AsyncRequest.created_at.desc())
|
||||
statement = statement.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
# Detach results from session
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results), total
|
||||
|
||||
def count_active_jobs(self, api_key: str) -> int:
|
||||
"""Count active (pending + processing) jobs for an API key."""
|
||||
with get_session_context() as session:
|
||||
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.api_key == api_key,
|
||||
AsyncRequest.status.in_(["pending", "processing"]),
|
||||
)
|
||||
return session.exec(statement).one()
|
||||
|
||||
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
|
||||
"""Get pending requests ordered by creation time."""
|
||||
with get_session_context() as session:
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.status == "pending"
|
||||
).order_by(AsyncRequest.created_at).limit(limit)
|
||||
results = session.exec(statement).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_queue_position(self, request_id: str) -> int | None:
|
||||
"""Get position of a request in the pending queue."""
|
||||
with get_session_context() as session:
|
||||
# Get the request's created_at
|
||||
request = session.get(AsyncRequest, UUID(request_id))
|
||||
if not request:
|
||||
return None
|
||||
|
||||
# Count pending requests created before this one
|
||||
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||
AsyncRequest.status == "pending",
|
||||
AsyncRequest.created_at < request.created_at,
|
||||
)
|
||||
count = session.exec(statement).one()
|
||||
return count + 1 # 1-based position
|
||||
|
||||
# ==========================================================================
|
||||
# Rate Limit Operations
|
||||
# ==========================================================================
|
||||
|
||||
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
|
||||
"""Record a rate limit event."""
|
||||
with get_session_context() as session:
|
||||
event = RateLimitEvent(
|
||||
api_key=api_key,
|
||||
event_type=event_type,
|
||||
)
|
||||
session.add(event)
|
||||
|
||||
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
|
||||
"""Count requests in the last N seconds."""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
|
||||
statement = select(func.count()).select_from(RateLimitEvent).where(
|
||||
RateLimitEvent.api_key == api_key,
|
||||
RateLimitEvent.event_type == "request",
|
||||
RateLimitEvent.created_at > cutoff,
|
||||
)
|
||||
return session.exec(statement).one()
|
||||
|
||||
# ==========================================================================
|
||||
# Cleanup Operations
|
||||
# ==========================================================================
|
||||
|
||||
def delete_expired_requests(self) -> int:
|
||||
"""Delete requests that have expired. Returns count of deleted rows."""
|
||||
with get_session_context() as session:
|
||||
now = datetime.utcnow()
|
||||
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
|
||||
expired = session.exec(statement).all()
|
||||
count = len(expired)
|
||||
for request in expired:
|
||||
session.delete(request)
|
||||
logger.info(f"Deleted {count} expired async requests")
|
||||
return count
|
||||
|
||||
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
|
||||
"""Delete rate limit events older than N hours."""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||
statement = select(RateLimitEvent).where(
|
||||
RateLimitEvent.created_at < cutoff
|
||||
)
|
||||
old_events = session.exec(statement).all()
|
||||
count = len(old_events)
|
||||
for event in old_events:
|
||||
session.delete(event)
|
||||
return count
|
||||
|
||||
def reset_stale_processing_requests(
|
||||
self,
|
||||
stale_minutes: int = 10,
|
||||
max_retries: int = 3,
|
||||
) -> int:
|
||||
"""
|
||||
Reset requests stuck in 'processing' status.
|
||||
|
||||
Requests that have been processing for more than stale_minutes
|
||||
are considered stale. They are either reset to 'pending' (if under
|
||||
max_retries) or set to 'failed'.
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
|
||||
reset_count = 0
|
||||
|
||||
# Find stale processing requests
|
||||
statement = select(AsyncRequest).where(
|
||||
AsyncRequest.status == "processing",
|
||||
AsyncRequest.started_at < cutoff,
|
||||
)
|
||||
stale_requests = session.exec(statement).all()
|
||||
|
||||
for request in stale_requests:
|
||||
if request.retry_count < max_retries:
|
||||
request.status = "pending"
|
||||
request.started_at = None
|
||||
else:
|
||||
request.status = "failed"
|
||||
request.error_message = "Processing timeout after max retries"
|
||||
session.add(request)
|
||||
reset_count += 1
|
||||
|
||||
if reset_count > 0:
|
||||
logger.warning(f"Reset {reset_count} stale processing requests")
|
||||
return reset_count
|
||||
103
src/data/database.py
Normal file
103
src/data/database.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Database Engine and Session Management
|
||||
|
||||
Provides SQLModel database engine and session handling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from src.config import get_db_connection_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global engine instance
|
||||
_engine = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Get or create the database engine."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
connection_string = get_db_connection_string()
|
||||
# Convert psycopg2 format to SQLAlchemy format
|
||||
if connection_string.startswith("postgresql://"):
|
||||
# Already in correct format
|
||||
pass
|
||||
elif "host=" in connection_string:
|
||||
# Convert DSN format to URL format
|
||||
parts = dict(item.split("=") for item in connection_string.split())
|
||||
connection_string = (
|
||||
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
|
||||
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
|
||||
f"/{parts.get('dbname', 'docmaster')}"
|
||||
)
|
||||
|
||||
_engine = create_engine(
|
||||
connection_string,
|
||||
echo=False, # Set to True for SQL debugging
|
||||
pool_pre_ping=True, # Verify connections before use
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def create_db_and_tables() -> None:
|
||||
"""Create all database tables."""
|
||||
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||
from src.data.admin_models import ( # noqa: F401
|
||||
AdminToken,
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
TrainingTask,
|
||||
TrainingLog,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
SQLModel.metadata.create_all(engine)
|
||||
logger.info("Database tables created/verified")
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
"""Get a new database session."""
|
||||
engine = get_engine()
|
||||
return Session(engine)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_context() -> Generator[Session, None, None]:
|
||||
"""Context manager for database sessions with auto-commit/rollback."""
|
||||
session = get_session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def close_engine() -> None:
|
||||
"""Close the database engine and release connections."""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
logger.info("Database engine closed")
|
||||
|
||||
|
||||
def execute_raw_sql(sql: str) -> None:
|
||||
"""Execute raw SQL (for migrations)."""
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(sql))
|
||||
conn.commit()
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string
|
||||
from src.config import get_db_connection_string
|
||||
|
||||
|
||||
class DocumentDB:
|
||||
|
||||
83
src/data/migrations/001_async_tables.sql
Normal file
83
src/data/migrations/001_async_tables.sql
Normal file
@@ -0,0 +1,83 @@
|
||||
-- Async Invoice Processing Tables
|
||||
-- Migration: 001_async_tables.sql
|
||||
-- Created: 2024-01-15
|
||||
|
||||
-- API Keys table for authentication and rate limiting
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
api_key TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
|
||||
-- Rate limits
|
||||
requests_per_minute INTEGER DEFAULT 10,
|
||||
max_concurrent_jobs INTEGER DEFAULT 3,
|
||||
max_file_size_mb INTEGER DEFAULT 50,
|
||||
|
||||
-- Usage tracking
|
||||
total_requests INTEGER DEFAULT 0,
|
||||
total_processed INTEGER DEFAULT 0,
|
||||
|
||||
-- Timestamps
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_used_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
-- Async processing requests table
|
||||
CREATE TABLE IF NOT EXISTS async_requests (
|
||||
request_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
filename TEXT NOT NULL,
|
||||
file_size INTEGER NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
|
||||
-- Processing metadata
|
||||
document_id TEXT,
|
||||
error_message TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
|
||||
-- Timestamps
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
started_at TIMESTAMPTZ,
|
||||
completed_at TIMESTAMPTZ,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
|
||||
-- Result storage (JSONB for flexibility)
|
||||
result JSONB,
|
||||
|
||||
-- Processing time
|
||||
processing_time_ms REAL,
|
||||
|
||||
-- Visualization path
|
||||
visualization_path TEXT,
|
||||
|
||||
CONSTRAINT valid_status CHECK (status IN ('pending', 'processing', 'completed', 'failed'))
|
||||
);
|
||||
|
||||
-- Indexes for async_requests
|
||||
CREATE INDEX IF NOT EXISTS idx_async_requests_api_key ON async_requests(api_key);
|
||||
CREATE INDEX IF NOT EXISTS idx_async_requests_status ON async_requests(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_async_requests_created_at ON async_requests(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_async_requests_expires_at ON async_requests(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_async_requests_api_key_status ON async_requests(api_key, status);
|
||||
|
||||
-- Rate limit tracking table
|
||||
CREATE TABLE IF NOT EXISTS rate_limit_events (
|
||||
id SERIAL PRIMARY KEY,
|
||||
api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE,
|
||||
event_type TEXT NOT NULL, -- 'request', 'complete', 'fail'
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Index for rate limiting queries (recent events only)
|
||||
CREATE INDEX IF NOT EXISTS idx_rate_limit_events_api_key_time
|
||||
ON rate_limit_events(api_key, created_at DESC);
|
||||
|
||||
-- Cleanup old rate limit events index
|
||||
CREATE INDEX IF NOT EXISTS idx_rate_limit_events_cleanup
|
||||
ON rate_limit_events(created_at);
|
||||
|
||||
-- Insert default API key for development/testing
|
||||
INSERT INTO api_keys (api_key, name, requests_per_minute, max_concurrent_jobs)
|
||||
VALUES ('dev-api-key-12345', 'Development Key', 100, 10)
|
||||
ON CONFLICT (api_key) DO NOTHING;
|
||||
5
src/data/migrations/002_nullable_admin_token.sql
Normal file
5
src/data/migrations/002_nullable_admin_token.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
-- Migration: Make admin_token nullable in admin_documents table
|
||||
-- This allows documents uploaded via public API to not require an admin token
|
||||
|
||||
ALTER TABLE admin_documents
|
||||
ALTER COLUMN admin_token DROP NOT NULL;
|
||||
95
src/data/models.py
Normal file
95
src/data/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
SQLModel Database Models
|
||||
|
||||
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, SQLModel, Column, JSON
|
||||
|
||||
|
||||
class ApiKey(SQLModel, table=True):
|
||||
"""API key configuration and limits."""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
api_key: str = Field(primary_key=True, max_length=255)
|
||||
name: str = Field(max_length=255)
|
||||
is_active: bool = Field(default=True)
|
||||
requests_per_minute: int = Field(default=10)
|
||||
max_concurrent_jobs: int = Field(default=3)
|
||||
max_file_size_mb: int = Field(default=50)
|
||||
total_requests: int = Field(default=0)
|
||||
total_processed: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class AsyncRequest(SQLModel, table=True):
|
||||
"""Async request record."""
|
||||
|
||||
__tablename__ = "async_requests"
|
||||
|
||||
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||
status: str = Field(default="pending", max_length=20, index=True)
|
||||
filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
content_type: str = Field(max_length=100)
|
||||
document_id: str | None = Field(default=None, max_length=100)
|
||||
error_message: str | None = Field(default=None)
|
||||
retry_count: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: datetime | None = Field(default=None)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime = Field(index=True)
|
||||
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
processing_time_ms: float | None = Field(default=None)
|
||||
visualization_path: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class RateLimitEvent(SQLModel, table=True):
|
||||
"""Rate limit event record."""
|
||||
|
||||
__tablename__ = "rate_limit_events"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||
event_type: str = Field(max_length=50)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
# Read-only models for responses (without table=True)
|
||||
class ApiKeyRead(SQLModel):
|
||||
"""API key response model (read-only)."""
|
||||
|
||||
api_key: str
|
||||
name: str
|
||||
is_active: bool
|
||||
requests_per_minute: int
|
||||
max_concurrent_jobs: int
|
||||
max_file_size_mb: int
|
||||
|
||||
|
||||
class AsyncRequestRead(SQLModel):
|
||||
"""Async request response model (read-only)."""
|
||||
|
||||
request_id: UUID
|
||||
api_key: str
|
||||
status: str
|
||||
filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
document_id: str | None
|
||||
error_message: str | None
|
||||
retry_count: int
|
||||
created_at: datetime
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
expires_at: datetime
|
||||
result: dict[str, Any] | None
|
||||
processing_time_ms: float | None
|
||||
visualization_path: str | None
|
||||
@@ -12,6 +12,8 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
|
||||
# Global OCR instance (initialized once per GPU worker process)
|
||||
_ocr_engine: Optional[Any] = None
|
||||
|
||||
@@ -94,7 +96,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
row_dict = task_data["row_dict"]
|
||||
pdf_path = Path(task_data["pdf_path"])
|
||||
output_dir = Path(task_data["output_dir"])
|
||||
dpi = task_data.get("dpi", 150)
|
||||
dpi = task_data.get("dpi", DEFAULT_DPI)
|
||||
min_confidence = task_data.get("min_confidence", 0.5)
|
||||
|
||||
start_time = time.time()
|
||||
@@ -212,7 +214,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
row_dict = task_data["row_dict"]
|
||||
pdf_path = Path(task_data["pdf_path"])
|
||||
output_dir = Path(task_data["output_dir"])
|
||||
dpi = task_data.get("dpi", 150)
|
||||
dpi = task_data.get("dpi", DEFAULT_DPI)
|
||||
min_confidence = task_data.get("min_confidence", 0.5)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -16,6 +16,8 @@ from datetime import datetime
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMExtractionResult:
|
||||
@@ -265,7 +267,7 @@ Return ONLY the JSON object, no other text."""
|
||||
self,
|
||||
pdf_path: Path,
|
||||
page_no: int = 0,
|
||||
dpi: int = 150,
|
||||
dpi: int = DEFAULT_DPI,
|
||||
max_size_mb: float = 18.0
|
||||
) -> bytes:
|
||||
"""
|
||||
|
||||
8
src/web/admin_routes_new.py
Normal file
8
src/web/admin_routes_new.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Backward compatibility shim for admin_routes.py
|
||||
|
||||
DEPRECATED: Import from src.web.api.v1.admin.documents instead.
|
||||
"""
|
||||
from src.web.api.v1.admin.documents import *
|
||||
|
||||
__all__ = ["create_admin_router"]
|
||||
0
src/web/api/__init__.py
Normal file
0
src/web/api/__init__.py
Normal file
0
src/web/api/v1/__init__.py
Normal file
0
src/web/api/v1/__init__.py
Normal file
19
src/web/api/v1/admin/__init__.py
Normal file
19
src/web/api/v1/admin/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Admin API v1
|
||||
|
||||
Document management, annotations, and training endpoints.
|
||||
"""
|
||||
|
||||
from src.web.api.v1.admin.annotations import create_annotation_router
|
||||
from src.web.api.v1.admin.auth import create_auth_router
|
||||
from src.web.api.v1.admin.documents import create_documents_router
|
||||
from src.web.api.v1.admin.locks import create_locks_router
|
||||
from src.web.api.v1.admin.training import create_training_router
|
||||
|
||||
__all__ = [
|
||||
"create_annotation_router",
|
||||
"create_auth_router",
|
||||
"create_documents_router",
|
||||
"create_locks_router",
|
||||
"create_training_router",
|
||||
]
|
||||
644
src/web/api/v1/admin/annotations.py
Normal file
644
src/web/api/v1/admin/annotations.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
Admin Annotation API Routes
|
||||
|
||||
FastAPI endpoints for annotation management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.services.autolabel import get_auto_label_service
|
||||
from src.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationItem,
|
||||
AnnotationListResponse,
|
||||
AnnotationOverrideRequest,
|
||||
AnnotationOverrideResponse,
|
||||
AnnotationResponse,
|
||||
AnnotationSource,
|
||||
AnnotationUpdate,
|
||||
AnnotationVerifyRequest,
|
||||
AnnotationVerifyResponse,
|
||||
AutoLabelRequest,
|
||||
AutoLabelResponse,
|
||||
BoundingBox,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Image storage directory
|
||||
ADMIN_IMAGES_DIR = Path("data/admin_images")
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_annotation_router() -> APIRouter:
|
||||
"""Create annotation API router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
|
||||
|
||||
# =========================================================================
|
||||
# Image Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/images/{page_number}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Get page image",
|
||||
description="Get the image for a specific page.",
|
||||
)
|
||||
async def get_page_image(
|
||||
document_id: str,
|
||||
page_number: int,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> FileResponse:
|
||||
"""Get page image."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
if page_number < 1 or page_number > document.page_count:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
||||
)
|
||||
|
||||
# Find image file
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
|
||||
if not image_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Image for page {page_number} not found",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=f"{document.filename}_page_{page_number}.png",
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Annotation Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/{document_id}/annotations",
|
||||
response_model=AnnotationListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="List annotations",
|
||||
description="Get all annotations for a document.",
|
||||
)
|
||||
async def list_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
page_number: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Filter by page number"),
|
||||
] = None,
|
||||
) -> AnnotationListResponse:
|
||||
"""List annotations for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = db.get_annotations_for_document(document_id, page_number)
|
||||
annotations = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
page_number=ann.page_number,
|
||||
class_id=ann.class_id,
|
||||
class_name=ann.class_name,
|
||||
bbox=BoundingBox(
|
||||
x=ann.bbox_x,
|
||||
y=ann.bbox_y,
|
||||
width=ann.bbox_width,
|
||||
height=ann.bbox_height,
|
||||
),
|
||||
normalized_bbox={
|
||||
"x_center": ann.x_center,
|
||||
"y_center": ann.y_center,
|
||||
"width": ann.width,
|
||||
"height": ann.height,
|
||||
},
|
||||
text_value=ann.text_value,
|
||||
confidence=ann.confidence,
|
||||
source=AnnotationSource(ann.source),
|
||||
created_at=ann.created_at,
|
||||
)
|
||||
for ann in raw_annotations
|
||||
]
|
||||
|
||||
return AnnotationListResponse(
|
||||
document_id=document_id,
|
||||
page_count=document.page_count,
|
||||
total_annotations=len(annotations),
|
||||
annotations=annotations,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/annotations",
|
||||
response_model=AnnotationResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Create annotation",
|
||||
description="Create a new annotation for a document.",
|
||||
)
|
||||
async def create_annotation(
|
||||
document_id: str,
|
||||
request: AnnotationCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Create a new annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Validate page number
|
||||
if request.page_number > document.page_count:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
|
||||
)
|
||||
|
||||
# Get image dimensions for normalization
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
|
||||
if not image_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image for page {request.page_number} not available",
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
|
||||
# Calculate normalized coordinates
|
||||
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||
width = request.bbox.width / image_width
|
||||
height = request.bbox.height / image_height
|
||||
|
||||
# Get class name
|
||||
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
|
||||
|
||||
# Create annotation
|
||||
annotation_id = db.create_annotation(
|
||||
document_id=document_id,
|
||||
page_number=request.page_number,
|
||||
class_id=request.class_id,
|
||||
class_name=class_name,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
bbox_x=request.bbox.x,
|
||||
bbox_y=request.bbox.y,
|
||||
bbox_width=request.bbox.width,
|
||||
bbox_height=request.bbox.height,
|
||||
text_value=request.text_value,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
# Keep status as pending - user must click "Mark Complete" to finalize
|
||||
# This allows user to add multiple annotations before saving to PostgreSQL
|
||||
|
||||
return AnnotationResponse(
|
||||
annotation_id=annotation_id,
|
||||
message="Annotation created successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/annotations/{annotation_id}",
|
||||
response_model=AnnotationResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Update annotation",
|
||||
description="Update an existing annotation.",
|
||||
)
|
||||
async def update_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
request: AnnotationUpdate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AnnotationResponse:
|
||||
"""Update an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = db.get_annotation(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Verify annotation belongs to document
|
||||
if str(annotation.document_id) != document_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation does not belong to this document",
|
||||
)
|
||||
|
||||
# Prepare update data
|
||||
update_kwargs = {}
|
||||
|
||||
if request.class_id is not None:
|
||||
update_kwargs["class_id"] = request.class_id
|
||||
update_kwargs["class_name"] = FIELD_CLASSES.get(
|
||||
request.class_id, f"class_{request.class_id}"
|
||||
)
|
||||
|
||||
if request.text_value is not None:
|
||||
update_kwargs["text_value"] = request.text_value
|
||||
|
||||
if request.bbox is not None:
|
||||
# Get image dimensions
|
||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
|
||||
from PIL import Image
|
||||
with Image.open(image_path) as img:
|
||||
image_width, image_height = img.size
|
||||
|
||||
# Calculate normalized coordinates
|
||||
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||
update_kwargs["width"] = request.bbox.width / image_width
|
||||
update_kwargs["height"] = request.bbox.height / image_height
|
||||
update_kwargs["bbox_x"] = request.bbox.x
|
||||
update_kwargs["bbox_y"] = request.bbox.y
|
||||
update_kwargs["bbox_width"] = request.bbox.width
|
||||
update_kwargs["bbox_height"] = request.bbox.height
|
||||
|
||||
# Update annotation
|
||||
if update_kwargs:
|
||||
success = db.update_annotation(annotation_id, **update_kwargs)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update annotation",
|
||||
)
|
||||
|
||||
return AnnotationResponse(
|
||||
annotation_id=annotation_id,
|
||||
message="Annotation updated successfully",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/annotations/{annotation_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Not found"},
|
||||
},
|
||||
summary="Delete annotation",
|
||||
description="Delete an annotation.",
|
||||
)
|
||||
async def delete_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get existing annotation
|
||||
annotation = db.get_annotation(annotation_id)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Verify annotation belongs to document
|
||||
if str(annotation.document_id) != document_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation does not belong to this document",
|
||||
)
|
||||
|
||||
# Delete annotation
|
||||
db.delete_annotation(annotation_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"annotation_id": annotation_id,
|
||||
"message": "Annotation deleted successfully",
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Auto-Labeling Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/auto-label",
|
||||
response_model=AutoLabelResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Trigger auto-labeling",
|
||||
description="Trigger auto-labeling for a document using field values.",
|
||||
)
|
||||
async def trigger_auto_label(
|
||||
document_id: str,
|
||||
request: AutoLabelRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AutoLabelResponse:
|
||||
"""Trigger auto-labeling for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Validate field values
|
||||
if not request.field_values:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one field value is required",
|
||||
)
|
||||
|
||||
# Run auto-labeling
|
||||
service = get_auto_label_service()
|
||||
result = service.auto_label_document(
|
||||
document_id=document_id,
|
||||
file_path=document.file_path,
|
||||
field_values=request.field_values,
|
||||
db=db,
|
||||
replace_existing=request.replace_existing,
|
||||
)
|
||||
|
||||
if result["status"] == "failed":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return AutoLabelResponse(
|
||||
document_id=document_id,
|
||||
status=result["status"],
|
||||
annotations_created=result["annotations_created"],
|
||||
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/annotations",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Delete all annotations",
|
||||
description="Delete all annotations for a document (optionally filter by source).",
|
||||
)
|
||||
async def delete_all_annotations(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
source: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by source (manual, auto, imported)"),
|
||||
] = None,
|
||||
) -> dict:
|
||||
"""Delete all annotations for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate source
|
||||
if source and source not in ("manual", "auto", "imported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid source: {source}",
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Delete annotations
|
||||
deleted_count = db.delete_annotations_for_document(document_id, source)
|
||||
|
||||
# Update document status if all annotations deleted
|
||||
remaining = db.get_annotations_for_document(document_id)
|
||||
if not remaining:
|
||||
db.update_document_status(document_id, "pending")
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"document_id": document_id,
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"Deleted {deleted_count} annotations",
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Phase 5: Annotation Enhancement
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/annotations/{annotation_id}/verify",
|
||||
response_model=AnnotationVerifyResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||
},
|
||||
summary="Verify annotation",
|
||||
description="Mark an annotation as verified by a human reviewer.",
|
||||
)
|
||||
async def verify_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
|
||||
) -> AnnotationVerifyResponse:
|
||||
"""Verify an annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership of document
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Verify the annotation
|
||||
annotation = db.verify_annotation(annotation_id, admin_token)
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
return AnnotationVerifyResponse(
|
||||
annotation_id=annotation_id,
|
||||
is_verified=annotation.is_verified,
|
||||
verified_at=annotation.verified_at,
|
||||
verified_by=annotation.verified_by,
|
||||
message="Annotation verified successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/annotations/{annotation_id}/override",
|
||||
response_model=AnnotationOverrideResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||
},
|
||||
summary="Override annotation",
|
||||
description="Override an auto-generated annotation with manual corrections.",
|
||||
)
|
||||
async def override_annotation(
|
||||
document_id: str,
|
||||
annotation_id: str,
|
||||
request: AnnotationOverrideRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> AnnotationOverrideResponse:
|
||||
"""Override an auto-generated annotation."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
_validate_uuid(annotation_id, "annotation_id")
|
||||
|
||||
# Verify ownership of document
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Build updates dict from request
|
||||
updates = {}
|
||||
if request.text_value is not None:
|
||||
updates["text_value"] = request.text_value
|
||||
if request.class_id is not None:
|
||||
updates["class_id"] = request.class_id
|
||||
# Update class_name if class_id changed
|
||||
if request.class_id in FIELD_CLASSES:
|
||||
updates["class_name"] = FIELD_CLASSES[request.class_id]
|
||||
if request.class_name is not None:
|
||||
updates["class_name"] = request.class_name
|
||||
if request.bbox:
|
||||
# Update bbox fields
|
||||
if "x" in request.bbox:
|
||||
updates["bbox_x"] = request.bbox["x"]
|
||||
if "y" in request.bbox:
|
||||
updates["bbox_y"] = request.bbox["y"]
|
||||
if "width" in request.bbox:
|
||||
updates["bbox_width"] = request.bbox["width"]
|
||||
if "height" in request.bbox:
|
||||
updates["bbox_height"] = request.bbox["height"]
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No updates provided. Specify at least one field to update.",
|
||||
)
|
||||
|
||||
# Override the annotation
|
||||
annotation = db.override_annotation(
|
||||
annotation_id=annotation_id,
|
||||
admin_token=admin_token,
|
||||
change_reason=request.reason,
|
||||
**updates,
|
||||
)
|
||||
|
||||
if annotation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Annotation not found",
|
||||
)
|
||||
|
||||
# Get history to return history_id
|
||||
history_records = db.get_annotation_history(UUID(annotation_id))
|
||||
latest_history = history_records[0] if history_records else None
|
||||
|
||||
return AnnotationOverrideResponse(
|
||||
annotation_id=annotation_id,
|
||||
source=annotation.source,
|
||||
override_source=annotation.override_source,
|
||||
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
|
||||
message="Annotation overridden successfully",
|
||||
history_id=str(latest_history.history_id) if latest_history else "",
|
||||
)
|
||||
|
||||
return router
|
||||
82
src/web/api/v1/admin/auth.py
Normal file
82
src/web/api/v1/admin/auth.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Admin Auth Routes
|
||||
|
||||
FastAPI endpoints for admin token management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
AdminTokenCreate,
|
||||
AdminTokenResponse,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_auth_router() -> APIRouter:
|
||||
"""Create admin auth router."""
|
||||
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
|
||||
|
||||
@router.post(
|
||||
"/token",
|
||||
response_model=AdminTokenResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
},
|
||||
summary="Create admin token",
|
||||
description="Create a new admin authentication token.",
|
||||
)
|
||||
async def create_token(
|
||||
request: AdminTokenCreate,
|
||||
db: AdminDBDep,
|
||||
) -> AdminTokenResponse:
|
||||
"""Create a new admin token."""
|
||||
# Generate secure token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if request.expires_in_days:
|
||||
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
|
||||
|
||||
# Create token in database
|
||||
db.create_admin_token(
|
||||
token=token,
|
||||
name=request.name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
return AdminTokenResponse(
|
||||
token=token,
|
||||
name=request.name,
|
||||
expires_at=expires_at,
|
||||
message="Admin token created successfully",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/token",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Revoke admin token",
|
||||
description="Revoke the current admin token.",
|
||||
)
|
||||
async def revoke_token(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Revoke the current admin token."""
|
||||
db.deactivate_admin_token(admin_token)
|
||||
return {
|
||||
"status": "revoked",
|
||||
"message": "Admin token has been revoked",
|
||||
}
|
||||
|
||||
return router
|
||||
551
src/web/api/v1/admin/documents.py
Normal file
551
src/web/api/v1/admin/documents.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""
|
||||
Admin Document Routes
|
||||
|
||||
FastAPI endpoints for admin document management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from src.web.config import DEFAULT_DPI, StorageConfig
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
AnnotationItem,
|
||||
AnnotationSource,
|
||||
AutoLabelStatus,
|
||||
BoundingBox,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
DocumentListResponse,
|
||||
DocumentStatus,
|
||||
DocumentStatsResponse,
|
||||
DocumentUploadResponse,
|
||||
ModelMetrics,
|
||||
TrainingHistoryItem,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = pdf_doc[page_num]
|
||||
# Render at configured DPI for consistency with training
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
|
||||
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
"""Create admin documents router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
|
||||
|
||||
# Directories are created by StorageConfig.__post_init__
|
||||
allowed_extensions = storage_config.allowed_extensions
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=DocumentUploadResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Upload document",
|
||||
description="Upload a PDF or image document for labeling.",
|
||||
)
|
||||
async def upload_document(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
file: UploadFile = File(..., description="PDF or image file"),
|
||||
auto_label: Annotated[
|
||||
bool,
|
||||
Query(description="Trigger auto-labeling after upload"),
|
||||
] = True,
|
||||
) -> DocumentUploadResponse:
|
||||
"""Upload a document for labeling."""
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
|
||||
# Validate extension
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {file_ext}. "
|
||||
f"Allowed: {', '.join(allowed_extensions)}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||
|
||||
# Get page count (for PDF)
|
||||
page_count = 1
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import fitz
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
page_count = len(pdf_doc)
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record (token only used for auth, not stored)
|
||||
document_id = db.create_document(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||
|
||||
# Update file path in database
|
||||
from src.data.database import get_session_context
|
||||
from src.data.admin_models import AdminDocument
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if doc:
|
||||
doc.file_path = str(file_path)
|
||||
session.add(doc)
|
||||
|
||||
# Convert PDF to images for annotation
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
# Trigger auto-labeling if requested
|
||||
auto_label_started = False
|
||||
if auto_label:
|
||||
# Auto-labeling will be triggered by a background task
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="running",
|
||||
)
|
||||
auto_label_started = True
|
||||
|
||||
return DocumentUploadResponse(
|
||||
document_id=document_id,
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
page_count=page_count,
|
||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||
auto_label_started=auto_label_started,
|
||||
message="Document uploaded successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=DocumentListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List documents",
|
||||
description="List all documents for the current admin.",
|
||||
)
|
||||
async def list_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
upload_source: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by upload source (ui or api)"),
|
||||
] = None,
|
||||
has_annotations: Annotated[
|
||||
bool | None,
|
||||
Query(description="Filter by annotation presence"),
|
||||
] = None,
|
||||
auto_label_status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by auto-label status"),
|
||||
] = None,
|
||||
batch_id: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by batch ID"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> DocumentListResponse:
|
||||
"""List documents."""
|
||||
# Validate status
|
||||
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}",
|
||||
)
|
||||
|
||||
# Validate upload_source
|
||||
if upload_source and upload_source not in ("ui", "api"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid upload_source: {upload_source}",
|
||||
)
|
||||
|
||||
# Validate auto_label_status
|
||||
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid auto_label_status: {auto_label_status}",
|
||||
)
|
||||
|
||||
documents, total = db.get_documents_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
upload_source=upload_source,
|
||||
has_annotations=has_annotations,
|
||||
auto_label_status=auto_label_status,
|
||||
batch_id=batch_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Get annotation counts and build items
|
||||
items = []
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
# Determine if document can be annotated (not locked)
|
||||
can_annotate = True
|
||||
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
|
||||
from datetime import datetime, timezone
|
||||
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
|
||||
|
||||
items.append(
|
||||
DocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
file_size=doc.file_size,
|
||||
page_count=doc.page_count,
|
||||
status=DocumentStatus(doc.status),
|
||||
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
|
||||
annotation_count=len(annotations),
|
||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||
can_annotate=can_annotate,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return DocumentListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=DocumentStatsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get document statistics",
|
||||
description="Get document count by status.",
|
||||
)
|
||||
async def get_document_stats(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DocumentStatsResponse:
|
||||
"""Get document statistics."""
|
||||
counts = db.count_documents_by_status(admin_token)
|
||||
|
||||
return DocumentStatsResponse(
|
||||
total=sum(counts.values()),
|
||||
pending=counts.get("pending", 0),
|
||||
auto_labeling=counts.get("auto_labeling", 0),
|
||||
labeled=counts.get("labeled", 0),
|
||||
exported=counts.get("exported", 0),
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{document_id}",
|
||||
response_model=DocumentDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Get document detail",
|
||||
description="Get document details with annotations.",
|
||||
)
|
||||
async def get_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DocumentDetailResponse:
|
||||
"""Get document details."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get annotations
|
||||
raw_annotations = db.get_annotations_for_document(document_id)
|
||||
annotations = [
|
||||
AnnotationItem(
|
||||
annotation_id=str(ann.annotation_id),
|
||||
page_number=ann.page_number,
|
||||
class_id=ann.class_id,
|
||||
class_name=ann.class_name,
|
||||
bbox=BoundingBox(
|
||||
x=ann.bbox_x,
|
||||
y=ann.bbox_y,
|
||||
width=ann.bbox_width,
|
||||
height=ann.bbox_height,
|
||||
),
|
||||
normalized_bbox={
|
||||
"x_center": ann.x_center,
|
||||
"y_center": ann.y_center,
|
||||
"width": ann.width,
|
||||
"height": ann.height,
|
||||
},
|
||||
text_value=ann.text_value,
|
||||
confidence=ann.confidence,
|
||||
source=AnnotationSource(ann.source),
|
||||
created_at=ann.created_at,
|
||||
)
|
||||
for ann in raw_annotations
|
||||
]
|
||||
|
||||
# Generate image URLs
|
||||
image_urls = []
|
||||
for page in range(1, document.page_count + 1):
|
||||
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
|
||||
|
||||
# Determine if document can be annotated (not locked)
|
||||
can_annotate = True
|
||||
annotation_lock_until = None
|
||||
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
||||
from datetime import datetime, timezone
|
||||
annotation_lock_until = document.annotation_lock_until
|
||||
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
|
||||
|
||||
# Get CSV field values if available
|
||||
csv_field_values = None
|
||||
if hasattr(document, 'csv_field_values') and document.csv_field_values:
|
||||
csv_field_values = document.csv_field_values
|
||||
|
||||
# Get training history (Phase 5)
|
||||
training_history = []
|
||||
training_links = db.get_document_training_tasks(document.document_id)
|
||||
for link in training_links:
|
||||
# Get task details
|
||||
task = db.get_training_task(str(link.task_id))
|
||||
if task:
|
||||
# Build metrics
|
||||
metrics = None
|
||||
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
training_history.append(
|
||||
TrainingHistoryItem(
|
||||
task_id=str(link.task_id),
|
||||
name=task.name,
|
||||
trained_at=link.created_at,
|
||||
model_metrics=metrics,
|
||||
)
|
||||
)
|
||||
|
||||
return DocumentDetailResponse(
|
||||
document_id=str(document.document_id),
|
||||
filename=document.filename,
|
||||
file_size=document.file_size,
|
||||
content_type=document.content_type,
|
||||
page_count=document.page_count,
|
||||
status=DocumentStatus(document.status),
|
||||
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
|
||||
auto_label_error=document.auto_label_error,
|
||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||
csv_field_values=csv_field_values,
|
||||
can_annotate=can_annotate,
|
||||
annotation_lock_until=annotation_lock_until,
|
||||
annotations=annotations,
|
||||
image_urls=image_urls,
|
||||
training_history=training_history,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Delete document",
|
||||
description="Delete a document and its annotations.",
|
||||
)
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Delete file
|
||||
file_path = Path(document.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
# Delete images
|
||||
images_dir = ADMIN_IMAGES_DIR / document_id
|
||||
if images_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(images_dir)
|
||||
|
||||
# Delete from database
|
||||
db.delete_document(document_id)
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"document_id": document_id,
|
||||
"message": "Document deleted successfully",
|
||||
}
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/status",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Update document status",
|
||||
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
|
||||
)
|
||||
async def update_document_status(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str,
|
||||
Query(description="New status"),
|
||||
],
|
||||
) -> dict:
|
||||
"""Update document status.
|
||||
|
||||
When status is set to 'labeled', the annotations are automatically
|
||||
saved to PostgreSQL documents/field_results tables for consistency
|
||||
with CLI auto-label workflow.
|
||||
"""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Validate status
|
||||
if status not in ("pending", "labeled", "exported"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}",
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# If marking as labeled, save annotations to PostgreSQL DocumentDB
|
||||
db_save_result = None
|
||||
if status == "labeled":
|
||||
from src.web.services.db_autolabel import save_manual_annotations_to_document_db
|
||||
|
||||
# Get all annotations for this document
|
||||
annotations = db.get_annotations_for_document(document_id)
|
||||
|
||||
if annotations:
|
||||
db_save_result = save_manual_annotations_to_document_db(
|
||||
document=document,
|
||||
annotations=annotations,
|
||||
db=db,
|
||||
)
|
||||
|
||||
db.update_document_status(document_id, status)
|
||||
|
||||
response = {
|
||||
"status": "updated",
|
||||
"document_id": document_id,
|
||||
"new_status": status,
|
||||
"message": "Document status updated",
|
||||
}
|
||||
|
||||
# Include PostgreSQL save result if applicable
|
||||
if db_save_result:
|
||||
response["document_db_saved"] = db_save_result.get("success", False)
|
||||
response["fields_saved"] = db_save_result.get("fields_saved", 0)
|
||||
|
||||
return response
|
||||
|
||||
return router
|
||||
184
src/web/api/v1/admin/locks.py
Normal file
184
src/web/api/v1/admin/locks.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Admin Document Lock Routes
|
||||
|
||||
FastAPI endpoints for annotation lock management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
AnnotationLockRequest,
|
||||
AnnotationLockResponse,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_locks_router() -> APIRouter:
|
||||
"""Create annotation locks router."""
|
||||
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
|
||||
|
||||
@router.post(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
409: {"model": ErrorResponse, "description": "Document already locked"},
|
||||
},
|
||||
summary="Acquire annotation lock",
|
||||
description="Acquire a lock on a document to prevent concurrent annotation edits.",
|
||||
)
|
||||
async def acquire_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Acquire annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Attempt to acquire lock
|
||||
updated_doc = db.acquire_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
duration_seconds=request.duration_seconds,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Document is already locked. Please try again later.",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=True,
|
||||
lock_expires_at=updated_doc.annotation_lock_until,
|
||||
message=f"Lock acquired for {request.duration_seconds} seconds",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
},
|
||||
summary="Release annotation lock",
|
||||
description="Release the annotation lock on a document.",
|
||||
)
|
||||
async def release_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query(description="Force release (admin override)"),
|
||||
] = False,
|
||||
) -> AnnotationLockResponse:
|
||||
"""Release annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Release lock
|
||||
updated_doc = db.release_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
force=force,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Failed to release lock",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=False,
|
||||
lock_expires_at=None,
|
||||
message="Lock released successfully",
|
||||
)
|
||||
|
||||
@router.patch(
|
||||
"/{document_id}/lock",
|
||||
response_model=AnnotationLockResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
|
||||
},
|
||||
summary="Extend annotation lock",
|
||||
description="Extend an existing annotation lock.",
|
||||
)
|
||||
async def extend_lock(
|
||||
document_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||
) -> AnnotationLockResponse:
|
||||
"""Extend annotation lock for a document."""
|
||||
_validate_uuid(document_id, "document_id")
|
||||
|
||||
# Verify ownership
|
||||
document = db.get_document_by_token(document_id, admin_token)
|
||||
if document is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Document not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Attempt to extend lock
|
||||
updated_doc = db.extend_annotation_lock(
|
||||
document_id=document_id,
|
||||
admin_token=admin_token,
|
||||
additional_seconds=request.duration_seconds,
|
||||
)
|
||||
|
||||
if updated_doc is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
|
||||
)
|
||||
|
||||
return AnnotationLockResponse(
|
||||
document_id=document_id,
|
||||
locked=True,
|
||||
lock_expires_at=updated_doc.annotation_lock_until,
|
||||
message=f"Lock extended by {request.duration_seconds} seconds",
|
||||
)
|
||||
|
||||
return router
|
||||
622
src/web/api/v1/admin/training.py
Normal file
622
src/web/api/v1/admin/training.py
Normal file
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
Admin Training API Routes
|
||||
|
||||
FastAPI endpoints for training task management and scheduling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
ExportRequest,
|
||||
ExportResponse,
|
||||
ModelMetrics,
|
||||
TrainingConfig,
|
||||
TrainingDocumentItem,
|
||||
TrainingDocumentsResponse,
|
||||
TrainingHistoryItem,
|
||||
TrainingLogItem,
|
||||
TrainingLogsResponse,
|
||||
TrainingModelItem,
|
||||
TrainingModelsResponse,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
TrainingTaskDetailResponse,
|
||||
TrainingTaskItem,
|
||||
TrainingTaskListResponse,
|
||||
TrainingTaskResponse,
|
||||
TrainingType,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
"""Create training API router."""
|
||||
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
|
||||
|
||||
# =========================================================================
|
||||
# Training Task Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Create training task",
|
||||
description="Create a new training task.",
|
||||
)
|
||||
async def create_training_task(
|
||||
request: TrainingTaskCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a new training task."""
|
||||
# Convert config to dict
|
||||
config_dict = request.config.model_dump() if request.config else {}
|
||||
|
||||
# Create task
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type=request.task_type.value,
|
||||
description=request.description,
|
||||
config=config_dict,
|
||||
scheduled_at=request.scheduled_at,
|
||||
cron_expression=request.cron_expression,
|
||||
is_recurring=bool(request.cron_expression),
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
|
||||
message="Training task created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List training tasks",
|
||||
description="List all training tasks.",
|
||||
)
|
||||
async def list_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingTaskListResponse:
|
||||
"""List training tasks."""
|
||||
# Validate status
|
||||
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
|
||||
if status and status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = [
|
||||
TrainingTaskItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
scheduled_at=task.scheduled_at,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
return TrainingTaskListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tasks=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
response_model=TrainingTaskDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training task detail",
|
||||
description="Get training task details.",
|
||||
)
|
||||
async def get_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskDetailResponse:
|
||||
"""Get training task details."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
return TrainingTaskDetailResponse(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
config=task.config,
|
||||
scheduled_at=task.scheduled_at,
|
||||
cron_expression=task.cron_expression,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
error_message=task.error_message,
|
||||
result_metrics=task.result_metrics,
|
||||
model_path=task.model_path,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/tasks/{task_id}/cancel",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
|
||||
},
|
||||
summary="Cancel training task",
|
||||
description="Cancel a pending or scheduled training task.",
|
||||
)
|
||||
async def cancel_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Cancel a training task."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
# Verify ownership
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Check if can be cancelled
|
||||
if task.status not in ("pending", "scheduled"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Cannot cancel task with status: {task.status}",
|
||||
)
|
||||
|
||||
# Cancel task
|
||||
success = db.cancel_training_task(task_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to cancel training task",
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.CANCELLED,
|
||||
message="Training task cancelled successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/logs",
|
||||
response_model=TrainingLogsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training logs",
|
||||
description="Get training task logs.",
|
||||
)
|
||||
async def get_training_logs(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=500, description="Maximum logs to return"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingLogsResponse:
|
||||
"""Get training logs."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
# Verify ownership
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Get logs
|
||||
logs = db.get_training_logs(task_id, limit, offset)
|
||||
|
||||
items = [
|
||||
TrainingLogItem(
|
||||
level=log.level,
|
||||
message=log.message,
|
||||
details=log.details,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return TrainingLogsResponse(
|
||||
task_id=task_id,
|
||||
logs=items,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Phase 4: Training Data Management
|
||||
# =========================================================================
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=TrainingDocumentsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get documents for training",
|
||||
description="Get labeled documents available for training with filtering options.",
|
||||
)
|
||||
async def get_training_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
has_annotations: Annotated[
|
||||
bool,
|
||||
Query(description="Only include documents with annotations"),
|
||||
] = True,
|
||||
min_annotation_count: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Minimum annotation count"),
|
||||
] = None,
|
||||
exclude_used_in_training: Annotated[
|
||||
bool,
|
||||
Query(description="Exclude documents already used in training"),
|
||||
] = False,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingDocumentsResponse:
|
||||
"""Get documents available for training."""
|
||||
# Get documents
|
||||
documents, total = db.get_documents_for_training(
|
||||
admin_token=admin_token,
|
||||
status="labeled",
|
||||
has_annotations=has_annotations,
|
||||
min_annotation_count=min_annotation_count,
|
||||
exclude_used_in_training=exclude_used_in_training,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Build response items with annotation details and training history
|
||||
items = []
|
||||
for doc in documents:
|
||||
# Get annotations for this document
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
# Count annotations by source
|
||||
sources = {"manual": 0, "auto": 0}
|
||||
for ann in annotations:
|
||||
if ann.source in sources:
|
||||
sources[ann.source] += 1
|
||||
|
||||
# Get training history
|
||||
training_links = db.get_document_training_tasks(doc.document_id)
|
||||
used_in_training = [str(link.task_id) for link in training_links]
|
||||
|
||||
items.append(
|
||||
TrainingDocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
annotation_count=len(annotations),
|
||||
annotation_sources=sources,
|
||||
used_in_training=used_in_training,
|
||||
last_modified=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingDocumentsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{task_id}/download",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Model not found"},
|
||||
},
|
||||
summary="Download trained model",
|
||||
description="Download trained model weights file.",
|
||||
)
|
||||
async def download_model(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
):
|
||||
"""Download trained model."""
|
||||
from fastapi.responses import FileResponse
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
# Verify ownership
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
# Check if model exists
|
||||
if not task.model_path:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not available for this task",
|
||||
)
|
||||
|
||||
model_path = Path(task.model_path)
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not found on disk",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(model_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=f"{task.name}_model.pt",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=TrainingModelsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get trained models",
|
||||
description="Get list of trained models with metrics and download links.",
|
||||
)
|
||||
async def get_training_models(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (completed, failed, etc.)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingModelsResponse:
|
||||
"""Get list of trained models."""
|
||||
# Get training tasks
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status if status else "completed",
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Build response items
|
||||
items = []
|
||||
for task in tasks:
|
||||
# Build metrics
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
# Build download URL if model exists
|
||||
download_url = None
|
||||
if task.model_path and task.status == "completed":
|
||||
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
||||
|
||||
items.append(
|
||||
TrainingModelItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
status=TrainingStatus(task.status),
|
||||
document_count=task.document_count,
|
||||
created_at=task.created_at,
|
||||
completed_at=task.completed_at,
|
||||
metrics=metrics,
|
||||
model_path=task.model_path,
|
||||
download_url=download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingModelsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=items,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Export Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post(
|
||||
"/export",
|
||||
response_model=ExportResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Export annotations",
|
||||
description="Export annotations in YOLO format for training.",
|
||||
)
|
||||
async def export_annotations(
|
||||
request: ExportRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
# Validate format
|
||||
if request.format not in ("yolo", "coco", "voc"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported export format: {request.format}",
|
||||
)
|
||||
|
||||
# Get labeled documents
|
||||
documents = db.get_labeled_documents_for_export(admin_token)
|
||||
|
||||
if not documents:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No labeled documents available for export",
|
||||
)
|
||||
|
||||
# Create export directory
|
||||
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO format directories
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Calculate train/val split
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * request.split_ratio)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
# Export documents
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
# Get annotations
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
|
||||
# Export each page
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
# Copy image
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
continue
|
||||
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
total_images += 1
|
||||
|
||||
# Write YOLO label file
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
# YOLO format: class_id x_center y_center width height
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
# Create data.yaml
|
||||
from src.data.admin_models import FIELD_CLASSES
|
||||
|
||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||
path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
(export_dir / "data.yaml").write_text(yaml_content)
|
||||
|
||||
return ExportResponse(
|
||||
status="completed",
|
||||
export_path=str(export_dir),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
train_count=len(train_docs),
|
||||
val_count=len(val_docs),
|
||||
message=f"Exported {total_images} images with {total_annotations} annotations",
|
||||
)
|
||||
|
||||
return router
|
||||
0
src/web/api/v1/batch/__init__.py
Normal file
0
src/web/api/v1/batch/__init__.py
Normal file
236
src/web/api/v1/batch/routes.py
Normal file
236
src/web/api/v1/batch/routes.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Batch Upload API Routes
|
||||
|
||||
Endpoints for batch uploading documents via ZIP files with CSV metadata.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from src.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
|
||||
from src.web.workers.batch_queue import BatchTask, get_batch_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_batch(
|
||||
file: UploadFile = File(...),
|
||||
upload_source: str = Form(default="ui"),
|
||||
async_mode: bool = Form(default=True),
|
||||
auto_label: bool = Form(default=True),
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
) -> dict:
|
||||
"""Upload a batch of documents via ZIP file.
|
||||
|
||||
The ZIP file can contain:
|
||||
- Multiple PDF files
|
||||
- Optional CSV file with field values for auto-labeling
|
||||
|
||||
CSV format:
|
||||
- Required column: DocumentId (matches PDF filename without extension)
|
||||
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
|
||||
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
|
||||
|
||||
Args:
|
||||
file: ZIP file upload
|
||||
upload_source: Upload source (ui or api)
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
|
||||
Returns:
|
||||
Batch upload result with batch_id and status
|
||||
"""
|
||||
if not file.filename.lower().endswith('.zip'):
|
||||
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
|
||||
|
||||
# Check compressed size
|
||||
if file.size and file.size > MAX_COMPRESSED_SIZE:
|
||||
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size exceeds {max_mb:.0f}MB limit"
|
||||
)
|
||||
|
||||
try:
|
||||
# Read file content
|
||||
zip_content = await file.read()
|
||||
|
||||
# Additional security validation before processing
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
|
||||
# Quick validation of ZIP structure
|
||||
test_zip.testzip()
|
||||
except zipfile.BadZipFile:
|
||||
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
|
||||
|
||||
if async_mode:
|
||||
# Async mode: Queue task and return immediately
|
||||
from uuid import uuid4
|
||||
|
||||
batch_id = uuid4()
|
||||
|
||||
# Create batch task for background processing
|
||||
task = BatchTask(
|
||||
batch_id=batch_id,
|
||||
admin_token=admin_token,
|
||||
zip_content=zip_content,
|
||||
zip_filename=file.filename,
|
||||
upload_source=upload_source,
|
||||
auto_label=auto_label,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Submit to queue
|
||||
queue = get_batch_queue()
|
||||
if not queue.submit(task):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Processing queue is full. Please try again later."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch upload queued: batch_id={batch_id}, "
|
||||
f"filename={file.filename}, async_mode=True"
|
||||
)
|
||||
|
||||
# Return 202 Accepted with batch_id and status URL
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"status": "accepted",
|
||||
"batch_id": str(batch_id),
|
||||
"message": "Batch upload queued for processing",
|
||||
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
|
||||
"queue_depth": queue.get_queue_depth(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Sync mode: Process immediately and return results
|
||||
service = BatchUploadService(admin_db)
|
||||
result = service.process_zip_upload(
|
||||
admin_token=admin_token,
|
||||
zip_filename=file.filename,
|
||||
zip_content=zip_content,
|
||||
upload_source=upload_source,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Batch upload completed: batch_id={result.get('batch_id')}, "
|
||||
f"status={result.get('status')}, files={result.get('successful_files')}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch upload: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to process batch upload. Please contact support."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status/{batch_id}")
|
||||
async def get_batch_status(
|
||||
batch_id: str,
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
) -> dict:
|
||||
"""Get batch upload status and file processing details.
|
||||
|
||||
Args:
|
||||
batch_id: Batch upload ID
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
|
||||
Returns:
|
||||
Batch status with file processing details
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
batch_uuid = UUID(batch_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid batch ID format")
|
||||
|
||||
# Check batch exists and verify ownership
|
||||
batch = admin_db.get_batch_upload(batch_uuid)
|
||||
if not batch:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
|
||||
# CRITICAL: Verify ownership
|
||||
if batch.admin_token != admin_token:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You do not have access to this batch"
|
||||
)
|
||||
|
||||
# Now safe to return details
|
||||
service = BatchUploadService(admin_db)
|
||||
result = service.get_batch_status(batch_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_batch_uploads(
|
||||
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
"""List batch uploads for the current admin token.
|
||||
|
||||
Args:
|
||||
admin_token: Admin authentication token
|
||||
admin_db: Admin database interface
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of batch uploads
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if limit < 1 or limit > 100:
|
||||
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
|
||||
if offset < 0:
|
||||
raise HTTPException(status_code=400, detail="Offset must be non-negative")
|
||||
|
||||
# Get batch uploads filtered by admin token
|
||||
batches, total = admin_db.get_batch_uploads_by_token(
|
||||
admin_token=admin_token,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return {
|
||||
"batches": [
|
||||
{
|
||||
"batch_id": str(b.batch_id),
|
||||
"filename": b.filename,
|
||||
"status": b.status,
|
||||
"total_files": b.total_files,
|
||||
"successful_files": b.successful_files,
|
||||
"failed_files": b.failed_files,
|
||||
"created_at": b.created_at.isoformat() if b.created_at else None,
|
||||
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
|
||||
}
|
||||
for b in batches
|
||||
],
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
16
src/web/api/v1/public/__init__.py
Normal file
16
src/web/api/v1/public/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Public API v1
|
||||
|
||||
Customer-facing endpoints for inference, async processing, and labeling.
|
||||
"""
|
||||
|
||||
from src.web.api.v1.public.inference import create_inference_router
|
||||
from src.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||
from src.web.api.v1.public.labeling import create_labeling_router
|
||||
|
||||
__all__ = [
|
||||
"create_inference_router",
|
||||
"create_async_router",
|
||||
"set_async_service",
|
||||
"create_labeling_router",
|
||||
]
|
||||
372
src/web/api/v1/public/async_api.py
Normal file
372
src/web/api/v1/public/async_api.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
Async API Routes
|
||||
|
||||
FastAPI endpoints for async invoice processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from src.web.dependencies import (
|
||||
ApiKeyDep,
|
||||
AsyncDBDep,
|
||||
PollRateLimitDep,
|
||||
SubmitRateLimitDep,
|
||||
)
|
||||
from src.web.schemas.inference import (
|
||||
AsyncRequestItem,
|
||||
AsyncRequestsListResponse,
|
||||
AsyncResultResponse,
|
||||
AsyncStatus,
|
||||
AsyncStatusResponse,
|
||||
AsyncSubmitResponse,
|
||||
DetectionResult,
|
||||
InferenceResult,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
|
||||
def _validate_request_id(request_id: str) -> None:
|
||||
"""Validate that request_id is a valid UUID format."""
|
||||
try:
|
||||
UUID(request_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid request ID format. Must be a valid UUID.",
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global reference to async processing service (set during app startup)
|
||||
_async_service = None
|
||||
|
||||
|
||||
def set_async_service(service) -> None:
|
||||
"""Set the async processing service instance."""
|
||||
global _async_service
|
||||
_async_service = service
|
||||
|
||||
|
||||
def get_async_service():
|
||||
"""Get the async processing service instance."""
|
||||
if _async_service is None:
|
||||
raise RuntimeError("AsyncProcessingService not initialized")
|
||||
return _async_service
|
||||
|
||||
|
||||
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
|
||||
"""Create async API router."""
|
||||
router = APIRouter(prefix="/async", tags=["Async Processing"])
|
||||
|
||||
@router.post(
|
||||
"/submit",
|
||||
response_model=AsyncSubmitResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
|
||||
503: {"model": ErrorResponse, "description": "Queue full"},
|
||||
},
|
||||
summary="Submit PDF for async processing",
|
||||
description="Submit a PDF or image file for asynchronous processing. "
|
||||
"Returns a request_id that can be used to poll for results.",
|
||||
)
|
||||
async def submit_document(
|
||||
api_key: SubmitRateLimitDep,
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
) -> AsyncSubmitResponse:
|
||||
"""Submit a document for async processing."""
|
||||
# Validate filename
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required")
|
||||
|
||||
# Validate file extension
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {file_ext}. "
|
||||
f"Allowed: {', '.join(allowed_extensions)}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||
|
||||
# Check file size (get from config via service)
|
||||
service = get_async_service()
|
||||
max_size = service._async_config.max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File too large. Maximum size: "
|
||||
f"{service._async_config.max_file_size_mb}MB",
|
||||
)
|
||||
|
||||
# Submit request
|
||||
result = service.submit_request(
|
||||
api_key=api_key,
|
||||
file_content=content,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
if "queue" in (result.error or "").lower():
|
||||
raise HTTPException(status_code=503, detail=result.error)
|
||||
raise HTTPException(status_code=500, detail=result.error)
|
||||
|
||||
return AsyncSubmitResponse(
|
||||
status="accepted",
|
||||
message="Request submitted for processing",
|
||||
request_id=result.request_id,
|
||||
estimated_wait_seconds=result.estimated_wait_seconds,
|
||||
poll_url=f"/api/v1/async/status/{result.request_id}",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/status/{request_id}",
|
||||
response_model=AsyncStatusResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||
},
|
||||
summary="Get request status",
|
||||
description="Get the current processing status of an async request.",
|
||||
)
|
||||
async def get_status(
|
||||
request_id: str,
|
||||
api_key: PollRateLimitDep,
|
||||
db: AsyncDBDep,
|
||||
) -> AsyncStatusResponse:
|
||||
"""Get the status of an async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database (validates API key ownership)
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Get queue position for pending requests
|
||||
position = None
|
||||
if request.status == "pending":
|
||||
position = db.get_queue_position(request_id)
|
||||
|
||||
# Build result URL for completed requests
|
||||
result_url = None
|
||||
if request.status == "completed":
|
||||
result_url = f"/api/v1/async/result/{request_id}"
|
||||
|
||||
return AsyncStatusResponse(
|
||||
request_id=str(request.request_id),
|
||||
status=AsyncStatus(request.status),
|
||||
filename=request.filename,
|
||||
created_at=request.created_at,
|
||||
started_at=request.started_at,
|
||||
completed_at=request.completed_at,
|
||||
position_in_queue=position,
|
||||
error_message=request.error_message,
|
||||
result_url=result_url,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/result/{request_id}",
|
||||
response_model=AsyncResultResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
409: {"model": ErrorResponse, "description": "Request not completed"},
|
||||
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||
},
|
||||
summary="Get extraction results",
|
||||
description="Get the extraction results for a completed async request.",
|
||||
)
|
||||
async def get_result(
|
||||
request_id: str,
|
||||
api_key: PollRateLimitDep,
|
||||
db: AsyncDBDep,
|
||||
) -> AsyncResultResponse:
|
||||
"""Get the results of a completed async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database (validates API key ownership)
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Check if completed or failed
|
||||
if request.status not in ("completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Request not yet completed. Current status: {request.status}",
|
||||
)
|
||||
|
||||
# Build inference result from stored data
|
||||
inference_result = None
|
||||
if request.result:
|
||||
# Convert detections to DetectionResult objects
|
||||
detections = []
|
||||
for d in request.result.get("detections", []):
|
||||
detections.append(DetectionResult(
|
||||
field=d.get("field", ""),
|
||||
confidence=d.get("confidence", 0.0),
|
||||
bbox=d.get("bbox", [0, 0, 0, 0]),
|
||||
))
|
||||
|
||||
inference_result = InferenceResult(
|
||||
document_id=request.result.get("document_id", str(request.request_id)[:8]),
|
||||
success=request.result.get("success", False),
|
||||
document_type=request.result.get("document_type", "invoice"),
|
||||
fields=request.result.get("fields", {}),
|
||||
confidence=request.result.get("confidence", {}),
|
||||
detections=detections,
|
||||
processing_time_ms=request.processing_time_ms or 0.0,
|
||||
errors=request.result.get("errors", []),
|
||||
)
|
||||
|
||||
# Build visualization URL
|
||||
viz_url = None
|
||||
if request.visualization_path:
|
||||
viz_url = f"/api/v1/results/{request.visualization_path}"
|
||||
|
||||
return AsyncResultResponse(
|
||||
request_id=str(request.request_id),
|
||||
status=AsyncStatus(request.status),
|
||||
processing_time_ms=request.processing_time_ms or 0.0,
|
||||
result=inference_result,
|
||||
visualization_url=viz_url,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/requests",
|
||||
response_model=AsyncRequestsListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
},
|
||||
summary="List requests",
|
||||
description="List all async requests for the authenticated API key.",
|
||||
)
|
||||
async def list_requests(
|
||||
api_key: ApiKeyDep,
|
||||
db: AsyncDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (pending, processing, completed, failed)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Maximum number of results"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Pagination offset"),
|
||||
] = 0,
|
||||
) -> AsyncRequestsListResponse:
|
||||
"""List all requests for the authenticated API key."""
|
||||
# Validate status filter
|
||||
if status and status not in ("pending", "processing", "completed", "failed"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status filter: {status}. "
|
||||
"Must be one of: pending, processing, completed, failed",
|
||||
)
|
||||
|
||||
# Get requests from database
|
||||
requests, total = db.get_requests_by_api_key(
|
||||
api_key=api_key,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Convert to response items
|
||||
items = [
|
||||
AsyncRequestItem(
|
||||
request_id=str(r.request_id),
|
||||
status=AsyncStatus(r.status),
|
||||
filename=r.filename,
|
||||
file_size=r.file_size,
|
||||
created_at=r.created_at,
|
||||
completed_at=r.completed_at,
|
||||
)
|
||||
for r in requests
|
||||
]
|
||||
|
||||
return AsyncRequestsListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
requests=items,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/requests/{request_id}",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
|
||||
},
|
||||
summary="Cancel/delete request",
|
||||
description="Cancel a pending request or delete a completed/failed request.",
|
||||
)
|
||||
async def delete_request(
|
||||
request_id: str,
|
||||
api_key: ApiKeyDep,
|
||||
db: AsyncDBDep,
|
||||
) -> dict:
|
||||
"""Delete or cancel an async request."""
|
||||
# Validate UUID format
|
||||
_validate_request_id(request_id)
|
||||
|
||||
# Get request from database
|
||||
request = db.get_request_by_api_key(request_id, api_key)
|
||||
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Request not found or does not belong to this API key",
|
||||
)
|
||||
|
||||
# Cannot delete processing requests
|
||||
if request.status == "processing":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete a request that is currently processing",
|
||||
)
|
||||
|
||||
# Delete from database (will cascade delete related records)
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"DELETE FROM async_requests WHERE request_id = %s",
|
||||
(request_id,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"request_id": request_id,
|
||||
"message": "Request deleted successfully",
|
||||
}
|
||||
|
||||
return router
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
API Routes
|
||||
Inference API Routes
|
||||
|
||||
FastAPI route definitions for the inference API.
|
||||
"""
|
||||
@@ -15,23 +15,22 @@ from typing import TYPE_CHECKING
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .schemas import (
|
||||
BatchInferenceResponse,
|
||||
from src.web.schemas.inference import (
|
||||
DetectionResult,
|
||||
ErrorResponse,
|
||||
HealthResponse,
|
||||
InferenceResponse,
|
||||
InferenceResult,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .services import InferenceService
|
||||
from .config import StorageConfig
|
||||
from src.web.services import InferenceService
|
||||
from src.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_api_router(
|
||||
def create_inference_router(
|
||||
inference_service: "InferenceService",
|
||||
storage_config: "StorageConfig",
|
||||
) -> APIRouter:
|
||||
203
src/web/api/v1/public/labeling.py
Normal file
203
src/web/api/v1/public/labeling.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Labeling API Routes
|
||||
|
||||
FastAPI endpoints for pre-labeling documents with expected field values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.schemas.labeling import PreLabelResponse
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.web.services import InferenceService
|
||||
from src.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage directory for pre-label uploads (legacy, now uses storage_config)
|
||||
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
|
||||
|
||||
|
||||
def _convert_pdf_to_images(
|
||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||
) -> None:
|
||||
"""Convert PDF pages to images for annotation."""
|
||||
import fitz
|
||||
|
||||
doc_images_dir = images_dir / document_id
|
||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = pdf_doc[page_num]
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
|
||||
pdf_doc.close()
|
||||
|
||||
|
||||
def get_admin_db() -> AdminDB:
|
||||
"""Get admin database instance."""
|
||||
return AdminDB()
|
||||
|
||||
|
||||
def create_labeling_router(
|
||||
inference_service: "InferenceService",
|
||||
storage_config: "StorageConfig",
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Create API router with labeling endpoints.
|
||||
|
||||
Args:
|
||||
inference_service: Inference service instance
|
||||
storage_config: Storage configuration
|
||||
|
||||
Returns:
|
||||
Configured APIRouter
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
||||
|
||||
# Ensure upload directory exists
|
||||
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@router.post(
|
||||
"/pre-label",
|
||||
response_model=PreLabelResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
|
||||
500: {"model": ErrorResponse, "description": "Processing error"},
|
||||
},
|
||||
summary="Pre-label document with expected values",
|
||||
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
|
||||
)
|
||||
async def pre_label(
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
field_values: str = Form(
|
||||
...,
|
||||
description="JSON object with expected field values. "
|
||||
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
|
||||
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
|
||||
),
|
||||
db: AdminDB = Depends(get_admin_db),
|
||||
) -> PreLabelResponse:
|
||||
"""
|
||||
Upload a document with expected field values for pre-labeling.
|
||||
|
||||
Returns document_id which can be used to retrieve results later.
|
||||
|
||||
Example field_values JSON:
|
||||
```json
|
||||
{
|
||||
"InvoiceNumber": "12345",
|
||||
"Amount": "1500.00",
|
||||
"Bankgiro": "123-4567",
|
||||
"OCR": "1234567890"
|
||||
}
|
||||
```
|
||||
"""
|
||||
# Parse field_values JSON
|
||||
try:
|
||||
expected_values = json.loads(field_values)
|
||||
if not isinstance(expected_values, dict):
|
||||
raise ValueError("field_values must be a JSON object")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid JSON in field_values: {e}",
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required",
|
||||
)
|
||||
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in storage_config.allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read file",
|
||||
)
|
||||
|
||||
# Get page count for PDF
|
||||
page_count = 1
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import fitz
|
||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||
page_count = len(pdf_doc)
|
||||
pdf_doc.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
|
||||
# Create document record with field_values
|
||||
document_id = db.create_document(
|
||||
filename=file.filename,
|
||||
file_size=len(content),
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
file_path="", # Will update after saving
|
||||
page_count=page_count,
|
||||
upload_source="api",
|
||||
csv_field_values=expected_values,
|
||||
)
|
||||
|
||||
# Save file to admin uploads
|
||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||
try:
|
||||
file_path.write_bytes(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to save file",
|
||||
)
|
||||
|
||||
# Update file path in database
|
||||
db.update_document_file_path(document_id, str(file_path))
|
||||
|
||||
# Convert PDF to images for annotation UI
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
_convert_pdf_to_images(
|
||||
document_id, content, page_count,
|
||||
storage_config.admin_images_dir, storage_config.dpi
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert PDF to images: {e}")
|
||||
|
||||
# Trigger auto-labeling
|
||||
db.update_document_status(
|
||||
document_id=document_id,
|
||||
status="auto_labeling",
|
||||
auto_label_status="pending",
|
||||
)
|
||||
|
||||
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
|
||||
|
||||
return PreLabelResponse(document_id=document_id)
|
||||
|
||||
return router
|
||||
158
src/web/app.py
158
src/web/app.py
@@ -17,8 +17,39 @@ from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from .config import AppConfig, default_config
|
||||
from .routes import create_api_router
|
||||
from .services import InferenceService
|
||||
from src.web.services import InferenceService
|
||||
|
||||
# Public API imports
|
||||
from src.web.api.v1.public import (
|
||||
create_inference_router,
|
||||
create_async_router,
|
||||
set_async_service,
|
||||
create_labeling_router,
|
||||
)
|
||||
|
||||
# Async processing imports
|
||||
from src.data.async_request_db import AsyncRequestDB
|
||||
from src.web.workers.async_queue import AsyncTaskQueue
|
||||
from src.web.services.async_processing import AsyncProcessingService
|
||||
from src.web.dependencies import init_dependencies
|
||||
from src.web.core.rate_limiter import RateLimiter
|
||||
|
||||
# Admin API imports
|
||||
from src.web.api.v1.admin import (
|
||||
create_annotation_router,
|
||||
create_auth_router,
|
||||
create_documents_router,
|
||||
create_locks_router,
|
||||
create_training_router,
|
||||
)
|
||||
from src.web.core.scheduler import start_scheduler, stop_scheduler
|
||||
from src.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
|
||||
|
||||
# Batch upload imports
|
||||
from src.web.api.v1.batch.routes import router as batch_upload_router
|
||||
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from src.web.services.batch_upload import BatchUploadService
|
||||
from src.data.admin_db import AdminDB
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -44,11 +75,38 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
storage_config=config.storage,
|
||||
)
|
||||
|
||||
# Create async processing components
|
||||
async_db = AsyncRequestDB()
|
||||
rate_limiter = RateLimiter(async_db)
|
||||
task_queue = AsyncTaskQueue(
|
||||
max_size=config.async_processing.queue_max_size,
|
||||
worker_count=config.async_processing.worker_count,
|
||||
)
|
||||
async_service = AsyncProcessingService(
|
||||
inference_service=inference_service,
|
||||
db=async_db,
|
||||
queue=task_queue,
|
||||
rate_limiter=rate_limiter,
|
||||
async_config=config.async_processing,
|
||||
storage_config=config.storage,
|
||||
)
|
||||
|
||||
# Initialize dependencies for FastAPI
|
||||
init_dependencies(async_db, rate_limiter)
|
||||
set_async_service(async_service)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting Invoice Inference API...")
|
||||
|
||||
# Initialize database tables
|
||||
try:
|
||||
async_db.create_tables()
|
||||
logger.info("Async database tables ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize async database: {e}")
|
||||
|
||||
# Initialize inference service on startup
|
||||
try:
|
||||
inference_service.initialize()
|
||||
@@ -57,10 +115,75 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
# Continue anyway - service will retry on first request
|
||||
|
||||
# Start async processing service
|
||||
try:
|
||||
async_service.start()
|
||||
logger.info("Async processing service started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start async processing: {e}")
|
||||
|
||||
# Start batch upload queue
|
||||
try:
|
||||
admin_db = AdminDB()
|
||||
batch_service = BatchUploadService(admin_db)
|
||||
init_batch_queue(batch_service)
|
||||
logger.info("Batch upload queue started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start batch upload queue: {e}")
|
||||
|
||||
# Start training scheduler
|
||||
try:
|
||||
start_scheduler()
|
||||
logger.info("Training scheduler started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start training scheduler: {e}")
|
||||
|
||||
# Start auto-label scheduler
|
||||
try:
|
||||
start_autolabel_scheduler()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start autolabel scheduler: {e}")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down Invoice Inference API...")
|
||||
|
||||
# Stop auto-label scheduler
|
||||
try:
|
||||
stop_autolabel_scheduler()
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping autolabel scheduler: {e}")
|
||||
|
||||
# Stop training scheduler
|
||||
try:
|
||||
stop_scheduler()
|
||||
logger.info("Training scheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping training scheduler: {e}")
|
||||
|
||||
# Stop batch upload queue
|
||||
try:
|
||||
shutdown_batch_queue()
|
||||
logger.info("Batch upload queue stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping batch upload queue: {e}")
|
||||
|
||||
# Stop async processing service
|
||||
try:
|
||||
async_service.stop(timeout=30.0)
|
||||
logger.info("Async processing service stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping async service: {e}")
|
||||
|
||||
# Close database connection
|
||||
try:
|
||||
async_db.close()
|
||||
logger.info("Database connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing database: {e}")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Invoice Field Extraction API",
|
||||
@@ -106,9 +229,34 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
name="results",
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
api_router = create_api_router(inference_service, config.storage)
|
||||
app.include_router(api_router)
|
||||
# Include public API routes
|
||||
inference_router = create_inference_router(inference_service, config.storage)
|
||||
app.include_router(inference_router)
|
||||
|
||||
async_router = create_async_router(config.storage.allowed_extensions)
|
||||
app.include_router(async_router, prefix="/api/v1")
|
||||
|
||||
labeling_router = create_labeling_router(inference_service, config.storage)
|
||||
app.include_router(labeling_router)
|
||||
|
||||
# Include admin API routes
|
||||
auth_router = create_auth_router()
|
||||
app.include_router(auth_router, prefix="/api/v1")
|
||||
|
||||
documents_router = create_documents_router(config.storage)
|
||||
app.include_router(documents_router, prefix="/api/v1")
|
||||
|
||||
locks_router = create_locks_router()
|
||||
app.include_router(locks_router, prefix="/api/v1")
|
||||
|
||||
annotation_router = create_annotation_router()
|
||||
app.include_router(annotation_router, prefix="/api/v1")
|
||||
|
||||
training_router = create_training_router()
|
||||
app.include_router(training_router, prefix="/api/v1")
|
||||
|
||||
# Include batch upload routes
|
||||
app.include_router(batch_upload_router)
|
||||
|
||||
# Root endpoint - serve HTML UI
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
|
||||
@@ -8,6 +8,8 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.config import DEFAULT_DPI, PATHS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelConfig:
|
||||
@@ -16,7 +18,7 @@ class ModelConfig:
|
||||
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
||||
confidence_threshold: float = 0.5
|
||||
use_gpu: bool = True
|
||||
dpi: int = 150
|
||||
dpi: int = DEFAULT_DPI
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -32,19 +34,59 @@ class ServerConfig:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageConfig:
|
||||
"""File storage configuration."""
|
||||
"""File storage configuration.
|
||||
|
||||
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
|
||||
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
|
||||
and avoids storing duplicate files.
|
||||
"""
|
||||
|
||||
upload_dir: Path = Path("uploads")
|
||||
result_dir: Path = Path("results")
|
||||
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
|
||||
admin_images_dir: Path = Path("data/admin_images")
|
||||
max_file_size_mb: int = 50
|
||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||
dpi: int = DEFAULT_DPI
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
||||
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
|
||||
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsyncConfig:
|
||||
"""Async processing configuration."""
|
||||
|
||||
# Queue settings
|
||||
queue_max_size: int = 100
|
||||
worker_count: int = 1
|
||||
task_timeout_seconds: int = 300
|
||||
|
||||
# Rate limiting defaults
|
||||
default_requests_per_minute: int = 10
|
||||
default_max_concurrent_jobs: int = 3
|
||||
default_min_poll_interval_ms: int = 1000
|
||||
|
||||
# Storage
|
||||
result_retention_days: int = 7
|
||||
temp_upload_dir: Path = Path("uploads/async")
|
||||
max_file_size_mb: int = 50
|
||||
|
||||
# Cleanup
|
||||
cleanup_interval_hours: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
||||
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -54,6 +96,7 @@ class AppConfig:
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||
@@ -62,6 +105,7 @@ class AppConfig:
|
||||
model=ModelConfig(**config_dict.get("model", {})),
|
||||
server=ServerConfig(**config_dict.get("server", {})),
|
||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
||||
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
||||
)
|
||||
|
||||
|
||||
|
||||
28
src/web/core/__init__.py
Normal file
28
src/web/core/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Core Components
|
||||
|
||||
Reusable core functionality: authentication, rate limiting, scheduling.
|
||||
"""
|
||||
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
|
||||
from src.web.core.rate_limiter import RateLimiter
|
||||
from src.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
|
||||
from src.web.core.autolabel_scheduler import (
|
||||
start_autolabel_scheduler,
|
||||
stop_autolabel_scheduler,
|
||||
get_autolabel_scheduler,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"validate_admin_token",
|
||||
"get_admin_db",
|
||||
"AdminTokenDep",
|
||||
"AdminDBDep",
|
||||
"RateLimiter",
|
||||
"start_scheduler",
|
||||
"stop_scheduler",
|
||||
"get_training_scheduler",
|
||||
"start_autolabel_scheduler",
|
||||
"stop_autolabel_scheduler",
|
||||
"get_autolabel_scheduler",
|
||||
]
|
||||
60
src/web/core/auth.py
Normal file
60
src/web/core/auth.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Admin Authentication
|
||||
|
||||
FastAPI dependencies for admin token authentication.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.database import get_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global AdminDB instance
|
||||
_admin_db: AdminDB | None = None
|
||||
|
||||
|
||||
def get_admin_db() -> AdminDB:
|
||||
"""Get the AdminDB instance."""
|
||||
global _admin_db
|
||||
if _admin_db is None:
|
||||
_admin_db = AdminDB()
|
||||
return _admin_db
|
||||
|
||||
|
||||
def reset_admin_db() -> None:
|
||||
"""Reset the AdminDB instance (for testing)."""
|
||||
global _admin_db
|
||||
_admin_db = None
|
||||
|
||||
|
||||
async def validate_admin_token(
|
||||
x_admin_token: Annotated[str | None, Header()] = None,
|
||||
admin_db: AdminDB = Depends(get_admin_db),
|
||||
) -> str:
|
||||
"""Validate admin token from header."""
|
||||
if not x_admin_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Admin token required. Provide X-Admin-Token header.",
|
||||
)
|
||||
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or expired admin token.",
|
||||
)
|
||||
|
||||
# Update last used timestamp
|
||||
admin_db.update_admin_token_usage(x_admin_token)
|
||||
|
||||
return x_admin_token
|
||||
|
||||
|
||||
# Type alias for dependency injection
|
||||
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
|
||||
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]
|
||||
153
src/web/core/autolabel_scheduler.py
Normal file
153
src/web/core/autolabel_scheduler.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Auto-Label Scheduler
|
||||
|
||||
Background scheduler for processing documents pending auto-labeling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.services.db_autolabel import (
|
||||
get_pending_autolabel_documents,
|
||||
process_document_autolabel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoLabelScheduler:
|
||||
"""Scheduler for auto-labeling tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_seconds: int = 10,
|
||||
batch_size: int = 5,
|
||||
output_dir: Path | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize auto-label scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_seconds: Interval to check for pending tasks
|
||||
batch_size: Number of documents to process per batch
|
||||
output_dir: Output directory for temporary files
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._batch_size = batch_size
|
||||
self._output_dir = output_dir or Path("data/autolabel_output")
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._db = AdminDB()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if self._running:
|
||||
logger.warning("AutoLabel scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("AutoLabel scheduler started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
logger.info("AutoLabel scheduler stopped")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Main scheduler loop."""
|
||||
while self._running:
|
||||
try:
|
||||
self._process_pending_documents()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
|
||||
|
||||
# Wait for next check interval
|
||||
self._stop_event.wait(timeout=self._check_interval)
|
||||
|
||||
def _process_pending_documents(self) -> None:
|
||||
"""Check and process pending auto-label documents."""
|
||||
try:
|
||||
documents = get_pending_autolabel_documents(
|
||||
self._db, limit=self._batch_size
|
||||
)
|
||||
|
||||
if not documents:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(documents)} pending autolabel documents")
|
||||
|
||||
for doc in documents:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
|
||||
try:
|
||||
result = process_document_autolabel(
|
||||
document=doc,
|
||||
db=self._db,
|
||||
output_dir=self._output_dir,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
logger.info(
|
||||
f"AutoLabel completed for document {doc.document_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"AutoLabel failed for document {doc.document_id}: "
|
||||
f"{result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing document {doc.document_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_autolabel_scheduler: AutoLabelScheduler | None = None
|
||||
|
||||
|
||||
def get_autolabel_scheduler() -> AutoLabelScheduler:
|
||||
"""Get the auto-label scheduler instance."""
|
||||
global _autolabel_scheduler
|
||||
if _autolabel_scheduler is None:
|
||||
_autolabel_scheduler = AutoLabelScheduler()
|
||||
return _autolabel_scheduler
|
||||
|
||||
|
||||
def start_autolabel_scheduler() -> None:
|
||||
"""Start the global auto-label scheduler."""
|
||||
scheduler = get_autolabel_scheduler()
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def stop_autolabel_scheduler() -> None:
|
||||
"""Stop the global auto-label scheduler."""
|
||||
global _autolabel_scheduler
|
||||
if _autolabel_scheduler:
|
||||
_autolabel_scheduler.stop()
|
||||
_autolabel_scheduler = None
|
||||
211
src/web/core/rate_limiter.py
Normal file
211
src/web/core/rate_limiter.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Rate Limiter Implementation
|
||||
|
||||
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.data.async_request_db import AsyncRequestDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Rate limit configuration for an API key."""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000 # Minimum time between status polls
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitStatus:
|
||||
"""Current rate limit status."""
|
||||
|
||||
allowed: bool
|
||||
remaining_requests: int
|
||||
reset_at: datetime
|
||||
retry_after_seconds: int | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Thread-safe rate limiter with sliding window algorithm.
|
||||
|
||||
Tracks:
|
||||
- Requests per minute (sliding window)
|
||||
- Concurrent active jobs
|
||||
- Poll frequency per request_id
|
||||
"""
|
||||
|
||||
def __init__(self, db: "AsyncRequestDB") -> None:
|
||||
self._db = db
|
||||
self._lock = Lock()
|
||||
# In-memory tracking for fast checks
|
||||
self._request_windows: dict[str, list[float]] = defaultdict(list)
|
||||
# (api_key, request_id) -> last_poll timestamp
|
||||
self._poll_timestamps: dict[tuple[str, str], float] = {}
|
||||
# Cache for API key configs (TTL 60 seconds)
|
||||
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
|
||||
self._config_cache_ttl = 60.0
|
||||
|
||||
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
|
||||
"""Check if API key can submit a new request."""
|
||||
config = self._get_config(api_key)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60 # 1 minute window
|
||||
|
||||
# Clean old entries
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
|
||||
current_count = len(self._request_windows[api_key])
|
||||
|
||||
if current_count >= config.requests_per_minute:
|
||||
oldest = min(self._request_windows[api_key])
|
||||
retry_after = int(oldest + 60 - now) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
|
||||
retry_after_seconds=max(1, retry_after),
|
||||
reason="Rate limit exceeded: too many requests per minute",
|
||||
)
|
||||
|
||||
# Check concurrent jobs (query database) - inside lock for thread safety
|
||||
active_jobs = self._db.count_active_jobs(api_key)
|
||||
if active_jobs >= config.max_concurrent_jobs:
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=config.requests_per_minute - current_count,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||
retry_after_seconds=30,
|
||||
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
|
||||
)
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=config.requests_per_minute - current_count - 1,
|
||||
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
||||
)
|
||||
|
||||
def record_request(self, api_key: str) -> None:
|
||||
"""Record a successful request submission."""
|
||||
with self._lock:
|
||||
self._request_windows[api_key].append(time.time())
|
||||
|
||||
# Also record in database for persistence
|
||||
try:
|
||||
self._db.record_rate_limit_event(api_key, "request")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record rate limit event: {e}")
|
||||
|
||||
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
|
||||
"""Check if polling is allowed (prevent abuse)."""
|
||||
config = self._get_config(api_key)
|
||||
key = (api_key, request_id)
|
||||
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
last_poll = self._poll_timestamps.get(key, 0)
|
||||
elapsed_ms = (now - last_poll) * 1000
|
||||
|
||||
if elapsed_ms < config.min_poll_interval_ms:
|
||||
# Suggest exponential backoff
|
||||
wait_ms = min(
|
||||
config.min_poll_interval_ms * 2,
|
||||
5000, # Max 5 seconds
|
||||
)
|
||||
retry_after = int(wait_ms / 1000) + 1
|
||||
return RateLimitStatus(
|
||||
allowed=False,
|
||||
remaining_requests=0,
|
||||
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
|
||||
retry_after_seconds=retry_after,
|
||||
reason="Polling too frequently. Please wait before retrying.",
|
||||
)
|
||||
|
||||
# Update poll timestamp
|
||||
self._poll_timestamps[key] = now
|
||||
|
||||
return RateLimitStatus(
|
||||
allowed=True,
|
||||
remaining_requests=999, # No limit on poll count, just frequency
|
||||
reset_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
def _get_config(self, api_key: str) -> RateLimitConfig:
|
||||
"""Get rate limit config for API key with caching."""
|
||||
now = time.time()
|
||||
|
||||
# Check cache
|
||||
if api_key in self._config_cache:
|
||||
cached_config, cached_at = self._config_cache[api_key]
|
||||
if now - cached_at < self._config_cache_ttl:
|
||||
return cached_config
|
||||
|
||||
# Query database
|
||||
db_config = self._db.get_api_key_config(api_key)
|
||||
if db_config:
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=db_config.requests_per_minute,
|
||||
max_concurrent_jobs=db_config.max_concurrent_jobs,
|
||||
)
|
||||
else:
|
||||
config = RateLimitConfig() # Default limits
|
||||
|
||||
# Cache result
|
||||
self._config_cache[api_key] = (config, now)
|
||||
return config
|
||||
|
||||
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
|
||||
"""Clean up old poll timestamps to prevent memory leak."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
cutoff = now - max_age_seconds
|
||||
old_keys = [
|
||||
k for k, v in self._poll_timestamps.items()
|
||||
if v < cutoff
|
||||
]
|
||||
for key in old_keys:
|
||||
del self._poll_timestamps[key]
|
||||
return len(old_keys)
|
||||
|
||||
def cleanup_request_windows(self) -> None:
|
||||
"""Clean up expired entries from request windows."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - 60
|
||||
|
||||
for api_key in list(self._request_windows.keys()):
|
||||
self._request_windows[api_key] = [
|
||||
ts for ts in self._request_windows[api_key]
|
||||
if ts > window_start
|
||||
]
|
||||
# Remove empty entries
|
||||
if not self._request_windows[api_key]:
|
||||
del self._request_windows[api_key]
|
||||
|
||||
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
|
||||
"""Generate rate limit headers for HTTP response."""
|
||||
headers = {
|
||||
"X-RateLimit-Remaining": str(status.remaining_requests),
|
||||
"X-RateLimit-Reset": status.reset_at.isoformat(),
|
||||
}
|
||||
if status.retry_after_seconds:
|
||||
headers["Retry-After"] = str(status.retry_after_seconds)
|
||||
return headers
|
||||
329
src/web/core/scheduler.py
Normal file
329
src/web/core/scheduler.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Admin Training Scheduler
|
||||
|
||||
Background scheduler for training tasks using APScheduler.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingScheduler:
|
||||
"""Scheduler for training tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_seconds: int = 60,
|
||||
):
|
||||
"""
|
||||
Initialize training scheduler.
|
||||
|
||||
Args:
|
||||
check_interval_seconds: Interval to check for pending tasks
|
||||
"""
|
||||
self._check_interval = check_interval_seconds
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._db = AdminDB()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler."""
|
||||
if self._running:
|
||||
logger.warning("Training scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Training scheduler started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
logger.info("Training scheduler stopped")
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Main scheduler loop."""
|
||||
while self._running:
|
||||
try:
|
||||
self._check_pending_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduler loop: {e}")
|
||||
|
||||
# Wait for next check interval
|
||||
self._stop_event.wait(timeout=self._check_interval)
|
||||
|
||||
def _check_pending_tasks(self) -> None:
|
||||
"""Check and execute pending training tasks."""
|
||||
try:
|
||||
tasks = self._db.get_pending_training_tasks()
|
||||
|
||||
for task in tasks:
|
||||
task_id = str(task.task_id)
|
||||
|
||||
# Check if scheduled time has passed
|
||||
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
|
||||
continue
|
||||
|
||||
logger.info(f"Starting training task: {task_id}")
|
||||
|
||||
try:
|
||||
self._execute_task(task_id, task.config or {})
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.update_training_task_status(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking pending tasks: {e}")
|
||||
|
||||
def _execute_task(self, task_id: str, config: dict[str, Any]) -> None:
|
||||
"""Execute a training task."""
|
||||
# Update status to running
|
||||
self._db.update_training_task_status(task_id, "running")
|
||||
self._db.add_training_log(task_id, "INFO", "Training task started")
|
||||
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
epochs = config.get("epochs", 100)
|
||||
batch_size = config.get("batch_size", 16)
|
||||
image_size = config.get("image_size", 640)
|
||||
learning_rate = config.get("learning_rate", 0.01)
|
||||
device = config.get("device", "0")
|
||||
project_name = config.get("project_name", "invoice_fields")
|
||||
|
||||
# Export annotations for training
|
||||
export_result = self._export_training_data(task_id)
|
||||
if not export_result:
|
||||
raise ValueError("Failed to export training data")
|
||||
|
||||
data_yaml = export_result["data_yaml"]
|
||||
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
|
||||
# Run YOLO training
|
||||
result = self._run_yolo_training(
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
image_size=image_size,
|
||||
learning_rate=learning_rate,
|
||||
device=device,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
# Update task with results
|
||||
self._db.update_training_task_status(
|
||||
task_id=task_id,
|
||||
status="completed",
|
||||
result_metrics=result.get("metrics"),
|
||||
model_path=result.get("model_path"),
|
||||
)
|
||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from src.data.admin_models import FIELD_CLASSES
|
||||
|
||||
# Get all labeled documents
|
||||
documents = self._db.get_labeled_documents_for_export()
|
||||
|
||||
if not documents:
|
||||
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
||||
return None
|
||||
|
||||
# Create export directory
|
||||
export_dir = Path("data/training") / task_id
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# YOLO format directories
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 80/20 train/val split
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * 0.8)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
# Export documents
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
annotations = self._db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
# Copy image
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
continue
|
||||
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
total_images += 1
|
||||
|
||||
# Write YOLO label
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
# Create data.yaml
|
||||
yaml_path = export_dir / "data.yaml"
|
||||
yaml_content = f"""path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
yaml_path.write_text(yaml_content)
|
||||
|
||||
return {
|
||||
"data_yaml": str(yaml_path),
|
||||
"total_images": total_images,
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _run_yolo_training(
|
||||
self,
|
||||
task_id: str,
|
||||
model_name: str,
|
||||
data_yaml: str,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
image_size: int,
|
||||
learning_rate: float,
|
||||
device: str,
|
||||
project_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run YOLO training."""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Log training start
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_name)
|
||||
|
||||
# Train
|
||||
results = model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
batch=batch_size,
|
||||
imgsz=image_size,
|
||||
lr0=learning_rate,
|
||||
device=device,
|
||||
project=f"runs/train/{project_name}",
|
||||
name=f"task_{task_id[:8]}",
|
||||
exist_ok=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Get best model path
|
||||
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||
|
||||
# Extract metrics
|
||||
metrics = {}
|
||||
if hasattr(results, "results_dict"):
|
||||
metrics = {
|
||||
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||
}
|
||||
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
|
||||
)
|
||||
|
||||
return {
|
||||
"model_path": str(best_model) if best_model.exists() else None,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
|
||||
raise ValueError("Ultralytics (YOLO) not installed")
|
||||
except Exception as e:
|
||||
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Global scheduler instance
|
||||
_scheduler: TrainingScheduler | None = None
|
||||
|
||||
|
||||
def get_training_scheduler() -> TrainingScheduler:
|
||||
"""Get the training scheduler instance."""
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
_scheduler = TrainingScheduler()
|
||||
return _scheduler
|
||||
|
||||
|
||||
def start_scheduler() -> None:
|
||||
"""Start the global training scheduler."""
|
||||
scheduler = get_training_scheduler()
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def stop_scheduler() -> None:
|
||||
"""Stop the global training scheduler."""
|
||||
global _scheduler
|
||||
if _scheduler:
|
||||
_scheduler.stop()
|
||||
_scheduler = None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user