restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -87,7 +87,10 @@
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")", "Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")", "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)", "Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")" "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")"
], ],
"deny": [], "deny": [],
"ask": [], "ask": [],

BIN
.coverage

Binary file not shown.

4
.gitignore vendored
View File

@@ -52,6 +52,10 @@ reports/*.jsonl
logs/ logs/
*.log *.log
# Coverage
htmlcov/
.coverage
# Jupyter # Jupyter
.ipynb_checkpoints/ .ipynb_checkpoints/

807
README.md
View File

@@ -8,7 +8,25 @@
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注 1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
2. **模型训练**: 使用 YOLOv11 训练字段检测模型 2. **模型训练**: 使用 YOLOv11 训练字段检测模型
3. **推理提取**: 检测字段区域 OCR 提取文本 字段规范化 3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化
### 架构
项目采用 **monorepo + 三包分离** 架构,训练和推理可独立部署:
```
packages/
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 工具)
├── training/ # 训练服务 (GPU, 按需启动)
└── inference/ # 推理服务 (常驻运行)
```
| 服务 | 部署目标 | GPU | 生命周期 |
|------|---------|-----|---------|
| **Inference** | Azure App Service | 可选 | 常驻 7x24 |
| **Training** | Azure ACI | 必需 | 按需启动/销毁 |
两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。
### 当前进度 ### 当前进度
@@ -16,6 +34,8 @@
|------|------| |------|------|
| **已标注文档** | 9,738 (9,709 成功) | | **已标注文档** | 9,738 (9,709 成功) |
| **总体字段匹配率** | 94.8% (82,604/87,121) | | **总体字段匹配率** | 94.8% (82,604/87,121) |
| **测试** | 922 passed |
| **模型 mAP@0.5** | 93.5% |
**各字段匹配率:** **各字段匹配率:**
@@ -42,24 +62,83 @@
|------|------| |------|------|
| **WSL** | WSL 2 + Ubuntu 22.04 | | **WSL** | WSL 2 + Ubuntu 22.04 |
| **Conda** | Miniconda 或 Anaconda | | **Conda** | Miniconda 或 Anaconda |
| **Python** | 3.10+ (通过 Conda 管理) | | **Python** | 3.11+ (通过 Conda 管理) |
| **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) | | **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) |
| **数据库** | PostgreSQL (存储标注结果) | | **数据库** | PostgreSQL (存储标注结果) |
## 功能特点 ## 安装
- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF ```bash
- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据 # 1. 进入 WSL
- **多策略字段匹配**: 精确匹配、子串匹配、规范化匹配 wsl -d Ubuntu-22.04
- **数据库存储**: 标注结果存储在 PostgreSQL支持增量处理和断点续传
- **YOLO 检测**: 使用 YOLOv11 检测发票字段区域 # 2. 创建 Conda 环境
- **OCR 识别**: 使用 PaddleOCR v5 提取检测区域的文本 conda create -n invoice-py311 python=3.11 -y
- **统一解析器**: payment_line 和 customer_number 采用独立解析器模块 conda activate invoice-py311
- **交叉验证**: payment_line 数据与单独检测字段交叉验证,优先采用 payment_line 值
- **文档类型识别**: 自动区分 invoice (有 payment_line) 和 letter (无 payment_line) # 3. 进入项目目录
- **Web 应用**: 提供 REST API 和可视化界面 cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
- **增量训练**: 支持在已训练模型基础上继续训练
- **内存优化**: 支持低内存模式训练 (--low-memory) # 4. 安装三个包 (editable mode)
pip install -e packages/shared
pip install -e packages/training
pip install -e packages/inference
```
## 项目结构
```
invoice-master-poc-v2/
├── packages/
│ ├── shared/ # 共享库
│ │ ├── setup.py
│ │ └── shared/
│ │ ├── pdf/ # PDF 处理 (提取, 渲染, 检测)
│ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析
│ │ ├── normalize/ # 字段规范化 (10 种 normalizer)
│ │ ├── matcher/ # 字段匹配 (精确/子串/模糊)
│ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配)
│ │ ├── data/ # DocumentDB, CSVLoader
│ │ ├── config.py # 全局配置 (数据库, 路径, DPI)
│ │ └── exceptions.py # 异常定义
│ │
│ ├── training/ # 训练服务 (GPU, 按需)
│ │ ├── setup.py
│ │ ├── Dockerfile
│ │ ├── run_training.py # 入口 (--task-id 或 --poll)
│ │ └── training/
│ │ ├── cli/ # train, autolabel, analyze_*, validate
│ │ ├── yolo/ # db_dataset, annotation_generator
│ │ ├── processing/ # CPU/GPU worker pool, task dispatcher
│ │ └── data/ # training_db, autolabel_report
│ │
│ └── inference/ # 推理服务 (常驻)
│ ├── setup.py
│ ├── Dockerfile
│ ├── run_server.py # Web 服务器入口
│ └── inference/
│ ├── cli/ # infer, serve
│ ├── pipeline/ # YOLO 检测, 字段提取, 解析器
│ ├── web/ # FastAPI 应用
│ │ ├── api/v1/ # REST API (admin, public, batch)
│ │ ├── schemas/ # Pydantic 数据模型
│ │ ├── services/ # 业务逻辑
│ │ ├── core/ # 认证, 调度器, 限流
│ │ └── workers/ # 后台任务队列
│ ├── validation/ # LLM 验证器
│ ├── data/ # AdminDB, AsyncRequestDB, Models
│ └── azure/ # ACI 训练触发器
├── migrations/ # 数据库迁移
│ ├── 001_async_tables.sql
│ ├── 002_nullable_admin_token.sql
│ └── 003_training_tasks.sql
├── frontend/ # React 前端 (Vite + TypeScript)
├── tests/ # 测试 (922 tests)
├── docker-compose.yml # 本地开发 (postgres + inference + training)
├── run_server.py # 快捷启动脚本
└── runs/train/ # 训练输出 (weights, curves)
```
## 支持的字段 ## 支持的字段
@@ -76,476 +155,129 @@
| 8 | payment_line | 支付行 (机器可读格式) | | 8 | payment_line | 支付行 (机器可读格式) |
| 9 | customer_number | 客户编号 | | 9 | customer_number | 客户编号 |
## DPI 配置
**重要**: 系统所有组件统一使用 **150 DPI**,确保训练和推理的一致性。
DPI每英寸点数设置必须在训练和推理时保持一致否则会导致
- 检测框尺寸失配
- mAP显著下降可能从93.5%降到60-70%
- 字段漏检或误检
### 配置位置
| 组件 | 配置文件 | 配置项 |
|------|---------|--------|
| **全局常量** | `src/config.py` | `DEFAULT_DPI = 150` |
| **Web推理** | `src/web/config.py` | `ModelConfig.dpi` (导入自 `src.config`) |
| **CLI推理** | `src/cli/infer.py` | `--dpi` 默认值 = `DEFAULT_DPI` |
| **自动标注** | `src/config.py` | `AUTOLABEL['dpi'] = DEFAULT_DPI` |
| **PDF转图** | `src/web/api/v1/admin/documents.py` | 使用 `DEFAULT_DPI` |
### 使用示例
```bash
# 训练使用默认150 DPI
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
# 推理默认150 DPI与训练一致
python -m src.cli.infer -m runs/train/invoice_fields/weights/best.pt -i invoice.pdf
# 手动指定DPI仅当需要与非默认训练DPI的模型配合时
python -m src.cli.infer -m custom_model.pt -i invoice.pdf --dpi 150
```
## 安装
```bash
# 1. 进入 WSL
wsl -d Ubuntu-22.04
# 2. 创建 Conda 环境
conda create -n invoice-py311 python=3.11 -y
conda activate invoice-py311
# 3. 进入项目目录
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
# 4. 安装依赖
pip install -r requirements.txt
# 5. 安装 Web 依赖
pip install uvicorn fastapi python-multipart pydantic
```
## 快速开始 ## 快速开始
### 1. 准备数据 ### 1. 自动标注
```
~/invoice-data/
├── raw_pdfs/
│ ├── {DocumentId}.pdf
│ └── ...
├── structured_data/
│ └── document_export_YYYYMMDD.csv
└── dataset/
└── temp/ (渲染的图片)
```
CSV 格式:
```csv
DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount
3be53fd7-...,2025-12-13,100017500321,2026-01-03,100017500321,53939484,,114
```
### 2. 自动标注
```bash ```bash
# 使用双池模式 (CPU + GPU) # 使用双池模式 (CPU + GPU)
python -m src.cli.autolabel \ python -m training.cli.autolabel \
--dual-pool \ --dual-pool \
--cpu-workers 3 \ --cpu-workers 3 \
--gpu-workers 1 --gpu-workers 1
# 单线程模式 # 单线程模式
python -m src.cli.autolabel --workers 4 python -m training.cli.autolabel --workers 4
``` ```
### 3. 训练模型 ### 2. 训练模型
```bash ```bash
# 从预训练模型开始训练 # 从预训练模型开始训练
python -m src.cli.train \ python -m training.cli.train \
--model yolo11n.pt \ --model yolo11n.pt \
--epochs 100 \ --epochs 100 \
--batch 16 \ --batch 16 \
--name invoice_fields \ --name invoice_fields \
--dpi 150 --dpi 150
# 低内存模式 (适用于内存不足场景) # 低内存模式
python -m src.cli.train \ python -m training.cli.train \
--model yolo11n.pt \ --model yolo11n.pt \
--epochs 100 \ --epochs 100 \
--name invoice_fields \ --name invoice_fields \
--low-memory \ --low-memory
--workers 4 \
--no-cache
# 从检查点恢复训练 (训练中断后) # 从检查点恢复训练
python -m src.cli.train \ python -m training.cli.train \
--model runs/train/invoice_fields/weights/last.pt \ --model runs/train/invoice_fields/weights/last.pt \
--epochs 100 \ --epochs 100 \
--name invoice_fields \ --name invoice_fields \
--resume --resume
``` ```
### 4. 增量训练 ### 3. 推理
当添加新数据后,可以在已训练模型基础上继续训练:
```bash
# 从已训练的 best.pt 继续训练
python -m src.cli.train \
--model runs/train/invoice_yolo11n_full/weights/best.pt \
--epochs 30 \
--batch 16 \
--name invoice_yolo11n_v2 \
--dpi 150
```
**增量训练建议**:
| 场景 | 建议 |
|------|------|
| 添加少量新数据 (<20%) | 继续训练 10-30 epochs |
| 添加大量新数据 (>50%) | 继续训练 50-100 epochs |
| 修正大量标注错误 | 从头训练 |
| 添加新的字段类型 | 从头训练 |
### 5. 推理
```bash ```bash
# 命令行推理 # 命令行推理
python -m src.cli.infer \ python -m inference.cli.infer \
--model runs/train/invoice_fields/weights/best.pt \ --model runs/train/invoice_fields/weights/best.pt \
--input path/to/invoice.pdf \ --input path/to/invoice.pdf \
--output result.json \ --output result.json \
--gpu --gpu
# 批量推理
python -m src.cli.infer \
--model runs/train/invoice_fields/weights/best.pt \
--input invoices/*.pdf \
--output results/ \
--gpu
``` ```
**推理结果包含**: ### 4. Web 应用
- `fields`: 提取的字段值 (InvoiceNumber, Amount, payment_line, customer_number 等)
- `confidence`: 各字段的置信度
- `document_type`: 文档类型 ("invoice" 或 "letter")
- `cross_validation`: payment_line 交叉验证结果 (如果有)
### 6. Web 应用
**在 WSL 环境中启动**:
```bash ```bash
# 方法 1: 从 Windows PowerShell 启动 (推荐) # 从 Windows PowerShell 启动
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000" wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
# 方法 2: 在 WSL 内启动 # 启动前端
conda activate invoice-py311 cd frontend && npm install && npm run dev
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 # 访问 http://localhost:5173
python run_server.py --port 8000
# 方法 3: 使用启动脚本
./start_web.sh
``` ```
**服务启动后**: ### 5. Docker 本地开发
- 访问 **http://localhost:8000** 使用 Web 界面
- 服务会自动加载模型 `runs/train/invoice_fields/weights/best.pt`
- GPU 默认启用,置信度阈值 0.5
#### Web API 端点 ```bash
docker-compose up
# inference: http://localhost:8000
# training: 轮询模式自动拾取任务
```
## 训练触发流程
推理服务通过 API 触发训练,训练在独立的 GPU 实例上执行:
```
Inference API PostgreSQL Training (ACI)
| | |
POST /admin/training/trigger | |
|-> INSERT training_tasks ------>| status=pending |
|-> Azure SDK: create ACI --------------------------------> 启动
| | |
| |<-- SELECT pending -----+
| |--- UPDATE running -----+
| | 执行训练...
| |<-- UPDATE completed ---+
| | + model_path |
| | + metrics 自动关机
| | |
GET /admin/training/{id} | |
|-> SELECT training_tasks ------>| |
+-- return status + metrics | |
```
## Web API 端点
**Public API:**
| 方法 | 端点 | 描述 | | 方法 | 端点 | 描述 |
|------|------|------| |------|------|------|
| GET | `/` | Web UI 界面 |
| GET | `/api/v1/health` | 健康检查 | | GET | `/api/v1/health` | 健康检查 |
| POST | `/api/v1/infer` | 上传文件并推理 | | POST | `/api/v1/infer` | 上传文件并推理 |
| GET | `/api/v1/results/{filename}` | 获取可视化图片 | | GET | `/api/v1/results/{filename}` | 获取可视化图片 |
| POST | `/api/v1/async/infer` | 异步推理 |
| GET | `/api/v1/async/status/{task_id}` | 查询异步任务状态 |
#### API 响应格式 **Admin API** (需要 `X-Admin-Token` header):
```json | 方法 | 端点 | 描述 |
{ |------|------|------|
"status": "success", | POST | `/api/v1/admin/auth/login` | 管理员登录 |
"result": { | GET | `/api/v1/admin/documents` | 文档列表 |
"document_id": "abc123", | POST | `/api/v1/admin/documents/upload` | 上传 PDF |
"document_type": "invoice", | GET | `/api/v1/admin/documents/{id}` | 文档详情 |
"fields": { | PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 |
"InvoiceNumber": "12345", | POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 |
"Amount": "1234.56", | POST | `/api/v1/admin/training/trigger` | 触发训练任务 |
"payment_line": "# 94228110015950070 # > 48666036#14#", | GET | `/api/v1/admin/training/{id}/status` | 查询训练状态 |
"customer_number": "UMJ 436-R"
},
"confidence": {
"InvoiceNumber": 0.95,
"Amount": 0.92
},
"cross_validation": {
"is_valid": true,
"ocr_match": true,
"amount_match": true
}
}
}
```
## 训练配置
### YOLO 训练参数
```bash
python -m src.cli.train [OPTIONS]
Options:
--model, -m 基础模型 (默认: yolo11n.pt)
--epochs, -e 训练轮数 (默认: 100)
--batch, -b 批大小 (默认: 16)
--imgsz 图像尺寸 (默认: 1280)
--dpi PDF 渲染 DPI (默认: 150)
--name 训练名称
--limit 限制文档数 (用于测试)
--device 设备 (0=GPU, cpu)
--resume 从检查点恢复训练
--low-memory 启用低内存模式 (batch=8, workers=4, no-cache)
--workers 数据加载 worker 数 (默认: 8)
--cache 缓存图像到内存
```
### 训练最佳实践
1. **禁用翻转增强** (文本检测):
```python
fliplr=0.0, flipud=0.0
```
2. **使用 Early Stopping**:
```python
patience=20
```
3. **启用 AMP** (混合精度训练):
```python
amp=True
```
4. **保存检查点**:
```python
save_period=10
```
### 训练结果示例
**最新训练结果** (100 epochs, 2026-01-22):
| 指标 | 值 |
|------|-----|
| **mAP@0.5** | 93.5% |
| **mAP@0.5-0.95** | 83.0% |
| **训练集** | ~10,000 张标注图片 |
| **字段类型** | 10 个字段 (新增 payment_line, customer_number) |
| **模型位置** | `runs/train/invoice_fields/weights/best.pt` |
**各字段检测性能**:
- 发票基础信息 (InvoiceNumber, InvoiceDate, InvoiceDueDate): >95% mAP
- 支付信息 (OCR, Bankgiro, Plusgiro, Amount): >90% mAP
- 组织信息 (supplier_org_number, customer_number): >85% mAP
- 支付行 (payment_line): >80% mAP
**模型文件**:
```
runs/train/invoice_fields/weights/
├── best.pt # 最佳模型 (mAP@0.5 最高) ⭐ 推荐用于生产
└── last.pt # 最后检查点 (用于继续训练)
```
> 注:目前仍在持续标注更多数据,预计最终将有 25,000+ 张标注图片用于训练。
## 项目结构
```
invoice-master-poc-v2/
├── src/
│ ├── cli/ # 命令行工具
│ │ ├── autolabel.py # 自动标注
│ │ ├── train.py # 模型训练
│ │ ├── infer.py # 推理
│ │ └── serve.py # Web 服务器
│ ├── pdf/ # PDF 处理
│ │ ├── extractor.py # 文本提取
│ │ ├── renderer.py # 图像渲染
│ │ └── detector.py # 类型检测
│ ├── ocr/ # PaddleOCR 封装
│ │ └── machine_code_parser.py # 机器可读付款行解析器
│ ├── normalize/ # 字段规范化
│ ├── matcher/ # 字段匹配
│ ├── yolo/ # YOLO 相关
│ │ ├── annotation_generator.py
│ │ └── db_dataset.py
│ ├── inference/ # 推理管道
│ │ ├── pipeline.py # 主推理流程
│ │ ├── yolo_detector.py # YOLO 检测
│ │ ├── field_extractor.py # 字段提取
│ │ ├── payment_line_parser.py # 支付行解析器
│ │ └── customer_number_parser.py # 客户编号解析器
│ ├── processing/ # 多池处理架构
│ │ ├── worker_pool.py
│ │ ├── cpu_pool.py
│ │ ├── gpu_pool.py
│ │ ├── task_dispatcher.py
│ │ └── dual_pool_coordinator.py
│ ├── web/ # Web 应用
│ │ ├── app.py # FastAPI 应用入口
│ │ ├── routes.py # API 路由
│ │ ├── services.py # 业务逻辑
│ │ └── schemas.py # 数据模型
│ ├── utils/ # 工具模块
│ │ ├── text_cleaner.py # 文本清理
│ │ ├── validators.py # 字段验证
│ │ ├── fuzzy_matcher.py # 模糊匹配
│ │ └── ocr_corrections.py # OCR 错误修正
│ └── data/ # 数据处理
├── tests/ # 测试文件
│ ├── ocr/ # OCR 模块测试
│ │ └── test_machine_code_parser.py
│ ├── inference/ # 推理模块测试
│ ├── normalize/ # 规范化模块测试
│ └── utils/ # 工具模块测试
├── docs/ # 文档
│ ├── REFACTORING_SUMMARY.md
│ └── TEST_COVERAGE_IMPROVEMENT.md
├── config.py # 配置文件
├── run_server.py # Web 服务器启动脚本
├── runs/ # 训练输出
│ └── train/
│ └── invoice_fields/
│ └── weights/
│ ├── best.pt # 最佳模型
│ └── last.pt # 最后检查点
└── requirements.txt
```
## 多池处理架构
项目使用 CPU + GPU 双池架构处理不同类型的 PDF
```
┌─────────────────────────────────────────────────────┐
│ DualPoolCoordinator │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ CPU Pool │ │ GPU Pool │ │
│ │ (3 workers) │ │ (1 worker) │ │
│ │ │ │ │ │
│ │ Text PDFs │ │ Scanned PDFs │ │
│ │ ~50-87 it/s │ │ ~1-2 it/s │ │
│ └─────────────────┘ └─────────────────┘ │
│ │
│ TaskDispatcher: 根据 PDF 类型分配任务 │
└─────────────────────────────────────────────────────┘
```
### 关键设计
- **spawn 启动方式**: 兼容 CUDA 多进程
- **as_completed()**: 无死锁结果收集
- **进程初始化器**: 每个 worker 加载一次模型
- **协调器持久化**: 跨 CSV 文件复用 worker 池
## 配置文件
### config.py
```python
# 数据库配置
DATABASE = {
'host': '192.168.68.31',
'port': 5432,
'database': 'docmaster',
'user': 'docmaster',
'password': '******',
}
# 路径配置
PATHS = {
'csv_dir': '~/invoice-data/structured_data',
'pdf_dir': '~/invoice-data/raw_pdfs',
'output_dir': '~/invoice-data/dataset',
}
```
## CLI 命令参考
### autolabel
```bash
python -m src.cli.autolabel [OPTIONS]
Options:
--csv, -c CSV 文件路径 (支持 glob)
--pdf-dir, -p PDF 文件目录
--output, -o 输出目录
--workers, -w 单线程模式 worker 数 (默认: 4)
--dual-pool 启用双池模式
--cpu-workers CPU 池 worker 数 (默认: 3)
--gpu-workers GPU 池 worker 数 (默认: 1)
--dpi 渲染 DPI (默认: 150)
--limit, -l 限制处理文档数
```
### train
```bash
python -m src.cli.train [OPTIONS]
Options:
--model, -m 基础模型路径
--epochs, -e 训练轮数 (默认: 100)
--batch, -b 批大小 (默认: 16)
--imgsz 图像尺寸 (默认: 1280)
--dpi PDF 渲染 DPI (默认: 150)
--name 训练名称
--limit 限制文档数
```
### infer
```bash
python -m src.cli.infer [OPTIONS]
Options:
--model, -m 模型路径
--input, -i 输入 PDF/图像
--output, -o 输出 JSON 路径
--confidence 置信度阈值 (默认: 0.5)
--dpi 渲染 DPI (默认: 150, 必须与训练DPI一致)
--gpu 使用 GPU
```
### serve
```bash
python run_server.py [OPTIONS]
Options:
--host 绑定地址 (默认: 0.0.0.0)
--port 端口 (默认: 8000)
--model, -m 模型路径
--confidence 置信度阈值 (默认: 0.3)
--dpi 渲染 DPI (默认: 150)
--no-gpu 禁用 GPU
--reload 开发模式自动重载
--debug 调试模式
```
## Python API ## Python API
```python ```python
from src.inference.pipeline import InferencePipeline from inference.pipeline import InferencePipeline
# 初始化 # 初始化
pipeline = InferencePipeline( pipeline = InferencePipeline(
@@ -559,41 +291,25 @@ pipeline = InferencePipeline(
# 处理 PDF # 处理 PDF
result = pipeline.process_pdf('invoice.pdf') result = pipeline.process_pdf('invoice.pdf')
# 处理图片
result = pipeline.process_image('invoice.png')
# 获取结果
print(result.fields) print(result.fields)
# { # {'InvoiceNumber': '12345', 'Amount': '1234.56', ...}
# 'InvoiceNumber': '12345',
# 'Amount': '1234.56',
# 'payment_line': '# 94228110015950070 # > 48666036#14#',
# 'customer_number': 'UMJ 436-R',
# ...
# }
print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...} print(result.confidence)
print(result.to_json()) # JSON 格式输出 # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
# 访问交叉验证结果 # 交叉验证
if result.cross_validation: if result.cross_validation:
print(f"OCR match: {result.cross_validation.ocr_match}") print(f"OCR match: {result.cross_validation.ocr_match}")
print(f"Amount match: {result.cross_validation.amount_match}")
print(f"Details: {result.cross_validation.details}")
``` ```
### 统一解析器使用
```python ```python
from src.inference.payment_line_parser import PaymentLineParser from inference.pipeline.payment_line_parser import PaymentLineParser
from src.inference.customer_number_parser import CustomerNumberParser from inference.pipeline.customer_number_parser import CustomerNumberParser
# Payment Line 解析 # Payment Line 解析
parser = PaymentLineParser() parser = PaymentLineParser()
result = parser.parse("# 94228110015950070 # 15658 00 8 > 48666036#14#") result = parser.parse("# 94228110015950070 # 15658 00 8 > 48666036#14#")
print(f"OCR: {result.ocr_number}") print(f"OCR: {result.ocr_number}, Amount: {result.amount}")
print(f"Amount: {result.amount}")
print(f"Account: {result.account_number}")
# Customer Number 解析 # Customer Number 解析
parser = CustomerNumberParser() parser = CustomerNumberParser()
@@ -601,156 +317,38 @@ result = parser.parse("Said, Shakar Umj 436-R Billo")
print(f"Customer Number: {result}") # "UMJ 436-R" print(f"Customer Number: {result}") # "UMJ 436-R"
``` ```
## DPI 配置
系统所有组件统一使用 **150 DPI**。DPI 必须在训练和推理时保持一致。
| 组件 | 配置位置 |
|------|---------|
| 全局常量 | `packages/shared/shared/config.py` -> `DEFAULT_DPI = 150` |
| Web 推理 | `packages/inference/inference/web/config.py` -> `ModelConfig.dpi` |
| CLI 推理 | `python -m inference.cli.infer --dpi 150` |
| 自动标注 | `packages/shared/shared/config.py` -> `AUTOLABEL['dpi']` |
## 数据库架构
| 数据库 | 用途 | 存储内容 |
|--------|------|----------|
| **PostgreSQL** | 标注结果 | `documents`, `field_results`, `training_tasks` |
| **SQLite** (AdminDB) | Web 应用 | 文档管理, 标注编辑, 用户认证 |
## 测试 ## 测试
### 测试统计
| 指标 | 数值 |
|------|------|
| **测试总数** | 688 |
| **通过率** | 100% |
| **整体覆盖率** | 37% |
### 关键模块覆盖率
| 模块 | 覆盖率 | 测试数 |
|------|--------|--------|
| `machine_code_parser.py` | 65% | 79 |
| `payment_line_parser.py` | 85% | 45 |
| `customer_number_parser.py` | 90% | 32 |
### 运行测试
```bash ```bash
# 运行所有测试 # 运行所有测试
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest" DB_PASSWORD=xxx pytest tests/ -q
# 运行并查看覆盖率 # 运行并查看覆盖率
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src --cov-report=term-missing" DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
# 运行特定模块测试
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/ocr/test_machine_code_parser.py -v"
``` ```
### 测试结构 | 指标 | 数值 |
|------|------|
``` | **测试总数** | 922 |
tests/ | **通过率** | 100% |
├── ocr/
│ ├── test_machine_code_parser.py # 支付行解析 (79 tests)
│ └── test_ocr_engine.py # OCR 引擎测试
├── inference/
│ ├── test_payment_line_parser.py # 支付行解析器
│ └── test_customer_number_parser.py # 客户编号解析器
├── normalize/
│ └── test_normalizers.py # 字段规范化
└── utils/
└── test_validators.py # 字段验证
```
## 开发状态
**已完成功能**:
- [x] 文本层 PDF 自动标注
- [x] 扫描图 OCR 自动标注
- [x] 多策略字段匹配 (精确/子串/规范化)
- [x] PostgreSQL 数据库存储 (断点续传)
- [x] 信号处理和超时保护
- [x] YOLO 训练 (93.5% mAP@0.5, 10 个字段)
- [x] 推理管道
- [x] 字段规范化和验证
- [x] Web 应用 (FastAPI + REST API)
- [x] 增量训练支持
- [x] 内存优化训练 (--low-memory, --resume)
- [x] Payment Line 解析器 (统一模块)
- [x] Customer Number 解析器 (统一模块)
- [x] Payment Line 交叉验证 (OCR, Amount, Account)
- [x] 文档类型识别 (invoice/letter)
- [x] 单元测试覆盖 (688 tests, 37% coverage)
**进行中**:
- [ ] 完成全部 25,000+ 文档标注
- [ ] 多源融合增强 (Multi-source fusion)
- [ ] OCR 错误修正集成
- [ ] 提升测试覆盖率到 60%+
**计划中**:
- [ ] 表格 items 提取
- [ ] 模型量化部署 (ONNX/TensorRT)
- [ ] 多语言支持扩展
## 关键技术特性
### 1. Payment Line 交叉验证
瑞典发票的 payment_line (支付行) 包含完整的支付信息OCR 参考号、金额、账号。我们实现了交叉验证机制:
```
Payment Line: # 94228110015950070 # 15658 00 8 > 48666036#14#
↓ ↓ ↓
OCR Number Amount Bankgiro Account
```
**验证流程**:
1. 从 payment_line 提取 OCR、Amount、Account
2. 与单独检测的字段对比验证
3. **payment_line 值优先** - 如有不匹配,采用 payment_line 的值
4. 返回验证结果和详细信息
**优势**:
- 提高数据准确性 (payment_line 是机器可读格式,更可靠)
- 发现 OCR 或检测错误
- 为数据质量提供信心指标
### 2. 统一解析器架构
采用独立解析器模块处理复杂字段:
**PaymentLineParser**:
- 解析瑞典标准支付行格式
- 提取 OCR、Amount (包含 Kronor + Öre)、Account + Check digits
- 支持多种变体格式
**CustomerNumberParser**:
- 支持多种瑞典客户编号格式 (`UMJ 436-R`, `JTY 576-3`, `FFL 019N`)
- 从混合文本中提取 (如地址行中的客户编号)
- 大小写不敏感,输出统一大写格式
**优势**:
- 代码模块化、可测试
- 易于扩展新格式
- 统一的解析逻辑,减少重复代码
### 3. 文档类型自动识别
根据 payment_line 字段自动判断文档类型:
- **invoice**: 包含 payment_line 的发票文档
- **letter**: 不含 payment_line 的信函文档
这个特性帮助下游系统区分处理流程。
### 4. 低内存模式训练
支持在内存受限环境下训练:
```bash
python -m src.cli.train --low-memory
```
自动调整:
- batch size: 16 → 8
- workers: 8 → 4
- cache: disabled
- 推荐用于 GPU 内存 < 8GB 或系统内存 < 16GB 的场景
### 5. 断点续传训练
训练中断后可从检查点恢复:
```bash
python -m src.cli.train --resume --model runs/train/invoice_fields/weights/last.pt
```
## 技术栈 ## 技术栈
@@ -762,32 +360,7 @@ python -m src.cli.train --resume --model runs/train/invoice_fields/weights/last.
| **数据库** | PostgreSQL + psycopg2 | | **数据库** | PostgreSQL + psycopg2 |
| **Web 框架** | FastAPI + Uvicorn | | **Web 框架** | FastAPI + Uvicorn |
| **深度学习** | PyTorch + CUDA 12.x | | **深度学习** | PyTorch + CUDA 12.x |
| **部署** | Docker + Azure ACI (训练) / App Service (推理) |
## 常见问题
**Q: 为什么必须在 WSL 环境运行?**
A: PaddleOCR 和某些依赖在 Windows 原生环境存在兼容性问题。WSL 提供完整的 Linux 环境,确保所有依赖正常工作。
**Q: 训练过程中出现 OOM (内存不足) 错误怎么办?**
A: 使用 `--low-memory` 模式,或手动调整 `--batch` 和 `--workers` 参数。
**Q: payment_line 和单独检测字段不匹配时怎么处理?**
A: 系统默认优先采用 payment_line 的值,因为 payment_line 是机器可读格式,通常更准确。验证结果会记录在 `cross_validation` 字段中。
**Q: 如何添加新的字段类型?**
A:
1. 在 `src/inference/constants.py` 添加字段定义
2. 在 `field_extractor.py` 添加规范化方法
3. 重新生成标注数据
4. 从头训练模型
**Q: 可以用 CPU 训练吗?**
A: 可以,但速度会非常慢 (慢 10-50 倍)。强烈建议使用 GPU 训练。
## 许可证 ## 许可证

60
docker-compose.yml Normal file
View File

@@ -0,0 +1,60 @@
version: "3.8"
services:
postgres:
image: postgres:15
environment:
POSTGRES_DB: docmaster
POSTGRES_USER: docmaster
POSTGRES_PASSWORD: ${DB_PASSWORD:-devpassword}
ports:
- "5432:5432"
volumes:
- pgdata:/var/lib/postgresql/data
- ./migrations:/docker-entrypoint-initdb.d
inference:
build:
context: .
dockerfile: packages/inference/Dockerfile
ports:
- "8000:8000"
environment:
- DB_HOST=postgres
- DB_PORT=5432
- DB_NAME=docmaster
- DB_USER=docmaster
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
- MODEL_PATH=/app/models/best.pt
volumes:
- ./models:/app/models
depends_on:
- postgres
training:
build:
context: .
dockerfile: packages/training/Dockerfile
environment:
- DB_HOST=postgres
- DB_PORT=5432
- DB_NAME=docmaster
- DB_USER=docmaster
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
volumes:
- ./models:/app/models
- ./temp:/app/temp
depends_on:
- postgres
# Override CMD for local dev polling mode
command: ["python", "run_training.py", "--poll", "--poll-interval", "30"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
volumes:
pgdata:

54
docs/training-flow.mmd Normal file
View File

@@ -0,0 +1,54 @@
flowchart TD
A[CLI Entry Point\nsrc/cli/train.py] --> B[Parse Arguments\n--model, --epochs, --batch, --imgsz, etc.]
B --> C[Connect PostgreSQL\nDB_HOST / DB_NAME / DB_PASSWORD]
C --> D[Load Data from DB\nsrc/yolo/db_dataset.py]
D --> D1[Scan temp/doc_id/images/\nfor rendered PNGs]
D --> D2[Batch load field_results\nfrom database - batch 500]
D1 --> E[Create DBYOLODataset]
D2 --> E
E --> F[Split Train/Val/Test\n80% / 10% / 10%\nDocument-level, seed=42]
F --> G[Export to YOLO Format]
G --> G1[Copy images to\ntrain/val/test dirs]
G --> G2[Generate .txt labels\nclass x_center y_center w h]
G --> G3[Generate dataset.yaml\n+ classes.txt]
G --> G4[Coordinate Conversion\nPDF points 72DPI -> render DPI\nNormalize to 0-1]
G1 --> H{--export-only?}
G2 --> H
G3 --> H
G4 --> H
H -- Yes --> Z[Done - Dataset exported]
H -- No --> I[Load YOLO Model]
I --> I1{--resume?}
I1 -- Yes --> I2[Load last.pt checkpoint]
I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt]
I2 --> J[Configure Training]
I3 --> J
J --> J1[Conservative Augmentation\nrotation=5 deg, translate=5%\nno flip, no mosaic, no mixup]
J --> J2[imgsz=1280, pretrained=True]
J1 --> K[model.train\nUltralytics Training Loop]
J2 --> K
K --> L[Training Outputs\nruns/train/name/]
L --> L1[weights/best.pt\nweights/last.pt]
L --> L2[results.csv + results.png\nTraining curves]
L --> L3[PR curves, F1 curves\nConfusion matrix]
L1 --> M[Test Set Validation\nmodel.val split=test]
M --> N[Report Metrics\nmAP@0.5 = 93.5%\nmAP@0.5-0.95]
N --> O[Close DB Connection]
style A fill:#4a90d9,color:#fff
style K fill:#e67e22,color:#fff
style N fill:#27ae60,color:#fff
style Z fill:#95a5a6,color:#fff

File diff suppressed because it is too large Load Diff

View File

@@ -6,27 +6,36 @@
"scripts": { "scripts": {
"dev": "vite", "dev": "vite",
"build": "vite build", "build": "vite build",
"preview": "vite preview" "preview": "vite preview",
"test": "vitest run",
"test:watch": "vitest",
"test:coverage": "vitest run --coverage"
}, },
"dependencies": { "dependencies": {
"@tanstack/react-query": "^5.20.0",
"axios": "^1.6.7",
"clsx": "^2.1.0",
"date-fns": "^3.3.0",
"lucide-react": "^0.563.0",
"react": "^19.2.3", "react": "^19.2.3",
"react-dom": "^19.2.3", "react-dom": "^19.2.3",
"lucide-react": "^0.563.0",
"recharts": "^3.7.0",
"axios": "^1.6.7",
"react-router-dom": "^6.22.0", "react-router-dom": "^6.22.0",
"zustand": "^4.5.0", "recharts": "^3.7.0",
"@tanstack/react-query": "^5.20.0", "zustand": "^4.5.0"
"date-fns": "^3.3.0",
"clsx": "^2.1.0"
}, },
"devDependencies": { "devDependencies": {
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.2",
"@testing-library/user-event": "^14.6.1",
"@types/node": "^22.14.0", "@types/node": "^22.14.0",
"@vitejs/plugin-react": "^5.0.0", "@vitejs/plugin-react": "^5.0.0",
"@vitest/coverage-v8": "^4.0.18",
"autoprefixer": "^10.4.17",
"jsdom": "^27.4.0",
"postcss": "^8.4.35",
"tailwindcss": "^3.4.1",
"typescript": "~5.8.2", "typescript": "~5.8.2",
"vite": "^6.2.0", "vite": "^6.2.0",
"tailwindcss": "^3.4.1", "vitest": "^4.0.18"
"autoprefixer": "^10.4.17",
"postcss": "^8.4.35"
} }
} }

View File

@@ -0,0 +1,32 @@
import { render, screen } from '@testing-library/react';
import { describe, it, expect } from 'vitest';
import { Badge } from './Badge';
import { DocumentStatus } from '../types';
describe('Badge', () => {
it('renders Exported badge with check icon', () => {
render(<Badge status="Exported" />);
expect(screen.getByText('Exported')).toBeInTheDocument();
});
it('renders Pending status', () => {
render(<Badge status={DocumentStatus.PENDING} />);
expect(screen.getByText('Pending')).toBeInTheDocument();
});
it('renders Verified status', () => {
render(<Badge status={DocumentStatus.VERIFIED} />);
expect(screen.getByText('Verified')).toBeInTheDocument();
});
it('renders Labeled status', () => {
render(<Badge status={DocumentStatus.LABELED} />);
expect(screen.getByText('Labeled')).toBeInTheDocument();
});
it('renders Partial status with warning indicator', () => {
render(<Badge status={DocumentStatus.PARTIAL} />);
expect(screen.getByText('Partial')).toBeInTheDocument();
expect(screen.getByText('!')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,38 @@
import { render, screen } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { describe, it, expect, vi } from 'vitest';
import { Button } from './Button';
describe('Button', () => {
it('renders children text', () => {
render(<Button>Click me</Button>);
expect(screen.getByRole('button', { name: 'Click me' })).toBeInTheDocument();
});
it('calls onClick handler', async () => {
const user = userEvent.setup();
const onClick = vi.fn();
render(<Button onClick={onClick}>Click</Button>);
await user.click(screen.getByRole('button'));
expect(onClick).toHaveBeenCalledOnce();
});
it('is disabled when disabled prop is set', () => {
render(<Button disabled>Disabled</Button>);
expect(screen.getByRole('button')).toBeDisabled();
});
it('applies variant styles', () => {
const { rerender } = render(<Button variant="primary">Primary</Button>);
const btn = screen.getByRole('button');
expect(btn.className).toContain('bg-warm-text-secondary');
rerender(<Button variant="secondary">Secondary</Button>);
expect(screen.getByRole('button').className).toContain('border');
});
it('applies size styles', () => {
render(<Button size="sm">Small</Button>);
expect(screen.getByRole('button').className).toContain('h-8');
});
});

1
frontend/tests/setup.ts Normal file
View File

@@ -0,0 +1 @@
import '@testing-library/jest-dom';

19
frontend/vitest.config.ts Normal file
View File

@@ -0,0 +1,19 @@
/// <reference types="vitest/config" />
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
export default defineConfig({
plugins: [react()],
test: {
globals: true,
environment: 'jsdom',
setupFiles: ['./tests/setup.ts'],
include: ['src/**/*.test.{ts,tsx}', 'tests/**/*.test.{ts,tsx}'],
coverage: {
provider: 'v8',
reporter: ['text', 'lcov'],
include: ['src/**/*.{ts,tsx}'],
exclude: ['src/**/*.test.{ts,tsx}', 'src/main.tsx'],
},
},
});

