restructure project
This commit is contained in:
@@ -87,7 +87,10 @@
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
|
||||
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")"
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": [],
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -52,6 +52,10 @@ reports/*.jsonl
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Coverage
|
||||
htmlcov/
|
||||
.coverage
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
|
||||
807
README.md
807
README.md
@@ -8,7 +8,25 @@
|
||||
|
||||
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
|
||||
2. **模型训练**: 使用 YOLOv11 训练字段检测模型
|
||||
3. **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化
|
||||
3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化
|
||||
|
||||
### 架构
|
||||
|
||||
项目采用 **monorepo + 三包分离** 架构,训练和推理可独立部署:
|
||||
|
||||
```
|
||||
packages/
|
||||
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 工具)
|
||||
├── training/ # 训练服务 (GPU, 按需启动)
|
||||
└── inference/ # 推理服务 (常驻运行)
|
||||
```
|
||||
|
||||
| 服务 | 部署目标 | GPU | 生命周期 |
|
||||
|------|---------|-----|---------|
|
||||
| **Inference** | Azure App Service | 可选 | 常驻 7x24 |
|
||||
| **Training** | Azure ACI | 必需 | 按需启动/销毁 |
|
||||
|
||||
两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。
|
||||
|
||||
### 当前进度
|
||||
|
||||
@@ -16,6 +34,8 @@
|
||||
|------|------|
|
||||
| **已标注文档** | 9,738 (9,709 成功) |
|
||||
| **总体字段匹配率** | 94.8% (82,604/87,121) |
|
||||
| **测试** | 922 passed |
|
||||
| **模型 mAP@0.5** | 93.5% |
|
||||
|
||||
**各字段匹配率:**
|
||||
|
||||
@@ -42,24 +62,83 @@
|
||||
|------|------|
|
||||
| **WSL** | WSL 2 + Ubuntu 22.04 |
|
||||
| **Conda** | Miniconda 或 Anaconda |
|
||||
| **Python** | 3.10+ (通过 Conda 管理) |
|
||||
| **Python** | 3.11+ (通过 Conda 管理) |
|
||||
| **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) |
|
||||
| **数据库** | PostgreSQL (存储标注结果) |
|
||||
|
||||
## 功能特点
|
||||
## 安装
|
||||
|
||||
- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF
|
||||
- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据
|
||||
- **多策略字段匹配**: 精确匹配、子串匹配、规范化匹配
|
||||
- **数据库存储**: 标注结果存储在 PostgreSQL,支持增量处理和断点续传
|
||||
- **YOLO 检测**: 使用 YOLOv11 检测发票字段区域
|
||||
- **OCR 识别**: 使用 PaddleOCR v5 提取检测区域的文本
|
||||
- **统一解析器**: payment_line 和 customer_number 采用独立解析器模块
|
||||
- **交叉验证**: payment_line 数据与单独检测字段交叉验证,优先采用 payment_line 值
|
||||
- **文档类型识别**: 自动区分 invoice (有 payment_line) 和 letter (无 payment_line)
|
||||
- **Web 应用**: 提供 REST API 和可视化界面
|
||||
- **增量训练**: 支持在已训练模型基础上继续训练
|
||||
- **内存优化**: 支持低内存模式训练 (--low-memory)
|
||||
```bash
|
||||
# 1. 进入 WSL
|
||||
wsl -d Ubuntu-22.04
|
||||
|
||||
# 2. 创建 Conda 环境
|
||||
conda create -n invoice-py311 python=3.11 -y
|
||||
conda activate invoice-py311
|
||||
|
||||
# 3. 进入项目目录
|
||||
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
|
||||
|
||||
# 4. 安装三个包 (editable mode)
|
||||
pip install -e packages/shared
|
||||
pip install -e packages/training
|
||||
pip install -e packages/inference
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
invoice-master-poc-v2/
|
||||
├── packages/
|
||||
│ ├── shared/ # 共享库
|
||||
│ │ ├── setup.py
|
||||
│ │ └── shared/
|
||||
│ │ ├── pdf/ # PDF 处理 (提取, 渲染, 检测)
|
||||
│ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析
|
||||
│ │ ├── normalize/ # 字段规范化 (10 种 normalizer)
|
||||
│ │ ├── matcher/ # 字段匹配 (精确/子串/模糊)
|
||||
│ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配)
|
||||
│ │ ├── data/ # DocumentDB, CSVLoader
|
||||
│ │ ├── config.py # 全局配置 (数据库, 路径, DPI)
|
||||
│ │ └── exceptions.py # 异常定义
|
||||
│ │
|
||||
│ ├── training/ # 训练服务 (GPU, 按需)
|
||||
│ │ ├── setup.py
|
||||
│ │ ├── Dockerfile
|
||||
│ │ ├── run_training.py # 入口 (--task-id 或 --poll)
|
||||
│ │ └── training/
|
||||
│ │ ├── cli/ # train, autolabel, analyze_*, validate
|
||||
│ │ ├── yolo/ # db_dataset, annotation_generator
|
||||
│ │ ├── processing/ # CPU/GPU worker pool, task dispatcher
|
||||
│ │ └── data/ # training_db, autolabel_report
|
||||
│ │
|
||||
│ └── inference/ # 推理服务 (常驻)
|
||||
│ ├── setup.py
|
||||
│ ├── Dockerfile
|
||||
│ ├── run_server.py # Web 服务器入口
|
||||
│ └── inference/
|
||||
│ ├── cli/ # infer, serve
|
||||
│ ├── pipeline/ # YOLO 检测, 字段提取, 解析器
|
||||
│ ├── web/ # FastAPI 应用
|
||||
│ │ ├── api/v1/ # REST API (admin, public, batch)
|
||||
│ │ ├── schemas/ # Pydantic 数据模型
|
||||
│ │ ├── services/ # 业务逻辑
|
||||
│ │ ├── core/ # 认证, 调度器, 限流
|
||||
│ │ └── workers/ # 后台任务队列
|
||||
│ ├── validation/ # LLM 验证器
|
||||
│ ├── data/ # AdminDB, AsyncRequestDB, Models
|
||||
│ └── azure/ # ACI 训练触发器
|
||||
│
|
||||
├── migrations/ # 数据库迁移
|
||||
│ ├── 001_async_tables.sql
|
||||
│ ├── 002_nullable_admin_token.sql
|
||||
│ └── 003_training_tasks.sql
|
||||
├── frontend/ # React 前端 (Vite + TypeScript)
|
||||
├── tests/ # 测试 (922 tests)
|
||||
├── docker-compose.yml # 本地开发 (postgres + inference + training)
|
||||
├── run_server.py # 快捷启动脚本
|
||||
└── runs/train/ # 训练输出 (weights, curves)
|
||||
```
|
||||
|
||||
## 支持的字段
|
||||
|
||||
@@ -76,476 +155,129 @@
|
||||
| 8 | payment_line | 支付行 (机器可读格式) |
|
||||
| 9 | customer_number | 客户编号 |
|
||||
|
||||
## DPI 配置
|
||||
|
||||
**重要**: 系统所有组件统一使用 **150 DPI**,确保训练和推理的一致性。
|
||||
|
||||
DPI(每英寸点数)设置必须在训练和推理时保持一致,否则会导致:
|
||||
- 检测框尺寸失配
|
||||
- mAP显著下降(可能从93.5%降到60-70%)
|
||||
- 字段漏检或误检
|
||||
|
||||
### 配置位置
|
||||
|
||||
| 组件 | 配置文件 | 配置项 |
|
||||
|------|---------|--------|
|
||||
| **全局常量** | `src/config.py` | `DEFAULT_DPI = 150` |
|
||||
| **Web推理** | `src/web/config.py` | `ModelConfig.dpi` (导入自 `src.config`) |
|
||||
| **CLI推理** | `src/cli/infer.py` | `--dpi` 默认值 = `DEFAULT_DPI` |
|
||||
| **自动标注** | `src/config.py` | `AUTOLABEL['dpi'] = DEFAULT_DPI` |
|
||||
| **PDF转图** | `src/web/api/v1/admin/documents.py` | 使用 `DEFAULT_DPI` |
|
||||
|
||||
### 使用示例
|
||||
|
||||
```bash
|
||||
# 训练(使用默认150 DPI)
|
||||
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
|
||||
|
||||
# 推理(默认150 DPI,与训练一致)
|
||||
python -m src.cli.infer -m runs/train/invoice_fields/weights/best.pt -i invoice.pdf
|
||||
|
||||
# 手动指定DPI(仅当需要与非默认训练DPI的模型配合时)
|
||||
python -m src.cli.infer -m custom_model.pt -i invoice.pdf --dpi 150
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
# 1. 进入 WSL
|
||||
wsl -d Ubuntu-22.04
|
||||
|
||||
# 2. 创建 Conda 环境
|
||||
conda create -n invoice-py311 python=3.11 -y
|
||||
conda activate invoice-py311
|
||||
|
||||
# 3. 进入项目目录
|
||||
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
|
||||
|
||||
# 4. 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 5. 安装 Web 依赖
|
||||
pip install uvicorn fastapi python-multipart pydantic
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 准备数据
|
||||
|
||||
```
|
||||
~/invoice-data/
|
||||
├── raw_pdfs/
|
||||
│ ├── {DocumentId}.pdf
|
||||
│ └── ...
|
||||
├── structured_data/
|
||||
│ └── document_export_YYYYMMDD.csv
|
||||
└── dataset/
|
||||
└── temp/ (渲染的图片)
|
||||
```
|
||||
|
||||
CSV 格式:
|
||||
```csv
|
||||
DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount
|
||||
3be53fd7-...,2025-12-13,100017500321,2026-01-03,100017500321,53939484,,114
|
||||
```
|
||||
|
||||
### 2. 自动标注
|
||||
### 1. 自动标注
|
||||
|
||||
```bash
|
||||
# 使用双池模式 (CPU + GPU)
|
||||
python -m src.cli.autolabel \
|
||||
python -m training.cli.autolabel \
|
||||
--dual-pool \
|
||||
--cpu-workers 3 \
|
||||
--gpu-workers 1
|
||||
|
||||
# 单线程模式
|
||||
python -m src.cli.autolabel --workers 4
|
||||
python -m training.cli.autolabel --workers 4
|
||||
```
|
||||
|
||||
### 3. 训练模型
|
||||
### 2. 训练模型
|
||||
|
||||
```bash
|
||||
# 从预训练模型开始训练
|
||||
python -m src.cli.train \
|
||||
python -m training.cli.train \
|
||||
--model yolo11n.pt \
|
||||
--epochs 100 \
|
||||
--batch 16 \
|
||||
--name invoice_fields \
|
||||
--dpi 150
|
||||
|
||||
# 低内存模式 (适用于内存不足场景)
|
||||
python -m src.cli.train \
|
||||
# 低内存模式
|
||||
python -m training.cli.train \
|
||||
--model yolo11n.pt \
|
||||
--epochs 100 \
|
||||
--name invoice_fields \
|
||||
--low-memory \
|
||||
--workers 4 \
|
||||
--no-cache
|
||||
--low-memory
|
||||
|
||||
# 从检查点恢复训练 (训练中断后)
|
||||
python -m src.cli.train \
|
||||
# 从检查点恢复训练
|
||||
python -m training.cli.train \
|
||||
--model runs/train/invoice_fields/weights/last.pt \
|
||||
--epochs 100 \
|
||||
--name invoice_fields \
|
||||
--resume
|
||||
```
|
||||
|
||||
### 4. 增量训练
|
||||
|
||||
当添加新数据后,可以在已训练模型基础上继续训练:
|
||||
|
||||
```bash
|
||||
# 从已训练的 best.pt 继续训练
|
||||
python -m src.cli.train \
|
||||
--model runs/train/invoice_yolo11n_full/weights/best.pt \
|
||||
--epochs 30 \
|
||||
--batch 16 \
|
||||
--name invoice_yolo11n_v2 \
|
||||
--dpi 150
|
||||
```
|
||||
|
||||
**增量训练建议**:
|
||||
|
||||
| 场景 | 建议 |
|
||||
|------|------|
|
||||
| 添加少量新数据 (<20%) | 继续训练 10-30 epochs |
|
||||
| 添加大量新数据 (>50%) | 继续训练 50-100 epochs |
|
||||
| 修正大量标注错误 | 从头训练 |
|
||||
| 添加新的字段类型 | 从头训练 |
|
||||
|
||||
### 5. 推理
|
||||
### 3. 推理
|
||||
|
||||
```bash
|
||||
# 命令行推理
|
||||
python -m src.cli.infer \
|
||||
python -m inference.cli.infer \
|
||||
--model runs/train/invoice_fields/weights/best.pt \
|
||||
--input path/to/invoice.pdf \
|
||||
--output result.json \
|
||||
--gpu
|
||||
|
||||
# 批量推理
|
||||
python -m src.cli.infer \
|
||||
--model runs/train/invoice_fields/weights/best.pt \
|
||||
--input invoices/*.pdf \
|
||||
--output results/ \
|
||||
--gpu
|
||||
```
|
||||
|
||||
**推理结果包含**:
|
||||
- `fields`: 提取的字段值 (InvoiceNumber, Amount, payment_line, customer_number 等)
|
||||
- `confidence`: 各字段的置信度
|
||||
- `document_type`: 文档类型 ("invoice" 或 "letter")
|
||||
- `cross_validation`: payment_line 交叉验证结果 (如果有)
|
||||
|
||||
### 6. Web 应用
|
||||
|
||||
**在 WSL 环境中启动**:
|
||||
### 4. Web 应用
|
||||
|
||||
```bash
|
||||
# 方法 1: 从 Windows PowerShell 启动 (推荐)
|
||||
# 从 Windows PowerShell 启动
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
|
||||
|
||||
# 方法 2: 在 WSL 内启动
|
||||
conda activate invoice-py311
|
||||
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
|
||||
python run_server.py --port 8000
|
||||
|
||||
# 方法 3: 使用启动脚本
|
||||
./start_web.sh
|
||||
# 启动前端
|
||||
cd frontend && npm install && npm run dev
|
||||
# 访问 http://localhost:5173
|
||||
```
|
||||
|
||||
**服务启动后**:
|
||||
- 访问 **http://localhost:8000** 使用 Web 界面
|
||||
- 服务会自动加载模型 `runs/train/invoice_fields/weights/best.pt`
|
||||
- GPU 默认启用,置信度阈值 0.5
|
||||
### 5. Docker 本地开发
|
||||
|
||||
#### Web API 端点
|
||||
```bash
|
||||
docker-compose up
|
||||
# inference: http://localhost:8000
|
||||
# training: 轮询模式自动拾取任务
|
||||
```
|
||||
|
||||
## 训练触发流程
|
||||
|
||||
推理服务通过 API 触发训练,训练在独立的 GPU 实例上执行:
|
||||
|
||||
```
|
||||
Inference API PostgreSQL Training (ACI)
|
||||
| | |
|
||||
POST /admin/training/trigger | |
|
||||
|-> INSERT training_tasks ------>| status=pending |
|
||||
|-> Azure SDK: create ACI --------------------------------> 启动
|
||||
| | |
|
||||
| |<-- SELECT pending -----+
|
||||
| |--- UPDATE running -----+
|
||||
| | 执行训练...
|
||||
| |<-- UPDATE completed ---+
|
||||
| | + model_path |
|
||||
| | + metrics 自动关机
|
||||
| | |
|
||||
GET /admin/training/{id} | |
|
||||
|-> SELECT training_tasks ------>| |
|
||||
+-- return status + metrics | |
|
||||
```
|
||||
|
||||
## Web API 端点
|
||||
|
||||
**Public API:**
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| GET | `/` | Web UI 界面 |
|
||||
| GET | `/api/v1/health` | 健康检查 |
|
||||
| POST | `/api/v1/infer` | 上传文件并推理 |
|
||||
| GET | `/api/v1/results/{filename}` | 获取可视化图片 |
|
||||
| POST | `/api/v1/async/infer` | 异步推理 |
|
||||
| GET | `/api/v1/async/status/{task_id}` | 查询异步任务状态 |
|
||||
|
||||
#### API 响应格式
|
||||
**Admin API** (需要 `X-Admin-Token` header):
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"result": {
|
||||
"document_id": "abc123",
|
||||
"document_type": "invoice",
|
||||
"fields": {
|
||||
"InvoiceNumber": "12345",
|
||||
"Amount": "1234.56",
|
||||
"payment_line": "# 94228110015950070 # > 48666036#14#",
|
||||
"customer_number": "UMJ 436-R"
|
||||
},
|
||||
"confidence": {
|
||||
"InvoiceNumber": 0.95,
|
||||
"Amount": 0.92
|
||||
},
|
||||
"cross_validation": {
|
||||
"is_valid": true,
|
||||
"ocr_match": true,
|
||||
"amount_match": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 训练配置
|
||||
|
||||
### YOLO 训练参数
|
||||
|
||||
```bash
|
||||
python -m src.cli.train [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 基础模型 (默认: yolo11n.pt)
|
||||
--epochs, -e 训练轮数 (默认: 100)
|
||||
--batch, -b 批大小 (默认: 16)
|
||||
--imgsz 图像尺寸 (默认: 1280)
|
||||
--dpi PDF 渲染 DPI (默认: 150)
|
||||
--name 训练名称
|
||||
--limit 限制文档数 (用于测试)
|
||||
--device 设备 (0=GPU, cpu)
|
||||
--resume 从检查点恢复训练
|
||||
--low-memory 启用低内存模式 (batch=8, workers=4, no-cache)
|
||||
--workers 数据加载 worker 数 (默认: 8)
|
||||
--cache 缓存图像到内存
|
||||
```
|
||||
|
||||
### 训练最佳实践
|
||||
|
||||
1. **禁用翻转增强** (文本检测):
|
||||
```python
|
||||
fliplr=0.0, flipud=0.0
|
||||
```
|
||||
|
||||
2. **使用 Early Stopping**:
|
||||
```python
|
||||
patience=20
|
||||
```
|
||||
|
||||
3. **启用 AMP** (混合精度训练):
|
||||
```python
|
||||
amp=True
|
||||
```
|
||||
|
||||
4. **保存检查点**:
|
||||
```python
|
||||
save_period=10
|
||||
```
|
||||
|
||||
### 训练结果示例
|
||||
|
||||
**最新训练结果** (100 epochs, 2026-01-22):
|
||||
|
||||
| 指标 | 值 |
|
||||
|------|-----|
|
||||
| **mAP@0.5** | 93.5% |
|
||||
| **mAP@0.5-0.95** | 83.0% |
|
||||
| **训练集** | ~10,000 张标注图片 |
|
||||
| **字段类型** | 10 个字段 (新增 payment_line, customer_number) |
|
||||
| **模型位置** | `runs/train/invoice_fields/weights/best.pt` |
|
||||
|
||||
**各字段检测性能**:
|
||||
- 发票基础信息 (InvoiceNumber, InvoiceDate, InvoiceDueDate): >95% mAP
|
||||
- 支付信息 (OCR, Bankgiro, Plusgiro, Amount): >90% mAP
|
||||
- 组织信息 (supplier_org_number, customer_number): >85% mAP
|
||||
- 支付行 (payment_line): >80% mAP
|
||||
|
||||
**模型文件**:
|
||||
```
|
||||
runs/train/invoice_fields/weights/
|
||||
├── best.pt # 最佳模型 (mAP@0.5 最高) ⭐ 推荐用于生产
|
||||
└── last.pt # 最后检查点 (用于继续训练)
|
||||
```
|
||||
|
||||
> 注:目前仍在持续标注更多数据,预计最终将有 25,000+ 张标注图片用于训练。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
invoice-master-poc-v2/
|
||||
├── src/
|
||||
│ ├── cli/ # 命令行工具
|
||||
│ │ ├── autolabel.py # 自动标注
|
||||
│ │ ├── train.py # 模型训练
|
||||
│ │ ├── infer.py # 推理
|
||||
│ │ └── serve.py # Web 服务器
|
||||
│ ├── pdf/ # PDF 处理
|
||||
│ │ ├── extractor.py # 文本提取
|
||||
│ │ ├── renderer.py # 图像渲染
|
||||
│ │ └── detector.py # 类型检测
|
||||
│ ├── ocr/ # PaddleOCR 封装
|
||||
│ │ └── machine_code_parser.py # 机器可读付款行解析器
|
||||
│ ├── normalize/ # 字段规范化
|
||||
│ ├── matcher/ # 字段匹配
|
||||
│ ├── yolo/ # YOLO 相关
|
||||
│ │ ├── annotation_generator.py
|
||||
│ │ └── db_dataset.py
|
||||
│ ├── inference/ # 推理管道
|
||||
│ │ ├── pipeline.py # 主推理流程
|
||||
│ │ ├── yolo_detector.py # YOLO 检测
|
||||
│ │ ├── field_extractor.py # 字段提取
|
||||
│ │ ├── payment_line_parser.py # 支付行解析器
|
||||
│ │ └── customer_number_parser.py # 客户编号解析器
|
||||
│ ├── processing/ # 多池处理架构
|
||||
│ │ ├── worker_pool.py
|
||||
│ │ ├── cpu_pool.py
|
||||
│ │ ├── gpu_pool.py
|
||||
│ │ ├── task_dispatcher.py
|
||||
│ │ └── dual_pool_coordinator.py
|
||||
│ ├── web/ # Web 应用
|
||||
│ │ ├── app.py # FastAPI 应用入口
|
||||
│ │ ├── routes.py # API 路由
|
||||
│ │ ├── services.py # 业务逻辑
|
||||
│ │ └── schemas.py # 数据模型
|
||||
│ ├── utils/ # 工具模块
|
||||
│ │ ├── text_cleaner.py # 文本清理
|
||||
│ │ ├── validators.py # 字段验证
|
||||
│ │ ├── fuzzy_matcher.py # 模糊匹配
|
||||
│ │ └── ocr_corrections.py # OCR 错误修正
|
||||
│ └── data/ # 数据处理
|
||||
├── tests/ # 测试文件
|
||||
│ ├── ocr/ # OCR 模块测试
|
||||
│ │ └── test_machine_code_parser.py
|
||||
│ ├── inference/ # 推理模块测试
|
||||
│ ├── normalize/ # 规范化模块测试
|
||||
│ └── utils/ # 工具模块测试
|
||||
├── docs/ # 文档
|
||||
│ ├── REFACTORING_SUMMARY.md
|
||||
│ └── TEST_COVERAGE_IMPROVEMENT.md
|
||||
├── config.py # 配置文件
|
||||
├── run_server.py # Web 服务器启动脚本
|
||||
├── runs/ # 训练输出
|
||||
│ └── train/
|
||||
│ └── invoice_fields/
|
||||
│ └── weights/
|
||||
│ ├── best.pt # 最佳模型
|
||||
│ └── last.pt # 最后检查点
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
## 多池处理架构
|
||||
|
||||
项目使用 CPU + GPU 双池架构处理不同类型的 PDF:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ DualPoolCoordinator │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ │
|
||||
│ │ CPU Pool │ │ GPU Pool │ │
|
||||
│ │ (3 workers) │ │ (1 worker) │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ Text PDFs │ │ Scanned PDFs │ │
|
||||
│ │ ~50-87 it/s │ │ ~1-2 it/s │ │
|
||||
│ └─────────────────┘ └─────────────────┘ │
|
||||
│ │
|
||||
│ TaskDispatcher: 根据 PDF 类型分配任务 │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 关键设计
|
||||
|
||||
- **spawn 启动方式**: 兼容 CUDA 多进程
|
||||
- **as_completed()**: 无死锁结果收集
|
||||
- **进程初始化器**: 每个 worker 加载一次模型
|
||||
- **协调器持久化**: 跨 CSV 文件复用 worker 池
|
||||
|
||||
## 配置文件
|
||||
|
||||
### config.py
|
||||
|
||||
```python
|
||||
# 数据库配置
|
||||
DATABASE = {
|
||||
'host': '192.168.68.31',
|
||||
'port': 5432,
|
||||
'database': 'docmaster',
|
||||
'user': 'docmaster',
|
||||
'password': '******',
|
||||
}
|
||||
|
||||
# 路径配置
|
||||
PATHS = {
|
||||
'csv_dir': '~/invoice-data/structured_data',
|
||||
'pdf_dir': '~/invoice-data/raw_pdfs',
|
||||
'output_dir': '~/invoice-data/dataset',
|
||||
}
|
||||
```
|
||||
|
||||
## CLI 命令参考
|
||||
|
||||
### autolabel
|
||||
|
||||
```bash
|
||||
python -m src.cli.autolabel [OPTIONS]
|
||||
|
||||
Options:
|
||||
--csv, -c CSV 文件路径 (支持 glob)
|
||||
--pdf-dir, -p PDF 文件目录
|
||||
--output, -o 输出目录
|
||||
--workers, -w 单线程模式 worker 数 (默认: 4)
|
||||
--dual-pool 启用双池模式
|
||||
--cpu-workers CPU 池 worker 数 (默认: 3)
|
||||
--gpu-workers GPU 池 worker 数 (默认: 1)
|
||||
--dpi 渲染 DPI (默认: 150)
|
||||
--limit, -l 限制处理文档数
|
||||
```
|
||||
|
||||
### train
|
||||
|
||||
```bash
|
||||
python -m src.cli.train [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 基础模型路径
|
||||
--epochs, -e 训练轮数 (默认: 100)
|
||||
--batch, -b 批大小 (默认: 16)
|
||||
--imgsz 图像尺寸 (默认: 1280)
|
||||
--dpi PDF 渲染 DPI (默认: 150)
|
||||
--name 训练名称
|
||||
--limit 限制文档数
|
||||
```
|
||||
|
||||
### infer
|
||||
|
||||
```bash
|
||||
python -m src.cli.infer [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 模型路径
|
||||
--input, -i 输入 PDF/图像
|
||||
--output, -o 输出 JSON 路径
|
||||
--confidence 置信度阈值 (默认: 0.5)
|
||||
--dpi 渲染 DPI (默认: 150, 必须与训练DPI一致)
|
||||
--gpu 使用 GPU
|
||||
```
|
||||
|
||||
### serve
|
||||
|
||||
```bash
|
||||
python run_server.py [OPTIONS]
|
||||
|
||||
Options:
|
||||
--host 绑定地址 (默认: 0.0.0.0)
|
||||
--port 端口 (默认: 8000)
|
||||
--model, -m 模型路径
|
||||
--confidence 置信度阈值 (默认: 0.3)
|
||||
--dpi 渲染 DPI (默认: 150)
|
||||
--no-gpu 禁用 GPU
|
||||
--reload 开发模式自动重载
|
||||
--debug 调试模式
|
||||
```
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| POST | `/api/v1/admin/auth/login` | 管理员登录 |
|
||||
| GET | `/api/v1/admin/documents` | 文档列表 |
|
||||
| POST | `/api/v1/admin/documents/upload` | 上传 PDF |
|
||||
| GET | `/api/v1/admin/documents/{id}` | 文档详情 |
|
||||
| PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 |
|
||||
| POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 |
|
||||
| POST | `/api/v1/admin/training/trigger` | 触发训练任务 |
|
||||
| GET | `/api/v1/admin/training/{id}/status` | 查询训练状态 |
|
||||
|
||||
## Python API
|
||||
|
||||
```python
|
||||
from src.inference.pipeline import InferencePipeline
|
||||
from inference.pipeline import InferencePipeline
|
||||
|
||||
# 初始化
|
||||
pipeline = InferencePipeline(
|
||||
@@ -559,41 +291,25 @@ pipeline = InferencePipeline(
|
||||
# 处理 PDF
|
||||
result = pipeline.process_pdf('invoice.pdf')
|
||||
|
||||
# 处理图片
|
||||
result = pipeline.process_image('invoice.png')
|
||||
|
||||
# 获取结果
|
||||
print(result.fields)
|
||||
# {
|
||||
# 'InvoiceNumber': '12345',
|
||||
# 'Amount': '1234.56',
|
||||
# 'payment_line': '# 94228110015950070 # > 48666036#14#',
|
||||
# 'customer_number': 'UMJ 436-R',
|
||||
# ...
|
||||
# }
|
||||
# {'InvoiceNumber': '12345', 'Amount': '1234.56', ...}
|
||||
|
||||
print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
|
||||
print(result.to_json()) # JSON 格式输出
|
||||
print(result.confidence)
|
||||
# {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
|
||||
|
||||
# 访问交叉验证结果
|
||||
# 交叉验证
|
||||
if result.cross_validation:
|
||||
print(f"OCR match: {result.cross_validation.ocr_match}")
|
||||
print(f"Amount match: {result.cross_validation.amount_match}")
|
||||
print(f"Details: {result.cross_validation.details}")
|
||||
```
|
||||
|
||||
### 统一解析器使用
|
||||
|
||||
```python
|
||||
from src.inference.payment_line_parser import PaymentLineParser
|
||||
from src.inference.customer_number_parser import CustomerNumberParser
|
||||
from inference.pipeline.payment_line_parser import PaymentLineParser
|
||||
from inference.pipeline.customer_number_parser import CustomerNumberParser
|
||||
|
||||
# Payment Line 解析
|
||||
parser = PaymentLineParser()
|
||||
result = parser.parse("# 94228110015950070 # 15658 00 8 > 48666036#14#")
|
||||
print(f"OCR: {result.ocr_number}")
|
||||
print(f"Amount: {result.amount}")
|
||||
print(f"Account: {result.account_number}")
|
||||
print(f"OCR: {result.ocr_number}, Amount: {result.amount}")
|
||||
|
||||
# Customer Number 解析
|
||||
parser = CustomerNumberParser()
|
||||
@@ -601,156 +317,38 @@ result = parser.parse("Said, Shakar Umj 436-R Billo")
|
||||
print(f"Customer Number: {result}") # "UMJ 436-R"
|
||||
```
|
||||
|
||||
## DPI 配置
|
||||
|
||||
系统所有组件统一使用 **150 DPI**。DPI 必须在训练和推理时保持一致。
|
||||
|
||||
| 组件 | 配置位置 |
|
||||
|------|---------|
|
||||
| 全局常量 | `packages/shared/shared/config.py` -> `DEFAULT_DPI = 150` |
|
||||
| Web 推理 | `packages/inference/inference/web/config.py` -> `ModelConfig.dpi` |
|
||||
| CLI 推理 | `python -m inference.cli.infer --dpi 150` |
|
||||
| 自动标注 | `packages/shared/shared/config.py` -> `AUTOLABEL['dpi']` |
|
||||
|
||||
## 数据库架构
|
||||
|
||||
| 数据库 | 用途 | 存储内容 |
|
||||
|--------|------|----------|
|
||||
| **PostgreSQL** | 标注结果 | `documents`, `field_results`, `training_tasks` |
|
||||
| **SQLite** (AdminDB) | Web 应用 | 文档管理, 标注编辑, 用户认证 |
|
||||
|
||||
## 测试
|
||||
|
||||
### 测试统计
|
||||
|
||||
| 指标 | 数值 |
|
||||
|------|------|
|
||||
| **测试总数** | 688 |
|
||||
| **通过率** | 100% |
|
||||
| **整体覆盖率** | 37% |
|
||||
|
||||
### 关键模块覆盖率
|
||||
|
||||
| 模块 | 覆盖率 | 测试数 |
|
||||
|------|--------|--------|
|
||||
| `machine_code_parser.py` | 65% | 79 |
|
||||
| `payment_line_parser.py` | 85% | 45 |
|
||||
| `customer_number_parser.py` | 90% | 32 |
|
||||
|
||||
### 运行测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest"
|
||||
DB_PASSWORD=xxx pytest tests/ -q
|
||||
|
||||
# 运行并查看覆盖率
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src --cov-report=term-missing"
|
||||
|
||||
# 运行特定模块测试
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/ocr/test_machine_code_parser.py -v"
|
||||
DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
|
||||
```
|
||||
|
||||
### 测试结构
|
||||
|
||||
```
|
||||
tests/
|
||||
├── ocr/
|
||||
│ ├── test_machine_code_parser.py # 支付行解析 (79 tests)
|
||||
│ └── test_ocr_engine.py # OCR 引擎测试
|
||||
├── inference/
|
||||
│ ├── test_payment_line_parser.py # 支付行解析器
|
||||
│ └── test_customer_number_parser.py # 客户编号解析器
|
||||
├── normalize/
|
||||
│ └── test_normalizers.py # 字段规范化
|
||||
└── utils/
|
||||
└── test_validators.py # 字段验证
|
||||
```
|
||||
|
||||
## 开发状态
|
||||
|
||||
**已完成功能**:
|
||||
- [x] 文本层 PDF 自动标注
|
||||
- [x] 扫描图 OCR 自动标注
|
||||
- [x] 多策略字段匹配 (精确/子串/规范化)
|
||||
- [x] PostgreSQL 数据库存储 (断点续传)
|
||||
- [x] 信号处理和超时保护
|
||||
- [x] YOLO 训练 (93.5% mAP@0.5, 10 个字段)
|
||||
- [x] 推理管道
|
||||
- [x] 字段规范化和验证
|
||||
- [x] Web 应用 (FastAPI + REST API)
|
||||
- [x] 增量训练支持
|
||||
- [x] 内存优化训练 (--low-memory, --resume)
|
||||
- [x] Payment Line 解析器 (统一模块)
|
||||
- [x] Customer Number 解析器 (统一模块)
|
||||
- [x] Payment Line 交叉验证 (OCR, Amount, Account)
|
||||
- [x] 文档类型识别 (invoice/letter)
|
||||
- [x] 单元测试覆盖 (688 tests, 37% coverage)
|
||||
|
||||
**进行中**:
|
||||
- [ ] 完成全部 25,000+ 文档标注
|
||||
- [ ] 多源融合增强 (Multi-source fusion)
|
||||
- [ ] OCR 错误修正集成
|
||||
- [ ] 提升测试覆盖率到 60%+
|
||||
|
||||
**计划中**:
|
||||
- [ ] 表格 items 提取
|
||||
- [ ] 模型量化部署 (ONNX/TensorRT)
|
||||
- [ ] 多语言支持扩展
|
||||
|
||||
## 关键技术特性
|
||||
|
||||
### 1. Payment Line 交叉验证
|
||||
|
||||
瑞典发票的 payment_line (支付行) 包含完整的支付信息:OCR 参考号、金额、账号。我们实现了交叉验证机制:
|
||||
|
||||
```
|
||||
Payment Line: # 94228110015950070 # 15658 00 8 > 48666036#14#
|
||||
↓ ↓ ↓
|
||||
OCR Number Amount Bankgiro Account
|
||||
```
|
||||
|
||||
**验证流程**:
|
||||
1. 从 payment_line 提取 OCR、Amount、Account
|
||||
2. 与单独检测的字段对比验证
|
||||
3. **payment_line 值优先** - 如有不匹配,采用 payment_line 的值
|
||||
4. 返回验证结果和详细信息
|
||||
|
||||
**优势**:
|
||||
- 提高数据准确性 (payment_line 是机器可读格式,更可靠)
|
||||
- 发现 OCR 或检测错误
|
||||
- 为数据质量提供信心指标
|
||||
|
||||
### 2. 统一解析器架构
|
||||
|
||||
采用独立解析器模块处理复杂字段:
|
||||
|
||||
**PaymentLineParser**:
|
||||
- 解析瑞典标准支付行格式
|
||||
- 提取 OCR、Amount (包含 Kronor + Öre)、Account + Check digits
|
||||
- 支持多种变体格式
|
||||
|
||||
**CustomerNumberParser**:
|
||||
- 支持多种瑞典客户编号格式 (`UMJ 436-R`, `JTY 576-3`, `FFL 019N`)
|
||||
- 从混合文本中提取 (如地址行中的客户编号)
|
||||
- 大小写不敏感,输出统一大写格式
|
||||
|
||||
**优势**:
|
||||
- 代码模块化、可测试
|
||||
- 易于扩展新格式
|
||||
- 统一的解析逻辑,减少重复代码
|
||||
|
||||
### 3. 文档类型自动识别
|
||||
|
||||
根据 payment_line 字段自动判断文档类型:
|
||||
|
||||
- **invoice**: 包含 payment_line 的发票文档
|
||||
- **letter**: 不含 payment_line 的信函文档
|
||||
|
||||
这个特性帮助下游系统区分处理流程。
|
||||
|
||||
### 4. 低内存模式训练
|
||||
|
||||
支持在内存受限环境下训练:
|
||||
|
||||
```bash
|
||||
python -m src.cli.train --low-memory
|
||||
```
|
||||
|
||||
自动调整:
|
||||
- batch size: 16 → 8
|
||||
- workers: 8 → 4
|
||||
- cache: disabled
|
||||
- 推荐用于 GPU 内存 < 8GB 或系统内存 < 16GB 的场景
|
||||
|
||||
### 5. 断点续传训练
|
||||
|
||||
训练中断后可从检查点恢复:
|
||||
|
||||
```bash
|
||||
python -m src.cli.train --resume --model runs/train/invoice_fields/weights/last.pt
|
||||
```
|
||||
| 指标 | 数值 |
|
||||
|------|------|
|
||||
| **测试总数** | 922 |
|
||||
| **通过率** | 100% |
|
||||
|
||||
## 技术栈
|
||||
|
||||
@@ -762,32 +360,7 @@ python -m src.cli.train --resume --model runs/train/invoice_fields/weights/last.
|
||||
| **数据库** | PostgreSQL + psycopg2 |
|
||||
| **Web 框架** | FastAPI + Uvicorn |
|
||||
| **深度学习** | PyTorch + CUDA 12.x |
|
||||
|
||||
## 常见问题
|
||||
|
||||
**Q: 为什么必须在 WSL 环境运行?**
|
||||
|
||||
A: PaddleOCR 和某些依赖在 Windows 原生环境存在兼容性问题。WSL 提供完整的 Linux 环境,确保所有依赖正常工作。
|
||||
|
||||
**Q: 训练过程中出现 OOM (内存不足) 错误怎么办?**
|
||||
|
||||
A: 使用 `--low-memory` 模式,或手动调整 `--batch` 和 `--workers` 参数。
|
||||
|
||||
**Q: payment_line 和单独检测字段不匹配时怎么处理?**
|
||||
|
||||
A: 系统默认优先采用 payment_line 的值,因为 payment_line 是机器可读格式,通常更准确。验证结果会记录在 `cross_validation` 字段中。
|
||||
|
||||
**Q: 如何添加新的字段类型?**
|
||||
|
||||
A:
|
||||
1. 在 `src/inference/constants.py` 添加字段定义
|
||||
2. 在 `field_extractor.py` 添加规范化方法
|
||||
3. 重新生成标注数据
|
||||
4. 从头训练模型
|
||||
|
||||
**Q: 可以用 CPU 训练吗?**
|
||||
|
||||
A: 可以,但速度会非常慢 (慢 10-50 倍)。强烈建议使用 GPU 训练。
|
||||
| **部署** | Docker + Azure ACI (训练) / App Service (推理) |
|
||||
|
||||
## 许可证
|
||||
|
||||
|
||||
60
docker-compose.yml
Normal file
60
docker-compose.yml
Normal file
@@ -0,0 +1,60 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15
|
||||
environment:
|
||||
POSTGRES_DB: docmaster
|
||||
POSTGRES_USER: docmaster
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-devpassword}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
- ./migrations:/docker-entrypoint-initdb.d
|
||||
|
||||
inference:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: packages/inference/Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- DB_NAME=docmaster
|
||||
- DB_USER=docmaster
|
||||
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
|
||||
- MODEL_PATH=/app/models/best.pt
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
depends_on:
|
||||
- postgres
|
||||
|
||||
training:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: packages/training/Dockerfile
|
||||
environment:
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- DB_NAME=docmaster
|
||||
- DB_USER=docmaster
|
||||
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
- ./temp:/app/temp
|
||||
depends_on:
|
||||
- postgres
|
||||
# Override CMD for local dev polling mode
|
||||
command: ["python", "run_training.py", "--poll", "--poll-interval", "30"]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
54
docs/training-flow.mmd
Normal file
54
docs/training-flow.mmd
Normal file
@@ -0,0 +1,54 @@
|
||||
flowchart TD
|
||||
A[CLI Entry Point\nsrc/cli/train.py] --> B[Parse Arguments\n--model, --epochs, --batch, --imgsz, etc.]
|
||||
B --> C[Connect PostgreSQL\nDB_HOST / DB_NAME / DB_PASSWORD]
|
||||
|
||||
C --> D[Load Data from DB\nsrc/yolo/db_dataset.py]
|
||||
D --> D1[Scan temp/doc_id/images/\nfor rendered PNGs]
|
||||
D --> D2[Batch load field_results\nfrom database - batch 500]
|
||||
|
||||
D1 --> E[Create DBYOLODataset]
|
||||
D2 --> E
|
||||
|
||||
E --> F[Split Train/Val/Test\n80% / 10% / 10%\nDocument-level, seed=42]
|
||||
|
||||
F --> G[Export to YOLO Format]
|
||||
G --> G1[Copy images to\ntrain/val/test dirs]
|
||||
G --> G2[Generate .txt labels\nclass x_center y_center w h]
|
||||
G --> G3[Generate dataset.yaml\n+ classes.txt]
|
||||
G --> G4[Coordinate Conversion\nPDF points 72DPI -> render DPI\nNormalize to 0-1]
|
||||
|
||||
G1 --> H{--export-only?}
|
||||
G2 --> H
|
||||
G3 --> H
|
||||
G4 --> H
|
||||
|
||||
H -- Yes --> Z[Done - Dataset exported]
|
||||
H -- No --> I[Load YOLO Model]
|
||||
|
||||
I --> I1{--resume?}
|
||||
I1 -- Yes --> I2[Load last.pt checkpoint]
|
||||
I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt]
|
||||
|
||||
I2 --> J[Configure Training]
|
||||
I3 --> J
|
||||
|
||||
J --> J1[Conservative Augmentation\nrotation=5 deg, translate=5%\nno flip, no mosaic, no mixup]
|
||||
J --> J2[imgsz=1280, pretrained=True]
|
||||
|
||||
J1 --> K[model.train\nUltralytics Training Loop]
|
||||
J2 --> K
|
||||
|
||||
K --> L[Training Outputs\nruns/train/name/]
|
||||
L --> L1[weights/best.pt\nweights/last.pt]
|
||||
L --> L2[results.csv + results.png\nTraining curves]
|
||||
L --> L3[PR curves, F1 curves\nConfusion matrix]
|
||||
|
||||
L1 --> M[Test Set Validation\nmodel.val split=test]
|
||||
M --> N[Report Metrics\nmAP@0.5 = 93.5%\nmAP@0.5-0.95]
|
||||
|
||||
N --> O[Close DB Connection]
|
||||
|
||||
style A fill:#4a90d9,color:#fff
|
||||
style K fill:#e67e22,color:#fff
|
||||
style N fill:#27ae60,color:#fff
|
||||
style Z fill:#95a5a6,color:#fff
|
||||
1391
frontend/package-lock.json
generated
1391
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -6,27 +6,36 @@
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview"
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:coverage": "vitest run --coverage"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"axios": "^1.6.7",
|
||||
"clsx": "^2.1.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"lucide-react": "^0.563.0",
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"lucide-react": "^0.563.0",
|
||||
"recharts": "^3.7.0",
|
||||
"axios": "^1.6.7",
|
||||
"react-router-dom": "^6.22.0",
|
||||
"zustand": "^4.5.0",
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"clsx": "^2.1.0"
|
||||
"recharts": "^3.7.0",
|
||||
"zustand": "^4.5.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/react": "^16.3.2",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
"@types/node": "^22.14.0",
|
||||
"@vitejs/plugin-react": "^5.0.0",
|
||||
"@vitest/coverage-v8": "^4.0.18",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"jsdom": "^27.4.0",
|
||||
"postcss": "^8.4.35",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"typescript": "~5.8.2",
|
||||
"vite": "^6.2.0",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"postcss": "^8.4.35"
|
||||
"vitest": "^4.0.18"
|
||||
}
|
||||
}
|
||||
|
||||
32
frontend/src/components/Badge.test.tsx
Normal file
32
frontend/src/components/Badge.test.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { Badge } from './Badge';
|
||||
import { DocumentStatus } from '../types';
|
||||
|
||||
describe('Badge', () => {
|
||||
it('renders Exported badge with check icon', () => {
|
||||
render(<Badge status="Exported" />);
|
||||
expect(screen.getByText('Exported')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Pending status', () => {
|
||||
render(<Badge status={DocumentStatus.PENDING} />);
|
||||
expect(screen.getByText('Pending')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Verified status', () => {
|
||||
render(<Badge status={DocumentStatus.VERIFIED} />);
|
||||
expect(screen.getByText('Verified')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Labeled status', () => {
|
||||
render(<Badge status={DocumentStatus.LABELED} />);
|
||||
expect(screen.getByText('Labeled')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Partial status with warning indicator', () => {
|
||||
render(<Badge status={DocumentStatus.PARTIAL} />);
|
||||
expect(screen.getByText('Partial')).toBeInTheDocument();
|
||||
expect(screen.getByText('!')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
38
frontend/src/components/Button.test.tsx
Normal file
38
frontend/src/components/Button.test.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { Button } from './Button';
|
||||
|
||||
describe('Button', () => {
|
||||
it('renders children text', () => {
|
||||
render(<Button>Click me</Button>);
|
||||
expect(screen.getByRole('button', { name: 'Click me' })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onClick handler', async () => {
|
||||
const user = userEvent.setup();
|
||||
const onClick = vi.fn();
|
||||
render(<Button onClick={onClick}>Click</Button>);
|
||||
await user.click(screen.getByRole('button'));
|
||||
expect(onClick).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('is disabled when disabled prop is set', () => {
|
||||
render(<Button disabled>Disabled</Button>);
|
||||
expect(screen.getByRole('button')).toBeDisabled();
|
||||
});
|
||||
|
||||
it('applies variant styles', () => {
|
||||
const { rerender } = render(<Button variant="primary">Primary</Button>);
|
||||
const btn = screen.getByRole('button');
|
||||
expect(btn.className).toContain('bg-warm-text-secondary');
|
||||
|
||||
rerender(<Button variant="secondary">Secondary</Button>);
|
||||
expect(screen.getByRole('button').className).toContain('border');
|
||||
});
|
||||
|
||||
it('applies size styles', () => {
|
||||
render(<Button size="sm">Small</Button>);
|
||||
expect(screen.getByRole('button').className).toContain('h-8');
|
||||
});
|
||||
});
|
||||
1
frontend/tests/setup.ts
Normal file
1
frontend/tests/setup.ts
Normal file
@@ -0,0 +1 @@
|
||||
import '@testing-library/jest-dom';
|
||||
19
frontend/vitest.config.ts
Normal file
19
frontend/vitest.config.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
/// <reference types="vitest/config" />
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'jsdom',
|
||||
setupFiles: ['./tests/setup.ts'],
|
||||
include: ['src/**/*.test.{ts,tsx}', 'tests/**/*.test.{ts,tsx}'],
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reporter: ['text', 'lcov'],
|
||||
include: ['src/**/*.{ts,tsx}'],
|
||||
exclude: ['src/**/*.test.{ts,tsx}', 'src/main.tsx'],
|
||||
},
|
||||
},
|
||||
});
|
||||
18
migrations/003_training_tasks.sql
Normal file
18
migrations/003_training_tasks.sql
Normal file
@@ -0,0 +1,18 @@
|
||||
-- Training tasks table for async training job management.
|
||||
-- Inference service writes pending tasks; training service polls and executes.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS training_tasks (
|
||||
task_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
||||
config JSONB,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
scheduled_at TIMESTAMP WITH TIME ZONE,
|
||||
started_at TIMESTAMP WITH TIME ZONE,
|
||||
completed_at TIMESTAMP WITH TIME ZONE,
|
||||
error_message TEXT,
|
||||
model_path TEXT,
|
||||
metrics JSONB
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_status ON training_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_created ON training_tasks(created_at);
|
||||
39
migrations/004_training_datasets.sql
Normal file
39
migrations/004_training_datasets.sql
Normal file
@@ -0,0 +1,39 @@
|
||||
-- Training Datasets Management
|
||||
-- Tracks dataset-document relationships and train/val/test splits
|
||||
|
||||
CREATE TABLE IF NOT EXISTS training_datasets (
|
||||
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'building',
|
||||
train_ratio FLOAT NOT NULL DEFAULT 0.8,
|
||||
val_ratio FLOAT NOT NULL DEFAULT 0.1,
|
||||
seed INTEGER NOT NULL DEFAULT 42,
|
||||
total_documents INTEGER NOT NULL DEFAULT 0,
|
||||
total_images INTEGER NOT NULL DEFAULT 0,
|
||||
total_annotations INTEGER NOT NULL DEFAULT 0,
|
||||
dataset_path VARCHAR(512),
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS dataset_documents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
|
||||
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
|
||||
split VARCHAR(10) NOT NULL,
|
||||
page_count INTEGER NOT NULL DEFAULT 0,
|
||||
annotation_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(dataset_id, document_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
|
||||
|
||||
-- Add dataset_id to training_tasks
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
|
||||
25
packages/inference/Dockerfile
Normal file
25
packages/inference/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install shared package
|
||||
COPY packages/shared /app/packages/shared
|
||||
RUN pip install --no-cache-dir -e /app/packages/shared
|
||||
|
||||
# Install inference package
|
||||
COPY packages/inference /app/packages/inference
|
||||
RUN pip install --no-cache-dir -e /app/packages/inference
|
||||
|
||||
# Copy frontend (if needed)
|
||||
COPY frontend /app/frontend
|
||||
|
||||
WORKDIR /app/packages/inference
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]
|
||||
105
packages/inference/inference/azure/aci_trigger.py
Normal file
105
packages/inference/inference/azure/aci_trigger.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Trigger training jobs on Azure Container Instances."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Azure SDK is optional; only needed if using ACI trigger
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
|
||||
from azure.mgmt.containerinstance.models import (
|
||||
Container,
|
||||
ContainerGroup,
|
||||
EnvironmentVariable,
|
||||
GpuResource,
|
||||
ResourceRequests,
|
||||
ResourceRequirements,
|
||||
)
|
||||
|
||||
_AZURE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AZURE_SDK_AVAILABLE = False
|
||||
|
||||
|
||||
def start_training_container(task_id: str) -> str | None:
|
||||
"""
|
||||
Start an Azure Container Instance for a training task.
|
||||
|
||||
Returns the container group name if successful, None otherwise.
|
||||
Requires environment variables:
|
||||
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
|
||||
"""
|
||||
if not _AZURE_SDK_AVAILABLE:
|
||||
logger.warning(
|
||||
"Azure SDK not installed. Install azure-mgmt-containerinstance "
|
||||
"and azure-identity to use ACI trigger."
|
||||
)
|
||||
return None
|
||||
|
||||
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
|
||||
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
|
||||
image = os.environ.get(
|
||||
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
|
||||
)
|
||||
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
|
||||
location = os.environ.get("AZURE_LOCATION", "eastus")
|
||||
|
||||
if not subscription_id:
|
||||
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
|
||||
return None
|
||||
|
||||
credential = DefaultAzureCredential()
|
||||
client = ContainerInstanceManagementClient(credential, subscription_id)
|
||||
|
||||
container_name = f"training-{task_id[:8]}"
|
||||
|
||||
env_vars = [
|
||||
EnvironmentVariable(name="TASK_ID", value=task_id),
|
||||
]
|
||||
|
||||
# Pass DB connection securely
|
||||
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
|
||||
val = os.environ.get(var, "")
|
||||
if val:
|
||||
env_vars.append(EnvironmentVariable(name=var, value=val))
|
||||
|
||||
db_password = os.environ.get("DB_PASSWORD", "")
|
||||
if db_password:
|
||||
env_vars.append(
|
||||
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
|
||||
)
|
||||
|
||||
container = Container(
|
||||
name=container_name,
|
||||
image=image,
|
||||
resources=ResourceRequirements(
|
||||
requests=ResourceRequests(
|
||||
cpu=4,
|
||||
memory_in_gb=16,
|
||||
gpu=GpuResource(count=1, sku=gpu_sku),
|
||||
)
|
||||
),
|
||||
environment_variables=env_vars,
|
||||
command=[
|
||||
"python",
|
||||
"run_training.py",
|
||||
"--task-id",
|
||||
task_id,
|
||||
],
|
||||
)
|
||||
|
||||
group = ContainerGroup(
|
||||
location=location,
|
||||
containers=[container],
|
||||
os_type="Linux",
|
||||
restart_policy="Never",
|
||||
)
|
||||
|
||||
logger.info("Creating ACI container group: %s", container_name)
|
||||
client.container_groups.begin_create_or_update(
|
||||
resource_group, container_name, group
|
||||
)
|
||||
|
||||
return container_name
|
||||
@@ -10,8 +10,7 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from src.config import DEFAULT_DPI
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def main():
|
||||
@@ -91,7 +90,7 @@ def main():
|
||||
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||
print(f"Model: {model_path}")
|
||||
|
||||
from ..inference import InferencePipeline
|
||||
from inference.pipeline import InferencePipeline
|
||||
|
||||
# Initialize pipeline
|
||||
pipeline = InferencePipeline(
|
||||
@@ -13,9 +13,8 @@ from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
@@ -121,7 +120,7 @@ def main() -> None:
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create config
|
||||
from src.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
||||
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
@@ -142,7 +141,7 @@ def main() -> None:
|
||||
|
||||
# Create and run app
|
||||
import uvicorn
|
||||
from src.web.app import create_app
|
||||
from inference.web.app import create_app
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
0
packages/inference/inference/data/__init__.py
Normal file
0
packages/inference/inference/data/__init__.py
Normal file
@@ -12,8 +12,8 @@ from uuid import UUID
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from src.data.database import get_session_context
|
||||
from src.data.admin_models import (
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import (
|
||||
AdminToken,
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
@@ -23,6 +23,8 @@ from src.data.admin_models import (
|
||||
BatchUploadFile,
|
||||
TrainingDocumentLink,
|
||||
AnnotationHistory,
|
||||
TrainingDataset,
|
||||
DatasetDocument,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -174,7 +176,7 @@ class AdminDB:
|
||||
|
||||
# For has_annotations filter, we need to join with annotations
|
||||
if has_annotations is not None:
|
||||
from src.data.admin_models import AdminAnnotation
|
||||
from inference.data.admin_models import AdminAnnotation
|
||||
|
||||
if has_annotations:
|
||||
# Documents WITH annotations
|
||||
@@ -200,7 +202,7 @@ class AdminDB:
|
||||
|
||||
# Apply has_annotations filter
|
||||
if has_annotations is not None:
|
||||
from src.data.admin_models import AdminAnnotation
|
||||
from inference.data.admin_models import AdminAnnotation
|
||||
|
||||
if has_annotations:
|
||||
statement = (
|
||||
@@ -456,6 +458,7 @@ class AdminDB:
|
||||
scheduled_at: datetime | None = None,
|
||||
cron_expression: str | None = None,
|
||||
is_recurring: bool = False,
|
||||
dataset_id: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new training task."""
|
||||
with get_session_context() as session:
|
||||
@@ -469,6 +472,7 @@ class AdminDB:
|
||||
cron_expression=cron_expression,
|
||||
is_recurring=is_recurring,
|
||||
status="scheduled" if scheduled_at else "pending",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
session.add(task)
|
||||
session.flush()
|
||||
@@ -1154,3 +1158,159 @@ class AdminDB:
|
||||
session.refresh(annotation)
|
||||
session.expunge(annotation)
|
||||
return annotation
|
||||
|
||||
# ==========================================================================
|
||||
# Training Dataset Operations
|
||||
# ==========================================================================
|
||||
|
||||
def create_dataset(
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
) -> TrainingDataset:
|
||||
"""Create a new training dataset."""
|
||||
with get_session_context() as session:
|
||||
dataset = TrainingDataset(
|
||||
name=name,
|
||||
description=description,
|
||||
train_ratio=train_ratio,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed,
|
||||
)
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
session.refresh(dataset)
|
||||
session.expunge(dataset)
|
||||
return dataset
|
||||
|
||||
def get_dataset(self, dataset_id: str | UUID) -> TrainingDataset | None:
|
||||
"""Get a dataset by ID."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if dataset:
|
||||
session.expunge(dataset)
|
||||
return dataset
|
||||
|
||||
def get_datasets(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[TrainingDataset], int]:
|
||||
"""List datasets with optional status filter."""
|
||||
with get_session_context() as session:
|
||||
query = select(TrainingDataset)
|
||||
count_query = select(func.count()).select_from(TrainingDataset)
|
||||
if status:
|
||||
query = query.where(TrainingDataset.status == status)
|
||||
count_query = count_query.where(TrainingDataset.status == status)
|
||||
total = session.exec(count_query).one()
|
||||
datasets = session.exec(
|
||||
query.order_by(TrainingDataset.created_at.desc()).offset(offset).limit(limit)
|
||||
).all()
|
||||
for d in datasets:
|
||||
session.expunge(d)
|
||||
return list(datasets), total
|
||||
|
||||
def update_dataset_status(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
status: str,
|
||||
error_message: str | None = None,
|
||||
total_documents: int | None = None,
|
||||
total_images: int | None = None,
|
||||
total_annotations: int | None = None,
|
||||
dataset_path: str | None = None,
|
||||
) -> None:
|
||||
"""Update dataset status and optional totals."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return
|
||||
dataset.status = status
|
||||
dataset.updated_at = datetime.utcnow()
|
||||
if error_message is not None:
|
||||
dataset.error_message = error_message
|
||||
if total_documents is not None:
|
||||
dataset.total_documents = total_documents
|
||||
if total_images is not None:
|
||||
dataset.total_images = total_images
|
||||
if total_annotations is not None:
|
||||
dataset.total_annotations = total_annotations
|
||||
if dataset_path is not None:
|
||||
dataset.dataset_path = dataset_path
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
def add_dataset_documents(
|
||||
self,
|
||||
dataset_id: str | UUID,
|
||||
documents: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Batch insert documents into a dataset.
|
||||
|
||||
Each dict: {document_id, split, page_count, annotation_count}
|
||||
"""
|
||||
with get_session_context() as session:
|
||||
for doc in documents:
|
||||
dd = DatasetDocument(
|
||||
dataset_id=UUID(str(dataset_id)),
|
||||
document_id=UUID(str(doc["document_id"])),
|
||||
split=doc["split"],
|
||||
page_count=doc.get("page_count", 0),
|
||||
annotation_count=doc.get("annotation_count", 0),
|
||||
)
|
||||
session.add(dd)
|
||||
session.commit()
|
||||
|
||||
def get_dataset_documents(
|
||||
self, dataset_id: str | UUID
|
||||
) -> list[DatasetDocument]:
|
||||
"""Get all documents in a dataset."""
|
||||
with get_session_context() as session:
|
||||
results = session.exec(
|
||||
select(DatasetDocument)
|
||||
.where(DatasetDocument.dataset_id == UUID(str(dataset_id)))
|
||||
).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_documents_by_ids(
|
||||
self, document_ids: list[str]
|
||||
) -> list[AdminDocument]:
|
||||
"""Get documents by list of IDs."""
|
||||
with get_session_context() as session:
|
||||
uuids = [UUID(str(did)) for did in document_ids]
|
||||
results = session.exec(
|
||||
select(AdminDocument).where(AdminDocument.document_id.in_(uuids))
|
||||
).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def get_annotations_for_document(
|
||||
self, document_id: str | UUID
|
||||
) -> list[AdminAnnotation]:
|
||||
"""Get all annotations for a document."""
|
||||
with get_session_context() as session:
|
||||
results = session.exec(
|
||||
select(AdminAnnotation)
|
||||
.where(AdminAnnotation.document_id == UUID(str(document_id)))
|
||||
).all()
|
||||
for r in results:
|
||||
session.expunge(r)
|
||||
return list(results)
|
||||
|
||||
def delete_dataset(self, dataset_id: str | UUID) -> bool:
|
||||
"""Delete a dataset and its document links (CASCADE)."""
|
||||
with get_session_context() as session:
|
||||
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||
if not dataset:
|
||||
return False
|
||||
session.delete(dataset)
|
||||
session.commit()
|
||||
return True
|
||||
@@ -131,6 +131,7 @@ class TrainingTask(SQLModel, table=True):
|
||||
# Status: pending, scheduled, running, completed, failed, cancelled
|
||||
task_type: str = Field(default="train", max_length=20)
|
||||
# Task type: train, finetune
|
||||
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||
# Training configuration
|
||||
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||
# Schedule settings
|
||||
@@ -225,6 +226,42 @@ class BatchUploadFile(SQLModel, table=True):
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TrainingDataset(SQLModel, table=True):
|
||||
"""Training dataset containing selected documents with train/val/test splits."""
|
||||
|
||||
__tablename__ = "training_datasets"
|
||||
|
||||
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str = Field(max_length=255)
|
||||
description: str | None = Field(default=None)
|
||||
status: str = Field(default="building", max_length=20, index=True)
|
||||
# Status: building, ready, training, archived, failed
|
||||
train_ratio: float = Field(default=0.8)
|
||||
val_ratio: float = Field(default=0.1)
|
||||
seed: int = Field(default=42)
|
||||
total_documents: int = Field(default=0)
|
||||
total_images: int = Field(default=0)
|
||||
total_annotations: int = Field(default=0)
|
||||
dataset_path: str | None = Field(default=None, max_length=512)
|
||||
error_message: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class DatasetDocument(SQLModel, table=True):
|
||||
"""Junction table linking datasets to documents with split assignment."""
|
||||
|
||||
__tablename__ = "dataset_documents"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
split: str = Field(max_length=10) # train, val, test
|
||||
page_count: int = Field(default=0)
|
||||
annotation_count: int = Field(default=0)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrainingDocumentLink(SQLModel, table=True):
|
||||
"""Junction table linking training tasks to documents."""
|
||||
|
||||
@@ -336,4 +373,35 @@ class TrainingTaskRead(SQLModel):
|
||||
error_message: str | None
|
||||
result_metrics: dict[str, Any] | None
|
||||
model_path: str | None
|
||||
dataset_id: UUID | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TrainingDatasetRead(SQLModel):
|
||||
"""Training dataset response model."""
|
||||
|
||||
dataset_id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
dataset_path: str | None
|
||||
error_message: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DatasetDocumentRead(SQLModel):
|
||||
"""Dataset document response model."""
|
||||
|
||||
id: UUID
|
||||
dataset_id: UUID
|
||||
document_id: UUID
|
||||
split: str
|
||||
page_count: int
|
||||
annotation_count: int
|
||||
@@ -12,8 +12,8 @@ from uuid import UUID
|
||||
from sqlalchemy import func, text
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from src.data.database import get_session_context, create_db_and_tables, close_engine
|
||||
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent
|
||||
from inference.data.database import get_session_context, create_db_and_tables, close_engine
|
||||
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,8 +13,7 @@ from sqlalchemy import text
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from src.config import get_db_connection_string
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,8 +51,8 @@ def get_engine():
|
||||
|
||||
def create_db_and_tables() -> None:
|
||||
"""Create all database tables."""
|
||||
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||
from src.data.admin_models import ( # noqa: F401
|
||||
from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
|
||||
from inference.data.admin_models import ( # noqa: F401
|
||||
AdminToken,
|
||||
AdminDocument,
|
||||
AdminAnnotation,
|
||||
@@ -92,7 +92,7 @@ constructors or methods. The values here serve as sensible defaults
|
||||
based on Swedish invoice processing requirements.
|
||||
|
||||
Example:
|
||||
from src.inference.constants import DEFAULT_CONFIDENCE_THRESHOLD
|
||||
from inference.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD
|
||||
|
||||
detector = YOLODetector(
|
||||
model_path="model.pt",
|
||||
@@ -17,7 +17,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
|
||||
from src.exceptions import CustomerNumberParseError
|
||||
from shared.exceptions import CustomerNumberParseError
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -4,7 +4,7 @@ Field Extractor Module
|
||||
Extracts and validates field values from detected regions.
|
||||
|
||||
This module is used during inference to extract values from OCR text.
|
||||
It uses shared utilities from src.utils for text cleaning and validation.
|
||||
It uses shared utilities from shared.utils for text cleaning and validation.
|
||||
|
||||
Enhanced features:
|
||||
- Multi-source fusion with confidence weighting
|
||||
@@ -24,10 +24,10 @@ from PIL import Image
|
||||
from .yolo_detector import Detection, CLASS_TO_FIELD
|
||||
|
||||
# Import shared utilities for text cleaning and validation
|
||||
from src.utils.text_cleaner import TextCleaner
|
||||
from src.utils.validators import FieldValidators
|
||||
from src.utils.fuzzy_matcher import FuzzyMatcher
|
||||
from src.utils.ocr_corrections import OCRCorrections
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
from shared.utils.validators import FieldValidators
|
||||
from shared.utils.fuzzy_matcher import FuzzyMatcher
|
||||
from shared.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
# Import new unified parsers
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
@@ -104,7 +104,7 @@ class FieldExtractor:
|
||||
def ocr_engine(self):
|
||||
"""Lazy-load OCR engine only when needed."""
|
||||
if self._ocr_engine is None:
|
||||
from ..ocr import OCREngine
|
||||
from shared.ocr import OCREngine
|
||||
self._ocr_engine = OCREngine(lang=self.ocr_lang)
|
||||
return self._ocr_engine
|
||||
|
||||
@@ -21,7 +21,7 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from src.exceptions import PaymentLineParseError
|
||||
from shared.exceptions import PaymentLineParseError
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -144,7 +144,7 @@ class InferencePipeline:
|
||||
Returns:
|
||||
InferenceResult with extracted fields
|
||||
"""
|
||||
from ..pdf.renderer import render_pdf_to_images
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
@@ -381,8 +381,8 @@ class InferencePipeline:
|
||||
|
||||
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
||||
"""Run full-page OCR fallback."""
|
||||
from ..pdf.renderer import render_pdf_to_images
|
||||
from ..ocr import OCREngine
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from shared.ocr import OCREngine
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
@@ -189,7 +189,7 @@ class YOLODetector:
|
||||
Returns:
|
||||
Dict mapping page number to list of detections
|
||||
"""
|
||||
from ..pdf.renderer import render_pdf_to_images
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
@@ -16,7 +16,7 @@ from datetime import datetime
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
@dataclass
|
||||
8
packages/inference/inference/web/admin_routes_new.py
Normal file
8
packages/inference/inference/web/admin_routes_new.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Backward compatibility shim for admin_routes.py
|
||||
|
||||
DEPRECATED: Import from inference.web.api.v1.admin.documents instead.
|
||||
"""
|
||||
from inference.web.api.v1.admin.documents import *
|
||||
|
||||
__all__ = ["create_admin_router"]
|
||||
0
packages/inference/inference/web/api/__init__.py
Normal file
0
packages/inference/inference/web/api/__init__.py
Normal file
0
packages/inference/inference/web/api/v1/__init__.py
Normal file
0
packages/inference/inference/web/api/v1/__init__.py
Normal file
19
packages/inference/inference/web/api/v1/admin/__init__.py
Normal file
19
packages/inference/inference/web/api/v1/admin/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Admin API v1
|
||||
|
||||
Document management, annotations, and training endpoints.
|
||||
"""
|
||||
|
||||
from inference.web.api.v1.admin.annotations import create_annotation_router
|
||||
from inference.web.api.v1.admin.auth import create_auth_router
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.api.v1.admin.locks import create_locks_router
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
|
||||
__all__ = [
|
||||
"create_annotation_router",
|
||||
"create_auth_router",
|
||||
"create_documents_router",
|
||||
"create_locks_router",
|
||||
"create_training_router",
|
||||
]
|
||||
@@ -12,11 +12,11 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.services.autolabel import get_auto_label_service
|
||||
from src.web.schemas.admin import (
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.services.autolabel import get_auto_label_service
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationItem,
|
||||
AnnotationListResponse,
|
||||
@@ -31,7 +31,7 @@ from src.web.schemas.admin import (
|
||||
AutoLabelResponse,
|
||||
BoundingBox,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,12 +10,12 @@ from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
AdminTokenCreate,
|
||||
AdminTokenResponse,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,9 +11,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from src.web.config import DEFAULT_DPI, StorageConfig
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
from inference.web.config import DEFAULT_DPI, StorageConfig
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationItem,
|
||||
AnnotationSource,
|
||||
AutoLabelStatus,
|
||||
@@ -27,7 +27,7 @@ from src.web.schemas.admin import (
|
||||
ModelMetrics,
|
||||
TrainingHistoryItem,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -142,8 +142,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||
|
||||
# Update file path in database
|
||||
from src.data.database import get_session_context
|
||||
from src.data.admin_models import AdminDocument
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminDocument
|
||||
with get_session_context() as session:
|
||||
doc = session.get(AdminDocument, UUID(document_id))
|
||||
if doc:
|
||||
@@ -520,7 +520,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
# If marking as labeled, save annotations to PostgreSQL DocumentDB
|
||||
db_save_result = None
|
||||
if status == "labeled":
|
||||
from src.web.services.db_autolabel import save_manual_annotations_to_document_db
|
||||
from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
|
||||
|
||||
# Get all annotations for this document
|
||||
annotations = db.get_annotations_for_document(document_id)
|
||||
@@ -10,12 +10,12 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from src.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from src.web.schemas.admin import (
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
AnnotationLockRequest,
|
||||
AnnotationLockResponse,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Admin Training API Routes
|
||||
|
||||
FastAPI endpoints for training task management and scheduling.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
from .tasks import register_task_routes
|
||||
from .documents import register_document_routes
|
||||
from .export import register_export_routes
|
||||
from .datasets import register_dataset_routes
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
"""Create training API router."""
|
||||
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
|
||||
|
||||
register_task_routes(router)
|
||||
register_document_routes(router)
|
||||
register_export_routes(router)
|
||||
register_dataset_routes(router)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = ["create_training_router", "_validate_uuid"]
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Shared utilities for training routes."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {name} format. Must be a valid UUID.",
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Training Dataset Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetDocumentItem,
|
||||
DatasetListItem,
|
||||
DatasetListResponse,
|
||||
DatasetResponse,
|
||||
DatasetTrainRequest,
|
||||
TrainingStatus,
|
||||
TrainingTaskResponse,
|
||||
)
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_dataset_routes(router: APIRouter) -> None:
|
||||
"""Register dataset endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/datasets",
|
||||
response_model=DatasetResponse,
|
||||
summary="Create training dataset",
|
||||
description="Create a dataset from selected documents with train/val/test splits.",
|
||||
)
|
||||
async def create_dataset(
|
||||
request: DatasetCreateRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DatasetResponse:
|
||||
"""Create a training dataset from document IDs."""
|
||||
from pathlib import Path
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
dataset = db.create_dataset(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
|
||||
try:
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=request.document_ids,
|
||||
train_ratio=request.train_ratio,
|
||||
val_ratio=request.val_ratio,
|
||||
seed=request.seed,
|
||||
admin_images_dir=Path("data/admin_images"),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return DatasetResponse(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
name=dataset.name,
|
||||
status="ready",
|
||||
message="Dataset created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets",
|
||||
response_model=DatasetListResponse,
|
||||
summary="List datasets",
|
||||
)
|
||||
async def list_datasets(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> DatasetListResponse:
|
||||
"""List training datasets."""
|
||||
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
|
||||
return DatasetListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
datasets=[
|
||||
DatasetListItem(
|
||||
dataset_id=str(d.dataset_id),
|
||||
name=d.name,
|
||||
description=d.description,
|
||||
status=d.status,
|
||||
total_documents=d.total_documents,
|
||||
total_images=d.total_images,
|
||||
total_annotations=d.total_annotations,
|
||||
created_at=d.created_at,
|
||||
)
|
||||
for d in datasets
|
||||
],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/datasets/{dataset_id}",
|
||||
response_model=DatasetDetailResponse,
|
||||
summary="Get dataset detail",
|
||||
)
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> DatasetDetailResponse:
|
||||
"""Get dataset details with document list."""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
docs = db.get_dataset_documents(dataset_id)
|
||||
return DatasetDetailResponse(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
status=dataset.status,
|
||||
train_ratio=dataset.train_ratio,
|
||||
val_ratio=dataset.val_ratio,
|
||||
seed=dataset.seed,
|
||||
total_documents=dataset.total_documents,
|
||||
total_images=dataset.total_images,
|
||||
total_annotations=dataset.total_annotations,
|
||||
dataset_path=dataset.dataset_path,
|
||||
error_message=dataset.error_message,
|
||||
documents=[
|
||||
DatasetDocumentItem(
|
||||
document_id=str(d.document_id),
|
||||
split=d.split,
|
||||
page_count=d.page_count,
|
||||
annotation_count=d.annotation_count,
|
||||
)
|
||||
for d in docs
|
||||
],
|
||||
created_at=dataset.created_at,
|
||||
updated_at=dataset.updated_at,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/datasets/{dataset_id}",
|
||||
summary="Delete dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> dict:
|
||||
"""Delete a dataset and its files."""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
if dataset.dataset_path:
|
||||
dataset_dir = Path(dataset.dataset_path)
|
||||
if dataset_dir.exists():
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
db.delete_dataset(dataset_id)
|
||||
return {"message": "Dataset deleted"}
|
||||
|
||||
@router.post(
|
||||
"/datasets/{dataset_id}/train",
|
||||
response_model=TrainingTaskResponse,
|
||||
summary="Start training from dataset",
|
||||
)
|
||||
async def train_from_dataset(
|
||||
dataset_id: str,
|
||||
request: DatasetTrainRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a training task from a dataset."""
|
||||
_validate_uuid(dataset_id, "dataset_id")
|
||||
dataset = db.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
if dataset.status != "ready":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Dataset is not ready (status: {dataset.status})",
|
||||
)
|
||||
|
||||
config_dict = request.config.model_dump()
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type="train",
|
||||
config=config_dict,
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
message="Training task created from dataset",
|
||||
)
|
||||
@@ -0,0 +1,211 @@
|
||||
"""Training Documents and Models Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
ModelMetrics,
|
||||
TrainingDocumentItem,
|
||||
TrainingDocumentsResponse,
|
||||
TrainingModelItem,
|
||||
TrainingModelsResponse,
|
||||
TrainingStatus,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_document_routes(router: APIRouter) -> None:
|
||||
"""Register training document and model endpoints on the router."""
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
response_model=TrainingDocumentsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get documents for training",
|
||||
description="Get labeled documents available for training with filtering options.",
|
||||
)
|
||||
async def get_training_documents(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
has_annotations: Annotated[
|
||||
bool,
|
||||
Query(description="Only include documents with annotations"),
|
||||
] = True,
|
||||
min_annotation_count: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, description="Minimum annotation count"),
|
||||
] = None,
|
||||
exclude_used_in_training: Annotated[
|
||||
bool,
|
||||
Query(description="Exclude documents already used in training"),
|
||||
] = False,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingDocumentsResponse:
|
||||
"""Get documents available for training."""
|
||||
documents, total = db.get_documents_for_training(
|
||||
admin_token=admin_token,
|
||||
status="labeled",
|
||||
has_annotations=has_annotations,
|
||||
min_annotation_count=min_annotation_count,
|
||||
exclude_used_in_training=exclude_used_in_training,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = []
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
sources = {"manual": 0, "auto": 0}
|
||||
for ann in annotations:
|
||||
if ann.source in sources:
|
||||
sources[ann.source] += 1
|
||||
|
||||
training_links = db.get_document_training_tasks(doc.document_id)
|
||||
used_in_training = [str(link.task_id) for link in training_links]
|
||||
|
||||
items.append(
|
||||
TrainingDocumentItem(
|
||||
document_id=str(doc.document_id),
|
||||
filename=doc.filename,
|
||||
annotation_count=len(annotations),
|
||||
annotation_sources=sources,
|
||||
used_in_training=used_in_training,
|
||||
last_modified=doc.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingDocumentsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
documents=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models/{task_id}/download",
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Model not found"},
|
||||
},
|
||||
summary="Download trained model",
|
||||
description="Download trained model weights file.",
|
||||
)
|
||||
async def download_model(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
):
|
||||
"""Download trained model."""
|
||||
from fastapi.responses import FileResponse
|
||||
from pathlib import Path
|
||||
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
if not task.model_path:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not available for this task",
|
||||
)
|
||||
|
||||
model_path = Path(task.model_path)
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model file not found on disk",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=str(model_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=f"{task.name}_model.pt",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=TrainingModelsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Get trained models",
|
||||
description="Get list of trained models with metrics and download links.",
|
||||
)
|
||||
async def get_training_models(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status (completed, failed, etc.)"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingModelsResponse:
|
||||
"""Get list of trained models."""
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status if status else "completed",
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = []
|
||||
for task in tasks:
|
||||
metrics = ModelMetrics(
|
||||
mAP=task.metrics_mAP,
|
||||
precision=task.metrics_precision,
|
||||
recall=task.metrics_recall,
|
||||
)
|
||||
|
||||
download_url = None
|
||||
if task.model_path and task.status == "completed":
|
||||
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
|
||||
|
||||
items.append(
|
||||
TrainingModelItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
status=TrainingStatus(task.status),
|
||||
document_count=task.document_count,
|
||||
created_at=task.created_at,
|
||||
completed_at=task.completed_at,
|
||||
metrics=metrics,
|
||||
model_path=task.model_path,
|
||||
download_url=download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return TrainingModelsResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
models=items,
|
||||
)
|
||||
121
packages/inference/inference/web/api/v1/admin/training/export.py
Normal file
121
packages/inference/inference/web/api/v1/admin/training/export.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Training Export Endpoints."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
ExportRequest,
|
||||
ExportResponse,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_export_routes(router: APIRouter) -> None:
|
||||
"""Register export endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/export",
|
||||
response_model=ExportResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Export annotations",
|
||||
description="Export annotations in YOLO format for training.",
|
||||
)
|
||||
async def export_annotations(
|
||||
request: ExportRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> ExportResponse:
|
||||
"""Export annotations for training."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
if request.format not in ("yolo", "coco", "voc"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported export format: {request.format}",
|
||||
)
|
||||
|
||||
documents = db.get_labeled_documents_for_export(admin_token)
|
||||
|
||||
if not documents:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No labeled documents available for export",
|
||||
)
|
||||
|
||||
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
|
||||
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_docs = len(documents)
|
||||
train_count = int(total_docs * request.split_ratio)
|
||||
train_docs = documents[:train_count]
|
||||
val_docs = documents[train_count:]
|
||||
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
|
||||
for split, docs in [("train", train_docs), ("val", val_docs)]:
|
||||
for doc in docs:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
|
||||
if not annotations:
|
||||
continue
|
||||
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||
|
||||
if not page_annotations and not request.include_images:
|
||||
continue
|
||||
|
||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
continue
|
||||
|
||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||
dst_image = export_dir / "images" / split / image_name
|
||||
shutil.copy(src_image, dst_image)
|
||||
total_images += 1
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
|
||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||
path: {export_dir.absolute()}
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
nc: {len(FIELD_CLASSES)}
|
||||
names: {list(FIELD_CLASSES.values())}
|
||||
"""
|
||||
(export_dir / "data.yaml").write_text(yaml_content)
|
||||
|
||||
return ExportResponse(
|
||||
status="completed",
|
||||
export_path=str(export_dir),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
train_count=len(train_docs),
|
||||
val_count=len(val_docs),
|
||||
message=f"Exported {total_images} images with {total_annotations} annotations",
|
||||
)
|
||||
263
packages/inference/inference/web/api/v1/admin/training/tasks.py
Normal file
263
packages/inference/inference/web/api/v1/admin/training/tasks.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Training Task Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||
from inference.web.schemas.admin import (
|
||||
TrainingLogItem,
|
||||
TrainingLogsResponse,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
TrainingTaskDetailResponse,
|
||||
TrainingTaskItem,
|
||||
TrainingTaskListResponse,
|
||||
TrainingTaskResponse,
|
||||
TrainingType,
|
||||
)
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_task_routes(router: APIRouter) -> None:
|
||||
"""Register training task endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="Create training task",
|
||||
description="Create a new training task.",
|
||||
)
|
||||
async def create_training_task(
|
||||
request: TrainingTaskCreate,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a new training task."""
|
||||
config_dict = request.config.model_dump() if request.config else {}
|
||||
|
||||
task_id = db.create_training_task(
|
||||
admin_token=admin_token,
|
||||
name=request.name,
|
||||
task_type=request.task_type.value,
|
||||
description=request.description,
|
||||
config=config_dict,
|
||||
scheduled_at=request.scheduled_at,
|
||||
cron_expression=request.cron_expression,
|
||||
is_recurring=bool(request.cron_expression),
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
|
||||
message="Training task created successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks",
|
||||
response_model=TrainingTaskListResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
},
|
||||
summary="List training tasks",
|
||||
description="List all training tasks.",
|
||||
)
|
||||
async def list_training_tasks(
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
status: Annotated[
|
||||
str | None,
|
||||
Query(description="Filter by status"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="Page size"),
|
||||
] = 20,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingTaskListResponse:
|
||||
"""List training tasks."""
|
||||
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
|
||||
if status and status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
tasks, total = db.get_training_tasks_by_token(
|
||||
admin_token=admin_token,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
items = [
|
||||
TrainingTaskItem(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
scheduled_at=task.scheduled_at,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
return TrainingTaskListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
tasks=items,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
response_model=TrainingTaskDetailResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training task detail",
|
||||
description="Get training task details.",
|
||||
)
|
||||
async def get_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskDetailResponse:
|
||||
"""Get training task details."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
return TrainingTaskDetailResponse(
|
||||
task_id=str(task.task_id),
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
task_type=TrainingType(task.task_type),
|
||||
status=TrainingStatus(task.status),
|
||||
config=task.config,
|
||||
scheduled_at=task.scheduled_at,
|
||||
cron_expression=task.cron_expression,
|
||||
is_recurring=task.is_recurring,
|
||||
started_at=task.started_at,
|
||||
completed_at=task.completed_at,
|
||||
error_message=task.error_message,
|
||||
result_metrics=task.result_metrics,
|
||||
model_path=task.model_path,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/tasks/{task_id}/cancel",
|
||||
response_model=TrainingTaskResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
|
||||
},
|
||||
summary="Cancel training task",
|
||||
description="Cancel a pending or scheduled training task.",
|
||||
)
|
||||
async def cancel_training_task(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Cancel a training task."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
if task.status not in ("pending", "scheduled"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Cannot cancel task with status: {task.status}",
|
||||
)
|
||||
|
||||
success = db.cancel_training_task(task_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to cancel training task",
|
||||
)
|
||||
|
||||
return TrainingTaskResponse(
|
||||
task_id=task_id,
|
||||
status=TrainingStatus.CANCELLED,
|
||||
message="Training task cancelled successfully",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/logs",
|
||||
response_model=TrainingLogsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||
404: {"model": ErrorResponse, "description": "Task not found"},
|
||||
},
|
||||
summary="Get training logs",
|
||||
description="Get training task logs.",
|
||||
)
|
||||
async def get_training_logs(
|
||||
task_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
db: AdminDBDep,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=500, description="Maximum logs to return"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int,
|
||||
Query(ge=0, description="Offset"),
|
||||
] = 0,
|
||||
) -> TrainingLogsResponse:
|
||||
"""Get training logs."""
|
||||
_validate_uuid(task_id, "task_id")
|
||||
|
||||
task = db.get_training_task_by_token(task_id, admin_token)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Training task not found or does not belong to this token",
|
||||
)
|
||||
|
||||
logs = db.get_training_logs(task_id, limit, offset)
|
||||
|
||||
items = [
|
||||
TrainingLogItem(
|
||||
level=log.level,
|
||||
message=log.message,
|
||||
details=log.details,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return TrainingLogsResponse(
|
||||
task_id=task_id,
|
||||
logs=items,
|
||||
)
|
||||
@@ -14,10 +14,10 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db
|
||||
from src.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
|
||||
from src.web.workers.batch_queue import BatchTask, get_batch_queue
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
|
||||
from inference.web.workers.batch_queue import BatchTask, get_batch_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
16
packages/inference/inference/web/api/v1/public/__init__.py
Normal file
16
packages/inference/inference/web/api/v1/public/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Public API v1
|
||||
|
||||
Customer-facing endpoints for inference, async processing, and labeling.
|
||||
"""
|
||||
|
||||
from inference.web.api.v1.public.inference import create_inference_router
|
||||
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||
from inference.web.api.v1.public.labeling import create_labeling_router
|
||||
|
||||
__all__ = [
|
||||
"create_inference_router",
|
||||
"create_async_router",
|
||||
"set_async_service",
|
||||
"create_labeling_router",
|
||||
]
|
||||
@@ -11,13 +11,13 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||
|
||||
from src.web.dependencies import (
|
||||
from inference.web.dependencies import (
|
||||
ApiKeyDep,
|
||||
AsyncDBDep,
|
||||
PollRateLimitDep,
|
||||
SubmitRateLimitDep,
|
||||
)
|
||||
from src.web.schemas.inference import (
|
||||
from inference.web.schemas.inference import (
|
||||
AsyncRequestItem,
|
||||
AsyncRequestsListResponse,
|
||||
AsyncResultResponse,
|
||||
@@ -27,7 +27,7 @@ from src.web.schemas.inference import (
|
||||
DetectionResult,
|
||||
InferenceResult,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
|
||||
def _validate_request_id(request_id: str) -> None:
|
||||
@@ -15,17 +15,17 @@ from typing import TYPE_CHECKING
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.web.schemas.inference import (
|
||||
from inference.web.schemas.inference import (
|
||||
DetectionResult,
|
||||
HealthResponse,
|
||||
InferenceResponse,
|
||||
InferenceResult,
|
||||
)
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.web.services import InferenceService
|
||||
from src.web.config import StorageConfig
|
||||
from inference.web.services import InferenceService
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,13 +13,13 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.schemas.labeling import PreLabelResponse
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.schemas.labeling import PreLabelResponse
|
||||
from inference.web.schemas.common import ErrorResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.web.services import InferenceService
|
||||
from src.web.config import StorageConfig
|
||||
from inference.web.services import InferenceService
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,10 +17,10 @@ from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from .config import AppConfig, default_config
|
||||
from src.web.services import InferenceService
|
||||
from inference.web.services import InferenceService
|
||||
|
||||
# Public API imports
|
||||
from src.web.api.v1.public import (
|
||||
from inference.web.api.v1.public import (
|
||||
create_inference_router,
|
||||
create_async_router,
|
||||
set_async_service,
|
||||
@@ -28,28 +28,28 @@ from src.web.api.v1.public import (
|
||||
)
|
||||
|
||||
# Async processing imports
|
||||
from src.data.async_request_db import AsyncRequestDB
|
||||
from src.web.workers.async_queue import AsyncTaskQueue
|
||||
from src.web.services.async_processing import AsyncProcessingService
|
||||
from src.web.dependencies import init_dependencies
|
||||
from src.web.core.rate_limiter import RateLimiter
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
from inference.web.services.async_processing import AsyncProcessingService
|
||||
from inference.web.dependencies import init_dependencies
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
|
||||
# Admin API imports
|
||||
from src.web.api.v1.admin import (
|
||||
from inference.web.api.v1.admin import (
|
||||
create_annotation_router,
|
||||
create_auth_router,
|
||||
create_documents_router,
|
||||
create_locks_router,
|
||||
create_training_router,
|
||||
)
|
||||
from src.web.core.scheduler import start_scheduler, stop_scheduler
|
||||
from src.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
|
||||
from inference.web.core.scheduler import start_scheduler, stop_scheduler
|
||||
from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
|
||||
|
||||
# Batch upload imports
|
||||
from src.web.api.v1.batch.routes import router as batch_upload_router
|
||||
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from src.web.services.batch_upload import BatchUploadService
|
||||
from src.data.admin_db import AdminDB
|
||||
from inference.web.api.v1.batch.routes import router as batch_upload_router
|
||||
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -8,7 +8,7 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.config import DEFAULT_DPI, PATHS
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -4,10 +4,10 @@ Core Components
|
||||
Reusable core functionality: authentication, rate limiting, scheduling.
|
||||
"""
|
||||
|
||||
from src.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
|
||||
from src.web.core.rate_limiter import RateLimiter
|
||||
from src.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
|
||||
from src.web.core.autolabel_scheduler import (
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
|
||||
from inference.web.core.autolabel_scheduler import (
|
||||
start_autolabel_scheduler,
|
||||
stop_autolabel_scheduler,
|
||||
get_autolabel_scheduler,
|
||||
@@ -9,8 +9,8 @@ from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.database import get_session_context
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.database import get_session_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -8,8 +8,8 @@ import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.web.services.db_autolabel import (
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.services.db_autolabel import (
|
||||
get_pending_autolabel_documents,
|
||||
process_document_autolabel,
|
||||
)
|
||||
@@ -13,7 +13,7 @@ from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.data.async_request_db import AsyncRequestDB
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,7 +10,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,7 +86,8 @@ class TrainingScheduler:
|
||||
logger.info(f"Starting training task: {task_id}")
|
||||
|
||||
try:
|
||||
self._execute_task(task_id, task.config or {})
|
||||
dataset_id = getattr(task, "dataset_id", None)
|
||||
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._db.update_training_task_status(
|
||||
@@ -98,7 +99,9 @@ class TrainingScheduler:
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking pending tasks: {e}")
|
||||
|
||||
def _execute_task(self, task_id: str, config: dict[str, Any]) -> None:
|
||||
def _execute_task(
|
||||
self, task_id: str, config: dict[str, Any], dataset_id: str | None = None
|
||||
) -> None:
|
||||
"""Execute a training task."""
|
||||
# Update status to running
|
||||
self._db.update_training_task_status(task_id, "running")
|
||||
@@ -114,17 +117,25 @@ class TrainingScheduler:
|
||||
device = config.get("device", "0")
|
||||
project_name = config.get("project_name", "invoice_fields")
|
||||
|
||||
# Export annotations for training
|
||||
export_result = self._export_training_data(task_id)
|
||||
if not export_result:
|
||||
raise ValueError("Failed to export training data")
|
||||
|
||||
data_yaml = export_result["data_yaml"]
|
||||
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
# Use dataset if available, otherwise export from scratch
|
||||
if dataset_id:
|
||||
dataset = self._db.get_dataset(dataset_id)
|
||||
if not dataset or not dataset.dataset_path:
|
||||
raise ValueError(f"Dataset {dataset_id} not found or has no path")
|
||||
data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
|
||||
)
|
||||
else:
|
||||
export_result = self._export_training_data(task_id)
|
||||
if not export_result:
|
||||
raise ValueError("Failed to export training data")
|
||||
data_yaml = export_result["data_yaml"]
|
||||
self._db.add_training_log(
|
||||
task_id, "INFO",
|
||||
f"Exported {export_result['total_images']} images for training",
|
||||
)
|
||||
|
||||
# Run YOLO training
|
||||
result = self._run_yolo_training(
|
||||
@@ -157,7 +168,7 @@ class TrainingScheduler:
|
||||
"""Export training data for a task."""
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from src.data.admin_models import FIELD_CLASSES
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
|
||||
# Get all labeled documents
|
||||
documents = self._db.get_labeled_documents_for_export()
|
||||
@@ -9,8 +9,8 @@ from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, Request
|
||||
|
||||
from src.data.async_request_db import AsyncRequestDB
|
||||
from src.web.rate_limiter import RateLimiter
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.rate_limiter import RateLimiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,7 +13,7 @@ from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.data.async_request_db import AsyncRequestDB
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
11
packages/inference/inference/web/schemas/__init__.py
Normal file
11
packages/inference/inference/web/schemas/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
API Schemas
|
||||
|
||||
Pydantic models for request/response validation.
|
||||
"""
|
||||
|
||||
# Import everything from sub-modules for backward compatibility
|
||||
from inference.web.schemas.common import * # noqa: F401, F403
|
||||
from inference.web.schemas.admin import * # noqa: F401, F403
|
||||
from inference.web.schemas.inference import * # noqa: F401, F403
|
||||
from inference.web.schemas.labeling import * # noqa: F401, F403
|
||||
17
packages/inference/inference/web/schemas/admin/__init__.py
Normal file
17
packages/inference/inference/web/schemas/admin/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Admin API Request/Response Schemas
|
||||
|
||||
Pydantic models for admin API validation and serialization.
|
||||
"""
|
||||
|
||||
from .enums import * # noqa: F401, F403
|
||||
from .auth import * # noqa: F401, F403
|
||||
from .documents import * # noqa: F401, F403
|
||||
from .annotations import * # noqa: F401, F403
|
||||
from .training import * # noqa: F401, F403
|
||||
from .datasets import * # noqa: F401, F403
|
||||
|
||||
# Resolve forward references for DocumentDetailResponse
|
||||
from .documents import DocumentDetailResponse
|
||||
|
||||
DocumentDetailResponse.model_rebuild()
|
||||
152
packages/inference/inference/web/schemas/admin/annotations.py
Normal file
152
packages/inference/inference/web/schemas/admin/annotations.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Admin Annotation Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import AnnotationSource
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
"""Bounding box coordinates."""
|
||||
|
||||
x: int = Field(..., ge=0, description="X coordinate (pixels)")
|
||||
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
|
||||
width: int = Field(..., ge=1, description="Width (pixels)")
|
||||
height: int = Field(..., ge=1, description="Height (pixels)")
|
||||
|
||||
|
||||
class AnnotationCreate(BaseModel):
|
||||
"""Request to create an annotation."""
|
||||
|
||||
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
text_value: str | None = Field(None, description="Text value (optional)")
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
"""Request to update an annotation."""
|
||||
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
|
||||
bbox: BoundingBox | None = Field(None, description="New bounding box")
|
||||
text_value: str | None = Field(None, description="New text value")
|
||||
|
||||
|
||||
class AnnotationItem(BaseModel):
|
||||
"""Single annotation item."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
page_number: int = Field(..., ge=1, description="Page number")
|
||||
class_id: int = Field(..., ge=0, le=9, description="Class ID")
|
||||
class_name: str = Field(..., description="Class name")
|
||||
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
|
||||
normalized_bbox: dict[str, float] = Field(
|
||||
..., description="Normalized bbox (x_center, y_center, width, height)"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Text value")
|
||||
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
|
||||
source: AnnotationSource = Field(..., description="Annotation source")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class AnnotationResponse(BaseModel):
|
||||
"""Response for annotation operation."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationListResponse(BaseModel):
|
||||
"""Response for annotation list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
page_count: int = Field(..., ge=1, description="Total pages")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
annotations: list[AnnotationItem] = Field(
|
||||
default_factory=list, description="Annotation list"
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockRequest(BaseModel):
|
||||
"""Request to acquire annotation lock."""
|
||||
|
||||
duration_seconds: int = Field(
|
||||
default=300,
|
||||
ge=60,
|
||||
le=3600,
|
||||
description="Lock duration in seconds (60-3600)",
|
||||
)
|
||||
|
||||
|
||||
class AnnotationLockResponse(BaseModel):
|
||||
"""Response for annotation lock operation."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
locked: bool = Field(..., description="Whether lock was acquired/released")
|
||||
lock_expires_at: datetime | None = Field(
|
||||
None, description="Lock expiration time"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AutoLabelRequest(BaseModel):
|
||||
"""Request to trigger auto-labeling."""
|
||||
|
||||
field_values: dict[str, str] = Field(
|
||||
...,
|
||||
description="Field values to match (e.g., {'invoice_number': '12345'})",
|
||||
)
|
||||
replace_existing: bool = Field(
|
||||
default=False, description="Replace existing auto annotations"
|
||||
)
|
||||
|
||||
|
||||
class AutoLabelResponse(BaseModel):
|
||||
"""Response for auto-labeling."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
status: str = Field(..., description="Auto-labeling status")
|
||||
annotations_created: int = Field(
|
||||
default=0, ge=0, description="Number of annotations created"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationVerifyRequest(BaseModel):
|
||||
"""Request to verify an annotation."""
|
||||
|
||||
pass # No body needed, just POST to verify
|
||||
|
||||
|
||||
class AnnotationVerifyResponse(BaseModel):
|
||||
"""Response for annotation verification."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
is_verified: bool = Field(..., description="Verification status")
|
||||
verified_at: datetime = Field(..., description="Verification timestamp")
|
||||
verified_by: str = Field(..., description="Admin token who verified")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class AnnotationOverrideRequest(BaseModel):
|
||||
"""Request to override an annotation."""
|
||||
|
||||
bbox: dict[str, int] | None = Field(
|
||||
None, description="Updated bounding box {x, y, width, height}"
|
||||
)
|
||||
text_value: str | None = Field(None, description="Updated text value")
|
||||
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
|
||||
class_name: str | None = Field(None, description="Updated class name")
|
||||
reason: str | None = Field(None, description="Reason for override")
|
||||
|
||||
|
||||
class AnnotationOverrideResponse(BaseModel):
|
||||
"""Response for annotation override."""
|
||||
|
||||
annotation_id: str = Field(..., description="Annotation UUID")
|
||||
source: str = Field(..., description="New source (manual)")
|
||||
override_source: str | None = Field(None, description="Original source (auto)")
|
||||
original_annotation_id: str | None = Field(None, description="Original annotation ID")
|
||||
message: str = Field(..., description="Status message")
|
||||
history_id: str = Field(..., description="History record UUID")
|
||||
23
packages/inference/inference/web/schemas/admin/auth.py
Normal file
23
packages/inference/inference/web/schemas/admin/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Admin Auth Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AdminTokenCreate(BaseModel):
|
||||
"""Request to create an admin token."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Token name")
|
||||
expires_in_days: int | None = Field(
|
||||
None, ge=1, le=365, description="Token expiration in days (optional)"
|
||||
)
|
||||
|
||||
|
||||
class AdminTokenResponse(BaseModel):
|
||||
"""Response with created admin token."""
|
||||
|
||||
token: str = Field(..., description="Admin token")
|
||||
name: str = Field(..., description="Token name")
|
||||
expires_at: datetime | None = Field(None, description="Token expiration time")
|
||||
message: str = Field(..., description="Status message")
|
||||
85
packages/inference/inference/web/schemas/admin/datasets.py
Normal file
85
packages/inference/inference/web/schemas/admin/datasets.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Admin Dataset Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .training import TrainingConfig
|
||||
|
||||
|
||||
class DatasetCreateRequest(BaseModel):
|
||||
"""Request to create a training dataset."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
||||
description: str | None = Field(None, description="Optional description")
|
||||
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
|
||||
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
|
||||
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
|
||||
seed: int = Field(42, description="Random seed for split")
|
||||
|
||||
|
||||
class DatasetDocumentItem(BaseModel):
|
||||
"""Document within a dataset."""
|
||||
|
||||
document_id: str
|
||||
split: str
|
||||
page_count: int
|
||||
annotation_count: int
|
||||
|
||||
|
||||
class DatasetResponse(BaseModel):
|
||||
"""Response after creating a dataset."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class DatasetDetailResponse(BaseModel):
|
||||
"""Detailed dataset info with documents."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
seed: int
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
dataset_path: str | None
|
||||
error_message: str | None
|
||||
documents: list[DatasetDocumentItem]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DatasetListItem(BaseModel):
|
||||
"""Dataset in list view."""
|
||||
|
||||
dataset_id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
total_documents: int
|
||||
total_images: int
|
||||
total_annotations: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
"""Paginated dataset list."""
|
||||
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
datasets: list[DatasetListItem]
|
||||
|
||||
|
||||
class DatasetTrainRequest(BaseModel):
|
||||
"""Request to start training from a dataset."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Training task name")
|
||||
config: TrainingConfig = Field(..., description="Training configuration")
|
||||
103
packages/inference/inference/web/schemas/admin/documents.py
Normal file
103
packages/inference/inference/web/schemas/admin/documents.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Admin Document Schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import AutoLabelStatus, DocumentStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .annotations import AnnotationItem
|
||||
from .training import TrainingHistoryItem
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""Response for document upload."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_started: bool = Field(
|
||||
default=False, description="Whether auto-labeling was started"
|
||||
)
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class DocumentItem(BaseModel):
|
||||
"""Single document in list."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentListResponse(BaseModel):
|
||||
"""Response for document list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
documents: list[DocumentItem] = Field(
|
||||
default_factory=list, description="Document list"
|
||||
)
|
||||
|
||||
|
||||
class DocumentDetailResponse(BaseModel):
|
||||
"""Response for document detail."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||
content_type: str = Field(..., description="MIME type")
|
||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||
status: DocumentStatus = Field(..., description="Document status")
|
||||
auto_label_status: AutoLabelStatus | None = Field(
|
||||
None, description="Auto-labeling status"
|
||||
)
|
||||
auto_label_error: str | None = Field(None, description="Auto-labeling error")
|
||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||
csv_field_values: dict[str, str] | None = Field(
|
||||
None, description="CSV field values if uploaded via batch"
|
||||
)
|
||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||
annotation_lock_until: datetime | None = Field(
|
||||
None, description="Lock expiration time if document is locked"
|
||||
)
|
||||
annotations: list["AnnotationItem"] = Field(
|
||||
default_factory=list, description="Document annotations"
|
||||
)
|
||||
image_urls: list[str] = Field(
|
||||
default_factory=list, description="URLs to page images"
|
||||
)
|
||||
training_history: list["TrainingHistoryItem"] = Field(
|
||||
default_factory=list, description="Training tasks that used this document"
|
||||
)
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class DocumentStatsResponse(BaseModel):
|
||||
"""Document statistics response."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total documents")
|
||||
pending: int = Field(default=0, ge=0, description="Pending documents")
|
||||
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
||||
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
||||
exported: int = Field(default=0, ge=0, description="Exported documents")
|
||||
46
packages/inference/inference/web/schemas/admin/enums.py
Normal file
46
packages/inference/inference/web/schemas/admin/enums.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Admin API Enums."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
"""Document status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
AUTO_LABELING = "auto_labeling"
|
||||
LABELED = "labeled"
|
||||
EXPORTED = "exported"
|
||||
|
||||
|
||||
class AutoLabelStatus(str, Enum):
|
||||
"""Auto-labeling status enum."""
|
||||
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training task status enum."""
|
||||
|
||||
PENDING = "pending"
|
||||
SCHEDULED = "scheduled"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TrainingType(str, Enum):
|
||||
"""Training task type enum."""
|
||||
|
||||
TRAIN = "train"
|
||||
FINETUNE = "finetune"
|
||||
|
||||
|
||||
class AnnotationSource(str, Enum):
|
||||
"""Annotation source enum."""
|
||||
|
||||
MANUAL = "manual"
|
||||
AUTO = "auto"
|
||||
IMPORTED = "imported"
|
||||
202
packages/inference/inference/web/schemas/admin/training.py
Normal file
202
packages/inference/inference/web/schemas/admin/training.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Admin Training Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import TrainingStatus, TrainingType
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training configuration."""
|
||||
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name")
|
||||
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
|
||||
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
|
||||
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
|
||||
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
|
||||
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
|
||||
project_name: str = Field(
|
||||
default="invoice_fields", description="Training project name"
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskCreate(BaseModel):
|
||||
"""Request to create a training task."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Task name")
|
||||
description: str | None = Field(None, max_length=1000, description="Description")
|
||||
task_type: TrainingType = Field(
|
||||
default=TrainingType.TRAIN, description="Task type"
|
||||
)
|
||||
config: TrainingConfig = Field(
|
||||
default_factory=TrainingConfig, description="Training configuration"
|
||||
)
|
||||
scheduled_at: datetime | None = Field(
|
||||
None, description="Scheduled execution time"
|
||||
)
|
||||
cron_expression: str | None = Field(
|
||||
None, max_length=50, description="Cron expression for recurring tasks"
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskItem(BaseModel):
|
||||
"""Single training task in list."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskListResponse(BaseModel):
|
||||
"""Response for training task list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total tasks")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
|
||||
|
||||
|
||||
class TrainingTaskDetailResponse(BaseModel):
|
||||
"""Response for training task detail."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
name: str = Field(..., description="Task name")
|
||||
description: str | None = Field(None, description="Description")
|
||||
task_type: TrainingType = Field(..., description="Task type")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
config: dict[str, Any] | None = Field(None, description="Training configuration")
|
||||
scheduled_at: datetime | None = Field(None, description="Scheduled time")
|
||||
cron_expression: str | None = Field(None, description="Cron expression")
|
||||
is_recurring: bool = Field(default=False, description="Is recurring task")
|
||||
started_at: datetime | None = Field(None, description="Start time")
|
||||
completed_at: datetime | None = Field(None, description="Completion time")
|
||||
error_message: str | None = Field(None, description="Error message")
|
||||
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
|
||||
model_path: str | None = Field(None, description="Trained model path")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class TrainingTaskResponse(BaseModel):
|
||||
"""Response for training task operation."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
status: TrainingStatus = Field(..., description="Task status")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class TrainingLogItem(BaseModel):
|
||||
"""Single training log entry."""
|
||||
|
||||
level: str = Field(..., description="Log level")
|
||||
message: str = Field(..., description="Log message")
|
||||
details: dict[str, Any] | None = Field(None, description="Additional details")
|
||||
created_at: datetime = Field(..., description="Timestamp")
|
||||
|
||||
|
||||
class TrainingLogsResponse(BaseModel):
|
||||
"""Response for training logs."""
|
||||
|
||||
task_id: str = Field(..., description="Task UUID")
|
||||
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""Request to export annotations."""
|
||||
|
||||
format: str = Field(
|
||||
default="yolo", description="Export format (yolo, coco, voc)"
|
||||
)
|
||||
include_images: bool = Field(
|
||||
default=True, description="Include images in export"
|
||||
)
|
||||
split_ratio: float = Field(
|
||||
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
|
||||
)
|
||||
|
||||
|
||||
class ExportResponse(BaseModel):
|
||||
"""Response for export operation."""
|
||||
|
||||
status: str = Field(..., description="Export status")
|
||||
export_path: str = Field(..., description="Path to exported dataset")
|
||||
total_images: int = Field(..., ge=0, description="Total images exported")
|
||||
total_annotations: int = Field(..., ge=0, description="Total annotations")
|
||||
train_count: int = Field(..., ge=0, description="Training set count")
|
||||
val_count: int = Field(..., ge=0, description="Validation set count")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class TrainingDocumentItem(BaseModel):
|
||||
"""Document item for training page."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
filename: str = Field(..., description="Filename")
|
||||
annotation_count: int = Field(..., ge=0, description="Total annotations")
|
||||
annotation_sources: dict[str, int] = Field(
|
||||
..., description="Annotation counts by source (manual, auto)"
|
||||
)
|
||||
used_in_training: list[str] = Field(
|
||||
default_factory=list, description="List of training task IDs that used this document"
|
||||
)
|
||||
last_modified: datetime = Field(..., description="Last modification time")
|
||||
|
||||
|
||||
class TrainingDocumentsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/documents."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total document count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
documents: list[TrainingDocumentItem] = Field(
|
||||
default_factory=list, description="Documents available for training"
|
||||
)
|
||||
|
||||
|
||||
class ModelMetrics(BaseModel):
|
||||
"""Training model metrics."""
|
||||
|
||||
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
|
||||
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
|
||||
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
|
||||
|
||||
|
||||
class TrainingModelItem(BaseModel):
|
||||
"""Trained model item for model list."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Model name")
|
||||
status: TrainingStatus = Field(..., description="Training status")
|
||||
document_count: int = Field(..., ge=0, description="Documents used in training")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: datetime | None = Field(None, description="Completion timestamp")
|
||||
metrics: ModelMetrics = Field(..., description="Model metrics")
|
||||
model_path: str | None = Field(None, description="Path to model weights")
|
||||
download_url: str | None = Field(None, description="Download URL for model")
|
||||
|
||||
|
||||
class TrainingModelsResponse(BaseModel):
|
||||
"""Response for GET /admin/training/models."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total model count")
|
||||
limit: int = Field(..., ge=1, le=100, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Pagination offset")
|
||||
models: list[TrainingModelItem] = Field(
|
||||
default_factory=list, description="Trained models"
|
||||
)
|
||||
|
||||
|
||||
class TrainingHistoryItem(BaseModel):
|
||||
"""Training history for a document."""
|
||||
|
||||
task_id: str = Field(..., description="Training task UUID")
|
||||
name: str = Field(..., description="Training task name")
|
||||
trained_at: datetime = Field(..., description="Training timestamp")
|
||||
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")
|
||||
18
packages/inference/inference/web/services/__init__.py
Normal file
18
packages/inference/inference/web/services/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Business Logic Services
|
||||
|
||||
Service layer for processing requests and orchestrating data operations.
|
||||
"""
|
||||
|
||||
from inference.web.services.autolabel import AutoLabelService, get_auto_label_service
|
||||
from inference.web.services.inference import InferenceService
|
||||
from inference.web.services.async_processing import AsyncProcessingService
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
|
||||
__all__ = [
|
||||
"AutoLabelService",
|
||||
"get_auto_label_service",
|
||||
"InferenceService",
|
||||
"AsyncProcessingService",
|
||||
"BatchUploadService",
|
||||
]
|
||||
@@ -14,13 +14,13 @@ from pathlib import Path
|
||||
from threading import Event, Thread
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.data.async_request_db import AsyncRequestDB
|
||||
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from src.web.core.rate_limiter import RateLimiter
|
||||
from inference.data.async_request_db import AsyncRequestDB
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.web.config import AsyncConfig, StorageConfig
|
||||
from src.web.services.inference import InferenceService
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
from inference.web.services.inference import InferenceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,11 +11,11 @@ from typing import Any
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from src.matcher.field_matcher import FieldMatcher
|
||||
from src.ocr.paddle_ocr import OCREngine, OCRToken
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||
from shared.matcher.field_matcher import FieldMatcher
|
||||
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -144,7 +144,7 @@ class AutoLabelService:
|
||||
db: AdminDB,
|
||||
) -> int:
|
||||
"""Process PDF document and create annotations."""
|
||||
from src.pdf.renderer import render_pdf_to_images
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
import io
|
||||
|
||||
total_annotations = 0
|
||||
@@ -222,7 +222,7 @@ class AutoLabelService:
|
||||
image_height: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Find annotations for field values using token matching."""
|
||||
from src.normalize import normalize_field
|
||||
from shared.normalize import normalize_field
|
||||
|
||||
annotations = []
|
||||
|
||||
@@ -15,8 +15,8 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.admin_models import CSV_TO_CLASS_MAPPING
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
188
packages/inference/inference/web/services/dataset_builder.py
Normal file
188
packages/inference/inference/web/services/dataset_builder.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Dataset Builder Service
|
||||
|
||||
Creates training datasets by copying images from admin storage,
|
||||
generating YOLO label files, and splitting into train/val/test sets.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import FIELD_CLASSES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
"""Builds YOLO training datasets from admin documents."""
|
||||
|
||||
def __init__(self, db, base_dir: Path):
|
||||
self._db = db
|
||||
self._base_dir = Path(base_dir)
|
||||
|
||||
def build_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: list[str],
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
admin_images_dir: Path,
|
||||
) -> dict:
|
||||
"""Build a complete YOLO dataset from document IDs.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset record.
|
||||
document_ids: List of document UUIDs to include.
|
||||
train_ratio: Fraction for training set.
|
||||
val_ratio: Fraction for validation set.
|
||||
seed: Random seed for reproducible splits.
|
||||
admin_images_dir: Root directory of admin images.
|
||||
|
||||
Returns:
|
||||
Summary dict with total_documents, total_images, total_annotations.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid documents found.
|
||||
"""
|
||||
try:
|
||||
return self._do_build(
|
||||
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
|
||||
)
|
||||
except Exception as e:
|
||||
self._db.update_dataset_status(
|
||||
dataset_id=dataset_id,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
def _do_build(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: list[str],
|
||||
train_ratio: float,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
admin_images_dir: Path,
|
||||
) -> dict:
|
||||
# 1. Fetch documents
|
||||
documents = self._db.get_documents_by_ids(document_ids)
|
||||
if not documents:
|
||||
raise ValueError("No valid documents found for the given IDs")
|
||||
|
||||
# 2. Create directory structure
|
||||
dataset_dir = self._base_dir / dataset_id
|
||||
for split in ["train", "val", "test"]:
|
||||
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. Shuffle and split documents
|
||||
doc_list = list(documents)
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(doc_list)
|
||||
|
||||
n = len(doc_list)
|
||||
n_train = max(1, round(n * train_ratio))
|
||||
n_val = max(0, round(n * val_ratio))
|
||||
n_test = n - n_train - n_val
|
||||
|
||||
splits = (
|
||||
["train"] * n_train
|
||||
+ ["val"] * n_val
|
||||
+ ["test"] * n_test
|
||||
)
|
||||
|
||||
# 4. Process each document
|
||||
total_images = 0
|
||||
total_annotations = 0
|
||||
dataset_docs = []
|
||||
|
||||
for doc, split in zip(doc_list, splits):
|
||||
doc_id = str(doc.document_id)
|
||||
annotations = self._db.get_annotations_for_document(doc.document_id)
|
||||
|
||||
# Group annotations by page
|
||||
page_annotations: dict[int, list] = {}
|
||||
for ann in annotations:
|
||||
page_annotations.setdefault(ann.page_number, []).append(ann)
|
||||
|
||||
doc_image_count = 0
|
||||
doc_ann_count = 0
|
||||
|
||||
# Copy images and write labels for each page
|
||||
for page_num in range(1, doc.page_count + 1):
|
||||
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
|
||||
if not src_image.exists():
|
||||
logger.warning("Image not found: %s", src_image)
|
||||
continue
|
||||
|
||||
dst_name = f"{doc_id}_page{page_num}"
|
||||
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
|
||||
shutil.copy2(src_image, dst_image)
|
||||
doc_image_count += 1
|
||||
|
||||
# Write YOLO label file
|
||||
page_anns = page_annotations.get(page_num, [])
|
||||
label_lines = []
|
||||
for ann in page_anns:
|
||||
label_lines.append(
|
||||
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
|
||||
f"{ann.width:.6f} {ann.height:.6f}"
|
||||
)
|
||||
doc_ann_count += 1
|
||||
|
||||
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
|
||||
label_path.write_text("\n".join(label_lines))
|
||||
|
||||
total_images += doc_image_count
|
||||
total_annotations += doc_ann_count
|
||||
|
||||
dataset_docs.append({
|
||||
"document_id": doc_id,
|
||||
"split": split,
|
||||
"page_count": doc_image_count,
|
||||
"annotation_count": doc_ann_count,
|
||||
})
|
||||
|
||||
# 5. Record document-split assignments in DB
|
||||
self._db.add_dataset_documents(
|
||||
dataset_id=dataset_id,
|
||||
documents=dataset_docs,
|
||||
)
|
||||
|
||||
# 6. Generate data.yaml
|
||||
self._generate_data_yaml(dataset_dir)
|
||||
|
||||
# 7. Update dataset status
|
||||
self._db.update_dataset_status(
|
||||
dataset_id=dataset_id,
|
||||
status="ready",
|
||||
total_documents=len(doc_list),
|
||||
total_images=total_images,
|
||||
total_annotations=total_annotations,
|
||||
dataset_path=str(dataset_dir),
|
||||
)
|
||||
|
||||
return {
|
||||
"total_documents": len(doc_list),
|
||||
"total_images": total_images,
|
||||
"total_annotations": total_annotations,
|
||||
}
|
||||
|
||||
def _generate_data_yaml(self, dataset_dir: Path) -> None:
|
||||
"""Generate YOLO data.yaml configuration file."""
|
||||
data = {
|
||||
"path": str(dataset_dir.absolute()),
|
||||
"train": "images/train",
|
||||
"val": "images/val",
|
||||
"test": "images/test",
|
||||
"nc": len(FIELD_CLASSES),
|
||||
"names": FIELD_CLASSES,
|
||||
}
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))
|
||||
@@ -11,11 +11,11 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.config import DEFAULT_DPI
|
||||
from src.data.admin_db import AdminDB
|
||||
from src.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
|
||||
from src.data.db import DocumentDB
|
||||
from src.web.config import StorageConfig
|
||||
from shared.config import DEFAULT_DPI
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
|
||||
from shared.data.db import DocumentDB
|
||||
from inference.web.config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,8 +81,8 @@ def get_pending_autolabel_documents(
|
||||
List of AdminDocument records with status='auto_labeling' and auto_label_status='pending'
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from src.data.database import get_session_context
|
||||
from src.data.admin_models import AdminDocument
|
||||
from inference.data.database import get_session_context
|
||||
from inference.data.admin_models import AdminDocument
|
||||
|
||||
with get_session_context() as session:
|
||||
statement = select(AdminDocument).where(
|
||||
@@ -116,8 +116,8 @@ def process_document_autolabel(
|
||||
Returns:
|
||||
Result dictionary with success status and annotations
|
||||
"""
|
||||
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
from src.pdf import PDFDocument
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
from shared.pdf import PDFDocument
|
||||
|
||||
document_id = str(document.document_id)
|
||||
file_path = Path(document.file_path)
|
||||
@@ -247,7 +247,7 @@ def _save_annotations_to_db(
|
||||
Number of annotations saved
|
||||
"""
|
||||
from PIL import Image
|
||||
from src.data.admin_models import FIELD_CLASS_IDS
|
||||
from inference.data.admin_models import FIELD_CLASS_IDS
|
||||
|
||||
# Mapping from CSV field names to internal field names
|
||||
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
|
||||
@@ -480,7 +480,7 @@ def save_manual_annotations_to_document_db(
|
||||
pdf_type = "unknown"
|
||||
if pdf_path.exists():
|
||||
try:
|
||||
from src.pdf import PDFDocument
|
||||
from shared.pdf import PDFDocument
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
tokens = list(pdf_doc.extract_text_tokens(0))
|
||||
pdf_type = "scanned" if len(tokens) < 10 else "text"
|
||||
@@ -71,8 +71,8 @@ class InferenceService:
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from src.inference.pipeline import InferencePipeline
|
||||
from src.inference.yolo_detector import YOLODetector
|
||||
from inference.pipeline.pipeline import InferencePipeline
|
||||
from inference.pipeline.yolo_detector import YOLODetector
|
||||
|
||||
# Initialize YOLO detector for visualization
|
||||
self._detector = YOLODetector(
|
||||
@@ -257,7 +257,7 @@ class InferenceService:
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
from src.pdf.renderer import render_pdf_to_images
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
@@ -4,8 +4,8 @@ Background Task Queues
|
||||
Worker queues for asynchronous and batch processing.
|
||||
"""
|
||||
|
||||
from src.web.workers.async_queue import AsyncTaskQueue, AsyncTask
|
||||
from src.web.workers.batch_queue import (
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue, AsyncTask
|
||||
from inference.web.workers.batch_queue import (
|
||||
BatchTaskQueue,
|
||||
BatchTask,
|
||||
init_batch_queue,
|
||||
8
packages/inference/requirements.txt
Normal file
8
packages/inference/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
-e ../shared
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
python-multipart>=0.0.6
|
||||
sqlmodel>=0.0.22
|
||||
ultralytics>=8.1.0
|
||||
httpx>=0.25.0
|
||||
openai>=1.0.0
|
||||
14
packages/inference/run_server.py
Normal file
14
packages/inference/run_server.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Quick start script for the web server.
|
||||
|
||||
Usage:
|
||||
python run_server.py
|
||||
python run_server.py --port 8080
|
||||
python run_server.py --debug --reload
|
||||
"""
|
||||
|
||||
from inference.cli.serve import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
17
packages/inference/setup.py
Normal file
17
packages/inference/setup.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="invoice-inference",
|
||||
version="0.1.0",
|
||||
packages=find_packages(),
|
||||
python_requires=">=3.11",
|
||||
install_requires=[
|
||||
"invoice-shared",
|
||||
"fastapi>=0.104.0",
|
||||
"uvicorn[standard]>=0.24.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"sqlmodel>=0.0.22",
|
||||
"ultralytics>=8.1.0",
|
||||
"httpx>=0.25.0",
|
||||
],
|
||||
)
|
||||
9
packages/shared/requirements.txt
Normal file
9
packages/shared/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
PyMuPDF>=1.23.0
|
||||
paddleocr>=2.7.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
opencv-python>=4.8.0
|
||||
psycopg2-binary>=2.9.0
|
||||
python-dotenv>=1.0.0
|
||||
pyyaml>=6.0
|
||||
thefuzz>=0.20.0
|
||||
19
packages/shared/setup.py
Normal file
19
packages/shared/setup.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="invoice-shared",
|
||||
version="0.1.0",
|
||||
packages=find_packages(),
|
||||
python_requires=">=3.11",
|
||||
install_requires=[
|
||||
"PyMuPDF>=1.23.0",
|
||||
"paddleocr>=2.7.0",
|
||||
"Pillow>=10.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"opencv-python>=4.8.0",
|
||||
"psycopg2-binary>=2.9.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pyyaml>=6.0",
|
||||
"thefuzz>=0.20.0",
|
||||
],
|
||||
)
|
||||
@@ -7,10 +7,16 @@ import platform
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
# .env is at project root, config.py is in src/
|
||||
env_path = Path(__file__).parent.parent / '.env'
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
# Load environment variables from .env file at project root
|
||||
# Walk up from packages/shared/shared/config.py to find project root
|
||||
_config_dir = Path(__file__).parent
|
||||
for _candidate in [_config_dir.parent.parent.parent, _config_dir.parent.parent, _config_dir.parent]:
|
||||
_env_path = _candidate / '.env'
|
||||
if _env_path.exists():
|
||||
load_dotenv(dotenv_path=_env_path)
|
||||
break
|
||||
else:
|
||||
load_dotenv() # fallback: search cwd and parents
|
||||
|
||||
# Global DPI setting - must match training DPI for optimal model performance
|
||||
DEFAULT_DPI = 150
|
||||
3
packages/shared/shared/data/__init__.py
Normal file
3
packages/shared/shared/data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .csv_loader import CSVLoader, InvoiceRow
|
||||
|
||||
__all__ = ['CSVLoader', 'InvoiceRow']
|
||||
@@ -9,8 +9,7 @@ from typing import Set, Dict, Any, Optional
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from src.config import get_db_connection_string
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
|
||||
class DocumentDB:
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user