WIP
This commit is contained in:
@@ -7,7 +7,8 @@
|
|||||||
"Edit(*)",
|
"Edit(*)",
|
||||||
"Glob(*)",
|
"Glob(*)",
|
||||||
"Grep(*)",
|
"Grep(*)",
|
||||||
"Task(*)"
|
"Task(*)",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_batch_upload_routes.py::TestBatchUploadRoutes::test_upload_batch_async_mode_default -v -s 2>&1 | head -100\")"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,7 +81,13 @@
|
|||||||
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
|
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
|
||||||
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
|
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
|
||||||
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
|
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
|
||||||
"Bash(wsl bash -c:*)"
|
"Bash(wsl bash -c:*)",
|
||||||
|
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -120\")",
|
||||||
|
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -80\")",
|
||||||
|
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
|
||||||
|
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": [],
|
"ask": [],
|
||||||
|
|||||||
34
README.md
34
README.md
@@ -76,6 +76,38 @@
|
|||||||
| 8 | payment_line | 支付行 (机器可读格式) |
|
| 8 | payment_line | 支付行 (机器可读格式) |
|
||||||
| 9 | customer_number | 客户编号 |
|
| 9 | customer_number | 客户编号 |
|
||||||
|
|
||||||
|
## DPI 配置
|
||||||
|
|
||||||
|
**重要**: 系统所有组件统一使用 **150 DPI**,确保训练和推理的一致性。
|
||||||
|
|
||||||
|
DPI(每英寸点数)设置必须在训练和推理时保持一致,否则会导致:
|
||||||
|
- 检测框尺寸失配
|
||||||
|
- mAP显著下降(可能从93.5%降到60-70%)
|
||||||
|
- 字段漏检或误检
|
||||||
|
|
||||||
|
### 配置位置
|
||||||
|
|
||||||
|
| 组件 | 配置文件 | 配置项 |
|
||||||
|
|------|---------|--------|
|
||||||
|
| **全局常量** | `src/config.py` | `DEFAULT_DPI = 150` |
|
||||||
|
| **Web推理** | `src/web/config.py` | `ModelConfig.dpi` (导入自 `src.config`) |
|
||||||
|
| **CLI推理** | `src/cli/infer.py` | `--dpi` 默认值 = `DEFAULT_DPI` |
|
||||||
|
| **自动标注** | `src/config.py` | `AUTOLABEL['dpi'] = DEFAULT_DPI` |
|
||||||
|
| **PDF转图** | `src/web/api/v1/admin/documents.py` | 使用 `DEFAULT_DPI` |
|
||||||
|
|
||||||
|
### 使用示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 训练(使用默认150 DPI)
|
||||||
|
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
|
||||||
|
|
||||||
|
# 推理(默认150 DPI,与训练一致)
|
||||||
|
python -m src.cli.infer -m runs/train/invoice_fields/weights/best.pt -i invoice.pdf
|
||||||
|
|
||||||
|
# 手动指定DPI(仅当需要与非默认训练DPI的模型配合时)
|
||||||
|
python -m src.cli.infer -m custom_model.pt -i invoice.pdf --dpi 150
|
||||||
|
```
|
||||||
|
|
||||||
## 安装
|
## 安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -490,7 +522,7 @@ Options:
|
|||||||
--input, -i 输入 PDF/图像
|
--input, -i 输入 PDF/图像
|
||||||
--output, -o 输出 JSON 路径
|
--output, -o 输出 JSON 路径
|
||||||
--confidence 置信度阈值 (默认: 0.5)
|
--confidence 置信度阈值 (默认: 0.5)
|
||||||
--dpi 渲染 DPI (默认: 300)
|
--dpi 渲染 DPI (默认: 150, 必须与训练DPI一致)
|
||||||
--gpu 使用 GPU
|
--gpu 使用 GPU
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
96
create_shims.sh
Normal file
96
create_shims.sh
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Create backward compatibility shims for all migrated files
|
||||||
|
|
||||||
|
# admin_auth.py -> core/auth.py
|
||||||
|
cat > src/web/admin_auth.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.core.auth instead"""
|
||||||
|
from src.web.core.auth import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# admin_autolabel.py -> services/autolabel.py
|
||||||
|
cat > src/web/admin_autolabel.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.services.autolabel instead"""
|
||||||
|
from src.web.services.autolabel import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# admin_scheduler.py -> core/scheduler.py
|
||||||
|
cat > src/web/admin_scheduler.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.core.scheduler instead"""
|
||||||
|
from src.web.core.scheduler import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# admin_schemas.py -> schemas/admin.py
|
||||||
|
cat > src/web/admin_schemas.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.schemas.admin instead"""
|
||||||
|
from src.web.schemas.admin import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# schemas.py -> schemas/inference.py + schemas/common.py
|
||||||
|
cat > src/web/schemas.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.schemas.inference or src.web.schemas.common instead"""
|
||||||
|
from src.web.schemas.inference import * # noqa: F401, F403
|
||||||
|
from src.web.schemas.common import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# services.py -> services/inference.py
|
||||||
|
cat > src/web/services.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.services.inference instead"""
|
||||||
|
from src.web.services.inference import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# async_queue.py -> workers/async_queue.py
|
||||||
|
cat > src/web/async_queue.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.workers.async_queue instead"""
|
||||||
|
from src.web.workers.async_queue import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# async_service.py -> services/async_processing.py
|
||||||
|
cat > src/web/async_service.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.services.async_processing instead"""
|
||||||
|
from src.web.services.async_processing import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# batch_queue.py -> workers/batch_queue.py
|
||||||
|
cat > src/web/batch_queue.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.workers.batch_queue instead"""
|
||||||
|
from src.web.workers.batch_queue import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# batch_upload_service.py -> services/batch_upload.py
|
||||||
|
cat > src/web/batch_upload_service.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.services.batch_upload instead"""
|
||||||
|
from src.web.services.batch_upload import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# batch_upload_routes.py -> api/v1/batch/routes.py
|
||||||
|
cat > src/web/batch_upload_routes.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.api.v1.batch.routes instead"""
|
||||||
|
from src.web.api.v1.batch.routes import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# admin_routes.py -> api/v1/admin/documents.py
|
||||||
|
cat > src/web/admin_routes.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.api.v1.admin.documents instead"""
|
||||||
|
from src.web.api.v1.admin.documents import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# admin_annotation_routes.py -> api/v1/admin/annotations.py
|
||||||
|
cat > src/web/admin_annotation_routes.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.api.v1.admin.annotations instead"""
|
||||||
|
from src.web.api.v1.admin.annotations import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# admin_training_routes.py -> api/v1/admin/training.py
|
||||||
|
cat > src/web/admin_training_routes.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.api.v1.admin.training instead"""
|
||||||
|
from src.web.api.v1.admin.training import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# routes.py -> api/v1/routes.py
|
||||||
|
cat > src/web/routes.py << 'EOF'
|
||||||
|
"""DEPRECATED: Import from src.web.api.v1.routes instead"""
|
||||||
|
from src.web.api.v1.routes import * # noqa: F401, F403
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "✓ Created backward compatibility shims for all migrated files"
|
||||||
@@ -1,405 +0,0 @@
|
|||||||
# Invoice Master POC v2 - 代码审查报告
|
|
||||||
|
|
||||||
**审查日期**: 2026-01-22
|
|
||||||
**代码库规模**: 67 个 Python 源文件,约 22,434 行代码
|
|
||||||
**测试覆盖率**: ~40-50%
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 执行摘要
|
|
||||||
|
|
||||||
### 总体评估:**良好(B+)**
|
|
||||||
|
|
||||||
**优势**:
|
|
||||||
- ✅ 清晰的模块化架构,职责分离良好
|
|
||||||
- ✅ 使用了合适的数据类和类型提示
|
|
||||||
- ✅ 针对瑞典发票的全面规范化逻辑
|
|
||||||
- ✅ 空间索引优化(O(1) token 查找)
|
|
||||||
- ✅ 完善的降级机制(YOLO 失败时的 OCR fallback)
|
|
||||||
- ✅ 设计良好的 Web API 和 UI
|
|
||||||
|
|
||||||
**主要问题**:
|
|
||||||
- ❌ 支付行解析代码重复(3+ 处)
|
|
||||||
- ❌ 长函数(`_normalize_customer_number` 127 行)
|
|
||||||
- ❌ 配置安全问题(明文数据库密码)
|
|
||||||
- ❌ 异常处理不一致(到处都是通用 Exception)
|
|
||||||
- ❌ 缺少集成测试
|
|
||||||
- ❌ 魔法数字散布各处(0.5, 0.95, 300 等)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. 架构分析
|
|
||||||
|
|
||||||
### 1.1 模块结构
|
|
||||||
|
|
||||||
```
|
|
||||||
src/
|
|
||||||
├── inference/ # 推理管道核心
|
|
||||||
│ ├── pipeline.py (517 行) ⚠️
|
|
||||||
│ ├── field_extractor.py (1,347 行) 🔴 太长
|
|
||||||
│ └── yolo_detector.py
|
|
||||||
├── web/ # FastAPI Web 服务
|
|
||||||
│ ├── app.py (765 行) ⚠️ HTML 内联
|
|
||||||
│ ├── routes.py (184 行)
|
|
||||||
│ └── services.py (286 行)
|
|
||||||
├── ocr/ # OCR 提取
|
|
||||||
│ ├── paddle_ocr.py
|
|
||||||
│ └── machine_code_parser.py (919 行) 🔴 太长
|
|
||||||
├── matcher/ # 字段匹配
|
|
||||||
│ └── field_matcher.py (875 行) ⚠️
|
|
||||||
├── utils/ # 共享工具
|
|
||||||
│ ├── validators.py
|
|
||||||
│ ├── text_cleaner.py
|
|
||||||
│ ├── fuzzy_matcher.py
|
|
||||||
│ ├── ocr_corrections.py
|
|
||||||
│ └── format_variants.py (610 行)
|
|
||||||
├── processing/ # 批处理
|
|
||||||
├── data/ # 数据管理
|
|
||||||
└── cli/ # 命令行工具
|
|
||||||
```
|
|
||||||
|
|
||||||
### 1.2 推理流程
|
|
||||||
|
|
||||||
```
|
|
||||||
PDF/Image 输入
|
|
||||||
↓
|
|
||||||
渲染为图片 (pdf/renderer.py)
|
|
||||||
↓
|
|
||||||
YOLO 检测 (yolo_detector.py) - 检测字段区域
|
|
||||||
↓
|
|
||||||
字段提取 (field_extractor.py)
|
|
||||||
├→ OCR 文本提取 (ocr/paddle_ocr.py)
|
|
||||||
├→ 规范化 & 验证
|
|
||||||
└→ 置信度计算
|
|
||||||
↓
|
|
||||||
交叉验证 (pipeline.py)
|
|
||||||
├→ 解析 payment_line 格式
|
|
||||||
├→ 从 payment_line 提取 OCR/Amount/Account
|
|
||||||
└→ 与检测字段验证,payment_line 值优先
|
|
||||||
↓
|
|
||||||
降级 OCR(如果关键字段缺失)
|
|
||||||
├→ 全页 OCR
|
|
||||||
└→ 正则提取
|
|
||||||
↓
|
|
||||||
InferenceResult 输出
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. 代码质量问题
|
|
||||||
|
|
||||||
### 2.1 长函数(>50 行)🔴
|
|
||||||
|
|
||||||
| 函数 | 文件 | 行数 | 复杂度 | 问题 |
|
|
||||||
|------|------|------|--------|------|
|
|
||||||
| `_normalize_customer_number()` | field_extractor.py | **127** | 极高 | 4 层模式匹配,7+ 正则,复杂评分 |
|
|
||||||
| `_cross_validate_payment_line()` | pipeline.py | **127** | 极高 | 核心验证逻辑,8+ 条件分支 |
|
|
||||||
| `_normalize_bankgiro()` | field_extractor.py | 62 | 高 | Luhn 验证 + 多种降级 |
|
|
||||||
| `_normalize_plusgiro()` | field_extractor.py | 63 | 高 | 类似 bankgiro |
|
|
||||||
| `_normalize_payment_line()` | field_extractor.py | 74 | 高 | 4 种正则模式 |
|
|
||||||
| `_normalize_amount()` | field_extractor.py | 78 | 高 | 多策略降级 |
|
|
||||||
|
|
||||||
**示例问题** - `_normalize_customer_number()` (第 776-902 行):
|
|
||||||
```python
|
|
||||||
def _normalize_customer_number(self, text: str):
|
|
||||||
# 127 行函数,包含:
|
|
||||||
# - 4 个嵌套的 if/for 循环
|
|
||||||
# - 7 种不同的正则模式
|
|
||||||
# - 5 个评分机制
|
|
||||||
# - 处理有标签和无标签格式
|
|
||||||
```
|
|
||||||
|
|
||||||
**建议**: 拆分为:
|
|
||||||
- `_find_customer_code_patterns()`
|
|
||||||
- `_find_labeled_customer_code()`
|
|
||||||
- `_score_customer_candidates()`
|
|
||||||
|
|
||||||
### 2.2 代码重复 🔴
|
|
||||||
|
|
||||||
**支付行解析(3+ 处重复实现)**:
|
|
||||||
|
|
||||||
1. `_parse_machine_readable_payment_line()` (pipeline.py:217-252)
|
|
||||||
2. `MachineCodeParser.parse()` (machine_code_parser.py:919 行)
|
|
||||||
3. `_normalize_payment_line()` (field_extractor.py:632-705)
|
|
||||||
|
|
||||||
所有三处都实现类似的正则模式:
|
|
||||||
```
|
|
||||||
格式: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
|
||||||
```
|
|
||||||
|
|
||||||
**Bankgiro/Plusgiro 验证(重复)**:
|
|
||||||
- `validators.py`: `is_valid_bankgiro()`, `format_bankgiro()`
|
|
||||||
- `field_extractor.py`: `_normalize_bankgiro()`, `_normalize_plusgiro()`, `_luhn_checksum()`
|
|
||||||
- `normalizer.py`: `normalize_bankgiro()`, `normalize_plusgiro()`
|
|
||||||
- `field_matcher.py`: 类似匹配逻辑
|
|
||||||
|
|
||||||
**建议**: 创建统一模块:
|
|
||||||
```python
|
|
||||||
# src/common/payment_line_parser.py
|
|
||||||
class PaymentLineParser:
|
|
||||||
def parse(text: str) -> PaymentLineResult
|
|
||||||
|
|
||||||
# src/common/giro_validator.py
|
|
||||||
class GiroValidator:
|
|
||||||
def validate_and_format(value: str, giro_type: str) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.3 错误处理不一致 ⚠️
|
|
||||||
|
|
||||||
**通用异常捕获(31 处)**:
|
|
||||||
```python
|
|
||||||
except Exception as e: # 代码库中 31 处
|
|
||||||
result.errors.append(str(e))
|
|
||||||
```
|
|
||||||
|
|
||||||
**问题**:
|
|
||||||
- 没有捕获特定错误类型
|
|
||||||
- 通用错误消息丢失上下文
|
|
||||||
- 第 142-147 行 (routes.py): 捕获所有异常,返回 500 状态
|
|
||||||
|
|
||||||
**当前写法** (routes.py:142-147):
|
|
||||||
```python
|
|
||||||
try:
|
|
||||||
service_result = inference_service.process_pdf(...)
|
|
||||||
except Exception as e: # 太宽泛
|
|
||||||
logger.error(f"Error processing document: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
```
|
|
||||||
|
|
||||||
**改进建议**:
|
|
||||||
```python
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise HTTPException(status_code=400, detail="PDF 文件未找到")
|
|
||||||
except PyMuPDFError:
|
|
||||||
raise HTTPException(status_code=400, detail="无效的 PDF 格式")
|
|
||||||
except OCRError:
|
|
||||||
raise HTTPException(status_code=503, detail="OCR 服务不可用")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.4 配置安全问题 🔴
|
|
||||||
|
|
||||||
**config.py 第 24-30 行** - 明文凭据:
|
|
||||||
```python
|
|
||||||
DATABASE = {
|
|
||||||
'host': '192.168.68.31', # 硬编码 IP
|
|
||||||
'user': 'docmaster', # 硬编码用户名
|
|
||||||
'password': 'nY6LYK5d', # 🔴 明文密码!
|
|
||||||
'database': 'invoice_master'
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**建议**:
|
|
||||||
```python
|
|
||||||
DATABASE = {
|
|
||||||
'host': os.getenv('DB_HOST', 'localhost'),
|
|
||||||
'user': os.getenv('DB_USER', 'docmaster'),
|
|
||||||
'password': os.getenv('DB_PASSWORD'), # 从环境变量读取
|
|
||||||
'database': os.getenv('DB_NAME', 'invoice_master')
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.5 魔法数字 ⚠️
|
|
||||||
|
|
||||||
| 值 | 位置 | 用途 | 问题 |
|
|
||||||
|---|------|------|------|
|
|
||||||
| 0.5 | 多处 | 置信度阈值 | 不可按字段配置 |
|
|
||||||
| 0.95 | pipeline.py | payment_line 置信度 | 无说明 |
|
|
||||||
| 300 | 多处 | DPI | 硬编码 |
|
|
||||||
| 0.1 | field_extractor.py | BBox 填充 | 应为配置 |
|
|
||||||
| 72 | 多处 | PDF 基础 DPI | 公式中的魔法数字 |
|
|
||||||
| 50 | field_extractor.py | 客户编号评分加分 | 无说明 |
|
|
||||||
|
|
||||||
**建议**: 提取到配置:
|
|
||||||
```python
|
|
||||||
INFERENCE_CONFIG = {
|
|
||||||
'confidence_threshold': 0.5,
|
|
||||||
'payment_line_confidence': 0.95,
|
|
||||||
'dpi': 300,
|
|
||||||
'bbox_padding': 0.1,
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.6 命名不一致 ⚠️
|
|
||||||
|
|
||||||
**字段名称不一致**:
|
|
||||||
- YOLO 类名: `invoice_number`, `ocr_number`, `supplier_org_number`
|
|
||||||
- 字段名: `InvoiceNumber`, `OCR`, `supplier_org_number`
|
|
||||||
- CSV 列名: 可能又不同
|
|
||||||
- 数据库字段名: 另一种变体
|
|
||||||
|
|
||||||
映射维护在多处:
|
|
||||||
- `yolo_detector.py` (90-100 行): `CLASS_TO_FIELD`
|
|
||||||
- 多个其他位置
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. 测试分析
|
|
||||||
|
|
||||||
### 3.1 测试覆盖率
|
|
||||||
|
|
||||||
**测试文件**: 13 个
|
|
||||||
- ✅ 覆盖良好: field_matcher, normalizer, payment_line_parser
|
|
||||||
- ⚠️ 中等覆盖: field_extractor, pipeline
|
|
||||||
- ❌ 覆盖不足: web 层, CLI, 批处理
|
|
||||||
|
|
||||||
**估算覆盖率**: 40-50%
|
|
||||||
|
|
||||||
### 3.2 缺失的测试用例 🔴
|
|
||||||
|
|
||||||
**关键缺失**:
|
|
||||||
1. 交叉验证逻辑 - 最复杂部分,测试很少
|
|
||||||
2. payment_line 解析变体 - 多种实现,边界情况不清楚
|
|
||||||
3. OCR 错误纠正 - 不同策略的复杂逻辑
|
|
||||||
4. Web API 端点 - 没有请求/响应测试
|
|
||||||
5. 批处理 - 多 worker 协调未测试
|
|
||||||
6. 降级 OCR 机制 - YOLO 检测失败时
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. 架构风险
|
|
||||||
|
|
||||||
### 🔴 关键风险
|
|
||||||
|
|
||||||
1. **配置安全** - config.py 中明文数据库凭据(24-30 行)
|
|
||||||
2. **错误恢复** - 宽泛的异常处理掩盖真实问题
|
|
||||||
3. **可测试性** - 硬编码依赖阻止单元测试
|
|
||||||
|
|
||||||
### 🟡 高风险
|
|
||||||
|
|
||||||
1. **代码可维护性** - 支付行解析重复
|
|
||||||
2. **可扩展性** - 没有长时间推理的异步处理
|
|
||||||
3. **扩展性** - 添加新字段类型会很困难
|
|
||||||
|
|
||||||
### 🟢 中等风险
|
|
||||||
|
|
||||||
1. **性能** - 懒加载有帮助,但 ORM 查询未优化
|
|
||||||
2. **文档** - 大部分足够但可以更好
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. 优先级矩阵
|
|
||||||
|
|
||||||
| 优先级 | 行动 | 工作量 | 影响 |
|
|
||||||
|--------|------|--------|------|
|
|
||||||
| 🔴 关键 | 修复配置安全(环境变量) | 1 小时 | 高 |
|
|
||||||
| 🔴 关键 | 添加集成测试 | 2-3 天 | 高 |
|
|
||||||
| 🔴 关键 | 文档化错误处理策略 | 4 小时 | 中 |
|
|
||||||
| 🟡 高 | 统一 payment_line 解析 | 1-2 天 | 高 |
|
|
||||||
| 🟡 高 | 提取规范化到子模块 | 2-3 天 | 中 |
|
|
||||||
| 🟡 高 | 添加依赖注入 | 2-3 天 | 中 |
|
|
||||||
| 🟡 高 | 拆分长函数 | 2-3 天 | 低 |
|
|
||||||
| 🟢 中 | 提高测试覆盖率到 70%+ | 3-5 天 | 高 |
|
|
||||||
| 🟢 中 | 提取魔法数字 | 4 小时 | 低 |
|
|
||||||
| 🟢 中 | 标准化命名约定 | 1-2 天 | 中 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. 具体文件建议
|
|
||||||
|
|
||||||
### 高优先级(代码质量)
|
|
||||||
|
|
||||||
| 文件 | 问题 | 建议 |
|
|
||||||
|------|------|------|
|
|
||||||
| `field_extractor.py` | 1,347 行;6 个长规范化方法 | 拆分为 `normalizers/` 子模块 |
|
|
||||||
| `pipeline.py` | 127 行 `_cross_validate_payment_line()` | 提取到单独的 `CrossValidator` 类 |
|
|
||||||
| `field_matcher.py` | 875 行;复杂匹配逻辑 | 拆分为 `matching/` 子模块 |
|
|
||||||
| `config.py` | 硬编码凭据(第 29 行) | 使用环境变量 |
|
|
||||||
| `machine_code_parser.py` | 919 行;payment_line 解析 | 与 pipeline 解析合并 |
|
|
||||||
|
|
||||||
### 中优先级(重构)
|
|
||||||
|
|
||||||
| 文件 | 问题 | 建议 |
|
|
||||||
|------|------|------|
|
|
||||||
| `app.py` | 765 行;HTML 内联在 Python 中 | 提取到 `templates/` 目录 |
|
|
||||||
| `autolabel.py` | 753 行;批处理逻辑 | 提取 worker 函数到模块 |
|
|
||||||
| `format_variants.py` | 610 行;变体生成 | 考虑策略模式 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. 建议行动
|
|
||||||
|
|
||||||
### 第 1 阶段:关键修复(1 周)
|
|
||||||
|
|
||||||
1. **配置安全** (1 小时)
|
|
||||||
- 移除 config.py 中的明文密码
|
|
||||||
- 添加环境变量支持
|
|
||||||
- 更新 README 说明配置
|
|
||||||
|
|
||||||
2. **错误处理标准化** (1 天)
|
|
||||||
- 定义自定义异常类
|
|
||||||
- 替换通用 Exception 捕获
|
|
||||||
- 添加错误代码常量
|
|
||||||
|
|
||||||
3. **添加关键集成测试** (2 天)
|
|
||||||
- 端到端推理测试
|
|
||||||
- payment_line 交叉验证测试
|
|
||||||
- API 端点测试
|
|
||||||
|
|
||||||
### 第 2 阶段:重构(2-3 周)
|
|
||||||
|
|
||||||
4. **统一 payment_line 解析** (2 天)
|
|
||||||
- 创建 `src/common/payment_line_parser.py`
|
|
||||||
- 合并 3 处重复实现
|
|
||||||
- 迁移所有调用方
|
|
||||||
|
|
||||||
5. **拆分 field_extractor.py** (3 天)
|
|
||||||
- 创建 `src/inference/normalizers/` 子模块
|
|
||||||
- 每个字段类型一个文件
|
|
||||||
- 提取共享验证逻辑
|
|
||||||
|
|
||||||
6. **拆分长函数** (2 天)
|
|
||||||
- `_normalize_customer_number()` → 3 个函数
|
|
||||||
- `_cross_validate_payment_line()` → CrossValidator 类
|
|
||||||
|
|
||||||
### 第 3 阶段:改进(1-2 周)
|
|
||||||
|
|
||||||
7. **提高测试覆盖率** (5 天)
|
|
||||||
- 目标:70%+ 覆盖率
|
|
||||||
- 专注于验证逻辑
|
|
||||||
- 添加边界情况测试
|
|
||||||
|
|
||||||
8. **配置管理改进** (1 天)
|
|
||||||
- 提取所有魔法数字
|
|
||||||
- 创建配置文件(YAML)
|
|
||||||
- 添加配置验证
|
|
||||||
|
|
||||||
9. **文档改进** (2 天)
|
|
||||||
- 添加架构图
|
|
||||||
- 文档化所有私有方法
|
|
||||||
- 创建贡献指南
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 附录 A:度量指标
|
|
||||||
|
|
||||||
### 代码复杂度
|
|
||||||
|
|
||||||
| 类别 | 计数 | 平均行数 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 源文件 | 67 | 334 |
|
|
||||||
| 长文件 (>500 行) | 12 | 875 |
|
|
||||||
| 长函数 (>50 行) | 23 | 89 |
|
|
||||||
| 测试文件 | 13 | 298 |
|
|
||||||
|
|
||||||
### 依赖关系
|
|
||||||
|
|
||||||
| 类型 | 计数 |
|
|
||||||
|------|------|
|
|
||||||
| 外部依赖 | ~25 |
|
|
||||||
| 内部模块 | 10 |
|
|
||||||
| 循环依赖 | 0 ✅ |
|
|
||||||
|
|
||||||
### 代码风格
|
|
||||||
|
|
||||||
| 指标 | 覆盖率 |
|
|
||||||
|------|--------|
|
|
||||||
| 类型提示 | 80% |
|
|
||||||
| Docstrings (公开) | 80% |
|
|
||||||
| Docstrings (私有) | 40% |
|
|
||||||
| 测试覆盖率 | 45% |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**生成日期**: 2026-01-22
|
|
||||||
**审查者**: Claude Code
|
|
||||||
**版本**: v2.0
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
# Field Extractor 分析报告
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
field_extractor.py (1183行) 最初被识别为可优化文件,尝试使用 `src/normalize` 模块进行重构,但经过分析和测试后发现 **不应该重构**。
|
|
||||||
|
|
||||||
## 重构尝试
|
|
||||||
|
|
||||||
### 初始计划
|
|
||||||
将 field_extractor.py 中的重复 normalize 方法删除,统一使用 `src/normalize/normalize_field()` 接口。
|
|
||||||
|
|
||||||
### 实施步骤
|
|
||||||
1. ✅ 备份原文件 (`field_extractor_old.py`)
|
|
||||||
2. ✅ 修改 `_normalize_and_validate` 使用统一 normalizer
|
|
||||||
3. ✅ 删除重复的 normalize 方法 (~400行)
|
|
||||||
4. ❌ 运行测试 - **28个失败**
|
|
||||||
5. ✅ 添加 wrapper 方法委托给 normalizer
|
|
||||||
6. ❌ 再次测试 - **12个失败**
|
|
||||||
7. ✅ 还原原文件
|
|
||||||
8. ✅ 测试通过 - **全部45个测试通过**
|
|
||||||
|
|
||||||
## 关键发现
|
|
||||||
|
|
||||||
### 两个模块的不同用途
|
|
||||||
|
|
||||||
| 模块 | 用途 | 输入 | 输出 | 示例 |
|
|
||||||
|------|------|------|------|------|
|
|
||||||
| **src/normalize/** | **变体生成** 用于匹配 | 已提取的字段值 | 多个匹配变体列表 | `"INV-12345"` → `["INV-12345", "12345"]` |
|
|
||||||
| **field_extractor** | **值提取** 从OCR文本 | 包含字段的原始OCR文本 | 提取的单个字段值 | `"Fakturanummer: A3861"` → `"A3861"` |
|
|
||||||
|
|
||||||
### 为什么不能统一?
|
|
||||||
|
|
||||||
1. **src/normalize/** 的设计目的:
|
|
||||||
- 接收已经提取的字段值
|
|
||||||
- 生成多个标准化变体用于fuzzy matching
|
|
||||||
- 例如 BankgiroNormalizer:
|
|
||||||
```python
|
|
||||||
normalize("782-1713") → ["7821713", "782-1713"] # 生成变体
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **field_extractor** 的 normalize 方法:
|
|
||||||
- 接收包含字段的原始OCR文本(可能包含标签、其他文本等)
|
|
||||||
- **提取**特定模式的字段值
|
|
||||||
- 例如 `_normalize_bankgiro`:
|
|
||||||
```python
|
|
||||||
_normalize_bankgiro("Bankgiro: 782-1713") → ("782-1713", True, None) # 从文本提取
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **关键区别**:
|
|
||||||
- Normalizer: 变体生成器 (for matching)
|
|
||||||
- Field Extractor: 模式提取器 (for parsing)
|
|
||||||
|
|
||||||
### 测试失败示例
|
|
||||||
|
|
||||||
使用 normalizer 替代 field extractor 方法后的失败:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# InvoiceNumber 测试
|
|
||||||
Input: "Fakturanummer: A3861"
|
|
||||||
期望: "A3861"
|
|
||||||
实际: "Fakturanummer: A3861" # 没有提取,只是清理
|
|
||||||
|
|
||||||
# Bankgiro 测试
|
|
||||||
Input: "Bankgiro: 782-1713"
|
|
||||||
期望: "782-1713"
|
|
||||||
实际: "7821713" # 返回了不带破折号的变体,而不是提取格式化值
|
|
||||||
```
|
|
||||||
|
|
||||||
## 结论
|
|
||||||
|
|
||||||
**field_extractor.py 不应该使用 src/normalize 模块重构**,因为:
|
|
||||||
|
|
||||||
1. ✅ **职责不同**: 提取 vs 变体生成
|
|
||||||
2. ✅ **输入不同**: 包含标签的原始OCR文本 vs 已提取的字段值
|
|
||||||
3. ✅ **输出不同**: 单个提取值 vs 多个匹配变体
|
|
||||||
4. ✅ **现有代码运行良好**: 所有45个测试通过
|
|
||||||
5. ✅ **提取逻辑有价值**: 包含复杂的模式匹配规则(例如区分 Bankgiro/Plusgiro 格式)
|
|
||||||
|
|
||||||
## 建议
|
|
||||||
|
|
||||||
1. **保留 field_extractor.py 原样**: 不进行重构
|
|
||||||
2. **文档化两个模块的差异**: 确保团队理解各自用途
|
|
||||||
3. **关注其他优化目标**: machine_code_parser.py (919行)
|
|
||||||
|
|
||||||
## 学习点
|
|
||||||
|
|
||||||
重构前应该:
|
|
||||||
1. 理解模块的**真实用途**,而不只是看代码相似度
|
|
||||||
2. 运行完整测试套件验证假设
|
|
||||||
3. 评估是否真的存在重复,还是表面相似但用途不同
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**状态**: ✅ 分析完成,决定不重构
|
|
||||||
**测试**: ✅ 45/45 通过
|
|
||||||
**文件**: 保持 1183行 原样
|
|
||||||
@@ -1,238 +0,0 @@
|
|||||||
# Machine Code Parser 分析报告
|
|
||||||
|
|
||||||
## 文件概况
|
|
||||||
|
|
||||||
- **文件**: `src/ocr/machine_code_parser.py`
|
|
||||||
- **总行数**: 919 行
|
|
||||||
- **代码行**: 607 行 (66%)
|
|
||||||
- **方法数**: 14 个
|
|
||||||
- **正则表达式使用**: 47 次
|
|
||||||
|
|
||||||
## 代码结构
|
|
||||||
|
|
||||||
### 类结构
|
|
||||||
|
|
||||||
```
|
|
||||||
MachineCodeResult (数据类)
|
|
||||||
├── to_dict()
|
|
||||||
└── get_region_bbox()
|
|
||||||
|
|
||||||
MachineCodeParser (主解析器)
|
|
||||||
├── __init__()
|
|
||||||
├── parse() - 主入口
|
|
||||||
├── _find_tokens_with_values()
|
|
||||||
├── _find_machine_code_line_tokens()
|
|
||||||
├── _parse_standard_payment_line_with_tokens()
|
|
||||||
├── _parse_standard_payment_line() - 142行 ⚠️
|
|
||||||
├── _extract_ocr() - 50行
|
|
||||||
├── _extract_bankgiro() - 58行
|
|
||||||
├── _extract_plusgiro() - 30行
|
|
||||||
├── _extract_amount() - 68行
|
|
||||||
├── _calculate_confidence()
|
|
||||||
└── cross_validate()
|
|
||||||
```
|
|
||||||
|
|
||||||
## 发现的问题
|
|
||||||
|
|
||||||
### 1. ⚠️ `_parse_standard_payment_line` 方法过长 (142行)
|
|
||||||
|
|
||||||
**位置**: 442-582 行
|
|
||||||
|
|
||||||
**问题**:
|
|
||||||
- 包含嵌套函数 `normalize_account_spaces` 和 `format_account`
|
|
||||||
- 多个正则匹配分支
|
|
||||||
- 逻辑复杂,难以测试和维护
|
|
||||||
|
|
||||||
**建议**:
|
|
||||||
可以拆分为独立方法:
|
|
||||||
- `_normalize_account_spaces(line)`
|
|
||||||
- `_format_account(account_digits, context)`
|
|
||||||
- `_match_primary_pattern(line)`
|
|
||||||
- `_match_fallback_patterns(line)`
|
|
||||||
|
|
||||||
### 2. 🔁 4个 `_extract_*` 方法有重复模式
|
|
||||||
|
|
||||||
所有 extract 方法都遵循相同模式:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def _extract_XXX(self, tokens):
|
|
||||||
candidates = []
|
|
||||||
|
|
||||||
for token in tokens:
|
|
||||||
text = token.text.strip()
|
|
||||||
matches = self.XXX_PATTERN.findall(text)
|
|
||||||
for match in matches:
|
|
||||||
# 验证逻辑
|
|
||||||
# 上下文检测
|
|
||||||
candidates.append((normalized, context_score, token))
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
return None
|
|
||||||
|
|
||||||
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
|
|
||||||
return candidates[0][0]
|
|
||||||
```
|
|
||||||
|
|
||||||
**重复的逻辑**:
|
|
||||||
- Token 迭代
|
|
||||||
- 模式匹配
|
|
||||||
- 候选收集
|
|
||||||
- 上下文评分
|
|
||||||
- 排序和选择最佳匹配
|
|
||||||
|
|
||||||
**建议**:
|
|
||||||
可以提取基础提取器类或通用方法来减少重复。
|
|
||||||
|
|
||||||
### 3. ✅ 上下文检测重复
|
|
||||||
|
|
||||||
上下文检测代码在多个地方重复:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# _extract_bankgiro 中
|
|
||||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
|
||||||
is_bankgiro_context = (
|
|
||||||
'bankgiro' in context_text or
|
|
||||||
'bg:' in context_text or
|
|
||||||
'bg ' in context_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# _extract_plusgiro 中
|
|
||||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
|
||||||
is_plusgiro_context = (
|
|
||||||
'plusgiro' in context_text or
|
|
||||||
'postgiro' in context_text or
|
|
||||||
'pg:' in context_text or
|
|
||||||
'pg ' in context_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# _parse_standard_payment_line 中
|
|
||||||
context = (context_line or raw_line).lower()
|
|
||||||
is_plusgiro_context = (
|
|
||||||
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
|
|
||||||
and 'bankgiro' not in context
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**建议**:
|
|
||||||
提取为独立方法:
|
|
||||||
- `_detect_account_context(tokens) -> dict[str, bool]`
|
|
||||||
|
|
||||||
## 重构建议
|
|
||||||
|
|
||||||
### 方案 A: 轻度重构(推荐)✅
|
|
||||||
|
|
||||||
**目标**: 提取重复的上下文检测逻辑,不改变主要结构
|
|
||||||
|
|
||||||
**步骤**:
|
|
||||||
1. 提取 `_detect_account_context(tokens)` 方法
|
|
||||||
2. 提取 `_normalize_account_spaces(line)` 为独立方法
|
|
||||||
3. 提取 `_format_account(digits, context)` 为独立方法
|
|
||||||
|
|
||||||
**影响**:
|
|
||||||
- 减少 ~50-80 行重复代码
|
|
||||||
- 提高可测试性
|
|
||||||
- 低风险,易于验证
|
|
||||||
|
|
||||||
**预期结果**: 919 行 → ~850 行 (↓7%)
|
|
||||||
|
|
||||||
### 方案 B: 中度重构
|
|
||||||
|
|
||||||
**目标**: 创建通用的字段提取框架
|
|
||||||
|
|
||||||
**步骤**:
|
|
||||||
1. 创建 `_generic_extract(pattern, normalizer, context_checker)`
|
|
||||||
2. 重构所有 `_extract_*` 方法使用通用框架
|
|
||||||
3. 拆分 `_parse_standard_payment_line` 为多个小方法
|
|
||||||
|
|
||||||
**影响**:
|
|
||||||
- 减少 ~150-200 行代码
|
|
||||||
- 显著提高可维护性
|
|
||||||
- 中等风险,需要全面测试
|
|
||||||
|
|
||||||
**预期结果**: 919 行 → ~720 行 (↓22%)
|
|
||||||
|
|
||||||
### 方案 C: 深度重构(不推荐)
|
|
||||||
|
|
||||||
**目标**: 完全重新设计为策略模式
|
|
||||||
|
|
||||||
**风险**:
|
|
||||||
- 高风险,可能引入 bugs
|
|
||||||
- 需要大量测试
|
|
||||||
- 可能破坏现有集成
|
|
||||||
|
|
||||||
## 推荐方案
|
|
||||||
|
|
||||||
### ✅ 采用方案 A(轻度重构)
|
|
||||||
|
|
||||||
**理由**:
|
|
||||||
1. **代码已经工作良好**: 没有明显的 bug 或性能问题
|
|
||||||
2. **低风险**: 只提取重复逻辑,不改变核心算法
|
|
||||||
3. **性价比高**: 小改动带来明显的代码质量提升
|
|
||||||
4. **易于验证**: 现有测试应该能覆盖
|
|
||||||
|
|
||||||
### 重构步骤
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 1. 提取上下文检测
|
|
||||||
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
|
|
||||||
"""检测上下文中的账户类型关键词"""
|
|
||||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
|
|
||||||
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 2. 提取空格标准化
|
|
||||||
def _normalize_account_spaces(self, line: str) -> str:
|
|
||||||
"""移除账户号码中的空格"""
|
|
||||||
# (现有 line 460-481 的代码)
|
|
||||||
|
|
||||||
# 3. 提取账户格式化
|
|
||||||
def _format_account(
|
|
||||||
self,
|
|
||||||
account_digits: str,
|
|
||||||
is_plusgiro_context: bool
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""格式化账户并确定类型"""
|
|
||||||
# (现有 line 485-523 的代码)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 对比:field_extractor vs machine_code_parser
|
|
||||||
|
|
||||||
| 特征 | field_extractor | machine_code_parser |
|
|
||||||
|------|-----------------|---------------------|
|
|
||||||
| 用途 | 值提取 | 机器码解析 |
|
|
||||||
| 重复代码 | ~400行normalize方法 | ~80行上下文检测 |
|
|
||||||
| 重构价值 | ❌ 不同用途,不应统一 | ✅ 可提取共享逻辑 |
|
|
||||||
| 风险 | 高(会破坏功能) | 低(只是代码组织) |
|
|
||||||
|
|
||||||
## 决策
|
|
||||||
|
|
||||||
### ✅ 建议重构 machine_code_parser.py
|
|
||||||
|
|
||||||
**与 field_extractor 的不同**:
|
|
||||||
- field_extractor: 重复的方法有**不同的用途**(提取 vs 变体生成)
|
|
||||||
- machine_code_parser: 重复的代码有**相同的用途**(都是上下文检测)
|
|
||||||
|
|
||||||
**预期收益**:
|
|
||||||
- 减少 ~70 行重复代码
|
|
||||||
- 提高可测试性(可以单独测试上下文检测)
|
|
||||||
- 更清晰的代码组织
|
|
||||||
- **低风险**,易于验证
|
|
||||||
|
|
||||||
## 下一步
|
|
||||||
|
|
||||||
1. ✅ 备份原文件
|
|
||||||
2. ✅ 提取 `_detect_account_context` 方法
|
|
||||||
3. ✅ 提取 `_normalize_account_spaces` 方法
|
|
||||||
4. ✅ 提取 `_format_account` 方法
|
|
||||||
5. ✅ 更新所有调用点
|
|
||||||
6. ✅ 运行测试验证
|
|
||||||
7. ✅ 检查代码覆盖率
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**状态**: 📋 分析完成,建议轻度重构
|
|
||||||
**风险评估**: 🟢 低风险
|
|
||||||
**预期收益**: 919行 → ~850行 (↓7%)
|
|
||||||
@@ -1,519 +0,0 @@
|
|||||||
# Performance Optimization Guide
|
|
||||||
|
|
||||||
This document provides performance optimization recommendations for the Invoice Field Extraction system.
|
|
||||||
|
|
||||||
## Table of Contents
|
|
||||||
|
|
||||||
1. [Batch Processing Optimization](#batch-processing-optimization)
|
|
||||||
2. [Database Query Optimization](#database-query-optimization)
|
|
||||||
3. [Caching Strategies](#caching-strategies)
|
|
||||||
4. [Memory Management](#memory-management)
|
|
||||||
5. [Profiling and Monitoring](#profiling-and-monitoring)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Batch Processing Optimization
|
|
||||||
|
|
||||||
### Current State
|
|
||||||
|
|
||||||
The system processes invoices one at a time. For large batches, this can be inefficient.
|
|
||||||
|
|
||||||
### Recommendations
|
|
||||||
|
|
||||||
#### 1. Database Batch Operations
|
|
||||||
|
|
||||||
**Current**: Individual inserts for each document
|
|
||||||
```python
|
|
||||||
# Inefficient
|
|
||||||
for doc in documents:
|
|
||||||
db.insert_document(doc) # Individual DB call
|
|
||||||
```
|
|
||||||
|
|
||||||
**Optimized**: Use `execute_values` for batch inserts
|
|
||||||
```python
|
|
||||||
# Efficient - already implemented in db.py line 519
|
|
||||||
from psycopg2.extras import execute_values
|
|
||||||
|
|
||||||
execute_values(cursor, """
|
|
||||||
INSERT INTO documents (...)
|
|
||||||
VALUES %s
|
|
||||||
""", document_values)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**: 10-50x faster for batches of 100+ documents
|
|
||||||
|
|
||||||
#### 2. PDF Processing Batching
|
|
||||||
|
|
||||||
**Recommendation**: Process PDFs in parallel using multiprocessing
|
|
||||||
|
|
||||||
```python
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
def process_batch(pdf_paths, batch_size=10):
|
|
||||||
"""Process PDFs in parallel batches."""
|
|
||||||
with Pool(processes=batch_size) as pool:
|
|
||||||
results = pool.map(pipeline.process_pdf, pdf_paths)
|
|
||||||
return results
|
|
||||||
```
|
|
||||||
|
|
||||||
**Considerations**:
|
|
||||||
- GPU models should use a shared process pool (already exists: `src/processing/gpu_pool.py`)
|
|
||||||
- CPU-intensive tasks can use separate process pool (`src/processing/cpu_pool.py`)
|
|
||||||
- Current dual pool coordinator (`dual_pool_coordinator.py`) already supports this pattern
|
|
||||||
|
|
||||||
**Status**: ✅ Already implemented in `src/processing/` modules
|
|
||||||
|
|
||||||
#### 3. Image Caching for Multi-Page PDFs
|
|
||||||
|
|
||||||
**Current**: Each page rendered independently
|
|
||||||
```python
|
|
||||||
# Current pattern in field_extractor.py
|
|
||||||
for page_num in range(total_pages):
|
|
||||||
image = render_pdf_page(pdf_path, page_num, dpi=300)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Optimized**: Pre-render all pages if processing multiple fields per page
|
|
||||||
```python
|
|
||||||
# Batch render
|
|
||||||
images = {
|
|
||||||
page_num: render_pdf_page(pdf_path, page_num, dpi=300)
|
|
||||||
for page_num in page_numbers_needed
|
|
||||||
}
|
|
||||||
|
|
||||||
# Reuse images
|
|
||||||
for detection in detections:
|
|
||||||
image = images[detection.page_no]
|
|
||||||
extract_field(detection, image)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**: Reduces redundant PDF rendering by 50-90% for multi-field invoices
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Database Query Optimization
|
|
||||||
|
|
||||||
### Current Performance
|
|
||||||
|
|
||||||
- **Parameterized queries**: ✅ Implemented (Phase 1)
|
|
||||||
- **Connection pooling**: ❌ Not implemented
|
|
||||||
- **Query batching**: ✅ Partially implemented
|
|
||||||
- **Index optimization**: ⚠️ Needs verification
|
|
||||||
|
|
||||||
### Recommendations
|
|
||||||
|
|
||||||
#### 1. Connection Pooling
|
|
||||||
|
|
||||||
**Current**: New connection for each operation
|
|
||||||
```python
|
|
||||||
def connect(self):
|
|
||||||
"""Create new database connection."""
|
|
||||||
return psycopg2.connect(**self.config)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Optimized**: Use connection pooling
|
|
||||||
```python
|
|
||||||
from psycopg2 import pool
|
|
||||||
|
|
||||||
class DocumentDatabase:
|
|
||||||
def __init__(self, config):
|
|
||||||
self.pool = pool.SimpleConnectionPool(
|
|
||||||
minconn=1,
|
|
||||||
maxconn=10,
|
|
||||||
**config
|
|
||||||
)
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
return self.pool.getconn()
|
|
||||||
|
|
||||||
def close(self, conn):
|
|
||||||
self.pool.putconn(conn)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**:
|
|
||||||
- Reduces connection overhead by 80-95%
|
|
||||||
- Especially important for high-frequency operations
|
|
||||||
|
|
||||||
#### 2. Index Recommendations
|
|
||||||
|
|
||||||
**Check current indexes**:
|
|
||||||
```sql
|
|
||||||
-- Verify indexes exist on frequently queried columns
|
|
||||||
SELECT tablename, indexname, indexdef
|
|
||||||
FROM pg_indexes
|
|
||||||
WHERE schemaname = 'public';
|
|
||||||
```
|
|
||||||
|
|
||||||
**Recommended indexes**:
|
|
||||||
```sql
|
|
||||||
-- If not already present
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_documents_success
|
|
||||||
ON documents(success);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_documents_timestamp
|
|
||||||
ON documents(timestamp DESC);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_field_results_document_id
|
|
||||||
ON field_results(document_id);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_field_results_matched
|
|
||||||
ON field_results(matched);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_field_results_field_name
|
|
||||||
ON field_results(field_name);
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**:
|
|
||||||
- 10-100x faster queries for filtered/sorted results
|
|
||||||
- Critical for `get_failed_matches()` and `get_all_documents_summary()`
|
|
||||||
|
|
||||||
#### 3. Query Batching
|
|
||||||
|
|
||||||
**Status**: ✅ Already implemented for field results (line 519)
|
|
||||||
|
|
||||||
**Verify batching is used**:
|
|
||||||
```python
|
|
||||||
# Good pattern in db.py
|
|
||||||
execute_values(cursor, "INSERT INTO field_results (...) VALUES %s", field_values)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Additional opportunity**: Batch `SELECT` queries
|
|
||||||
```python
|
|
||||||
# Current
|
|
||||||
docs = [get_document(doc_id) for doc_id in doc_ids] # N queries
|
|
||||||
|
|
||||||
# Optimized
|
|
||||||
docs = get_documents_batch(doc_ids) # 1 query with IN clause
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**: ✅ Already implemented (`get_documents_batch` exists in db.py)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Caching Strategies
|
|
||||||
|
|
||||||
### 1. Model Loading Cache
|
|
||||||
|
|
||||||
**Current**: Models loaded per-instance
|
|
||||||
|
|
||||||
**Recommendation**: Singleton pattern for YOLO model
|
|
||||||
```python
|
|
||||||
class YOLODetectorSingleton:
|
|
||||||
_instance = None
|
|
||||||
_model = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls, model_path):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = YOLODetector(model_path)
|
|
||||||
return cls._instance
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**: Reduces memory usage by 90% when processing multiple documents
|
|
||||||
|
|
||||||
### 2. Parser Instance Caching
|
|
||||||
|
|
||||||
**Current**: ✅ Already optimal
|
|
||||||
```python
|
|
||||||
# Good pattern in field_extractor.py
|
|
||||||
def __init__(self):
|
|
||||||
self.payment_line_parser = PaymentLineParser() # Reused
|
|
||||||
self.customer_number_parser = CustomerNumberParser() # Reused
|
|
||||||
```
|
|
||||||
|
|
||||||
**Status**: No changes needed
|
|
||||||
|
|
||||||
### 3. OCR Result Caching
|
|
||||||
|
|
||||||
**Recommendation**: Cache OCR results for identical regions
|
|
||||||
```python
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1000)
|
|
||||||
def ocr_region_cached(image_hash, bbox):
|
|
||||||
"""Cache OCR results by image hash + bbox."""
|
|
||||||
return paddle_ocr.ocr_region(image, bbox)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**: 50-80% speedup when re-processing similar documents
|
|
||||||
|
|
||||||
**Note**: Requires implementing image hashing (e.g., `hashlib.md5(image.tobytes())`)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Memory Management
|
|
||||||
|
|
||||||
### Current Issues
|
|
||||||
|
|
||||||
**Potential memory leaks**:
|
|
||||||
1. Large images kept in memory after processing
|
|
||||||
2. OCR results accumulated without cleanup
|
|
||||||
3. Model outputs not explicitly cleared
|
|
||||||
|
|
||||||
### Recommendations
|
|
||||||
|
|
||||||
#### 1. Explicit Image Cleanup
|
|
||||||
|
|
||||||
```python
|
|
||||||
import gc
|
|
||||||
|
|
||||||
def process_pdf(pdf_path):
|
|
||||||
try:
|
|
||||||
image = render_pdf(pdf_path)
|
|
||||||
result = extract_fields(image)
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
del image # Explicit cleanup
|
|
||||||
gc.collect() # Force garbage collection
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 2. Generator Pattern for Large Batches
|
|
||||||
|
|
||||||
**Current**: Load all documents into memory
|
|
||||||
```python
|
|
||||||
docs = [process_pdf(path) for path in pdf_paths] # All in memory
|
|
||||||
```
|
|
||||||
|
|
||||||
**Optimized**: Use generator for streaming processing
|
|
||||||
```python
|
|
||||||
def process_batch_streaming(pdf_paths):
|
|
||||||
"""Process documents one at a time, yielding results."""
|
|
||||||
for path in pdf_paths:
|
|
||||||
result = process_pdf(path)
|
|
||||||
yield result
|
|
||||||
# Result can be saved to DB immediately
|
|
||||||
# Previous result is garbage collected
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**: Constant memory usage regardless of batch size
|
|
||||||
|
|
||||||
#### 3. Context Managers for Resources
|
|
||||||
|
|
||||||
```python
|
|
||||||
class InferencePipeline:
|
|
||||||
def __enter__(self):
|
|
||||||
self.detector.load_model()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self.detector.unload_model()
|
|
||||||
self.extractor.cleanup()
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
with InferencePipeline(...) as pipeline:
|
|
||||||
results = pipeline.process_pdf(path)
|
|
||||||
# Automatic cleanup
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Profiling and Monitoring
|
|
||||||
|
|
||||||
### Recommended Profiling Tools
|
|
||||||
|
|
||||||
#### 1. cProfile for CPU Profiling
|
|
||||||
|
|
||||||
```python
|
|
||||||
import cProfile
|
|
||||||
import pstats
|
|
||||||
|
|
||||||
profiler = cProfile.Profile()
|
|
||||||
profiler.enable()
|
|
||||||
|
|
||||||
# Your code here
|
|
||||||
pipeline.process_pdf(pdf_path)
|
|
||||||
|
|
||||||
profiler.disable()
|
|
||||||
stats = pstats.Stats(profiler)
|
|
||||||
stats.sort_stats('cumulative')
|
|
||||||
stats.print_stats(20) # Top 20 slowest functions
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 2. memory_profiler for Memory Analysis
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install memory_profiler
|
|
||||||
python -m memory_profiler your_script.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Or decorator-based:
|
|
||||||
```python
|
|
||||||
from memory_profiler import profile
|
|
||||||
|
|
||||||
@profile
|
|
||||||
def process_large_batch(pdf_paths):
|
|
||||||
# Memory usage tracked line-by-line
|
|
||||||
results = [process_pdf(path) for path in pdf_paths]
|
|
||||||
return results
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 3. py-spy for Production Profiling
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install py-spy
|
|
||||||
|
|
||||||
# Profile running process
|
|
||||||
py-spy top --pid 12345
|
|
||||||
|
|
||||||
# Generate flamegraph
|
|
||||||
py-spy record -o profile.svg -- python your_script.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Advantage**: No code changes needed, minimal overhead
|
|
||||||
|
|
||||||
### Key Metrics to Monitor
|
|
||||||
|
|
||||||
1. **Processing Time per Document**
|
|
||||||
- Target: <10 seconds for single-page invoice
|
|
||||||
- Current: ~2-5 seconds (estimated)
|
|
||||||
|
|
||||||
2. **Memory Usage**
|
|
||||||
- Target: <2GB for batch of 100 documents
|
|
||||||
- Monitor: Peak memory usage
|
|
||||||
|
|
||||||
3. **Database Query Time**
|
|
||||||
- Target: <100ms per query (with indexes)
|
|
||||||
- Monitor: Slow query log
|
|
||||||
|
|
||||||
4. **OCR Accuracy vs Speed Trade-off**
|
|
||||||
- Current: PaddleOCR with GPU (~200ms per region)
|
|
||||||
- Alternative: Tesseract (~500ms, slightly more accurate)
|
|
||||||
|
|
||||||
### Logging Performance Metrics
|
|
||||||
|
|
||||||
**Add to pipeline.py**:
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def process_pdf(self, pdf_path):
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
# Processing...
|
|
||||||
result = self._process_internal(pdf_path)
|
|
||||||
|
|
||||||
elapsed = time.time() - start
|
|
||||||
logger.info(f"Processed {pdf_path} in {elapsed:.2f}s")
|
|
||||||
|
|
||||||
# Log to database for analysis
|
|
||||||
self.db.log_performance({
|
|
||||||
'document_id': result.document_id,
|
|
||||||
'processing_time': elapsed,
|
|
||||||
'field_count': len(result.fields)
|
|
||||||
})
|
|
||||||
|
|
||||||
return result
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Performance Optimization Priorities
|
|
||||||
|
|
||||||
### High Priority (Implement First)
|
|
||||||
|
|
||||||
1. ✅ **Database parameterized queries** - Already done (Phase 1)
|
|
||||||
2. ⚠️ **Database connection pooling** - Not implemented
|
|
||||||
3. ⚠️ **Index optimization** - Needs verification
|
|
||||||
|
|
||||||
### Medium Priority
|
|
||||||
|
|
||||||
4. ⚠️ **Batch PDF rendering** - Optimization possible
|
|
||||||
5. ✅ **Parser instance reuse** - Already done (Phase 2)
|
|
||||||
6. ⚠️ **Model caching** - Could improve
|
|
||||||
|
|
||||||
### Low Priority (Nice to Have)
|
|
||||||
|
|
||||||
7. ⚠️ **OCR result caching** - Complex implementation
|
|
||||||
8. ⚠️ **Generator patterns** - Refactoring needed
|
|
||||||
9. ⚠️ **Advanced profiling** - For production optimization
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Benchmarking Script
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""
|
|
||||||
Benchmark script for invoice processing performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from src.inference.pipeline import InferencePipeline
|
|
||||||
|
|
||||||
def benchmark_single_document(pdf_path, iterations=10):
|
|
||||||
"""Benchmark single document processing."""
|
|
||||||
pipeline = InferencePipeline(
|
|
||||||
model_path="path/to/model.pt",
|
|
||||||
use_gpu=True
|
|
||||||
)
|
|
||||||
|
|
||||||
times = []
|
|
||||||
for i in range(iterations):
|
|
||||||
start = time.time()
|
|
||||||
result = pipeline.process_pdf(pdf_path)
|
|
||||||
elapsed = time.time() - start
|
|
||||||
times.append(elapsed)
|
|
||||||
print(f"Iteration {i+1}: {elapsed:.2f}s")
|
|
||||||
|
|
||||||
avg_time = sum(times) / len(times)
|
|
||||||
print(f"\nAverage: {avg_time:.2f}s")
|
|
||||||
print(f"Min: {min(times):.2f}s")
|
|
||||||
print(f"Max: {max(times):.2f}s")
|
|
||||||
|
|
||||||
def benchmark_batch(pdf_paths, batch_size=10):
|
|
||||||
"""Benchmark batch processing."""
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
pipeline = InferencePipeline(
|
|
||||||
model_path="path/to/model.pt",
|
|
||||||
use_gpu=True
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
with Pool(processes=batch_size) as pool:
|
|
||||||
results = pool.map(pipeline.process_pdf, pdf_paths)
|
|
||||||
|
|
||||||
elapsed = time.time() - start
|
|
||||||
avg_per_doc = elapsed / len(pdf_paths)
|
|
||||||
|
|
||||||
print(f"Total time: {elapsed:.2f}s")
|
|
||||||
print(f"Documents: {len(pdf_paths)}")
|
|
||||||
print(f"Average per document: {avg_per_doc:.2f}s")
|
|
||||||
print(f"Throughput: {len(pdf_paths)/elapsed:.2f} docs/sec")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Single document benchmark
|
|
||||||
benchmark_single_document("test.pdf")
|
|
||||||
|
|
||||||
# Batch benchmark
|
|
||||||
pdf_paths = list(Path("data/test_pdfs").glob("*.pdf"))
|
|
||||||
benchmark_batch(pdf_paths[:100])
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
**Implemented (Phase 1-2)**:
|
|
||||||
- ✅ Parameterized queries (SQL injection fix)
|
|
||||||
- ✅ Parser instance reuse (Phase 2 refactoring)
|
|
||||||
- ✅ Batch insert operations (execute_values)
|
|
||||||
- ✅ Dual pool processing (CPU/GPU separation)
|
|
||||||
|
|
||||||
**Quick Wins (Low effort, high impact)**:
|
|
||||||
- Database connection pooling (2-4 hours)
|
|
||||||
- Index verification and optimization (1-2 hours)
|
|
||||||
- Batch PDF rendering (4-6 hours)
|
|
||||||
|
|
||||||
**Long-term Improvements**:
|
|
||||||
- OCR result caching with hashing
|
|
||||||
- Generator patterns for streaming
|
|
||||||
- Advanced profiling and monitoring
|
|
||||||
|
|
||||||
**Expected Impact**:
|
|
||||||
- Connection pooling: 80-95% reduction in DB overhead
|
|
||||||
- Indexes: 10-100x faster queries
|
|
||||||
- Batch rendering: 50-90% less redundant work
|
|
||||||
- **Overall**: 2-5x throughput improvement for batch processing
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,170 +0,0 @@
|
|||||||
# 代码重构总结报告
|
|
||||||
|
|
||||||
## 📊 整体成果
|
|
||||||
|
|
||||||
### 测试状态
|
|
||||||
- ✅ **688/688 测试全部通过** (100%)
|
|
||||||
- ✅ **代码覆盖率**: 34% → 37% (+3%)
|
|
||||||
- ✅ **0 个失败**, 0 个错误
|
|
||||||
|
|
||||||
### 测试覆盖率改进
|
|
||||||
- ✅ **machine_code_parser**: 25% → 65% (+40%)
|
|
||||||
- ✅ **新增测试**: 55个(633 → 688)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎯 已完成的重构
|
|
||||||
|
|
||||||
### 1. ✅ Matcher 模块化 (876行 → 205行, ↓76%)
|
|
||||||
|
|
||||||
**文件**:
|
|
||||||
|
|
||||||
**重构内容**:
|
|
||||||
- 将单一876行文件拆分为 **11个模块**
|
|
||||||
- 提取 **5种独立的匹配策略**
|
|
||||||
- 创建专门的数据模型、工具函数和上下文处理模块
|
|
||||||
|
|
||||||
**新模块结构**:
|
|
||||||
|
|
||||||
|
|
||||||
**测试结果**:
|
|
||||||
- ✅ 77个 matcher 测试全部通过
|
|
||||||
- ✅ 完整的README文档
|
|
||||||
- ✅ 策略模式,易于扩展
|
|
||||||
|
|
||||||
**收益**:
|
|
||||||
- 📉 代码量减少 76%
|
|
||||||
- 📈 可维护性显著提高
|
|
||||||
- ✨ 每个策略独立测试
|
|
||||||
- 🔧 易于添加新策略
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2. ✅ Machine Code Parser 轻度重构 + 测试覆盖 (919行 → 929行)
|
|
||||||
|
|
||||||
**文件**: src/ocr/machine_code_parser.py
|
|
||||||
|
|
||||||
**重构内容**:
|
|
||||||
- 提取 **3个共享辅助方法**,消除重复代码
|
|
||||||
- 优化上下文检测逻辑
|
|
||||||
- 简化账号格式化方法
|
|
||||||
|
|
||||||
**测试改进**:
|
|
||||||
- ✅ **新增55个测试**(24 → 79个)
|
|
||||||
- ✅ **覆盖率**: 25% → 65% (+40%)
|
|
||||||
- ✅ 所有688个项目测试通过
|
|
||||||
|
|
||||||
**新增测试覆盖**:
|
|
||||||
- **第一轮** (22个测试):
|
|
||||||
- `_detect_account_context()` - 8个测试(上下文检测)
|
|
||||||
- `_normalize_account_spaces()` - 5个测试(空格规范化)
|
|
||||||
- `_format_account()` - 4个测试(账号格式化)
|
|
||||||
- `parse()` - 5个测试(主入口方法)
|
|
||||||
- **第二轮** (33个测试):
|
|
||||||
- `_extract_ocr()` - 8个测试(OCR 提取)
|
|
||||||
- `_extract_bankgiro()` - 9个测试(Bankgiro 提取)
|
|
||||||
- `_extract_plusgiro()` - 8个测试(Plusgiro 提取)
|
|
||||||
- `_extract_amount()` - 8个测试(金额提取)
|
|
||||||
|
|
||||||
**收益**:
|
|
||||||
- 🔄 消除80行重复代码
|
|
||||||
- 📈 可测试性提高(可独立测试辅助方法)
|
|
||||||
- 📖 代码可读性提升
|
|
||||||
- ✅ 覆盖率从25%提升到65% (+40%)
|
|
||||||
- 🎯 低风险,高回报
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 3. ✅ Field Extractor 分析 (决定不重构)
|
|
||||||
|
|
||||||
**文件**: (1183行)
|
|
||||||
|
|
||||||
**分析结果**: ❌ **不应重构**
|
|
||||||
|
|
||||||
**关键洞察**:
|
|
||||||
- 表面相似的代码可能有**完全不同的用途**
|
|
||||||
- field_extractor: **解析/提取** 字段值
|
|
||||||
- src/normalize: **标准化/生成变体** 用于匹配
|
|
||||||
- 两者职责不同,不应统一
|
|
||||||
|
|
||||||
**文档**:
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📈 重构统计
|
|
||||||
|
|
||||||
### 代码行数变化
|
|
||||||
|
|
||||||
| 文件 | 重构前 | 重构后 | 变化 | 百分比 |
|
|
||||||
|------|--------|--------|------|--------|
|
|
||||||
| **matcher/field_matcher.py** | 876行 | 205行 | -671 | ↓76% |
|
|
||||||
| **matcher/* (新增10个模块)** | 0行 | 466行 | +466 | 新增 |
|
|
||||||
| **matcher 总计** | 876行 | 671行 | -205 | ↓23% |
|
|
||||||
| **ocr/machine_code_parser.py** | 919行 | 929行 | +10 | +1% |
|
|
||||||
| **总净减少** | - | - | **-195行** | **↓11%** |
|
|
||||||
|
|
||||||
### 测试覆盖
|
|
||||||
|
|
||||||
| 模块 | 测试数 | 通过率 | 覆盖率 | 状态 |
|
|
||||||
|------|--------|--------|--------|------|
|
|
||||||
| matcher | 77 | 100% | - | ✅ |
|
|
||||||
| field_extractor | 45 | 100% | 39% | ✅ |
|
|
||||||
| machine_code_parser | 79 | 100% | 65% | ✅ |
|
|
||||||
| normalizer | ~120 | 100% | - | ✅ |
|
|
||||||
| 其他模块 | ~367 | 100% | - | ✅ |
|
|
||||||
| **总计** | **688** | **100%** | **37%** | ✅ |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎓 重构经验总结
|
|
||||||
|
|
||||||
### 成功经验
|
|
||||||
|
|
||||||
1. **✅ 先测试后重构**
|
|
||||||
- 所有重构都有完整测试覆盖
|
|
||||||
- 每次改动后立即验证测试
|
|
||||||
- 100%测试通过率保证质量
|
|
||||||
|
|
||||||
2. **✅ 识别真正的重复**
|
|
||||||
- 不是所有相似代码都是重复
|
|
||||||
- field_extractor vs normalizer: 表面相似但用途不同
|
|
||||||
- machine_code_parser: 真正的代码重复
|
|
||||||
|
|
||||||
3. **✅ 渐进式重构**
|
|
||||||
- matcher: 大规模模块化 (策略模式)
|
|
||||||
- machine_code_parser: 轻度重构 (提取共享方法)
|
|
||||||
- field_extractor: 分析后决定不重构
|
|
||||||
|
|
||||||
### 关键决策
|
|
||||||
|
|
||||||
#### ✅ 应该重构的情况
|
|
||||||
- **matcher**: 单一文件过长 (876行),包含多种策略
|
|
||||||
- **machine_code_parser**: 多处相同用途的重复代码
|
|
||||||
|
|
||||||
#### ❌ 不应重构的情况
|
|
||||||
- **field_extractor**: 相似代码有不同用途
|
|
||||||
|
|
||||||
### 教训
|
|
||||||
|
|
||||||
**不要盲目追求DRY原则**
|
|
||||||
> 相似代码不一定是重复。要理解代码的**真实用途**。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ✅ 总结
|
|
||||||
|
|
||||||
**关键成果**:
|
|
||||||
- 📉 净减少 195 行代码
|
|
||||||
- 📈 代码覆盖率 +3% (34% → 37%)
|
|
||||||
- ✅ 测试数量 +55 (633 → 688)
|
|
||||||
- 🎯 machine_code_parser 覆盖率 +40% (25% → 65%)
|
|
||||||
- ✨ 模块化程度显著提高
|
|
||||||
- 🎯 可维护性大幅提升
|
|
||||||
|
|
||||||
**重要教训**:
|
|
||||||
> 相似的代码不一定是重复的代码。理解代码的真实用途,才能做出正确的重构决策。
|
|
||||||
|
|
||||||
**下一步建议**:
|
|
||||||
1. 继续提升 machine_code_parser 覆盖率到 80%+ (目前 65%)
|
|
||||||
2. 为其他低覆盖模块添加测试(field_extractor 39%, pipeline 19%)
|
|
||||||
3. 完善边界条件和异常情况的测试
|
|
||||||
@@ -1,258 +0,0 @@
|
|||||||
# 测试覆盖率改进报告
|
|
||||||
|
|
||||||
## 📊 改进概览
|
|
||||||
|
|
||||||
### 整体统计
|
|
||||||
- ✅ **测试总数**: 633 → 688 (+55个测试, +8.7%)
|
|
||||||
- ✅ **通过率**: 100% (688/688)
|
|
||||||
- ✅ **整体覆盖率**: 34% → 37% (+3%)
|
|
||||||
|
|
||||||
### machine_code_parser.py 专项改进
|
|
||||||
- ✅ **测试数**: 24 → 79 (+55个测试, +229%)
|
|
||||||
- ✅ **覆盖率**: 25% → 65% (+40%)
|
|
||||||
- ✅ **未覆盖行**: 273 → 129 (减少144行)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎯 新增测试详情
|
|
||||||
|
|
||||||
### 第一轮改进 (22个测试)
|
|
||||||
|
|
||||||
#### 1. TestDetectAccountContext (8个测试)
|
|
||||||
|
|
||||||
测试新增的 `_detect_account_context()` 辅助方法。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_bankgiro_keyword` - 检测 'bankgiro' 关键词
|
|
||||||
2. `test_bg_keyword` - 检测 'bg:' 缩写
|
|
||||||
3. `test_plusgiro_keyword` - 检测 'plusgiro' 关键词
|
|
||||||
4. `test_postgiro_keyword` - 检测 'postgiro' 别名
|
|
||||||
5. `test_pg_keyword` - 检测 'pg:' 缩写
|
|
||||||
6. `test_both_contexts` - 同时存在两种关键词
|
|
||||||
7. `test_no_context` - 无账号关键词
|
|
||||||
8. `test_case_insensitive` - 大小写不敏感检测
|
|
||||||
|
|
||||||
**覆盖的代码路径**:
|
|
||||||
```python
|
|
||||||
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
|
|
||||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
|
||||||
return {
|
|
||||||
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
|
|
||||||
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2. TestNormalizeAccountSpacesMethod (5个测试)
|
|
||||||
|
|
||||||
测试新增的 `_normalize_account_spaces()` 辅助方法。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_removes_spaces_after_arrow` - 移除 > 后的空格
|
|
||||||
2. `test_multiple_consecutive_spaces` - 处理多个连续空格
|
|
||||||
3. `test_no_arrow_returns_unchanged` - 无 > 标记时返回原值
|
|
||||||
4. `test_spaces_before_arrow_preserved` - 保留 > 前的空格
|
|
||||||
5. `test_empty_string` - 空字符串处理
|
|
||||||
|
|
||||||
**覆盖的代码路径**:
|
|
||||||
```python
|
|
||||||
def _normalize_account_spaces(self, line: str) -> str:
|
|
||||||
if '>' not in line:
|
|
||||||
return line
|
|
||||||
parts = line.split('>', 1)
|
|
||||||
after_arrow = parts[1]
|
|
||||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
|
|
||||||
while re.search(r'(\d)\s+(\d)', normalized):
|
|
||||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
|
|
||||||
return parts[0] + '>' + normalized
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 3. TestFormatAccount (4个测试)
|
|
||||||
|
|
||||||
测试新增的 `_format_account()` 辅助方法。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_plusgiro_context_forces_plusgiro` - Plusgiro 上下文强制格式化为 Plusgiro
|
|
||||||
2. `test_valid_bankgiro_7_digits` - 7位有效 Bankgiro 格式化
|
|
||||||
3. `test_valid_bankgiro_8_digits` - 8位有效 Bankgiro 格式化
|
|
||||||
4. `test_defaults_to_bankgiro_when_ambiguous` - 模糊情况默认 Bankgiro
|
|
||||||
|
|
||||||
**覆盖的代码路径**:
|
|
||||||
```python
|
|
||||||
def _format_account(self, account_digits: str, is_plusgiro_context: bool) -> tuple[str, str]:
|
|
||||||
if is_plusgiro_context:
|
|
||||||
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
|
|
||||||
return formatted, 'plusgiro'
|
|
||||||
|
|
||||||
# Luhn 验证逻辑
|
|
||||||
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
|
|
||||||
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
|
|
||||||
|
|
||||||
# 决策逻辑
|
|
||||||
if pg_valid and not bg_valid:
|
|
||||||
return pg_formatted, 'plusgiro'
|
|
||||||
elif bg_valid and not pg_valid:
|
|
||||||
return bg_formatted, 'bankgiro'
|
|
||||||
else:
|
|
||||||
return bg_formatted, 'bankgiro'
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 4. TestParseMethod (5个测试)
|
|
||||||
|
|
||||||
测试主入口 `parse()` 方法。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_parse_empty_tokens` - 空 token 列表处理
|
|
||||||
2. `test_parse_finds_payment_line_in_bottom_region` - 在页面底部35%区域查找付款行
|
|
||||||
3. `test_parse_ignores_top_region` - 忽略页面顶部区域
|
|
||||||
4. `test_parse_with_context_keywords` - 检测上下文关键词
|
|
||||||
5. `test_parse_stores_source_tokens` - 存储源 token
|
|
||||||
|
|
||||||
**覆盖的代码路径**:
|
|
||||||
- Token 过滤(底部区域检测)
|
|
||||||
- 上下文关键词检测
|
|
||||||
- 付款行查找和解析
|
|
||||||
- 结果对象构建
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 第二轮改进 (33个测试)
|
|
||||||
|
|
||||||
#### 5. TestExtractOCR (8个测试)
|
|
||||||
|
|
||||||
测试 `_extract_ocr()` 方法 - OCR 参考号码提取。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_extract_valid_ocr_10_digits` - 提取10位 OCR 号码
|
|
||||||
2. `test_extract_valid_ocr_15_digits` - 提取15位 OCR 号码
|
|
||||||
3. `test_extract_ocr_with_hash_markers` - 带 # 标记的 OCR
|
|
||||||
4. `test_extract_longest_ocr_when_multiple` - 多个候选时选最长
|
|
||||||
5. `test_extract_ocr_ignores_short_numbers` - 忽略短于10位的数字
|
|
||||||
6. `test_extract_ocr_ignores_long_numbers` - 忽略长于25位的数字
|
|
||||||
7. `test_extract_ocr_excludes_bankgiro_variants` - 排除 Bankgiro 变体
|
|
||||||
8. `test_extract_ocr_empty_tokens` - 空 token 处理
|
|
||||||
|
|
||||||
#### 6. TestExtractBankgiro (9个测试)
|
|
||||||
|
|
||||||
测试 `_extract_bankgiro()` 方法 - Bankgiro 账号提取。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_extract_bankgiro_7_digits_with_dash` - 带破折号的7位 Bankgiro
|
|
||||||
2. `test_extract_bankgiro_7_digits_without_dash` - 无破折号的7位 Bankgiro
|
|
||||||
3. `test_extract_bankgiro_8_digits_with_dash` - 带破折号的8位 Bankgiro
|
|
||||||
4. `test_extract_bankgiro_8_digits_without_dash` - 无破折号的8位 Bankgiro
|
|
||||||
5. `test_extract_bankgiro_with_spaces` - 带空格的 Bankgiro
|
|
||||||
6. `test_extract_bankgiro_handles_plusgiro_format` - 处理 Plusgiro 格式
|
|
||||||
7. `test_extract_bankgiro_with_context` - 带上下文关键词
|
|
||||||
8. `test_extract_bankgiro_ignores_plusgiro_context` - 忽略 Plusgiro 上下文
|
|
||||||
9. `test_extract_bankgiro_empty_tokens` - 空 token 处理
|
|
||||||
|
|
||||||
#### 7. TestExtractPlusgiro (8个测试)
|
|
||||||
|
|
||||||
测试 `_extract_plusgiro()` 方法 - Plusgiro 账号提取。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_extract_plusgiro_7_digits_with_dash` - 带破折号的7位 Plusgiro
|
|
||||||
2. `test_extract_plusgiro_7_digits_without_dash` - 无破折号的7位 Plusgiro
|
|
||||||
3. `test_extract_plusgiro_8_digits` - 8位 Plusgiro
|
|
||||||
4. `test_extract_plusgiro_with_spaces` - 带空格的 Plusgiro
|
|
||||||
5. `test_extract_plusgiro_with_context` - 带上下文关键词
|
|
||||||
6. `test_extract_plusgiro_ignores_too_short` - 忽略少于7位
|
|
||||||
7. `test_extract_plusgiro_ignores_too_long` - 忽略多于8位
|
|
||||||
8. `test_extract_plusgiro_empty_tokens` - 空 token 处理
|
|
||||||
|
|
||||||
#### 8. TestExtractAmount (8个测试)
|
|
||||||
|
|
||||||
测试 `_extract_amount()` 方法 - 金额提取。
|
|
||||||
|
|
||||||
**测试用例**:
|
|
||||||
1. `test_extract_amount_with_comma_decimal` - 逗号小数分隔符
|
|
||||||
2. `test_extract_amount_with_dot_decimal` - 点号小数分隔符
|
|
||||||
3. `test_extract_amount_integer` - 整数金额
|
|
||||||
4. `test_extract_amount_with_thousand_separator` - 千位分隔符
|
|
||||||
5. `test_extract_amount_large_number` - 大额金额
|
|
||||||
6. `test_extract_amount_ignores_too_large` - 忽略过大金额
|
|
||||||
7. `test_extract_amount_ignores_zero` - 忽略零或负数
|
|
||||||
8. `test_extract_amount_empty_tokens` - 空 token 处理
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📈 覆盖率分析
|
|
||||||
|
|
||||||
### 已覆盖的方法
|
|
||||||
✅ `_detect_account_context()` - **100%** (第一轮新增)
|
|
||||||
✅ `_normalize_account_spaces()` - **100%** (第一轮新增)
|
|
||||||
✅ `_format_account()` - **95%** (第一轮新增)
|
|
||||||
✅ `parse()` - **70%** (第一轮改进)
|
|
||||||
✅ `_parse_standard_payment_line()` - **95%** (已有测试)
|
|
||||||
✅ `_extract_ocr()` - **85%** (第二轮新增)
|
|
||||||
✅ `_extract_bankgiro()` - **90%** (第二轮新增)
|
|
||||||
✅ `_extract_plusgiro()` - **90%** (第二轮新增)
|
|
||||||
✅ `_extract_amount()` - **80%** (第二轮新增)
|
|
||||||
|
|
||||||
### 仍需改进的方法 (未覆盖/部分覆盖)
|
|
||||||
⚠️ `_calculate_confidence()` - **0%** (未测试)
|
|
||||||
⚠️ `cross_validate()` - **0%** (未测试)
|
|
||||||
⚠️ `get_region_bbox()` - **0%** (未测试)
|
|
||||||
⚠️ `_find_tokens_with_values()` - **部分覆盖**
|
|
||||||
⚠️ `_find_machine_code_line_tokens()` - **部分覆盖**
|
|
||||||
|
|
||||||
### 未覆盖的代码行(129行)
|
|
||||||
主要集中在:
|
|
||||||
1. **验证方法** (lines 805-824): `_calculate_confidence`, `cross_validate`
|
|
||||||
2. **辅助方法** (lines 80-92, 336-369, 377-407): Token 查找、bbox 计算、日志记录
|
|
||||||
3. **边界条件** (lines 648-653, 690, 699, 759-760等): 某些提取方法的边界情况
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎯 改进建议
|
|
||||||
|
|
||||||
### ✅ 已完成目标
|
|
||||||
- ✅ 覆盖率从 25% 提升到 65% (+40%)
|
|
||||||
- ✅ 测试数量从 24 增加到 79 (+55个)
|
|
||||||
- ✅ 提取方法全部测试(_extract_ocr, _extract_bankgiro, _extract_plusgiro, _extract_amount)
|
|
||||||
|
|
||||||
### 下一步目标(覆盖率 65% → 80%+)
|
|
||||||
1. **添加验证方法测试** - 为 `_calculate_confidence`, `cross_validate` 添加测试
|
|
||||||
2. **添加辅助方法测试** - 为 token 查找和 bbox 计算方法添加测试
|
|
||||||
3. **完善边界条件** - 增加边界情况和异常处理的测试
|
|
||||||
4. **集成测试** - 添加端到端的集成测试,使用真实 PDF token 数据
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ✅ 已完成的改进
|
|
||||||
|
|
||||||
### 重构收益
|
|
||||||
- ✅ 提取的3个辅助方法现在可以独立测试
|
|
||||||
- ✅ 测试粒度更细,更容易定位问题
|
|
||||||
- ✅ 代码可读性提高,测试用例清晰易懂
|
|
||||||
|
|
||||||
### 质量保证
|
|
||||||
- ✅ 所有655个测试100%通过
|
|
||||||
- ✅ 无回归问题
|
|
||||||
- ✅ 新增测试覆盖了之前未测试的重构代码
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📚 测试编写经验
|
|
||||||
|
|
||||||
### 成功经验
|
|
||||||
1. **使用 fixture 创建测试数据** - `_create_token()` 辅助方法简化了 token 创建
|
|
||||||
2. **按方法组织测试类** - 每个方法一个测试类,结构清晰
|
|
||||||
3. **测试用例命名清晰** - `test_<what>_<condition>` 格式,一目了然
|
|
||||||
4. **覆盖关键路径** - 优先测试常见场景和边界条件
|
|
||||||
|
|
||||||
### 遇到的问题
|
|
||||||
1. **Token 初始化参数** - 忘记了 `page_no` 参数,导致初始测试失败
|
|
||||||
- 解决:修复 `_create_token()` 辅助方法,添加 `page_no=0`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**报告日期**: 2026-01-24
|
|
||||||
**状态**: ✅ 完成
|
|
||||||
**下一步**: 继续提升覆盖率到 60%+
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
# 多池处理架构设计文档
|
|
||||||
|
|
||||||
## 1. 研究总结
|
|
||||||
|
|
||||||
### 1.1 当前问题分析
|
|
||||||
|
|
||||||
我们之前实现的双池模式存在稳定性问题,主要原因:
|
|
||||||
|
|
||||||
| 问题 | 原因 | 解决方案 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
|
|
||||||
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
|
|
||||||
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
|
|
||||||
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
|
|
||||||
|
|
||||||
### 1.2 推荐架构方案
|
|
||||||
|
|
||||||
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**:
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
|
||||||
│ Main Process │ │ CPU Workers │ │ GPU Worker │
|
|
||||||
│ │ │ (4 processes) │ │ (1 process) │
|
|
||||||
│ ┌───────────┐ │ │ │ │ │
|
|
||||||
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
|
|
||||||
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
|
|
||||||
│ └───────────┘ │ │ │ │ │
|
|
||||||
│ ▲ │ │ │ │ │ │ │
|
|
||||||
│ │ │ │ ▼ │ │ ▼ │
|
|
||||||
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
|
|
||||||
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
|
|
||||||
│ │ Collector │ │ │ │ │ │
|
|
||||||
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
|
|
||||||
│ │ │
|
|
||||||
│ ▼ │
|
|
||||||
│ ┌───────────┐ │
|
|
||||||
│ │ Database │ │
|
|
||||||
│ │ Batch │ │
|
|
||||||
│ │ Writer │ │
|
|
||||||
│ └───────────┘ │
|
|
||||||
└─────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. 核心设计原则
|
|
||||||
|
|
||||||
### 2.1 CUDA 兼容性
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 关键:使用 spawn 启动方式
|
|
||||||
import multiprocessing as mp
|
|
||||||
ctx = mp.get_context("spawn")
|
|
||||||
|
|
||||||
# GPU worker 初始化时设置设备
|
|
||||||
def init_gpu_worker(gpu_id: int = 0):
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
||||||
global _ocr
|
|
||||||
from paddleocr import PaddleOCR
|
|
||||||
_ocr = PaddleOCR(use_gpu=True, ...)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.2 Worker 初始化模式
|
|
||||||
|
|
||||||
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 全局变量保存模型
|
|
||||||
_ocr = None
|
|
||||||
|
|
||||||
def init_worker(use_gpu: bool, gpu_id: int = 0):
|
|
||||||
global _ocr
|
|
||||||
if use_gpu:
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
||||||
else:
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
||||||
|
|
||||||
from paddleocr import PaddleOCR
|
|
||||||
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
|
|
||||||
|
|
||||||
# 创建 Pool 时使用 initializer
|
|
||||||
pool = ProcessPoolExecutor(
|
|
||||||
max_workers=1,
|
|
||||||
initializer=init_worker,
|
|
||||||
initargs=(True, 0), # use_gpu=True, gpu_id=0
|
|
||||||
mp_context=mp.get_context("spawn")
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.3 队列模式 vs as_completed
|
|
||||||
|
|
||||||
| 方式 | 优点 | 缺点 | 适用场景 |
|
|
||||||
|------|------|------|----------|
|
|
||||||
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
|
|
||||||
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
|
|
||||||
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
|
|
||||||
|
|
||||||
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. 详细开发计划
|
|
||||||
|
|
||||||
### 阶段 1:重构基础架构 (2-3天)
|
|
||||||
|
|
||||||
#### 1.1 创建 WorkerPool 抽象类
|
|
||||||
|
|
||||||
```python
|
|
||||||
# src/processing/worker_pool.py
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, Future
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Any, Optional, Callable
|
|
||||||
import multiprocessing as mp
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TaskResult:
|
|
||||||
"""任务结果容器"""
|
|
||||||
task_id: str
|
|
||||||
success: bool
|
|
||||||
data: Any
|
|
||||||
error: Optional[str] = None
|
|
||||||
processing_time: float = 0.0
|
|
||||||
|
|
||||||
class WorkerPool(ABC):
|
|
||||||
"""Worker Pool 抽象基类"""
|
|
||||||
|
|
||||||
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
|
|
||||||
self.max_workers = max_workers
|
|
||||||
self.use_gpu = use_gpu
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self._executor: Optional[ProcessPoolExecutor] = None
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_initializer(self) -> Callable:
|
|
||||||
"""返回 worker 初始化函数"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_init_args(self) -> tuple:
|
|
||||||
"""返回初始化参数"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""启动 worker pool"""
|
|
||||||
ctx = mp.get_context("spawn")
|
|
||||||
self._executor = ProcessPoolExecutor(
|
|
||||||
max_workers=self.max_workers,
|
|
||||||
mp_context=ctx,
|
|
||||||
initializer=self.get_initializer(),
|
|
||||||
initargs=self.get_init_args()
|
|
||||||
)
|
|
||||||
|
|
||||||
def submit(self, fn: Callable, *args, **kwargs) -> Future:
|
|
||||||
"""提交任务"""
|
|
||||||
if not self._executor:
|
|
||||||
raise RuntimeError("Pool not started")
|
|
||||||
return self._executor.submit(fn, *args, **kwargs)
|
|
||||||
|
|
||||||
def shutdown(self, wait: bool = True):
|
|
||||||
"""关闭 pool"""
|
|
||||||
if self._executor:
|
|
||||||
self._executor.shutdown(wait=wait)
|
|
||||||
self._executor = None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.start()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self.shutdown()
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 1.2 实现 CPU 和 GPU Worker Pool
|
|
||||||
|
|
||||||
```python
|
|
||||||
# src/processing/cpu_pool.py
|
|
||||||
|
|
||||||
class CPUWorkerPool(WorkerPool):
|
|
||||||
"""CPU-only worker pool for text PDF processing"""
|
|
||||||
|
|
||||||
def __init__(self, max_workers: int = 4):
|
|
||||||
super().__init__(max_workers=max_workers, use_gpu=False)
|
|
||||||
|
|
||||||
def get_initializer(self) -> Callable:
|
|
||||||
return init_cpu_worker
|
|
||||||
|
|
||||||
def get_init_args(self) -> tuple:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
# src/processing/gpu_pool.py
|
|
||||||
|
|
||||||
class GPUWorkerPool(WorkerPool):
|
|
||||||
"""GPU worker pool for OCR processing"""
|
|
||||||
|
|
||||||
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
|
|
||||||
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
|
||||||
|
|
||||||
def get_initializer(self) -> Callable:
|
|
||||||
return init_gpu_worker
|
|
||||||
|
|
||||||
def get_init_args(self) -> tuple:
|
|
||||||
return (self.gpu_id,)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 阶段 2:实现双池协调器 (2-3天)
|
|
||||||
|
|
||||||
#### 2.1 任务分发器
|
|
||||||
|
|
||||||
```python
|
|
||||||
# src/processing/task_dispatcher.py
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
class TaskType(Enum):
|
|
||||||
CPU = auto() # Text PDF
|
|
||||||
GPU = auto() # Scanned PDF
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Task:
|
|
||||||
id: str
|
|
||||||
task_type: TaskType
|
|
||||||
data: Any
|
|
||||||
|
|
||||||
class TaskDispatcher:
|
|
||||||
"""根据 PDF 类型分发任务到不同的 pool"""
|
|
||||||
|
|
||||||
def classify_task(self, doc_info: dict) -> TaskType:
|
|
||||||
"""判断文档是否需要 OCR"""
|
|
||||||
# 基于 PDF 特征判断
|
|
||||||
if self._is_scanned_pdf(doc_info):
|
|
||||||
return TaskType.GPU
|
|
||||||
return TaskType.CPU
|
|
||||||
|
|
||||||
def _is_scanned_pdf(self, doc_info: dict) -> bool:
|
|
||||||
"""检测是否为扫描件"""
|
|
||||||
# 1. 检查是否有可提取文本
|
|
||||||
# 2. 检查图片比例
|
|
||||||
# 3. 检查文本密度
|
|
||||||
pass
|
|
||||||
|
|
||||||
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
|
|
||||||
"""将任务分为 CPU 和 GPU 两组"""
|
|
||||||
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
|
|
||||||
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
|
|
||||||
return cpu_tasks, gpu_tasks
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 2.2 双池协调器
|
|
||||||
|
|
||||||
```python
|
|
||||||
# src/processing/dual_pool_coordinator.py
|
|
||||||
|
|
||||||
from concurrent.futures import as_completed
|
|
||||||
from typing import List, Iterator
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class DualPoolCoordinator:
|
|
||||||
"""协调 CPU 和 GPU 两个 worker pool"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cpu_workers: int = 4,
|
|
||||||
gpu_workers: int = 1,
|
|
||||||
gpu_id: int = 0
|
|
||||||
):
|
|
||||||
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
|
|
||||||
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
|
|
||||||
self.dispatcher = TaskDispatcher()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.cpu_pool.start()
|
|
||||||
self.gpu_pool.start()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self.cpu_pool.shutdown()
|
|
||||||
self.gpu_pool.shutdown()
|
|
||||||
|
|
||||||
def process_batch(
|
|
||||||
self,
|
|
||||||
documents: List[dict],
|
|
||||||
cpu_task_fn: Callable,
|
|
||||||
gpu_task_fn: Callable,
|
|
||||||
on_result: Optional[Callable[[TaskResult], None]] = None,
|
|
||||||
on_error: Optional[Callable[[str, Exception], None]] = None
|
|
||||||
) -> List[TaskResult]:
|
|
||||||
"""
|
|
||||||
处理一批文档,自动分发到 CPU 或 GPU pool
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: 待处理文档列表
|
|
||||||
cpu_task_fn: CPU 任务处理函数
|
|
||||||
gpu_task_fn: GPU 任务处理函数
|
|
||||||
on_result: 结果回调(可选)
|
|
||||||
on_error: 错误回调(可选)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
所有任务结果列表
|
|
||||||
"""
|
|
||||||
# 分类任务
|
|
||||||
tasks = [
|
|
||||||
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
|
|
||||||
for doc in documents
|
|
||||||
]
|
|
||||||
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
|
|
||||||
|
|
||||||
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
|
|
||||||
|
|
||||||
# 提交任务到各自的 pool
|
|
||||||
cpu_futures = {
|
|
||||||
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
|
|
||||||
for t in cpu_tasks
|
|
||||||
}
|
|
||||||
gpu_futures = {
|
|
||||||
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
|
|
||||||
for t in gpu_tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
# 收集结果
|
|
||||||
results = []
|
|
||||||
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
|
|
||||||
|
|
||||||
for future in as_completed(all_futures):
|
|
||||||
task_id = cpu_futures.get(future) or gpu_futures.get(future)
|
|
||||||
pool_type = "CPU" if future in cpu_futures else "GPU"
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = future.result(timeout=300) # 5分钟超时
|
|
||||||
result = TaskResult(task_id=task_id, success=True, data=data)
|
|
||||||
if on_result:
|
|
||||||
on_result(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
|
|
||||||
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
|
|
||||||
if on_error:
|
|
||||||
on_error(task_id, e)
|
|
||||||
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 阶段 3:集成到 autolabel (1-2天)
|
|
||||||
|
|
||||||
#### 3.1 修改 autolabel.py
|
|
||||||
|
|
||||||
```python
|
|
||||||
# src/cli/autolabel.py
|
|
||||||
|
|
||||||
def run_autolabel_dual_pool(args):
|
|
||||||
"""使用双池模式运行自动标注"""
|
|
||||||
|
|
||||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator
|
|
||||||
|
|
||||||
# 初始化数据库批处理
|
|
||||||
db_batch = []
|
|
||||||
db_batch_size = 100
|
|
||||||
|
|
||||||
def on_result(result: TaskResult):
|
|
||||||
"""处理成功结果"""
|
|
||||||
nonlocal db_batch
|
|
||||||
db_batch.append(result.data)
|
|
||||||
|
|
||||||
if len(db_batch) >= db_batch_size:
|
|
||||||
save_documents_batch(db_batch)
|
|
||||||
db_batch.clear()
|
|
||||||
|
|
||||||
def on_error(task_id: str, error: Exception):
|
|
||||||
"""处理错误"""
|
|
||||||
logger.error(f"Task {task_id} failed: {error}")
|
|
||||||
|
|
||||||
# 创建双池协调器
|
|
||||||
with DualPoolCoordinator(
|
|
||||||
cpu_workers=args.cpu_workers or 4,
|
|
||||||
gpu_workers=args.gpu_workers or 1,
|
|
||||||
gpu_id=0
|
|
||||||
) as coordinator:
|
|
||||||
|
|
||||||
# 处理所有 CSV
|
|
||||||
for csv_file in csv_files:
|
|
||||||
documents = load_documents_from_csv(csv_file)
|
|
||||||
|
|
||||||
results = coordinator.process_batch(
|
|
||||||
documents=documents,
|
|
||||||
cpu_task_fn=process_text_pdf,
|
|
||||||
gpu_task_fn=process_scanned_pdf,
|
|
||||||
on_result=on_result,
|
|
||||||
on_error=on_error
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"CSV {csv_file}: {len(results)} processed")
|
|
||||||
|
|
||||||
# 保存剩余批次
|
|
||||||
if db_batch:
|
|
||||||
save_documents_batch(db_batch)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 阶段 4:测试与验证 (1-2天)
|
|
||||||
|
|
||||||
#### 4.1 单元测试
|
|
||||||
|
|
||||||
```python
|
|
||||||
# tests/unit/test_dual_pool.py
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
|
|
||||||
|
|
||||||
class TestDualPoolCoordinator:
|
|
||||||
|
|
||||||
def test_cpu_only_batch(self):
|
|
||||||
"""测试纯 CPU 任务批处理"""
|
|
||||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
|
||||||
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
|
|
||||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
|
||||||
assert len(results) == 10
|
|
||||||
assert all(r.success for r in results)
|
|
||||||
|
|
||||||
def test_mixed_batch(self):
|
|
||||||
"""测试混合任务批处理"""
|
|
||||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
|
||||||
docs = [
|
|
||||||
{"id": "text_1", "type": "text"},
|
|
||||||
{"id": "scan_1", "type": "scanned"},
|
|
||||||
{"id": "text_2", "type": "text"},
|
|
||||||
]
|
|
||||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
|
||||||
assert len(results) == 3
|
|
||||||
|
|
||||||
def test_timeout_handling(self):
|
|
||||||
"""测试超时处理"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_error_recovery(self):
|
|
||||||
"""测试错误恢复"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4.2 集成测试
|
|
||||||
|
|
||||||
```python
|
|
||||||
# tests/integration/test_autolabel_dual_pool.py
|
|
||||||
|
|
||||||
def test_autolabel_with_dual_pool():
|
|
||||||
"""端到端测试双池模式"""
|
|
||||||
# 使用少量测试数据
|
|
||||||
result = subprocess.run([
|
|
||||||
"python", "-m", "src.cli.autolabel",
|
|
||||||
"--cpu-workers", "2",
|
|
||||||
"--gpu-workers", "1",
|
|
||||||
"--limit", "50"
|
|
||||||
], capture_output=True)
|
|
||||||
|
|
||||||
assert result.returncode == 0
|
|
||||||
# 验证数据库记录
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. 关键技术点
|
|
||||||
|
|
||||||
### 4.1 避免死锁的策略
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 1. 使用 timeout
|
|
||||||
try:
|
|
||||||
result = future.result(timeout=300)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning(f"Task timed out")
|
|
||||||
|
|
||||||
# 2. 使用哨兵值
|
|
||||||
SENTINEL = object()
|
|
||||||
queue.put(SENTINEL) # 发送结束信号
|
|
||||||
|
|
||||||
# 3. 检查进程状态
|
|
||||||
if not worker.is_alive():
|
|
||||||
logger.error("Worker died unexpectedly")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 4. 先清空队列再 join
|
|
||||||
while not queue.empty():
|
|
||||||
results.append(queue.get_nowait())
|
|
||||||
worker.join(timeout=5.0)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.2 PaddleOCR 特殊处理
|
|
||||||
|
|
||||||
```python
|
|
||||||
# PaddleOCR 必须在 worker 进程中初始化
|
|
||||||
def init_paddle_worker(gpu_id: int):
|
|
||||||
global _ocr
|
|
||||||
import os
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
||||||
|
|
||||||
# 延迟导入,确保 CUDA 环境变量生效
|
|
||||||
from paddleocr import PaddleOCR
|
|
||||||
_ocr = PaddleOCR(
|
|
||||||
use_angle_cls=True,
|
|
||||||
lang='en',
|
|
||||||
use_gpu=True,
|
|
||||||
show_log=False,
|
|
||||||
# 重要:设置 GPU 内存比例
|
|
||||||
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.3 资源监控
|
|
||||||
|
|
||||||
```python
|
|
||||||
import psutil
|
|
||||||
import GPUtil
|
|
||||||
|
|
||||||
def get_resource_usage():
|
|
||||||
"""获取系统资源使用情况"""
|
|
||||||
cpu_percent = psutil.cpu_percent(interval=1)
|
|
||||||
memory = psutil.virtual_memory()
|
|
||||||
|
|
||||||
gpu_info = []
|
|
||||||
for gpu in GPUtil.getGPUs():
|
|
||||||
gpu_info.append({
|
|
||||||
"id": gpu.id,
|
|
||||||
"memory_used": gpu.memoryUsed,
|
|
||||||
"memory_total": gpu.memoryTotal,
|
|
||||||
"utilization": gpu.load * 100
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"cpu_percent": cpu_percent,
|
|
||||||
"memory_percent": memory.percent,
|
|
||||||
"gpu": gpu_info
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. 风险评估与应对
|
|
||||||
|
|
||||||
| 风险 | 可能性 | 影响 | 应对策略 |
|
|
||||||
|------|--------|------|----------|
|
|
||||||
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1,设置 gpu_mem 参数 |
|
|
||||||
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
|
|
||||||
| 任务分类错误 | 中 | 中 | 添加回退机制,CPU 失败后尝试 GPU |
|
|
||||||
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. 备选方案
|
|
||||||
|
|
||||||
如果上述方案仍存在问题,可以考虑:
|
|
||||||
|
|
||||||
### 6.1 使用 Ray
|
|
||||||
|
|
||||||
```python
|
|
||||||
import ray
|
|
||||||
|
|
||||||
ray.init()
|
|
||||||
|
|
||||||
@ray.remote(num_cpus=1)
|
|
||||||
def cpu_task(data):
|
|
||||||
return process_text_pdf(data)
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1)
|
|
||||||
def gpu_task(data):
|
|
||||||
return process_scanned_pdf(data)
|
|
||||||
|
|
||||||
# 自动资源调度
|
|
||||||
futures = [cpu_task.remote(d) for d in cpu_docs]
|
|
||||||
futures += [gpu_task.remote(d) for d in gpu_docs]
|
|
||||||
results = ray.get(futures)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 单池 + 动态 GPU 调度
|
|
||||||
|
|
||||||
保持单池模式,但在每个任务内部动态决定是否使用 GPU:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def process_document(doc_data):
|
|
||||||
if is_scanned_pdf(doc_data):
|
|
||||||
# 使用 GPU (需要全局锁或信号量控制并发)
|
|
||||||
with gpu_semaphore:
|
|
||||||
return process_with_ocr(doc_data)
|
|
||||||
else:
|
|
||||||
return process_text_only(doc_data)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. 时间线总结
|
|
||||||
|
|
||||||
| 阶段 | 任务 | 预计工作量 |
|
|
||||||
|------|------|------------|
|
|
||||||
| 阶段 1 | 基础架构重构 | 2-3 天 |
|
|
||||||
| 阶段 2 | 双池协调器实现 | 2-3 天 |
|
|
||||||
| 阶段 3 | 集成到 autolabel | 1-2 天 |
|
|
||||||
| 阶段 4 | 测试与验证 | 1-2 天 |
|
|
||||||
| **总计** | | **6-10 天** |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. 参考资料
|
|
||||||
|
|
||||||
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
|
|
||||||
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
|
|
||||||
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
|
|
||||||
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
|
|
||||||
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
|
|
||||||
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)
|
|
||||||
1223
docs/product-plan-v2.md
Normal file
1223
docs/product-plan-v2.md
Normal file
File diff suppressed because it is too large
Load Diff
302
docs/ux-design-prompt-v2.md
Normal file
302
docs/ux-design-prompt-v2.md
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
# Document Annotation Tool – UX Design Spec v2
|
||||||
|
|
||||||
|
## Theme: Warm Graphite (Modern Enterprise)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Design Principles (Updated)
|
||||||
|
|
||||||
|
1. **Clarity** – High contrast, but never pure black-on-white
|
||||||
|
2. **Warm Neutrality** – Slightly warm grays reduce visual fatigue
|
||||||
|
3. **Focus** – Content-first layouts with restrained accents
|
||||||
|
4. **Consistency** – Reusable patterns, predictable behavior
|
||||||
|
5. **Professional Trust** – Calm, serious, enterprise-ready
|
||||||
|
6. **Longevity** – No trendy colors that age quickly
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Color Palette (Warm Graphite)
|
||||||
|
|
||||||
|
### Core Colors
|
||||||
|
|
||||||
|
| Usage | Color Name | Hex |
|
||||||
|
|------|-----------|-----|
|
||||||
|
| Primary Text | Soft Black | #121212 |
|
||||||
|
| Secondary Text | Charcoal Gray | #2A2A2A |
|
||||||
|
| Muted Text | Warm Gray | #6B6B6B |
|
||||||
|
| Disabled Text | Light Warm Gray | #9A9A9A |
|
||||||
|
|
||||||
|
### Backgrounds
|
||||||
|
|
||||||
|
| Usage | Color | Hex |
|
||||||
|
|-----|------|-----|
|
||||||
|
| App Background | Paper White | #FAFAF8 |
|
||||||
|
| Card / Panel | White | #FFFFFF |
|
||||||
|
| Hover Surface | Subtle Warm Gray | #F1F0ED |
|
||||||
|
| Selected Row | Very Light Warm Gray | #ECEAE6 |
|
||||||
|
|
||||||
|
### Borders & Dividers
|
||||||
|
|
||||||
|
| Usage | Color | Hex |
|
||||||
|
|------|------|-----|
|
||||||
|
| Default Border | Warm Light Gray | #E6E4E1 |
|
||||||
|
| Strong Divider | Neutral Gray | #D8D6D2 |
|
||||||
|
|
||||||
|
### Semantic States (Muted & Professional)
|
||||||
|
|
||||||
|
| State | Color | Hex |
|
||||||
|
|------|-------|-----|
|
||||||
|
| Success | Olive Gray | #3E4A3A |
|
||||||
|
| Error | Brick Gray | #4A3A3A |
|
||||||
|
| Warning | Sand Gray | #4A4A3A |
|
||||||
|
| Info | Graphite Gray | #3A3A3A |
|
||||||
|
|
||||||
|
> Accent colors are **never saturated** and are used only for status, progress, or selection.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Typography
|
||||||
|
|
||||||
|
- **Font Family**: Inter / SF Pro / system-ui
|
||||||
|
- **Headings**:
|
||||||
|
- Weight: 600–700
|
||||||
|
- Color: #121212
|
||||||
|
- Letter spacing: -0.01em
|
||||||
|
- **Body Text**:
|
||||||
|
- Weight: 400
|
||||||
|
- Color: #2A2A2A
|
||||||
|
- **Captions / Meta**:
|
||||||
|
- Weight: 400
|
||||||
|
- Color: #6B6B6B
|
||||||
|
- **Monospace (IDs / Values)**:
|
||||||
|
- JetBrains Mono / SF Mono
|
||||||
|
- Color: #2A2A2A
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Global Layout
|
||||||
|
|
||||||
|
### Top Navigation Bar
|
||||||
|
|
||||||
|
- Height: 56px
|
||||||
|
- Background: #FAFAF8
|
||||||
|
- Bottom Border: 1px solid #E6E4E1
|
||||||
|
- Logo: Text or icon in #121212
|
||||||
|
|
||||||
|
**Navigation Items**
|
||||||
|
- Default: #6B6B6B
|
||||||
|
- Hover: #2A2A2A
|
||||||
|
- Active:
|
||||||
|
- Text: #121212
|
||||||
|
- Bottom indicator: 2px solid #3A3A3A (rounded ends)
|
||||||
|
|
||||||
|
**Avatar**
|
||||||
|
- Circle background: #ECEAE6
|
||||||
|
- Text: #2A2A2A
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Page: Documents (Dashboard)
|
||||||
|
|
||||||
|
### Page Header
|
||||||
|
|
||||||
|
- Title: "Documents" (#121212)
|
||||||
|
- Actions:
|
||||||
|
- Primary button: Dark graphite outline
|
||||||
|
- Secondary button: Subtle border only
|
||||||
|
|
||||||
|
### Filters Bar
|
||||||
|
|
||||||
|
- Background: #FFFFFF
|
||||||
|
- Border: 1px solid #E6E4E1
|
||||||
|
- Inputs:
|
||||||
|
- Background: #FFFFFF
|
||||||
|
- Hover: #F1F0ED
|
||||||
|
- Focus ring: 1px #3A3A3A
|
||||||
|
|
||||||
|
### Document Table
|
||||||
|
|
||||||
|
- Table background: #FFFFFF
|
||||||
|
- Header text: #6B6B6B
|
||||||
|
- Row hover: #F1F0ED
|
||||||
|
- Row selected:
|
||||||
|
- Background: #ECEAE6
|
||||||
|
- Left indicator: 3px solid #3A3A3A
|
||||||
|
|
||||||
|
### Status Badges
|
||||||
|
|
||||||
|
- Pending:
|
||||||
|
- BG: #FFFFFF
|
||||||
|
- Border: #D8D6D2
|
||||||
|
- Text: #2A2A2A
|
||||||
|
|
||||||
|
- Labeled:
|
||||||
|
- BG: #2A2A2A
|
||||||
|
- Text: #FFFFFF
|
||||||
|
|
||||||
|
- Exported:
|
||||||
|
- BG: #ECEAE6
|
||||||
|
- Text: #2A2A2A
|
||||||
|
- Icon: ✓
|
||||||
|
|
||||||
|
### Auto-label States
|
||||||
|
|
||||||
|
- Running:
|
||||||
|
- Progress bar: #3A3A3A on #ECEAE6
|
||||||
|
- Completed:
|
||||||
|
- Text: #3E4A3A
|
||||||
|
- Failed:
|
||||||
|
- BG: #F1EDED
|
||||||
|
- Text: #4A3A3A
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Upload Modals (Single & Batch)
|
||||||
|
|
||||||
|
### Modal Container
|
||||||
|
|
||||||
|
- Background: #FFFFFF
|
||||||
|
- Border radius: 8px
|
||||||
|
- Shadow: 0 1px 3px rgba(0,0,0,0.08)
|
||||||
|
|
||||||
|
### Drop Zone
|
||||||
|
|
||||||
|
- Background: #FAFAF8
|
||||||
|
- Border: 1px dashed #D8D6D2
|
||||||
|
- Hover: #F1F0ED
|
||||||
|
- Icon: Graphite gray
|
||||||
|
|
||||||
|
### Form Fields
|
||||||
|
|
||||||
|
- Input BG: #FFFFFF
|
||||||
|
- Border: #D8D6D2
|
||||||
|
- Focus: 1px solid #3A3A3A
|
||||||
|
|
||||||
|
Primary Action Button:
|
||||||
|
- Text: #FFFFFF
|
||||||
|
- BG: #2A2A2A
|
||||||
|
- Hover: #121212
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Document Detail View
|
||||||
|
|
||||||
|
### Canvas Area
|
||||||
|
|
||||||
|
- Background: #FFFFFF
|
||||||
|
- Annotation styles:
|
||||||
|
- Manual: Solid border #2A2A2A
|
||||||
|
- Auto: Dashed border #6B6B6B
|
||||||
|
- Selected: 2px border #3A3A3A + resize handles
|
||||||
|
|
||||||
|
### Right Info Panel
|
||||||
|
|
||||||
|
- Card background: #FFFFFF
|
||||||
|
- Section headers: #121212
|
||||||
|
- Meta text: #6B6B6B
|
||||||
|
|
||||||
|
### Annotation Table
|
||||||
|
|
||||||
|
- Same table styles as Documents
|
||||||
|
- Inline edit:
|
||||||
|
- Input background: #FAFAF8
|
||||||
|
- Save button: Graphite
|
||||||
|
|
||||||
|
### Locked State (Auto-label Running)
|
||||||
|
|
||||||
|
- Banner BG: #FAFAF8
|
||||||
|
- Border-left: 3px solid #4A4A3A
|
||||||
|
- Progress bar: Graphite
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Training Page
|
||||||
|
|
||||||
|
### Document Selector
|
||||||
|
|
||||||
|
- Selected rows use same highlight rules
|
||||||
|
- Verified state:
|
||||||
|
- Full: Olive gray check
|
||||||
|
- Partial: Sand gray warning
|
||||||
|
|
||||||
|
### Configuration Panel
|
||||||
|
|
||||||
|
- Card layout
|
||||||
|
- Inputs aligned to grid
|
||||||
|
- Schedule option visually muted until enabled
|
||||||
|
|
||||||
|
Primary CTA:
|
||||||
|
- Start Training button in dark graphite
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Models & Training History
|
||||||
|
|
||||||
|
### Training Job List
|
||||||
|
|
||||||
|
- Job cards use #FFFFFF background
|
||||||
|
- Running job:
|
||||||
|
- Progress bar: #3A3A3A
|
||||||
|
- Completed job:
|
||||||
|
- Metrics bars in graphite
|
||||||
|
|
||||||
|
### Model Detail Panel
|
||||||
|
|
||||||
|
- Sectioned cards
|
||||||
|
- Metric bars:
|
||||||
|
- Track: #ECEAE6
|
||||||
|
- Fill: #3A3A3A
|
||||||
|
|
||||||
|
Actions:
|
||||||
|
- Primary: Download Model
|
||||||
|
- Secondary: View Logs / Use as Base
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Micro-interactions (Refined)
|
||||||
|
|
||||||
|
| Element | Interaction | Animation |
|
||||||
|
|------|------------|-----------|
|
||||||
|
| Button hover | BG lightens | 150ms ease-out |
|
||||||
|
| Button press | Scale 0.98 | 100ms |
|
||||||
|
| Row hover | BG fade | 120ms |
|
||||||
|
| Modal open | Fade + scale 0.96 → 1 | 200ms |
|
||||||
|
| Progress fill | Smooth | ease-out |
|
||||||
|
| Annotation select | Border + handles | 120ms |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Tailwind Theme (Updated)
|
||||||
|
|
||||||
|
```js
|
||||||
|
colors: {
|
||||||
|
text: {
|
||||||
|
primary: '#121212',
|
||||||
|
secondary: '#2A2A2A',
|
||||||
|
muted: '#6B6B6B',
|
||||||
|
disabled: '#9A9A9A',
|
||||||
|
},
|
||||||
|
bg: {
|
||||||
|
app: '#FAFAF8',
|
||||||
|
card: '#FFFFFF',
|
||||||
|
hover: '#F1F0ED',
|
||||||
|
selected: '#ECEAE6',
|
||||||
|
},
|
||||||
|
border: '#E6E4E1',
|
||||||
|
accent: '#3A3A3A',
|
||||||
|
success: '#3E4A3A',
|
||||||
|
error: '#4A3A3A',
|
||||||
|
warning: '#4A4A3A',
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Final Notes
|
||||||
|
|
||||||
|
- Pure black (#000000) should **never** be used as large surfaces
|
||||||
|
- Accent color usage should stay under **10% of UI area**
|
||||||
|
- Warm grays are intentional and must not be "corrected" to blue-grays
|
||||||
|
|
||||||
|
This theme is designed to scale from internal tool → polished SaaS without redesign.
|
||||||
|
|
||||||
273
docs/web-refactoring-complete.md
Normal file
273
docs/web-refactoring-complete.md
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
# Web Directory Refactoring - Complete ✅
|
||||||
|
|
||||||
|
**Date**: 2026-01-25
|
||||||
|
**Status**: ✅ Completed
|
||||||
|
**Tests**: 188 passing (0 failures)
|
||||||
|
**Coverage**: 23% (maintained)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Final Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
src/web/
|
||||||
|
├── api/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ └── v1/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── routes.py # Public inference API
|
||||||
|
│ ├── admin/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── documents.py # Document management (was admin_routes.py)
|
||||||
|
│ │ ├── annotations.py # Annotation routes (was admin_annotation_routes.py)
|
||||||
|
│ │ └── training.py # Training routes (was admin_training_routes.py)
|
||||||
|
│ ├── async_api/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── routes.py # Async processing API (was async_routes.py)
|
||||||
|
│ └── batch/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ └── routes.py # Batch upload API (was batch_upload_routes.py)
|
||||||
|
│
|
||||||
|
├── schemas/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── common.py # Shared models (ErrorResponse)
|
||||||
|
│ ├── admin.py # Admin schemas (was admin_schemas.py)
|
||||||
|
│ └── inference.py # Inference + async schemas (was schemas.py)
|
||||||
|
│
|
||||||
|
├── services/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── inference.py # Inference service (was services.py)
|
||||||
|
│ ├── autolabel.py # Auto-label service (was admin_autolabel.py)
|
||||||
|
│ ├── async_processing.py # Async processing (was async_service.py)
|
||||||
|
│ └── batch_upload.py # Batch upload service (was batch_upload_service.py)
|
||||||
|
│
|
||||||
|
├── core/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── auth.py # Authentication (was admin_auth.py)
|
||||||
|
│ ├── rate_limiter.py # Rate limiting (unchanged)
|
||||||
|
│ └── scheduler.py # Task scheduler (was admin_scheduler.py)
|
||||||
|
│
|
||||||
|
├── workers/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── async_queue.py # Async task queue (was async_queue.py)
|
||||||
|
│ └── batch_queue.py # Batch task queue (was batch_queue.py)
|
||||||
|
│
|
||||||
|
├── __init__.py # Main exports
|
||||||
|
├── app.py # FastAPI app (imports updated)
|
||||||
|
├── config.py # Configuration (unchanged)
|
||||||
|
└── dependencies.py # Global dependencies (unchanged)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Changes Summary
|
||||||
|
|
||||||
|
### Files Moved and Renamed
|
||||||
|
|
||||||
|
| Old Location | New Location | Change Type |
|
||||||
|
|-------------|--------------|-------------|
|
||||||
|
| `admin_routes.py` | `api/v1/admin/documents.py` | Moved + Renamed |
|
||||||
|
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Moved + Renamed |
|
||||||
|
| `admin_training_routes.py` | `api/v1/admin/training.py` | Moved + Renamed |
|
||||||
|
| `admin_auth.py` | `core/auth.py` | Moved |
|
||||||
|
| `admin_autolabel.py` | `services/autolabel.py` | Moved |
|
||||||
|
| `admin_scheduler.py` | `core/scheduler.py` | Moved |
|
||||||
|
| `admin_schemas.py` | `schemas/admin.py` | Moved |
|
||||||
|
| `routes.py` | `api/v1/routes.py` | Moved |
|
||||||
|
| `schemas.py` | `schemas/inference.py` | Moved |
|
||||||
|
| `services.py` | `services/inference.py` | Moved |
|
||||||
|
| `async_routes.py` | `api/v1/async_api/routes.py` | Moved |
|
||||||
|
| `async_queue.py` | `workers/async_queue.py` | Moved |
|
||||||
|
| `async_service.py` | `services/async_processing.py` | Moved + Renamed |
|
||||||
|
| `batch_queue.py` | `workers/batch_queue.py` | Moved |
|
||||||
|
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Moved |
|
||||||
|
| `batch_upload_service.py` | `services/batch_upload.py` | Moved |
|
||||||
|
|
||||||
|
**Total**: 16 files reorganized
|
||||||
|
|
||||||
|
### Files Updated
|
||||||
|
|
||||||
|
**Source Files** (imports updated):
|
||||||
|
- `app.py` - Updated all imports to new structure
|
||||||
|
- `api/v1/admin/documents.py` - Updated schema/auth imports
|
||||||
|
- `api/v1/admin/annotations.py` - Updated schema/service imports
|
||||||
|
- `api/v1/admin/training.py` - Updated schema/auth imports
|
||||||
|
- `api/v1/routes.py` - Updated schema imports
|
||||||
|
- `api/v1/async_api/routes.py` - Updated schema imports
|
||||||
|
- `api/v1/batch/routes.py` - Updated service/worker imports
|
||||||
|
- `services/async_processing.py` - Updated worker/core imports
|
||||||
|
|
||||||
|
**Test Files** (all 15 updated):
|
||||||
|
- `test_admin_annotations.py`
|
||||||
|
- `test_admin_auth.py`
|
||||||
|
- `test_admin_routes.py`
|
||||||
|
- `test_admin_routes_enhanced.py`
|
||||||
|
- `test_admin_training.py`
|
||||||
|
- `test_annotation_locks.py`
|
||||||
|
- `test_annotation_phase5.py`
|
||||||
|
- `test_async_queue.py`
|
||||||
|
- `test_async_routes.py`
|
||||||
|
- `test_async_service.py`
|
||||||
|
- `test_autolabel_with_locks.py`
|
||||||
|
- `test_batch_queue.py`
|
||||||
|
- `test_batch_upload_routes.py`
|
||||||
|
- `test_batch_upload_service.py`
|
||||||
|
- `test_training_phase4.py`
|
||||||
|
- `conftest.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Import Examples
|
||||||
|
|
||||||
|
### Old Import Style (Before Refactoring)
|
||||||
|
```python
|
||||||
|
from src.web.admin_routes import create_admin_router
|
||||||
|
from src.web.admin_schemas import DocumentItem
|
||||||
|
from src.web.admin_auth import validate_admin_token
|
||||||
|
from src.web.async_routes import create_async_router
|
||||||
|
from src.web.schemas import ErrorResponse
|
||||||
|
```
|
||||||
|
|
||||||
|
### New Import Style (After Refactoring)
|
||||||
|
```python
|
||||||
|
# Admin API
|
||||||
|
from src.web.api.v1.admin.documents import create_admin_router
|
||||||
|
from src.web.api.v1.admin import create_admin_router # Shorter alternative
|
||||||
|
|
||||||
|
# Schemas
|
||||||
|
from src.web.schemas.admin import DocumentItem
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
# Core components
|
||||||
|
from src.web.core.auth import validate_admin_token
|
||||||
|
|
||||||
|
# Async API
|
||||||
|
from src.web.api.v1.async_api.routes import create_async_router
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Benefits Achieved
|
||||||
|
|
||||||
|
### 1. **Clear Separation of Concerns**
|
||||||
|
- **API Routes**: All in `api/v1/` by version and feature
|
||||||
|
- **Data Models**: All in `schemas/` by domain
|
||||||
|
- **Business Logic**: All in `services/`
|
||||||
|
- **Core Components**: Reusable utilities in `core/`
|
||||||
|
- **Background Jobs**: Task queues in `workers/`
|
||||||
|
|
||||||
|
### 2. **Better Scalability**
|
||||||
|
- Easy to add API v2 without touching v1
|
||||||
|
- Clear namespace for each module
|
||||||
|
- Reduced file sizes (no 800+ line files)
|
||||||
|
- Follows single responsibility principle
|
||||||
|
|
||||||
|
### 3. **Improved Maintainability**
|
||||||
|
- Find files by function, not by prefix
|
||||||
|
- Each module has one clear purpose
|
||||||
|
- Easier to onboard new developers
|
||||||
|
- Better IDE navigation
|
||||||
|
|
||||||
|
### 4. **Standards Compliance**
|
||||||
|
- Follows FastAPI best practices
|
||||||
|
- Matches Django/Flask project structures
|
||||||
|
- Standard Python package organization
|
||||||
|
- Industry-standard naming conventions
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing Results
|
||||||
|
|
||||||
|
**Before Refactoring**:
|
||||||
|
- 188 tests passing
|
||||||
|
- 23% code coverage
|
||||||
|
- Flat directory structure
|
||||||
|
|
||||||
|
**After Refactoring**:
|
||||||
|
- ✅ 188 tests passing (0 failures)
|
||||||
|
- ✅ 23% code coverage (maintained)
|
||||||
|
- ✅ Clean hierarchical structure
|
||||||
|
- ✅ All imports updated
|
||||||
|
- ✅ No backward compatibility shims needed
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Migration Statistics
|
||||||
|
|
||||||
|
| Metric | Count |
|
||||||
|
|--------|-------|
|
||||||
|
| Files moved | 16 |
|
||||||
|
| Directories created | 9 |
|
||||||
|
| Files updated (source) | 8 |
|
||||||
|
| Files updated (tests) | 16 |
|
||||||
|
| Import statements updated | ~150 |
|
||||||
|
| Lines of code changed | ~200 |
|
||||||
|
| Tests broken | 0 |
|
||||||
|
| Coverage lost | 0% |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Code Diff Summary
|
||||||
|
|
||||||
|
```diff
|
||||||
|
Before:
|
||||||
|
src/web/
|
||||||
|
├── admin_routes.py (645 lines)
|
||||||
|
├── admin_annotation_routes.py (504 lines)
|
||||||
|
├── admin_training_routes.py (565 lines)
|
||||||
|
├── admin_auth.py (22 lines)
|
||||||
|
├── admin_schemas.py (262 lines)
|
||||||
|
... (15 more files at root level)
|
||||||
|
|
||||||
|
After:
|
||||||
|
src/web/
|
||||||
|
├── api/v1/
|
||||||
|
│ ├── admin/ (3 route files)
|
||||||
|
│ ├── async_api/ (1 route file)
|
||||||
|
│ └── batch/ (1 route file)
|
||||||
|
├── schemas/ (3 schema files)
|
||||||
|
├── services/ (4 service files)
|
||||||
|
├── core/ (3 core files)
|
||||||
|
└── workers/ (2 worker files)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Next Steps (Optional)
|
||||||
|
|
||||||
|
### Phase 2: Documentation
|
||||||
|
- [ ] Update API documentation with new import paths
|
||||||
|
- [ ] Create migration guide for external developers
|
||||||
|
- [ ] Update CLAUDE.md with new structure
|
||||||
|
|
||||||
|
### Phase 3: Further Optimization
|
||||||
|
- [ ] Split large files (>400 lines) if needed
|
||||||
|
- [ ] Extract common utilities
|
||||||
|
- [ ] Add typing stubs
|
||||||
|
|
||||||
|
### Phase 4: Deprecation (Future)
|
||||||
|
- [ ] Add deprecation warnings if creating compatibility layer
|
||||||
|
- [ ] Remove old imports after grace period
|
||||||
|
- [ ] Update all documentation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Rollback Instructions
|
||||||
|
|
||||||
|
If needed, rollback is simple:
|
||||||
|
```bash
|
||||||
|
git revert <commit-hash>
|
||||||
|
```
|
||||||
|
|
||||||
|
All changes are in version control, making rollback safe and easy.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
✅ **Refactoring completed successfully**
|
||||||
|
✅ **Zero breaking changes**
|
||||||
|
✅ **All tests passing**
|
||||||
|
✅ **Industry-standard structure achieved**
|
||||||
|
|
||||||
|
The web directory is now organized following Python and FastAPI best practices, making it easier to scale, maintain, and extend.
|
||||||
186
docs/web-refactoring-plan.md
Normal file
186
docs/web-refactoring-plan.md
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# Web Directory Refactoring Plan
|
||||||
|
|
||||||
|
## Current Structure Issues
|
||||||
|
|
||||||
|
1. **Flat structure**: All files in one directory (20 Python files)
|
||||||
|
2. **Naming inconsistency**: Mix of `admin_*`, `async_*`, `batch_*` prefixes
|
||||||
|
3. **Mixed concerns**: Routes, schemas, services, and workers in same directory
|
||||||
|
4. **Poor scalability**: Hard to navigate and maintain as project grows
|
||||||
|
|
||||||
|
## Proposed Structure (Best Practices)
|
||||||
|
|
||||||
|
```
|
||||||
|
src/web/
|
||||||
|
├── __init__.py # Main exports
|
||||||
|
├── app.py # FastAPI app factory
|
||||||
|
├── config.py # App configuration
|
||||||
|
├── dependencies.py # Global dependencies
|
||||||
|
│
|
||||||
|
├── api/ # API Routes Layer
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ └── v1/ # API version 1
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── routes.py # Public API routes (inference)
|
||||||
|
│ ├── admin/ # Admin API routes
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── documents.py # admin_routes.py → documents.py
|
||||||
|
│ │ ├── annotations.py # admin_annotation_routes.py → annotations.py
|
||||||
|
│ │ ├── training.py # admin_training_routes.py → training.py
|
||||||
|
│ │ └── auth.py # admin_auth.py → auth.py (routes only)
|
||||||
|
│ ├── async_api/ # Async processing API
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── routes.py # async_routes.py → routes.py
|
||||||
|
│ └── batch/ # Batch upload API
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ └── routes.py # batch_upload_routes.py → routes.py
|
||||||
|
│
|
||||||
|
├── schemas/ # Pydantic Models
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── common.py # Shared schemas (ErrorResponse, etc.)
|
||||||
|
│ ├── inference.py # schemas.py → inference.py
|
||||||
|
│ ├── admin.py # admin_schemas.py → admin.py
|
||||||
|
│ ├── async_api.py # New: async API schemas
|
||||||
|
│ └── batch.py # New: batch upload schemas
|
||||||
|
│
|
||||||
|
├── services/ # Business Logic Layer
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── inference.py # services.py → inference.py
|
||||||
|
│ ├── autolabel.py # admin_autolabel.py → autolabel.py
|
||||||
|
│ ├── async_processing.py # async_service.py → async_processing.py
|
||||||
|
│ └── batch_upload.py # batch_upload_service.py → batch_upload.py
|
||||||
|
│
|
||||||
|
├── core/ # Core Components
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── auth.py # admin_auth.py → auth.py (logic only)
|
||||||
|
│ ├── rate_limiter.py # rate_limiter.py → rate_limiter.py
|
||||||
|
│ └── scheduler.py # admin_scheduler.py → scheduler.py
|
||||||
|
│
|
||||||
|
└── workers/ # Background Task Queues
|
||||||
|
├── __init__.py
|
||||||
|
├── async_queue.py # async_queue.py → async_queue.py
|
||||||
|
└── batch_queue.py # batch_queue.py → batch_queue.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## File Mapping
|
||||||
|
|
||||||
|
### Current → New Location
|
||||||
|
|
||||||
|
| Current File | New Location | Purpose |
|
||||||
|
|--------------|--------------|---------|
|
||||||
|
| `admin_routes.py` | `api/v1/admin/documents.py` | Document management routes |
|
||||||
|
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Annotation routes |
|
||||||
|
| `admin_training_routes.py` | `api/v1/admin/training.py` | Training routes |
|
||||||
|
| `admin_auth.py` | Split: `api/v1/admin/auth.py` + `core/auth.py` | Auth routes + logic |
|
||||||
|
| `admin_schemas.py` | `schemas/admin.py` | Admin Pydantic models |
|
||||||
|
| `admin_autolabel.py` | `services/autolabel.py` | Auto-label service |
|
||||||
|
| `admin_scheduler.py` | `core/scheduler.py` | Training scheduler |
|
||||||
|
| `routes.py` | `api/v1/routes.py` | Public inference API |
|
||||||
|
| `schemas.py` | `schemas/inference.py` | Inference models |
|
||||||
|
| `services.py` | `services/inference.py` | Inference service |
|
||||||
|
| `async_routes.py` | `api/v1/async_api/routes.py` | Async API routes |
|
||||||
|
| `async_service.py` | `services/async_processing.py` | Async processing service |
|
||||||
|
| `async_queue.py` | `workers/async_queue.py` | Async task queue |
|
||||||
|
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Batch upload routes |
|
||||||
|
| `batch_upload_service.py` | `services/batch_upload.py` | Batch upload service |
|
||||||
|
| `batch_queue.py` | `workers/batch_queue.py` | Batch task queue |
|
||||||
|
| `rate_limiter.py` | `core/rate_limiter.py` | Rate limiting logic |
|
||||||
|
| `config.py` | `config.py` | Keep as-is |
|
||||||
|
| `dependencies.py` | `dependencies.py` | Keep as-is |
|
||||||
|
| `app.py` | `app.py` | Keep as-is (update imports) |
|
||||||
|
|
||||||
|
## Benefits
|
||||||
|
|
||||||
|
### 1. Clear Separation of Concerns
|
||||||
|
- **Routes**: API endpoint definitions
|
||||||
|
- **Schemas**: Data validation models
|
||||||
|
- **Services**: Business logic
|
||||||
|
- **Core**: Reusable components
|
||||||
|
- **Workers**: Background processing
|
||||||
|
|
||||||
|
### 2. Better Scalability
|
||||||
|
- Easy to add new API versions (`v2/`)
|
||||||
|
- Clear namespace for each domain
|
||||||
|
- Reduced file size (no 800+ line files)
|
||||||
|
|
||||||
|
### 3. Improved Maintainability
|
||||||
|
- Find files by function, not by prefix
|
||||||
|
- Each module has single responsibility
|
||||||
|
- Easier to write focused tests
|
||||||
|
|
||||||
|
### 4. Standard Python Patterns
|
||||||
|
- Package-based organization
|
||||||
|
- Follows FastAPI best practices
|
||||||
|
- Similar to Django/Flask structures
|
||||||
|
|
||||||
|
## Implementation Steps
|
||||||
|
|
||||||
|
### Phase 1: Create New Structure (No Breaking Changes)
|
||||||
|
1. Create new directories: `api/`, `schemas/`, `services/`, `core/`, `workers/`
|
||||||
|
2. Copy files to new locations (don't delete originals yet)
|
||||||
|
3. Update imports in new files
|
||||||
|
4. Add `__init__.py` with proper exports
|
||||||
|
|
||||||
|
### Phase 2: Update Tests
|
||||||
|
5. Update test imports to use new structure
|
||||||
|
6. Run tests to verify nothing breaks
|
||||||
|
7. Fix any import issues
|
||||||
|
|
||||||
|
### Phase 3: Update Main App
|
||||||
|
8. Update `app.py` to import from new locations
|
||||||
|
9. Run full test suite
|
||||||
|
10. Verify all endpoints work
|
||||||
|
|
||||||
|
### Phase 4: Cleanup
|
||||||
|
11. Delete old files
|
||||||
|
12. Update documentation
|
||||||
|
13. Final test run
|
||||||
|
|
||||||
|
## Migration Priority
|
||||||
|
|
||||||
|
**High Priority** (Most used):
|
||||||
|
- Routes and schemas (user-facing APIs)
|
||||||
|
- Services (core business logic)
|
||||||
|
|
||||||
|
**Medium Priority**:
|
||||||
|
- Core components (auth, rate limiter)
|
||||||
|
- Workers (background tasks)
|
||||||
|
|
||||||
|
**Low Priority**:
|
||||||
|
- Config and dependencies (already well-located)
|
||||||
|
|
||||||
|
## Backwards Compatibility
|
||||||
|
|
||||||
|
During migration, maintain backwards compatibility:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# src/web/__init__.py
|
||||||
|
# Old imports still work
|
||||||
|
from src.web.api.v1.admin.documents import router as admin_router
|
||||||
|
from src.web.schemas.admin import AdminDocument
|
||||||
|
|
||||||
|
# Keep old names for compatibility (temporary)
|
||||||
|
admin_routes = admin_router # Deprecated alias
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
1. **Unit Tests**: Test each module independently
|
||||||
|
2. **Integration Tests**: Test API endpoints still work
|
||||||
|
3. **Import Tests**: Verify all old imports still work
|
||||||
|
4. **Coverage**: Maintain current 23% coverage minimum
|
||||||
|
|
||||||
|
## Rollback Plan
|
||||||
|
|
||||||
|
If issues arise:
|
||||||
|
1. Keep old files until fully migrated
|
||||||
|
2. Git allows easy revert
|
||||||
|
3. Tests catch breaking changes early
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
Would you like me to:
|
||||||
|
1. **Start Phase 1**: Create new directory structure and move files?
|
||||||
|
2. **Create migration script**: Automate the file moves and import updates?
|
||||||
|
3. **Focus on specific area**: Start with admin API or async API first?
|
||||||
218
docs/web-refactoring-status.md
Normal file
218
docs/web-refactoring-status.md
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
# Web Directory Refactoring - Current Status
|
||||||
|
|
||||||
|
## ✅ Completed Steps
|
||||||
|
|
||||||
|
### 1. Directory Structure Created
|
||||||
|
```
|
||||||
|
src/web/
|
||||||
|
├── api/
|
||||||
|
│ ├── v1/
|
||||||
|
│ │ ├── admin/ (documents.py, annotations.py, training.py)
|
||||||
|
│ │ ├── async_api/ (routes.py)
|
||||||
|
│ │ ├── batch/ (routes.py)
|
||||||
|
│ │ └── routes.py (public inference API)
|
||||||
|
├── schemas/
|
||||||
|
│ ├── admin.py (admin schemas)
|
||||||
|
│ ├── inference.py (inference + async schemas)
|
||||||
|
│ └── common.py (ErrorResponse)
|
||||||
|
├── services/
|
||||||
|
│ ├── autolabel.py
|
||||||
|
│ ├── async_processing.py
|
||||||
|
│ ├── batch_upload.py
|
||||||
|
│ └── inference.py
|
||||||
|
├── core/
|
||||||
|
│ ├── auth.py
|
||||||
|
│ ├── rate_limiter.py
|
||||||
|
│ └── scheduler.py
|
||||||
|
└── workers/
|
||||||
|
├── async_queue.py
|
||||||
|
└── batch_queue.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Files Copied and Imports Updated
|
||||||
|
|
||||||
|
#### Admin API (✅ Complete)
|
||||||
|
- [x] `admin_routes.py` → `api/v1/admin/documents.py` (imports updated)
|
||||||
|
- [x] `admin_annotation_routes.py` → `api/v1/admin/annotations.py` (imports updated)
|
||||||
|
- [x] `admin_training_routes.py` → `api/v1/admin/training.py` (imports updated)
|
||||||
|
- [x] `api/v1/admin/__init__.py` created with exports
|
||||||
|
|
||||||
|
#### Public & Async API (✅ Complete)
|
||||||
|
- [x] `routes.py` → `api/v1/routes.py` (imports updated)
|
||||||
|
- [x] `async_routes.py` → `api/v1/async_api/routes.py` (imports updated)
|
||||||
|
- [x] `batch_upload_routes.py` → `api/v1/batch/routes.py` (copied, imports pending)
|
||||||
|
|
||||||
|
#### Schemas (✅ Complete)
|
||||||
|
- [x] `admin_schemas.py` → `schemas/admin.py`
|
||||||
|
- [x] `schemas.py` → `schemas/inference.py`
|
||||||
|
- [x] `schemas/common.py` created
|
||||||
|
- [x] `schemas/__init__.py` created with exports
|
||||||
|
|
||||||
|
#### Services (✅ Complete)
|
||||||
|
- [x] `admin_autolabel.py` → `services/autolabel.py`
|
||||||
|
- [x] `async_service.py` → `services/async_processing.py`
|
||||||
|
- [x] `batch_upload_service.py` → `services/batch_upload.py`
|
||||||
|
- [x] `services.py` → `services/inference.py`
|
||||||
|
- [x] `services/__init__.py` created
|
||||||
|
|
||||||
|
#### Core Components (✅ Complete)
|
||||||
|
- [x] `admin_auth.py` → `core/auth.py`
|
||||||
|
- [x] `rate_limiter.py` → `core/rate_limiter.py`
|
||||||
|
- [x] `admin_scheduler.py` → `core/scheduler.py`
|
||||||
|
- [x] `core/__init__.py` created
|
||||||
|
|
||||||
|
#### Workers (✅ Complete)
|
||||||
|
- [x] `async_queue.py` → `workers/async_queue.py`
|
||||||
|
- [x] `batch_queue.py` → `workers/batch_queue.py`
|
||||||
|
- [x] `workers/__init__.py` created
|
||||||
|
|
||||||
|
#### Main App (✅ Complete)
|
||||||
|
- [x] `app.py` imports updated to use new structure
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ⏳ Remaining Work
|
||||||
|
|
||||||
|
### 1. Update Remaining File Imports (HIGH PRIORITY)
|
||||||
|
|
||||||
|
Files that need import updates:
|
||||||
|
- [ ] `api/v1/batch/routes.py` - update to use new schema/service imports
|
||||||
|
- [ ] `services/autolabel.py` - may need import updates if it references old paths
|
||||||
|
- [ ] `services/async_processing.py` - check for old import references
|
||||||
|
- [ ] `services/batch_upload.py` - check for old import references
|
||||||
|
- [ ] `services/inference.py` - check for old import references
|
||||||
|
|
||||||
|
### 2. Update ALL Test Files (CRITICAL)
|
||||||
|
|
||||||
|
Test files need to import from new locations. Pattern:
|
||||||
|
|
||||||
|
**Old:**
|
||||||
|
```python
|
||||||
|
from src.web.admin_routes import create_admin_router
|
||||||
|
from src.web.admin_schemas import DocumentItem
|
||||||
|
from src.web.admin_auth import validate_admin_token
|
||||||
|
```
|
||||||
|
|
||||||
|
**New:**
|
||||||
|
```python
|
||||||
|
from src.web.api.v1.admin import create_admin_router
|
||||||
|
from src.web.schemas.admin import DocumentItem
|
||||||
|
from src.web.core.auth import validate_admin_token
|
||||||
|
```
|
||||||
|
|
||||||
|
Test files to update:
|
||||||
|
- [ ] `tests/web/test_admin_annotations.py`
|
||||||
|
- [ ] `tests/web/test_admin_auth.py`
|
||||||
|
- [ ] `tests/web/test_admin_routes.py`
|
||||||
|
- [ ] `tests/web/test_admin_routes_enhanced.py`
|
||||||
|
- [ ] `tests/web/test_admin_training.py`
|
||||||
|
- [ ] `tests/web/test_annotation_locks.py`
|
||||||
|
- [ ] `tests/web/test_annotation_phase5.py`
|
||||||
|
- [ ] `tests/web/test_async_queue.py`
|
||||||
|
- [ ] `tests/web/test_async_routes.py`
|
||||||
|
- [ ] `tests/web/test_async_service.py`
|
||||||
|
- [ ] `tests/web/test_autolabel_with_locks.py`
|
||||||
|
- [ ] `tests/web/test_batch_queue.py`
|
||||||
|
- [ ] `tests/web/test_batch_upload_routes.py`
|
||||||
|
- [ ] `tests/web/test_batch_upload_service.py`
|
||||||
|
- [ ] `tests/web/test_rate_limiter.py`
|
||||||
|
- [ ] `tests/web/test_training_phase4.py`
|
||||||
|
|
||||||
|
### 3. Create Backward Compatibility Layer (OPTIONAL)
|
||||||
|
|
||||||
|
Keep old imports working temporarily:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# src/web/admin_routes.py (temporary compatibility shim)
|
||||||
|
\"\"\"
|
||||||
|
DEPRECATED: Use src.web.api.v1.admin.documents instead.
|
||||||
|
This file will be removed in next version.
|
||||||
|
\"\"\"
|
||||||
|
import warnings
|
||||||
|
from src.web.api.v1.admin.documents import *
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"Importing from src.web.admin_routes is deprecated. "
|
||||||
|
"Use src.web.api.v1.admin.documents instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Verify and Test
|
||||||
|
|
||||||
|
1. Run tests:
|
||||||
|
```bash
|
||||||
|
pytest tests/web/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Check for any import errors:
|
||||||
|
```bash
|
||||||
|
python -c "from src.web.app import create_app; create_app()"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Start server and test endpoints:
|
||||||
|
```bash
|
||||||
|
python run_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Clean Up Old Files (ONLY AFTER TESTS PASS)
|
||||||
|
|
||||||
|
Old files to remove:
|
||||||
|
- `src/web/admin_*.py` (7 files)
|
||||||
|
- `src/web/async_*.py` (3 files)
|
||||||
|
- `src/web/batch_*.py` (3 files)
|
||||||
|
- `src/web/routes.py`
|
||||||
|
- `src/web/services.py`
|
||||||
|
- `src/web/schemas.py`
|
||||||
|
- `src/web/rate_limiter.py`
|
||||||
|
|
||||||
|
Keep these files (don't remove):
|
||||||
|
- `src/web/__init__.py`
|
||||||
|
- `src/web/app.py`
|
||||||
|
- `src/web/config.py`
|
||||||
|
- `src/web/dependencies.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🎯 Next Immediate Steps
|
||||||
|
|
||||||
|
1. **Update batch/routes.py imports** - Quick fix for remaining API route
|
||||||
|
2. **Update test file imports** - Critical for verification
|
||||||
|
3. **Run test suite** - Verify nothing broke
|
||||||
|
4. **Fix any import errors** - Address failures
|
||||||
|
5. **Remove old files** - Clean up after tests pass
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📊 Migration Impact Summary
|
||||||
|
|
||||||
|
| Category | Files Moved | Imports Updated | Status |
|
||||||
|
|----------|-------------|-----------------|--------|
|
||||||
|
| API Routes | 7 | 5/7 | 🟡 In Progress |
|
||||||
|
| Schemas | 3 | 3/3 | ✅ Complete |
|
||||||
|
| Services | 4 | 0/4 | ⚠️ Pending |
|
||||||
|
| Core | 3 | 3/3 | ✅ Complete |
|
||||||
|
| Workers | 2 | 2/2 | ✅ Complete |
|
||||||
|
| Tests | 0 | 0/16 | ❌ Not Started |
|
||||||
|
|
||||||
|
**Overall Progress: 65%**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🚀 Benefits After Migration
|
||||||
|
|
||||||
|
1. **Better Organization**: Clear separation by function
|
||||||
|
2. **Easier Navigation**: Find files by purpose, not prefix
|
||||||
|
3. **Scalability**: Easy to add new API versions
|
||||||
|
4. **Standard Structure**: Follows FastAPI best practices
|
||||||
|
5. **Maintainability**: Each module has single responsibility
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📝 Notes
|
||||||
|
|
||||||
|
- All original files are still in place (no data loss risk)
|
||||||
|
- New structure is operational but needs import updates
|
||||||
|
- Backward compatibility can be added if needed
|
||||||
|
- Tests will validate the migration success
|
||||||
5
frontend/.env.example
Normal file
5
frontend/.env.example
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Backend API URL
|
||||||
|
VITE_API_URL=http://localhost:8000
|
||||||
|
|
||||||
|
# WebSocket URL (for future real-time updates)
|
||||||
|
VITE_WS_URL=ws://localhost:8000/ws
|
||||||
24
frontend/.gitignore
vendored
Normal file
24
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
20
frontend/README.md
Normal file
20
frontend/README.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
<div align="center">
|
||||||
|
<img width="1200" height="475" alt="GHBanner" src="https://github.com/user-attachments/assets/0aa67016-6eaf-458a-adb2-6e31a0763ed6" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# Run and deploy your AI Studio app
|
||||||
|
|
||||||
|
This contains everything you need to run your app locally.
|
||||||
|
|
||||||
|
View your app in AI Studio: https://ai.studio/apps/drive/13hqd80ft4g_LngMYB8LLJxx2XU8C_eI4
|
||||||
|
|
||||||
|
## Run Locally
|
||||||
|
|
||||||
|
**Prerequisites:** Node.js
|
||||||
|
|
||||||
|
|
||||||
|
1. Install dependencies:
|
||||||
|
`npm install`
|
||||||
|
2. Set the `GEMINI_API_KEY` in [.env.local](.env.local) to your Gemini API key
|
||||||
|
3. Run the app:
|
||||||
|
`npm run dev`
|
||||||
240
frontend/REFACTORING_PLAN.md
Normal file
240
frontend/REFACTORING_PLAN.md
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
# Frontend Refactoring Plan
|
||||||
|
|
||||||
|
## Current Structure Issues
|
||||||
|
|
||||||
|
1. **Flat component organization** - All components in one directory
|
||||||
|
2. **Mock data only** - No real API integration
|
||||||
|
3. **No state management** - Props drilling everywhere
|
||||||
|
4. **CDN dependencies** - Should use npm packages
|
||||||
|
5. **Manual routing** - Using useState instead of react-router
|
||||||
|
6. **No TypeScript integration with backend** - Types don't match API schemas
|
||||||
|
|
||||||
|
## Recommended Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
frontend/
|
||||||
|
├── public/
|
||||||
|
│ └── favicon.ico
|
||||||
|
│
|
||||||
|
├── src/
|
||||||
|
│ ├── api/ # API Layer
|
||||||
|
│ │ ├── client.ts # Axios instance + interceptors
|
||||||
|
│ │ ├── types.ts # API request/response types
|
||||||
|
│ │ └── endpoints/
|
||||||
|
│ │ ├── documents.ts # GET /api/v1/admin/documents
|
||||||
|
│ │ ├── annotations.ts # GET/POST /api/v1/admin/documents/{id}/annotations
|
||||||
|
│ │ ├── training.ts # GET/POST /api/v1/admin/training/*
|
||||||
|
│ │ ├── inference.ts # POST /api/v1/infer
|
||||||
|
│ │ └── async.ts # POST /api/v1/async/submit
|
||||||
|
│ │
|
||||||
|
│ ├── components/
|
||||||
|
│ │ ├── common/ # Reusable components
|
||||||
|
│ │ │ ├── Badge.tsx
|
||||||
|
│ │ │ ├── Button.tsx
|
||||||
|
│ │ │ ├── Input.tsx
|
||||||
|
│ │ │ ├── Modal.tsx
|
||||||
|
│ │ │ ├── Table.tsx
|
||||||
|
│ │ │ ├── ProgressBar.tsx
|
||||||
|
│ │ │ └── StatusBadge.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── layout/ # Layout components
|
||||||
|
│ │ │ ├── TopNav.tsx
|
||||||
|
│ │ │ ├── Sidebar.tsx
|
||||||
|
│ │ │ └── PageHeader.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── documents/ # Document-specific components
|
||||||
|
│ │ │ ├── DocumentTable.tsx
|
||||||
|
│ │ │ ├── DocumentFilters.tsx
|
||||||
|
│ │ │ ├── DocumentRow.tsx
|
||||||
|
│ │ │ ├── UploadModal.tsx
|
||||||
|
│ │ │ └── BatchUploadModal.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── annotations/ # Annotation components
|
||||||
|
│ │ │ ├── AnnotationCanvas.tsx
|
||||||
|
│ │ │ ├── AnnotationBox.tsx
|
||||||
|
│ │ │ ├── AnnotationTable.tsx
|
||||||
|
│ │ │ ├── FieldEditor.tsx
|
||||||
|
│ │ │ └── VerificationPanel.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ └── training/ # Training components
|
||||||
|
│ │ ├── DocumentSelector.tsx
|
||||||
|
│ │ ├── TrainingConfig.tsx
|
||||||
|
│ │ ├── TrainingJobList.tsx
|
||||||
|
│ │ ├── ModelCard.tsx
|
||||||
|
│ │ └── MetricsChart.tsx
|
||||||
|
│ │
|
||||||
|
│ ├── pages/ # Page-level components
|
||||||
|
│ │ ├── DocumentsPage.tsx # Was Dashboard.tsx
|
||||||
|
│ │ ├── DocumentDetailPage.tsx # Was DocumentDetail.tsx
|
||||||
|
│ │ ├── TrainingPage.tsx # Was Training.tsx
|
||||||
|
│ │ ├── ModelsPage.tsx # Was Models.tsx
|
||||||
|
│ │ └── InferencePage.tsx # New: Test inference
|
||||||
|
│ │
|
||||||
|
│ ├── hooks/ # Custom React Hooks
|
||||||
|
│ │ ├── useDocuments.ts # Document CRUD + listing
|
||||||
|
│ │ ├── useAnnotations.ts # Annotation management
|
||||||
|
│ │ ├── useTraining.ts # Training jobs
|
||||||
|
│ │ ├── usePolling.ts # Auto-refresh for async jobs
|
||||||
|
│ │ └── useDebounce.ts # Debounce search inputs
|
||||||
|
│ │
|
||||||
|
│ ├── store/ # State Management (Zustand)
|
||||||
|
│ │ ├── documentsStore.ts
|
||||||
|
│ │ ├── annotationsStore.ts
|
||||||
|
│ │ ├── trainingStore.ts
|
||||||
|
│ │ └── uiStore.ts
|
||||||
|
│ │
|
||||||
|
│ ├── types/ # TypeScript Types
|
||||||
|
│ │ ├── index.ts
|
||||||
|
│ │ ├── document.ts
|
||||||
|
│ │ ├── annotation.ts
|
||||||
|
│ │ ├── training.ts
|
||||||
|
│ │ └── api.ts
|
||||||
|
│ │
|
||||||
|
│ ├── utils/ # Utility Functions
|
||||||
|
│ │ ├── formatters.ts # Date, currency, etc.
|
||||||
|
│ │ ├── validators.ts # Form validation
|
||||||
|
│ │ └── constants.ts # Field definitions, statuses
|
||||||
|
│ │
|
||||||
|
│ ├── styles/
|
||||||
|
│ │ └── index.css # Tailwind entry
|
||||||
|
│ │
|
||||||
|
│ ├── App.tsx
|
||||||
|
│ ├── main.tsx
|
||||||
|
│ └── router.tsx # React Router config
|
||||||
|
│
|
||||||
|
├── .env.example
|
||||||
|
├── package.json
|
||||||
|
├── tsconfig.json
|
||||||
|
├── vite.config.ts
|
||||||
|
├── tailwind.config.js
|
||||||
|
├── postcss.config.js
|
||||||
|
└── index.html
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration Steps
|
||||||
|
|
||||||
|
### Phase 1: Setup Infrastructure
|
||||||
|
- [ ] Install dependencies (axios, react-router, zustand, @tanstack/react-query)
|
||||||
|
- [ ] Setup local Tailwind (remove CDN)
|
||||||
|
- [ ] Create API client with interceptors
|
||||||
|
- [ ] Add environment variables (.env.local with VITE_API_URL)
|
||||||
|
|
||||||
|
### Phase 2: Create API Layer
|
||||||
|
- [ ] Create `src/api/client.ts` with axios instance
|
||||||
|
- [ ] Create `src/api/endpoints/documents.ts` matching backend API
|
||||||
|
- [ ] Create `src/api/endpoints/annotations.ts`
|
||||||
|
- [ ] Create `src/api/endpoints/training.ts`
|
||||||
|
- [ ] Add types matching backend schemas
|
||||||
|
|
||||||
|
### Phase 3: Reorganize Components
|
||||||
|
- [ ] Move existing components to new structure
|
||||||
|
- [ ] Split large components (Dashboard > DocumentTable + DocumentFilters + DocumentRow)
|
||||||
|
- [ ] Extract reusable components (Badge, Button already done)
|
||||||
|
- [ ] Create layout components (TopNav, Sidebar)
|
||||||
|
|
||||||
|
### Phase 4: Add Routing
|
||||||
|
- [ ] Install react-router-dom
|
||||||
|
- [ ] Create router.tsx with routes
|
||||||
|
- [ ] Update App.tsx to use RouterProvider
|
||||||
|
- [ ] Add navigation links
|
||||||
|
|
||||||
|
### Phase 5: State Management
|
||||||
|
- [ ] Create custom hooks (useDocuments, useAnnotations)
|
||||||
|
- [ ] Use @tanstack/react-query for server state
|
||||||
|
- [ ] Add Zustand stores for UI state
|
||||||
|
- [ ] Replace mock data with API calls
|
||||||
|
|
||||||
|
### Phase 6: Backend Integration
|
||||||
|
- [ ] Update CORS settings in backend
|
||||||
|
- [ ] Test all API endpoints
|
||||||
|
- [ ] Add error handling
|
||||||
|
- [ ] Add loading states
|
||||||
|
|
||||||
|
## Dependencies to Add
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"dependencies": {
|
||||||
|
"react-router-dom": "^6.22.0",
|
||||||
|
"axios": "^1.6.7",
|
||||||
|
"zustand": "^4.5.0",
|
||||||
|
"@tanstack/react-query": "^5.20.0",
|
||||||
|
"date-fns": "^3.3.0",
|
||||||
|
"clsx": "^2.1.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"tailwindcss": "^3.4.1",
|
||||||
|
"autoprefixer": "^10.4.17",
|
||||||
|
"postcss": "^8.4.35"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Files to Create
|
||||||
|
|
||||||
|
### tailwind.config.js
|
||||||
|
```javascript
|
||||||
|
export default {
|
||||||
|
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||||
|
theme: {
|
||||||
|
extend: {
|
||||||
|
colors: {
|
||||||
|
warm: {
|
||||||
|
bg: '#FAFAF8',
|
||||||
|
card: '#FFFFFF',
|
||||||
|
hover: '#F1F0ED',
|
||||||
|
selected: '#ECEAE6',
|
||||||
|
border: '#E6E4E1',
|
||||||
|
divider: '#D8D6D2',
|
||||||
|
text: {
|
||||||
|
primary: '#121212',
|
||||||
|
secondary: '#2A2A2A',
|
||||||
|
muted: '#6B6B6B',
|
||||||
|
disabled: '#9A9A9A',
|
||||||
|
},
|
||||||
|
state: {
|
||||||
|
success: '#3E4A3A',
|
||||||
|
error: '#4A3A3A',
|
||||||
|
warning: '#4A4A3A',
|
||||||
|
info: '#3A3A3A',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### .env.example
|
||||||
|
```bash
|
||||||
|
VITE_API_URL=http://localhost:8000
|
||||||
|
VITE_WS_URL=ws://localhost:8000/ws
|
||||||
|
```
|
||||||
|
|
||||||
|
## Type Generation from Backend
|
||||||
|
|
||||||
|
Consider generating TypeScript types from Python Pydantic schemas:
|
||||||
|
- Option 1: Use `datamodel-code-generator` to convert schemas
|
||||||
|
- Option 2: Manually maintain types in `src/types/api.ts`
|
||||||
|
- Option 3: Use OpenAPI spec + openapi-typescript-codegen
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
- Unit tests: Vitest for components
|
||||||
|
- Integration tests: React Testing Library
|
||||||
|
- E2E tests: Playwright (matching backend)
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- Code splitting by route
|
||||||
|
- Lazy load heavy components (AnnotationCanvas)
|
||||||
|
- Optimize re-renders with React.memo
|
||||||
|
- Use virtual scrolling for large tables
|
||||||
|
- Image lazy loading for document previews
|
||||||
|
|
||||||
|
## Accessibility
|
||||||
|
|
||||||
|
- Proper ARIA labels
|
||||||
|
- Keyboard navigation
|
||||||
|
- Focus management
|
||||||
|
- Color contrast compliance (already done with Warm Graphite theme)
|
||||||
256
frontend/SETUP.md
Normal file
256
frontend/SETUP.md
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
# Frontend Setup Guide
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Install Dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Configure Environment
|
||||||
|
|
||||||
|
Copy `.env.example` to `.env.local` and update if needed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env.local
|
||||||
|
```
|
||||||
|
|
||||||
|
Default configuration:
|
||||||
|
```
|
||||||
|
VITE_API_URL=http://localhost:8000
|
||||||
|
VITE_WS_URL=ws://localhost:8000/ws
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Start Backend API
|
||||||
|
|
||||||
|
Make sure the backend is running first:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# From project root
|
||||||
|
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python run_server.py"
|
||||||
|
```
|
||||||
|
|
||||||
|
Backend will be available at: http://localhost:8000
|
||||||
|
|
||||||
|
### 4. Start Frontend Dev Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
Frontend will be available at: http://localhost:3000
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
frontend/
|
||||||
|
├── src/
|
||||||
|
│ ├── api/ # API client layer
|
||||||
|
│ │ ├── client.ts # Axios instance with interceptors
|
||||||
|
│ │ ├── types.ts # API type definitions
|
||||||
|
│ │ └── endpoints/
|
||||||
|
│ │ ├── documents.ts # Document API calls
|
||||||
|
│ │ ├── annotations.ts # Annotation API calls
|
||||||
|
│ │ └── training.ts # Training API calls
|
||||||
|
│ │
|
||||||
|
│ ├── components/ # React components
|
||||||
|
│ │ └── Dashboard.tsx # Updated with real API integration
|
||||||
|
│ │
|
||||||
|
│ ├── hooks/ # Custom React Hooks
|
||||||
|
│ │ ├── useDocuments.ts
|
||||||
|
│ │ ├── useDocumentDetail.ts
|
||||||
|
│ │ ├── useAnnotations.ts
|
||||||
|
│ │ └── useTraining.ts
|
||||||
|
│ │
|
||||||
|
│ ├── styles/
|
||||||
|
│ │ └── index.css # Tailwind CSS entry
|
||||||
|
│ │
|
||||||
|
│ ├── App.tsx
|
||||||
|
│ └── main.tsx # App entry point with QueryClient
|
||||||
|
│
|
||||||
|
├── components/ # Legacy components (to be migrated)
|
||||||
|
│ ├── Badge.tsx
|
||||||
|
│ ├── Button.tsx
|
||||||
|
│ ├── Layout.tsx
|
||||||
|
│ ├── DocumentDetail.tsx
|
||||||
|
│ ├── Training.tsx
|
||||||
|
│ ├── Models.tsx
|
||||||
|
│ └── UploadModal.tsx
|
||||||
|
│
|
||||||
|
├── tailwind.config.js # Tailwind configuration
|
||||||
|
├── postcss.config.js
|
||||||
|
├── vite.config.ts
|
||||||
|
├── package.json
|
||||||
|
└── index.html
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Technologies
|
||||||
|
|
||||||
|
- **React 19** - UI framework
|
||||||
|
- **TypeScript** - Type safety
|
||||||
|
- **Vite** - Build tool
|
||||||
|
- **Tailwind CSS** - Styling (Warm Graphite theme)
|
||||||
|
- **Axios** - HTTP client
|
||||||
|
- **@tanstack/react-query** - Server state management
|
||||||
|
- **lucide-react** - Icon library
|
||||||
|
|
||||||
|
## API Integration
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
The app stores admin token in localStorage:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
localStorage.setItem('admin_token', 'your-token')
|
||||||
|
```
|
||||||
|
|
||||||
|
All API requests automatically include the `X-Admin-Token` header.
|
||||||
|
|
||||||
|
### Available Hooks
|
||||||
|
|
||||||
|
#### useDocuments
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const {
|
||||||
|
documents,
|
||||||
|
total,
|
||||||
|
isLoading,
|
||||||
|
uploadDocument,
|
||||||
|
deleteDocument,
|
||||||
|
triggerAutoLabel,
|
||||||
|
} = useDocuments({ status: 'labeled', limit: 20 })
|
||||||
|
```
|
||||||
|
|
||||||
|
#### useDocumentDetail
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const { document, annotations, isLoading } = useDocumentDetail(documentId)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### useAnnotations
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const {
|
||||||
|
createAnnotation,
|
||||||
|
updateAnnotation,
|
||||||
|
deleteAnnotation,
|
||||||
|
verifyAnnotation,
|
||||||
|
overrideAnnotation,
|
||||||
|
} = useAnnotations(documentId)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### useTraining
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const {
|
||||||
|
models,
|
||||||
|
isLoadingModels,
|
||||||
|
startTraining,
|
||||||
|
downloadModel,
|
||||||
|
} = useTraining()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features Implemented
|
||||||
|
|
||||||
|
### Phase 1 (Completed)
|
||||||
|
- ✅ API client with axios interceptors
|
||||||
|
- ✅ Type-safe API endpoints
|
||||||
|
- ✅ React Query for server state
|
||||||
|
- ✅ Custom hooks for all APIs
|
||||||
|
- ✅ Dashboard with real data
|
||||||
|
- ✅ Local Tailwind CSS
|
||||||
|
- ✅ Environment configuration
|
||||||
|
- ✅ CORS configured in backend
|
||||||
|
|
||||||
|
### Phase 2 (TODO)
|
||||||
|
- [ ] Update DocumentDetail to use useDocumentDetail
|
||||||
|
- [ ] Update Training page to use useTraining hooks
|
||||||
|
- [ ] Update Models page with real data
|
||||||
|
- [ ] Add UploadModal integration with API
|
||||||
|
- [ ] Add react-router for proper routing
|
||||||
|
- [ ] Add error boundary
|
||||||
|
- [ ] Add loading states
|
||||||
|
- [ ] Add toast notifications
|
||||||
|
|
||||||
|
### Phase 3 (TODO)
|
||||||
|
- [ ] Annotation canvas with real data
|
||||||
|
- [ ] Batch upload functionality
|
||||||
|
- [ ] Auto-label progress polling
|
||||||
|
- [ ] Training job monitoring
|
||||||
|
- [ ] Model download functionality
|
||||||
|
- [ ] Search and filtering
|
||||||
|
- [ ] Pagination
|
||||||
|
|
||||||
|
## Development Tips
|
||||||
|
|
||||||
|
### Hot Module Replacement
|
||||||
|
|
||||||
|
Vite supports HMR. Changes will reflect immediately without page reload.
|
||||||
|
|
||||||
|
### API Debugging
|
||||||
|
|
||||||
|
Check browser console for API requests:
|
||||||
|
- Network tab shows all requests/responses
|
||||||
|
- Axios interceptors log errors automatically
|
||||||
|
|
||||||
|
### Type Safety
|
||||||
|
|
||||||
|
TypeScript types in `src/api/types.ts` match backend Pydantic schemas.
|
||||||
|
|
||||||
|
To regenerate types from backend:
|
||||||
|
```bash
|
||||||
|
# TODO: Add type generation script
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend API Documentation
|
||||||
|
|
||||||
|
Visit http://localhost:8000/docs for interactive API documentation (Swagger UI).
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### CORS Errors
|
||||||
|
|
||||||
|
If you see CORS errors:
|
||||||
|
1. Check backend is running at http://localhost:8000
|
||||||
|
2. Verify CORS settings in `src/web/app.py`
|
||||||
|
3. Check `.env.local` has correct `VITE_API_URL`
|
||||||
|
|
||||||
|
### Module Not Found
|
||||||
|
|
||||||
|
If imports fail:
|
||||||
|
```bash
|
||||||
|
rm -rf node_modules package-lock.json
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
### Types Not Matching
|
||||||
|
|
||||||
|
If API responses don't match types:
|
||||||
|
1. Check backend version is up-to-date
|
||||||
|
2. Verify types in `src/api/types.ts`
|
||||||
|
3. Check API response in Network tab
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. Run `npm install` to install dependencies
|
||||||
|
2. Start backend server
|
||||||
|
3. Run `npm run dev` to start frontend
|
||||||
|
4. Open http://localhost:3000
|
||||||
|
5. Create an admin token via backend API
|
||||||
|
6. Store token in localStorage via browser console:
|
||||||
|
```javascript
|
||||||
|
localStorage.setItem('admin_token', 'your-token-here')
|
||||||
|
```
|
||||||
|
7. Refresh page to see authenticated API calls
|
||||||
|
|
||||||
|
## Production Build
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run build
|
||||||
|
npm run preview # Preview production build
|
||||||
|
```
|
||||||
|
|
||||||
|
Build output will be in `dist/` directory.
|
||||||
15
frontend/index.html
Normal file
15
frontend/index.html
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Graphite Annotator - Invoice Field Extraction</title>
|
||||||
|
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||||
|
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||||
|
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
5
frontend/metadata.json
Normal file
5
frontend/metadata.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"name": "Graphite Annotator",
|
||||||
|
"description": "A professional, warm graphite themed document annotation and training tool for enterprise use cases.",
|
||||||
|
"requestFramePermissions": []
|
||||||
|
}
|
||||||
3510
frontend/package-lock.json
generated
Normal file
3510
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
32
frontend/package.json
Normal file
32
frontend/package.json
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"name": "graphite-annotator",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^19.2.3",
|
||||||
|
"react-dom": "^19.2.3",
|
||||||
|
"lucide-react": "^0.563.0",
|
||||||
|
"recharts": "^3.7.0",
|
||||||
|
"axios": "^1.6.7",
|
||||||
|
"react-router-dom": "^6.22.0",
|
||||||
|
"zustand": "^4.5.0",
|
||||||
|
"@tanstack/react-query": "^5.20.0",
|
||||||
|
"date-fns": "^3.3.0",
|
||||||
|
"clsx": "^2.1.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/node": "^22.14.0",
|
||||||
|
"@vitejs/plugin-react": "^5.0.0",
|
||||||
|
"typescript": "~5.8.2",
|
||||||
|
"vite": "^6.2.0",
|
||||||
|
"tailwindcss": "^3.4.1",
|
||||||
|
"autoprefixer": "^10.4.17",
|
||||||
|
"postcss": "^8.4.35"
|
||||||
|
}
|
||||||
|
}
|
||||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export default {
|
||||||
|
plugins: {
|
||||||
|
tailwindcss: {},
|
||||||
|
autoprefixer: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
73
frontend/src/App.tsx
Normal file
73
frontend/src/App.tsx
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import React, { useState, useEffect } from 'react'
|
||||||
|
import { Layout } from './components/Layout'
|
||||||
|
import { DashboardOverview } from './components/DashboardOverview'
|
||||||
|
import { Dashboard } from './components/Dashboard'
|
||||||
|
import { DocumentDetail } from './components/DocumentDetail'
|
||||||
|
import { Training } from './components/Training'
|
||||||
|
import { Models } from './components/Models'
|
||||||
|
import { Login } from './components/Login'
|
||||||
|
import { InferenceDemo } from './components/InferenceDemo'
|
||||||
|
|
||||||
|
const App: React.FC = () => {
|
||||||
|
const [currentView, setCurrentView] = useState('dashboard')
|
||||||
|
const [selectedDocId, setSelectedDocId] = useState<string | null>(null)
|
||||||
|
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const token = localStorage.getItem('admin_token')
|
||||||
|
setIsAuthenticated(!!token)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handleNavigate = (view: string, docId?: string) => {
|
||||||
|
setCurrentView(view)
|
||||||
|
if (docId) {
|
||||||
|
setSelectedDocId(docId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleLogin = (token: string) => {
|
||||||
|
setIsAuthenticated(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleLogout = () => {
|
||||||
|
localStorage.removeItem('admin_token')
|
||||||
|
setIsAuthenticated(false)
|
||||||
|
setCurrentView('documents')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isAuthenticated) {
|
||||||
|
return <Login onLogin={handleLogin} />
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderContent = () => {
|
||||||
|
switch (currentView) {
|
||||||
|
case 'dashboard':
|
||||||
|
return <DashboardOverview onNavigate={handleNavigate} />
|
||||||
|
case 'documents':
|
||||||
|
return <Dashboard onNavigate={handleNavigate} />
|
||||||
|
case 'detail':
|
||||||
|
return (
|
||||||
|
<DocumentDetail
|
||||||
|
docId={selectedDocId || '1'}
|
||||||
|
onBack={() => setCurrentView('documents')}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
case 'demo':
|
||||||
|
return <InferenceDemo />
|
||||||
|
case 'training':
|
||||||
|
return <Training />
|
||||||
|
case 'models':
|
||||||
|
return <Models />
|
||||||
|
default:
|
||||||
|
return <DashboardOverview onNavigate={handleNavigate} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Layout activeView={currentView} onNavigate={handleNavigate} onLogout={handleLogout}>
|
||||||
|
{renderContent()}
|
||||||
|
</Layout>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App
|
||||||
41
frontend/src/api/client.ts
Normal file
41
frontend/src/api/client.ts
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import axios, { AxiosInstance, AxiosError } from 'axios'
|
||||||
|
|
||||||
|
const apiClient: AxiosInstance = axios.create({
|
||||||
|
baseURL: import.meta.env.VITE_API_URL || 'http://localhost:8000',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
timeout: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
apiClient.interceptors.request.use(
|
||||||
|
(config) => {
|
||||||
|
const token = localStorage.getItem('admin_token')
|
||||||
|
if (token) {
|
||||||
|
config.headers['X-Admin-Token'] = token
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
},
|
||||||
|
(error) => {
|
||||||
|
return Promise.reject(error)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
apiClient.interceptors.response.use(
|
||||||
|
(response) => response,
|
||||||
|
(error: AxiosError) => {
|
||||||
|
if (error.response?.status === 401) {
|
||||||
|
console.warn('Authentication required. Please set admin_token in localStorage.')
|
||||||
|
// Don't redirect to avoid infinite loop
|
||||||
|
// User should manually set: localStorage.setItem('admin_token', 'your-token')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error.response?.status === 429) {
|
||||||
|
console.error('Rate limit exceeded')
|
||||||
|
}
|
||||||
|
|
||||||
|
return Promise.reject(error)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
export default apiClient
|
||||||
66
frontend/src/api/endpoints/annotations.ts
Normal file
66
frontend/src/api/endpoints/annotations.ts
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import apiClient from '../client'
|
||||||
|
import type {
|
||||||
|
AnnotationItem,
|
||||||
|
CreateAnnotationRequest,
|
||||||
|
AnnotationOverrideRequest,
|
||||||
|
} from '../types'
|
||||||
|
|
||||||
|
export const annotationsApi = {
|
||||||
|
list: async (documentId: string): Promise<AnnotationItem[]> => {
|
||||||
|
const { data } = await apiClient.get(
|
||||||
|
`/api/v1/admin/documents/${documentId}/annotations`
|
||||||
|
)
|
||||||
|
return data.annotations
|
||||||
|
},
|
||||||
|
|
||||||
|
create: async (
|
||||||
|
documentId: string,
|
||||||
|
annotation: CreateAnnotationRequest
|
||||||
|
): Promise<AnnotationItem> => {
|
||||||
|
const { data } = await apiClient.post(
|
||||||
|
`/api/v1/admin/documents/${documentId}/annotations`,
|
||||||
|
annotation
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
update: async (
|
||||||
|
documentId: string,
|
||||||
|
annotationId: string,
|
||||||
|
updates: Partial<CreateAnnotationRequest>
|
||||||
|
): Promise<AnnotationItem> => {
|
||||||
|
const { data } = await apiClient.patch(
|
||||||
|
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`,
|
||||||
|
updates
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
delete: async (documentId: string, annotationId: string): Promise<void> => {
|
||||||
|
await apiClient.delete(
|
||||||
|
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`
|
||||||
|
)
|
||||||
|
},
|
||||||
|
|
||||||
|
verify: async (
|
||||||
|
documentId: string,
|
||||||
|
annotationId: string
|
||||||
|
): Promise<{ annotation_id: string; is_verified: boolean; message: string }> => {
|
||||||
|
const { data } = await apiClient.post(
|
||||||
|
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/verify`
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
override: async (
|
||||||
|
documentId: string,
|
||||||
|
annotationId: string,
|
||||||
|
overrideData: AnnotationOverrideRequest
|
||||||
|
): Promise<{ annotation_id: string; source: string; message: string }> => {
|
||||||
|
const { data } = await apiClient.patch(
|
||||||
|
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/override`,
|
||||||
|
overrideData
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
}
|
||||||
80
frontend/src/api/endpoints/documents.ts
Normal file
80
frontend/src/api/endpoints/documents.ts
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import apiClient from '../client'
|
||||||
|
import type {
|
||||||
|
DocumentListResponse,
|
||||||
|
DocumentDetailResponse,
|
||||||
|
DocumentItem,
|
||||||
|
UploadDocumentResponse,
|
||||||
|
} from '../types'
|
||||||
|
|
||||||
|
export const documentsApi = {
|
||||||
|
list: async (params?: {
|
||||||
|
status?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}): Promise<DocumentListResponse> => {
|
||||||
|
const { data } = await apiClient.get('/api/v1/admin/documents', { params })
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
|
||||||
|
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
upload: async (file: File): Promise<UploadDocumentResponse> => {
|
||||||
|
const formData = new FormData()
|
||||||
|
formData.append('file', file)
|
||||||
|
|
||||||
|
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'multipart/form-data',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
batchUpload: async (
|
||||||
|
files: File[],
|
||||||
|
csvFile?: File
|
||||||
|
): Promise<{ batch_id: string; message: string; documents_created: number }> => {
|
||||||
|
const formData = new FormData()
|
||||||
|
|
||||||
|
files.forEach((file) => {
|
||||||
|
formData.append('files', file)
|
||||||
|
})
|
||||||
|
|
||||||
|
if (csvFile) {
|
||||||
|
formData.append('csv_file', csvFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data } = await apiClient.post('/api/v1/admin/batch/upload', formData, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'multipart/form-data',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
delete: async (documentId: string): Promise<void> => {
|
||||||
|
await apiClient.delete(`/api/v1/admin/documents/${documentId}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
updateStatus: async (
|
||||||
|
documentId: string,
|
||||||
|
status: string
|
||||||
|
): Promise<DocumentItem> => {
|
||||||
|
const { data } = await apiClient.patch(
|
||||||
|
`/api/v1/admin/documents/${documentId}/status`,
|
||||||
|
null,
|
||||||
|
{ params: { status } }
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
triggerAutoLabel: async (documentId: string): Promise<{ message: string }> => {
|
||||||
|
const { data } = await apiClient.post(
|
||||||
|
`/api/v1/admin/documents/${documentId}/auto-label`
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
}
|
||||||
4
frontend/src/api/endpoints/index.ts
Normal file
4
frontend/src/api/endpoints/index.ts
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
export { documentsApi } from './documents'
|
||||||
|
export { annotationsApi } from './annotations'
|
||||||
|
export { trainingApi } from './training'
|
||||||
|
export { inferenceApi } from './inference'
|
||||||
16
frontend/src/api/endpoints/inference.ts
Normal file
16
frontend/src/api/endpoints/inference.ts
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import apiClient from '../client'
|
||||||
|
import type { InferenceResponse } from '../types'
|
||||||
|
|
||||||
|
export const inferenceApi = {
|
||||||
|
processDocument: async (file: File): Promise<InferenceResponse> => {
|
||||||
|
const formData = new FormData()
|
||||||
|
formData.append('file', file)
|
||||||
|
|
||||||
|
const { data } = await apiClient.post('/api/v1/infer', formData, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'multipart/form-data',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
}
|
||||||
74
frontend/src/api/endpoints/training.ts
Normal file
74
frontend/src/api/endpoints/training.ts
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import apiClient from '../client'
|
||||||
|
import type { TrainingModelsResponse, DocumentListResponse } from '../types'
|
||||||
|
|
||||||
|
export const trainingApi = {
|
||||||
|
getDocumentsForTraining: async (params?: {
|
||||||
|
has_annotations?: boolean
|
||||||
|
min_annotation_count?: number
|
||||||
|
exclude_used_in_training?: boolean
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}): Promise<DocumentListResponse> => {
|
||||||
|
const { data } = await apiClient.get('/api/v1/admin/training/documents', {
|
||||||
|
params,
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
getModels: async (params?: {
|
||||||
|
status?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}): Promise<TrainingModelsResponse> => {
|
||||||
|
const { data} = await apiClient.get('/api/v1/admin/training/models', {
|
||||||
|
params,
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
getTaskDetail: async (taskId: string) => {
|
||||||
|
const { data } = await apiClient.get(`/api/v1/admin/training/tasks/${taskId}`)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
startTraining: async (config: {
|
||||||
|
name: string
|
||||||
|
description?: string
|
||||||
|
document_ids: string[]
|
||||||
|
epochs?: number
|
||||||
|
batch_size?: number
|
||||||
|
model_base?: string
|
||||||
|
}) => {
|
||||||
|
// Convert frontend config to backend TrainingTaskCreate format
|
||||||
|
const taskRequest = {
|
||||||
|
name: config.name,
|
||||||
|
task_type: 'yolo',
|
||||||
|
description: config.description,
|
||||||
|
config: {
|
||||||
|
document_ids: config.document_ids,
|
||||||
|
epochs: config.epochs,
|
||||||
|
batch_size: config.batch_size,
|
||||||
|
base_model: config.model_base,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
const { data } = await apiClient.post('/api/v1/admin/training/tasks', taskRequest)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
cancelTask: async (taskId: string) => {
|
||||||
|
const { data } = await apiClient.post(
|
||||||
|
`/api/v1/admin/training/tasks/${taskId}/cancel`
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
|
downloadModel: async (taskId: string): Promise<Blob> => {
|
||||||
|
const { data } = await apiClient.get(
|
||||||
|
`/api/v1/admin/training/models/${taskId}/download`,
|
||||||
|
{
|
||||||
|
responseType: 'blob',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
}
|
||||||
173
frontend/src/api/types.ts
Normal file
173
frontend/src/api/types.ts
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
export interface DocumentItem {
|
||||||
|
document_id: string
|
||||||
|
filename: string
|
||||||
|
file_size: number
|
||||||
|
content_type: string
|
||||||
|
page_count: number
|
||||||
|
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||||
|
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||||
|
auto_label_error: string | null
|
||||||
|
upload_source: string
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
annotation_count?: number
|
||||||
|
annotation_sources?: {
|
||||||
|
manual: number
|
||||||
|
auto: number
|
||||||
|
verified: number
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DocumentListResponse {
|
||||||
|
documents: DocumentItem[]
|
||||||
|
total: number
|
||||||
|
limit: number
|
||||||
|
offset: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AnnotationItem {
|
||||||
|
annotation_id: string
|
||||||
|
page_number: number
|
||||||
|
class_id: number
|
||||||
|
class_name: string
|
||||||
|
bbox: {
|
||||||
|
x: number
|
||||||
|
y: number
|
||||||
|
width: number
|
||||||
|
height: number
|
||||||
|
}
|
||||||
|
normalized_bbox: {
|
||||||
|
x_center: number
|
||||||
|
y_center: number
|
||||||
|
width: number
|
||||||
|
height: number
|
||||||
|
}
|
||||||
|
text_value: string | null
|
||||||
|
confidence: number | null
|
||||||
|
source: 'manual' | 'auto'
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DocumentDetailResponse {
|
||||||
|
document_id: string
|
||||||
|
filename: string
|
||||||
|
file_size: number
|
||||||
|
content_type: string
|
||||||
|
page_count: number
|
||||||
|
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||||
|
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||||
|
auto_label_error: string | null
|
||||||
|
upload_source: string
|
||||||
|
batch_id: string | null
|
||||||
|
csv_field_values: Record<string, string> | null
|
||||||
|
can_annotate: boolean
|
||||||
|
annotation_lock_until: string | null
|
||||||
|
annotations: AnnotationItem[]
|
||||||
|
image_urls: string[]
|
||||||
|
training_history: Array<{
|
||||||
|
task_id: string
|
||||||
|
name: string
|
||||||
|
trained_at: string
|
||||||
|
model_metrics: {
|
||||||
|
mAP: number | null
|
||||||
|
precision: number | null
|
||||||
|
recall: number | null
|
||||||
|
} | null
|
||||||
|
}>
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TrainingTask {
|
||||||
|
task_id: string
|
||||||
|
admin_token: string
|
||||||
|
name: string
|
||||||
|
description: string | null
|
||||||
|
status: 'pending' | 'running' | 'completed' | 'failed'
|
||||||
|
task_type: string
|
||||||
|
config: Record<string, unknown>
|
||||||
|
started_at: string | null
|
||||||
|
completed_at: string | null
|
||||||
|
error_message: string | null
|
||||||
|
result_metrics: Record<string, unknown>
|
||||||
|
model_path: string | null
|
||||||
|
document_count: number
|
||||||
|
metrics_mAP: number | null
|
||||||
|
metrics_precision: number | null
|
||||||
|
metrics_recall: number | null
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TrainingModelsResponse {
|
||||||
|
models: TrainingTask[]
|
||||||
|
total: number
|
||||||
|
limit: number
|
||||||
|
offset: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ErrorResponse {
|
||||||
|
detail: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UploadDocumentResponse {
|
||||||
|
document_id: string
|
||||||
|
filename: string
|
||||||
|
status: string
|
||||||
|
message: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CreateAnnotationRequest {
|
||||||
|
page_number: number
|
||||||
|
class_id: number
|
||||||
|
bbox: {
|
||||||
|
x: number
|
||||||
|
y: number
|
||||||
|
width: number
|
||||||
|
height: number
|
||||||
|
}
|
||||||
|
text_value?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AnnotationOverrideRequest {
|
||||||
|
text_value?: string
|
||||||
|
bbox?: {
|
||||||
|
x: number
|
||||||
|
y: number
|
||||||
|
width: number
|
||||||
|
height: number
|
||||||
|
}
|
||||||
|
class_id?: number
|
||||||
|
class_name?: string
|
||||||
|
reason?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CrossValidationResult {
|
||||||
|
is_valid: boolean
|
||||||
|
payment_line_ocr: string | null
|
||||||
|
payment_line_amount: string | null
|
||||||
|
payment_line_account: string | null
|
||||||
|
payment_line_account_type: 'bankgiro' | 'plusgiro' | null
|
||||||
|
ocr_match: boolean | null
|
||||||
|
amount_match: boolean | null
|
||||||
|
bankgiro_match: boolean | null
|
||||||
|
plusgiro_match: boolean | null
|
||||||
|
details: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface InferenceResult {
|
||||||
|
document_id: string
|
||||||
|
document_type: string
|
||||||
|
success: boolean
|
||||||
|
fields: Record<string, string>
|
||||||
|
confidence: Record<string, number>
|
||||||
|
cross_validation: CrossValidationResult | null
|
||||||
|
processing_time_ms: number
|
||||||
|
visualization_url: string | null
|
||||||
|
errors: string[]
|
||||||
|
fallback_used: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface InferenceResponse {
|
||||||
|
result: InferenceResult
|
||||||
|
}
|
||||||
39
frontend/src/components/Badge.tsx
Normal file
39
frontend/src/components/Badge.tsx
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { DocumentStatus } from '../types';
|
||||||
|
import { Check } from 'lucide-react';
|
||||||
|
|
||||||
|
interface BadgeProps {
|
||||||
|
status: DocumentStatus | 'Exported';
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Badge: React.FC<BadgeProps> = ({ status }) => {
|
||||||
|
if (status === 'Exported') {
|
||||||
|
return (
|
||||||
|
<span className="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-full text-xs font-medium bg-warm-selected text-warm-text-secondary">
|
||||||
|
<Check size={12} strokeWidth={3} />
|
||||||
|
Exported
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const styles = {
|
||||||
|
[DocumentStatus.PENDING]: "bg-white border border-warm-divider text-warm-text-secondary",
|
||||||
|
[DocumentStatus.LABELED]: "bg-warm-text-secondary text-white border border-transparent",
|
||||||
|
[DocumentStatus.VERIFIED]: "bg-warm-state-success/10 text-warm-state-success border border-warm-state-success/20",
|
||||||
|
[DocumentStatus.PARTIAL]: "bg-warm-state-warning/10 text-warm-state-warning border border-warm-state-warning/20",
|
||||||
|
};
|
||||||
|
|
||||||
|
const icons = {
|
||||||
|
[DocumentStatus.VERIFIED]: <Check size={12} className="mr-1" />,
|
||||||
|
[DocumentStatus.PARTIAL]: <span className="mr-1 text-[10px] font-bold">!</span>,
|
||||||
|
[DocumentStatus.PENDING]: null,
|
||||||
|
[DocumentStatus.LABELED]: null,
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<span className={`inline-flex items-center px-3 py-1 rounded-full text-xs font-medium border ${styles[status]}`}>
|
||||||
|
{icons[status]}
|
||||||
|
{status}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
};
|
||||||
38
frontend/src/components/Button.tsx
Normal file
38
frontend/src/components/Button.tsx
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import React from 'react';
|
||||||
|
|
||||||
|
interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement> {
|
||||||
|
variant?: 'primary' | 'secondary' | 'outline' | 'text';
|
||||||
|
size?: 'sm' | 'md' | 'lg';
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Button: React.FC<ButtonProps> = ({
|
||||||
|
variant = 'primary',
|
||||||
|
size = 'md',
|
||||||
|
className = '',
|
||||||
|
children,
|
||||||
|
...props
|
||||||
|
}) => {
|
||||||
|
const baseStyles = "inline-flex items-center justify-center rounded-md font-medium transition-all duration-150 ease-out active:scale-98 disabled:opacity-50 disabled:pointer-events-none";
|
||||||
|
|
||||||
|
const variants = {
|
||||||
|
primary: "bg-warm-text-secondary text-white hover:bg-warm-text-primary shadow-sm",
|
||||||
|
secondary: "bg-white border border-warm-divider text-warm-text-secondary hover:bg-warm-hover",
|
||||||
|
outline: "bg-transparent border border-warm-text-secondary text-warm-text-secondary hover:bg-warm-hover",
|
||||||
|
text: "text-warm-text-muted hover:text-warm-text-primary hover:bg-warm-hover",
|
||||||
|
};
|
||||||
|
|
||||||
|
const sizes = {
|
||||||
|
sm: "h-8 px-3 text-xs",
|
||||||
|
md: "h-10 px-4 text-sm",
|
||||||
|
lg: "h-12 px-6 text-base",
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
className={`${baseStyles} ${variants[variant]} ${sizes[size]} ${className}`}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
};
|
||||||
266
frontend/src/components/Dashboard.tsx
Normal file
266
frontend/src/components/Dashboard.tsx
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
|
||||||
|
import { Badge } from './Badge'
|
||||||
|
import { Button } from './Button'
|
||||||
|
import { UploadModal } from './UploadModal'
|
||||||
|
import { useDocuments } from '../hooks/useDocuments'
|
||||||
|
import type { DocumentItem } from '../api/types'
|
||||||
|
|
||||||
|
interface DashboardProps {
|
||||||
|
onNavigate: (view: string, docId?: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const getStatusForBadge = (status: string): string => {
|
||||||
|
const statusMap: Record<string, string> = {
|
||||||
|
pending: 'Pending',
|
||||||
|
labeled: 'Labeled',
|
||||||
|
verified: 'Verified',
|
||||||
|
exported: 'Exported',
|
||||||
|
}
|
||||||
|
return statusMap[status] || status
|
||||||
|
}
|
||||||
|
|
||||||
|
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
|
||||||
|
if (doc.auto_label_status === 'running') {
|
||||||
|
return 45
|
||||||
|
}
|
||||||
|
if (doc.auto_label_status === 'completed') {
|
||||||
|
return 100
|
||||||
|
}
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
||||||
|
const [isUploadOpen, setIsUploadOpen] = useState(false)
|
||||||
|
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
|
||||||
|
const [statusFilter, setStatusFilter] = useState<string>('')
|
||||||
|
const [limit] = useState(20)
|
||||||
|
const [offset] = useState(0)
|
||||||
|
|
||||||
|
const { documents, total, isLoading, error, refetch } = useDocuments({
|
||||||
|
status: statusFilter || undefined,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
})
|
||||||
|
|
||||||
|
const toggleSelection = (id: string) => {
|
||||||
|
const newSet = new Set(selectedDocs)
|
||||||
|
if (newSet.has(id)) {
|
||||||
|
newSet.delete(id)
|
||||||
|
} else {
|
||||||
|
newSet.add(id)
|
||||||
|
}
|
||||||
|
setSelectedDocs(newSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto">
|
||||||
|
<div className="bg-red-50 border border-red-200 text-red-800 p-4 rounded-lg">
|
||||||
|
Error loading documents. Please check your connection to the backend API.
|
||||||
|
<button
|
||||||
|
onClick={() => refetch()}
|
||||||
|
className="ml-4 underline hover:no-underline"
|
||||||
|
>
|
||||||
|
Retry
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||||
|
<div className="flex items-center justify-between mb-8">
|
||||||
|
<div>
|
||||||
|
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||||
|
Documents
|
||||||
|
</h1>
|
||||||
|
<p className="text-sm text-warm-text-muted mt-1">
|
||||||
|
{isLoading ? 'Loading...' : `${total} documents total`}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-3">
|
||||||
|
<Button variant="secondary" disabled={selectedDocs.size === 0}>
|
||||||
|
Export Selection ({selectedDocs.size})
|
||||||
|
</Button>
|
||||||
|
<Button onClick={() => setIsUploadOpen(true)}>Upload Documents</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg p-4 mb-6 shadow-sm flex flex-wrap gap-4 items-center">
|
||||||
|
<div className="relative flex-1 min-w-[200px]">
|
||||||
|
<Search
|
||||||
|
className="absolute left-3 top-1/2 -translate-y-1/2 text-warm-text-muted"
|
||||||
|
size={16}
|
||||||
|
/>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
placeholder="Search documents..."
|
||||||
|
className="w-full pl-9 pr-4 h-10 rounded-md border border-warm-border bg-white focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex gap-3">
|
||||||
|
<div className="relative">
|
||||||
|
<select
|
||||||
|
value={statusFilter}
|
||||||
|
onChange={(e) => setStatusFilter(e.target.value)}
|
||||||
|
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
|
||||||
|
>
|
||||||
|
<option value="">All Statuses</option>
|
||||||
|
<option value="pending">Pending</option>
|
||||||
|
<option value="labeled">Labeled</option>
|
||||||
|
<option value="verified">Verified</option>
|
||||||
|
<option value="exported">Exported</option>
|
||||||
|
</select>
|
||||||
|
<ChevronDown
|
||||||
|
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||||
|
size={14}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||||
|
<table className="w-full text-left border-collapse">
|
||||||
|
<thead>
|
||||||
|
<tr className="border-b border-warm-border bg-white">
|
||||||
|
<th className="py-3 pl-6 pr-4 w-12">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary"
|
||||||
|
/>
|
||||||
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
|
Document Name
|
||||||
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
|
Date
|
||||||
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
|
Status
|
||||||
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
|
Annotations
|
||||||
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
|
||||||
|
Auto-label
|
||||||
|
</th>
|
||||||
|
<th className="py-3 px-4 w-12"></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{isLoading ? (
|
||||||
|
<tr>
|
||||||
|
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
||||||
|
Loading documents...
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
) : documents.length === 0 ? (
|
||||||
|
<tr>
|
||||||
|
<td colSpan={7} className="py-8 text-center text-warm-text-muted">
|
||||||
|
No documents found. Upload your first document to get started.
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
) : (
|
||||||
|
documents.map((doc) => {
|
||||||
|
const isSelected = selectedDocs.has(doc.document_id)
|
||||||
|
const progress = getAutoLabelProgress(doc)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<tr
|
||||||
|
key={doc.document_id}
|
||||||
|
onClick={() => onNavigate('detail', doc.document_id)}
|
||||||
|
className={`
|
||||||
|
group transition-colors duration-150 cursor-pointer border-b border-warm-border last:border-0
|
||||||
|
${isSelected ? 'bg-warm-selected' : 'hover:bg-warm-hover bg-white'}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
<td
|
||||||
|
className="py-4 pl-6 pr-4 relative"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
toggleSelection(doc.document_id)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{isSelected && (
|
||||||
|
<div className="absolute left-0 top-0 bottom-0 w-[3px] bg-warm-state-info" />
|
||||||
|
)}
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={isSelected}
|
||||||
|
readOnly
|
||||||
|
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary cursor-pointer"
|
||||||
|
/>
|
||||||
|
</td>
|
||||||
|
<td className="py-4 px-4">
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<div className="p-2 bg-warm-bg rounded border border-warm-border text-warm-text-muted">
|
||||||
|
<FileText size={16} />
|
||||||
|
</div>
|
||||||
|
<span className="font-medium text-warm-text-secondary">
|
||||||
|
{doc.filename}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</td>
|
||||||
|
<td className="py-4 px-4 text-sm text-warm-text-secondary font-mono">
|
||||||
|
{new Date(doc.created_at).toLocaleDateString()}
|
||||||
|
</td>
|
||||||
|
<td className="py-4 px-4">
|
||||||
|
<Badge status={getStatusForBadge(doc.status)} />
|
||||||
|
</td>
|
||||||
|
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
||||||
|
{doc.annotation_count || 0} annotations
|
||||||
|
</td>
|
||||||
|
<td className="py-4 px-4">
|
||||||
|
{doc.auto_label_status === 'running' && progress && (
|
||||||
|
<div className="w-full">
|
||||||
|
<div className="flex justify-between text-xs mb-1">
|
||||||
|
<span className="text-warm-text-secondary font-medium">
|
||||||
|
Running
|
||||||
|
</span>
|
||||||
|
<span className="text-warm-text-muted">{progress}%</span>
|
||||||
|
</div>
|
||||||
|
<div className="h-1.5 w-full bg-warm-selected rounded-full overflow-hidden">
|
||||||
|
<div
|
||||||
|
className="h-full bg-warm-state-info transition-all duration-500 ease-out"
|
||||||
|
style={{ width: `${progress}%` }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{doc.auto_label_status === 'completed' && (
|
||||||
|
<span className="text-sm font-medium text-warm-state-success">
|
||||||
|
Completed
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
{doc.auto_label_status === 'failed' && (
|
||||||
|
<span className="text-sm font-medium text-warm-state-error">
|
||||||
|
Failed
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
<td className="py-4 px-4 text-right">
|
||||||
|
<button className="text-warm-text-muted hover:text-warm-text-secondary p-1 rounded hover:bg-black/5 transition-colors">
|
||||||
|
<MoreHorizontal size={18} />
|
||||||
|
</button>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
)
|
||||||
|
})
|
||||||
|
)}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<UploadModal
|
||||||
|
isOpen={isUploadOpen}
|
||||||
|
onClose={() => {
|
||||||
|
setIsUploadOpen(false)
|
||||||
|
refetch()
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
148
frontend/src/components/DashboardOverview.tsx
Normal file
148
frontend/src/components/DashboardOverview.tsx
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
|
||||||
|
import { Button } from './Button'
|
||||||
|
import { useDocuments } from '../hooks/useDocuments'
|
||||||
|
import { useTraining } from '../hooks/useTraining'
|
||||||
|
|
||||||
|
interface DashboardOverviewProps {
|
||||||
|
onNavigate: (view: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
|
||||||
|
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
|
||||||
|
const { models, isLoadingModels } = useTraining()
|
||||||
|
|
||||||
|
const stats = [
|
||||||
|
{
|
||||||
|
label: 'Total Documents',
|
||||||
|
value: docsLoading ? '...' : totalDocs.toString(),
|
||||||
|
icon: FileText,
|
||||||
|
color: 'text-warm-text-primary',
|
||||||
|
bgColor: 'bg-warm-bg',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Labeled',
|
||||||
|
value: '0',
|
||||||
|
icon: CheckCircle,
|
||||||
|
color: 'text-warm-state-success',
|
||||||
|
bgColor: 'bg-green-50',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Pending',
|
||||||
|
value: '0',
|
||||||
|
icon: Clock,
|
||||||
|
color: 'text-warm-state-warning',
|
||||||
|
bgColor: 'bg-yellow-50',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Training Models',
|
||||||
|
value: isLoadingModels ? '...' : models.length.toString(),
|
||||||
|
icon: TrendingUp,
|
||||||
|
color: 'text-warm-state-info',
|
||||||
|
bgColor: 'bg-blue-50',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="mb-8">
|
||||||
|
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||||
|
Dashboard
|
||||||
|
</h1>
|
||||||
|
<p className="text-sm text-warm-text-muted mt-1">
|
||||||
|
Overview of your document annotation system
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Stats Grid */}
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
|
||||||
|
{stats.map((stat) => (
|
||||||
|
<div
|
||||||
|
key={stat.label}
|
||||||
|
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
|
||||||
|
>
|
||||||
|
<div className="flex items-center justify-between mb-4">
|
||||||
|
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
|
||||||
|
<stat.icon className={stat.color} size={24} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<p className="text-2xl font-bold text-warm-text-primary mb-1">
|
||||||
|
{stat.value}
|
||||||
|
</p>
|
||||||
|
<p className="text-sm text-warm-text-muted">{stat.label}</p>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Quick Actions */}
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
|
||||||
|
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||||
|
Quick Actions
|
||||||
|
</h2>
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||||
|
<Button onClick={() => onNavigate('documents')} className="justify-start">
|
||||||
|
<FileText size={18} className="mr-2" />
|
||||||
|
Manage Documents
|
||||||
|
</Button>
|
||||||
|
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
|
||||||
|
<Activity size={18} className="mr-2" />
|
||||||
|
Start Training
|
||||||
|
</Button>
|
||||||
|
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
|
||||||
|
<TrendingUp size={18} className="mr-2" />
|
||||||
|
View Models
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Recent Activity */}
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||||
|
<div className="p-6 border-b border-warm-border">
|
||||||
|
<h2 className="text-lg font-semibold text-warm-text-primary">
|
||||||
|
Recent Activity
|
||||||
|
</h2>
|
||||||
|
</div>
|
||||||
|
<div className="p-6">
|
||||||
|
<div className="text-center py-8 text-warm-text-muted">
|
||||||
|
<Activity size={48} className="mx-auto mb-3 opacity-20" />
|
||||||
|
<p className="text-sm">No recent activity</p>
|
||||||
|
<p className="text-xs mt-1">
|
||||||
|
Start by uploading documents or creating training jobs
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* System Status */}
|
||||||
|
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||||
|
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||||
|
System Status
|
||||||
|
</h2>
|
||||||
|
<div className="space-y-3">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<span className="text-sm text-warm-text-secondary">Backend API</span>
|
||||||
|
<span className="flex items-center text-sm text-warm-state-success">
|
||||||
|
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||||
|
Online
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<span className="text-sm text-warm-text-secondary">Database</span>
|
||||||
|
<span className="flex items-center text-sm text-warm-state-success">
|
||||||
|
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||||
|
Connected
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<span className="text-sm text-warm-text-secondary">GPU</span>
|
||||||
|
<span className="flex items-center text-sm text-warm-state-success">
|
||||||
|
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||||
|
Available
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
504
frontend/src/components/DocumentDetail.tsx
Normal file
504
frontend/src/components/DocumentDetail.tsx
Normal file
@@ -0,0 +1,504 @@
|
|||||||
|
import React, { useState, useRef, useEffect } from 'react'
|
||||||
|
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle } from 'lucide-react'
|
||||||
|
import { Button } from './Button'
|
||||||
|
import { useDocumentDetail } from '../hooks/useDocumentDetail'
|
||||||
|
import { useAnnotations } from '../hooks/useAnnotations'
|
||||||
|
import { documentsApi } from '../api/endpoints/documents'
|
||||||
|
import type { AnnotationItem } from '../api/types'
|
||||||
|
|
||||||
|
interface DocumentDetailProps {
|
||||||
|
docId: string
|
||||||
|
onBack: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
// Field class mapping from backend
|
||||||
|
const FIELD_CLASSES: Record<number, string> = {
|
||||||
|
0: 'invoice_number',
|
||||||
|
1: 'invoice_date',
|
||||||
|
2: 'invoice_due_date',
|
||||||
|
3: 'ocr_number',
|
||||||
|
4: 'bankgiro',
|
||||||
|
5: 'plusgiro',
|
||||||
|
6: 'amount',
|
||||||
|
7: 'supplier_organisation_number',
|
||||||
|
8: 'payment_line',
|
||||||
|
9: 'customer_number',
|
||||||
|
}
|
||||||
|
|
||||||
|
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
|
||||||
|
const { document, annotations, isLoading } = useDocumentDetail(docId)
|
||||||
|
const {
|
||||||
|
createAnnotation,
|
||||||
|
updateAnnotation,
|
||||||
|
deleteAnnotation,
|
||||||
|
isCreating,
|
||||||
|
isDeleting,
|
||||||
|
} = useAnnotations(docId)
|
||||||
|
|
||||||
|
const [selectedId, setSelectedId] = useState<string | null>(null)
|
||||||
|
const [zoom, setZoom] = useState(100)
|
||||||
|
const [isDrawing, setIsDrawing] = useState(false)
|
||||||
|
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
|
||||||
|
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
|
||||||
|
const [selectedClassId, setSelectedClassId] = useState<number>(0)
|
||||||
|
const [currentPage, setCurrentPage] = useState(1)
|
||||||
|
const [imageSize, setImageSize] = useState<{ width: number; height: number } | null>(null)
|
||||||
|
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const canvasRef = useRef<HTMLDivElement>(null)
|
||||||
|
const imageRef = useRef<HTMLImageElement>(null)
|
||||||
|
|
||||||
|
const [isMarkingComplete, setIsMarkingComplete] = useState(false)
|
||||||
|
|
||||||
|
const selectedAnnotation = annotations?.find((a) => a.annotation_id === selectedId)
|
||||||
|
|
||||||
|
// Handle mark as complete
|
||||||
|
const handleMarkComplete = async () => {
|
||||||
|
if (!annotations || annotations.length === 0) {
|
||||||
|
alert('Please add at least one annotation before marking as complete.')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!confirm('Mark this document as labeled? This will save annotations to the database.')) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsMarkingComplete(true)
|
||||||
|
try {
|
||||||
|
const result = await documentsApi.updateStatus(docId, 'labeled')
|
||||||
|
alert(`Document marked as labeled. ${(result as any).fields_saved || annotations.length} annotations saved.`)
|
||||||
|
onBack() // Return to document list
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to mark document as complete:', error)
|
||||||
|
alert('Failed to mark document as complete. Please try again.')
|
||||||
|
} finally {
|
||||||
|
setIsMarkingComplete(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load image via fetch with authentication header
|
||||||
|
useEffect(() => {
|
||||||
|
let objectUrl: string | null = null
|
||||||
|
|
||||||
|
const loadImage = async () => {
|
||||||
|
if (!docId) return
|
||||||
|
|
||||||
|
const token = localStorage.getItem('admin_token')
|
||||||
|
const imageUrl = `${import.meta.env.VITE_API_URL || 'http://localhost:8000'}/api/v1/admin/documents/${docId}/images/${currentPage}`
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(imageUrl, {
|
||||||
|
headers: {
|
||||||
|
'X-Admin-Token': token || '',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to load image: ${response.status}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const blob = await response.blob()
|
||||||
|
objectUrl = URL.createObjectURL(blob)
|
||||||
|
setImageBlobUrl(objectUrl)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to load image:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loadImage()
|
||||||
|
|
||||||
|
// Cleanup: revoke object URL when component unmounts or page changes
|
||||||
|
return () => {
|
||||||
|
if (objectUrl) {
|
||||||
|
URL.revokeObjectURL(objectUrl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [currentPage, docId])
|
||||||
|
|
||||||
|
// Load image size
|
||||||
|
useEffect(() => {
|
||||||
|
if (imageRef.current && imageRef.current.complete) {
|
||||||
|
setImageSize({
|
||||||
|
width: imageRef.current.naturalWidth,
|
||||||
|
height: imageRef.current.naturalHeight,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}, [imageBlobUrl])
|
||||||
|
|
||||||
|
const handleImageLoad = () => {
|
||||||
|
if (imageRef.current) {
|
||||||
|
setImageSize({
|
||||||
|
width: imageRef.current.naturalWidth,
|
||||||
|
height: imageRef.current.naturalHeight,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||||
|
if (!canvasRef.current || !imageSize) return
|
||||||
|
const rect = canvasRef.current.getBoundingClientRect()
|
||||||
|
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||||
|
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||||
|
setIsDrawing(true)
|
||||||
|
setDrawStart({ x, y })
|
||||||
|
setDrawEnd({ x, y })
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleMouseMove = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||||
|
if (!isDrawing || !canvasRef.current || !imageSize) return
|
||||||
|
const rect = canvasRef.current.getBoundingClientRect()
|
||||||
|
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||||
|
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||||
|
setDrawEnd({ x, y })
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleMouseUp = () => {
|
||||||
|
if (!isDrawing || !drawStart || !drawEnd || !imageSize) {
|
||||||
|
setIsDrawing(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const bbox_x = Math.min(drawStart.x, drawEnd.x)
|
||||||
|
const bbox_y = Math.min(drawStart.y, drawEnd.y)
|
||||||
|
const bbox_width = Math.abs(drawEnd.x - drawStart.x)
|
||||||
|
const bbox_height = Math.abs(drawEnd.y - drawStart.y)
|
||||||
|
|
||||||
|
// Only create if box is large enough (min 10x10 pixels)
|
||||||
|
if (bbox_width > 10 && bbox_height > 10) {
|
||||||
|
createAnnotation({
|
||||||
|
page_number: currentPage,
|
||||||
|
class_id: selectedClassId,
|
||||||
|
bbox: {
|
||||||
|
x: Math.round(bbox_x),
|
||||||
|
y: Math.round(bbox_y),
|
||||||
|
width: Math.round(bbox_width),
|
||||||
|
height: Math.round(bbox_height),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsDrawing(false)
|
||||||
|
setDrawStart(null)
|
||||||
|
setDrawEnd(null)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDeleteAnnotation = (annotationId: string) => {
|
||||||
|
if (confirm('Are you sure you want to delete this annotation?')) {
|
||||||
|
deleteAnnotation(annotationId)
|
||||||
|
setSelectedId(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isLoading || !document) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-screen items-center justify-center">
|
||||||
|
<div className="text-warm-text-muted">Loading...</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current page annotations
|
||||||
|
const pageAnnotations = annotations?.filter((a) => a.page_number === currentPage) || []
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex h-[calc(100vh-56px)] overflow-hidden">
|
||||||
|
{/* Main Canvas Area */}
|
||||||
|
<div className="flex-1 bg-warm-bg flex flex-col relative">
|
||||||
|
{/* Toolbar */}
|
||||||
|
<div className="h-14 border-b border-warm-border bg-white flex items-center justify-between px-4 z-10">
|
||||||
|
<div className="flex items-center gap-4">
|
||||||
|
<button
|
||||||
|
onClick={onBack}
|
||||||
|
className="p-2 hover:bg-warm-hover rounded-md text-warm-text-secondary transition-colors"
|
||||||
|
>
|
||||||
|
<ChevronLeft size={20} />
|
||||||
|
</button>
|
||||||
|
<div>
|
||||||
|
<h2 className="text-sm font-semibold text-warm-text-primary">{document.filename}</h2>
|
||||||
|
<p className="text-xs text-warm-text-muted">
|
||||||
|
Page {currentPage} of {document.page_count}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div className="h-6 w-px bg-warm-divider mx-2" />
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<button
|
||||||
|
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||||
|
onClick={() => setZoom((z) => Math.max(50, z - 10))}
|
||||||
|
>
|
||||||
|
<ZoomOut size={16} />
|
||||||
|
</button>
|
||||||
|
<span className="text-xs font-mono w-12 text-center text-warm-text-secondary">
|
||||||
|
{zoom}%
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||||
|
onClick={() => setZoom((z) => Math.min(200, z + 10))}
|
||||||
|
>
|
||||||
|
<ZoomIn size={16} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button variant="secondary" size="sm">
|
||||||
|
Auto-label
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="sm"
|
||||||
|
onClick={handleMarkComplete}
|
||||||
|
disabled={isMarkingComplete || document.status === 'labeled'}
|
||||||
|
>
|
||||||
|
<CheckCircle size={16} className="mr-1" />
|
||||||
|
{isMarkingComplete ? 'Saving...' : document.status === 'labeled' ? 'Labeled' : 'Mark Complete'}
|
||||||
|
</Button>
|
||||||
|
{document.page_count > 1 && (
|
||||||
|
<div className="flex gap-1">
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => setCurrentPage((p) => Math.max(1, p - 1))}
|
||||||
|
disabled={currentPage === 1}
|
||||||
|
>
|
||||||
|
Prev
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => setCurrentPage((p) => Math.min(document.page_count, p + 1))}
|
||||||
|
disabled={currentPage === document.page_count}
|
||||||
|
>
|
||||||
|
Next
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Canvas Scroll Area */}
|
||||||
|
<div className="flex-1 overflow-auto p-8 flex justify-center bg-warm-bg">
|
||||||
|
<div
|
||||||
|
ref={canvasRef}
|
||||||
|
className="bg-white shadow-lg relative transition-transform duration-200 ease-out origin-top"
|
||||||
|
style={{
|
||||||
|
width: imageSize?.width || 800,
|
||||||
|
height: imageSize?.height || 1132,
|
||||||
|
transform: `scale(${zoom / 100})`,
|
||||||
|
marginBottom: '100px',
|
||||||
|
cursor: isDrawing ? 'crosshair' : 'default',
|
||||||
|
}}
|
||||||
|
onMouseDown={handleMouseDown}
|
||||||
|
onMouseMove={handleMouseMove}
|
||||||
|
onMouseUp={handleMouseUp}
|
||||||
|
onClick={() => setSelectedId(null)}
|
||||||
|
>
|
||||||
|
{/* Document Image */}
|
||||||
|
{imageBlobUrl ? (
|
||||||
|
<img
|
||||||
|
ref={imageRef}
|
||||||
|
src={imageBlobUrl}
|
||||||
|
alt={`Page ${currentPage}`}
|
||||||
|
className="w-full h-full object-contain select-none pointer-events-none"
|
||||||
|
onLoad={handleImageLoad}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div className="flex items-center justify-center h-full">
|
||||||
|
<div className="text-warm-text-muted">Loading image...</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Annotation Overlays */}
|
||||||
|
{pageAnnotations.map((ann) => {
|
||||||
|
const isSelected = selectedId === ann.annotation_id
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={ann.annotation_id}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
setSelectedId(ann.annotation_id)
|
||||||
|
}}
|
||||||
|
className={`
|
||||||
|
absolute group cursor-pointer transition-all duration-100
|
||||||
|
${
|
||||||
|
ann.source === 'auto'
|
||||||
|
? 'border border-dashed border-warm-text-muted bg-transparent'
|
||||||
|
: 'border-2 border-warm-text-secondary bg-warm-text-secondary/5'
|
||||||
|
}
|
||||||
|
${
|
||||||
|
isSelected
|
||||||
|
? 'border-2 border-warm-state-info ring-4 ring-warm-state-info/10 z-20'
|
||||||
|
: 'hover:bg-warm-state-info/5 z-10'
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
style={{
|
||||||
|
left: ann.bbox.x,
|
||||||
|
top: ann.bbox.y,
|
||||||
|
width: ann.bbox.width,
|
||||||
|
height: ann.bbox.height,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{/* Label Tag */}
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
absolute -top-6 left-0 text-[10px] uppercase font-bold px-1.5 py-0.5 rounded-sm tracking-wide shadow-sm whitespace-nowrap
|
||||||
|
${
|
||||||
|
isSelected
|
||||||
|
? 'bg-warm-state-info text-white'
|
||||||
|
: 'bg-white text-warm-text-secondary border border-warm-border'
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
{ann.class_name}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Resize Handles (Visual only) */}
|
||||||
|
{isSelected && (
|
||||||
|
<>
|
||||||
|
<div className="absolute -top-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||||
|
<div className="absolute -top-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||||
|
<div className="absolute -bottom-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||||
|
<div className="absolute -bottom-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
|
||||||
|
{/* Drawing Box Preview */}
|
||||||
|
{isDrawing && drawStart && drawEnd && (
|
||||||
|
<div
|
||||||
|
className="absolute border-2 border-warm-state-info bg-warm-state-info/10 z-30 pointer-events-none"
|
||||||
|
style={{
|
||||||
|
left: Math.min(drawStart.x, drawEnd.x),
|
||||||
|
top: Math.min(drawStart.y, drawEnd.y),
|
||||||
|
width: Math.abs(drawEnd.x - drawStart.x),
|
||||||
|
height: Math.abs(drawEnd.y - drawStart.y),
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right Sidebar */}
|
||||||
|
<div className="w-80 bg-white border-l border-warm-border flex flex-col shadow-[-4px_0_15px_-3px_rgba(0,0,0,0.03)] z-20">
|
||||||
|
{/* Field Selector */}
|
||||||
|
<div className="p-4 border-b border-warm-border">
|
||||||
|
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Draw Annotation</h3>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<label className="block text-xs text-warm-text-muted mb-1">Select Field Type</label>
|
||||||
|
<select
|
||||||
|
value={selectedClassId}
|
||||||
|
onChange={(e) => setSelectedClassId(Number(e.target.value))}
|
||||||
|
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
|
>
|
||||||
|
{Object.entries(FIELD_CLASSES).map(([id, name]) => (
|
||||||
|
<option key={id} value={id}>
|
||||||
|
{name.replace(/_/g, ' ')}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
<p className="text-xs text-warm-text-muted mt-2">
|
||||||
|
Click and drag on the document to create a bounding box
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Document Info Card */}
|
||||||
|
<div className="p-4 border-b border-warm-border">
|
||||||
|
<div className="bg-white rounded-lg border border-warm-border p-4 shadow-sm">
|
||||||
|
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Document Info</h3>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="flex justify-between text-xs">
|
||||||
|
<span className="text-warm-text-muted">Status</span>
|
||||||
|
<span className="text-warm-text-secondary font-medium capitalize">
|
||||||
|
{document.status}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between text-xs">
|
||||||
|
<span className="text-warm-text-muted">Size</span>
|
||||||
|
<span className="text-warm-text-secondary font-medium">
|
||||||
|
{(document.file_size / 1024 / 1024).toFixed(2)} MB
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between text-xs">
|
||||||
|
<span className="text-warm-text-muted">Uploaded</span>
|
||||||
|
<span className="text-warm-text-secondary font-medium">
|
||||||
|
{new Date(document.created_at).toLocaleDateString()}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Annotations List */}
|
||||||
|
<div className="flex-1 overflow-y-auto p-4">
|
||||||
|
<div className="flex items-center justify-between mb-4">
|
||||||
|
<h3 className="text-sm font-semibold text-warm-text-primary">Annotations</h3>
|
||||||
|
<span className="text-xs text-warm-text-muted">{pageAnnotations.length} items</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{pageAnnotations.length === 0 ? (
|
||||||
|
<div className="text-center py-8 text-warm-text-muted">
|
||||||
|
<Tag size={48} className="mx-auto mb-3 opacity-20" />
|
||||||
|
<p className="text-sm">No annotations yet</p>
|
||||||
|
<p className="text-xs mt-1">Draw on the document to add annotations</p>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="space-y-3">
|
||||||
|
{pageAnnotations.map((ann) => (
|
||||||
|
<div
|
||||||
|
key={ann.annotation_id}
|
||||||
|
onClick={() => setSelectedId(ann.annotation_id)}
|
||||||
|
className={`
|
||||||
|
group p-3 rounded-md border transition-all duration-150 cursor-pointer
|
||||||
|
${
|
||||||
|
selectedId === ann.annotation_id
|
||||||
|
? 'bg-warm-bg border-warm-state-info shadow-sm'
|
||||||
|
: 'bg-white border-warm-border hover:border-warm-text-muted'
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
<div className="flex justify-between items-start mb-1">
|
||||||
|
<span className="text-xs font-bold text-warm-text-secondary uppercase tracking-wider">
|
||||||
|
{ann.class_name.replace(/_/g, ' ')}
|
||||||
|
</span>
|
||||||
|
{selectedId === ann.annotation_id && (
|
||||||
|
<div className="flex gap-1">
|
||||||
|
<button
|
||||||
|
onClick={() => handleDeleteAnnotation(ann.annotation_id)}
|
||||||
|
className="text-warm-text-muted hover:text-warm-state-error"
|
||||||
|
disabled={isDeleting}
|
||||||
|
>
|
||||||
|
<Trash2 size={12} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<p className="text-sm text-warm-text-muted font-mono truncate">
|
||||||
|
{ann.text_value || '(no text)'}
|
||||||
|
</p>
|
||||||
|
<div className="flex items-center gap-2 mt-2">
|
||||||
|
<span
|
||||||
|
className={`text-[10px] px-1.5 py-0.5 rounded ${
|
||||||
|
ann.source === 'auto'
|
||||||
|
? 'bg-blue-50 text-blue-700'
|
||||||
|
: 'bg-green-50 text-green-700'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{ann.source}
|
||||||
|
</span>
|
||||||
|
{ann.confidence && (
|
||||||
|
<span className="text-[10px] text-warm-text-muted">
|
||||||
|
{(ann.confidence * 100).toFixed(0)}%
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
466
frontend/src/components/InferenceDemo.tsx
Normal file
466
frontend/src/components/InferenceDemo.tsx
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
import React, { useState, useRef } from 'react'
|
||||||
|
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock } from 'lucide-react'
|
||||||
|
import { Button } from './Button'
|
||||||
|
import { inferenceApi } from '../api/endpoints'
|
||||||
|
import type { InferenceResult } from '../api/types'
|
||||||
|
|
||||||
|
export const InferenceDemo: React.FC = () => {
|
||||||
|
const [isDragging, setIsDragging] = useState(false)
|
||||||
|
const [selectedFile, setSelectedFile] = useState<File | null>(null)
|
||||||
|
const [isProcessing, setIsProcessing] = useState(false)
|
||||||
|
const [result, setResult] = useState<InferenceResult | null>(null)
|
||||||
|
const [error, setError] = useState<string | null>(null)
|
||||||
|
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||||
|
|
||||||
|
const handleFileSelect = (file: File | null) => {
|
||||||
|
if (!file) return
|
||||||
|
|
||||||
|
const validTypes = ['application/pdf', 'image/png', 'image/jpeg', 'image/jpg']
|
||||||
|
if (!validTypes.includes(file.type)) {
|
||||||
|
setError('Please upload a PDF, PNG, or JPG file')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (file.size > 50 * 1024 * 1024) {
|
||||||
|
setError('File size must be less than 50MB')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setSelectedFile(file)
|
||||||
|
setResult(null)
|
||||||
|
setError(null)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDrop = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
setIsDragging(false)
|
||||||
|
if (e.dataTransfer.files.length > 0) {
|
||||||
|
handleFileSelect(e.dataTransfer.files[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleBrowseClick = () => {
|
||||||
|
fileInputRef.current?.click()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleProcess = async () => {
|
||||||
|
if (!selectedFile) return
|
||||||
|
|
||||||
|
setIsProcessing(true)
|
||||||
|
setError(null)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await inferenceApi.processDocument(selectedFile)
|
||||||
|
console.log('API Response:', response)
|
||||||
|
console.log('Visualization URL:', response.result?.visualization_url)
|
||||||
|
setResult(response.result)
|
||||||
|
} catch (err) {
|
||||||
|
setError(err instanceof Error ? err.message : 'Processing failed')
|
||||||
|
} finally {
|
||||||
|
setIsProcessing(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleReset = () => {
|
||||||
|
setSelectedFile(null)
|
||||||
|
setResult(null)
|
||||||
|
setError(null)
|
||||||
|
}
|
||||||
|
|
||||||
|
const formatFieldName = (field: string): string => {
|
||||||
|
const fieldNames: Record<string, string> = {
|
||||||
|
InvoiceNumber: 'Invoice Number',
|
||||||
|
InvoiceDate: 'Invoice Date',
|
||||||
|
InvoiceDueDate: 'Due Date',
|
||||||
|
OCR: 'OCR Number',
|
||||||
|
Amount: 'Amount',
|
||||||
|
Bankgiro: 'Bankgiro',
|
||||||
|
Plusgiro: 'Plusgiro',
|
||||||
|
supplier_org_number: 'Supplier Org Number',
|
||||||
|
customer_number: 'Customer Number',
|
||||||
|
payment_line: 'Payment Line',
|
||||||
|
}
|
||||||
|
return fieldNames[field] || field
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-7xl mx-auto px-4 py-6 space-y-6">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="text-center">
|
||||||
|
<h2 className="text-3xl font-bold text-warm-text-primary mb-2">
|
||||||
|
Invoice Extraction Demo
|
||||||
|
</h2>
|
||||||
|
<p className="text-warm-text-muted">
|
||||||
|
Upload a Swedish invoice to see our AI-powered field extraction in action
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Upload Area */}
|
||||||
|
{!result && (
|
||||||
|
<div className="max-w-2xl mx-auto">
|
||||||
|
<div className="bg-warm-card rounded-xl border border-warm-border p-8 shadow-sm">
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
relative h-72 rounded-xl border-2 border-dashed transition-all duration-200
|
||||||
|
${isDragging
|
||||||
|
? 'border-warm-text-secondary bg-warm-selected scale-[1.02]'
|
||||||
|
: 'border-warm-divider bg-warm-bg hover:bg-warm-hover hover:border-warm-text-secondary/50'
|
||||||
|
}
|
||||||
|
${isProcessing ? 'opacity-60 pointer-events-none' : 'cursor-pointer'}
|
||||||
|
`}
|
||||||
|
onDragOver={(e) => {
|
||||||
|
e.preventDefault()
|
||||||
|
setIsDragging(true)
|
||||||
|
}}
|
||||||
|
onDragLeave={() => setIsDragging(false)}
|
||||||
|
onDrop={handleDrop}
|
||||||
|
onClick={handleBrowseClick}
|
||||||
|
>
|
||||||
|
<div className="absolute inset-0 flex flex-col items-center justify-center gap-6">
|
||||||
|
{isProcessing ? (
|
||||||
|
<>
|
||||||
|
<Loader2 size={56} className="text-warm-text-secondary animate-spin" />
|
||||||
|
<div className="text-center">
|
||||||
|
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||||
|
Processing invoice...
|
||||||
|
</p>
|
||||||
|
<p className="text-sm text-warm-text-muted">
|
||||||
|
This may take a few moments
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
) : selectedFile ? (
|
||||||
|
<>
|
||||||
|
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||||
|
<FileText size={40} className="text-warm-text-secondary" />
|
||||||
|
</div>
|
||||||
|
<div className="text-center px-4">
|
||||||
|
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||||
|
{selectedFile.name}
|
||||||
|
</p>
|
||||||
|
<p className="text-sm text-warm-text-muted">
|
||||||
|
{(selectedFile.size / 1024 / 1024).toFixed(2)} MB
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||||
|
<UploadCloud size={40} className="text-warm-text-secondary" />
|
||||||
|
</div>
|
||||||
|
<div className="text-center px-4">
|
||||||
|
<p className="text-lg font-semibold text-warm-text-primary mb-2">
|
||||||
|
Drag & drop invoice here
|
||||||
|
</p>
|
||||||
|
<p className="text-sm text-warm-text-muted mb-3">
|
||||||
|
or{' '}
|
||||||
|
<span className="text-warm-text-secondary font-medium">
|
||||||
|
browse files
|
||||||
|
</span>
|
||||||
|
</p>
|
||||||
|
<p className="text-xs text-warm-text-muted">
|
||||||
|
Supports PDF, PNG, JPG (up to 50MB)
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<input
|
||||||
|
ref={fileInputRef}
|
||||||
|
type="file"
|
||||||
|
accept=".pdf,image/*"
|
||||||
|
className="hidden"
|
||||||
|
onChange={(e) => handleFileSelect(e.target.files?.[0] || null)}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<div className="mt-5 p-4 bg-red-50 border border-red-200 rounded-lg flex items-start gap-3">
|
||||||
|
<AlertCircle size={18} className="text-red-600 flex-shrink-0 mt-0.5" />
|
||||||
|
<span className="text-sm text-red-800 font-medium">{error}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{selectedFile && !isProcessing && (
|
||||||
|
<div className="mt-6 flex gap-3 justify-end">
|
||||||
|
<Button variant="secondary" onClick={handleReset}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button onClick={handleProcess}>Process Invoice</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Results */}
|
||||||
|
{result && (
|
||||||
|
<div className="space-y-6">
|
||||||
|
{/* Status Header */}
|
||||||
|
<div className="bg-warm-card rounded-xl border border-warm-border shadow-sm overflow-hidden">
|
||||||
|
<div className="p-6 flex items-center justify-between border-b border-warm-divider">
|
||||||
|
<div className="flex items-center gap-4">
|
||||||
|
{result.success ? (
|
||||||
|
<div className="p-3 bg-green-100 rounded-xl">
|
||||||
|
<CheckCircle2 size={28} className="text-green-600" />
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="p-3 bg-yellow-100 rounded-xl">
|
||||||
|
<AlertCircle size={28} className="text-yellow-600" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div>
|
||||||
|
<h3 className="text-xl font-bold text-warm-text-primary">
|
||||||
|
{result.success ? 'Extraction Complete' : 'Partial Results'}
|
||||||
|
</h3>
|
||||||
|
<p className="text-sm text-warm-text-muted mt-0.5">
|
||||||
|
Document ID: <span className="font-mono">{result.document_id}</span>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<Button variant="secondary" onClick={handleReset}>
|
||||||
|
Process Another
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="px-6 py-4 bg-warm-bg/50 flex items-center gap-6 text-sm">
|
||||||
|
<div className="flex items-center gap-2 text-warm-text-secondary">
|
||||||
|
<Clock size={16} />
|
||||||
|
<span className="font-medium">
|
||||||
|
{result.processing_time_ms.toFixed(0)}ms
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
{result.fallback_used && (
|
||||||
|
<span className="px-3 py-1.5 bg-warm-selected rounded-md text-warm-text-secondary font-medium text-xs">
|
||||||
|
Fallback OCR Used
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Main Content Grid */}
|
||||||
|
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
|
||||||
|
{/* Left Column: Extracted Fields */}
|
||||||
|
<div className="lg:col-span-2 space-y-6">
|
||||||
|
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||||
|
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||||
|
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||||
|
Extracted Fields
|
||||||
|
</h3>
|
||||||
|
<div className="flex flex-wrap gap-4">
|
||||||
|
{Object.entries(result.fields).map(([field, value]) => {
|
||||||
|
const confidence = result.confidence[field]
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={field}
|
||||||
|
className="p-4 bg-warm-bg/70 rounded-lg border border-warm-divider hover:border-warm-text-secondary/30 transition-colors w-[calc(50%-0.5rem)]"
|
||||||
|
>
|
||||||
|
<div className="text-xs font-semibold text-warm-text-muted uppercase tracking-wide mb-2">
|
||||||
|
{formatFieldName(field)}
|
||||||
|
</div>
|
||||||
|
<div className="text-sm font-bold text-warm-text-primary mb-2 min-h-[1.5rem]">
|
||||||
|
{value || <span className="text-warm-text-muted italic">N/A</span>}
|
||||||
|
</div>
|
||||||
|
{confidence && (
|
||||||
|
<div className="flex items-center gap-1.5 text-xs font-medium text-warm-text-secondary">
|
||||||
|
<CheckCircle2 size={13} />
|
||||||
|
<span>{(confidence * 100).toFixed(1)}%</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Visualization */}
|
||||||
|
{result.visualization_url && (
|
||||||
|
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||||
|
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||||
|
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||||
|
Detection Visualization
|
||||||
|
</h3>
|
||||||
|
<div className="bg-warm-bg rounded-lg overflow-hidden border border-warm-divider">
|
||||||
|
<img
|
||||||
|
src={`${import.meta.env.VITE_API_URL || 'http://localhost:8000'}${result.visualization_url}`}
|
||||||
|
alt="Detection visualization"
|
||||||
|
className="w-full h-auto"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right Column: Cross-Validation & Errors */}
|
||||||
|
<div className="space-y-6">
|
||||||
|
{/* Cross-Validation */}
|
||||||
|
{result.cross_validation && (
|
||||||
|
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||||
|
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||||
|
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||||
|
Payment Line Validation
|
||||||
|
</h3>
|
||||||
|
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
p-4 rounded-lg mb-4 flex items-center gap-3
|
||||||
|
${result.cross_validation.is_valid
|
||||||
|
? 'bg-green-50 border border-green-200'
|
||||||
|
: 'bg-yellow-50 border border-yellow-200'
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
{result.cross_validation.is_valid ? (
|
||||||
|
<>
|
||||||
|
<CheckCircle2 size={22} className="text-green-600 flex-shrink-0" />
|
||||||
|
<span className="font-bold text-green-800">All Fields Match</span>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<AlertCircle size={22} className="text-yellow-600 flex-shrink-0" />
|
||||||
|
<span className="font-bold text-yellow-800">Mismatch Detected</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2.5">
|
||||||
|
{result.cross_validation.payment_line_ocr && (
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
p-3 rounded-lg border transition-colors
|
||||||
|
${result.cross_validation.ocr_match === true
|
||||||
|
? 'bg-green-50 border-green-200'
|
||||||
|
: result.cross_validation.ocr_match === false
|
||||||
|
? 'bg-red-50 border-red-200'
|
||||||
|
: 'bg-warm-bg border-warm-divider'
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="flex-1">
|
||||||
|
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||||
|
OCR NUMBER
|
||||||
|
</div>
|
||||||
|
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||||
|
{result.cross_validation.payment_line_ocr}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{result.cross_validation.ocr_match === true && (
|
||||||
|
<CheckCircle2 size={16} className="text-green-600" />
|
||||||
|
)}
|
||||||
|
{result.cross_validation.ocr_match === false && (
|
||||||
|
<AlertCircle size={16} className="text-red-600" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{result.cross_validation.payment_line_amount && (
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
p-3 rounded-lg border transition-colors
|
||||||
|
${result.cross_validation.amount_match === true
|
||||||
|
? 'bg-green-50 border-green-200'
|
||||||
|
: result.cross_validation.amount_match === false
|
||||||
|
? 'bg-red-50 border-red-200'
|
||||||
|
: 'bg-warm-bg border-warm-divider'
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="flex-1">
|
||||||
|
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||||
|
AMOUNT
|
||||||
|
</div>
|
||||||
|
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||||
|
{result.cross_validation.payment_line_amount}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{result.cross_validation.amount_match === true && (
|
||||||
|
<CheckCircle2 size={16} className="text-green-600" />
|
||||||
|
)}
|
||||||
|
{result.cross_validation.amount_match === false && (
|
||||||
|
<AlertCircle size={16} className="text-red-600" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{result.cross_validation.payment_line_account && (
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
p-3 rounded-lg border transition-colors
|
||||||
|
${(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||||
|
? result.cross_validation.bankgiro_match
|
||||||
|
: result.cross_validation.plusgiro_match) === true
|
||||||
|
? 'bg-green-50 border-green-200'
|
||||||
|
: (result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||||
|
? result.cross_validation.bankgiro_match
|
||||||
|
: result.cross_validation.plusgiro_match) === false
|
||||||
|
? 'bg-red-50 border-red-200'
|
||||||
|
: 'bg-warm-bg border-warm-divider'
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="flex-1">
|
||||||
|
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||||
|
{result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||||
|
? 'BANKGIRO'
|
||||||
|
: 'PLUSGIRO'}
|
||||||
|
</div>
|
||||||
|
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||||
|
{result.cross_validation.payment_line_account}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||||
|
? result.cross_validation.bankgiro_match
|
||||||
|
: result.cross_validation.plusgiro_match) === true && (
|
||||||
|
<CheckCircle2 size={16} className="text-green-600" />
|
||||||
|
)}
|
||||||
|
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||||
|
? result.cross_validation.bankgiro_match
|
||||||
|
: result.cross_validation.plusgiro_match) === false && (
|
||||||
|
<AlertCircle size={16} className="text-red-600" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{result.cross_validation.details.length > 0 && (
|
||||||
|
<div className="mt-4 p-3 bg-warm-bg/70 rounded-lg text-xs text-warm-text-secondary leading-relaxed border border-warm-divider">
|
||||||
|
{result.cross_validation.details[result.cross_validation.details.length - 1]}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Errors */}
|
||||||
|
{result.errors.length > 0 && (
|
||||||
|
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||||
|
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||||
|
<span className="w-1 h-5 bg-red-500 rounded-full"></span>
|
||||||
|
Issues
|
||||||
|
</h3>
|
||||||
|
<div className="space-y-2.5">
|
||||||
|
{result.errors.map((err, idx) => (
|
||||||
|
<div
|
||||||
|
key={idx}
|
||||||
|
className="p-3 bg-yellow-50 border border-yellow-200 rounded-lg flex items-start gap-3"
|
||||||
|
>
|
||||||
|
<AlertCircle size={16} className="text-yellow-600 flex-shrink-0 mt-0.5" />
|
||||||
|
<span className="text-xs text-yellow-800 leading-relaxed">{err}</span>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
102
frontend/src/components/Layout.tsx
Normal file
102
frontend/src/components/Layout.tsx
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import React, { useState } from 'react';
|
||||||
|
import { Box, LayoutTemplate, Users, BookOpen, LogOut, Sparkles } from 'lucide-react';
|
||||||
|
|
||||||
|
interface LayoutProps {
|
||||||
|
children: React.ReactNode;
|
||||||
|
activeView: string;
|
||||||
|
onNavigate: (view: string) => void;
|
||||||
|
onLogout?: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Layout: React.FC<LayoutProps> = ({ children, activeView, onNavigate, onLogout }) => {
|
||||||
|
const [showDropdown, setShowDropdown] = useState(false);
|
||||||
|
const navItems = [
|
||||||
|
{ id: 'dashboard', label: 'Dashboard', icon: LayoutTemplate },
|
||||||
|
{ id: 'demo', label: 'Demo', icon: Sparkles },
|
||||||
|
{ id: 'training', label: 'Training', icon: Box }, // Mapped to Compliants visually in prompt, using logical name
|
||||||
|
{ id: 'documents', label: 'Documents', icon: BookOpen },
|
||||||
|
{ id: 'models', label: 'Models', icon: Users }, // Contacts in prompt, mapped to models for this use case
|
||||||
|
];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen bg-warm-bg font-sans text-warm-text-primary flex flex-col">
|
||||||
|
{/* Top Navigation */}
|
||||||
|
<nav className="h-14 bg-warm-bg border-b border-warm-border px-6 flex items-center justify-between shrink-0 sticky top-0 z-40">
|
||||||
|
<div className="flex items-center gap-8">
|
||||||
|
{/* Logo */}
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="w-8 h-8 bg-warm-text-primary rounded-full flex items-center justify-center text-white">
|
||||||
|
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="3" strokeLinecap="round" strokeLinejoin="round">
|
||||||
|
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Nav Links */}
|
||||||
|
<div className="flex h-14">
|
||||||
|
{navItems.map(item => {
|
||||||
|
const isActive = activeView === item.id || (activeView === 'detail' && item.id === 'documents');
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
key={item.id}
|
||||||
|
onClick={() => onNavigate(item.id)}
|
||||||
|
className={`
|
||||||
|
relative px-4 h-full flex items-center text-sm font-medium transition-colors
|
||||||
|
${isActive ? 'text-warm-text-primary' : 'text-warm-text-muted hover:text-warm-text-secondary'}
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
{item.label}
|
||||||
|
{isActive && (
|
||||||
|
<div className="absolute bottom-0 left-0 right-0 h-0.5 bg-warm-text-secondary rounded-t-full mx-2" />
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* User Profile */}
|
||||||
|
<div className="flex items-center gap-3 pl-6 border-l border-warm-border h-6 relative">
|
||||||
|
<button
|
||||||
|
onClick={() => setShowDropdown(!showDropdown)}
|
||||||
|
className="w-8 h-8 rounded-full bg-warm-selected flex items-center justify-center text-xs font-semibold text-warm-text-secondary border border-warm-divider hover:bg-warm-hover transition-colors"
|
||||||
|
>
|
||||||
|
AD
|
||||||
|
</button>
|
||||||
|
|
||||||
|
{showDropdown && (
|
||||||
|
<>
|
||||||
|
<div
|
||||||
|
className="fixed inset-0 z-10"
|
||||||
|
onClick={() => setShowDropdown(false)}
|
||||||
|
/>
|
||||||
|
<div className="absolute right-0 top-10 w-48 bg-warm-card border border-warm-border rounded-lg shadow-modal z-20">
|
||||||
|
<div className="p-3 border-b border-warm-border">
|
||||||
|
<p className="text-sm font-medium text-warm-text-primary">Admin User</p>
|
||||||
|
<p className="text-xs text-warm-text-muted mt-0.5">Authenticated</p>
|
||||||
|
</div>
|
||||||
|
{onLogout && (
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
setShowDropdown(false)
|
||||||
|
onLogout()
|
||||||
|
}}
|
||||||
|
className="w-full px-3 py-2 text-left text-sm text-warm-text-secondary hover:bg-warm-hover transition-colors flex items-center gap-2"
|
||||||
|
>
|
||||||
|
<LogOut size={14} />
|
||||||
|
Sign Out
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
{/* Main Content */}
|
||||||
|
<main className="flex-1 overflow-auto">
|
||||||
|
{children}
|
||||||
|
</main>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
188
frontend/src/components/Login.tsx
Normal file
188
frontend/src/components/Login.tsx
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
import { Button } from './Button'
|
||||||
|
|
||||||
|
interface LoginProps {
|
||||||
|
onLogin: (token: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Login: React.FC<LoginProps> = ({ onLogin }) => {
|
||||||
|
const [token, setToken] = useState('')
|
||||||
|
const [name, setName] = useState('')
|
||||||
|
const [description, setDescription] = useState('')
|
||||||
|
const [isCreating, setIsCreating] = useState(false)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
const [createdToken, setCreatedToken] = useState('')
|
||||||
|
|
||||||
|
const handleLoginWithToken = () => {
|
||||||
|
if (!token.trim()) {
|
||||||
|
setError('Please enter a token')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
localStorage.setItem('admin_token', token.trim())
|
||||||
|
onLogin(token.trim())
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCreateToken = async () => {
|
||||||
|
if (!name.trim()) {
|
||||||
|
setError('Please enter a token name')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsCreating(true)
|
||||||
|
setError('')
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:8000/api/v1/admin/auth/token', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
name: name.trim(),
|
||||||
|
description: description.trim() || undefined,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to create token')
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json()
|
||||||
|
setCreatedToken(data.token)
|
||||||
|
setToken(data.token)
|
||||||
|
setError('')
|
||||||
|
} catch (err) {
|
||||||
|
setError('Failed to create token. Please check your connection.')
|
||||||
|
console.error(err)
|
||||||
|
} finally {
|
||||||
|
setIsCreating(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleUseCreatedToken = () => {
|
||||||
|
if (createdToken) {
|
||||||
|
localStorage.setItem('admin_token', createdToken)
|
||||||
|
onLogin(createdToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="min-h-screen bg-warm-bg flex items-center justify-center p-4">
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg shadow-modal p-8 max-w-md w-full">
|
||||||
|
<h1 className="text-2xl font-bold text-warm-text-primary mb-2">
|
||||||
|
Admin Authentication
|
||||||
|
</h1>
|
||||||
|
<p className="text-sm text-warm-text-muted mb-6">
|
||||||
|
Sign in with an admin token to access the document management system
|
||||||
|
</p>
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<div className="mb-4 p-3 bg-red-50 border border-red-200 text-red-800 rounded text-sm">
|
||||||
|
{error}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{createdToken && (
|
||||||
|
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded">
|
||||||
|
<p className="text-sm font-medium text-green-800 mb-2">Token created successfully!</p>
|
||||||
|
<div className="bg-white border border-green-300 rounded p-2 mb-3">
|
||||||
|
<code className="text-xs font-mono text-warm-text-primary break-all">
|
||||||
|
{createdToken}
|
||||||
|
</code>
|
||||||
|
</div>
|
||||||
|
<p className="text-xs text-green-700 mb-3">
|
||||||
|
Save this token securely. You won't be able to see it again.
|
||||||
|
</p>
|
||||||
|
<Button onClick={handleUseCreatedToken} className="w-full">
|
||||||
|
Use This Token
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="space-y-6">
|
||||||
|
{/* Login with existing token */}
|
||||||
|
<div>
|
||||||
|
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||||
|
Sign in with existing token
|
||||||
|
</h2>
|
||||||
|
<div className="space-y-3">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||||
|
Admin Token
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={token}
|
||||||
|
onChange={(e) => setToken(e.target.value)}
|
||||||
|
placeholder="Enter your admin token"
|
||||||
|
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info font-mono"
|
||||||
|
onKeyDown={(e) => e.key === 'Enter' && handleLoginWithToken()}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<Button onClick={handleLoginWithToken} className="w-full">
|
||||||
|
Sign In
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="relative">
|
||||||
|
<div className="absolute inset-0 flex items-center">
|
||||||
|
<div className="w-full border-t border-warm-border"></div>
|
||||||
|
</div>
|
||||||
|
<div className="relative flex justify-center text-xs">
|
||||||
|
<span className="px-2 bg-warm-card text-warm-text-muted">OR</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Create new token */}
|
||||||
|
<div>
|
||||||
|
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||||
|
Create new admin token
|
||||||
|
</h2>
|
||||||
|
<div className="space-y-3">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||||
|
Token Name <span className="text-red-500">*</span>
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={name}
|
||||||
|
onChange={(e) => setName(e.target.value)}
|
||||||
|
placeholder="e.g., my-laptop"
|
||||||
|
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||||
|
Description (optional)
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={description}
|
||||||
|
onChange={(e) => setDescription(e.target.value)}
|
||||||
|
placeholder="e.g., Personal laptop access"
|
||||||
|
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
onClick={handleCreateToken}
|
||||||
|
variant="secondary"
|
||||||
|
disabled={isCreating}
|
||||||
|
className="w-full"
|
||||||
|
>
|
||||||
|
{isCreating ? 'Creating...' : 'Create Token'}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mt-6 pt-4 border-t border-warm-border">
|
||||||
|
<p className="text-xs text-warm-text-muted">
|
||||||
|
Admin tokens are used to authenticate with the document management API.
|
||||||
|
Keep your tokens secure and never share them.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
134
frontend/src/components/Models.tsx
Normal file
134
frontend/src/components/Models.tsx
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
|
||||||
|
import { Button } from './Button';
|
||||||
|
|
||||||
|
const CHART_DATA = [
|
||||||
|
{ name: 'Model A', value: 75 },
|
||||||
|
{ name: 'Model B', value: 82 },
|
||||||
|
{ name: 'Model C', value: 95 },
|
||||||
|
{ name: 'Model D', value: 68 },
|
||||||
|
];
|
||||||
|
|
||||||
|
const METRICS_DATA = [
|
||||||
|
{ name: 'Precision', value: 88 },
|
||||||
|
{ name: 'Recall', value: 76 },
|
||||||
|
{ name: 'F1 Score', value: 91 },
|
||||||
|
{ name: 'Accuracy', value: 82 },
|
||||||
|
];
|
||||||
|
|
||||||
|
const JOBS = [
|
||||||
|
{ id: 1, name: 'Training Job Job 1', date: '12/29/2024 10:33 PM', status: 'Running', progress: 65 },
|
||||||
|
{ id: 2, name: 'Training Job 2', date: '12/29/2024 10:33 PM', status: 'Completed', success: 37, metrics: 89 },
|
||||||
|
{ id: 3, name: 'Model Training Compentr 1', date: '12/29/2024 10:19 PM', status: 'Completed', success: 87, metrics: 92 },
|
||||||
|
];
|
||||||
|
|
||||||
|
export const Models: React.FC = () => {
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto flex gap-8">
|
||||||
|
{/* Left: Job History */}
|
||||||
|
<div className="flex-1">
|
||||||
|
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
|
||||||
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Recent Training Jobs</h3>
|
||||||
|
|
||||||
|
<div className="space-y-4">
|
||||||
|
{JOBS.map(job => (
|
||||||
|
<div key={job.id} className="bg-warm-card border border-warm-border rounded-lg p-5 shadow-sm hover:border-warm-divider transition-colors">
|
||||||
|
<div className="flex justify-between items-start mb-2">
|
||||||
|
<div>
|
||||||
|
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">{job.name}</h4>
|
||||||
|
<p className="text-sm text-warm-text-muted">Started {job.date}</p>
|
||||||
|
</div>
|
||||||
|
<span className={`px-3 py-1 rounded-full text-xs font-medium ${job.status === 'Running' ? 'bg-warm-selected text-warm-text-secondary' : 'bg-warm-selected text-warm-state-success'}`}>
|
||||||
|
{job.status}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{job.status === 'Running' ? (
|
||||||
|
<div className="mt-4">
|
||||||
|
<div className="h-2 w-full bg-warm-selected rounded-full overflow-hidden">
|
||||||
|
<div className="h-full bg-warm-text-secondary w-[65%] rounded-full"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="mt-4 flex gap-8">
|
||||||
|
<div>
|
||||||
|
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Success</span>
|
||||||
|
<span className="text-lg font-mono text-warm-text-secondary">{job.success}</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Performance</span>
|
||||||
|
<span className="text-lg font-mono text-warm-text-secondary">{job.metrics}%</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Completed</span>
|
||||||
|
<span className="text-lg font-mono text-warm-text-secondary">100%</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right: Model Detail */}
|
||||||
|
<div className="w-[400px]">
|
||||||
|
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
|
||||||
|
<div className="flex justify-between items-center mb-6">
|
||||||
|
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
|
||||||
|
<span className="text-sm font-medium text-warm-state-success">Completed</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mb-8">
|
||||||
|
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
|
||||||
|
<p className="font-medium text-warm-text-primary">Invoices Q4 v2.1</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-8">
|
||||||
|
{/* Chart 1 */}
|
||||||
|
<div>
|
||||||
|
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Bar Rate Metrics</h4>
|
||||||
|
<div className="h-40">
|
||||||
|
<ResponsiveContainer width="100%" height="100%">
|
||||||
|
<BarChart data={CHART_DATA}>
|
||||||
|
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||||
|
<XAxis dataKey="name" hide />
|
||||||
|
<YAxis hide domain={[0, 100]} />
|
||||||
|
<Tooltip
|
||||||
|
cursor={{fill: '#F1F0ED'}}
|
||||||
|
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
|
||||||
|
/>
|
||||||
|
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||||
|
</BarChart>
|
||||||
|
</ResponsiveContainer>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Chart 2 */}
|
||||||
|
<div>
|
||||||
|
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Entity Extraction Accuracy</h4>
|
||||||
|
<div className="h-40">
|
||||||
|
<ResponsiveContainer width="100%" height="100%">
|
||||||
|
<BarChart data={METRICS_DATA}>
|
||||||
|
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||||
|
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||||
|
<YAxis hide domain={[0, 100]} />
|
||||||
|
<Tooltip cursor={{fill: '#F1F0ED'}} />
|
||||||
|
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||||
|
</BarChart>
|
||||||
|
</ResponsiveContainer>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mt-8 space-y-3">
|
||||||
|
<Button className="w-full">Download Model</Button>
|
||||||
|
<div className="flex gap-3">
|
||||||
|
<Button variant="secondary" className="flex-1">View Logs</Button>
|
||||||
|
<Button variant="secondary" className="flex-1">Use as Base</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
113
frontend/src/components/Training.tsx
Normal file
113
frontend/src/components/Training.tsx
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import React, { useState } from 'react';
|
||||||
|
import { Check, AlertCircle } from 'lucide-react';
|
||||||
|
import { Button } from './Button';
|
||||||
|
import { DocumentStatus } from '../types';
|
||||||
|
|
||||||
|
export const Training: React.FC = () => {
|
||||||
|
const [split, setSplit] = useState(80);
|
||||||
|
|
||||||
|
const docs = [
|
||||||
|
{ id: '1', name: 'Document Document 1', date: '12/28/2024', status: DocumentStatus.VERIFIED },
|
||||||
|
{ id: '2', name: 'Document Document 2', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||||
|
{ id: '3', name: 'Document Document 3', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||||
|
{ id: '4', name: 'Document Document 4', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||||
|
{ id: '5', name: 'Document Document 5', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||||
|
{ id: '6', name: 'Document Document 6', date: '12/29/2024', status: DocumentStatus.PARTIAL },
|
||||||
|
{ id: '8', name: 'Document Document 8', date: '12/29/2024', status: DocumentStatus.VERIFIED },
|
||||||
|
];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="p-8 max-w-7xl mx-auto h-[calc(100vh-56px)] flex gap-8">
|
||||||
|
{/* Document Selection List */}
|
||||||
|
<div className="flex-1 flex flex-col">
|
||||||
|
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Document Selection</h2>
|
||||||
|
|
||||||
|
<div className="flex-1 bg-warm-card border border-warm-border rounded-lg overflow-hidden flex flex-col shadow-sm">
|
||||||
|
<div className="overflow-auto flex-1">
|
||||||
|
<table className="w-full text-left">
|
||||||
|
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
||||||
|
<tr>
|
||||||
|
<th className="py-3 pl-6 pr-4 w-12"><input type="checkbox" className="rounded border-warm-divider"/></th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document name</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Date</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{docs.map(doc => (
|
||||||
|
<tr key={doc.id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||||
|
<td className="py-3 pl-6 pr-4"><input type="checkbox" defaultChecked className="rounded border-warm-divider accent-warm-state-info"/></td>
|
||||||
|
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{doc.name}</td>
|
||||||
|
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.date}</td>
|
||||||
|
<td className="py-3 px-4">
|
||||||
|
{doc.status === DocumentStatus.VERIFIED ? (
|
||||||
|
<div className="flex items-center text-warm-state-success text-sm font-medium">
|
||||||
|
<div className="w-5 h-5 rounded-full bg-warm-state-success flex items-center justify-center text-white mr-2">
|
||||||
|
<Check size={12} strokeWidth={3}/>
|
||||||
|
</div>
|
||||||
|
Verified
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="flex items-center text-warm-text-muted text-sm">
|
||||||
|
<div className="w-5 h-5 rounded-full bg-[#BDBBB5] flex items-center justify-center text-white mr-2">
|
||||||
|
<span className="font-bold text-[10px]">!</span>
|
||||||
|
</div>
|
||||||
|
Partial
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Configuration Panel */}
|
||||||
|
<div className="w-96">
|
||||||
|
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
||||||
|
<h3 className="text-lg font-semibold text-warm-text-primary mb-6">Training Configuration</h3>
|
||||||
|
|
||||||
|
<div className="space-y-6">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Model Name</label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
placeholder="e.g. Invoices Q4"
|
||||||
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Base Model</label>
|
||||||
|
<select className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none">
|
||||||
|
<option>LayoutLMv3 (Standard)</option>
|
||||||
|
<option>Donut (Beta)</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<div className="flex justify-between mb-2">
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary">Train/Test Split</label>
|
||||||
|
<span className="text-xs font-mono text-warm-text-muted">{split}% / {100-split}%</span>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
type="range"
|
||||||
|
min="50"
|
||||||
|
max="95"
|
||||||
|
value={split}
|
||||||
|
onChange={(e) => setSplit(parseInt(e.target.value))}
|
||||||
|
className="w-full h-1.5 bg-warm-border rounded-lg appearance-none cursor-pointer accent-warm-state-info"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="pt-4 border-t border-warm-border">
|
||||||
|
<Button className="w-full h-12">Start Training</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
210
frontend/src/components/UploadModal.tsx
Normal file
210
frontend/src/components/UploadModal.tsx
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
import React, { useState, useRef } from 'react'
|
||||||
|
import { X, UploadCloud, File, CheckCircle, AlertCircle } from 'lucide-react'
|
||||||
|
import { Button } from './Button'
|
||||||
|
import { useDocuments } from '../hooks/useDocuments'
|
||||||
|
|
||||||
|
interface UploadModalProps {
|
||||||
|
isOpen: boolean
|
||||||
|
onClose: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
|
||||||
|
const [isDragging, setIsDragging] = useState(false)
|
||||||
|
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
||||||
|
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
||||||
|
const [errorMessage, setErrorMessage] = useState('')
|
||||||
|
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||||
|
|
||||||
|
const { uploadDocument, isUploading } = useDocuments({})
|
||||||
|
|
||||||
|
if (!isOpen) return null
|
||||||
|
|
||||||
|
const handleFileSelect = (files: FileList | null) => {
|
||||||
|
if (!files) return
|
||||||
|
|
||||||
|
const pdfFiles = Array.from(files).filter(file => {
|
||||||
|
const isPdf = file.type === 'application/pdf'
|
||||||
|
const isImage = file.type.startsWith('image/')
|
||||||
|
const isUnder25MB = file.size <= 25 * 1024 * 1024
|
||||||
|
return (isPdf || isImage) && isUnder25MB
|
||||||
|
})
|
||||||
|
|
||||||
|
setSelectedFiles(prev => [...prev, ...pdfFiles])
|
||||||
|
setUploadStatus('idle')
|
||||||
|
setErrorMessage('')
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDrop = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
setIsDragging(false)
|
||||||
|
handleFileSelect(e.dataTransfer.files)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleBrowseClick = () => {
|
||||||
|
fileInputRef.current?.click()
|
||||||
|
}
|
||||||
|
|
||||||
|
const removeFile = (index: number) => {
|
||||||
|
setSelectedFiles(prev => prev.filter((_, i) => i !== index))
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleUpload = async () => {
|
||||||
|
if (selectedFiles.length === 0) {
|
||||||
|
setErrorMessage('Please select at least one file')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setUploadStatus('uploading')
|
||||||
|
setErrorMessage('')
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Upload files one by one
|
||||||
|
for (const file of selectedFiles) {
|
||||||
|
await new Promise<void>((resolve, reject) => {
|
||||||
|
uploadDocument(file, {
|
||||||
|
onSuccess: () => resolve(),
|
||||||
|
onError: (error: Error) => reject(error),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
setUploadStatus('success')
|
||||||
|
setTimeout(() => {
|
||||||
|
onClose()
|
||||||
|
setSelectedFiles([])
|
||||||
|
setUploadStatus('idle')
|
||||||
|
}, 1500)
|
||||||
|
} catch (error) {
|
||||||
|
setUploadStatus('error')
|
||||||
|
setErrorMessage(error instanceof Error ? error.message : 'Upload failed')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleClose = () => {
|
||||||
|
if (uploadStatus === 'uploading') {
|
||||||
|
return // Prevent closing during upload
|
||||||
|
}
|
||||||
|
setSelectedFiles([])
|
||||||
|
setUploadStatus('idle')
|
||||||
|
setErrorMessage('')
|
||||||
|
onClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/20 backdrop-blur-sm transition-opacity duration-200">
|
||||||
|
<div
|
||||||
|
className="w-full max-w-lg bg-warm-card rounded-lg shadow-modal border border-warm-border transform transition-all duration-200 scale-100 p-6"
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
|
>
|
||||||
|
<div className="flex items-center justify-between mb-6">
|
||||||
|
<h3 className="text-xl font-semibold text-warm-text-primary">Upload Documents</h3>
|
||||||
|
<button
|
||||||
|
onClick={handleClose}
|
||||||
|
className="text-warm-text-muted hover:text-warm-text-primary transition-colors disabled:opacity-50"
|
||||||
|
disabled={uploadStatus === 'uploading'}
|
||||||
|
>
|
||||||
|
<X size={20} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Drop Zone */}
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
w-full h-48 rounded-lg border-2 border-dashed flex flex-col items-center justify-center gap-3 transition-colors duration-150 mb-6 cursor-pointer
|
||||||
|
${isDragging ? 'border-warm-text-secondary bg-warm-selected' : 'border-warm-divider bg-warm-bg hover:bg-warm-hover'}
|
||||||
|
${uploadStatus === 'uploading' ? 'opacity-50 pointer-events-none' : ''}
|
||||||
|
`}
|
||||||
|
onDragOver={(e) => { e.preventDefault(); setIsDragging(true); }}
|
||||||
|
onDragLeave={() => setIsDragging(false)}
|
||||||
|
onDrop={handleDrop}
|
||||||
|
onClick={handleBrowseClick}
|
||||||
|
>
|
||||||
|
<div className="p-3 bg-white rounded-full shadow-sm">
|
||||||
|
<UploadCloud size={24} className="text-warm-text-secondary" />
|
||||||
|
</div>
|
||||||
|
<div className="text-center">
|
||||||
|
<p className="text-sm font-medium text-warm-text-primary">
|
||||||
|
Drag & drop files here or <span className="underline decoration-1 underline-offset-2 hover:text-warm-state-info">Browse</span>
|
||||||
|
</p>
|
||||||
|
<p className="text-xs text-warm-text-muted mt-1">PDF, JPG, PNG up to 25MB</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<input
|
||||||
|
ref={fileInputRef}
|
||||||
|
type="file"
|
||||||
|
multiple
|
||||||
|
accept=".pdf,image/*"
|
||||||
|
className="hidden"
|
||||||
|
onChange={(e) => handleFileSelect(e.target.files)}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Selected Files */}
|
||||||
|
{selectedFiles.length > 0 && (
|
||||||
|
<div className="mb-6 max-h-40 overflow-y-auto">
|
||||||
|
<p className="text-sm font-medium text-warm-text-secondary mb-2">
|
||||||
|
Selected Files ({selectedFiles.length})
|
||||||
|
</p>
|
||||||
|
<div className="space-y-2">
|
||||||
|
{selectedFiles.map((file, index) => (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className="flex items-center justify-between p-2 bg-warm-bg rounded border border-warm-border"
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-2 flex-1 min-w-0">
|
||||||
|
<File size={16} className="text-warm-text-muted flex-shrink-0" />
|
||||||
|
<span className="text-sm text-warm-text-secondary truncate">
|
||||||
|
{file.name}
|
||||||
|
</span>
|
||||||
|
<span className="text-xs text-warm-text-muted flex-shrink-0">
|
||||||
|
({(file.size / 1024 / 1024).toFixed(2)} MB)
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={() => removeFile(index)}
|
||||||
|
className="text-warm-text-muted hover:text-warm-state-error ml-2 flex-shrink-0"
|
||||||
|
disabled={uploadStatus === 'uploading'}
|
||||||
|
>
|
||||||
|
<X size={16} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Status Messages */}
|
||||||
|
{uploadStatus === 'success' && (
|
||||||
|
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
|
||||||
|
<CheckCircle size={16} className="text-green-600" />
|
||||||
|
<span className="text-sm text-green-800">Upload successful!</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{uploadStatus === 'error' && errorMessage && (
|
||||||
|
<div className="mb-4 p-3 bg-red-50 border border-red-200 rounded flex items-center gap-2">
|
||||||
|
<AlertCircle size={16} className="text-red-600" />
|
||||||
|
<span className="text-sm text-red-800">{errorMessage}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Actions */}
|
||||||
|
<div className="mt-8 flex justify-end gap-3">
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
onClick={handleClose}
|
||||||
|
disabled={uploadStatus === 'uploading'}
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
onClick={handleUpload}
|
||||||
|
disabled={selectedFiles.length === 0 || uploadStatus === 'uploading'}
|
||||||
|
>
|
||||||
|
{uploadStatus === 'uploading' ? 'Uploading...' : `Upload ${selectedFiles.length > 0 ? `(${selectedFiles.length})` : ''}`}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
4
frontend/src/hooks/index.ts
Normal file
4
frontend/src/hooks/index.ts
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
export { useDocuments } from './useDocuments'
|
||||||
|
export { useDocumentDetail } from './useDocumentDetail'
|
||||||
|
export { useAnnotations } from './useAnnotations'
|
||||||
|
export { useTraining, useTrainingDocuments } from './useTraining'
|
||||||
70
frontend/src/hooks/useAnnotations.ts
Normal file
70
frontend/src/hooks/useAnnotations.ts
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import { useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { annotationsApi } from '../api/endpoints'
|
||||||
|
import type { CreateAnnotationRequest, AnnotationOverrideRequest } from '../api/types'
|
||||||
|
|
||||||
|
export const useAnnotations = (documentId: string) => {
|
||||||
|
const queryClient = useQueryClient()
|
||||||
|
|
||||||
|
const createMutation = useMutation({
|
||||||
|
mutationFn: (annotation: CreateAnnotationRequest) =>
|
||||||
|
annotationsApi.create(documentId, annotation),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const updateMutation = useMutation({
|
||||||
|
mutationFn: ({
|
||||||
|
annotationId,
|
||||||
|
updates,
|
||||||
|
}: {
|
||||||
|
annotationId: string
|
||||||
|
updates: Partial<CreateAnnotationRequest>
|
||||||
|
}) => annotationsApi.update(documentId, annotationId, updates),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const deleteMutation = useMutation({
|
||||||
|
mutationFn: (annotationId: string) =>
|
||||||
|
annotationsApi.delete(documentId, annotationId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const verifyMutation = useMutation({
|
||||||
|
mutationFn: (annotationId: string) =>
|
||||||
|
annotationsApi.verify(documentId, annotationId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const overrideMutation = useMutation({
|
||||||
|
mutationFn: ({
|
||||||
|
annotationId,
|
||||||
|
overrideData,
|
||||||
|
}: {
|
||||||
|
annotationId: string
|
||||||
|
overrideData: AnnotationOverrideRequest
|
||||||
|
}) => annotationsApi.override(documentId, annotationId, overrideData),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
createAnnotation: createMutation.mutate,
|
||||||
|
isCreating: createMutation.isPending,
|
||||||
|
updateAnnotation: updateMutation.mutate,
|
||||||
|
isUpdating: updateMutation.isPending,
|
||||||
|
deleteAnnotation: deleteMutation.mutate,
|
||||||
|
isDeleting: deleteMutation.isPending,
|
||||||
|
verifyAnnotation: verifyMutation.mutate,
|
||||||
|
isVerifying: verifyMutation.isPending,
|
||||||
|
overrideAnnotation: overrideMutation.mutate,
|
||||||
|
isOverriding: overrideMutation.isPending,
|
||||||
|
}
|
||||||
|
}
|
||||||
25
frontend/src/hooks/useDocumentDetail.ts
Normal file
25
frontend/src/hooks/useDocumentDetail.ts
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import { useQuery } from '@tanstack/react-query'
|
||||||
|
import { documentsApi } from '../api/endpoints'
|
||||||
|
import type { DocumentDetailResponse } from '../api/types'
|
||||||
|
|
||||||
|
export const useDocumentDetail = (documentId: string | null) => {
|
||||||
|
const { data, isLoading, error, refetch } = useQuery<DocumentDetailResponse>({
|
||||||
|
queryKey: ['document', documentId],
|
||||||
|
queryFn: () => {
|
||||||
|
if (!documentId) {
|
||||||
|
throw new Error('Document ID is required')
|
||||||
|
}
|
||||||
|
return documentsApi.getDetail(documentId)
|
||||||
|
},
|
||||||
|
enabled: !!documentId,
|
||||||
|
staleTime: 10000,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
document: data || null,
|
||||||
|
annotations: data?.annotations || [],
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
refetch,
|
||||||
|
}
|
||||||
|
}
|
||||||
78
frontend/src/hooks/useDocuments.ts
Normal file
78
frontend/src/hooks/useDocuments.ts
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { documentsApi } from '../api/endpoints'
|
||||||
|
import type { DocumentListResponse, UploadDocumentResponse } from '../api/types'
|
||||||
|
|
||||||
|
interface UseDocumentsParams {
|
||||||
|
status?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useDocuments = (params: UseDocumentsParams = {}) => {
|
||||||
|
const queryClient = useQueryClient()
|
||||||
|
|
||||||
|
const { data, isLoading, error, refetch } = useQuery<DocumentListResponse>({
|
||||||
|
queryKey: ['documents', params],
|
||||||
|
queryFn: () => documentsApi.list(params),
|
||||||
|
staleTime: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
const uploadMutation = useMutation({
|
||||||
|
mutationFn: (file: File) => documentsApi.upload(file),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const batchUploadMutation = useMutation({
|
||||||
|
mutationFn: ({ files, csvFile }: { files: File[]; csvFile?: File }) =>
|
||||||
|
documentsApi.batchUpload(files, csvFile),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const deleteMutation = useMutation({
|
||||||
|
mutationFn: (documentId: string) => documentsApi.delete(documentId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const updateStatusMutation = useMutation({
|
||||||
|
mutationFn: ({ documentId, status }: { documentId: string; status: string }) =>
|
||||||
|
documentsApi.updateStatus(documentId, status),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const triggerAutoLabelMutation = useMutation({
|
||||||
|
mutationFn: (documentId: string) => documentsApi.triggerAutoLabel(documentId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
documents: data?.documents || [],
|
||||||
|
total: data?.total || 0,
|
||||||
|
limit: data?.limit || params.limit || 20,
|
||||||
|
offset: data?.offset || params.offset || 0,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
refetch,
|
||||||
|
uploadDocument: uploadMutation.mutate,
|
||||||
|
uploadDocumentAsync: uploadMutation.mutateAsync,
|
||||||
|
isUploading: uploadMutation.isPending,
|
||||||
|
batchUpload: batchUploadMutation.mutate,
|
||||||
|
batchUploadAsync: batchUploadMutation.mutateAsync,
|
||||||
|
isBatchUploading: batchUploadMutation.isPending,
|
||||||
|
deleteDocument: deleteMutation.mutate,
|
||||||
|
isDeleting: deleteMutation.isPending,
|
||||||
|
updateStatus: updateStatusMutation.mutate,
|
||||||
|
isUpdatingStatus: updateStatusMutation.isPending,
|
||||||
|
triggerAutoLabel: triggerAutoLabelMutation.mutate,
|
||||||
|
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
|
||||||
|
}
|
||||||
|
}
|
||||||
83
frontend/src/hooks/useTraining.ts
Normal file
83
frontend/src/hooks/useTraining.ts
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
|
import { trainingApi } from '../api/endpoints'
|
||||||
|
import type { TrainingModelsResponse } from '../api/types'
|
||||||
|
|
||||||
|
export const useTraining = () => {
|
||||||
|
const queryClient = useQueryClient()
|
||||||
|
|
||||||
|
const { data: modelsData, isLoading: isLoadingModels } =
|
||||||
|
useQuery<TrainingModelsResponse>({
|
||||||
|
queryKey: ['training', 'models'],
|
||||||
|
queryFn: () => trainingApi.getModels(),
|
||||||
|
staleTime: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
const startTrainingMutation = useMutation({
|
||||||
|
mutationFn: (config: {
|
||||||
|
name: string
|
||||||
|
description?: string
|
||||||
|
document_ids: string[]
|
||||||
|
epochs?: number
|
||||||
|
batch_size?: number
|
||||||
|
model_base?: string
|
||||||
|
}) => trainingApi.startTraining(config),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const cancelTaskMutation = useMutation({
|
||||||
|
mutationFn: (taskId: string) => trainingApi.cancelTask(taskId),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const downloadModelMutation = useMutation({
|
||||||
|
mutationFn: (taskId: string) => trainingApi.downloadModel(taskId),
|
||||||
|
onSuccess: (blob, taskId) => {
|
||||||
|
const url = window.URL.createObjectURL(blob)
|
||||||
|
const a = document.createElement('a')
|
||||||
|
a.href = url
|
||||||
|
a.download = `model-${taskId}.pt`
|
||||||
|
document.body.appendChild(a)
|
||||||
|
a.click()
|
||||||
|
window.URL.revokeObjectURL(url)
|
||||||
|
document.body.removeChild(a)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
models: modelsData?.models || [],
|
||||||
|
total: modelsData?.total || 0,
|
||||||
|
isLoadingModels,
|
||||||
|
startTraining: startTrainingMutation.mutate,
|
||||||
|
startTrainingAsync: startTrainingMutation.mutateAsync,
|
||||||
|
isStartingTraining: startTrainingMutation.isPending,
|
||||||
|
cancelTask: cancelTaskMutation.mutate,
|
||||||
|
isCancelling: cancelTaskMutation.isPending,
|
||||||
|
downloadModel: downloadModelMutation.mutate,
|
||||||
|
isDownloading: downloadModelMutation.isPending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useTrainingDocuments = (params?: {
|
||||||
|
has_annotations?: boolean
|
||||||
|
min_annotation_count?: number
|
||||||
|
exclude_used_in_training?: boolean
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}) => {
|
||||||
|
const { data, isLoading, error } = useQuery({
|
||||||
|
queryKey: ['training', 'documents', params],
|
||||||
|
queryFn: () => trainingApi.getDocumentsForTraining(params),
|
||||||
|
staleTime: 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
documents: data?.documents || [],
|
||||||
|
total: data?.total || 0,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
}
|
||||||
|
}
|
||||||
23
frontend/src/main.tsx
Normal file
23
frontend/src/main.tsx
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import ReactDOM from 'react-dom/client'
|
||||||
|
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||||
|
import App from './App'
|
||||||
|
import './styles/index.css'
|
||||||
|
|
||||||
|
const queryClient = new QueryClient({
|
||||||
|
defaultOptions: {
|
||||||
|
queries: {
|
||||||
|
retry: 1,
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
|
staleTime: 30000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||||
|
<React.StrictMode>
|
||||||
|
<QueryClientProvider client={queryClient}>
|
||||||
|
<App />
|
||||||
|
</QueryClientProvider>
|
||||||
|
</React.StrictMode>
|
||||||
|
)
|
||||||
26
frontend/src/styles/index.css
Normal file
26
frontend/src/styles/index.css
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
@tailwind base;
|
||||||
|
@tailwind components;
|
||||||
|
@tailwind utilities;
|
||||||
|
|
||||||
|
@layer base {
|
||||||
|
body {
|
||||||
|
@apply bg-warm-bg text-warm-text-primary;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Custom scrollbar */
|
||||||
|
::-webkit-scrollbar {
|
||||||
|
@apply w-2 h-2;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-track {
|
||||||
|
@apply bg-transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb {
|
||||||
|
@apply bg-warm-divider rounded;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb:hover {
|
||||||
|
@apply bg-warm-text-disabled;
|
||||||
|
}
|
||||||
|
}
|
||||||
48
frontend/src/types/index.ts
Normal file
48
frontend/src/types/index.ts
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
// Legacy types for backward compatibility with old components
|
||||||
|
// These will be gradually replaced with API types
|
||||||
|
|
||||||
|
export enum DocumentStatus {
|
||||||
|
PENDING = 'Pending',
|
||||||
|
LABELED = 'Labeled',
|
||||||
|
VERIFIED = 'Verified',
|
||||||
|
PARTIAL = 'Partial'
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Document {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
date: string
|
||||||
|
status: DocumentStatus
|
||||||
|
exported: boolean
|
||||||
|
autoLabelProgress?: number
|
||||||
|
autoLabelStatus?: 'Running' | 'Completed' | 'Failed'
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Annotation {
|
||||||
|
id: string
|
||||||
|
text: string
|
||||||
|
label: string
|
||||||
|
x: number
|
||||||
|
y: number
|
||||||
|
width: number
|
||||||
|
height: number
|
||||||
|
isAuto?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TrainingJob {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
startDate: string
|
||||||
|
status: 'Running' | 'Completed' | 'Failed'
|
||||||
|
progress: number
|
||||||
|
metrics?: {
|
||||||
|
accuracy: number
|
||||||
|
precision: number
|
||||||
|
recall: number
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelMetric {
|
||||||
|
name: string
|
||||||
|
value: number
|
||||||
|
}
|
||||||
47
frontend/tailwind.config.js
Normal file
47
frontend/tailwind.config.js
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
export default {
|
||||||
|
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||||
|
theme: {
|
||||||
|
extend: {
|
||||||
|
fontFamily: {
|
||||||
|
sans: ['Inter', 'SF Pro', 'system-ui', 'sans-serif'],
|
||||||
|
mono: ['JetBrains Mono', 'SF Mono', 'monospace'],
|
||||||
|
},
|
||||||
|
colors: {
|
||||||
|
warm: {
|
||||||
|
bg: '#FAFAF8',
|
||||||
|
card: '#FFFFFF',
|
||||||
|
hover: '#F1F0ED',
|
||||||
|
selected: '#ECEAE6',
|
||||||
|
border: '#E6E4E1',
|
||||||
|
divider: '#D8D6D2',
|
||||||
|
text: {
|
||||||
|
primary: '#121212',
|
||||||
|
secondary: '#2A2A2A',
|
||||||
|
muted: '#6B6B6B',
|
||||||
|
disabled: '#9A9A9A',
|
||||||
|
},
|
||||||
|
state: {
|
||||||
|
success: '#3E4A3A',
|
||||||
|
error: '#4A3A3A',
|
||||||
|
warning: '#4A4A3A',
|
||||||
|
info: '#3A3A3A',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
boxShadow: {
|
||||||
|
'card': '0 1px 3px rgba(0,0,0,0.08)',
|
||||||
|
'modal': '0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)',
|
||||||
|
},
|
||||||
|
animation: {
|
||||||
|
'fade-in': 'fadeIn 0.3s ease-out',
|
||||||
|
},
|
||||||
|
keyframes: {
|
||||||
|
fadeIn: {
|
||||||
|
'0%': { opacity: '0', transform: 'translateY(10px)' },
|
||||||
|
'100%': { opacity: '1', transform: 'translateY(0)' },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
plugins: [],
|
||||||
|
}
|
||||||
29
frontend/tsconfig.json
Normal file
29
frontend/tsconfig.json
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2022",
|
||||||
|
"experimentalDecorators": true,
|
||||||
|
"useDefineForClassFields": false,
|
||||||
|
"module": "ESNext",
|
||||||
|
"lib": [
|
||||||
|
"ES2022",
|
||||||
|
"DOM",
|
||||||
|
"DOM.Iterable"
|
||||||
|
],
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"types": [
|
||||||
|
"node"
|
||||||
|
],
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"isolatedModules": true,
|
||||||
|
"moduleDetection": "force",
|
||||||
|
"allowJs": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
"paths": {
|
||||||
|
"@/*": [
|
||||||
|
"./*"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"noEmit": true
|
||||||
|
}
|
||||||
|
}
|
||||||
16
frontend/vite.config.ts
Normal file
16
frontend/vite.config.ts
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import { defineConfig } from 'vite';
|
||||||
|
import react from '@vitejs/plugin-react';
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
server: {
|
||||||
|
port: 3000,
|
||||||
|
host: '0.0.0.0',
|
||||||
|
proxy: {
|
||||||
|
'/api': {
|
||||||
|
target: 'http://localhost:8000',
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
plugins: [react()],
|
||||||
|
});
|
||||||
@@ -21,3 +21,7 @@ pyyaml>=6.0 # YAML config files
|
|||||||
# Utilities
|
# Utilities
|
||||||
tqdm>=4.65.0 # Progress bars
|
tqdm>=4.65.0 # Progress bars
|
||||||
python-dotenv>=1.0.0 # Environment variable management
|
python-dotenv>=1.0.0 # Environment variable management
|
||||||
|
|
||||||
|
# Database
|
||||||
|
psycopg2-binary>=2.9.0 # PostgreSQL driver
|
||||||
|
sqlmodel>=0.0.22 # SQLModel ORM (SQLAlchemy + Pydantic)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
from config import get_db_connection_string
|
from src.config import get_db_connection_string
|
||||||
|
|
||||||
from ..normalize import normalize_field
|
from ..normalize import normalize_field
|
||||||
from ..matcher import FieldMatcher
|
from ..matcher import FieldMatcher
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
from config import get_db_connection_string
|
from src.config import get_db_connection_string
|
||||||
|
|
||||||
|
|
||||||
def load_reports_from_db() -> dict:
|
def load_reports_from_db() -> dict:
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ if sys.platform == 'win32':
|
|||||||
multiprocessing.set_start_method('spawn', force=True)
|
multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
from config import get_db_connection_string, PATHS, AUTOLABEL
|
from src.config import get_db_connection_string, PATHS, AUTOLABEL
|
||||||
|
|
||||||
# Global OCR engine for worker processes (initialized once per worker)
|
# Global OCR engine for worker processes (initialized once per worker)
|
||||||
_worker_ocr_engine = None
|
_worker_ocr_engine = None
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from psycopg2.extras import execute_values
|
|||||||
|
|
||||||
# Add project root to path
|
# Add project root to path
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
from config import get_db_connection_string, PATHS
|
from src.config import get_db_connection_string, PATHS
|
||||||
|
|
||||||
|
|
||||||
def create_tables(conn):
|
def create_tables(conn):
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
from src.config import DEFAULT_DPI
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -38,8 +41,8 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dpi',
|
'--dpi',
|
||||||
type=int,
|
type=int,
|
||||||
default=150,
|
default=DEFAULT_DPI,
|
||||||
help='DPI for PDF rendering (default: 150, must match training)'
|
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--no-fallback',
|
'--no-fallback',
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from src.config import DEFAULT_DPI
|
||||||
from src.data.db import DocumentDB
|
from src.data.db import DocumentDB
|
||||||
from src.data.csv_loader import CSVLoader
|
from src.data.csv_loader import CSVLoader
|
||||||
from src.normalize.normalizer import normalize_field
|
from src.normalize.normalizer import normalize_field
|
||||||
@@ -144,7 +145,7 @@ def process_single_document(args):
|
|||||||
ocr_engine = OCREngine()
|
ocr_engine = OCREngine()
|
||||||
for page_no in range(pdf_doc.page_count):
|
for page_no in range(pdf_doc.page_count):
|
||||||
# Render page to image
|
# Render page to image
|
||||||
img = pdf_doc.render_page(page_no, dpi=150)
|
img = pdf_doc.render_page(page_no, dpi=DEFAULT_DPI)
|
||||||
if img is None:
|
if img is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from pathlib import Path
|
|||||||
project_root = Path(__file__).parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent
|
||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from src.config import DEFAULT_DPI
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(debug: bool = False) -> None:
|
def setup_logging(debug: bool = False) -> None:
|
||||||
"""Configure logging."""
|
"""Configure logging."""
|
||||||
@@ -65,8 +67,8 @@ def parse_args() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dpi",
|
"--dpi",
|
||||||
type=int,
|
type=int,
|
||||||
default=150,
|
default=DEFAULT_DPI,
|
||||||
help="DPI for PDF rendering (must match training DPI)",
|
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
from config import PATHS
|
from src.config import DEFAULT_DPI, PATHS
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -103,8 +103,8 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dpi',
|
'--dpi',
|
||||||
type=int,
|
type=int,
|
||||||
default=150,
|
default=DEFAULT_DPI,
|
||||||
help='DPI used for rendering (default: 150, must match autolabel rendering)'
|
help=f'DPI used for rendering (default: {DEFAULT_DPI}, must match autolabel rendering)'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--export-only',
|
'--export-only',
|
||||||
|
|||||||
@@ -8,9 +8,13 @@ from pathlib import Path
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load environment variables from .env file
|
# Load environment variables from .env file
|
||||||
env_path = Path(__file__).parent / '.env'
|
# .env is at project root, config.py is in src/
|
||||||
|
env_path = Path(__file__).parent.parent / '.env'
|
||||||
load_dotenv(dotenv_path=env_path)
|
load_dotenv(dotenv_path=env_path)
|
||||||
|
|
||||||
|
# Global DPI setting - must match training DPI for optimal model performance
|
||||||
|
DEFAULT_DPI = 150
|
||||||
|
|
||||||
|
|
||||||
def _is_wsl() -> bool:
|
def _is_wsl() -> bool:
|
||||||
"""Check if running inside WSL (Windows Subsystem for Linux)."""
|
"""Check if running inside WSL (Windows Subsystem for Linux)."""
|
||||||
@@ -69,7 +73,7 @@ else:
|
|||||||
# Auto-labeling Configuration
|
# Auto-labeling Configuration
|
||||||
AUTOLABEL = {
|
AUTOLABEL = {
|
||||||
'workers': 2,
|
'workers': 2,
|
||||||
'dpi': 150,
|
'dpi': DEFAULT_DPI,
|
||||||
'min_confidence': 0.5,
|
'min_confidence': 0.5,
|
||||||
'train_ratio': 0.8,
|
'train_ratio': 0.8,
|
||||||
'val_ratio': 0.1,
|
'val_ratio': 0.1,
|
||||||
1156
src/data/admin_db.py
Normal file
1156
src/data/admin_db.py
Normal file
File diff suppressed because it is too large
Load Diff
339
src/data/admin_models.py
Normal file
339
src/data/admin_models.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
Admin API SQLModel Database Models
|
||||||
|
|
||||||
|
Defines the database schema for admin document management, annotations, and training tasks.
|
||||||
|
Includes batch upload support, training document links, and annotation history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from sqlmodel import Field, SQLModel, Column, JSON
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CSV to Field Class Mapping
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
CSV_TO_CLASS_MAPPING: dict[str, int] = {
|
||||||
|
"InvoiceNumber": 0, # invoice_number
|
||||||
|
"InvoiceDate": 1, # invoice_date
|
||||||
|
"InvoiceDueDate": 2, # invoice_due_date
|
||||||
|
"OCR": 3, # ocr_number
|
||||||
|
"Bankgiro": 4, # bankgiro
|
||||||
|
"Plusgiro": 5, # plusgiro
|
||||||
|
"Amount": 6, # amount
|
||||||
|
"supplier_organisation_number": 7, # supplier_organisation_number
|
||||||
|
# 8: payment_line (derived from OCR/Bankgiro/Amount)
|
||||||
|
"customer_number": 9, # customer_number
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Core Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AdminToken(SQLModel, table=True):
|
||||||
|
"""Admin authentication token."""
|
||||||
|
|
||||||
|
__tablename__ = "admin_tokens"
|
||||||
|
|
||||||
|
token: str = Field(primary_key=True, max_length=255)
|
||||||
|
name: str = Field(max_length=255)
|
||||||
|
is_active: bool = Field(default=True)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
last_used_at: datetime | None = Field(default=None)
|
||||||
|
expires_at: datetime | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AdminDocument(SQLModel, table=True):
|
||||||
|
"""Document uploaded for labeling/annotation."""
|
||||||
|
|
||||||
|
__tablename__ = "admin_documents"
|
||||||
|
|
||||||
|
document_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
admin_token: str | None = Field(default=None, foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||||
|
filename: str = Field(max_length=255)
|
||||||
|
file_size: int
|
||||||
|
content_type: str = Field(max_length=100)
|
||||||
|
file_path: str = Field(max_length=512) # Path to stored file
|
||||||
|
page_count: int = Field(default=1)
|
||||||
|
status: str = Field(default="pending", max_length=20, index=True)
|
||||||
|
# Status: pending, auto_labeling, labeled, exported
|
||||||
|
auto_label_status: str | None = Field(default=None, max_length=20)
|
||||||
|
# Auto-label status: running, completed, failed
|
||||||
|
auto_label_error: str | None = Field(default=None)
|
||||||
|
# v2: Upload source tracking
|
||||||
|
upload_source: str = Field(default="ui", max_length=20)
|
||||||
|
# Upload source: ui, api
|
||||||
|
batch_id: UUID | None = Field(default=None, index=True)
|
||||||
|
# Link to batch upload (if uploaded via ZIP)
|
||||||
|
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
# Original CSV values for reference
|
||||||
|
auto_label_queued_at: datetime | None = Field(default=None)
|
||||||
|
# When auto-label was queued
|
||||||
|
annotation_lock_until: datetime | None = Field(default=None)
|
||||||
|
# Lock for manual annotation while auto-label runs
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class AdminAnnotation(SQLModel, table=True):
|
||||||
|
"""Annotation for a document (bounding box + label)."""
|
||||||
|
|
||||||
|
__tablename__ = "admin_annotations"
|
||||||
|
|
||||||
|
annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||||
|
page_number: int = Field(default=1) # 1-indexed
|
||||||
|
class_id: int # 0-9 for invoice fields
|
||||||
|
class_name: str = Field(max_length=50) # e.g., "invoice_number"
|
||||||
|
# Bounding box (normalized 0-1 coordinates)
|
||||||
|
x_center: float
|
||||||
|
y_center: float
|
||||||
|
width: float
|
||||||
|
height: float
|
||||||
|
# Original pixel coordinates (for display)
|
||||||
|
bbox_x: int
|
||||||
|
bbox_y: int
|
||||||
|
bbox_width: int
|
||||||
|
bbox_height: int
|
||||||
|
# OCR extracted text (if available)
|
||||||
|
text_value: str | None = Field(default=None)
|
||||||
|
confidence: float | None = Field(default=None)
|
||||||
|
# Source: manual, auto, imported
|
||||||
|
source: str = Field(default="manual", max_length=20, index=True)
|
||||||
|
# v2: Verification fields
|
||||||
|
is_verified: bool = Field(default=False, index=True)
|
||||||
|
verified_at: datetime | None = Field(default=None)
|
||||||
|
verified_by: str | None = Field(default=None, max_length=255)
|
||||||
|
# v2: Override tracking
|
||||||
|
override_source: str | None = Field(default=None, max_length=20)
|
||||||
|
# If this annotation overrides another: 'auto' or 'imported'
|
||||||
|
original_annotation_id: UUID | None = Field(default=None)
|
||||||
|
# Reference to the annotation this overrides
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingTask(SQLModel, table=True):
|
||||||
|
"""Training/fine-tuning task."""
|
||||||
|
|
||||||
|
__tablename__ = "training_tasks"
|
||||||
|
|
||||||
|
task_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||||
|
name: str = Field(max_length=255)
|
||||||
|
description: str | None = Field(default=None)
|
||||||
|
status: str = Field(default="pending", max_length=20, index=True)
|
||||||
|
# Status: pending, scheduled, running, completed, failed, cancelled
|
||||||
|
task_type: str = Field(default="train", max_length=20)
|
||||||
|
# Task type: train, finetune
|
||||||
|
# Training configuration
|
||||||
|
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
# Schedule settings
|
||||||
|
scheduled_at: datetime | None = Field(default=None)
|
||||||
|
cron_expression: str | None = Field(default=None, max_length=50)
|
||||||
|
is_recurring: bool = Field(default=False)
|
||||||
|
# Execution details
|
||||||
|
started_at: datetime | None = Field(default=None)
|
||||||
|
completed_at: datetime | None = Field(default=None)
|
||||||
|
error_message: str | None = Field(default=None)
|
||||||
|
# Result metrics
|
||||||
|
result_metrics: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
model_path: str | None = Field(default=None, max_length=512)
|
||||||
|
# v2: Document count and extracted metrics
|
||||||
|
document_count: int = Field(default=0)
|
||||||
|
# Count of documents used in training
|
||||||
|
metrics_mAP: float | None = Field(default=None, index=True)
|
||||||
|
metrics_precision: float | None = Field(default=None)
|
||||||
|
metrics_recall: float | None = Field(default=None)
|
||||||
|
# Extracted metrics for easy querying
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingLog(SQLModel, table=True):
|
||||||
|
"""Training log entry."""
|
||||||
|
|
||||||
|
__tablename__ = "training_logs"
|
||||||
|
|
||||||
|
log_id: int | None = Field(default=None, primary_key=True)
|
||||||
|
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||||
|
level: str = Field(max_length=20) # INFO, WARNING, ERROR
|
||||||
|
message: str
|
||||||
|
details: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Batch Upload Models (v2)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class BatchUpload(SQLModel, table=True):
|
||||||
|
"""Batch upload of multiple documents via ZIP file."""
|
||||||
|
|
||||||
|
__tablename__ = "batch_uploads"
|
||||||
|
|
||||||
|
batch_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
admin_token: str = Field(foreign_key="admin_tokens.token", max_length=255, index=True)
|
||||||
|
filename: str = Field(max_length=255) # ZIP filename
|
||||||
|
file_size: int
|
||||||
|
upload_source: str = Field(default="ui", max_length=20)
|
||||||
|
# Upload source: ui, api
|
||||||
|
status: str = Field(default="processing", max_length=20, index=True)
|
||||||
|
# Status: processing, completed, partial, failed
|
||||||
|
total_files: int = Field(default=0)
|
||||||
|
processed_files: int = Field(default=0)
|
||||||
|
# Number of files processed so far
|
||||||
|
successful_files: int = Field(default=0)
|
||||||
|
failed_files: int = Field(default=0)
|
||||||
|
csv_filename: str | None = Field(default=None, max_length=255)
|
||||||
|
# CSV file used for auto-labeling
|
||||||
|
csv_row_count: int | None = Field(default=None)
|
||||||
|
error_message: str | None = Field(default=None)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
completed_at: datetime | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchUploadFile(SQLModel, table=True):
|
||||||
|
"""Individual file within a batch upload."""
|
||||||
|
|
||||||
|
__tablename__ = "batch_upload_files"
|
||||||
|
|
||||||
|
file_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
batch_id: UUID = Field(foreign_key="batch_uploads.batch_id", index=True)
|
||||||
|
filename: str = Field(max_length=255) # PDF filename within ZIP
|
||||||
|
document_id: UUID | None = Field(default=None)
|
||||||
|
# Link to created AdminDocument (if successful)
|
||||||
|
status: str = Field(default="pending", max_length=20, index=True)
|
||||||
|
# Status: pending, processing, completed, failed, skipped
|
||||||
|
error_message: str | None = Field(default=None)
|
||||||
|
annotation_count: int = Field(default=0)
|
||||||
|
# Number of annotations created for this file
|
||||||
|
csv_row_data: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
# CSV row data for this file (if available)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
processed_at: datetime | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Training Document Link (v2)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingDocumentLink(SQLModel, table=True):
|
||||||
|
"""Junction table linking training tasks to documents."""
|
||||||
|
|
||||||
|
__tablename__ = "training_document_links"
|
||||||
|
|
||||||
|
link_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
task_id: UUID = Field(foreign_key="training_tasks.task_id", index=True)
|
||||||
|
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||||
|
annotation_snapshot: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
# Snapshot of annotations at training time (includes count, verified count, etc.)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Annotation History (v2)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationHistory(SQLModel, table=True):
|
||||||
|
"""History of annotation changes (for override tracking)."""
|
||||||
|
|
||||||
|
__tablename__ = "annotation_history"
|
||||||
|
|
||||||
|
history_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
annotation_id: UUID = Field(foreign_key="admin_annotations.annotation_id", index=True)
|
||||||
|
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||||
|
# Change action: created, updated, deleted, override
|
||||||
|
action: str = Field(max_length=20, index=True)
|
||||||
|
# Previous value (for updates/deletes)
|
||||||
|
previous_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
# New value (for creates/updates)
|
||||||
|
new_value: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
# Change metadata
|
||||||
|
changed_by: str | None = Field(default=None, max_length=255)
|
||||||
|
# User/token who made the change
|
||||||
|
change_reason: str | None = Field(default=None)
|
||||||
|
# Optional reason for change
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Field class mapping (same as src/cli/train.py)
|
||||||
|
FIELD_CLASSES = {
|
||||||
|
0: "invoice_number",
|
||||||
|
1: "invoice_date",
|
||||||
|
2: "invoice_due_date",
|
||||||
|
3: "ocr_number",
|
||||||
|
4: "bankgiro",
|
||||||
|
5: "plusgiro",
|
||||||
|
6: "amount",
|
||||||
|
7: "supplier_organisation_number",
|
||||||
|
8: "payment_line",
|
||||||
|
9: "customer_number",
|
||||||
|
}
|
||||||
|
|
||||||
|
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# Read-only models for API responses
|
||||||
|
class AdminDocumentRead(SQLModel):
|
||||||
|
"""Admin document response model."""
|
||||||
|
|
||||||
|
document_id: UUID
|
||||||
|
filename: str
|
||||||
|
file_size: int
|
||||||
|
content_type: str
|
||||||
|
page_count: int
|
||||||
|
status: str
|
||||||
|
auto_label_status: str | None
|
||||||
|
auto_label_error: str | None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class AdminAnnotationRead(SQLModel):
|
||||||
|
"""Admin annotation response model."""
|
||||||
|
|
||||||
|
annotation_id: UUID
|
||||||
|
document_id: UUID
|
||||||
|
page_number: int
|
||||||
|
class_id: int
|
||||||
|
class_name: str
|
||||||
|
x_center: float
|
||||||
|
y_center: float
|
||||||
|
width: float
|
||||||
|
height: float
|
||||||
|
bbox_x: int
|
||||||
|
bbox_y: int
|
||||||
|
bbox_width: int
|
||||||
|
bbox_height: int
|
||||||
|
text_value: str | None
|
||||||
|
confidence: float | None
|
||||||
|
source: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingTaskRead(SQLModel):
|
||||||
|
"""Training task response model."""
|
||||||
|
|
||||||
|
task_id: UUID
|
||||||
|
name: str
|
||||||
|
description: str | None
|
||||||
|
status: str
|
||||||
|
task_type: str
|
||||||
|
config: dict[str, Any] | None
|
||||||
|
scheduled_at: datetime | None
|
||||||
|
is_recurring: bool
|
||||||
|
started_at: datetime | None
|
||||||
|
completed_at: datetime | None
|
||||||
|
error_message: str | None
|
||||||
|
result_metrics: dict[str, Any] | None
|
||||||
|
model_path: str | None
|
||||||
|
created_at: datetime
|
||||||
374
src/data/async_request_db.py
Normal file
374
src/data/async_request_db.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
"""
|
||||||
|
Async Request Database Operations
|
||||||
|
|
||||||
|
Database interface for async invoice processing requests using SQLModel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import func, text
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from src.data.database import get_session_context, create_db_and_tables, close_engine
|
||||||
|
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy dataclasses for backward compatibility
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ApiKeyConfig:
|
||||||
|
"""API key configuration and limits (legacy compatibility)."""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
name: str
|
||||||
|
is_active: bool
|
||||||
|
requests_per_minute: int
|
||||||
|
max_concurrent_jobs: int
|
||||||
|
max_file_size_mb: int
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncRequestDB:
|
||||||
|
"""Database interface for async processing requests using SQLModel."""
|
||||||
|
|
||||||
|
def __init__(self, connection_string: str | None = None) -> None:
|
||||||
|
# connection_string is kept for backward compatibility but ignored
|
||||||
|
# SQLModel uses the global engine from database.py
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
"""Legacy method - returns self for compatibility."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close database connections."""
|
||||||
|
close_engine()
|
||||||
|
|
||||||
|
def __enter__(self) -> "AsyncRequestDB":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
pass # Sessions are managed per-operation
|
||||||
|
|
||||||
|
def create_tables(self) -> None:
|
||||||
|
"""Create async processing tables if they don't exist."""
|
||||||
|
create_db_and_tables()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# API Key Operations
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
def is_valid_api_key(self, api_key: str) -> bool:
|
||||||
|
"""Check if API key exists and is active."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
result = session.get(ApiKey, api_key)
|
||||||
|
return result is not None and result.is_active is True
|
||||||
|
|
||||||
|
def get_api_key_config(self, api_key: str) -> ApiKeyConfig | None:
|
||||||
|
"""Get API key configuration and limits."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
result = session.get(ApiKey, api_key)
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return ApiKeyConfig(
|
||||||
|
api_key=result.api_key,
|
||||||
|
name=result.name,
|
||||||
|
is_active=result.is_active,
|
||||||
|
requests_per_minute=result.requests_per_minute,
|
||||||
|
max_concurrent_jobs=result.max_concurrent_jobs,
|
||||||
|
max_file_size_mb=result.max_file_size_mb,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_api_key(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
name: str,
|
||||||
|
requests_per_minute: int = 10,
|
||||||
|
max_concurrent_jobs: int = 3,
|
||||||
|
max_file_size_mb: int = 50,
|
||||||
|
) -> None:
|
||||||
|
"""Create a new API key."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
existing = session.get(ApiKey, api_key)
|
||||||
|
if existing:
|
||||||
|
existing.name = name
|
||||||
|
existing.requests_per_minute = requests_per_minute
|
||||||
|
existing.max_concurrent_jobs = max_concurrent_jobs
|
||||||
|
existing.max_file_size_mb = max_file_size_mb
|
||||||
|
session.add(existing)
|
||||||
|
else:
|
||||||
|
new_key = ApiKey(
|
||||||
|
api_key=api_key,
|
||||||
|
name=name,
|
||||||
|
requests_per_minute=requests_per_minute,
|
||||||
|
max_concurrent_jobs=max_concurrent_jobs,
|
||||||
|
max_file_size_mb=max_file_size_mb,
|
||||||
|
)
|
||||||
|
session.add(new_key)
|
||||||
|
|
||||||
|
def update_api_key_usage(self, api_key: str) -> None:
|
||||||
|
"""Update API key last used timestamp and increment total requests."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
key = session.get(ApiKey, api_key)
|
||||||
|
if key:
|
||||||
|
key.last_used_at = datetime.utcnow()
|
||||||
|
key.total_requests += 1
|
||||||
|
session.add(key)
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# Async Request Operations
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
def create_request(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
filename: str,
|
||||||
|
file_size: int,
|
||||||
|
content_type: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
request_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new async request."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
request = AsyncRequest(
|
||||||
|
api_key=api_key,
|
||||||
|
filename=filename,
|
||||||
|
file_size=file_size,
|
||||||
|
content_type=content_type,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
if request_id:
|
||||||
|
request.request_id = UUID(request_id)
|
||||||
|
session.add(request)
|
||||||
|
session.flush() # To get the generated ID
|
||||||
|
return str(request.request_id)
|
||||||
|
|
||||||
|
def get_request(self, request_id: str) -> AsyncRequest | None:
|
||||||
|
"""Get a single async request by ID."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
result = session.get(AsyncRequest, UUID(request_id))
|
||||||
|
if result:
|
||||||
|
# Detach from session for use outside context
|
||||||
|
session.expunge(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_request_by_api_key(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
api_key: str,
|
||||||
|
) -> AsyncRequest | None:
|
||||||
|
"""Get a request only if it belongs to the given API key."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
statement = select(AsyncRequest).where(
|
||||||
|
AsyncRequest.request_id == UUID(request_id),
|
||||||
|
AsyncRequest.api_key == api_key,
|
||||||
|
)
|
||||||
|
result = session.exec(statement).first()
|
||||||
|
if result:
|
||||||
|
session.expunge(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def update_status(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
status: str,
|
||||||
|
error_message: str | None = None,
|
||||||
|
increment_retry: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Update request status."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
request = session.get(AsyncRequest, UUID(request_id))
|
||||||
|
if request:
|
||||||
|
request.status = status
|
||||||
|
if status == "processing":
|
||||||
|
request.started_at = datetime.utcnow()
|
||||||
|
if error_message is not None:
|
||||||
|
request.error_message = error_message
|
||||||
|
if increment_retry:
|
||||||
|
request.retry_count += 1
|
||||||
|
session.add(request)
|
||||||
|
|
||||||
|
def complete_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
document_id: str,
|
||||||
|
result: dict[str, Any],
|
||||||
|
processing_time_ms: float,
|
||||||
|
visualization_path: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Mark request as completed with result."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
request = session.get(AsyncRequest, UUID(request_id))
|
||||||
|
if request:
|
||||||
|
request.status = "completed"
|
||||||
|
request.document_id = document_id
|
||||||
|
request.result = result
|
||||||
|
request.processing_time_ms = processing_time_ms
|
||||||
|
request.visualization_path = visualization_path
|
||||||
|
request.completed_at = datetime.utcnow()
|
||||||
|
session.add(request)
|
||||||
|
|
||||||
|
def get_requests_by_api_key(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
status: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[AsyncRequest], int]:
|
||||||
|
"""Get paginated requests for an API key."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
# Count query
|
||||||
|
count_stmt = select(func.count()).select_from(AsyncRequest).where(
|
||||||
|
AsyncRequest.api_key == api_key
|
||||||
|
)
|
||||||
|
if status:
|
||||||
|
count_stmt = count_stmt.where(AsyncRequest.status == status)
|
||||||
|
total = session.exec(count_stmt).one()
|
||||||
|
|
||||||
|
# Fetch query
|
||||||
|
statement = select(AsyncRequest).where(
|
||||||
|
AsyncRequest.api_key == api_key
|
||||||
|
)
|
||||||
|
if status:
|
||||||
|
statement = statement.where(AsyncRequest.status == status)
|
||||||
|
statement = statement.order_by(AsyncRequest.created_at.desc())
|
||||||
|
statement = statement.offset(offset).limit(limit)
|
||||||
|
|
||||||
|
results = session.exec(statement).all()
|
||||||
|
# Detach results from session
|
||||||
|
for r in results:
|
||||||
|
session.expunge(r)
|
||||||
|
return list(results), total
|
||||||
|
|
||||||
|
def count_active_jobs(self, api_key: str) -> int:
|
||||||
|
"""Count active (pending + processing) jobs for an API key."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||||
|
AsyncRequest.api_key == api_key,
|
||||||
|
AsyncRequest.status.in_(["pending", "processing"]),
|
||||||
|
)
|
||||||
|
return session.exec(statement).one()
|
||||||
|
|
||||||
|
def get_pending_requests(self, limit: int = 10) -> list[AsyncRequest]:
|
||||||
|
"""Get pending requests ordered by creation time."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
statement = select(AsyncRequest).where(
|
||||||
|
AsyncRequest.status == "pending"
|
||||||
|
).order_by(AsyncRequest.created_at).limit(limit)
|
||||||
|
results = session.exec(statement).all()
|
||||||
|
for r in results:
|
||||||
|
session.expunge(r)
|
||||||
|
return list(results)
|
||||||
|
|
||||||
|
def get_queue_position(self, request_id: str) -> int | None:
|
||||||
|
"""Get position of a request in the pending queue."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
# Get the request's created_at
|
||||||
|
request = session.get(AsyncRequest, UUID(request_id))
|
||||||
|
if not request:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Count pending requests created before this one
|
||||||
|
statement = select(func.count()).select_from(AsyncRequest).where(
|
||||||
|
AsyncRequest.status == "pending",
|
||||||
|
AsyncRequest.created_at < request.created_at,
|
||||||
|
)
|
||||||
|
count = session.exec(statement).one()
|
||||||
|
return count + 1 # 1-based position
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# Rate Limit Operations
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
def record_rate_limit_event(self, api_key: str, event_type: str) -> None:
|
||||||
|
"""Record a rate limit event."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
event = RateLimitEvent(
|
||||||
|
api_key=api_key,
|
||||||
|
event_type=event_type,
|
||||||
|
)
|
||||||
|
session.add(event)
|
||||||
|
|
||||||
|
def count_recent_requests(self, api_key: str, seconds: int = 60) -> int:
|
||||||
|
"""Count requests in the last N seconds."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
cutoff = datetime.utcnow() - timedelta(seconds=seconds)
|
||||||
|
statement = select(func.count()).select_from(RateLimitEvent).where(
|
||||||
|
RateLimitEvent.api_key == api_key,
|
||||||
|
RateLimitEvent.event_type == "request",
|
||||||
|
RateLimitEvent.created_at > cutoff,
|
||||||
|
)
|
||||||
|
return session.exec(statement).one()
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# Cleanup Operations
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
def delete_expired_requests(self) -> int:
|
||||||
|
"""Delete requests that have expired. Returns count of deleted rows."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
now = datetime.utcnow()
|
||||||
|
statement = select(AsyncRequest).where(AsyncRequest.expires_at < now)
|
||||||
|
expired = session.exec(statement).all()
|
||||||
|
count = len(expired)
|
||||||
|
for request in expired:
|
||||||
|
session.delete(request)
|
||||||
|
logger.info(f"Deleted {count} expired async requests")
|
||||||
|
return count
|
||||||
|
|
||||||
|
def cleanup_old_rate_limit_events(self, hours: int = 1) -> int:
|
||||||
|
"""Delete rate limit events older than N hours."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
||||||
|
statement = select(RateLimitEvent).where(
|
||||||
|
RateLimitEvent.created_at < cutoff
|
||||||
|
)
|
||||||
|
old_events = session.exec(statement).all()
|
||||||
|
count = len(old_events)
|
||||||
|
for event in old_events:
|
||||||
|
session.delete(event)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def reset_stale_processing_requests(
|
||||||
|
self,
|
||||||
|
stale_minutes: int = 10,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Reset requests stuck in 'processing' status.
|
||||||
|
|
||||||
|
Requests that have been processing for more than stale_minutes
|
||||||
|
are considered stale. They are either reset to 'pending' (if under
|
||||||
|
max_retries) or set to 'failed'.
|
||||||
|
"""
|
||||||
|
with get_session_context() as session:
|
||||||
|
cutoff = datetime.utcnow() - timedelta(minutes=stale_minutes)
|
||||||
|
reset_count = 0
|
||||||
|
|
||||||
|
# Find stale processing requests
|
||||||
|
statement = select(AsyncRequest).where(
|
||||||
|
AsyncRequest.status == "processing",
|
||||||
|
AsyncRequest.started_at < cutoff,
|
||||||
|
)
|
||||||
|
stale_requests = session.exec(statement).all()
|
||||||
|
|
||||||
|
for request in stale_requests:
|
||||||
|
if request.retry_count < max_retries:
|
||||||
|
request.status = "pending"
|
||||||
|
request.started_at = None
|
||||||
|
else:
|
||||||
|
request.status = "failed"
|
||||||
|
request.error_message = "Processing timeout after max retries"
|
||||||
|
session.add(request)
|
||||||
|
reset_count += 1
|
||||||
|
|
||||||
|
if reset_count > 0:
|
||||||
|
logger.warning(f"Reset {reset_count} stale processing requests")
|
||||||
|
return reset_count
|
||||||
103
src/data/database.py
Normal file
103
src/data/database.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Database Engine and Session Management
|
||||||
|
|
||||||
|
Provides SQLModel database engine and session handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlmodel import Session, SQLModel, create_engine
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
from src.config import get_db_connection_string
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global engine instance
|
||||||
|
_engine = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine():
|
||||||
|
"""Get or create the database engine."""
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
connection_string = get_db_connection_string()
|
||||||
|
# Convert psycopg2 format to SQLAlchemy format
|
||||||
|
if connection_string.startswith("postgresql://"):
|
||||||
|
# Already in correct format
|
||||||
|
pass
|
||||||
|
elif "host=" in connection_string:
|
||||||
|
# Convert DSN format to URL format
|
||||||
|
parts = dict(item.split("=") for item in connection_string.split())
|
||||||
|
connection_string = (
|
||||||
|
f"postgresql://{parts.get('user', '')}:{parts.get('password', '')}"
|
||||||
|
f"@{parts.get('host', 'localhost')}:{parts.get('port', '5432')}"
|
||||||
|
f"/{parts.get('dbname', 'docmaster')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_engine = create_engine(
|
||||||
|
connection_string,
|
||||||
|
echo=False, # Set to True for SQL debugging
|
||||||
|
pool_pre_ping=True, # Verify connections before use
|
||||||
|
pool_size=5,
|
||||||
|
max_overflow=10,
|
||||||
|
)
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def create_db_and_tables() -> None:
|
||||||
|
"""Create all database tables."""
|
||||||
|
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||||
|
from src.data.admin_models import ( # noqa: F401
|
||||||
|
AdminToken,
|
||||||
|
AdminDocument,
|
||||||
|
AdminAnnotation,
|
||||||
|
TrainingTask,
|
||||||
|
TrainingLog,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
logger.info("Database tables created/verified")
|
||||||
|
|
||||||
|
|
||||||
|
def get_session() -> Session:
|
||||||
|
"""Get a new database session."""
|
||||||
|
engine = get_engine()
|
||||||
|
return Session(engine)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_session_context() -> Generator[Session, None, None]:
|
||||||
|
"""Context manager for database sessions with auto-commit/rollback."""
|
||||||
|
session = get_session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def close_engine() -> None:
|
||||||
|
"""Close the database engine and release connections."""
|
||||||
|
global _engine
|
||||||
|
if _engine is not None:
|
||||||
|
_engine.dispose()
|
||||||
|
_engine = None
|
||||||
|
logger.info("Database engine closed")
|
||||||
|
|
||||||
|
|
||||||
|
def execute_raw_sql(sql: str) -> None:
|
||||||
|
"""Execute raw SQL (for migrations)."""
|
||||||
|
engine = get_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text(sql))
|
||||||
|
conn.commit()
|
||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
from config import get_db_connection_string
|
from src.config import get_db_connection_string
|
||||||
|
|
||||||
|
|
||||||
class DocumentDB:
|
class DocumentDB:
|
||||||
|
|||||||
83
src/data/migrations/001_async_tables.sql
Normal file
83
src/data/migrations/001_async_tables.sql
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
-- Async Invoice Processing Tables
|
||||||
|
-- Migration: 001_async_tables.sql
|
||||||
|
-- Created: 2024-01-15
|
||||||
|
|
||||||
|
-- API Keys table for authentication and rate limiting
|
||||||
|
CREATE TABLE IF NOT EXISTS api_keys (
|
||||||
|
api_key TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
is_active BOOLEAN DEFAULT true,
|
||||||
|
|
||||||
|
-- Rate limits
|
||||||
|
requests_per_minute INTEGER DEFAULT 10,
|
||||||
|
max_concurrent_jobs INTEGER DEFAULT 3,
|
||||||
|
max_file_size_mb INTEGER DEFAULT 50,
|
||||||
|
|
||||||
|
-- Usage tracking
|
||||||
|
total_requests INTEGER DEFAULT 0,
|
||||||
|
total_processed INTEGER DEFAULT 0,
|
||||||
|
|
||||||
|
-- Timestamps
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
last_used_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Async processing requests table
|
||||||
|
CREATE TABLE IF NOT EXISTS async_requests (
|
||||||
|
request_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
file_size INTEGER NOT NULL,
|
||||||
|
content_type TEXT NOT NULL,
|
||||||
|
|
||||||
|
-- Processing metadata
|
||||||
|
document_id TEXT,
|
||||||
|
error_message TEXT,
|
||||||
|
retry_count INTEGER DEFAULT 0,
|
||||||
|
|
||||||
|
-- Timestamps
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
started_at TIMESTAMPTZ,
|
||||||
|
completed_at TIMESTAMPTZ,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
|
||||||
|
-- Result storage (JSONB for flexibility)
|
||||||
|
result JSONB,
|
||||||
|
|
||||||
|
-- Processing time
|
||||||
|
processing_time_ms REAL,
|
||||||
|
|
||||||
|
-- Visualization path
|
||||||
|
visualization_path TEXT,
|
||||||
|
|
||||||
|
CONSTRAINT valid_status CHECK (status IN ('pending', 'processing', 'completed', 'failed'))
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Indexes for async_requests
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_async_requests_api_key ON async_requests(api_key);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_async_requests_status ON async_requests(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_async_requests_created_at ON async_requests(created_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_async_requests_expires_at ON async_requests(expires_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_async_requests_api_key_status ON async_requests(api_key, status);
|
||||||
|
|
||||||
|
-- Rate limit tracking table
|
||||||
|
CREATE TABLE IF NOT EXISTS rate_limit_events (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE,
|
||||||
|
event_type TEXT NOT NULL, -- 'request', 'complete', 'fail'
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Index for rate limiting queries (recent events only)
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_rate_limit_events_api_key_time
|
||||||
|
ON rate_limit_events(api_key, created_at DESC);
|
||||||
|
|
||||||
|
-- Cleanup old rate limit events index
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_rate_limit_events_cleanup
|
||||||
|
ON rate_limit_events(created_at);
|
||||||
|
|
||||||
|
-- Insert default API key for development/testing
|
||||||
|
INSERT INTO api_keys (api_key, name, requests_per_minute, max_concurrent_jobs)
|
||||||
|
VALUES ('dev-api-key-12345', 'Development Key', 100, 10)
|
||||||
|
ON CONFLICT (api_key) DO NOTHING;
|
||||||
5
src/data/migrations/002_nullable_admin_token.sql
Normal file
5
src/data/migrations/002_nullable_admin_token.sql
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
-- Migration: Make admin_token nullable in admin_documents table
|
||||||
|
-- This allows documents uploaded via public API to not require an admin token
|
||||||
|
|
||||||
|
ALTER TABLE admin_documents
|
||||||
|
ALTER COLUMN admin_token DROP NOT NULL;
|
||||||
95
src/data/models.py
Normal file
95
src/data/models.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""
|
||||||
|
SQLModel Database Models
|
||||||
|
|
||||||
|
Defines the database schema using SQLModel (SQLAlchemy + Pydantic).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from sqlmodel import Field, SQLModel, Column, JSON
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKey(SQLModel, table=True):
|
||||||
|
"""API key configuration and limits."""
|
||||||
|
|
||||||
|
__tablename__ = "api_keys"
|
||||||
|
|
||||||
|
api_key: str = Field(primary_key=True, max_length=255)
|
||||||
|
name: str = Field(max_length=255)
|
||||||
|
is_active: bool = Field(default=True)
|
||||||
|
requests_per_minute: int = Field(default=10)
|
||||||
|
max_concurrent_jobs: int = Field(default=3)
|
||||||
|
max_file_size_mb: int = Field(default=50)
|
||||||
|
total_requests: int = Field(default=0)
|
||||||
|
total_processed: int = Field(default=0)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
last_used_at: datetime | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncRequest(SQLModel, table=True):
|
||||||
|
"""Async request record."""
|
||||||
|
|
||||||
|
__tablename__ = "async_requests"
|
||||||
|
|
||||||
|
request_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||||
|
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||||
|
status: str = Field(default="pending", max_length=20, index=True)
|
||||||
|
filename: str = Field(max_length=255)
|
||||||
|
file_size: int
|
||||||
|
content_type: str = Field(max_length=100)
|
||||||
|
document_id: str | None = Field(default=None, max_length=100)
|
||||||
|
error_message: str | None = Field(default=None)
|
||||||
|
retry_count: int = Field(default=0)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
started_at: datetime | None = Field(default=None)
|
||||||
|
completed_at: datetime | None = Field(default=None)
|
||||||
|
expires_at: datetime = Field(index=True)
|
||||||
|
result: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
|
processing_time_ms: float | None = Field(default=None)
|
||||||
|
visualization_path: str | None = Field(default=None, max_length=255)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitEvent(SQLModel, table=True):
|
||||||
|
"""Rate limit event record."""
|
||||||
|
|
||||||
|
__tablename__ = "rate_limit_events"
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
api_key: str = Field(foreign_key="api_keys.api_key", max_length=255, index=True)
|
||||||
|
event_type: str = Field(max_length=50)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Read-only models for responses (without table=True)
|
||||||
|
class ApiKeyRead(SQLModel):
|
||||||
|
"""API key response model (read-only)."""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
name: str
|
||||||
|
is_active: bool
|
||||||
|
requests_per_minute: int
|
||||||
|
max_concurrent_jobs: int
|
||||||
|
max_file_size_mb: int
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncRequestRead(SQLModel):
|
||||||
|
"""Async request response model (read-only)."""
|
||||||
|
|
||||||
|
request_id: UUID
|
||||||
|
api_key: str
|
||||||
|
status: str
|
||||||
|
filename: str
|
||||||
|
file_size: int
|
||||||
|
content_type: str
|
||||||
|
document_id: str | None
|
||||||
|
error_message: str | None
|
||||||
|
retry_count: int
|
||||||
|
created_at: datetime
|
||||||
|
started_at: datetime | None
|
||||||
|
completed_at: datetime | None
|
||||||
|
expires_at: datetime
|
||||||
|
result: dict[str, Any] | None
|
||||||
|
processing_time_ms: float | None
|
||||||
|
visualization_path: str | None
|
||||||
@@ -12,6 +12,8 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.config import DEFAULT_DPI
|
||||||
|
|
||||||
# Global OCR instance (initialized once per GPU worker process)
|
# Global OCR instance (initialized once per GPU worker process)
|
||||||
_ocr_engine: Optional[Any] = None
|
_ocr_engine: Optional[Any] = None
|
||||||
|
|
||||||
@@ -94,7 +96,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
row_dict = task_data["row_dict"]
|
row_dict = task_data["row_dict"]
|
||||||
pdf_path = Path(task_data["pdf_path"])
|
pdf_path = Path(task_data["pdf_path"])
|
||||||
output_dir = Path(task_data["output_dir"])
|
output_dir = Path(task_data["output_dir"])
|
||||||
dpi = task_data.get("dpi", 150)
|
dpi = task_data.get("dpi", DEFAULT_DPI)
|
||||||
min_confidence = task_data.get("min_confidence", 0.5)
|
min_confidence = task_data.get("min_confidence", 0.5)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -212,7 +214,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
row_dict = task_data["row_dict"]
|
row_dict = task_data["row_dict"]
|
||||||
pdf_path = Path(task_data["pdf_path"])
|
pdf_path = Path(task_data["pdf_path"])
|
||||||
output_dir = Path(task_data["output_dir"])
|
output_dir = Path(task_data["output_dir"])
|
||||||
dpi = task_data.get("dpi", 150)
|
dpi = task_data.get("dpi", DEFAULT_DPI)
|
||||||
min_confidence = task_data.get("min_confidence", 0.5)
|
min_confidence = task_data.get("min_confidence", 0.5)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from datetime import datetime
|
|||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extras import execute_values
|
from psycopg2.extras import execute_values
|
||||||
|
|
||||||
|
from src.config import DEFAULT_DPI
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMExtractionResult:
|
class LLMExtractionResult:
|
||||||
@@ -265,7 +267,7 @@ Return ONLY the JSON object, no other text."""
|
|||||||
self,
|
self,
|
||||||
pdf_path: Path,
|
pdf_path: Path,
|
||||||
page_no: int = 0,
|
page_no: int = 0,
|
||||||
dpi: int = 150,
|
dpi: int = DEFAULT_DPI,
|
||||||
max_size_mb: float = 18.0
|
max_size_mb: float = 18.0
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
|||||||
8
src/web/admin_routes_new.py
Normal file
8
src/web/admin_routes_new.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Backward compatibility shim for admin_routes.py
|
||||||
|
|
||||||
|
DEPRECATED: Import from src.web.api.v1.admin.documents instead.
|
||||||
|
"""
|
||||||
|
from src.web.api.v1.admin.documents import *
|
||||||
|
|
||||||
|
__all__ = ["create_admin_router"]
|
||||||
0
src/web/api/__init__.py
Normal file
0
src/web/api/__init__.py
Normal file
0
src/web/api/v1/__init__.py
Normal file
0
src/web/api/v1/__init__.py
Normal file
19
src/web/api/v1/admin/__init__.py
Normal file
19
src/web/api/v1/admin/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Admin API v1
|
||||||
|
|
||||||
|
Document management, annotations, and training endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.web.api.v1.admin.annotations import create_annotation_router
|
||||||
|
from src.web.api.v1.admin.auth import create_auth_router
|
||||||
|
from src.web.api.v1.admin.documents import create_documents_router
|
||||||
|
from src.web.api.v1.admin.locks import create_locks_router
|
||||||
|
from src.web.api.v1.admin.training import create_training_router
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_annotation_router",
|
||||||
|
"create_auth_router",
|
||||||
|
"create_documents_router",
|
||||||
|
"create_locks_router",
|
||||||
|
"create_training_router",
|
||||||
|
]
|
||||||
644
src/web/api/v1/admin/annotations.py
Normal file
644
src/web/api/v1/admin/annotations.py
Normal file
@@ -0,0 +1,644 @@
|
|||||||
|
"""
|
||||||
|
Admin Annotation API Routes
|
||||||
|
|
||||||
|
FastAPI endpoints for annotation management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
from src.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||||
|
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
|
from src.web.services.autolabel import get_auto_label_service
|
||||||
|
from src.web.schemas.admin import (
|
||||||
|
AnnotationCreate,
|
||||||
|
AnnotationItem,
|
||||||
|
AnnotationListResponse,
|
||||||
|
AnnotationOverrideRequest,
|
||||||
|
AnnotationOverrideResponse,
|
||||||
|
AnnotationResponse,
|
||||||
|
AnnotationSource,
|
||||||
|
AnnotationUpdate,
|
||||||
|
AnnotationVerifyRequest,
|
||||||
|
AnnotationVerifyResponse,
|
||||||
|
AutoLabelRequest,
|
||||||
|
AutoLabelResponse,
|
||||||
|
BoundingBox,
|
||||||
|
)
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Image storage directory
|
||||||
|
ADMIN_IMAGES_DIR = Path("data/admin_images")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||||
|
"""Validate UUID format."""
|
||||||
|
try:
|
||||||
|
UUID(value)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_annotation_router() -> APIRouter:
|
||||||
|
"""Create annotation API router."""
|
||||||
|
router = APIRouter(prefix="/admin/documents", tags=["Admin Annotations"])
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Image Endpoints
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{document_id}/images/{page_number}",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Not found"},
|
||||||
|
},
|
||||||
|
summary="Get page image",
|
||||||
|
description="Get the image for a specific page.",
|
||||||
|
)
|
||||||
|
async def get_page_image(
|
||||||
|
document_id: str,
|
||||||
|
page_number: int,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Get page image."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate page number
|
||||||
|
if page_number < 1 or page_number > document.page_count:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find image file
|
||||||
|
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
|
||||||
|
if not image_path.exists():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Image for page {page_number} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=str(image_path),
|
||||||
|
media_type="image/png",
|
||||||
|
filename=f"{document.filename}_page_{page_number}.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Annotation Endpoints
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{document_id}/annotations",
|
||||||
|
response_model=AnnotationListResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="List annotations",
|
||||||
|
description="Get all annotations for a document.",
|
||||||
|
)
|
||||||
|
async def list_annotations(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
page_number: Annotated[
|
||||||
|
int | None,
|
||||||
|
Query(ge=1, description="Filter by page number"),
|
||||||
|
] = None,
|
||||||
|
) -> AnnotationListResponse:
|
||||||
|
"""List annotations for a document."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get annotations
|
||||||
|
raw_annotations = db.get_annotations_for_document(document_id, page_number)
|
||||||
|
annotations = [
|
||||||
|
AnnotationItem(
|
||||||
|
annotation_id=str(ann.annotation_id),
|
||||||
|
page_number=ann.page_number,
|
||||||
|
class_id=ann.class_id,
|
||||||
|
class_name=ann.class_name,
|
||||||
|
bbox=BoundingBox(
|
||||||
|
x=ann.bbox_x,
|
||||||
|
y=ann.bbox_y,
|
||||||
|
width=ann.bbox_width,
|
||||||
|
height=ann.bbox_height,
|
||||||
|
),
|
||||||
|
normalized_bbox={
|
||||||
|
"x_center": ann.x_center,
|
||||||
|
"y_center": ann.y_center,
|
||||||
|
"width": ann.width,
|
||||||
|
"height": ann.height,
|
||||||
|
},
|
||||||
|
text_value=ann.text_value,
|
||||||
|
confidence=ann.confidence,
|
||||||
|
source=AnnotationSource(ann.source),
|
||||||
|
created_at=ann.created_at,
|
||||||
|
)
|
||||||
|
for ann in raw_annotations
|
||||||
|
]
|
||||||
|
|
||||||
|
return AnnotationListResponse(
|
||||||
|
document_id=document_id,
|
||||||
|
page_count=document.page_count,
|
||||||
|
total_annotations=len(annotations),
|
||||||
|
annotations=annotations,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{document_id}/annotations",
|
||||||
|
response_model=AnnotationResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Create annotation",
|
||||||
|
description="Create a new annotation for a document.",
|
||||||
|
)
|
||||||
|
async def create_annotation(
|
||||||
|
document_id: str,
|
||||||
|
request: AnnotationCreate,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> AnnotationResponse:
|
||||||
|
"""Create a new annotation."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate page number
|
||||||
|
if request.page_number > document.page_count:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Page {request.page_number} exceeds document page count ({document.page_count})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get image dimensions for normalization
|
||||||
|
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
|
||||||
|
if not image_path.exists():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Image for page {request.page_number} not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
image_width, image_height = img.size
|
||||||
|
|
||||||
|
# Calculate normalized coordinates
|
||||||
|
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||||
|
y_center = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||||
|
width = request.bbox.width / image_width
|
||||||
|
height = request.bbox.height / image_height
|
||||||
|
|
||||||
|
# Get class name
|
||||||
|
class_name = FIELD_CLASSES.get(request.class_id, f"class_{request.class_id}")
|
||||||
|
|
||||||
|
# Create annotation
|
||||||
|
annotation_id = db.create_annotation(
|
||||||
|
document_id=document_id,
|
||||||
|
page_number=request.page_number,
|
||||||
|
class_id=request.class_id,
|
||||||
|
class_name=class_name,
|
||||||
|
x_center=x_center,
|
||||||
|
y_center=y_center,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
bbox_x=request.bbox.x,
|
||||||
|
bbox_y=request.bbox.y,
|
||||||
|
bbox_width=request.bbox.width,
|
||||||
|
bbox_height=request.bbox.height,
|
||||||
|
text_value=request.text_value,
|
||||||
|
source="manual",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep status as pending - user must click "Mark Complete" to finalize
|
||||||
|
# This allows user to add multiple annotations before saving to PostgreSQL
|
||||||
|
|
||||||
|
return AnnotationResponse(
|
||||||
|
annotation_id=annotation_id,
|
||||||
|
message="Annotation created successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{document_id}/annotations/{annotation_id}",
|
||||||
|
response_model=AnnotationResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Not found"},
|
||||||
|
},
|
||||||
|
summary="Update annotation",
|
||||||
|
description="Update an existing annotation.",
|
||||||
|
)
|
||||||
|
async def update_annotation(
|
||||||
|
document_id: str,
|
||||||
|
annotation_id: str,
|
||||||
|
request: AnnotationUpdate,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> AnnotationResponse:
|
||||||
|
"""Update an annotation."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
_validate_uuid(annotation_id, "annotation_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get existing annotation
|
||||||
|
annotation = db.get_annotation(annotation_id)
|
||||||
|
if annotation is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Annotation not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify annotation belongs to document
|
||||||
|
if str(annotation.document_id) != document_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Annotation does not belong to this document",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare update data
|
||||||
|
update_kwargs = {}
|
||||||
|
|
||||||
|
if request.class_id is not None:
|
||||||
|
update_kwargs["class_id"] = request.class_id
|
||||||
|
update_kwargs["class_name"] = FIELD_CLASSES.get(
|
||||||
|
request.class_id, f"class_{request.class_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.text_value is not None:
|
||||||
|
update_kwargs["text_value"] = request.text_value
|
||||||
|
|
||||||
|
if request.bbox is not None:
|
||||||
|
# Get image dimensions
|
||||||
|
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
|
||||||
|
from PIL import Image
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
image_width, image_height = img.size
|
||||||
|
|
||||||
|
# Calculate normalized coordinates
|
||||||
|
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||||
|
update_kwargs["y_center"] = (request.bbox.y + request.bbox.height / 2) / image_height
|
||||||
|
update_kwargs["width"] = request.bbox.width / image_width
|
||||||
|
update_kwargs["height"] = request.bbox.height / image_height
|
||||||
|
update_kwargs["bbox_x"] = request.bbox.x
|
||||||
|
update_kwargs["bbox_y"] = request.bbox.y
|
||||||
|
update_kwargs["bbox_width"] = request.bbox.width
|
||||||
|
update_kwargs["bbox_height"] = request.bbox.height
|
||||||
|
|
||||||
|
# Update annotation
|
||||||
|
if update_kwargs:
|
||||||
|
success = db.update_annotation(annotation_id, **update_kwargs)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to update annotation",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AnnotationResponse(
|
||||||
|
annotation_id=annotation_id,
|
||||||
|
message="Annotation updated successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{document_id}/annotations/{annotation_id}",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Not found"},
|
||||||
|
},
|
||||||
|
summary="Delete annotation",
|
||||||
|
description="Delete an annotation.",
|
||||||
|
)
|
||||||
|
async def delete_annotation(
|
||||||
|
document_id: str,
|
||||||
|
annotation_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> dict:
|
||||||
|
"""Delete an annotation."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
_validate_uuid(annotation_id, "annotation_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get existing annotation
|
||||||
|
annotation = db.get_annotation(annotation_id)
|
||||||
|
if annotation is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Annotation not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify annotation belongs to document
|
||||||
|
if str(annotation.document_id) != document_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Annotation does not belong to this document",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete annotation
|
||||||
|
db.delete_annotation(annotation_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "deleted",
|
||||||
|
"annotation_id": annotation_id,
|
||||||
|
"message": "Annotation deleted successfully",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Auto-Labeling Endpoints
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{document_id}/auto-label",
|
||||||
|
response_model=AutoLabelResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Trigger auto-labeling",
|
||||||
|
description="Trigger auto-labeling for a document using field values.",
|
||||||
|
)
|
||||||
|
async def trigger_auto_label(
|
||||||
|
document_id: str,
|
||||||
|
request: AutoLabelRequest,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> AutoLabelResponse:
|
||||||
|
"""Trigger auto-labeling for a document."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate field values
|
||||||
|
if not request.field_values:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="At least one field value is required",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run auto-labeling
|
||||||
|
service = get_auto_label_service()
|
||||||
|
result = service.auto_label_document(
|
||||||
|
document_id=document_id,
|
||||||
|
file_path=document.file_path,
|
||||||
|
field_values=request.field_values,
|
||||||
|
db=db,
|
||||||
|
replace_existing=request.replace_existing,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result["status"] == "failed":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Auto-labeling failed: {result.get('error', 'Unknown error')}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AutoLabelResponse(
|
||||||
|
document_id=document_id,
|
||||||
|
status=result["status"],
|
||||||
|
annotations_created=result["annotations_created"],
|
||||||
|
message=f"Auto-labeling completed. Created {result['annotations_created']} annotations.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{document_id}/annotations",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Delete all annotations",
|
||||||
|
description="Delete all annotations for a document (optionally filter by source).",
|
||||||
|
)
|
||||||
|
async def delete_all_annotations(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
source: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by source (manual, auto, imported)"),
|
||||||
|
] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Delete all annotations for a document."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Validate source
|
||||||
|
if source and source not in ("manual", "auto", "imported"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid source: {source}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete annotations
|
||||||
|
deleted_count = db.delete_annotations_for_document(document_id, source)
|
||||||
|
|
||||||
|
# Update document status if all annotations deleted
|
||||||
|
remaining = db.get_annotations_for_document(document_id)
|
||||||
|
if not remaining:
|
||||||
|
db.update_document_status(document_id, "pending")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "deleted",
|
||||||
|
"document_id": document_id,
|
||||||
|
"deleted_count": deleted_count,
|
||||||
|
"message": f"Deleted {deleted_count} annotations",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Phase 5: Annotation Enhancement
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{document_id}/annotations/{annotation_id}/verify",
|
||||||
|
response_model=AnnotationVerifyResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||||
|
},
|
||||||
|
summary="Verify annotation",
|
||||||
|
description="Mark an annotation as verified by a human reviewer.",
|
||||||
|
)
|
||||||
|
async def verify_annotation(
|
||||||
|
document_id: str,
|
||||||
|
annotation_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
request: AnnotationVerifyRequest = AnnotationVerifyRequest(),
|
||||||
|
) -> AnnotationVerifyResponse:
|
||||||
|
"""Verify an annotation."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
_validate_uuid(annotation_id, "annotation_id")
|
||||||
|
|
||||||
|
# Verify ownership of document
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the annotation
|
||||||
|
annotation = db.verify_annotation(annotation_id, admin_token)
|
||||||
|
if annotation is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Annotation not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AnnotationVerifyResponse(
|
||||||
|
annotation_id=annotation_id,
|
||||||
|
is_verified=annotation.is_verified,
|
||||||
|
verified_at=annotation.verified_at,
|
||||||
|
verified_by=annotation.verified_by,
|
||||||
|
message="Annotation verified successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{document_id}/annotations/{annotation_id}/override",
|
||||||
|
response_model=AnnotationOverrideResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Annotation not found"},
|
||||||
|
},
|
||||||
|
summary="Override annotation",
|
||||||
|
description="Override an auto-generated annotation with manual corrections.",
|
||||||
|
)
|
||||||
|
async def override_annotation(
|
||||||
|
document_id: str,
|
||||||
|
annotation_id: str,
|
||||||
|
request: AnnotationOverrideRequest,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> AnnotationOverrideResponse:
|
||||||
|
"""Override an auto-generated annotation."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
_validate_uuid(annotation_id, "annotation_id")
|
||||||
|
|
||||||
|
# Verify ownership of document
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build updates dict from request
|
||||||
|
updates = {}
|
||||||
|
if request.text_value is not None:
|
||||||
|
updates["text_value"] = request.text_value
|
||||||
|
if request.class_id is not None:
|
||||||
|
updates["class_id"] = request.class_id
|
||||||
|
# Update class_name if class_id changed
|
||||||
|
if request.class_id in FIELD_CLASSES:
|
||||||
|
updates["class_name"] = FIELD_CLASSES[request.class_id]
|
||||||
|
if request.class_name is not None:
|
||||||
|
updates["class_name"] = request.class_name
|
||||||
|
if request.bbox:
|
||||||
|
# Update bbox fields
|
||||||
|
if "x" in request.bbox:
|
||||||
|
updates["bbox_x"] = request.bbox["x"]
|
||||||
|
if "y" in request.bbox:
|
||||||
|
updates["bbox_y"] = request.bbox["y"]
|
||||||
|
if "width" in request.bbox:
|
||||||
|
updates["bbox_width"] = request.bbox["width"]
|
||||||
|
if "height" in request.bbox:
|
||||||
|
updates["bbox_height"] = request.bbox["height"]
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No updates provided. Specify at least one field to update.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override the annotation
|
||||||
|
annotation = db.override_annotation(
|
||||||
|
annotation_id=annotation_id,
|
||||||
|
admin_token=admin_token,
|
||||||
|
change_reason=request.reason,
|
||||||
|
**updates,
|
||||||
|
)
|
||||||
|
|
||||||
|
if annotation is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Annotation not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get history to return history_id
|
||||||
|
history_records = db.get_annotation_history(UUID(annotation_id))
|
||||||
|
latest_history = history_records[0] if history_records else None
|
||||||
|
|
||||||
|
return AnnotationOverrideResponse(
|
||||||
|
annotation_id=annotation_id,
|
||||||
|
source=annotation.source,
|
||||||
|
override_source=annotation.override_source,
|
||||||
|
original_annotation_id=str(annotation.original_annotation_id) if annotation.original_annotation_id else None,
|
||||||
|
message="Annotation overridden successfully",
|
||||||
|
history_id=str(latest_history.history_id) if latest_history else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
82
src/web/api/v1/admin/auth.py
Normal file
82
src/web/api/v1/admin/auth.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""
|
||||||
|
Admin Auth Routes
|
||||||
|
|
||||||
|
FastAPI endpoints for admin token management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
|
from src.web.schemas.admin import (
|
||||||
|
AdminTokenCreate,
|
||||||
|
AdminTokenResponse,
|
||||||
|
)
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_auth_router() -> APIRouter:
|
||||||
|
"""Create admin auth router."""
|
||||||
|
router = APIRouter(prefix="/admin/auth", tags=["Admin Auth"])
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/token",
|
||||||
|
response_model=AdminTokenResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||||
|
},
|
||||||
|
summary="Create admin token",
|
||||||
|
description="Create a new admin authentication token.",
|
||||||
|
)
|
||||||
|
async def create_token(
|
||||||
|
request: AdminTokenCreate,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> AdminTokenResponse:
|
||||||
|
"""Create a new admin token."""
|
||||||
|
# Generate secure token
|
||||||
|
token = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
# Calculate expiration
|
||||||
|
expires_at = None
|
||||||
|
if request.expires_in_days:
|
||||||
|
expires_at = datetime.utcnow() + timedelta(days=request.expires_in_days)
|
||||||
|
|
||||||
|
# Create token in database
|
||||||
|
db.create_admin_token(
|
||||||
|
token=token,
|
||||||
|
name=request.name,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AdminTokenResponse(
|
||||||
|
token=token,
|
||||||
|
name=request.name,
|
||||||
|
expires_at=expires_at,
|
||||||
|
message="Admin token created successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/token",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Revoke admin token",
|
||||||
|
description="Revoke the current admin token.",
|
||||||
|
)
|
||||||
|
async def revoke_token(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> dict:
|
||||||
|
"""Revoke the current admin token."""
|
||||||
|
db.deactivate_admin_token(admin_token)
|
||||||
|
return {
|
||||||
|
"status": "revoked",
|
||||||
|
"message": "Admin token has been revoked",
|
||||||
|
}
|
||||||
|
|
||||||
|
return router
|
||||||
551
src/web/api/v1/admin/documents.py
Normal file
551
src/web/api/v1/admin/documents.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
"""
|
||||||
|
Admin Document Routes
|
||||||
|
|
||||||
|
FastAPI endpoints for admin document management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||||
|
|
||||||
|
from src.web.config import DEFAULT_DPI, StorageConfig
|
||||||
|
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
|
from src.web.schemas.admin import (
|
||||||
|
AnnotationItem,
|
||||||
|
AnnotationSource,
|
||||||
|
AutoLabelStatus,
|
||||||
|
BoundingBox,
|
||||||
|
DocumentDetailResponse,
|
||||||
|
DocumentItem,
|
||||||
|
DocumentListResponse,
|
||||||
|
DocumentStatus,
|
||||||
|
DocumentStatsResponse,
|
||||||
|
DocumentUploadResponse,
|
||||||
|
ModelMetrics,
|
||||||
|
TrainingHistoryItem,
|
||||||
|
)
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||||
|
"""Validate UUID format."""
|
||||||
|
try:
|
||||||
|
UUID(value)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_pdf_to_images(
|
||||||
|
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||||
|
) -> None:
|
||||||
|
"""Convert PDF pages to images for annotation."""
|
||||||
|
import fitz
|
||||||
|
|
||||||
|
doc_images_dir = images_dir / document_id
|
||||||
|
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||||
|
|
||||||
|
for page_num in range(page_count):
|
||||||
|
page = pdf_doc[page_num]
|
||||||
|
# Render at configured DPI for consistency with training
|
||||||
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||||
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
|
||||||
|
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||||
|
pix.save(str(image_path))
|
||||||
|
|
||||||
|
pdf_doc.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||||
|
"""Create admin documents router."""
|
||||||
|
router = APIRouter(prefix="/admin/documents", tags=["Admin Documents"])
|
||||||
|
|
||||||
|
# Directories are created by StorageConfig.__post_init__
|
||||||
|
allowed_extensions = storage_config.allowed_extensions
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"",
|
||||||
|
response_model=DocumentUploadResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Upload document",
|
||||||
|
description="Upload a PDF or image document for labeling.",
|
||||||
|
)
|
||||||
|
async def upload_document(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
file: UploadFile = File(..., description="PDF or image file"),
|
||||||
|
auto_label: Annotated[
|
||||||
|
bool,
|
||||||
|
Query(description="Trigger auto-labeling after upload"),
|
||||||
|
] = True,
|
||||||
|
) -> DocumentUploadResponse:
|
||||||
|
"""Upload a document for labeling."""
|
||||||
|
# Validate filename
|
||||||
|
if not file.filename:
|
||||||
|
raise HTTPException(status_code=400, detail="Filename is required")
|
||||||
|
|
||||||
|
# Validate extension
|
||||||
|
file_ext = Path(file.filename).suffix.lower()
|
||||||
|
if file_ext not in allowed_extensions:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported file type: {file_ext}. "
|
||||||
|
f"Allowed: {', '.join(allowed_extensions)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
content = await file.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read uploaded file: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||||
|
|
||||||
|
# Get page count (for PDF)
|
||||||
|
page_count = 1
|
||||||
|
if file_ext == ".pdf":
|
||||||
|
try:
|
||||||
|
import fitz
|
||||||
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||||
|
page_count = len(pdf_doc)
|
||||||
|
pdf_doc.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get PDF page count: {e}")
|
||||||
|
|
||||||
|
# Create document record (token only used for auth, not stored)
|
||||||
|
document_id = db.create_document(
|
||||||
|
filename=file.filename,
|
||||||
|
file_size=len(content),
|
||||||
|
content_type=file.content_type or "application/octet-stream",
|
||||||
|
file_path="", # Will update after saving
|
||||||
|
page_count=page_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save file to admin uploads
|
||||||
|
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||||
|
try:
|
||||||
|
file_path.write_bytes(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save file: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||||
|
|
||||||
|
# Update file path in database
|
||||||
|
from src.data.database import get_session_context
|
||||||
|
from src.data.admin_models import AdminDocument
|
||||||
|
with get_session_context() as session:
|
||||||
|
doc = session.get(AdminDocument, UUID(document_id))
|
||||||
|
if doc:
|
||||||
|
doc.file_path = str(file_path)
|
||||||
|
session.add(doc)
|
||||||
|
|
||||||
|
# Convert PDF to images for annotation
|
||||||
|
if file_ext == ".pdf":
|
||||||
|
try:
|
||||||
|
_convert_pdf_to_images(
|
||||||
|
document_id, content, page_count,
|
||||||
|
storage_config.admin_images_dir, storage_config.dpi
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to convert PDF to images: {e}")
|
||||||
|
|
||||||
|
# Trigger auto-labeling if requested
|
||||||
|
auto_label_started = False
|
||||||
|
if auto_label:
|
||||||
|
# Auto-labeling will be triggered by a background task
|
||||||
|
db.update_document_status(
|
||||||
|
document_id=document_id,
|
||||||
|
status="auto_labeling",
|
||||||
|
auto_label_status="running",
|
||||||
|
)
|
||||||
|
auto_label_started = True
|
||||||
|
|
||||||
|
return DocumentUploadResponse(
|
||||||
|
document_id=document_id,
|
||||||
|
filename=file.filename,
|
||||||
|
file_size=len(content),
|
||||||
|
page_count=page_count,
|
||||||
|
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||||
|
auto_label_started=auto_label_started,
|
||||||
|
message="Document uploaded successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=DocumentListResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="List documents",
|
||||||
|
description="List all documents for the current admin.",
|
||||||
|
)
|
||||||
|
async def list_documents(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
status: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by status"),
|
||||||
|
] = None,
|
||||||
|
upload_source: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by upload source (ui or api)"),
|
||||||
|
] = None,
|
||||||
|
has_annotations: Annotated[
|
||||||
|
bool | None,
|
||||||
|
Query(description="Filter by annotation presence"),
|
||||||
|
] = None,
|
||||||
|
auto_label_status: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by auto-label status"),
|
||||||
|
] = None,
|
||||||
|
batch_id: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by batch ID"),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=1, le=100, description="Page size"),
|
||||||
|
] = 20,
|
||||||
|
offset: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=0, description="Offset"),
|
||||||
|
] = 0,
|
||||||
|
) -> DocumentListResponse:
|
||||||
|
"""List documents."""
|
||||||
|
# Validate status
|
||||||
|
if status and status not in ("pending", "auto_labeling", "labeled", "exported"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid status: {status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate upload_source
|
||||||
|
if upload_source and upload_source not in ("ui", "api"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid upload_source: {upload_source}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate auto_label_status
|
||||||
|
if auto_label_status and auto_label_status not in ("pending", "running", "completed", "failed"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid auto_label_status: {auto_label_status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
documents, total = db.get_documents_by_token(
|
||||||
|
admin_token=admin_token,
|
||||||
|
status=status,
|
||||||
|
upload_source=upload_source,
|
||||||
|
has_annotations=has_annotations,
|
||||||
|
auto_label_status=auto_label_status,
|
||||||
|
batch_id=batch_id,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get annotation counts and build items
|
||||||
|
items = []
|
||||||
|
for doc in documents:
|
||||||
|
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||||
|
|
||||||
|
# Determine if document can be annotated (not locked)
|
||||||
|
can_annotate = True
|
||||||
|
if hasattr(doc, 'annotation_lock_until') and doc.annotation_lock_until:
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
can_annotate = doc.annotation_lock_until < datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
items.append(
|
||||||
|
DocumentItem(
|
||||||
|
document_id=str(doc.document_id),
|
||||||
|
filename=doc.filename,
|
||||||
|
file_size=doc.file_size,
|
||||||
|
page_count=doc.page_count,
|
||||||
|
status=DocumentStatus(doc.status),
|
||||||
|
auto_label_status=AutoLabelStatus(doc.auto_label_status) if doc.auto_label_status else None,
|
||||||
|
annotation_count=len(annotations),
|
||||||
|
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||||
|
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||||
|
can_annotate=can_annotate,
|
||||||
|
created_at=doc.created_at,
|
||||||
|
updated_at=doc.updated_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return DocumentListResponse(
|
||||||
|
total=total,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
documents=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/stats",
|
||||||
|
response_model=DocumentStatsResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Get document statistics",
|
||||||
|
description="Get document count by status.",
|
||||||
|
)
|
||||||
|
async def get_document_stats(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> DocumentStatsResponse:
|
||||||
|
"""Get document statistics."""
|
||||||
|
counts = db.count_documents_by_status(admin_token)
|
||||||
|
|
||||||
|
return DocumentStatsResponse(
|
||||||
|
total=sum(counts.values()),
|
||||||
|
pending=counts.get("pending", 0),
|
||||||
|
auto_labeling=counts.get("auto_labeling", 0),
|
||||||
|
labeled=counts.get("labeled", 0),
|
||||||
|
exported=counts.get("exported", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{document_id}",
|
||||||
|
response_model=DocumentDetailResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Get document detail",
|
||||||
|
description="Get document details with annotations.",
|
||||||
|
)
|
||||||
|
async def get_document(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> DocumentDetailResponse:
|
||||||
|
"""Get document details."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get annotations
|
||||||
|
raw_annotations = db.get_annotations_for_document(document_id)
|
||||||
|
annotations = [
|
||||||
|
AnnotationItem(
|
||||||
|
annotation_id=str(ann.annotation_id),
|
||||||
|
page_number=ann.page_number,
|
||||||
|
class_id=ann.class_id,
|
||||||
|
class_name=ann.class_name,
|
||||||
|
bbox=BoundingBox(
|
||||||
|
x=ann.bbox_x,
|
||||||
|
y=ann.bbox_y,
|
||||||
|
width=ann.bbox_width,
|
||||||
|
height=ann.bbox_height,
|
||||||
|
),
|
||||||
|
normalized_bbox={
|
||||||
|
"x_center": ann.x_center,
|
||||||
|
"y_center": ann.y_center,
|
||||||
|
"width": ann.width,
|
||||||
|
"height": ann.height,
|
||||||
|
},
|
||||||
|
text_value=ann.text_value,
|
||||||
|
confidence=ann.confidence,
|
||||||
|
source=AnnotationSource(ann.source),
|
||||||
|
created_at=ann.created_at,
|
||||||
|
)
|
||||||
|
for ann in raw_annotations
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate image URLs
|
||||||
|
image_urls = []
|
||||||
|
for page in range(1, document.page_count + 1):
|
||||||
|
image_urls.append(f"/api/v1/admin/documents/{document_id}/images/{page}")
|
||||||
|
|
||||||
|
# Determine if document can be annotated (not locked)
|
||||||
|
can_annotate = True
|
||||||
|
annotation_lock_until = None
|
||||||
|
if hasattr(document, 'annotation_lock_until') and document.annotation_lock_until:
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
annotation_lock_until = document.annotation_lock_until
|
||||||
|
can_annotate = document.annotation_lock_until < datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Get CSV field values if available
|
||||||
|
csv_field_values = None
|
||||||
|
if hasattr(document, 'csv_field_values') and document.csv_field_values:
|
||||||
|
csv_field_values = document.csv_field_values
|
||||||
|
|
||||||
|
# Get training history (Phase 5)
|
||||||
|
training_history = []
|
||||||
|
training_links = db.get_document_training_tasks(document.document_id)
|
||||||
|
for link in training_links:
|
||||||
|
# Get task details
|
||||||
|
task = db.get_training_task(str(link.task_id))
|
||||||
|
if task:
|
||||||
|
# Build metrics
|
||||||
|
metrics = None
|
||||||
|
if task.metrics_mAP or task.metrics_precision or task.metrics_recall:
|
||||||
|
metrics = ModelMetrics(
|
||||||
|
mAP=task.metrics_mAP,
|
||||||
|
precision=task.metrics_precision,
|
||||||
|
recall=task.metrics_recall,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_history.append(
|
||||||
|
TrainingHistoryItem(
|
||||||
|
task_id=str(link.task_id),
|
||||||
|
name=task.name,
|
||||||
|
trained_at=link.created_at,
|
||||||
|
model_metrics=metrics,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return DocumentDetailResponse(
|
||||||
|
document_id=str(document.document_id),
|
||||||
|
filename=document.filename,
|
||||||
|
file_size=document.file_size,
|
||||||
|
content_type=document.content_type,
|
||||||
|
page_count=document.page_count,
|
||||||
|
status=DocumentStatus(document.status),
|
||||||
|
auto_label_status=AutoLabelStatus(document.auto_label_status) if document.auto_label_status else None,
|
||||||
|
auto_label_error=document.auto_label_error,
|
||||||
|
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||||
|
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||||
|
csv_field_values=csv_field_values,
|
||||||
|
can_annotate=can_annotate,
|
||||||
|
annotation_lock_until=annotation_lock_until,
|
||||||
|
annotations=annotations,
|
||||||
|
image_urls=image_urls,
|
||||||
|
training_history=training_history,
|
||||||
|
created_at=document.created_at,
|
||||||
|
updated_at=document.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{document_id}",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Delete document",
|
||||||
|
description="Delete a document and its annotations.",
|
||||||
|
)
|
||||||
|
async def delete_document(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> dict:
|
||||||
|
"""Delete a document."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete file
|
||||||
|
file_path = Path(document.file_path)
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
# Delete images
|
||||||
|
images_dir = ADMIN_IMAGES_DIR / document_id
|
||||||
|
if images_dir.exists():
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(images_dir)
|
||||||
|
|
||||||
|
# Delete from database
|
||||||
|
db.delete_document(document_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "deleted",
|
||||||
|
"document_id": document_id,
|
||||||
|
"message": "Document deleted successfully",
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{document_id}/status",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Update document status",
|
||||||
|
description="Update document status (e.g., mark as labeled). When marking as 'labeled', annotations are saved to PostgreSQL.",
|
||||||
|
)
|
||||||
|
async def update_document_status(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
status: Annotated[
|
||||||
|
str,
|
||||||
|
Query(description="New status"),
|
||||||
|
],
|
||||||
|
) -> dict:
|
||||||
|
"""Update document status.
|
||||||
|
|
||||||
|
When status is set to 'labeled', the annotations are automatically
|
||||||
|
saved to PostgreSQL documents/field_results tables for consistency
|
||||||
|
with CLI auto-label workflow.
|
||||||
|
"""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Validate status
|
||||||
|
if status not in ("pending", "labeled", "exported"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid status: {status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# If marking as labeled, save annotations to PostgreSQL DocumentDB
|
||||||
|
db_save_result = None
|
||||||
|
if status == "labeled":
|
||||||
|
from src.web.services.db_autolabel import save_manual_annotations_to_document_db
|
||||||
|
|
||||||
|
# Get all annotations for this document
|
||||||
|
annotations = db.get_annotations_for_document(document_id)
|
||||||
|
|
||||||
|
if annotations:
|
||||||
|
db_save_result = save_manual_annotations_to_document_db(
|
||||||
|
document=document,
|
||||||
|
annotations=annotations,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.update_document_status(document_id, status)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"status": "updated",
|
||||||
|
"document_id": document_id,
|
||||||
|
"new_status": status,
|
||||||
|
"message": "Document status updated",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include PostgreSQL save result if applicable
|
||||||
|
if db_save_result:
|
||||||
|
response["document_db_saved"] = db_save_result.get("success", False)
|
||||||
|
response["fields_saved"] = db_save_result.get("fields_saved", 0)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return router
|
||||||
184
src/web/api/v1/admin/locks.py
Normal file
184
src/web/api/v1/admin/locks.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""
|
||||||
|
Admin Document Lock Routes
|
||||||
|
|
||||||
|
FastAPI endpoints for annotation lock management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
|
from src.web.schemas.admin import (
|
||||||
|
AnnotationLockRequest,
|
||||||
|
AnnotationLockResponse,
|
||||||
|
)
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||||
|
"""Validate UUID format."""
|
||||||
|
try:
|
||||||
|
UUID(value)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_locks_router() -> APIRouter:
|
||||||
|
"""Create annotation locks router."""
|
||||||
|
router = APIRouter(prefix="/admin/documents", tags=["Admin Locks"])
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{document_id}/lock",
|
||||||
|
response_model=AnnotationLockResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
409: {"model": ErrorResponse, "description": "Document already locked"},
|
||||||
|
},
|
||||||
|
summary="Acquire annotation lock",
|
||||||
|
description="Acquire a lock on a document to prevent concurrent annotation edits.",
|
||||||
|
)
|
||||||
|
async def acquire_lock(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||||
|
) -> AnnotationLockResponse:
|
||||||
|
"""Acquire annotation lock for a document."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to acquire lock
|
||||||
|
updated_doc = db.acquire_annotation_lock(
|
||||||
|
document_id=document_id,
|
||||||
|
admin_token=admin_token,
|
||||||
|
duration_seconds=request.duration_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated_doc is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="Document is already locked. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AnnotationLockResponse(
|
||||||
|
document_id=document_id,
|
||||||
|
locked=True,
|
||||||
|
lock_expires_at=updated_doc.annotation_lock_until,
|
||||||
|
message=f"Lock acquired for {request.duration_seconds} seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{document_id}/lock",
|
||||||
|
response_model=AnnotationLockResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Release annotation lock",
|
||||||
|
description="Release the annotation lock on a document.",
|
||||||
|
)
|
||||||
|
async def release_lock(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
force: Annotated[
|
||||||
|
bool,
|
||||||
|
Query(description="Force release (admin override)"),
|
||||||
|
] = False,
|
||||||
|
) -> AnnotationLockResponse:
|
||||||
|
"""Release annotation lock for a document."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Release lock
|
||||||
|
updated_doc = db.release_annotation_lock(
|
||||||
|
document_id=document_id,
|
||||||
|
admin_token=admin_token,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated_doc is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Failed to release lock",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AnnotationLockResponse(
|
||||||
|
document_id=document_id,
|
||||||
|
locked=False,
|
||||||
|
lock_expires_at=None,
|
||||||
|
message="Lock released successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{document_id}/lock",
|
||||||
|
response_model=AnnotationLockResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
409: {"model": ErrorResponse, "description": "Lock expired or doesn't exist"},
|
||||||
|
},
|
||||||
|
summary="Extend annotation lock",
|
||||||
|
description="Extend an existing annotation lock.",
|
||||||
|
)
|
||||||
|
async def extend_lock(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
request: AnnotationLockRequest = AnnotationLockRequest(),
|
||||||
|
) -> AnnotationLockResponse:
|
||||||
|
"""Extend annotation lock for a document."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to extend lock
|
||||||
|
updated_doc = db.extend_annotation_lock(
|
||||||
|
document_id=document_id,
|
||||||
|
admin_token=admin_token,
|
||||||
|
additional_seconds=request.duration_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated_doc is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="Lock doesn't exist or has expired. Please acquire a new lock.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AnnotationLockResponse(
|
||||||
|
document_id=document_id,
|
||||||
|
locked=True,
|
||||||
|
lock_expires_at=updated_doc.annotation_lock_until,
|
||||||
|
message=f"Lock extended by {request.duration_seconds} seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
622
src/web/api/v1/admin/training.py
Normal file
622
src/web/api/v1/admin/training.py
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
"""
|
||||||
|
Admin Training API Routes
|
||||||
|
|
||||||
|
FastAPI endpoints for training task management and scheduling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated, Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
|
from src.web.schemas.admin import (
|
||||||
|
ExportRequest,
|
||||||
|
ExportResponse,
|
||||||
|
ModelMetrics,
|
||||||
|
TrainingConfig,
|
||||||
|
TrainingDocumentItem,
|
||||||
|
TrainingDocumentsResponse,
|
||||||
|
TrainingHistoryItem,
|
||||||
|
TrainingLogItem,
|
||||||
|
TrainingLogsResponse,
|
||||||
|
TrainingModelItem,
|
||||||
|
TrainingModelsResponse,
|
||||||
|
TrainingStatus,
|
||||||
|
TrainingTaskCreate,
|
||||||
|
TrainingTaskDetailResponse,
|
||||||
|
TrainingTaskItem,
|
||||||
|
TrainingTaskListResponse,
|
||||||
|
TrainingTaskResponse,
|
||||||
|
TrainingType,
|
||||||
|
)
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||||
|
"""Validate UUID format."""
|
||||||
|
try:
|
||||||
|
UUID(value)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_training_router() -> APIRouter:
|
||||||
|
"""Create training API router."""
|
||||||
|
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Training Task Endpoints
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/tasks",
|
||||||
|
response_model=TrainingTaskResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Create training task",
|
||||||
|
description="Create a new training task.",
|
||||||
|
)
|
||||||
|
async def create_training_task(
|
||||||
|
request: TrainingTaskCreate,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> TrainingTaskResponse:
|
||||||
|
"""Create a new training task."""
|
||||||
|
# Convert config to dict
|
||||||
|
config_dict = request.config.model_dump() if request.config else {}
|
||||||
|
|
||||||
|
# Create task
|
||||||
|
task_id = db.create_training_task(
|
||||||
|
admin_token=admin_token,
|
||||||
|
name=request.name,
|
||||||
|
task_type=request.task_type.value,
|
||||||
|
description=request.description,
|
||||||
|
config=config_dict,
|
||||||
|
scheduled_at=request.scheduled_at,
|
||||||
|
cron_expression=request.cron_expression,
|
||||||
|
is_recurring=bool(request.cron_expression),
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainingTaskResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
|
||||||
|
message="Training task created successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks",
|
||||||
|
response_model=TrainingTaskListResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="List training tasks",
|
||||||
|
description="List all training tasks.",
|
||||||
|
)
|
||||||
|
async def list_training_tasks(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
status: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by status"),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=1, le=100, description="Page size"),
|
||||||
|
] = 20,
|
||||||
|
offset: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=0, description="Offset"),
|
||||||
|
] = 0,
|
||||||
|
) -> TrainingTaskListResponse:
|
||||||
|
"""List training tasks."""
|
||||||
|
# Validate status
|
||||||
|
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
|
||||||
|
if status and status not in valid_statuses:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks, total = db.get_training_tasks_by_token(
|
||||||
|
admin_token=admin_token,
|
||||||
|
status=status,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
TrainingTaskItem(
|
||||||
|
task_id=str(task.task_id),
|
||||||
|
name=task.name,
|
||||||
|
task_type=TrainingType(task.task_type),
|
||||||
|
status=TrainingStatus(task.status),
|
||||||
|
scheduled_at=task.scheduled_at,
|
||||||
|
is_recurring=task.is_recurring,
|
||||||
|
started_at=task.started_at,
|
||||||
|
completed_at=task.completed_at,
|
||||||
|
created_at=task.created_at,
|
||||||
|
)
|
||||||
|
for task in tasks
|
||||||
|
]
|
||||||
|
|
||||||
|
return TrainingTaskListResponse(
|
||||||
|
total=total,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
tasks=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}",
|
||||||
|
response_model=TrainingTaskDetailResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||||
|
},
|
||||||
|
summary="Get training task detail",
|
||||||
|
description="Get training task details.",
|
||||||
|
)
|
||||||
|
async def get_training_task(
|
||||||
|
task_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> TrainingTaskDetailResponse:
|
||||||
|
"""Get training task details."""
|
||||||
|
_validate_uuid(task_id, "task_id")
|
||||||
|
|
||||||
|
task = db.get_training_task_by_token(task_id, admin_token)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Training task not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainingTaskDetailResponse(
|
||||||
|
task_id=str(task.task_id),
|
||||||
|
name=task.name,
|
||||||
|
description=task.description,
|
||||||
|
task_type=TrainingType(task.task_type),
|
||||||
|
status=TrainingStatus(task.status),
|
||||||
|
config=task.config,
|
||||||
|
scheduled_at=task.scheduled_at,
|
||||||
|
cron_expression=task.cron_expression,
|
||||||
|
is_recurring=task.is_recurring,
|
||||||
|
started_at=task.started_at,
|
||||||
|
completed_at=task.completed_at,
|
||||||
|
error_message=task.error_message,
|
||||||
|
result_metrics=task.result_metrics,
|
||||||
|
model_path=task.model_path,
|
||||||
|
created_at=task.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/tasks/{task_id}/cancel",
|
||||||
|
response_model=TrainingTaskResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||||
|
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
|
||||||
|
},
|
||||||
|
summary="Cancel training task",
|
||||||
|
description="Cancel a pending or scheduled training task.",
|
||||||
|
)
|
||||||
|
async def cancel_training_task(
|
||||||
|
task_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> TrainingTaskResponse:
|
||||||
|
"""Cancel a training task."""
|
||||||
|
_validate_uuid(task_id, "task_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
task = db.get_training_task_by_token(task_id, admin_token)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Training task not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if can be cancelled
|
||||||
|
if task.status not in ("pending", "scheduled"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail=f"Cannot cancel task with status: {task.status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cancel task
|
||||||
|
success = db.cancel_training_task(task_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to cancel training task",
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainingTaskResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
status=TrainingStatus.CANCELLED,
|
||||||
|
message="Training task cancelled successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}/logs",
|
||||||
|
response_model=TrainingLogsResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||||
|
},
|
||||||
|
summary="Get training logs",
|
||||||
|
description="Get training task logs.",
|
||||||
|
)
|
||||||
|
async def get_training_logs(
|
||||||
|
task_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=1, le=500, description="Maximum logs to return"),
|
||||||
|
] = 100,
|
||||||
|
offset: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=0, description="Offset"),
|
||||||
|
] = 0,
|
||||||
|
) -> TrainingLogsResponse:
|
||||||
|
"""Get training logs."""
|
||||||
|
_validate_uuid(task_id, "task_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
task = db.get_training_task_by_token(task_id, admin_token)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Training task not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get logs
|
||||||
|
logs = db.get_training_logs(task_id, limit, offset)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
TrainingLogItem(
|
||||||
|
level=log.level,
|
||||||
|
message=log.message,
|
||||||
|
details=log.details,
|
||||||
|
created_at=log.created_at,
|
||||||
|
)
|
||||||
|
for log in logs
|
||||||
|
]
|
||||||
|
|
||||||
|
return TrainingLogsResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
logs=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Phase 4: Training Data Management
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/documents",
|
||||||
|
response_model=TrainingDocumentsResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Get documents for training",
|
||||||
|
description="Get labeled documents available for training with filtering options.",
|
||||||
|
)
|
||||||
|
async def get_training_documents(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
has_annotations: Annotated[
|
||||||
|
bool,
|
||||||
|
Query(description="Only include documents with annotations"),
|
||||||
|
] = True,
|
||||||
|
min_annotation_count: Annotated[
|
||||||
|
int | None,
|
||||||
|
Query(ge=1, description="Minimum annotation count"),
|
||||||
|
] = None,
|
||||||
|
exclude_used_in_training: Annotated[
|
||||||
|
bool,
|
||||||
|
Query(description="Exclude documents already used in training"),
|
||||||
|
] = False,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=1, le=100, description="Page size"),
|
||||||
|
] = 100,
|
||||||
|
offset: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=0, description="Offset"),
|
||||||
|
] = 0,
|
||||||
|
) -> TrainingDocumentsResponse:
|
||||||
|
"""Get documents available for training."""
|
||||||
|
# Get documents
|
||||||
|
documents, total = db.get_documents_for_training(
|
||||||
|
admin_token=admin_token,
|
||||||
|
status="labeled",
|
||||||
|
has_annotations=has_annotations,
|
||||||
|
min_annotation_count=min_annotation_count,
|
||||||
|
exclude_used_in_training=exclude_used_in_training,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response items with annotation details and training history
|
||||||
|
items = []
|
||||||
|
for doc in documents:
|
||||||
|
# Get annotations for this document
|
||||||
|
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||||
|
|
||||||
|
# Count annotations by source
|
||||||
|
sources = {"manual": 0, "auto": 0}
|
||||||
|
for ann in annotations:
|
||||||
|
if ann.source in sources:
|
||||||
|
sources[ann.source] += 1
|
||||||
|
|
||||||
|
# Get training history
|
||||||
|
training_links = db.get_document_training_tasks(doc.document_id)
|
||||||
|
used_in_training = [str(link.task_id) for link in training_links]
|
||||||
|
|
||||||
|
items.append(
|
||||||
|
TrainingDocumentItem(
|
||||||
|
document_id=str(doc.document_id),
|
||||||
|
filename=doc.filename,
|
||||||
|
annotation_count=len(annotations),
|
||||||
|
annotation_sources=sources,
|
||||||
|
used_in_training=used_in_training,
|
||||||
|
last_modified=doc.updated_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainingDocumentsResponse(
|
||||||
|
total=total,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
documents=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models/{task_id}/download",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Model not found"},
|
||||||
|
},
|
||||||
|
summary="Download trained model",
|
||||||
|
description="Download trained model weights file.",
|
||||||
|
)
|
||||||
|
async def download_model(
|
||||||
|
task_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
):
|
||||||
|
"""Download trained model."""
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_validate_uuid(task_id, "task_id")
|
||||||
|
|
||||||
|
# Verify ownership
|
||||||
|
task = db.get_training_task_by_token(task_id, admin_token)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Training task not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if model exists
|
||||||
|
if not task.model_path:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Model file not available for this task",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_path = Path(task.model_path)
|
||||||
|
if not model_path.exists():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Model file not found on disk",
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=str(model_path),
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
filename=f"{task.name}_model.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/models",
|
||||||
|
response_model=TrainingModelsResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Get trained models",
|
||||||
|
description="Get list of trained models with metrics and download links.",
|
||||||
|
)
|
||||||
|
async def get_training_models(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
status: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by status (completed, failed, etc.)"),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=1, le=100, description="Page size"),
|
||||||
|
] = 20,
|
||||||
|
offset: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=0, description="Offset"),
|
||||||
|
] = 0,
|
||||||
|
) -> TrainingModelsResponse:
|
||||||
|
"""Get list of trained models."""
|
||||||
|
# Get training tasks
|
||||||
|
tasks, total = db.get_training_tasks_by_token(
|
||||||
|
admin_token=admin_token,
|
||||||
|
status=status if status else "completed",
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response items
|
||||||
|
items = []
|
||||||
|
for task in tasks:
|
||||||
|
# Build metrics
|
||||||
|
metrics = ModelMetrics(
|
||||||
|
mAP=task.metrics_mAP,
|
||||||
|
precision=task.metrics_precision,
|
||||||
|
recall=task.metrics_recall,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build download URL if model exists
|
||||||
|
download_url = None
|
||||||
|
if task.model_path and task.status == "completed":
|
||||||
|
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
||||||
|
|
||||||
|
items.append(
|
||||||
|
TrainingModelItem(
|
||||||
|
task_id=str(task.task_id),
|
||||||
|
name=task.name,
|
||||||
|
status=TrainingStatus(task.status),
|
||||||
|
document_count=task.document_count,
|
||||||
|
created_at=task.created_at,
|
||||||
|
completed_at=task.completed_at,
|
||||||
|
metrics=metrics,
|
||||||
|
model_path=task.model_path,
|
||||||
|
download_url=download_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainingModelsResponse(
|
||||||
|
total=total,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
models=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Export Endpoints
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/export",
|
||||||
|
response_model=ExportResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Export annotations",
|
||||||
|
description="Export annotations in YOLO format for training.",
|
||||||
|
)
|
||||||
|
async def export_annotations(
|
||||||
|
request: ExportRequest,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> ExportResponse:
|
||||||
|
"""Export annotations for training."""
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Validate format
|
||||||
|
if request.format not in ("yolo", "coco", "voc"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported export format: {request.format}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get labeled documents
|
||||||
|
documents = db.get_labeled_documents_for_export(admin_token)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No labeled documents available for export",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# YOLO format directories
|
||||||
|
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||||
|
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||||
|
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||||
|
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Calculate train/val split
|
||||||
|
total_docs = len(documents)
|
||||||
|
train_count = int(total_docs * request.split_ratio)
|
||||||
|
train_docs = documents[:train_count]
|
||||||
|
val_docs = documents[train_count:]
|
||||||
|
|
||||||
|
total_images = 0
|
||||||
|
total_annotations = 0
|
||||||
|
|
||||||
|
# Export documents
|
||||||
|
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||||
|
for doc in docs:
|
||||||
|
# Get annotations
|
||||||
|
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||||
|
|
||||||
|
if not annotations:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Export each page
|
||||||
|
for page_num in range(1, doc.page_count + 1):
|
||||||
|
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||||
|
|
||||||
|
if not page_annotations and not request.include_images:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Copy image
|
||||||
|
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||||
|
if not src_image.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||||
|
dst_image = export_dir / "images" / split / image_name
|
||||||
|
shutil.copy(src_image, dst_image)
|
||||||
|
total_images += 1
|
||||||
|
|
||||||
|
# Write YOLO label file
|
||||||
|
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||||
|
label_path = export_dir / "labels" / split / label_name
|
||||||
|
|
||||||
|
with open(label_path, "w") as f:
|
||||||
|
for ann in page_annotations:
|
||||||
|
# YOLO format: class_id x_center y_center width height
|
||||||
|
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||||
|
f.write(line)
|
||||||
|
total_annotations += 1
|
||||||
|
|
||||||
|
# Create data.yaml
|
||||||
|
from src.data.admin_models import FIELD_CLASSES
|
||||||
|
|
||||||
|
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||||
|
path: {export_dir.absolute()}
|
||||||
|
train: images/train
|
||||||
|
val: images/val
|
||||||
|
|
||||||
|
nc: {len(FIELD_CLASSES)}
|
||||||
|
names: {list(FIELD_CLASSES.values())}
|
||||||
|
"""
|
||||||
|
(export_dir / "data.yaml").write_text(yaml_content)
|
||||||
|
|
||||||
|
return ExportResponse(
|
||||||
|
status="completed",
|
||||||
|
export_path=str(export_dir),
|
||||||
|
total_images=total_images,
|
||||||
|
total_annotations=total_annotations,
|
||||||
|
train_count=len(train_docs),
|
||||||
|
val_count=len(val_docs),
|
||||||
|
message=f"Exported {total_images} images with {total_annotations} annotations",
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
0
src/web/api/v1/batch/__init__.py
Normal file
0
src/web/api/v1/batch/__init__.py
Normal file
236
src/web/api/v1/batch/routes.py
Normal file
236
src/web/api/v1/batch/routes.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
Batch Upload API Routes
|
||||||
|
|
||||||
|
Endpoints for batch uploading documents via ZIP files with CSV metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import zipfile
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||||
|
from src.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
|
||||||
|
from src.web.workers.batch_queue import BatchTask, get_batch_queue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/admin/batch", tags=["batch-upload"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload")
|
||||||
|
async def upload_batch(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
upload_source: str = Form(default="ui"),
|
||||||
|
async_mode: bool = Form(default=True),
|
||||||
|
auto_label: bool = Form(default=True),
|
||||||
|
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||||
|
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Upload a batch of documents via ZIP file.
|
||||||
|
|
||||||
|
The ZIP file can contain:
|
||||||
|
- Multiple PDF files
|
||||||
|
- Optional CSV file with field values for auto-labeling
|
||||||
|
|
||||||
|
CSV format:
|
||||||
|
- Required column: DocumentId (matches PDF filename without extension)
|
||||||
|
- Optional columns: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount,
|
||||||
|
OCR, Bankgiro, Plusgiro, customer_number, supplier_organisation_number
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: ZIP file upload
|
||||||
|
upload_source: Upload source (ui or api)
|
||||||
|
admin_token: Admin authentication token
|
||||||
|
admin_db: Admin database interface
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batch upload result with batch_id and status
|
||||||
|
"""
|
||||||
|
if not file.filename.lower().endswith('.zip'):
|
||||||
|
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
|
||||||
|
|
||||||
|
# Check compressed size
|
||||||
|
if file.size and file.size > MAX_COMPRESSED_SIZE:
|
||||||
|
max_mb = MAX_COMPRESSED_SIZE / (1024 * 1024)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"File size exceeds {max_mb:.0f}MB limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read file content
|
||||||
|
zip_content = await file.read()
|
||||||
|
|
||||||
|
# Additional security validation before processing
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(io.BytesIO(zip_content)) as test_zip:
|
||||||
|
# Quick validation of ZIP structure
|
||||||
|
test_zip.testzip()
|
||||||
|
except zipfile.BadZipFile:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid ZIP file format")
|
||||||
|
|
||||||
|
if async_mode:
|
||||||
|
# Async mode: Queue task and return immediately
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
batch_id = uuid4()
|
||||||
|
|
||||||
|
# Create batch task for background processing
|
||||||
|
task = BatchTask(
|
||||||
|
batch_id=batch_id,
|
||||||
|
admin_token=admin_token,
|
||||||
|
zip_content=zip_content,
|
||||||
|
zip_filename=file.filename,
|
||||||
|
upload_source=upload_source,
|
||||||
|
auto_label=auto_label,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Submit to queue
|
||||||
|
queue = get_batch_queue()
|
||||||
|
if not queue.submit(task):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Processing queue is full. Please try again later."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Batch upload queued: batch_id={batch_id}, "
|
||||||
|
f"filename={file.filename}, async_mode=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return 202 Accepted with batch_id and status URL
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202,
|
||||||
|
content={
|
||||||
|
"status": "accepted",
|
||||||
|
"batch_id": str(batch_id),
|
||||||
|
"message": "Batch upload queued for processing",
|
||||||
|
"status_url": f"/api/v1/admin/batch/status/{batch_id}",
|
||||||
|
"queue_depth": queue.get_queue_depth(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Sync mode: Process immediately and return results
|
||||||
|
service = BatchUploadService(admin_db)
|
||||||
|
result = service.process_zip_upload(
|
||||||
|
admin_token=admin_token,
|
||||||
|
zip_filename=file.filename,
|
||||||
|
zip_content=zip_content,
|
||||||
|
upload_source=upload_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Batch upload completed: batch_id={result.get('batch_id')}, "
|
||||||
|
f"status={result.get('status')}, files={result.get('successful_files')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing batch upload: {e}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to process batch upload. Please contact support."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status/{batch_id}")
|
||||||
|
async def get_batch_status(
|
||||||
|
batch_id: str,
|
||||||
|
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||||
|
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Get batch upload status and file processing details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_id: Batch upload ID
|
||||||
|
admin_token: Admin authentication token
|
||||||
|
admin_db: Admin database interface
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batch status with file processing details
|
||||||
|
"""
|
||||||
|
# Validate UUID format
|
||||||
|
try:
|
||||||
|
batch_uuid = UUID(batch_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid batch ID format")
|
||||||
|
|
||||||
|
# Check batch exists and verify ownership
|
||||||
|
batch = admin_db.get_batch_upload(batch_uuid)
|
||||||
|
if not batch:
|
||||||
|
raise HTTPException(status_code=404, detail="Batch not found")
|
||||||
|
|
||||||
|
# CRITICAL: Verify ownership
|
||||||
|
if batch.admin_token != admin_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="You do not have access to this batch"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now safe to return details
|
||||||
|
service = BatchUploadService(admin_db)
|
||||||
|
result = service.get_batch_status(batch_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list")
|
||||||
|
async def list_batch_uploads(
|
||||||
|
admin_token: Annotated[str, Depends(validate_admin_token)] = None,
|
||||||
|
admin_db: Annotated[AdminDB, Depends(get_admin_db)] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> dict:
|
||||||
|
"""List batch uploads for the current admin token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
admin_token: Admin authentication token
|
||||||
|
admin_db: Admin database interface
|
||||||
|
limit: Maximum number of results
|
||||||
|
offset: Offset for pagination
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of batch uploads
|
||||||
|
"""
|
||||||
|
# Validate pagination parameters
|
||||||
|
if limit < 1 or limit > 100:
|
||||||
|
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
|
||||||
|
if offset < 0:
|
||||||
|
raise HTTPException(status_code=400, detail="Offset must be non-negative")
|
||||||
|
|
||||||
|
# Get batch uploads filtered by admin token
|
||||||
|
batches, total = admin_db.get_batch_uploads_by_token(
|
||||||
|
admin_token=admin_token,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"batches": [
|
||||||
|
{
|
||||||
|
"batch_id": str(b.batch_id),
|
||||||
|
"filename": b.filename,
|
||||||
|
"status": b.status,
|
||||||
|
"total_files": b.total_files,
|
||||||
|
"successful_files": b.successful_files,
|
||||||
|
"failed_files": b.failed_files,
|
||||||
|
"created_at": b.created_at.isoformat() if b.created_at else None,
|
||||||
|
"completed_at": b.completed_at.isoformat() if b.completed_at else None,
|
||||||
|
}
|
||||||
|
for b in batches
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
}
|
||||||
16
src/web/api/v1/public/__init__.py
Normal file
16
src/web/api/v1/public/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Public API v1
|
||||||
|
|
||||||
|
Customer-facing endpoints for inference, async processing, and labeling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.web.api.v1.public.inference import create_inference_router
|
||||||
|
from src.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||||
|
from src.web.api.v1.public.labeling import create_labeling_router
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_inference_router",
|
||||||
|
"create_async_router",
|
||||||
|
"set_async_service",
|
||||||
|
"create_labeling_router",
|
||||||
|
]
|
||||||
372
src/web/api/v1/public/async_api.py
Normal file
372
src/web/api/v1/public/async_api.py
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
"""
|
||||||
|
Async API Routes
|
||||||
|
|
||||||
|
FastAPI endpoints for async invoice processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||||
|
|
||||||
|
from src.web.dependencies import (
|
||||||
|
ApiKeyDep,
|
||||||
|
AsyncDBDep,
|
||||||
|
PollRateLimitDep,
|
||||||
|
SubmitRateLimitDep,
|
||||||
|
)
|
||||||
|
from src.web.schemas.inference import (
|
||||||
|
AsyncRequestItem,
|
||||||
|
AsyncRequestsListResponse,
|
||||||
|
AsyncResultResponse,
|
||||||
|
AsyncStatus,
|
||||||
|
AsyncStatusResponse,
|
||||||
|
AsyncSubmitResponse,
|
||||||
|
DetectionResult,
|
||||||
|
InferenceResult,
|
||||||
|
)
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_request_id(request_id: str) -> None:
|
||||||
|
"""Validate that request_id is a valid UUID format."""
|
||||||
|
try:
|
||||||
|
UUID(request_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid request ID format. Must be a valid UUID.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global reference to async processing service (set during app startup)
|
||||||
|
_async_service = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_async_service(service) -> None:
|
||||||
|
"""Set the async processing service instance."""
|
||||||
|
global _async_service
|
||||||
|
_async_service = service
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_service():
|
||||||
|
"""Get the async processing service instance."""
|
||||||
|
if _async_service is None:
|
||||||
|
raise RuntimeError("AsyncProcessingService not initialized")
|
||||||
|
return _async_service
|
||||||
|
|
||||||
|
|
||||||
|
def create_async_router(allowed_extensions: tuple[str, ...]) -> APIRouter:
|
||||||
|
"""Create async API router."""
|
||||||
|
router = APIRouter(prefix="/async", tags=["Async Processing"])
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/submit",
|
||||||
|
response_model=AsyncSubmitResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||||
|
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
|
||||||
|
503: {"model": ErrorResponse, "description": "Queue full"},
|
||||||
|
},
|
||||||
|
summary="Submit PDF for async processing",
|
||||||
|
description="Submit a PDF or image file for asynchronous processing. "
|
||||||
|
"Returns a request_id that can be used to poll for results.",
|
||||||
|
)
|
||||||
|
async def submit_document(
|
||||||
|
api_key: SubmitRateLimitDep,
|
||||||
|
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||||
|
) -> AsyncSubmitResponse:
|
||||||
|
"""Submit a document for async processing."""
|
||||||
|
# Validate filename
|
||||||
|
if not file.filename:
|
||||||
|
raise HTTPException(status_code=400, detail="Filename is required")
|
||||||
|
|
||||||
|
# Validate file extension
|
||||||
|
file_ext = Path(file.filename).suffix.lower()
|
||||||
|
if file_ext not in allowed_extensions:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported file type: {file_ext}. "
|
||||||
|
f"Allowed: {', '.join(allowed_extensions)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
content = await file.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read uploaded file: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||||
|
|
||||||
|
# Check file size (get from config via service)
|
||||||
|
service = get_async_service()
|
||||||
|
max_size = service._async_config.max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_size:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"File too large. Maximum size: "
|
||||||
|
f"{service._async_config.max_file_size_mb}MB",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Submit request
|
||||||
|
result = service.submit_request(
|
||||||
|
api_key=api_key,
|
||||||
|
file_content=content,
|
||||||
|
filename=file.filename,
|
||||||
|
content_type=file.content_type or "application/octet-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
if "queue" in (result.error or "").lower():
|
||||||
|
raise HTTPException(status_code=503, detail=result.error)
|
||||||
|
raise HTTPException(status_code=500, detail=result.error)
|
||||||
|
|
||||||
|
return AsyncSubmitResponse(
|
||||||
|
status="accepted",
|
||||||
|
message="Request submitted for processing",
|
||||||
|
request_id=result.request_id,
|
||||||
|
estimated_wait_seconds=result.estimated_wait_seconds,
|
||||||
|
poll_url=f"/api/v1/async/status/{result.request_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/status/{request_id}",
|
||||||
|
response_model=AsyncStatusResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||||
|
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||||
|
},
|
||||||
|
summary="Get request status",
|
||||||
|
description="Get the current processing status of an async request.",
|
||||||
|
)
|
||||||
|
async def get_status(
|
||||||
|
request_id: str,
|
||||||
|
api_key: PollRateLimitDep,
|
||||||
|
db: AsyncDBDep,
|
||||||
|
) -> AsyncStatusResponse:
|
||||||
|
"""Get the status of an async request."""
|
||||||
|
# Validate UUID format
|
||||||
|
_validate_request_id(request_id)
|
||||||
|
|
||||||
|
# Get request from database (validates API key ownership)
|
||||||
|
request = db.get_request_by_api_key(request_id, api_key)
|
||||||
|
|
||||||
|
if request is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Request not found or does not belong to this API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get queue position for pending requests
|
||||||
|
position = None
|
||||||
|
if request.status == "pending":
|
||||||
|
position = db.get_queue_position(request_id)
|
||||||
|
|
||||||
|
# Build result URL for completed requests
|
||||||
|
result_url = None
|
||||||
|
if request.status == "completed":
|
||||||
|
result_url = f"/api/v1/async/result/{request_id}"
|
||||||
|
|
||||||
|
return AsyncStatusResponse(
|
||||||
|
request_id=str(request.request_id),
|
||||||
|
status=AsyncStatus(request.status),
|
||||||
|
filename=request.filename,
|
||||||
|
created_at=request.created_at,
|
||||||
|
started_at=request.started_at,
|
||||||
|
completed_at=request.completed_at,
|
||||||
|
position_in_queue=position,
|
||||||
|
error_message=request.error_message,
|
||||||
|
result_url=result_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/result/{request_id}",
|
||||||
|
response_model=AsyncResultResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||||
|
409: {"model": ErrorResponse, "description": "Request not completed"},
|
||||||
|
429: {"model": ErrorResponse, "description": "Polling too frequently"},
|
||||||
|
},
|
||||||
|
summary="Get extraction results",
|
||||||
|
description="Get the extraction results for a completed async request.",
|
||||||
|
)
|
||||||
|
async def get_result(
|
||||||
|
request_id: str,
|
||||||
|
api_key: PollRateLimitDep,
|
||||||
|
db: AsyncDBDep,
|
||||||
|
) -> AsyncResultResponse:
|
||||||
|
"""Get the results of a completed async request."""
|
||||||
|
# Validate UUID format
|
||||||
|
_validate_request_id(request_id)
|
||||||
|
|
||||||
|
# Get request from database (validates API key ownership)
|
||||||
|
request = db.get_request_by_api_key(request_id, api_key)
|
||||||
|
|
||||||
|
if request is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Request not found or does not belong to this API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if completed or failed
|
||||||
|
if request.status not in ("completed", "failed"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail=f"Request not yet completed. Current status: {request.status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build inference result from stored data
|
||||||
|
inference_result = None
|
||||||
|
if request.result:
|
||||||
|
# Convert detections to DetectionResult objects
|
||||||
|
detections = []
|
||||||
|
for d in request.result.get("detections", []):
|
||||||
|
detections.append(DetectionResult(
|
||||||
|
field=d.get("field", ""),
|
||||||
|
confidence=d.get("confidence", 0.0),
|
||||||
|
bbox=d.get("bbox", [0, 0, 0, 0]),
|
||||||
|
))
|
||||||
|
|
||||||
|
inference_result = InferenceResult(
|
||||||
|
document_id=request.result.get("document_id", str(request.request_id)[:8]),
|
||||||
|
success=request.result.get("success", False),
|
||||||
|
document_type=request.result.get("document_type", "invoice"),
|
||||||
|
fields=request.result.get("fields", {}),
|
||||||
|
confidence=request.result.get("confidence", {}),
|
||||||
|
detections=detections,
|
||||||
|
processing_time_ms=request.processing_time_ms or 0.0,
|
||||||
|
errors=request.result.get("errors", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build visualization URL
|
||||||
|
viz_url = None
|
||||||
|
if request.visualization_path:
|
||||||
|
viz_url = f"/api/v1/results/{request.visualization_path}"
|
||||||
|
|
||||||
|
return AsyncResultResponse(
|
||||||
|
request_id=str(request.request_id),
|
||||||
|
status=AsyncStatus(request.status),
|
||||||
|
processing_time_ms=request.processing_time_ms or 0.0,
|
||||||
|
result=inference_result,
|
||||||
|
visualization_url=viz_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/requests",
|
||||||
|
response_model=AsyncRequestsListResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||||
|
},
|
||||||
|
summary="List requests",
|
||||||
|
description="List all async requests for the authenticated API key.",
|
||||||
|
)
|
||||||
|
async def list_requests(
|
||||||
|
api_key: ApiKeyDep,
|
||||||
|
db: AsyncDBDep,
|
||||||
|
status: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by status (pending, processing, completed, failed)"),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=1, le=100, description="Maximum number of results"),
|
||||||
|
] = 20,
|
||||||
|
offset: Annotated[
|
||||||
|
int,
|
||||||
|
Query(ge=0, description="Pagination offset"),
|
||||||
|
] = 0,
|
||||||
|
) -> AsyncRequestsListResponse:
|
||||||
|
"""List all requests for the authenticated API key."""
|
||||||
|
# Validate status filter
|
||||||
|
if status and status not in ("pending", "processing", "completed", "failed"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid status filter: {status}. "
|
||||||
|
"Must be one of: pending, processing, completed, failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get requests from database
|
||||||
|
requests, total = db.get_requests_by_api_key(
|
||||||
|
api_key=api_key,
|
||||||
|
status=status,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to response items
|
||||||
|
items = [
|
||||||
|
AsyncRequestItem(
|
||||||
|
request_id=str(r.request_id),
|
||||||
|
status=AsyncStatus(r.status),
|
||||||
|
filename=r.filename,
|
||||||
|
file_size=r.file_size,
|
||||||
|
created_at=r.created_at,
|
||||||
|
completed_at=r.completed_at,
|
||||||
|
)
|
||||||
|
for r in requests
|
||||||
|
]
|
||||||
|
|
||||||
|
return AsyncRequestsListResponse(
|
||||||
|
total=total,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
requests=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/requests/{request_id}",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid API key"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Request not found"},
|
||||||
|
409: {"model": ErrorResponse, "description": "Cannot delete processing request"},
|
||||||
|
},
|
||||||
|
summary="Cancel/delete request",
|
||||||
|
description="Cancel a pending request or delete a completed/failed request.",
|
||||||
|
)
|
||||||
|
async def delete_request(
|
||||||
|
request_id: str,
|
||||||
|
api_key: ApiKeyDep,
|
||||||
|
db: AsyncDBDep,
|
||||||
|
) -> dict:
|
||||||
|
"""Delete or cancel an async request."""
|
||||||
|
# Validate UUID format
|
||||||
|
_validate_request_id(request_id)
|
||||||
|
|
||||||
|
# Get request from database
|
||||||
|
request = db.get_request_by_api_key(request_id, api_key)
|
||||||
|
|
||||||
|
if request is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Request not found or does not belong to this API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cannot delete processing requests
|
||||||
|
if request.status == "processing":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="Cannot delete a request that is currently processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete from database (will cascade delete related records)
|
||||||
|
conn = db.connect()
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"DELETE FROM async_requests WHERE request_id = %s",
|
||||||
|
(request_id,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "deleted",
|
||||||
|
"request_id": request_id,
|
||||||
|
"message": "Request deleted successfully",
|
||||||
|
}
|
||||||
|
|
||||||
|
return router
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
API Routes
|
Inference API Routes
|
||||||
|
|
||||||
FastAPI route definitions for the inference API.
|
FastAPI route definitions for the inference API.
|
||||||
"""
|
"""
|
||||||
@@ -15,23 +15,22 @@ from typing import TYPE_CHECKING
|
|||||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from .schemas import (
|
from src.web.schemas.inference import (
|
||||||
BatchInferenceResponse,
|
|
||||||
DetectionResult,
|
DetectionResult,
|
||||||
ErrorResponse,
|
|
||||||
HealthResponse,
|
HealthResponse,
|
||||||
InferenceResponse,
|
InferenceResponse,
|
||||||
InferenceResult,
|
InferenceResult,
|
||||||
)
|
)
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .services import InferenceService
|
from src.web.services import InferenceService
|
||||||
from .config import StorageConfig
|
from src.web.config import StorageConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_api_router(
|
def create_inference_router(
|
||||||
inference_service: "InferenceService",
|
inference_service: "InferenceService",
|
||||||
storage_config: "StorageConfig",
|
storage_config: "StorageConfig",
|
||||||
) -> APIRouter:
|
) -> APIRouter:
|
||||||
203
src/web/api/v1/public/labeling.py
Normal file
203
src/web/api/v1/public/labeling.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""
|
||||||
|
Labeling API Routes
|
||||||
|
|
||||||
|
FastAPI endpoints for pre-labeling documents with expected field values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||||
|
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
from src.web.schemas.labeling import PreLabelResponse
|
||||||
|
from src.web.schemas.common import ErrorResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.web.services import InferenceService
|
||||||
|
from src.web.config import StorageConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Storage directory for pre-label uploads (legacy, now uses storage_config)
|
||||||
|
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_pdf_to_images(
|
||||||
|
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
||||||
|
) -> None:
|
||||||
|
"""Convert PDF pages to images for annotation."""
|
||||||
|
import fitz
|
||||||
|
|
||||||
|
doc_images_dir = images_dir / document_id
|
||||||
|
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||||
|
|
||||||
|
for page_num in range(page_count):
|
||||||
|
page = pdf_doc[page_num]
|
||||||
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||||
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
|
||||||
|
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
||||||
|
pix.save(str(image_path))
|
||||||
|
|
||||||
|
pdf_doc.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_admin_db() -> AdminDB:
|
||||||
|
"""Get admin database instance."""
|
||||||
|
return AdminDB()
|
||||||
|
|
||||||
|
|
||||||
|
def create_labeling_router(
|
||||||
|
inference_service: "InferenceService",
|
||||||
|
storage_config: "StorageConfig",
|
||||||
|
) -> APIRouter:
|
||||||
|
"""
|
||||||
|
Create API router with labeling endpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_service: Inference service instance
|
||||||
|
storage_config: Storage configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured APIRouter
|
||||||
|
"""
|
||||||
|
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
||||||
|
|
||||||
|
# Ensure upload directory exists
|
||||||
|
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/pre-label",
|
||||||
|
response_model=PreLabelResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Invalid file or field values"},
|
||||||
|
500: {"model": ErrorResponse, "description": "Processing error"},
|
||||||
|
},
|
||||||
|
summary="Pre-label document with expected values",
|
||||||
|
description="Upload a document with expected field values for pre-labeling. Returns document_id for result retrieval.",
|
||||||
|
)
|
||||||
|
async def pre_label(
|
||||||
|
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||||
|
field_values: str = Form(
|
||||||
|
...,
|
||||||
|
description="JSON object with expected field values. "
|
||||||
|
"Keys: InvoiceNumber, InvoiceDate, InvoiceDueDate, Amount, OCR, "
|
||||||
|
"Bankgiro, Plusgiro, customer_number, supplier_organisation_number",
|
||||||
|
),
|
||||||
|
db: AdminDB = Depends(get_admin_db),
|
||||||
|
) -> PreLabelResponse:
|
||||||
|
"""
|
||||||
|
Upload a document with expected field values for pre-labeling.
|
||||||
|
|
||||||
|
Returns document_id which can be used to retrieve results later.
|
||||||
|
|
||||||
|
Example field_values JSON:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"InvoiceNumber": "12345",
|
||||||
|
"Amount": "1500.00",
|
||||||
|
"Bankgiro": "123-4567",
|
||||||
|
"OCR": "1234567890"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Parse field_values JSON
|
||||||
|
try:
|
||||||
|
expected_values = json.loads(field_values)
|
||||||
|
if not isinstance(expected_values, dict):
|
||||||
|
raise ValueError("field_values must be a JSON object")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Invalid JSON in field_values: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate file extension
|
||||||
|
if not file.filename:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Filename is required",
|
||||||
|
)
|
||||||
|
|
||||||
|
file_ext = Path(file.filename).suffix.lower()
|
||||||
|
if file_ext not in storage_config.allowed_extensions:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
content = await file.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read uploaded file: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to read file",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get page count for PDF
|
||||||
|
page_count = 1
|
||||||
|
if file_ext == ".pdf":
|
||||||
|
try:
|
||||||
|
import fitz
|
||||||
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||||
|
page_count = len(pdf_doc)
|
||||||
|
pdf_doc.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get PDF page count: {e}")
|
||||||
|
|
||||||
|
# Create document record with field_values
|
||||||
|
document_id = db.create_document(
|
||||||
|
filename=file.filename,
|
||||||
|
file_size=len(content),
|
||||||
|
content_type=file.content_type or "application/octet-stream",
|
||||||
|
file_path="", # Will update after saving
|
||||||
|
page_count=page_count,
|
||||||
|
upload_source="api",
|
||||||
|
csv_field_values=expected_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save file to admin uploads
|
||||||
|
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
||||||
|
try:
|
||||||
|
file_path.write_bytes(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save file: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to save file",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update file path in database
|
||||||
|
db.update_document_file_path(document_id, str(file_path))
|
||||||
|
|
||||||
|
# Convert PDF to images for annotation UI
|
||||||
|
if file_ext == ".pdf":
|
||||||
|
try:
|
||||||
|
_convert_pdf_to_images(
|
||||||
|
document_id, content, page_count,
|
||||||
|
storage_config.admin_images_dir, storage_config.dpi
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to convert PDF to images: {e}")
|
||||||
|
|
||||||
|
# Trigger auto-labeling
|
||||||
|
db.update_document_status(
|
||||||
|
document_id=document_id,
|
||||||
|
status="auto_labeling",
|
||||||
|
auto_label_status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Pre-label document {document_id} created with {len(expected_values)} expected fields")
|
||||||
|
|
||||||
|
return PreLabelResponse(document_id=document_id)
|
||||||
|
|
||||||
|
return router
|
||||||
158
src/web/app.py
158
src/web/app.py
@@ -17,8 +17,39 @@ from fastapi.staticfiles import StaticFiles
|
|||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
from .config import AppConfig, default_config
|
from .config import AppConfig, default_config
|
||||||
from .routes import create_api_router
|
from src.web.services import InferenceService
|
||||||
from .services import InferenceService
|
|
||||||
|
# Public API imports
|
||||||
|
from src.web.api.v1.public import (
|
||||||
|
create_inference_router,
|
||||||
|
create_async_router,
|
||||||
|
set_async_service,
|
||||||
|
create_labeling_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Async processing imports
|
||||||
|
from src.data.async_request_db import AsyncRequestDB
|
||||||
|
from src.web.workers.async_queue import AsyncTaskQueue
|
||||||
|
from src.web.services.async_processing import AsyncProcessingService
|
||||||
|
from src.web.dependencies import init_dependencies
|
||||||
|
from src.web.core.rate_limiter import RateLimiter
|
||||||
|
|
||||||
|
# Admin API imports
|
||||||
|
from src.web.api.v1.admin import (
|
||||||
|
create_annotation_router,
|
||||||
|
create_auth_router,
|
||||||
|
create_documents_router,
|
||||||
|
create_locks_router,
|
||||||
|
create_training_router,
|
||||||
|
)
|
||||||
|
from src.web.core.scheduler import start_scheduler, stop_scheduler
|
||||||
|
from src.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
|
||||||
|
|
||||||
|
# Batch upload imports
|
||||||
|
from src.web.api.v1.batch.routes import router as batch_upload_router
|
||||||
|
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||||
|
from src.web.services.batch_upload import BatchUploadService
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -44,11 +75,38 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
|||||||
storage_config=config.storage,
|
storage_config=config.storage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create async processing components
|
||||||
|
async_db = AsyncRequestDB()
|
||||||
|
rate_limiter = RateLimiter(async_db)
|
||||||
|
task_queue = AsyncTaskQueue(
|
||||||
|
max_size=config.async_processing.queue_max_size,
|
||||||
|
worker_count=config.async_processing.worker_count,
|
||||||
|
)
|
||||||
|
async_service = AsyncProcessingService(
|
||||||
|
inference_service=inference_service,
|
||||||
|
db=async_db,
|
||||||
|
queue=task_queue,
|
||||||
|
rate_limiter=rate_limiter,
|
||||||
|
async_config=config.async_processing,
|
||||||
|
storage_config=config.storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize dependencies for FastAPI
|
||||||
|
init_dependencies(async_db, rate_limiter)
|
||||||
|
set_async_service(async_service)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""Application lifespan manager."""
|
"""Application lifespan manager."""
|
||||||
logger.info("Starting Invoice Inference API...")
|
logger.info("Starting Invoice Inference API...")
|
||||||
|
|
||||||
|
# Initialize database tables
|
||||||
|
try:
|
||||||
|
async_db.create_tables()
|
||||||
|
logger.info("Async database tables ready")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize async database: {e}")
|
||||||
|
|
||||||
# Initialize inference service on startup
|
# Initialize inference service on startup
|
||||||
try:
|
try:
|
||||||
inference_service.initialize()
|
inference_service.initialize()
|
||||||
@@ -57,10 +115,75 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
|||||||
logger.error(f"Failed to initialize inference service: {e}")
|
logger.error(f"Failed to initialize inference service: {e}")
|
||||||
# Continue anyway - service will retry on first request
|
# Continue anyway - service will retry on first request
|
||||||
|
|
||||||
|
# Start async processing service
|
||||||
|
try:
|
||||||
|
async_service.start()
|
||||||
|
logger.info("Async processing service started")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start async processing: {e}")
|
||||||
|
|
||||||
|
# Start batch upload queue
|
||||||
|
try:
|
||||||
|
admin_db = AdminDB()
|
||||||
|
batch_service = BatchUploadService(admin_db)
|
||||||
|
init_batch_queue(batch_service)
|
||||||
|
logger.info("Batch upload queue started")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start batch upload queue: {e}")
|
||||||
|
|
||||||
|
# Start training scheduler
|
||||||
|
try:
|
||||||
|
start_scheduler()
|
||||||
|
logger.info("Training scheduler started")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start training scheduler: {e}")
|
||||||
|
|
||||||
|
# Start auto-label scheduler
|
||||||
|
try:
|
||||||
|
start_autolabel_scheduler()
|
||||||
|
logger.info("AutoLabel scheduler started")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start autolabel scheduler: {e}")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
logger.info("Shutting down Invoice Inference API...")
|
logger.info("Shutting down Invoice Inference API...")
|
||||||
|
|
||||||
|
# Stop auto-label scheduler
|
||||||
|
try:
|
||||||
|
stop_autolabel_scheduler()
|
||||||
|
logger.info("AutoLabel scheduler stopped")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping autolabel scheduler: {e}")
|
||||||
|
|
||||||
|
# Stop training scheduler
|
||||||
|
try:
|
||||||
|
stop_scheduler()
|
||||||
|
logger.info("Training scheduler stopped")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping training scheduler: {e}")
|
||||||
|
|
||||||
|
# Stop batch upload queue
|
||||||
|
try:
|
||||||
|
shutdown_batch_queue()
|
||||||
|
logger.info("Batch upload queue stopped")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping batch upload queue: {e}")
|
||||||
|
|
||||||
|
# Stop async processing service
|
||||||
|
try:
|
||||||
|
async_service.stop(timeout=30.0)
|
||||||
|
logger.info("Async processing service stopped")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping async service: {e}")
|
||||||
|
|
||||||
|
# Close database connection
|
||||||
|
try:
|
||||||
|
async_db.close()
|
||||||
|
logger.info("Database connection closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing database: {e}")
|
||||||
|
|
||||||
# Create FastAPI app
|
# Create FastAPI app
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Invoice Field Extraction API",
|
title="Invoice Field Extraction API",
|
||||||
@@ -106,9 +229,34 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
|||||||
name="results",
|
name="results",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include API routes
|
# Include public API routes
|
||||||
api_router = create_api_router(inference_service, config.storage)
|
inference_router = create_inference_router(inference_service, config.storage)
|
||||||
app.include_router(api_router)
|
app.include_router(inference_router)
|
||||||
|
|
||||||
|
async_router = create_async_router(config.storage.allowed_extensions)
|
||||||
|
app.include_router(async_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
labeling_router = create_labeling_router(inference_service, config.storage)
|
||||||
|
app.include_router(labeling_router)
|
||||||
|
|
||||||
|
# Include admin API routes
|
||||||
|
auth_router = create_auth_router()
|
||||||
|
app.include_router(auth_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
documents_router = create_documents_router(config.storage)
|
||||||
|
app.include_router(documents_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
locks_router = create_locks_router()
|
||||||
|
app.include_router(locks_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
annotation_router = create_annotation_router()
|
||||||
|
app.include_router(annotation_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
training_router = create_training_router()
|
||||||
|
app.include_router(training_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
# Include batch upload routes
|
||||||
|
app.include_router(batch_upload_router)
|
||||||
|
|
||||||
# Root endpoint - serve HTML UI
|
# Root endpoint - serve HTML UI
|
||||||
@app.get("/", response_class=HTMLResponse)
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from dataclasses import dataclass, field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from src.config import DEFAULT_DPI, PATHS
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@@ -16,7 +18,7 @@ class ModelConfig:
|
|||||||
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
||||||
confidence_threshold: float = 0.5
|
confidence_threshold: float = 0.5
|
||||||
use_gpu: bool = True
|
use_gpu: bool = True
|
||||||
dpi: int = 150
|
dpi: int = DEFAULT_DPI
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -32,19 +34,59 @@ class ServerConfig:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class StorageConfig:
|
class StorageConfig:
|
||||||
"""File storage configuration."""
|
"""File storage configuration.
|
||||||
|
|
||||||
|
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
|
||||||
|
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
|
||||||
|
and avoids storing duplicate files.
|
||||||
|
"""
|
||||||
|
|
||||||
upload_dir: Path = Path("uploads")
|
upload_dir: Path = Path("uploads")
|
||||||
result_dir: Path = Path("results")
|
result_dir: Path = Path("results")
|
||||||
|
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
|
||||||
|
admin_images_dir: Path = Path("data/admin_images")
|
||||||
max_file_size_mb: int = 50
|
max_file_size_mb: int = 50
|
||||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||||
|
dpi: int = DEFAULT_DPI
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Create directories if they don't exist."""
|
"""Create directories if they don't exist."""
|
||||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||||
|
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
||||||
|
object.__setattr__(self, "admin_images_dir", Path(self.admin_images_dir))
|
||||||
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.result_dir.mkdir(parents=True, exist_ok=True)
|
self.result_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.admin_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AsyncConfig:
|
||||||
|
"""Async processing configuration."""
|
||||||
|
|
||||||
|
# Queue settings
|
||||||
|
queue_max_size: int = 100
|
||||||
|
worker_count: int = 1
|
||||||
|
task_timeout_seconds: int = 300
|
||||||
|
|
||||||
|
# Rate limiting defaults
|
||||||
|
default_requests_per_minute: int = 10
|
||||||
|
default_max_concurrent_jobs: int = 3
|
||||||
|
default_min_poll_interval_ms: int = 1000
|
||||||
|
|
||||||
|
# Storage
|
||||||
|
result_retention_days: int = 7
|
||||||
|
temp_upload_dir: Path = Path("uploads/async")
|
||||||
|
max_file_size_mb: int = 50
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
cleanup_interval_hours: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Create directories if they don't exist."""
|
||||||
|
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
||||||
|
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -54,6 +96,7 @@ class AppConfig:
|
|||||||
model: ModelConfig = field(default_factory=ModelConfig)
|
model: ModelConfig = field(default_factory=ModelConfig)
|
||||||
server: ServerConfig = field(default_factory=ServerConfig)
|
server: ServerConfig = field(default_factory=ServerConfig)
|
||||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||||
|
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||||
@@ -62,6 +105,7 @@ class AppConfig:
|
|||||||
model=ModelConfig(**config_dict.get("model", {})),
|
model=ModelConfig(**config_dict.get("model", {})),
|
||||||
server=ServerConfig(**config_dict.get("server", {})),
|
server=ServerConfig(**config_dict.get("server", {})),
|
||||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
storage=StorageConfig(**config_dict.get("storage", {})),
|
||||||
|
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
28
src/web/core/__init__.py
Normal file
28
src/web/core/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Core Components
|
||||||
|
|
||||||
|
Reusable core functionality: authentication, rate limiting, scheduling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
|
||||||
|
from src.web.core.rate_limiter import RateLimiter
|
||||||
|
from src.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
|
||||||
|
from src.web.core.autolabel_scheduler import (
|
||||||
|
start_autolabel_scheduler,
|
||||||
|
stop_autolabel_scheduler,
|
||||||
|
get_autolabel_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"validate_admin_token",
|
||||||
|
"get_admin_db",
|
||||||
|
"AdminTokenDep",
|
||||||
|
"AdminDBDep",
|
||||||
|
"RateLimiter",
|
||||||
|
"start_scheduler",
|
||||||
|
"stop_scheduler",
|
||||||
|
"get_training_scheduler",
|
||||||
|
"start_autolabel_scheduler",
|
||||||
|
"stop_autolabel_scheduler",
|
||||||
|
"get_autolabel_scheduler",
|
||||||
|
]
|
||||||
60
src/web/core/auth.py
Normal file
60
src/web/core/auth.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
Admin Authentication
|
||||||
|
|
||||||
|
FastAPI dependencies for admin token authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, Header, HTTPException
|
||||||
|
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
from src.data.database import get_session_context
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global AdminDB instance
|
||||||
|
_admin_db: AdminDB | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_admin_db() -> AdminDB:
|
||||||
|
"""Get the AdminDB instance."""
|
||||||
|
global _admin_db
|
||||||
|
if _admin_db is None:
|
||||||
|
_admin_db = AdminDB()
|
||||||
|
return _admin_db
|
||||||
|
|
||||||
|
|
||||||
|
def reset_admin_db() -> None:
|
||||||
|
"""Reset the AdminDB instance (for testing)."""
|
||||||
|
global _admin_db
|
||||||
|
_admin_db = None
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_admin_token(
|
||||||
|
x_admin_token: Annotated[str | None, Header()] = None,
|
||||||
|
admin_db: AdminDB = Depends(get_admin_db),
|
||||||
|
) -> str:
|
||||||
|
"""Validate admin token from header."""
|
||||||
|
if not x_admin_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Admin token required. Provide X-Admin-Token header.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Invalid or expired admin token.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update last used timestamp
|
||||||
|
admin_db.update_admin_token_usage(x_admin_token)
|
||||||
|
|
||||||
|
return x_admin_token
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for dependency injection
|
||||||
|
AdminTokenDep = Annotated[str, Depends(validate_admin_token)]
|
||||||
|
AdminDBDep = Annotated[AdminDB, Depends(get_admin_db)]
|
||||||
153
src/web/core/autolabel_scheduler.py
Normal file
153
src/web/core/autolabel_scheduler.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
Auto-Label Scheduler
|
||||||
|
|
||||||
|
Background scheduler for processing documents pending auto-labeling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
from src.web.services.db_autolabel import (
|
||||||
|
get_pending_autolabel_documents,
|
||||||
|
process_document_autolabel,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoLabelScheduler:
|
||||||
|
"""Scheduler for auto-labeling tasks."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
check_interval_seconds: int = 10,
|
||||||
|
batch_size: int = 5,
|
||||||
|
output_dir: Path | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize auto-label scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
check_interval_seconds: Interval to check for pending tasks
|
||||||
|
batch_size: Number of documents to process per batch
|
||||||
|
output_dir: Output directory for temporary files
|
||||||
|
"""
|
||||||
|
self._check_interval = check_interval_seconds
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._output_dir = output_dir or Path("data/autolabel_output")
|
||||||
|
self._running = False
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._db = AdminDB()
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the scheduler."""
|
||||||
|
if self._running:
|
||||||
|
logger.warning("AutoLabel scheduler already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._stop_event.clear()
|
||||||
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
logger.info("AutoLabel scheduler started")
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the scheduler."""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
self._stop_event.set()
|
||||||
|
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
logger.info("AutoLabel scheduler stopped")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
"""Check if scheduler is running."""
|
||||||
|
return self._running
|
||||||
|
|
||||||
|
def _run_loop(self) -> None:
|
||||||
|
"""Main scheduler loop."""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
self._process_pending_documents()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in autolabel scheduler loop: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# Wait for next check interval
|
||||||
|
self._stop_event.wait(timeout=self._check_interval)
|
||||||
|
|
||||||
|
def _process_pending_documents(self) -> None:
|
||||||
|
"""Check and process pending auto-label documents."""
|
||||||
|
try:
|
||||||
|
documents = get_pending_autolabel_documents(
|
||||||
|
self._db, limit=self._batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(documents)} pending autolabel documents")
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = process_document_autolabel(
|
||||||
|
document=doc,
|
||||||
|
db=self._db,
|
||||||
|
output_dir=self._output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
f"AutoLabel completed for document {doc.document_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"AutoLabel failed for document {doc.document_id}: "
|
||||||
|
f"{result.get('error', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing document {doc.document_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching pending documents: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Global scheduler instance
|
||||||
|
_autolabel_scheduler: AutoLabelScheduler | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_autolabel_scheduler() -> AutoLabelScheduler:
|
||||||
|
"""Get the auto-label scheduler instance."""
|
||||||
|
global _autolabel_scheduler
|
||||||
|
if _autolabel_scheduler is None:
|
||||||
|
_autolabel_scheduler = AutoLabelScheduler()
|
||||||
|
return _autolabel_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def start_autolabel_scheduler() -> None:
|
||||||
|
"""Start the global auto-label scheduler."""
|
||||||
|
scheduler = get_autolabel_scheduler()
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
def stop_autolabel_scheduler() -> None:
|
||||||
|
"""Stop the global auto-label scheduler."""
|
||||||
|
global _autolabel_scheduler
|
||||||
|
if _autolabel_scheduler:
|
||||||
|
_autolabel_scheduler.stop()
|
||||||
|
_autolabel_scheduler = None
|
||||||
211
src/web/core/rate_limiter.py
Normal file
211
src/web/core/rate_limiter.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Rate Limiter Implementation
|
||||||
|
|
||||||
|
Thread-safe rate limiter with sliding window algorithm for API key-based limiting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from threading import Lock
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.data.async_request_db import AsyncRequestDB
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""Rate limit configuration for an API key."""
|
||||||
|
|
||||||
|
requests_per_minute: int = 10
|
||||||
|
max_concurrent_jobs: int = 3
|
||||||
|
min_poll_interval_ms: int = 1000 # Minimum time between status polls
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitStatus:
|
||||||
|
"""Current rate limit status."""
|
||||||
|
|
||||||
|
allowed: bool
|
||||||
|
remaining_requests: int
|
||||||
|
reset_at: datetime
|
||||||
|
retry_after_seconds: int | None = None
|
||||||
|
reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
Thread-safe rate limiter with sliding window algorithm.
|
||||||
|
|
||||||
|
Tracks:
|
||||||
|
- Requests per minute (sliding window)
|
||||||
|
- Concurrent active jobs
|
||||||
|
- Poll frequency per request_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: "AsyncRequestDB") -> None:
|
||||||
|
self._db = db
|
||||||
|
self._lock = Lock()
|
||||||
|
# In-memory tracking for fast checks
|
||||||
|
self._request_windows: dict[str, list[float]] = defaultdict(list)
|
||||||
|
# (api_key, request_id) -> last_poll timestamp
|
||||||
|
self._poll_timestamps: dict[tuple[str, str], float] = {}
|
||||||
|
# Cache for API key configs (TTL 60 seconds)
|
||||||
|
self._config_cache: dict[str, tuple[RateLimitConfig, float]] = {}
|
||||||
|
self._config_cache_ttl = 60.0
|
||||||
|
|
||||||
|
def check_submit_limit(self, api_key: str) -> RateLimitStatus:
|
||||||
|
"""Check if API key can submit a new request."""
|
||||||
|
config = self._get_config(api_key)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
window_start = now - 60 # 1 minute window
|
||||||
|
|
||||||
|
# Clean old entries
|
||||||
|
self._request_windows[api_key] = [
|
||||||
|
ts for ts in self._request_windows[api_key]
|
||||||
|
if ts > window_start
|
||||||
|
]
|
||||||
|
|
||||||
|
current_count = len(self._request_windows[api_key])
|
||||||
|
|
||||||
|
if current_count >= config.requests_per_minute:
|
||||||
|
oldest = min(self._request_windows[api_key])
|
||||||
|
retry_after = int(oldest + 60 - now) + 1
|
||||||
|
return RateLimitStatus(
|
||||||
|
allowed=False,
|
||||||
|
remaining_requests=0,
|
||||||
|
reset_at=datetime.utcnow() + timedelta(seconds=retry_after),
|
||||||
|
retry_after_seconds=max(1, retry_after),
|
||||||
|
reason="Rate limit exceeded: too many requests per minute",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check concurrent jobs (query database) - inside lock for thread safety
|
||||||
|
active_jobs = self._db.count_active_jobs(api_key)
|
||||||
|
if active_jobs >= config.max_concurrent_jobs:
|
||||||
|
return RateLimitStatus(
|
||||||
|
allowed=False,
|
||||||
|
remaining_requests=config.requests_per_minute - current_count,
|
||||||
|
reset_at=datetime.utcnow() + timedelta(seconds=30),
|
||||||
|
retry_after_seconds=30,
|
||||||
|
reason=f"Max concurrent jobs ({config.max_concurrent_jobs}) reached",
|
||||||
|
)
|
||||||
|
|
||||||
|
return RateLimitStatus(
|
||||||
|
allowed=True,
|
||||||
|
remaining_requests=config.requests_per_minute - current_count - 1,
|
||||||
|
reset_at=datetime.utcnow() + timedelta(seconds=60),
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_request(self, api_key: str) -> None:
|
||||||
|
"""Record a successful request submission."""
|
||||||
|
with self._lock:
|
||||||
|
self._request_windows[api_key].append(time.time())
|
||||||
|
|
||||||
|
# Also record in database for persistence
|
||||||
|
try:
|
||||||
|
self._db.record_rate_limit_event(api_key, "request")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to record rate limit event: {e}")
|
||||||
|
|
||||||
|
def check_poll_limit(self, api_key: str, request_id: str) -> RateLimitStatus:
|
||||||
|
"""Check if polling is allowed (prevent abuse)."""
|
||||||
|
config = self._get_config(api_key)
|
||||||
|
key = (api_key, request_id)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
last_poll = self._poll_timestamps.get(key, 0)
|
||||||
|
elapsed_ms = (now - last_poll) * 1000
|
||||||
|
|
||||||
|
if elapsed_ms < config.min_poll_interval_ms:
|
||||||
|
# Suggest exponential backoff
|
||||||
|
wait_ms = min(
|
||||||
|
config.min_poll_interval_ms * 2,
|
||||||
|
5000, # Max 5 seconds
|
||||||
|
)
|
||||||
|
retry_after = int(wait_ms / 1000) + 1
|
||||||
|
return RateLimitStatus(
|
||||||
|
allowed=False,
|
||||||
|
remaining_requests=0,
|
||||||
|
reset_at=datetime.utcnow() + timedelta(milliseconds=wait_ms),
|
||||||
|
retry_after_seconds=retry_after,
|
||||||
|
reason="Polling too frequently. Please wait before retrying.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update poll timestamp
|
||||||
|
self._poll_timestamps[key] = now
|
||||||
|
|
||||||
|
return RateLimitStatus(
|
||||||
|
allowed=True,
|
||||||
|
remaining_requests=999, # No limit on poll count, just frequency
|
||||||
|
reset_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_config(self, api_key: str) -> RateLimitConfig:
|
||||||
|
"""Get rate limit config for API key with caching."""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
if api_key in self._config_cache:
|
||||||
|
cached_config, cached_at = self._config_cache[api_key]
|
||||||
|
if now - cached_at < self._config_cache_ttl:
|
||||||
|
return cached_config
|
||||||
|
|
||||||
|
# Query database
|
||||||
|
db_config = self._db.get_api_key_config(api_key)
|
||||||
|
if db_config:
|
||||||
|
config = RateLimitConfig(
|
||||||
|
requests_per_minute=db_config.requests_per_minute,
|
||||||
|
max_concurrent_jobs=db_config.max_concurrent_jobs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = RateLimitConfig() # Default limits
|
||||||
|
|
||||||
|
# Cache result
|
||||||
|
self._config_cache[api_key] = (config, now)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def cleanup_poll_timestamps(self, max_age_seconds: int = 3600) -> int:
|
||||||
|
"""Clean up old poll timestamps to prevent memory leak."""
|
||||||
|
with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - max_age_seconds
|
||||||
|
old_keys = [
|
||||||
|
k for k, v in self._poll_timestamps.items()
|
||||||
|
if v < cutoff
|
||||||
|
]
|
||||||
|
for key in old_keys:
|
||||||
|
del self._poll_timestamps[key]
|
||||||
|
return len(old_keys)
|
||||||
|
|
||||||
|
def cleanup_request_windows(self) -> None:
|
||||||
|
"""Clean up expired entries from request windows."""
|
||||||
|
with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
window_start = now - 60
|
||||||
|
|
||||||
|
for api_key in list(self._request_windows.keys()):
|
||||||
|
self._request_windows[api_key] = [
|
||||||
|
ts for ts in self._request_windows[api_key]
|
||||||
|
if ts > window_start
|
||||||
|
]
|
||||||
|
# Remove empty entries
|
||||||
|
if not self._request_windows[api_key]:
|
||||||
|
del self._request_windows[api_key]
|
||||||
|
|
||||||
|
def get_rate_limit_headers(self, status: RateLimitStatus) -> dict[str, str]:
|
||||||
|
"""Generate rate limit headers for HTTP response."""
|
||||||
|
headers = {
|
||||||
|
"X-RateLimit-Remaining": str(status.remaining_requests),
|
||||||
|
"X-RateLimit-Reset": status.reset_at.isoformat(),
|
||||||
|
}
|
||||||
|
if status.retry_after_seconds:
|
||||||
|
headers["Retry-After"] = str(status.retry_after_seconds)
|
||||||
|
return headers
|
||||||
329
src/web/core/scheduler.py
Normal file
329
src/web/core/scheduler.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
"""
|
||||||
|
Admin Training Scheduler
|
||||||
|
|
||||||
|
Background scheduler for training tasks using APScheduler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingScheduler:
|
||||||
|
"""Scheduler for training tasks."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
check_interval_seconds: int = 60,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize training scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
check_interval_seconds: Interval to check for pending tasks
|
||||||
|
"""
|
||||||
|
self._check_interval = check_interval_seconds
|
||||||
|
self._running = False
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._db = AdminDB()
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the scheduler."""
|
||||||
|
if self._running:
|
||||||
|
logger.warning("Training scheduler already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._stop_event.clear()
|
||||||
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
logger.info("Training scheduler started")
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the scheduler."""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
self._stop_event.set()
|
||||||
|
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
logger.info("Training scheduler stopped")
|
||||||
|
|
||||||
|
def _run_loop(self) -> None:
|
||||||
|
"""Main scheduler loop."""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
self._check_pending_tasks()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in scheduler loop: {e}")
|
||||||
|
|
||||||
|
# Wait for next check interval
|
||||||
|
self._stop_event.wait(timeout=self._check_interval)
|
||||||
|
|
||||||
|
def _check_pending_tasks(self) -> None:
|
||||||
|
"""Check and execute pending training tasks."""
|
||||||
|
try:
|
||||||
|
tasks = self._db.get_pending_training_tasks()
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
task_id = str(task.task_id)
|
||||||
|
|
||||||
|
# Check if scheduled time has passed
|
||||||
|
if task.scheduled_at and task.scheduled_at > datetime.utcnow():
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Starting training task: {task_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._execute_task(task_id, task.config or {})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training task {task_id} failed: {e}")
|
||||||
|
self._db.update_training_task_status(
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking pending tasks: {e}")
|
||||||
|
|
||||||
|
def _execute_task(self, task_id: str, config: dict[str, Any]) -> None:
|
||||||
|
"""Execute a training task."""
|
||||||
|
# Update status to running
|
||||||
|
self._db.update_training_task_status(task_id, "running")
|
||||||
|
self._db.add_training_log(task_id, "INFO", "Training task started")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get training configuration
|
||||||
|
model_name = config.get("model_name", "yolo11n.pt")
|
||||||
|
epochs = config.get("epochs", 100)
|
||||||
|
batch_size = config.get("batch_size", 16)
|
||||||
|
image_size = config.get("image_size", 640)
|
||||||
|
learning_rate = config.get("learning_rate", 0.01)
|
||||||
|
device = config.get("device", "0")
|
||||||
|
project_name = config.get("project_name", "invoice_fields")
|
||||||
|
|
||||||
|
# Export annotations for training
|
||||||
|
export_result = self._export_training_data(task_id)
|
||||||
|
if not export_result:
|
||||||
|
raise ValueError("Failed to export training data")
|
||||||
|
|
||||||
|
data_yaml = export_result["data_yaml"]
|
||||||
|
|
||||||
|
self._db.add_training_log(
|
||||||
|
task_id, "INFO",
|
||||||
|
f"Exported {export_result['total_images']} images for training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run YOLO training
|
||||||
|
result = self._run_yolo_training(
|
||||||
|
task_id=task_id,
|
||||||
|
model_name=model_name,
|
||||||
|
data_yaml=data_yaml,
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
image_size=image_size,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
device=device,
|
||||||
|
project_name=project_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update task with results
|
||||||
|
self._db.update_training_task_status(
|
||||||
|
task_id=task_id,
|
||||||
|
status="completed",
|
||||||
|
result_metrics=result.get("metrics"),
|
||||||
|
model_path=result.get("model_path"),
|
||||||
|
)
|
||||||
|
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training task {task_id} failed: {e}")
|
||||||
|
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Export training data for a task."""
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
from src.data.admin_models import FIELD_CLASSES
|
||||||
|
|
||||||
|
# Get all labeled documents
|
||||||
|
documents = self._db.get_labeled_documents_for_export()
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_dir = Path("data/training") / task_id
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# YOLO format directories
|
||||||
|
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||||
|
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||||
|
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||||
|
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 80/20 train/val split
|
||||||
|
total_docs = len(documents)
|
||||||
|
train_count = int(total_docs * 0.8)
|
||||||
|
train_docs = documents[:train_count]
|
||||||
|
val_docs = documents[train_count:]
|
||||||
|
|
||||||
|
total_images = 0
|
||||||
|
total_annotations = 0
|
||||||
|
|
||||||
|
# Export documents
|
||||||
|
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||||
|
for doc in docs:
|
||||||
|
annotations = self._db.get_annotations_for_document(str(doc.document_id))
|
||||||
|
|
||||||
|
if not annotations:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for page_num in range(1, doc.page_count + 1):
|
||||||
|
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||||
|
|
||||||
|
# Copy image
|
||||||
|
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||||
|
if not src_image.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||||
|
dst_image = export_dir / "images" / split / image_name
|
||||||
|
shutil.copy(src_image, dst_image)
|
||||||
|
total_images += 1
|
||||||
|
|
||||||
|
# Write YOLO label
|
||||||
|
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||||
|
label_path = export_dir / "labels" / split / label_name
|
||||||
|
|
||||||
|
with open(label_path, "w") as f:
|
||||||
|
for ann in page_annotations:
|
||||||
|
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||||
|
f.write(line)
|
||||||
|
total_annotations += 1
|
||||||
|
|
||||||
|
# Create data.yaml
|
||||||
|
yaml_path = export_dir / "data.yaml"
|
||||||
|
yaml_content = f"""path: {export_dir.absolute()}
|
||||||
|
train: images/train
|
||||||
|
val: images/val
|
||||||
|
|
||||||
|
nc: {len(FIELD_CLASSES)}
|
||||||
|
names: {list(FIELD_CLASSES.values())}
|
||||||
|
"""
|
||||||
|
yaml_path.write_text(yaml_content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data_yaml": str(yaml_path),
|
||||||
|
"total_images": total_images,
|
||||||
|
"total_annotations": total_annotations,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run_yolo_training(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
model_name: str,
|
||||||
|
data_yaml: str,
|
||||||
|
epochs: int,
|
||||||
|
batch_size: int,
|
||||||
|
image_size: int,
|
||||||
|
learning_rate: float,
|
||||||
|
device: str,
|
||||||
|
project_name: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Run YOLO training."""
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Log training start
|
||||||
|
self._db.add_training_log(
|
||||||
|
task_id, "INFO",
|
||||||
|
f"Starting YOLO training: model={model_name}, epochs={epochs}, batch={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = YOLO(model_name)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
results = model.train(
|
||||||
|
data=data_yaml,
|
||||||
|
epochs=epochs,
|
||||||
|
batch=batch_size,
|
||||||
|
imgsz=image_size,
|
||||||
|
lr0=learning_rate,
|
||||||
|
device=device,
|
||||||
|
project=f"runs/train/{project_name}",
|
||||||
|
name=f"task_{task_id[:8]}",
|
||||||
|
exist_ok=True,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get best model path
|
||||||
|
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||||
|
|
||||||
|
# Extract metrics
|
||||||
|
metrics = {}
|
||||||
|
if hasattr(results, "results_dict"):
|
||||||
|
metrics = {
|
||||||
|
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||||
|
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||||
|
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||||
|
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
self._db.add_training_log(
|
||||||
|
task_id, "INFO",
|
||||||
|
f"Training completed. mAP@0.5: {metrics.get('mAP50', 'N/A')}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model_path": str(best_model) if best_model.exists() else None,
|
||||||
|
"metrics": metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
self._db.add_training_log(task_id, "ERROR", "Ultralytics not installed")
|
||||||
|
raise ValueError("Ultralytics (YOLO) not installed")
|
||||||
|
except Exception as e:
|
||||||
|
self._db.add_training_log(task_id, "ERROR", f"YOLO training failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Global scheduler instance
|
||||||
|
_scheduler: TrainingScheduler | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_scheduler() -> TrainingScheduler:
|
||||||
|
"""Get the training scheduler instance."""
|
||||||
|
global _scheduler
|
||||||
|
if _scheduler is None:
|
||||||
|
_scheduler = TrainingScheduler()
|
||||||
|
return _scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def start_scheduler() -> None:
|
||||||
|
"""Start the global training scheduler."""
|
||||||
|
scheduler = get_training_scheduler()
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
def stop_scheduler() -> None:
|
||||||
|
"""Stop the global training scheduler."""
|
||||||
|
global _scheduler
|
||||||
|
if _scheduler:
|
||||||
|
_scheduler.stop()
|
||||||
|
_scheduler = None
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user