View File

@@ -0,0 +1,18 @@
-- Training tasks table for async training job management.
-- Inference service writes pending tasks; training service polls and executes.
CREATE TABLE IF NOT EXISTS training_tasks (
task_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
status VARCHAR(20) NOT NULL DEFAULT 'pending',
config JSONB,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
scheduled_at TIMESTAMP WITH TIME ZONE,
started_at TIMESTAMP WITH TIME ZONE,
completed_at TIMESTAMP WITH TIME ZONE,
error_message TEXT,
model_path TEXT,
metrics JSONB
);
CREATE INDEX IF NOT EXISTS idx_training_tasks_status ON training_tasks(status);
CREATE INDEX IF NOT EXISTS idx_training_tasks_created ON training_tasks(created_at);

View File

@@ -0,0 +1,39 @@
-- Training Datasets Management
-- Tracks dataset-document relationships and train/val/test splits
CREATE TABLE IF NOT EXISTS training_datasets (
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(255) NOT NULL,
description TEXT,
status VARCHAR(20) NOT NULL DEFAULT 'building',
train_ratio FLOAT NOT NULL DEFAULT 0.8,
val_ratio FLOAT NOT NULL DEFAULT 0.1,
seed INTEGER NOT NULL DEFAULT 42,
total_documents INTEGER NOT NULL DEFAULT 0,
total_images INTEGER NOT NULL DEFAULT 0,
total_annotations INTEGER NOT NULL DEFAULT 0,
dataset_path VARCHAR(512),
error_message TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
CREATE TABLE IF NOT EXISTS dataset_documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
split VARCHAR(10) NOT NULL,
page_count INTEGER NOT NULL DEFAULT 0,
annotation_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
UNIQUE(dataset_id, document_id)
);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
-- Add dataset_id to training_tasks
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);

View File

@@ -0,0 +1,25 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
&& rm -rf /var/lib/apt/lists/*
# Install shared package
COPY packages/shared /app/packages/shared
RUN pip install --no-cache-dir -e /app/packages/shared
# Install inference package
COPY packages/inference /app/packages/inference
RUN pip install --no-cache-dir -e /app/packages/inference
# Copy frontend (if needed)
COPY frontend /app/frontend
WORKDIR /app/packages/inference
EXPOSE 8000
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,105 @@
"""Trigger training jobs on Azure Container Instances."""
import logging
import os
logger = logging.getLogger(__name__)
# Azure SDK is optional; only needed if using ACI trigger
try:
from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
EnvironmentVariable,
GpuResource,
ResourceRequests,
ResourceRequirements,
)
_AZURE_SDK_AVAILABLE = True
except ImportError:
_AZURE_SDK_AVAILABLE = False
def start_training_container(task_id: str) -> str | None:
"""
Start an Azure Container Instance for a training task.
Returns the container group name if successful, None otherwise.
Requires environment variables:
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
"""
if not _AZURE_SDK_AVAILABLE:
logger.warning(
"Azure SDK not installed. Install azure-mgmt-containerinstance "
"and azure-identity to use ACI trigger."
)
return None
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
image = os.environ.get(
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
)
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
location = os.environ.get("AZURE_LOCATION", "eastus")
if not subscription_id:
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
return None
credential = DefaultAzureCredential()
client = ContainerInstanceManagementClient(credential, subscription_id)
container_name = f"training-{task_id[:8]}"
env_vars = [
EnvironmentVariable(name="TASK_ID", value=task_id),
]
# Pass DB connection securely
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
val = os.environ.get(var, "")
if val:
env_vars.append(EnvironmentVariable(name=var, value=val))
db_password = os.environ.get("DB_PASSWORD", "")
if db_password:
env_vars.append(
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
)
container = Container(
name=container_name,
image=image,
resources=ResourceRequirements(
requests=ResourceRequests(
cpu=4,
memory_in_gb=16,
gpu=GpuResource(count=1, sku=gpu_sku),
)
),
environment_variables=env_vars,
command=[
"python",
"run_training.py",
"--task-id",
task_id,
],
)
group = ContainerGroup(
location=location,
containers=[container],
os_type="Linux",
restart_policy="Never",
)
logger.info("Creating ACI container group: %s", container_name)
client.container_groups.begin_create_or_update(
resource_group, container_name, group
)
return container_name

View File

@@ -10,8 +10,7 @@ import json
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from shared.config import DEFAULT_DPI
from src.config import DEFAULT_DPI
def main(): def main():
@@ -91,7 +90,7 @@ def main():
print(f"Processing {len(pdf_files)} PDF file(s)") print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}") print(f"Model: {model_path}")
from ..inference import InferencePipeline from inference.pipeline import InferencePipeline
# Initialize pipeline # Initialize pipeline
pipeline = InferencePipeline( pipeline = InferencePipeline(

View File

@@ -13,9 +13,8 @@ from pathlib import Path
# Add project root to path # Add project root to path
project_root = Path(__file__).parent.parent.parent project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
def setup_logging(debug: bool = False) -> None: def setup_logging(debug: bool = False) -> None:
@@ -121,7 +120,7 @@ def main() -> None:
logger.info("=" * 60) logger.info("=" * 60)
# Create config # Create config
from src.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
config = AppConfig( config = AppConfig(
model=ModelConfig( model=ModelConfig(
@@ -142,7 +141,7 @@ def main() -> None:
# Create and run app # Create and run app
import uvicorn import uvicorn
from src.web.app import create_app from inference.web.app import create_app
app = create_app(config) app = create_app(config)

View File

@@ -12,8 +12,8 @@ from uuid import UUID
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import select from sqlmodel import select
from src.data.database import get_session_context from inference.data.database import get_session_context
from src.data.admin_models import ( from inference.data.admin_models import (
AdminToken, AdminToken,
AdminDocument, AdminDocument,
AdminAnnotation, AdminAnnotation,
@@ -23,6 +23,8 @@ from src.data.admin_models import (
BatchUploadFile, BatchUploadFile,
TrainingDocumentLink, TrainingDocumentLink,
AnnotationHistory, AnnotationHistory,
TrainingDataset,
DatasetDocument,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -174,7 +176,7 @@ class AdminDB:
# For has_annotations filter, we need to join with annotations # For has_annotations filter, we need to join with annotations
if has_annotations is not None: if has_annotations is not None:
from src.data.admin_models import AdminAnnotation from inference.data.admin_models import AdminAnnotation
if has_annotations: if has_annotations:
# Documents WITH annotations # Documents WITH annotations
@@ -200,7 +202,7 @@ class AdminDB:
# Apply has_annotations filter # Apply has_annotations filter
if has_annotations is not None: if has_annotations is not None:
from src.data.admin_models import AdminAnnotation from inference.data.admin_models import AdminAnnotation
if has_annotations: if has_annotations:
statement = ( statement = (
@@ -456,6 +458,7 @@ class AdminDB:
scheduled_at: datetime | None = None, scheduled_at: datetime | None = None,
cron_expression: str | None = None, cron_expression: str | None = None,
is_recurring: bool = False, is_recurring: bool = False,
dataset_id: str | None = None,
) -> str: ) -> str:
"""Create a new training task.""" """Create a new training task."""
with get_session_context() as session: with get_session_context() as session:
@@ -469,6 +472,7 @@ class AdminDB:
cron_expression=cron_expression, cron_expression=cron_expression,
is_recurring=is_recurring, is_recurring=is_recurring,
status="scheduled" if scheduled_at else "pending", status="scheduled" if scheduled_at else "pending",
dataset_id=dataset_id,
) )
session.add(task) session.add(task)
session.flush() session.flush()
@@ -1154,3 +1158,159 @@ class AdminDB:
session.refresh(annotation) session.refresh(annotation)
session.expunge(annotation) session.expunge(annotation)
return annotation return annotation
# ==========================================================================
# Training Dataset Operations
# ==========================================================================
def create_dataset(
self,
name: str,
description: str | None = None,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
) -> TrainingDataset:
"""Create a new training dataset."""
with get_session_context() as session:
dataset = TrainingDataset(
name=name,
description=description,
train_ratio=train_ratio,
val_ratio=val_ratio,
seed=seed,
)
session.add(dataset)
session.commit()
session.refresh(dataset)
session.expunge(dataset)
return dataset
def get_dataset(self, dataset_id: str | UUID) -> TrainingDataset | None:
"""Get a dataset by ID."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if dataset:
session.expunge(dataset)
return dataset
def get_datasets(
self,
status: str | None = None,
limit: int = 20,
offset: int = 0,
) -> tuple[list[TrainingDataset], int]:
"""List datasets with optional status filter."""
with get_session_context() as session:
query = select(TrainingDataset)
count_query = select(func.count()).select_from(TrainingDataset)
if status:
query = query.where(TrainingDataset.status == status)
count_query = count_query.where(TrainingDataset.status == status)
total = session.exec(count_query).one()
datasets = session.exec(
query.order_by(TrainingDataset.created_at.desc()).offset(offset).limit(limit)
).all()
for d in datasets:
session.expunge(d)
return list(datasets), total
def update_dataset_status(
self,
dataset_id: str | UUID,
status: str,
error_message: str | None = None,
total_documents: int | None = None,
total_images: int | None = None,
total_annotations: int | None = None,
dataset_path: str | None = None,
) -> None:
"""Update dataset status and optional totals."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return
dataset.status = status
dataset.updated_at = datetime.utcnow()
if error_message is not None:
dataset.error_message = error_message
if total_documents is not None:
dataset.total_documents = total_documents
if total_images is not None:
dataset.total_images = total_images
if total_annotations is not None:
dataset.total_annotations = total_annotations
if dataset_path is not None:
dataset.dataset_path = dataset_path
session.add(dataset)
session.commit()
def add_dataset_documents(
self,
dataset_id: str | UUID,
documents: list[dict[str, Any]],
) -> None:
"""Batch insert documents into a dataset.
Each dict: {document_id, split, page_count, annotation_count}
"""
with get_session_context() as session:
for doc in documents:
dd = DatasetDocument(
dataset_id=UUID(str(dataset_id)),
document_id=UUID(str(doc["document_id"])),
split=doc["split"],
page_count=doc.get("page_count", 0),
annotation_count=doc.get("annotation_count", 0),
)
session.add(dd)
session.commit()
def get_dataset_documents(
self, dataset_id: str | UUID
) -> list[DatasetDocument]:
"""Get all documents in a dataset."""
with get_session_context() as session:
results = session.exec(
select(DatasetDocument)
.where(DatasetDocument.dataset_id == UUID(str(dataset_id)))
).all()
for r in results:
session.expunge(r)
return list(results)
def get_documents_by_ids(
self, document_ids: list[str]
) -> list[AdminDocument]:
"""Get documents by list of IDs."""
with get_session_context() as session:
uuids = [UUID(str(did)) for did in document_ids]
results = session.exec(
select(AdminDocument).where(AdminDocument.document_id.in_(uuids))
).all()
for r in results:
session.expunge(r)
return list(results)
def get_annotations_for_document(
self, document_id: str | UUID
) -> list[AdminAnnotation]:
"""Get all annotations for a document."""
with get_session_context() as session:
results = session.exec(
select(AdminAnnotation)
.where(AdminAnnotation.document_id == UUID(str(document_id)))
).all()
for r in results:
session.expunge(r)
return list(results)
def delete_dataset(self, dataset_id: str | UUID) -> bool:
"""Delete a dataset and its document links (CASCADE)."""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return False
session.delete(dataset)
session.commit()
return True

View File

@@ -131,6 +131,7 @@ class TrainingTask(SQLModel, table=True):
# Status: pending, scheduled, running, completed, failed, cancelled # Status: pending, scheduled, running, completed, failed, cancelled
task_type: str = Field(default="train", max_length=20) task_type: str = Field(default="train", max_length=20)
# Task type: train, finetune # Task type: train, finetune
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
# Training configuration # Training configuration
config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Schedule settings # Schedule settings
@@ -225,6 +226,42 @@ class BatchUploadFile(SQLModel, table=True):
# ============================================================================= # =============================================================================
class TrainingDataset(SQLModel, table=True):
"""Training dataset containing selected documents with train/val/test splits."""
__tablename__ = "training_datasets"
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str = Field(max_length=255)
description: str | None = Field(default=None)
status: str = Field(default="building", max_length=20, index=True)
# Status: building, ready, training, archived, failed
train_ratio: float = Field(default=0.8)
val_ratio: float = Field(default=0.1)
seed: int = Field(default=42)
total_documents: int = Field(default=0)
total_images: int = Field(default=0)
total_annotations: int = Field(default=0)
dataset_path: str | None = Field(default=None, max_length=512)
error_message: str | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class DatasetDocument(SQLModel, table=True):
"""Junction table linking datasets to documents with split assignment."""
__tablename__ = "dataset_documents"
id: UUID = Field(default_factory=uuid4, primary_key=True)
dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
split: str = Field(max_length=10) # train, val, test
page_count: int = Field(default=0)
annotation_count: int = Field(default=0)
created_at: datetime = Field(default_factory=datetime.utcnow)
class TrainingDocumentLink(SQLModel, table=True): class TrainingDocumentLink(SQLModel, table=True):
"""Junction table linking training tasks to documents.""" """Junction table linking training tasks to documents."""
@@ -336,4 +373,35 @@ class TrainingTaskRead(SQLModel):
error_message: str | None error_message: str | None
result_metrics: dict[str, Any] | None result_metrics: dict[str, Any] | None
model_path: str | None model_path: str | None
dataset_id: UUID | None
created_at: datetime created_at: datetime
class TrainingDatasetRead(SQLModel):
"""Training dataset response model."""
dataset_id: UUID
name: str
description: str | None
status: str
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
created_at: datetime
updated_at: datetime
class DatasetDocumentRead(SQLModel):
"""Dataset document response model."""
id: UUID
dataset_id: UUID
document_id: UUID
split: str
page_count: int
annotation_count: int

View File

@@ -12,8 +12,8 @@ from uuid import UUID
from sqlalchemy import func, text from sqlalchemy import func, text
from sqlmodel import Session, select from sqlmodel import Session, select
from src.data.database import get_session_context, create_db_and_tables, close_engine from inference.data.database import get_session_context, create_db_and_tables, close_engine
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -13,8 +13,7 @@ from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine from sqlmodel import Session, SQLModel, create_engine
import sys import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from shared.config import get_db_connection_string
from src.config import get_db_connection_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,8 +51,8 @@ def get_engine():
def create_db_and_tables() -> None: def create_db_and_tables() -> None:
"""Create all database tables.""" """Create all database tables."""
from src.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401 from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401
from src.data.admin_models import ( # noqa: F401 from inference.data.admin_models import ( # noqa: F401
AdminToken, AdminToken,
AdminDocument, AdminDocument,
AdminAnnotation, AdminAnnotation,

View File

@@ -92,7 +92,7 @@ constructors or methods. The values here serve as sensible defaults
based on Swedish invoice processing requirements. based on Swedish invoice processing requirements.
Example: Example:
from src.inference.constants import DEFAULT_CONFIDENCE_THRESHOLD from inference.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD
detector = YOLODetector( detector = YOLODetector(
model_path="model.pt", model_path="model.pt",

View File

@@ -17,7 +17,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, List from typing import Optional, List
from src.exceptions import CustomerNumberParseError from shared.exceptions import CustomerNumberParseError
@dataclass @dataclass

View File

@@ -4,7 +4,7 @@ Field Extractor Module
Extracts and validates field values from detected regions. Extracts and validates field values from detected regions.
This module is used during inference to extract values from OCR text. This module is used during inference to extract values from OCR text.
It uses shared utilities from src.utils for text cleaning and validation. It uses shared utilities from shared.utils for text cleaning and validation.
Enhanced features: Enhanced features:
- Multi-source fusion with confidence weighting - Multi-source fusion with confidence weighting
@@ -24,10 +24,10 @@ from PIL import Image
from .yolo_detector import Detection, CLASS_TO_FIELD from .yolo_detector import Detection, CLASS_TO_FIELD
# Import shared utilities for text cleaning and validation # Import shared utilities for text cleaning and validation
from src.utils.text_cleaner import TextCleaner from shared.utils.text_cleaner import TextCleaner
from src.utils.validators import FieldValidators from shared.utils.validators import FieldValidators
from src.utils.fuzzy_matcher import FuzzyMatcher from shared.utils.fuzzy_matcher import FuzzyMatcher
from src.utils.ocr_corrections import OCRCorrections from shared.utils.ocr_corrections import OCRCorrections
# Import new unified parsers # Import new unified parsers
from .payment_line_parser import PaymentLineParser from .payment_line_parser import PaymentLineParser
@@ -104,7 +104,7 @@ class FieldExtractor:
def ocr_engine(self): def ocr_engine(self):
"""Lazy-load OCR engine only when needed.""" """Lazy-load OCR engine only when needed."""
if self._ocr_engine is None: if self._ocr_engine is None:
from ..ocr import OCREngine from shared.ocr import OCREngine
self._ocr_engine = OCREngine(lang=self.ocr_lang) self._ocr_engine = OCREngine(lang=self.ocr_lang)
return self._ocr_engine return self._ocr_engine

View File

@@ -21,7 +21,7 @@ import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from src.exceptions import PaymentLineParseError from shared.exceptions import PaymentLineParseError
@dataclass @dataclass

View File

@@ -144,7 +144,7 @@ class InferencePipeline:
Returns: Returns:
InferenceResult with extracted fields InferenceResult with extracted fields
""" """
from ..pdf.renderer import render_pdf_to_images from shared.pdf.renderer import render_pdf_to_images
from PIL import Image from PIL import Image
import io import io
import numpy as np import numpy as np
@@ -381,8 +381,8 @@ class InferencePipeline:
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None: def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback.""" """Run full-page OCR fallback."""
from ..pdf.renderer import render_pdf_to_images from shared.pdf.renderer import render_pdf_to_images
from ..ocr import OCREngine from shared.ocr import OCREngine
from PIL import Image from PIL import Image
import io import io
import numpy as np import numpy as np

View File

@@ -189,7 +189,7 @@ class YOLODetector:
Returns: Returns:
Dict mapping page number to list of detections Dict mapping page number to list of detections
""" """
from ..pdf.renderer import render_pdf_to_images from shared.pdf.renderer import render_pdf_to_images
from PIL import Image from PIL import Image
import io import io

View File

@@ -16,7 +16,7 @@ from datetime import datetime
import psycopg2 import psycopg2
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
from src.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
@dataclass @dataclass

View File

@@ -0,0 +1,8 @@
"""
Backward compatibility shim for admin_routes.py
DEPRECATED: Import from inference.web.api.v1.admin.documents instead.
"""
from inference.web.api.v1.admin.documents import *
__all__ = ["create_admin_router"]

View File

@@ -0,0 +1,19 @@
"""
Admin API v1
Document management, annotations, and training endpoints.
"""
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.api.v1.admin.auth import create_auth_router
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.api.v1.admin.locks import create_locks_router
from inference.web.api.v1.admin.training import create_training_router
__all__ = [
"create_annotation_router",
"create_auth_router",
"create_documents_router",
"create_locks_router",
"create_training_router",
]

View File

@@ -12,11 +12,11 @@ from uuid import UUID
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
from src.web.core.auth import AdminTokenDep, AdminDBDep from inference.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.services.autolabel import get_auto_label_service from inference.web.services.autolabel import get_auto_label_service
from src.web.schemas.admin import ( from inference.web.schemas.admin import (
AnnotationCreate, AnnotationCreate,
AnnotationItem, AnnotationItem,
AnnotationListResponse, AnnotationListResponse,
@@ -31,7 +31,7 @@ from src.web.schemas.admin import (
AutoLabelResponse, AutoLabelResponse,
BoundingBox, BoundingBox,
) )
from src.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -10,12 +10,12 @@ from datetime import datetime, timedelta
from fastapi import APIRouter from fastapi import APIRouter
from src.web.core.auth import AdminTokenDep, AdminDBDep from inference.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.schemas.admin import ( from inference.web.schemas.admin import (
AdminTokenCreate, AdminTokenCreate,
AdminTokenResponse, AdminTokenResponse,
) )
from src.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -11,9 +11,9 @@ from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from src.web.config import DEFAULT_DPI, StorageConfig from inference.web.config import DEFAULT_DPI, StorageConfig
from src.web.core.auth import AdminTokenDep, AdminDBDep from inference.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.schemas.admin import ( from inference.web.schemas.admin import (
AnnotationItem, AnnotationItem,
AnnotationSource, AnnotationSource,
AutoLabelStatus, AutoLabelStatus,
@@ -27,7 +27,7 @@ from src.web.schemas.admin import (
ModelMetrics, ModelMetrics,
TrainingHistoryItem, TrainingHistoryItem,
) )
from src.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -142,8 +142,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
raise HTTPException(status_code=500, detail="Failed to save file") raise HTTPException(status_code=500, detail="Failed to save file")
# Update file path in database # Update file path in database
from src.data.database import get_session_context from inference.data.database import get_session_context
from src.data.admin_models import AdminDocument from inference.data.admin_models import AdminDocument
with get_session_context() as session: with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id)) doc = session.get(AdminDocument, UUID(document_id))
if doc: if doc:
@@ -520,7 +520,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
# If marking as labeled, save annotations to PostgreSQL DocumentDB # If marking as labeled, save annotations to PostgreSQL DocumentDB
db_save_result = None db_save_result = None
if status == "labeled": if status == "labeled":
from src.web.services.db_autolabel import save_manual_annotations_to_document_db from inference.web.services.db_autolabel import save_manual_annotations_to_document_db
# Get all annotations for this document # Get all annotations for this document
annotations = db.get_annotations_for_document(document_id) annotations = db.get_annotations_for_document(document_id)

View File

@@ -10,12 +10,12 @@ from uuid import UUID
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from src.web.core.auth import AdminTokenDep, AdminDBDep from inference.web.core.auth import AdminTokenDep, AdminDBDep
from src.web.schemas.admin import ( from inference.web.schemas.admin import (
AnnotationLockRequest, AnnotationLockRequest,
AnnotationLockResponse, AnnotationLockResponse,
) )
from src.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,28 @@
"""
Admin Training API Routes
FastAPI endpoints for training task management and scheduling.
"""
from fastapi import APIRouter
from ._utils import _validate_uuid
from .tasks import register_task_routes
from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
def create_training_router() -> APIRouter:
"""Create training API router."""
router = APIRouter(prefix="/admin/training", tags=["Admin Training"])
register_task_routes(router)
register_document_routes(router)
register_export_routes(router)
register_dataset_routes(router)
return router
__all__ = ["create_training_router", "_validate_uuid"]

View File

@@ -0,0 +1,16 @@
"""Shared utilities for training routes."""
from uuid import UUID
from fastapi import HTTPException
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
try:
UUID(value)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid {name} format. Must be a valid UUID.",
)

View File

@@ -0,0 +1,209 @@
"""Training Dataset Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
DatasetCreateRequest,
DatasetDetailResponse,
DatasetDocumentItem,
DatasetListItem,
DatasetListResponse,
DatasetResponse,
DatasetTrainRequest,
TrainingStatus,
TrainingTaskResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_dataset_routes(router: APIRouter) -> None:
"""Register dataset endpoints on the router."""
@router.post(
"/datasets",
response_model=DatasetResponse,
summary="Create training dataset",
description="Create a dataset from selected documents with train/val/test splits.",
)
async def create_dataset(
request: DatasetCreateRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DatasetResponse:
"""Create a training dataset from document IDs."""
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder
dataset = db.create_dataset(
name=request.name,
description=request.description,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
)
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
try:
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=request.document_ids,
train_ratio=request.train_ratio,
val_ratio=request.val_ratio,
seed=request.seed,
admin_images_dir=Path("data/admin_images"),
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return DatasetResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
status="ready",
message="Dataset created successfully",
)
@router.get(
"/datasets",
response_model=DatasetListResponse,
summary="List datasets",
)
async def list_datasets(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[str | None, Query(description="Filter by status")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> DatasetListResponse:
"""List training datasets."""
datasets, total = db.get_datasets(status=status, limit=limit, offset=offset)
return DatasetListResponse(
total=total,
limit=limit,
offset=offset,
datasets=[
DatasetListItem(
dataset_id=str(d.dataset_id),
name=d.name,
description=d.description,
status=d.status,
total_documents=d.total_documents,
total_images=d.total_images,
total_annotations=d.total_annotations,
created_at=d.created_at,
)
for d in datasets
],
)
@router.get(
"/datasets/{dataset_id}",
response_model=DatasetDetailResponse,
summary="Get dataset detail",
)
async def get_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DatasetDetailResponse:
"""Get dataset details with document list."""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
docs = db.get_dataset_documents(dataset_id)
return DatasetDetailResponse(
dataset_id=str(dataset.dataset_id),
name=dataset.name,
description=dataset.description,
status=dataset.status,
train_ratio=dataset.train_ratio,
val_ratio=dataset.val_ratio,
seed=dataset.seed,
total_documents=dataset.total_documents,
total_images=dataset.total_images,
total_annotations=dataset.total_annotations,
dataset_path=dataset.dataset_path,
error_message=dataset.error_message,
documents=[
DatasetDocumentItem(
document_id=str(d.document_id),
split=d.split,
page_count=d.page_count,
annotation_count=d.annotation_count,
)
for d in docs
],
created_at=dataset.created_at,
updated_at=dataset.updated_at,
)
@router.delete(
"/datasets/{dataset_id}",
summary="Delete dataset",
)
async def delete_dataset(
dataset_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> dict:
"""Delete a dataset and its files."""
import shutil
from pathlib import Path
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.dataset_path:
dataset_dir = Path(dataset.dataset_path)
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
db.delete_dataset(dataset_id)
return {"message": "Dataset deleted"}
@router.post(
"/datasets/{dataset_id}/train",
response_model=TrainingTaskResponse,
summary="Start training from dataset",
)
async def train_from_dataset(
dataset_id: str,
request: DatasetTrainRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset."""
_validate_uuid(dataset_id, "dataset_id")
dataset = db.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="Dataset not found")
if dataset.status != "ready":
raise HTTPException(
status_code=400,
detail=f"Dataset is not ready (status: {dataset.status})",
)
config_dict = request.config.model_dump()
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type="train",
config=config_dict,
dataset_id=str(dataset.dataset_id),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.PENDING,
message="Training task created from dataset",
)

View File

@@ -0,0 +1,211 @@
"""Training Documents and Models Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ModelMetrics,
TrainingDocumentItem,
TrainingDocumentsResponse,
TrainingModelItem,
TrainingModelsResponse,
TrainingStatus,
)
from inference.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_document_routes(router: APIRouter) -> None:
"""Register training document and model endpoints on the router."""
@router.get(
"/documents",
response_model=TrainingDocumentsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get documents for training",
description="Get labeled documents available for training with filtering options.",
)
async def get_training_documents(
admin_token: AdminTokenDep,
db: AdminDBDep,
has_annotations: Annotated[
bool,
Query(description="Only include documents with annotations"),
] = True,
min_annotation_count: Annotated[
int | None,
Query(ge=1, description="Minimum annotation count"),
] = None,
exclude_used_in_training: Annotated[
bool,
Query(description="Exclude documents already used in training"),
] = False,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingDocumentsResponse:
"""Get documents available for training."""
documents, total = db.get_documents_for_training(
admin_token=admin_token,
status="labeled",
has_annotations=has_annotations,
min_annotation_count=min_annotation_count,
exclude_used_in_training=exclude_used_in_training,
limit=limit,
offset=offset,
)
items = []
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
sources = {"manual": 0, "auto": 0}
for ann in annotations:
if ann.source in sources:
sources[ann.source] += 1
training_links = db.get_document_training_tasks(doc.document_id)
used_in_training = [str(link.task_id) for link in training_links]
items.append(
TrainingDocumentItem(
document_id=str(doc.document_id),
filename=doc.filename,
annotation_count=len(annotations),
annotation_sources=sources,
used_in_training=used_in_training,
last_modified=doc.updated_at,
)
)
return TrainingDocumentsResponse(
total=total,
limit=limit,
offset=offset,
documents=items,
)
@router.get(
"/models/{task_id}/download",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Model not found"},
},
summary="Download trained model",
description="Download trained model weights file.",
)
async def download_model(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
):
"""Download trained model."""
from fastapi.responses import FileResponse
from pathlib import Path
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if not task.model_path:
raise HTTPException(
status_code=404,
detail="Model file not available for this task",
)
model_path = Path(task.model_path)
if not model_path.exists():
raise HTTPException(
status_code=404,
detail="Model file not found on disk",
)
return FileResponse(
path=str(model_path),
media_type="application/octet-stream",
filename=f"{task.name}_model.pt",
)
@router.get(
"/models",
response_model=TrainingModelsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get trained models",
description="Get list of trained models with metrics and download links.",
)
async def get_training_models(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status (completed, failed, etc.)"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingModelsResponse:
"""Get list of trained models."""
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status if status else "completed",
limit=limit,
offset=offset,
)
items = []
for task in tasks:
metrics = ModelMetrics(
mAP=task.metrics_mAP,
precision=task.metrics_precision,
recall=task.metrics_recall,
)
download_url = None
if task.model_path and task.status == "completed":
download_url = f"/api/v1/admin/training/models/{task.task_id}/download"
items.append(
TrainingModelItem(
task_id=str(task.task_id),
name=task.name,
status=TrainingStatus(task.status),
document_count=task.document_count,
created_at=task.created_at,
completed_at=task.completed_at,
metrics=metrics,
model_path=task.model_path,
download_url=download_url,
)
)
return TrainingModelsResponse(
total=total,
limit=limit,
offset=offset,
models=items,
)

View File

@@ -0,0 +1,121 @@
"""Training Export Endpoints."""
import logging
from datetime import datetime
from fastapi import APIRouter, HTTPException
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
ExportRequest,
ExportResponse,
)
from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
def register_export_routes(router: APIRouter) -> None:
"""Register export endpoints on the router."""
@router.post(
"/export",
response_model=ExportResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Export annotations",
description="Export annotations in YOLO format for training.",
)
async def export_annotations(
request: ExportRequest,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> ExportResponse:
"""Export annotations for training."""
from pathlib import Path
import shutil
if request.format not in ("yolo", "coco", "voc"):
raise HTTPException(
status_code=400,
detail=f"Unsupported export format: {request.format}",
)
documents = db.get_labeled_documents_for_export(admin_token)
if not documents:
raise HTTPException(
status_code=400,
detail="No labeled documents available for export",
)
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
export_dir.mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True)
(export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True)
total_docs = len(documents)
train_count = int(total_docs * request.split_ratio)
train_docs = documents[:train_count]
val_docs = documents[train_count:]
total_images = 0
total_annotations = 0
for split, docs in [("train", train_docs), ("val", val_docs)]:
for doc in docs:
annotations = db.get_annotations_for_document(str(doc.document_id))
if not annotations:
continue
for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num]
if not page_annotations and not request.include_images:
continue
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
if not src_image.exists():
continue
image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image)
total_images += 1
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
f.write(line)
total_annotations += 1
from inference.data.admin_models import FIELD_CLASSES
yaml_content = f"""# Auto-generated YOLO dataset config
path: {export_dir.absolute()}
train: images/train
val: images/val
nc: {len(FIELD_CLASSES)}
names: {list(FIELD_CLASSES.values())}
"""
(export_dir / "data.yaml").write_text(yaml_content)
return ExportResponse(
status="completed",
export_path=str(export_dir),
total_images=total_images,
total_annotations=total_annotations,
train_count=len(train_docs),
val_count=len(val_docs),
message=f"Exported {total_images} images with {total_annotations} annotations",
)

View File

@@ -0,0 +1,263 @@
"""Training Task Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.schemas.admin import (
TrainingLogItem,
TrainingLogsResponse,
TrainingStatus,
TrainingTaskCreate,
TrainingTaskDetailResponse,
TrainingTaskItem,
TrainingTaskListResponse,
TrainingTaskResponse,
TrainingType,
)
from inference.web.schemas.common import ErrorResponse
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_task_routes(router: APIRouter) -> None:
"""Register training task endpoints on the router."""
@router.post(
"/tasks",
response_model=TrainingTaskResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Create training task",
description="Create a new training task.",
)
async def create_training_task(
request: TrainingTaskCreate,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Create a new training task."""
config_dict = request.config.model_dump() if request.config else {}
task_id = db.create_training_task(
admin_token=admin_token,
name=request.name,
task_type=request.task_type.value,
description=request.description,
config=config_dict,
scheduled_at=request.scheduled_at,
cron_expression=request.cron_expression,
is_recurring=bool(request.cron_expression),
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING,
message="Training task created successfully",
)
@router.get(
"/tasks",
response_model=TrainingTaskListResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="List training tasks",
description="List all training tasks.",
)
async def list_training_tasks(
admin_token: AdminTokenDep,
db: AdminDBDep,
status: Annotated[
str | None,
Query(description="Filter by status"),
] = None,
limit: Annotated[
int,
Query(ge=1, le=100, description="Page size"),
] = 20,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingTaskListResponse:
"""List training tasks."""
valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled")
if status and status not in valid_statuses:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}",
)
tasks, total = db.get_training_tasks_by_token(
admin_token=admin_token,
status=status,
limit=limit,
offset=offset,
)
items = [
TrainingTaskItem(
task_id=str(task.task_id),
name=task.name,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
scheduled_at=task.scheduled_at,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
created_at=task.created_at,
)
for task in tasks
]
return TrainingTaskListResponse(
total=total,
limit=limit,
offset=offset,
tasks=items,
)
@router.get(
"/tasks/{task_id}",
response_model=TrainingTaskDetailResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training task detail",
description="Get training task details.",
)
async def get_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskDetailResponse:
"""Get training task details."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
return TrainingTaskDetailResponse(
task_id=str(task.task_id),
name=task.name,
description=task.description,
task_type=TrainingType(task.task_type),
status=TrainingStatus(task.status),
config=task.config,
scheduled_at=task.scheduled_at,
cron_expression=task.cron_expression,
is_recurring=task.is_recurring,
started_at=task.started_at,
completed_at=task.completed_at,
error_message=task.error_message,
result_metrics=task.result_metrics,
model_path=task.model_path,
created_at=task.created_at,
)
@router.post(
"/tasks/{task_id}/cancel",
response_model=TrainingTaskResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
409: {"model": ErrorResponse, "description": "Cannot cancel task"},
},
summary="Cancel training task",
description="Cancel a pending or scheduled training task.",
)
async def cancel_training_task(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> TrainingTaskResponse:
"""Cancel a training task."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
if task.status not in ("pending", "scheduled"):
raise HTTPException(
status_code=409,
detail=f"Cannot cancel task with status: {task.status}",
)
success = db.cancel_training_task(task_id)
if not success:
raise HTTPException(
status_code=500,
detail="Failed to cancel training task",
)
return TrainingTaskResponse(
task_id=task_id,
status=TrainingStatus.CANCELLED,
message="Training task cancelled successfully",
)
@router.get(
"/tasks/{task_id}/logs",
response_model=TrainingLogsResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Task not found"},
},
summary="Get training logs",
description="Get training task logs.",
)
async def get_training_logs(
task_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
limit: Annotated[
int,
Query(ge=1, le=500, description="Maximum logs to return"),
] = 100,
offset: Annotated[
int,
Query(ge=0, description="Offset"),
] = 0,
) -> TrainingLogsResponse:
"""Get training logs."""
_validate_uuid(task_id, "task_id")
task = db.get_training_task_by_token(task_id, admin_token)
if task is None:
raise HTTPException(
status_code=404,
detail="Training task not found or does not belong to this token",
)
logs = db.get_training_logs(task_id, limit, offset)
items = [
TrainingLogItem(
level=log.level,
message=log.message,
details=log.details,
created_at=log.created_at,
)
for log in logs
]
return TrainingLogsResponse(
task_id=task_id,
logs=items,
)

View File

@@ -14,10 +14,10 @@ from uuid import UUID
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.web.core.auth import validate_admin_token, get_admin_db from inference.web.core.auth import validate_admin_token, get_admin_db
from src.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE
from src.web.workers.batch_queue import BatchTask, get_batch_queue from inference.web.workers.batch_queue import BatchTask, get_batch_queue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,16 @@
"""
Public API v1
Customer-facing endpoints for inference, async processing, and labeling.
"""
from inference.web.api.v1.public.inference import create_inference_router
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
from inference.web.api.v1.public.labeling import create_labeling_router
__all__ = [
"create_inference_router",
"create_async_router",
"set_async_service",
"create_labeling_router",
]

View File

@@ -11,13 +11,13 @@ from uuid import UUID
from fastapi import APIRouter, File, HTTPException, Query, UploadFile from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from src.web.dependencies import ( from inference.web.dependencies import (
ApiKeyDep, ApiKeyDep,
AsyncDBDep, AsyncDBDep,
PollRateLimitDep, PollRateLimitDep,
SubmitRateLimitDep, SubmitRateLimitDep,
) )
from src.web.schemas.inference import ( from inference.web.schemas.inference import (
AsyncRequestItem, AsyncRequestItem,
AsyncRequestsListResponse, AsyncRequestsListResponse,
AsyncResultResponse, AsyncResultResponse,
@@ -27,7 +27,7 @@ from src.web.schemas.inference import (
DetectionResult, DetectionResult,
InferenceResult, InferenceResult,
) )
from src.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
def _validate_request_id(request_id: str) -> None: def _validate_request_id(request_id: str) -> None:

View File

@@ -15,17 +15,17 @@ from typing import TYPE_CHECKING
from fastapi import APIRouter, File, HTTPException, UploadFile, status from fastapi import APIRouter, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from src.web.schemas.inference import ( from inference.web.schemas.inference import (
DetectionResult, DetectionResult,
HealthResponse, HealthResponse,
InferenceResponse, InferenceResponse,
InferenceResult, InferenceResult,
) )
from src.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
if TYPE_CHECKING: if TYPE_CHECKING:
from src.web.services import InferenceService from inference.web.services import InferenceService
from src.web.config import StorageConfig from inference.web.config import StorageConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -13,13 +13,13 @@ from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.web.schemas.labeling import PreLabelResponse from inference.web.schemas.labeling import PreLabelResponse
from src.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
if TYPE_CHECKING: if TYPE_CHECKING:
from src.web.services import InferenceService from inference.web.services import InferenceService
from src.web.config import StorageConfig from inference.web.config import StorageConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -17,10 +17,10 @@ from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config from .config import AppConfig, default_config
from src.web.services import InferenceService from inference.web.services import InferenceService
# Public API imports # Public API imports
from src.web.api.v1.public import ( from inference.web.api.v1.public import (
create_inference_router, create_inference_router,
create_async_router, create_async_router,
set_async_service, set_async_service,
@@ -28,28 +28,28 @@ from src.web.api.v1.public import (
) )
# Async processing imports # Async processing imports
from src.data.async_request_db import AsyncRequestDB from inference.data.async_request_db import AsyncRequestDB
from src.web.workers.async_queue import AsyncTaskQueue from inference.web.workers.async_queue import AsyncTaskQueue
from src.web.services.async_processing import AsyncProcessingService from inference.web.services.async_processing import AsyncProcessingService
from src.web.dependencies import init_dependencies from inference.web.dependencies import init_dependencies
from src.web.core.rate_limiter import RateLimiter from inference.web.core.rate_limiter import RateLimiter
# Admin API imports # Admin API imports
from src.web.api.v1.admin import ( from inference.web.api.v1.admin import (
create_annotation_router, create_annotation_router,
create_auth_router, create_auth_router,
create_documents_router, create_documents_router,
create_locks_router, create_locks_router,
create_training_router, create_training_router,
) )
from src.web.core.scheduler import start_scheduler, stop_scheduler from inference.web.core.scheduler import start_scheduler, stop_scheduler
from src.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler
# Batch upload imports # Batch upload imports
from src.web.api.v1.batch.routes import router as batch_upload_router from inference.web.api.v1.batch.routes import router as batch_upload_router
from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from src.web.services.batch_upload import BatchUploadService from inference.web.services.batch_upload import BatchUploadService
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator

View File

@@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from src.config import DEFAULT_DPI, PATHS from shared.config import DEFAULT_DPI, PATHS
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@@ -4,10 +4,10 @@ Core Components
Reusable core functionality: authentication, rate limiting, scheduling. Reusable core functionality: authentication, rate limiting, scheduling.
""" """
from src.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep
from src.web.core.rate_limiter import RateLimiter from inference.web.core.rate_limiter import RateLimiter
from src.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler
from src.web.core.autolabel_scheduler import ( from inference.web.core.autolabel_scheduler import (
start_autolabel_scheduler, start_autolabel_scheduler,
stop_autolabel_scheduler, stop_autolabel_scheduler,
get_autolabel_scheduler, get_autolabel_scheduler,

View File

@@ -9,8 +9,8 @@ from typing import Annotated
from fastapi import Depends, Header, HTTPException from fastapi import Depends, Header, HTTPException
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.data.database import get_session_context from inference.data.database import get_session_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -8,8 +8,8 @@ import logging
import threading import threading
from pathlib import Path from pathlib import Path
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.web.services.db_autolabel import ( from inference.web.services.db_autolabel import (
get_pending_autolabel_documents, get_pending_autolabel_documents,
process_document_autolabel, process_document_autolabel,
) )

View File

@@ -13,7 +13,7 @@ from threading import Lock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from src.data.async_request_db import AsyncRequestDB from inference.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -10,7 +10,7 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -86,7 +86,8 @@ class TrainingScheduler:
logger.info(f"Starting training task: {task_id}") logger.info(f"Starting training task: {task_id}")
try: try:
self._execute_task(task_id, task.config or {}) dataset_id = getattr(task, "dataset_id", None)
self._execute_task(task_id, task.config or {}, dataset_id=dataset_id)
except Exception as e: except Exception as e:
logger.error(f"Training task {task_id} failed: {e}") logger.error(f"Training task {task_id} failed: {e}")
self._db.update_training_task_status( self._db.update_training_task_status(
@@ -98,7 +99,9 @@ class TrainingScheduler:
except Exception as e: except Exception as e:
logger.error(f"Error checking pending tasks: {e}") logger.error(f"Error checking pending tasks: {e}")
def _execute_task(self, task_id: str, config: dict[str, Any]) -> None: def _execute_task(
self, task_id: str, config: dict[str, Any], dataset_id: str | None = None
) -> None:
"""Execute a training task.""" """Execute a training task."""
# Update status to running # Update status to running
self._db.update_training_task_status(task_id, "running") self._db.update_training_task_status(task_id, "running")
@@ -114,17 +117,25 @@ class TrainingScheduler:
device = config.get("device", "0") device = config.get("device", "0")
project_name = config.get("project_name", "invoice_fields") project_name = config.get("project_name", "invoice_fields")
# Export annotations for training # Use dataset if available, otherwise export from scratch
export_result = self._export_training_data(task_id) if dataset_id:
if not export_result: dataset = self._db.get_dataset(dataset_id)
raise ValueError("Failed to export training data") if not dataset or not dataset.dataset_path:
raise ValueError(f"Dataset {dataset_id} not found or has no path")
data_yaml = export_result["data_yaml"] data_yaml = str(Path(dataset.dataset_path) / "data.yaml")
self._db.add_training_log(
self._db.add_training_log( task_id, "INFO",
task_id, "INFO", f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)",
f"Exported {export_result['total_images']} images for training", )
) else:
export_result = self._export_training_data(task_id)
if not export_result:
raise ValueError("Failed to export training data")
data_yaml = export_result["data_yaml"]
self._db.add_training_log(
task_id, "INFO",
f"Exported {export_result['total_images']} images for training",
)
# Run YOLO training # Run YOLO training
result = self._run_yolo_training( result = self._run_yolo_training(
@@ -157,7 +168,7 @@ class TrainingScheduler:
"""Export training data for a task.""" """Export training data for a task."""
from pathlib import Path from pathlib import Path
import shutil import shutil
from src.data.admin_models import FIELD_CLASSES from inference.data.admin_models import FIELD_CLASSES
# Get all labeled documents # Get all labeled documents
documents = self._db.get_labeled_documents_for_export() documents = self._db.get_labeled_documents_for_export()

View File

@@ -9,8 +9,8 @@ from typing import Annotated
from fastapi import Depends, Header, HTTPException, Request from fastapi import Depends, Header, HTTPException, Request
from src.data.async_request_db import AsyncRequestDB from inference.data.async_request_db import AsyncRequestDB
from src.web.rate_limiter import RateLimiter from inference.web.rate_limiter import RateLimiter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -13,7 +13,7 @@ from threading import Lock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from src.data.async_request_db import AsyncRequestDB from inference.data.async_request_db import AsyncRequestDB
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,11 @@
"""
API Schemas
Pydantic models for request/response validation.
"""
# Import everything from sub-modules for backward compatibility
from inference.web.schemas.common import * # noqa: F401, F403
from inference.web.schemas.admin import * # noqa: F401, F403
from inference.web.schemas.inference import * # noqa: F401, F403
from inference.web.schemas.labeling import * # noqa: F401, F403

View File

@@ -0,0 +1,17 @@
"""
Admin API Request/Response Schemas
Pydantic models for admin API validation and serialization.
"""
from .enums import * # noqa: F401, F403
from .auth import * # noqa: F401, F403
from .documents import * # noqa: F401, F403
from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse
from .documents import DocumentDetailResponse
DocumentDetailResponse.model_rebuild()

View File

@@ -0,0 +1,152 @@
"""Admin Annotation Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .enums import AnnotationSource
class BoundingBox(BaseModel):
"""Bounding box coordinates."""
x: int = Field(..., ge=0, description="X coordinate (pixels)")
y: int = Field(..., ge=0, description="Y coordinate (pixels)")
width: int = Field(..., ge=1, description="Width (pixels)")
height: int = Field(..., ge=1, description="Height (pixels)")
class AnnotationCreate(BaseModel):
"""Request to create an annotation."""
page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
text_value: str | None = Field(None, description="Text value (optional)")
class AnnotationUpdate(BaseModel):
"""Request to update an annotation."""
class_id: int | None = Field(None, ge=0, le=9, description="New class ID")
bbox: BoundingBox | None = Field(None, description="New bounding box")
text_value: str | None = Field(None, description="New text value")
class AnnotationItem(BaseModel):
"""Single annotation item."""
annotation_id: str = Field(..., description="Annotation UUID")
page_number: int = Field(..., ge=1, description="Page number")
class_id: int = Field(..., ge=0, le=9, description="Class ID")
class_name: str = Field(..., description="Class name")
bbox: BoundingBox = Field(..., description="Bounding box in pixels")
normalized_bbox: dict[str, float] = Field(
..., description="Normalized bbox (x_center, y_center, width, height)"
)
text_value: str | None = Field(None, description="Text value")
confidence: float | None = Field(None, ge=0, le=1, description="Confidence score")
source: AnnotationSource = Field(..., description="Annotation source")
created_at: datetime = Field(..., description="Creation timestamp")
class AnnotationResponse(BaseModel):
"""Response for annotation operation."""
annotation_id: str = Field(..., description="Annotation UUID")
message: str = Field(..., description="Status message")
class AnnotationListResponse(BaseModel):
"""Response for annotation list."""
document_id: str = Field(..., description="Document UUID")
page_count: int = Field(..., ge=1, description="Total pages")
total_annotations: int = Field(..., ge=0, description="Total annotations")
annotations: list[AnnotationItem] = Field(
default_factory=list, description="Annotation list"
)
class AnnotationLockRequest(BaseModel):
"""Request to acquire annotation lock."""
duration_seconds: int = Field(
default=300,
ge=60,
le=3600,
description="Lock duration in seconds (60-3600)",
)
class AnnotationLockResponse(BaseModel):
"""Response for annotation lock operation."""
document_id: str = Field(..., description="Document UUID")
locked: bool = Field(..., description="Whether lock was acquired/released")
lock_expires_at: datetime | None = Field(
None, description="Lock expiration time"
)
message: str = Field(..., description="Status message")
class AutoLabelRequest(BaseModel):
"""Request to trigger auto-labeling."""
field_values: dict[str, str] = Field(
...,
description="Field values to match (e.g., {'invoice_number': '12345'})",
)
replace_existing: bool = Field(
default=False, description="Replace existing auto annotations"
)
class AutoLabelResponse(BaseModel):
"""Response for auto-labeling."""
document_id: str = Field(..., description="Document UUID")
status: str = Field(..., description="Auto-labeling status")
annotations_created: int = Field(
default=0, ge=0, description="Number of annotations created"
)
message: str = Field(..., description="Status message")
class AnnotationVerifyRequest(BaseModel):
"""Request to verify an annotation."""
pass # No body needed, just POST to verify
class AnnotationVerifyResponse(BaseModel):
"""Response for annotation verification."""
annotation_id: str = Field(..., description="Annotation UUID")
is_verified: bool = Field(..., description="Verification status")
verified_at: datetime = Field(..., description="Verification timestamp")
verified_by: str = Field(..., description="Admin token who verified")
message: str = Field(..., description="Status message")
class AnnotationOverrideRequest(BaseModel):
"""Request to override an annotation."""
bbox: dict[str, int] | None = Field(
None, description="Updated bounding box {x, y, width, height}"
)
text_value: str | None = Field(None, description="Updated text value")
class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID")
class_name: str | None = Field(None, description="Updated class name")
reason: str | None = Field(None, description="Reason for override")
class AnnotationOverrideResponse(BaseModel):
"""Response for annotation override."""
annotation_id: str = Field(..., description="Annotation UUID")
source: str = Field(..., description="New source (manual)")
override_source: str | None = Field(None, description="Original source (auto)")
original_annotation_id: str | None = Field(None, description="Original annotation ID")
message: str = Field(..., description="Status message")
history_id: str = Field(..., description="History record UUID")

View File

@@ -0,0 +1,23 @@
"""Admin Auth Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class AdminTokenCreate(BaseModel):
"""Request to create an admin token."""
name: str = Field(..., min_length=1, max_length=255, description="Token name")
expires_in_days: int | None = Field(
None, ge=1, le=365, description="Token expiration in days (optional)"
)
class AdminTokenResponse(BaseModel):
"""Response with created admin token."""
token: str = Field(..., description="Admin token")
name: str = Field(..., description="Token name")
expires_at: datetime | None = Field(None, description="Token expiration time")
message: str = Field(..., description="Status message")

View File

@@ -0,0 +1,85 @@
"""Admin Dataset Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from .training import TrainingConfig
class DatasetCreateRequest(BaseModel):
"""Request to create a training dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
description: str | None = Field(None, description="Optional description")
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
seed: int = Field(42, description="Random seed for split")
class DatasetDocumentItem(BaseModel):
"""Document within a dataset."""
document_id: str
split: str
page_count: int
annotation_count: int
class DatasetResponse(BaseModel):
"""Response after creating a dataset."""
dataset_id: str
name: str
status: str
message: str
class DatasetDetailResponse(BaseModel):
"""Detailed dataset info with documents."""
dataset_id: str
name: str
description: str | None
status: str
train_ratio: float
val_ratio: float
seed: int
total_documents: int
total_images: int
total_annotations: int
dataset_path: str | None
error_message: str | None
documents: list[DatasetDocumentItem]
created_at: datetime
updated_at: datetime
class DatasetListItem(BaseModel):
"""Dataset in list view."""
dataset_id: str
name: str
description: str | None
status: str
total_documents: int
total_images: int
total_annotations: int
created_at: datetime
class DatasetListResponse(BaseModel):
"""Paginated dataset list."""
total: int
limit: int
offset: int
datasets: list[DatasetListItem]
class DatasetTrainRequest(BaseModel):
"""Request to start training from a dataset."""
name: str = Field(..., min_length=1, max_length=255, description="Training task name")
config: TrainingConfig = Field(..., description="Training configuration")

View File

@@ -0,0 +1,103 @@
"""Admin Document Schemas."""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from .enums import AutoLabelStatus, DocumentStatus
if TYPE_CHECKING:
from .annotations import AnnotationItem
from .training import TrainingHistoryItem
class DocumentUploadResponse(BaseModel):
"""Response for document upload."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started"
)
message: str = Field(..., description="Status message")
class DocumentItem(BaseModel):
"""Single document in list."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
annotation_count: int = Field(default=0, ge=0, description="Number of annotations")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentListResponse(BaseModel):
"""Response for document list."""
total: int = Field(..., ge=0, description="Total documents")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
documents: list[DocumentItem] = Field(
default_factory=list, description="Document list"
)
class DocumentDetailResponse(BaseModel):
"""Response for document detail."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
content_type: str = Field(..., description="MIME type")
page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status")
auto_label_status: AutoLabelStatus | None = Field(
None, description="Auto-labeling status"
)
auto_label_error: str | None = Field(None, description="Auto-labeling error")
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch"
)
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
annotation_lock_until: datetime | None = Field(
None, description="Lock expiration time if document is locked"
)
annotations: list["AnnotationItem"] = Field(
default_factory=list, description="Document annotations"
)
image_urls: list[str] = Field(
default_factory=list, description="URLs to page images"
)
training_history: list["TrainingHistoryItem"] = Field(
default_factory=list, description="Training tasks that used this document"
)
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
class DocumentStatsResponse(BaseModel):
"""Document statistics response."""
total: int = Field(..., ge=0, description="Total documents")
pending: int = Field(default=0, ge=0, description="Pending documents")
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
labeled: int = Field(default=0, ge=0, description="Labeled documents")
exported: int = Field(default=0, ge=0, description="Exported documents")

View File

@@ -0,0 +1,46 @@
"""Admin API Enums."""
from enum import Enum
class DocumentStatus(str, Enum):
"""Document status enum."""
PENDING = "pending"
AUTO_LABELING = "auto_labeling"
LABELED = "labeled"
EXPORTED = "exported"
class AutoLabelStatus(str, Enum):
"""Auto-labeling status enum."""
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class TrainingStatus(str, Enum):
"""Training task status enum."""
PENDING = "pending"
SCHEDULED = "scheduled"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingType(str, Enum):
"""Training task type enum."""
TRAIN = "train"
FINETUNE = "finetune"
class AnnotationSource(str, Enum):
"""Annotation source enum."""
MANUAL = "manual"
AUTO = "auto"
IMPORTED = "imported"

View File

@@ -0,0 +1,202 @@
"""Admin Training Schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name")
epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs")
batch_size: int = Field(default=16, ge=1, le=128, description="Batch size")
image_size: int = Field(default=640, ge=320, le=1280, description="Image size")
learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate")
device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)")
project_name: str = Field(
default="invoice_fields", description="Training project name"
)
class TrainingTaskCreate(BaseModel):
"""Request to create a training task."""
name: str = Field(..., min_length=1, max_length=255, description="Task name")
description: str | None = Field(None, max_length=1000, description="Description")
task_type: TrainingType = Field(
default=TrainingType.TRAIN, description="Task type"
)
config: TrainingConfig = Field(
default_factory=TrainingConfig, description="Training configuration"
)
scheduled_at: datetime | None = Field(
None, description="Scheduled execution time"
)
cron_expression: str | None = Field(
None, max_length=50, description="Cron expression for recurring tasks"
)
class TrainingTaskItem(BaseModel):
"""Single training task in list."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskListResponse(BaseModel):
"""Response for training task list."""
total: int = Field(..., ge=0, description="Total tasks")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list")
class TrainingTaskDetailResponse(BaseModel):
"""Response for training task detail."""
task_id: str = Field(..., description="Task UUID")
name: str = Field(..., description="Task name")
description: str | None = Field(None, description="Description")
task_type: TrainingType = Field(..., description="Task type")
status: TrainingStatus = Field(..., description="Task status")
config: dict[str, Any] | None = Field(None, description="Training configuration")
scheduled_at: datetime | None = Field(None, description="Scheduled time")
cron_expression: str | None = Field(None, description="Cron expression")
is_recurring: bool = Field(default=False, description="Is recurring task")
started_at: datetime | None = Field(None, description="Start time")
completed_at: datetime | None = Field(None, description="Completion time")
error_message: str | None = Field(None, description="Error message")
result_metrics: dict[str, Any] | None = Field(None, description="Result metrics")
model_path: str | None = Field(None, description="Trained model path")
created_at: datetime = Field(..., description="Creation timestamp")
class TrainingTaskResponse(BaseModel):
"""Response for training task operation."""
task_id: str = Field(..., description="Task UUID")
status: TrainingStatus = Field(..., description="Task status")
message: str = Field(..., description="Status message")
class TrainingLogItem(BaseModel):
"""Single training log entry."""
level: str = Field(..., description="Log level")
message: str = Field(..., description="Log message")
details: dict[str, Any] | None = Field(None, description="Additional details")
created_at: datetime = Field(..., description="Timestamp")
class TrainingLogsResponse(BaseModel):
"""Response for training logs."""
task_id: str = Field(..., description="Task UUID")
logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries")
class ExportRequest(BaseModel):
"""Request to export annotations."""
format: str = Field(
default="yolo", description="Export format (yolo, coco, voc)"
)
include_images: bool = Field(
default=True, description="Include images in export"
)
split_ratio: float = Field(
default=0.8, ge=0.5, le=1.0, description="Train/val split ratio"
)
class ExportResponse(BaseModel):
"""Response for export operation."""
status: str = Field(..., description="Export status")
export_path: str = Field(..., description="Path to exported dataset")
total_images: int = Field(..., ge=0, description="Total images exported")
total_annotations: int = Field(..., ge=0, description="Total annotations")
train_count: int = Field(..., ge=0, description="Training set count")
val_count: int = Field(..., ge=0, description="Validation set count")
message: str = Field(..., description="Status message")
class TrainingDocumentItem(BaseModel):
"""Document item for training page."""
document_id: str = Field(..., description="Document UUID")
filename: str = Field(..., description="Filename")
annotation_count: int = Field(..., ge=0, description="Total annotations")
annotation_sources: dict[str, int] = Field(
..., description="Annotation counts by source (manual, auto)"
)
used_in_training: list[str] = Field(
default_factory=list, description="List of training task IDs that used this document"
)
last_modified: datetime = Field(..., description="Last modification time")
class TrainingDocumentsResponse(BaseModel):
"""Response for GET /admin/training/documents."""
total: int = Field(..., ge=0, description="Total document count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
documents: list[TrainingDocumentItem] = Field(
default_factory=list, description="Documents available for training"
)
class ModelMetrics(BaseModel):
"""Training model metrics."""
mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision")
precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision")
recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall")
class TrainingModelItem(BaseModel):
"""Trained model item for model list."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Model name")
status: TrainingStatus = Field(..., description="Training status")
document_count: int = Field(..., ge=0, description="Documents used in training")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: datetime | None = Field(None, description="Completion timestamp")
metrics: ModelMetrics = Field(..., description="Model metrics")
model_path: str | None = Field(None, description="Path to model weights")
download_url: str | None = Field(None, description="Download URL for model")
class TrainingModelsResponse(BaseModel):
"""Response for GET /admin/training/models."""
total: int = Field(..., ge=0, description="Total model count")
limit: int = Field(..., ge=1, le=100, description="Page size")
offset: int = Field(..., ge=0, description="Pagination offset")
models: list[TrainingModelItem] = Field(
default_factory=list, description="Trained models"
)
class TrainingHistoryItem(BaseModel):
"""Training history for a document."""
task_id: str = Field(..., description="Training task UUID")
name: str = Field(..., description="Training task name")
trained_at: datetime = Field(..., description="Training timestamp")
model_metrics: ModelMetrics | None = Field(None, description="Model metrics")

View File

@@ -0,0 +1,18 @@
"""
Business Logic Services
Service layer for processing requests and orchestrating data operations.
"""
from inference.web.services.autolabel import AutoLabelService, get_auto_label_service
from inference.web.services.inference import InferenceService
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.services.batch_upload import BatchUploadService
__all__ = [
"AutoLabelService",
"get_auto_label_service",
"InferenceService",
"AsyncProcessingService",
"BatchUploadService",
]

View File

@@ -14,13 +14,13 @@ from pathlib import Path
from threading import Event, Thread from threading import Event, Thread
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from src.data.async_request_db import AsyncRequestDB from inference.data.async_request_db import AsyncRequestDB
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from src.web.core.rate_limiter import RateLimiter from inference.web.core.rate_limiter import RateLimiter
if TYPE_CHECKING: if TYPE_CHECKING:
from src.web.config import AsyncConfig, StorageConfig from inference.web.config import AsyncConfig, StorageConfig
from src.web.services.inference import InferenceService from inference.web.services.inference import InferenceService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -11,11 +11,11 @@ from typing import Any
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from src.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
from src.matcher.field_matcher import FieldMatcher from shared.matcher.field_matcher import FieldMatcher
from src.ocr.paddle_ocr import OCREngine, OCRToken from shared.ocr.paddle_ocr import OCREngine, OCRToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -144,7 +144,7 @@ class AutoLabelService:
db: AdminDB, db: AdminDB,
) -> int: ) -> int:
"""Process PDF document and create annotations.""" """Process PDF document and create annotations."""
from src.pdf.renderer import render_pdf_to_images from shared.pdf.renderer import render_pdf_to_images
import io import io
total_annotations = 0 total_annotations = 0
@@ -222,7 +222,7 @@ class AutoLabelService:
image_height: int, image_height: int,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Find annotations for field values using token matching.""" """Find annotations for field values using token matching."""
from src.normalize import normalize_field from shared.normalize import normalize_field
annotations = [] annotations = []

View File

@@ -15,8 +15,8 @@ from uuid import UUID
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.data.admin_models import CSV_TO_CLASS_MAPPING from inference.data.admin_models import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,188 @@
"""
Dataset Builder Service
Creates training datasets by copying images from admin storage,
generating YOLO label files, and splitting into train/val/test sets.
"""
import logging
import random
import shutil
from pathlib import Path
import yaml
from inference.data.admin_models import FIELD_CLASSES
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Builds YOLO training datasets from admin documents."""
def __init__(self, db, base_dir: Path):
self._db = db
self._base_dir = Path(base_dir)
def build_dataset(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
"""Build a complete YOLO dataset from document IDs.
Args:
dataset_id: UUID of the dataset record.
document_ids: List of document UUIDs to include.
train_ratio: Fraction for training set.
val_ratio: Fraction for validation set.
seed: Random seed for reproducible splits.
admin_images_dir: Root directory of admin images.
Returns:
Summary dict with total_documents, total_images, total_annotations.
Raises:
ValueError: If no valid documents found.
"""
try:
return self._do_build(
dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir
)
except Exception as e:
self._db.update_dataset_status(
dataset_id=dataset_id,
status="failed",
error_message=str(e),
)
raise
def _do_build(
self,
dataset_id: str,
document_ids: list[str],
train_ratio: float,
val_ratio: float,
seed: int,
admin_images_dir: Path,
) -> dict:
# 1. Fetch documents
documents = self._db.get_documents_by_ids(document_ids)
if not documents:
raise ValueError("No valid documents found for the given IDs")
# 2. Create directory structure
dataset_dir = self._base_dir / dataset_id
for split in ["train", "val", "test"]:
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# 3. Shuffle and split documents
doc_list = list(documents)
rng = random.Random(seed)
rng.shuffle(doc_list)
n = len(doc_list)
n_train = max(1, round(n * train_ratio))
n_val = max(0, round(n * val_ratio))
n_test = n - n_train - n_val
splits = (
["train"] * n_train
+ ["val"] * n_val
+ ["test"] * n_test
)
# 4. Process each document
total_images = 0
total_annotations = 0
dataset_docs = []
for doc, split in zip(doc_list, splits):
doc_id = str(doc.document_id)
annotations = self._db.get_annotations_for_document(doc.document_id)
# Group annotations by page
page_annotations: dict[int, list] = {}
for ann in annotations:
page_annotations.setdefault(ann.page_number, []).append(ann)
doc_image_count = 0
doc_ann_count = 0
# Copy images and write labels for each page
for page_num in range(1, doc.page_count + 1):
src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png"
if not src_image.exists():
logger.warning("Image not found: %s", src_image)
continue
dst_name = f"{doc_id}_page{page_num}"
dst_image = dataset_dir / "images" / split / f"{dst_name}.png"
shutil.copy2(src_image, dst_image)
doc_image_count += 1
# Write YOLO label file
page_anns = page_annotations.get(page_num, [])
label_lines = []
for ann in page_anns:
label_lines.append(
f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} "
f"{ann.width:.6f} {ann.height:.6f}"
)
doc_ann_count += 1
label_path = dataset_dir / "labels" / split / f"{dst_name}.txt"
label_path.write_text("\n".join(label_lines))
total_images += doc_image_count
total_annotations += doc_ann_count
dataset_docs.append({
"document_id": doc_id,
"split": split,
"page_count": doc_image_count,
"annotation_count": doc_ann_count,
})
# 5. Record document-split assignments in DB
self._db.add_dataset_documents(
dataset_id=dataset_id,
documents=dataset_docs,
)
# 6. Generate data.yaml
self._generate_data_yaml(dataset_dir)
# 7. Update dataset status
self._db.update_dataset_status(
dataset_id=dataset_id,
status="ready",
total_documents=len(doc_list),
total_images=total_images,
total_annotations=total_annotations,
dataset_path=str(dataset_dir),
)
return {
"total_documents": len(doc_list),
"total_images": total_images,
"total_annotations": total_annotations,
}
def _generate_data_yaml(self, dataset_dir: Path) -> None:
"""Generate YOLO data.yaml configuration file."""
data = {
"path": str(dataset_dir.absolute()),
"train": "images/train",
"val": "images/val",
"test": "images/test",
"nc": len(FIELD_CLASSES),
"names": FIELD_CLASSES,
}
yaml_path = dataset_dir / "data.yaml"
yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True))

View File

@@ -11,11 +11,11 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from src.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
from src.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from src.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
from src.data.db import DocumentDB from shared.data.db import DocumentDB
from src.web.config import StorageConfig from inference.web.config import StorageConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -81,8 +81,8 @@ def get_pending_autolabel_documents(
List of AdminDocument records with status='auto_labeling' and auto_label_status='pending' List of AdminDocument records with status='auto_labeling' and auto_label_status='pending'
""" """
from sqlmodel import select from sqlmodel import select
from src.data.database import get_session_context from inference.data.database import get_session_context
from src.data.admin_models import AdminDocument from inference.data.admin_models import AdminDocument
with get_session_context() as session: with get_session_context() as session:
statement = select(AdminDocument).where( statement = select(AdminDocument).where(
@@ -116,8 +116,8 @@ def process_document_autolabel(
Returns: Returns:
Result dictionary with success status and annotations Result dictionary with success status and annotations
""" """
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
from src.pdf import PDFDocument from shared.pdf import PDFDocument
document_id = str(document.document_id) document_id = str(document.document_id)
file_path = Path(document.file_path) file_path = Path(document.file_path)
@@ -247,7 +247,7 @@ def _save_annotations_to_db(
Number of annotations saved Number of annotations saved
""" """
from PIL import Image from PIL import Image
from src.data.admin_models import FIELD_CLASS_IDS from inference.data.admin_models import FIELD_CLASS_IDS
# Mapping from CSV field names to internal field names # Mapping from CSV field names to internal field names
CSV_TO_INTERNAL_FIELD: dict[str, str] = { CSV_TO_INTERNAL_FIELD: dict[str, str] = {
@@ -480,7 +480,7 @@ def save_manual_annotations_to_document_db(
pdf_type = "unknown" pdf_type = "unknown"
if pdf_path.exists(): if pdf_path.exists():
try: try:
from src.pdf import PDFDocument from shared.pdf import PDFDocument
with PDFDocument(pdf_path) as pdf_doc: with PDFDocument(pdf_path) as pdf_doc:
tokens = list(pdf_doc.extract_text_tokens(0)) tokens = list(pdf_doc.extract_text_tokens(0))
pdf_type = "scanned" if len(tokens) < 10 else "text" pdf_type = "scanned" if len(tokens) < 10 else "text"

View File

@@ -71,8 +71,8 @@ class InferenceService:
start_time = time.time() start_time = time.time()
try: try:
from src.inference.pipeline import InferencePipeline from inference.pipeline.pipeline import InferencePipeline
from src.inference.yolo_detector import YOLODetector from inference.pipeline.yolo_detector import YOLODetector
# Initialize YOLO detector for visualization # Initialize YOLO detector for visualization
self._detector = YOLODetector( self._detector = YOLODetector(
@@ -257,7 +257,7 @@ class InferenceService:
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path: def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page).""" """Save visualization for PDF (first page)."""
from src.pdf.renderer import render_pdf_to_images from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO from ultralytics import YOLO
import io import io

View File

@@ -4,8 +4,8 @@ Background Task Queues
Worker queues for asynchronous and batch processing. Worker queues for asynchronous and batch processing.
""" """
from src.web.workers.async_queue import AsyncTaskQueue, AsyncTask from inference.web.workers.async_queue import AsyncTaskQueue, AsyncTask
from src.web.workers.batch_queue import ( from inference.web.workers.batch_queue import (
BatchTaskQueue, BatchTaskQueue,
BatchTask, BatchTask,
init_batch_queue, init_batch_queue,

View File

@@ -0,0 +1,8 @@
-e ../shared
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
sqlmodel>=0.0.22
ultralytics>=8.1.0
httpx>=0.25.0
openai>=1.0.0

View File

@@ -0,0 +1,14 @@
#!/usr/bin/env python
"""
Quick start script for the web server.
Usage:
python run_server.py
python run_server.py --port 8080
python run_server.py --debug --reload
"""
from inference.cli.serve import main
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,17 @@
from setuptools import setup, find_packages
setup(
name="invoice-inference",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
install_requires=[
"invoice-shared",
"fastapi>=0.104.0",
"uvicorn[standard]>=0.24.0",
"python-multipart>=0.0.6",
"sqlmodel>=0.0.22",
"ultralytics>=8.1.0",
"httpx>=0.25.0",
],
)

View File

@@ -0,0 +1,9 @@
PyMuPDF>=1.23.0
paddleocr>=2.7.0
Pillow>=10.0.0
numpy>=1.24.0
opencv-python>=4.8.0
psycopg2-binary>=2.9.0
python-dotenv>=1.0.0
pyyaml>=6.0
thefuzz>=0.20.0

19
packages/shared/setup.py Normal file
View File

@@ -0,0 +1,19 @@
from setuptools import setup, find_packages
setup(
name="invoice-shared",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
install_requires=[
"PyMuPDF>=1.23.0",
"paddleocr>=2.7.0",
"Pillow>=10.0.0",
"numpy>=1.24.0",
"opencv-python>=4.8.0",
"psycopg2-binary>=2.9.0",
"python-dotenv>=1.0.0",
"pyyaml>=6.0",
"thefuzz>=0.20.0",
],
)

View File

@@ -7,10 +7,16 @@ import platform
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file # Load environment variables from .env file at project root
# .env is at project root, config.py is in src/ # Walk up from packages/shared/shared/config.py to find project root
env_path = Path(__file__).parent.parent / '.env' _config_dir = Path(__file__).parent
load_dotenv(dotenv_path=env_path) for _candidate in [_config_dir.parent.parent.parent, _config_dir.parent.parent, _config_dir.parent]:
_env_path = _candidate / '.env'
if _env_path.exists():
load_dotenv(dotenv_path=_env_path)
break
else:
load_dotenv() # fallback: search cwd and parents
# Global DPI setting - must match training DPI for optimal model performance # Global DPI setting - must match training DPI for optimal model performance
DEFAULT_DPI = 150 DEFAULT_DPI = 150

View File

@@ -0,0 +1,3 @@
from .csv_loader import CSVLoader, InvoiceRow
__all__ = ['CSVLoader', 'InvoiceRow']

View File

@@ -9,8 +9,7 @@ from typing import Set, Dict, Any, Optional
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from shared.config import get_db_connection_string
from src.config import get_db_connection_string
class DocumentDB: class DocumentDB:

Some files were not shown because too many files have changed in this diff Show More