diff --git a/.claude/settings.local.json b/.claude/settings.local.json index da6d4f0..2c51df3 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -87,7 +87,10 @@ "Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")", "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")", "Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)", - "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")" + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")", + "Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")" ], "deny": [], "ask": [], diff --git a/.coverage b/.coverage index 482cda9..932eb87 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.gitignore b/.gitignore index e5e2a5f..676afe9 100644 --- a/.gitignore +++ b/.gitignore @@ -52,6 +52,10 @@ reports/*.jsonl logs/ *.log +# Coverage +htmlcov/ +.coverage + # Jupyter .ipynb_checkpoints/ diff --git a/README.md b/README.md index cb72623..2d23c67 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,25 @@ 1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注 2. **模型训练**: 使用 YOLOv11 训练字段检测模型 -3. **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化 +3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化 + +### 架构 + +项目采用 **monorepo + 三包分离** 架构,训练和推理可独立部署: + +``` +packages/ +├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 工具) +├── training/ # 训练服务 (GPU, 按需启动) +└── inference/ # 推理服务 (常驻运行) +``` + +| 服务 | 部署目标 | GPU | 生命周期 | +|------|---------|-----|---------| +| **Inference** | Azure App Service | 可选 | 常驻 7x24 | +| **Training** | Azure ACI | 必需 | 按需启动/销毁 | + +两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。 ### 当前进度 @@ -16,6 +34,8 @@ |------|------| | **已标注文档** | 9,738 (9,709 成功) | | **总体字段匹配率** | 94.8% (82,604/87,121) | +| **测试** | 922 passed | +| **模型 mAP@0.5** | 93.5% | **各字段匹配率:** @@ -42,24 +62,83 @@ |------|------| | **WSL** | WSL 2 + Ubuntu 22.04 | | **Conda** | Miniconda 或 Anaconda | -| **Python** | 3.10+ (通过 Conda 管理) | +| **Python** | 3.11+ (通过 Conda 管理) | | **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) | | **数据库** | PostgreSQL (存储标注结果) | -## 功能特点 +## 安装 -- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF -- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据 -- **多策略字段匹配**: 精确匹配、子串匹配、规范化匹配 -- **数据库存储**: 标注结果存储在 PostgreSQL,支持增量处理和断点续传 -- **YOLO 检测**: 使用 YOLOv11 检测发票字段区域 -- **OCR 识别**: 使用 PaddleOCR v5 提取检测区域的文本 -- **统一解析器**: payment_line 和 customer_number 采用独立解析器模块 -- **交叉验证**: payment_line 数据与单独检测字段交叉验证,优先采用 payment_line 值 -- **文档类型识别**: 自动区分 invoice (有 payment_line) 和 letter (无 payment_line) -- **Web 应用**: 提供 REST API 和可视化界面 -- **增量训练**: 支持在已训练模型基础上继续训练 -- **内存优化**: 支持低内存模式训练 (--low-memory) +```bash +# 1. 进入 WSL +wsl -d Ubuntu-22.04 + +# 2. 创建 Conda 环境 +conda create -n invoice-py311 python=3.11 -y +conda activate invoice-py311 + +# 3. 进入项目目录 +cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 + +# 4. 安装三个包 (editable mode) +pip install -e packages/shared +pip install -e packages/training +pip install -e packages/inference +``` + +## 项目结构 + +``` +invoice-master-poc-v2/ +├── packages/ +│ ├── shared/ # 共享库 +│ │ ├── setup.py +│ │ └── shared/ +│ │ ├── pdf/ # PDF 处理 (提取, 渲染, 检测) +│ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析 +│ │ ├── normalize/ # 字段规范化 (10 种 normalizer) +│ │ ├── matcher/ # 字段匹配 (精确/子串/模糊) +│ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配) +│ │ ├── data/ # DocumentDB, CSVLoader +│ │ ├── config.py # 全局配置 (数据库, 路径, DPI) +│ │ └── exceptions.py # 异常定义 +│ │ +│ ├── training/ # 训练服务 (GPU, 按需) +│ │ ├── setup.py +│ │ ├── Dockerfile +│ │ ├── run_training.py # 入口 (--task-id 或 --poll) +│ │ └── training/ +│ │ ├── cli/ # train, autolabel, analyze_*, validate +│ │ ├── yolo/ # db_dataset, annotation_generator +│ │ ├── processing/ # CPU/GPU worker pool, task dispatcher +│ │ └── data/ # training_db, autolabel_report +│ │ +│ └── inference/ # 推理服务 (常驻) +│ ├── setup.py +│ ├── Dockerfile +│ ├── run_server.py # Web 服务器入口 +│ └── inference/ +│ ├── cli/ # infer, serve +│ ├── pipeline/ # YOLO 检测, 字段提取, 解析器 +│ ├── web/ # FastAPI 应用 +│ │ ├── api/v1/ # REST API (admin, public, batch) +│ │ ├── schemas/ # Pydantic 数据模型 +│ │ ├── services/ # 业务逻辑 +│ │ ├── core/ # 认证, 调度器, 限流 +│ │ └── workers/ # 后台任务队列 +│ ├── validation/ # LLM 验证器 +│ ├── data/ # AdminDB, AsyncRequestDB, Models +│ └── azure/ # ACI 训练触发器 +│ +├── migrations/ # 数据库迁移 +│ ├── 001_async_tables.sql +│ ├── 002_nullable_admin_token.sql +│ └── 003_training_tasks.sql +├── frontend/ # React 前端 (Vite + TypeScript) +├── tests/ # 测试 (922 tests) +├── docker-compose.yml # 本地开发 (postgres + inference + training) +├── run_server.py # 快捷启动脚本 +└── runs/train/ # 训练输出 (weights, curves) +``` ## 支持的字段 @@ -76,476 +155,129 @@ | 8 | payment_line | 支付行 (机器可读格式) | | 9 | customer_number | 客户编号 | -## DPI 配置 - -**重要**: 系统所有组件统一使用 **150 DPI**,确保训练和推理的一致性。 - -DPI(每英寸点数)设置必须在训练和推理时保持一致,否则会导致: -- 检测框尺寸失配 -- mAP显著下降(可能从93.5%降到60-70%) -- 字段漏检或误检 - -### 配置位置 - -| 组件 | 配置文件 | 配置项 | -|------|---------|--------| -| **全局常量** | `src/config.py` | `DEFAULT_DPI = 150` | -| **Web推理** | `src/web/config.py` | `ModelConfig.dpi` (导入自 `src.config`) | -| **CLI推理** | `src/cli/infer.py` | `--dpi` 默认值 = `DEFAULT_DPI` | -| **自动标注** | `src/config.py` | `AUTOLABEL['dpi'] = DEFAULT_DPI` | -| **PDF转图** | `src/web/api/v1/admin/documents.py` | 使用 `DEFAULT_DPI` | - -### 使用示例 - -```bash -# 训练(使用默认150 DPI) -python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1 - -# 推理(默认150 DPI,与训练一致) -python -m src.cli.infer -m runs/train/invoice_fields/weights/best.pt -i invoice.pdf - -# 手动指定DPI(仅当需要与非默认训练DPI的模型配合时) -python -m src.cli.infer -m custom_model.pt -i invoice.pdf --dpi 150 -``` - -## 安装 - -```bash -# 1. 进入 WSL -wsl -d Ubuntu-22.04 - -# 2. 创建 Conda 环境 -conda create -n invoice-py311 python=3.11 -y -conda activate invoice-py311 - -# 3. 进入项目目录 -cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 - -# 4. 安装依赖 -pip install -r requirements.txt - -# 5. 安装 Web 依赖 -pip install uvicorn fastapi python-multipart pydantic -``` - ## 快速开始 -### 1. 准备数据 - -``` -~/invoice-data/ -├── raw_pdfs/ -│ ├── {DocumentId}.pdf -│ └── ... -├── structured_data/ -│ └── document_export_YYYYMMDD.csv -└── dataset/ - └── temp/ (渲染的图片) -``` - -CSV 格式: -```csv -DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount -3be53fd7-...,2025-12-13,100017500321,2026-01-03,100017500321,53939484,,114 -``` - -### 2. 自动标注 +### 1. 自动标注 ```bash # 使用双池模式 (CPU + GPU) -python -m src.cli.autolabel \ +python -m training.cli.autolabel \ --dual-pool \ --cpu-workers 3 \ --gpu-workers 1 # 单线程模式 -python -m src.cli.autolabel --workers 4 +python -m training.cli.autolabel --workers 4 ``` -### 3. 训练模型 +### 2. 训练模型 ```bash # 从预训练模型开始训练 -python -m src.cli.train \ +python -m training.cli.train \ --model yolo11n.pt \ --epochs 100 \ --batch 16 \ --name invoice_fields \ --dpi 150 -# 低内存模式 (适用于内存不足场景) -python -m src.cli.train \ +# 低内存模式 +python -m training.cli.train \ --model yolo11n.pt \ --epochs 100 \ --name invoice_fields \ - --low-memory \ - --workers 4 \ - --no-cache + --low-memory -# 从检查点恢复训练 (训练中断后) -python -m src.cli.train \ +# 从检查点恢复训练 +python -m training.cli.train \ --model runs/train/invoice_fields/weights/last.pt \ --epochs 100 \ --name invoice_fields \ --resume ``` -### 4. 增量训练 - -当添加新数据后,可以在已训练模型基础上继续训练: - -```bash -# 从已训练的 best.pt 继续训练 -python -m src.cli.train \ - --model runs/train/invoice_yolo11n_full/weights/best.pt \ - --epochs 30 \ - --batch 16 \ - --name invoice_yolo11n_v2 \ - --dpi 150 -``` - -**增量训练建议**: - -| 场景 | 建议 | -|------|------| -| 添加少量新数据 (<20%) | 继续训练 10-30 epochs | -| 添加大量新数据 (>50%) | 继续训练 50-100 epochs | -| 修正大量标注错误 | 从头训练 | -| 添加新的字段类型 | 从头训练 | - -### 5. 推理 +### 3. 推理 ```bash # 命令行推理 -python -m src.cli.infer \ +python -m inference.cli.infer \ --model runs/train/invoice_fields/weights/best.pt \ --input path/to/invoice.pdf \ --output result.json \ --gpu - -# 批量推理 -python -m src.cli.infer \ - --model runs/train/invoice_fields/weights/best.pt \ - --input invoices/*.pdf \ - --output results/ \ - --gpu ``` -**推理结果包含**: -- `fields`: 提取的字段值 (InvoiceNumber, Amount, payment_line, customer_number 等) -- `confidence`: 各字段的置信度 -- `document_type`: 文档类型 ("invoice" 或 "letter") -- `cross_validation`: payment_line 交叉验证结果 (如果有) - -### 6. Web 应用 - -**在 WSL 环境中启动**: +### 4. Web 应用 ```bash -# 方法 1: 从 Windows PowerShell 启动 (推荐) +# 从 Windows PowerShell 启动 wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000" -# 方法 2: 在 WSL 内启动 -conda activate invoice-py311 -cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 -python run_server.py --port 8000 - -# 方法 3: 使用启动脚本 -./start_web.sh +# 启动前端 +cd frontend && npm install && npm run dev +# 访问 http://localhost:5173 ``` -**服务启动后**: -- 访问 **http://localhost:8000** 使用 Web 界面 -- 服务会自动加载模型 `runs/train/invoice_fields/weights/best.pt` -- GPU 默认启用,置信度阈值 0.5 +### 5. Docker 本地开发 -#### Web API 端点 +```bash +docker-compose up +# inference: http://localhost:8000 +# training: 轮询模式自动拾取任务 +``` + +## 训练触发流程 + +推理服务通过 API 触发训练,训练在独立的 GPU 实例上执行: + +``` +Inference API PostgreSQL Training (ACI) + | | | + POST /admin/training/trigger | | + |-> INSERT training_tasks ------>| status=pending | + |-> Azure SDK: create ACI --------------------------------> 启动 + | | | + | |<-- SELECT pending -----+ + | |--- UPDATE running -----+ + | | 执行训练... + | |<-- UPDATE completed ---+ + | | + model_path | + | | + metrics 自动关机 + | | | + GET /admin/training/{id} | | + |-> SELECT training_tasks ------>| | + +-- return status + metrics | | +``` + +## Web API 端点 + +**Public API:** | 方法 | 端点 | 描述 | |------|------|------| -| GET | `/` | Web UI 界面 | | GET | `/api/v1/health` | 健康检查 | | POST | `/api/v1/infer` | 上传文件并推理 | | GET | `/api/v1/results/{filename}` | 获取可视化图片 | +| POST | `/api/v1/async/infer` | 异步推理 | +| GET | `/api/v1/async/status/{task_id}` | 查询异步任务状态 | -#### API 响应格式 +**Admin API** (需要 `X-Admin-Token` header): -```json -{ - "status": "success", - "result": { - "document_id": "abc123", - "document_type": "invoice", - "fields": { - "InvoiceNumber": "12345", - "Amount": "1234.56", - "payment_line": "# 94228110015950070 # > 48666036#14#", - "customer_number": "UMJ 436-R" - }, - "confidence": { - "InvoiceNumber": 0.95, - "Amount": 0.92 - }, - "cross_validation": { - "is_valid": true, - "ocr_match": true, - "amount_match": true - } - } -} -``` - -## 训练配置 - -### YOLO 训练参数 - -```bash -python -m src.cli.train [OPTIONS] - -Options: - --model, -m 基础模型 (默认: yolo11n.pt) - --epochs, -e 训练轮数 (默认: 100) - --batch, -b 批大小 (默认: 16) - --imgsz 图像尺寸 (默认: 1280) - --dpi PDF 渲染 DPI (默认: 150) - --name 训练名称 - --limit 限制文档数 (用于测试) - --device 设备 (0=GPU, cpu) - --resume 从检查点恢复训练 - --low-memory 启用低内存模式 (batch=8, workers=4, no-cache) - --workers 数据加载 worker 数 (默认: 8) - --cache 缓存图像到内存 -``` - -### 训练最佳实践 - -1. **禁用翻转增强** (文本检测): - ```python - fliplr=0.0, flipud=0.0 - ``` - -2. **使用 Early Stopping**: - ```python - patience=20 - ``` - -3. **启用 AMP** (混合精度训练): - ```python - amp=True - ``` - -4. **保存检查点**: - ```python - save_period=10 - ``` - -### 训练结果示例 - -**最新训练结果** (100 epochs, 2026-01-22): - -| 指标 | 值 | -|------|-----| -| **mAP@0.5** | 93.5% | -| **mAP@0.5-0.95** | 83.0% | -| **训练集** | ~10,000 张标注图片 | -| **字段类型** | 10 个字段 (新增 payment_line, customer_number) | -| **模型位置** | `runs/train/invoice_fields/weights/best.pt` | - -**各字段检测性能**: -- 发票基础信息 (InvoiceNumber, InvoiceDate, InvoiceDueDate): >95% mAP -- 支付信息 (OCR, Bankgiro, Plusgiro, Amount): >90% mAP -- 组织信息 (supplier_org_number, customer_number): >85% mAP -- 支付行 (payment_line): >80% mAP - -**模型文件**: -``` -runs/train/invoice_fields/weights/ -├── best.pt # 最佳模型 (mAP@0.5 最高) ⭐ 推荐用于生产 -└── last.pt # 最后检查点 (用于继续训练) -``` - -> 注:目前仍在持续标注更多数据,预计最终将有 25,000+ 张标注图片用于训练。 - -## 项目结构 - -``` -invoice-master-poc-v2/ -├── src/ -│ ├── cli/ # 命令行工具 -│ │ ├── autolabel.py # 自动标注 -│ │ ├── train.py # 模型训练 -│ │ ├── infer.py # 推理 -│ │ └── serve.py # Web 服务器 -│ ├── pdf/ # PDF 处理 -│ │ ├── extractor.py # 文本提取 -│ │ ├── renderer.py # 图像渲染 -│ │ └── detector.py # 类型检测 -│ ├── ocr/ # PaddleOCR 封装 -│ │ └── machine_code_parser.py # 机器可读付款行解析器 -│ ├── normalize/ # 字段规范化 -│ ├── matcher/ # 字段匹配 -│ ├── yolo/ # YOLO 相关 -│ │ ├── annotation_generator.py -│ │ └── db_dataset.py -│ ├── inference/ # 推理管道 -│ │ ├── pipeline.py # 主推理流程 -│ │ ├── yolo_detector.py # YOLO 检测 -│ │ ├── field_extractor.py # 字段提取 -│ │ ├── payment_line_parser.py # 支付行解析器 -│ │ └── customer_number_parser.py # 客户编号解析器 -│ ├── processing/ # 多池处理架构 -│ │ ├── worker_pool.py -│ │ ├── cpu_pool.py -│ │ ├── gpu_pool.py -│ │ ├── task_dispatcher.py -│ │ └── dual_pool_coordinator.py -│ ├── web/ # Web 应用 -│ │ ├── app.py # FastAPI 应用入口 -│ │ ├── routes.py # API 路由 -│ │ ├── services.py # 业务逻辑 -│ │ └── schemas.py # 数据模型 -│ ├── utils/ # 工具模块 -│ │ ├── text_cleaner.py # 文本清理 -│ │ ├── validators.py # 字段验证 -│ │ ├── fuzzy_matcher.py # 模糊匹配 -│ │ └── ocr_corrections.py # OCR 错误修正 -│ └── data/ # 数据处理 -├── tests/ # 测试文件 -│ ├── ocr/ # OCR 模块测试 -│ │ └── test_machine_code_parser.py -│ ├── inference/ # 推理模块测试 -│ ├── normalize/ # 规范化模块测试 -│ └── utils/ # 工具模块测试 -├── docs/ # 文档 -│ ├── REFACTORING_SUMMARY.md -│ └── TEST_COVERAGE_IMPROVEMENT.md -├── config.py # 配置文件 -├── run_server.py # Web 服务器启动脚本 -├── runs/ # 训练输出 -│ └── train/ -│ └── invoice_fields/ -│ └── weights/ -│ ├── best.pt # 最佳模型 -│ └── last.pt # 最后检查点 -└── requirements.txt -``` - -## 多池处理架构 - -项目使用 CPU + GPU 双池架构处理不同类型的 PDF: - -``` -┌─────────────────────────────────────────────────────┐ -│ DualPoolCoordinator │ -│ ┌─────────────────┐ ┌─────────────────┐ │ -│ │ CPU Pool │ │ GPU Pool │ │ -│ │ (3 workers) │ │ (1 worker) │ │ -│ │ │ │ │ │ -│ │ Text PDFs │ │ Scanned PDFs │ │ -│ │ ~50-87 it/s │ │ ~1-2 it/s │ │ -│ └─────────────────┘ └─────────────────┘ │ -│ │ -│ TaskDispatcher: 根据 PDF 类型分配任务 │ -└─────────────────────────────────────────────────────┘ -``` - -### 关键设计 - -- **spawn 启动方式**: 兼容 CUDA 多进程 -- **as_completed()**: 无死锁结果收集 -- **进程初始化器**: 每个 worker 加载一次模型 -- **协调器持久化**: 跨 CSV 文件复用 worker 池 - -## 配置文件 - -### config.py - -```python -# 数据库配置 -DATABASE = { - 'host': '192.168.68.31', - 'port': 5432, - 'database': 'docmaster', - 'user': 'docmaster', - 'password': '******', -} - -# 路径配置 -PATHS = { - 'csv_dir': '~/invoice-data/structured_data', - 'pdf_dir': '~/invoice-data/raw_pdfs', - 'output_dir': '~/invoice-data/dataset', -} -``` - -## CLI 命令参考 - -### autolabel - -```bash -python -m src.cli.autolabel [OPTIONS] - -Options: - --csv, -c CSV 文件路径 (支持 glob) - --pdf-dir, -p PDF 文件目录 - --output, -o 输出目录 - --workers, -w 单线程模式 worker 数 (默认: 4) - --dual-pool 启用双池模式 - --cpu-workers CPU 池 worker 数 (默认: 3) - --gpu-workers GPU 池 worker 数 (默认: 1) - --dpi 渲染 DPI (默认: 150) - --limit, -l 限制处理文档数 -``` - -### train - -```bash -python -m src.cli.train [OPTIONS] - -Options: - --model, -m 基础模型路径 - --epochs, -e 训练轮数 (默认: 100) - --batch, -b 批大小 (默认: 16) - --imgsz 图像尺寸 (默认: 1280) - --dpi PDF 渲染 DPI (默认: 150) - --name 训练名称 - --limit 限制文档数 -``` - -### infer - -```bash -python -m src.cli.infer [OPTIONS] - -Options: - --model, -m 模型路径 - --input, -i 输入 PDF/图像 - --output, -o 输出 JSON 路径 - --confidence 置信度阈值 (默认: 0.5) - --dpi 渲染 DPI (默认: 150, 必须与训练DPI一致) - --gpu 使用 GPU -``` - -### serve - -```bash -python run_server.py [OPTIONS] - -Options: - --host 绑定地址 (默认: 0.0.0.0) - --port 端口 (默认: 8000) - --model, -m 模型路径 - --confidence 置信度阈值 (默认: 0.3) - --dpi 渲染 DPI (默认: 150) - --no-gpu 禁用 GPU - --reload 开发模式自动重载 - --debug 调试模式 -``` +| 方法 | 端点 | 描述 | +|------|------|------| +| POST | `/api/v1/admin/auth/login` | 管理员登录 | +| GET | `/api/v1/admin/documents` | 文档列表 | +| POST | `/api/v1/admin/documents/upload` | 上传 PDF | +| GET | `/api/v1/admin/documents/{id}` | 文档详情 | +| PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 | +| POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 | +| POST | `/api/v1/admin/training/trigger` | 触发训练任务 | +| GET | `/api/v1/admin/training/{id}/status` | 查询训练状态 | ## Python API ```python -from src.inference.pipeline import InferencePipeline +from inference.pipeline import InferencePipeline # 初始化 pipeline = InferencePipeline( @@ -559,41 +291,25 @@ pipeline = InferencePipeline( # 处理 PDF result = pipeline.process_pdf('invoice.pdf') -# 处理图片 -result = pipeline.process_image('invoice.png') - -# 获取结果 print(result.fields) -# { -# 'InvoiceNumber': '12345', -# 'Amount': '1234.56', -# 'payment_line': '# 94228110015950070 # > 48666036#14#', -# 'customer_number': 'UMJ 436-R', -# ... -# } +# {'InvoiceNumber': '12345', 'Amount': '1234.56', ...} -print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...} -print(result.to_json()) # JSON 格式输出 +print(result.confidence) +# {'InvoiceNumber': 0.95, 'Amount': 0.92, ...} -# 访问交叉验证结果 +# 交叉验证 if result.cross_validation: print(f"OCR match: {result.cross_validation.ocr_match}") - print(f"Amount match: {result.cross_validation.amount_match}") - print(f"Details: {result.cross_validation.details}") ``` -### 统一解析器使用 - ```python -from src.inference.payment_line_parser import PaymentLineParser -from src.inference.customer_number_parser import CustomerNumberParser +from inference.pipeline.payment_line_parser import PaymentLineParser +from inference.pipeline.customer_number_parser import CustomerNumberParser # Payment Line 解析 parser = PaymentLineParser() result = parser.parse("# 94228110015950070 # 15658 00 8 > 48666036#14#") -print(f"OCR: {result.ocr_number}") -print(f"Amount: {result.amount}") -print(f"Account: {result.account_number}") +print(f"OCR: {result.ocr_number}, Amount: {result.amount}") # Customer Number 解析 parser = CustomerNumberParser() @@ -601,156 +317,38 @@ result = parser.parse("Said, Shakar Umj 436-R Billo") print(f"Customer Number: {result}") # "UMJ 436-R" ``` +## DPI 配置 + +系统所有组件统一使用 **150 DPI**。DPI 必须在训练和推理时保持一致。 + +| 组件 | 配置位置 | +|------|---------| +| 全局常量 | `packages/shared/shared/config.py` -> `DEFAULT_DPI = 150` | +| Web 推理 | `packages/inference/inference/web/config.py` -> `ModelConfig.dpi` | +| CLI 推理 | `python -m inference.cli.infer --dpi 150` | +| 自动标注 | `packages/shared/shared/config.py` -> `AUTOLABEL['dpi']` | + +## 数据库架构 + +| 数据库 | 用途 | 存储内容 | +|--------|------|----------| +| **PostgreSQL** | 标注结果 | `documents`, `field_results`, `training_tasks` | +| **SQLite** (AdminDB) | Web 应用 | 文档管理, 标注编辑, 用户认证 | + ## 测试 -### 测试统计 - -| 指标 | 数值 | -|------|------| -| **测试总数** | 688 | -| **通过率** | 100% | -| **整体覆盖率** | 37% | - -### 关键模块覆盖率 - -| 模块 | 覆盖率 | 测试数 | -|------|--------|--------| -| `machine_code_parser.py` | 65% | 79 | -| `payment_line_parser.py` | 85% | 45 | -| `customer_number_parser.py` | 90% | 32 | - -### 运行测试 - ```bash # 运行所有测试 -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest" +DB_PASSWORD=xxx pytest tests/ -q # 运行并查看覆盖率 -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src --cov-report=term-missing" - -# 运行特定模块测试 -wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/ocr/test_machine_code_parser.py -v" +DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing ``` -### 测试结构 - -``` -tests/ -├── ocr/ -│ ├── test_machine_code_parser.py # 支付行解析 (79 tests) -│ └── test_ocr_engine.py # OCR 引擎测试 -├── inference/ -│ ├── test_payment_line_parser.py # 支付行解析器 -│ └── test_customer_number_parser.py # 客户编号解析器 -├── normalize/ -│ └── test_normalizers.py # 字段规范化 -└── utils/ - └── test_validators.py # 字段验证 -``` - -## 开发状态 - -**已完成功能**: -- [x] 文本层 PDF 自动标注 -- [x] 扫描图 OCR 自动标注 -- [x] 多策略字段匹配 (精确/子串/规范化) -- [x] PostgreSQL 数据库存储 (断点续传) -- [x] 信号处理和超时保护 -- [x] YOLO 训练 (93.5% mAP@0.5, 10 个字段) -- [x] 推理管道 -- [x] 字段规范化和验证 -- [x] Web 应用 (FastAPI + REST API) -- [x] 增量训练支持 -- [x] 内存优化训练 (--low-memory, --resume) -- [x] Payment Line 解析器 (统一模块) -- [x] Customer Number 解析器 (统一模块) -- [x] Payment Line 交叉验证 (OCR, Amount, Account) -- [x] 文档类型识别 (invoice/letter) -- [x] 单元测试覆盖 (688 tests, 37% coverage) - -**进行中**: -- [ ] 完成全部 25,000+ 文档标注 -- [ ] 多源融合增强 (Multi-source fusion) -- [ ] OCR 错误修正集成 -- [ ] 提升测试覆盖率到 60%+ - -**计划中**: -- [ ] 表格 items 提取 -- [ ] 模型量化部署 (ONNX/TensorRT) -- [ ] 多语言支持扩展 - -## 关键技术特性 - -### 1. Payment Line 交叉验证 - -瑞典发票的 payment_line (支付行) 包含完整的支付信息:OCR 参考号、金额、账号。我们实现了交叉验证机制: - -``` -Payment Line: # 94228110015950070 # 15658 00 8 > 48666036#14# - ↓ ↓ ↓ - OCR Number Amount Bankgiro Account -``` - -**验证流程**: -1. 从 payment_line 提取 OCR、Amount、Account -2. 与单独检测的字段对比验证 -3. **payment_line 值优先** - 如有不匹配,采用 payment_line 的值 -4. 返回验证结果和详细信息 - -**优势**: -- 提高数据准确性 (payment_line 是机器可读格式,更可靠) -- 发现 OCR 或检测错误 -- 为数据质量提供信心指标 - -### 2. 统一解析器架构 - -采用独立解析器模块处理复杂字段: - -**PaymentLineParser**: -- 解析瑞典标准支付行格式 -- 提取 OCR、Amount (包含 Kronor + Öre)、Account + Check digits -- 支持多种变体格式 - -**CustomerNumberParser**: -- 支持多种瑞典客户编号格式 (`UMJ 436-R`, `JTY 576-3`, `FFL 019N`) -- 从混合文本中提取 (如地址行中的客户编号) -- 大小写不敏感,输出统一大写格式 - -**优势**: -- 代码模块化、可测试 -- 易于扩展新格式 -- 统一的解析逻辑,减少重复代码 - -### 3. 文档类型自动识别 - -根据 payment_line 字段自动判断文档类型: - -- **invoice**: 包含 payment_line 的发票文档 -- **letter**: 不含 payment_line 的信函文档 - -这个特性帮助下游系统区分处理流程。 - -### 4. 低内存模式训练 - -支持在内存受限环境下训练: - -```bash -python -m src.cli.train --low-memory -``` - -自动调整: -- batch size: 16 → 8 -- workers: 8 → 4 -- cache: disabled -- 推荐用于 GPU 内存 < 8GB 或系统内存 < 16GB 的场景 - -### 5. 断点续传训练 - -训练中断后可从检查点恢复: - -```bash -python -m src.cli.train --resume --model runs/train/invoice_fields/weights/last.pt -``` +| 指标 | 数值 | +|------|------| +| **测试总数** | 922 | +| **通过率** | 100% | ## 技术栈 @@ -762,32 +360,7 @@ python -m src.cli.train --resume --model runs/train/invoice_fields/weights/last. | **数据库** | PostgreSQL + psycopg2 | | **Web 框架** | FastAPI + Uvicorn | | **深度学习** | PyTorch + CUDA 12.x | - -## 常见问题 - -**Q: 为什么必须在 WSL 环境运行?** - -A: PaddleOCR 和某些依赖在 Windows 原生环境存在兼容性问题。WSL 提供完整的 Linux 环境,确保所有依赖正常工作。 - -**Q: 训练过程中出现 OOM (内存不足) 错误怎么办?** - -A: 使用 `--low-memory` 模式,或手动调整 `--batch` 和 `--workers` 参数。 - -**Q: payment_line 和单独检测字段不匹配时怎么处理?** - -A: 系统默认优先采用 payment_line 的值,因为 payment_line 是机器可读格式,通常更准确。验证结果会记录在 `cross_validation` 字段中。 - -**Q: 如何添加新的字段类型?** - -A: -1. 在 `src/inference/constants.py` 添加字段定义 -2. 在 `field_extractor.py` 添加规范化方法 -3. 重新生成标注数据 -4. 从头训练模型 - -**Q: 可以用 CPU 训练吗?** - -A: 可以,但速度会非常慢 (慢 10-50 倍)。强烈建议使用 GPU 训练。 +| **部署** | Docker + Azure ACI (训练) / App Service (推理) | ## 许可证 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..0f79a69 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,60 @@ +version: "3.8" + +services: + postgres: + image: postgres:15 + environment: + POSTGRES_DB: docmaster + POSTGRES_USER: docmaster + POSTGRES_PASSWORD: ${DB_PASSWORD:-devpassword} + ports: + - "5432:5432" + volumes: + - pgdata:/var/lib/postgresql/data + - ./migrations:/docker-entrypoint-initdb.d + + inference: + build: + context: . + dockerfile: packages/inference/Dockerfile + ports: + - "8000:8000" + environment: + - DB_HOST=postgres + - DB_PORT=5432 + - DB_NAME=docmaster + - DB_USER=docmaster + - DB_PASSWORD=${DB_PASSWORD:-devpassword} + - MODEL_PATH=/app/models/best.pt + volumes: + - ./models:/app/models + depends_on: + - postgres + + training: + build: + context: . + dockerfile: packages/training/Dockerfile + environment: + - DB_HOST=postgres + - DB_PORT=5432 + - DB_NAME=docmaster + - DB_USER=docmaster + - DB_PASSWORD=${DB_PASSWORD:-devpassword} + volumes: + - ./models:/app/models + - ./temp:/app/temp + depends_on: + - postgres + # Override CMD for local dev polling mode + command: ["python", "run_training.py", "--poll", "--poll-interval", "30"] + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + +volumes: + pgdata: diff --git a/docs/training-flow.mmd b/docs/training-flow.mmd new file mode 100644 index 0000000..b4ed0c8 --- /dev/null +++ b/docs/training-flow.mmd @@ -0,0 +1,54 @@ +flowchart TD + A[CLI Entry Point\nsrc/cli/train.py] --> B[Parse Arguments\n--model, --epochs, --batch, --imgsz, etc.] + B --> C[Connect PostgreSQL\nDB_HOST / DB_NAME / DB_PASSWORD] + + C --> D[Load Data from DB\nsrc/yolo/db_dataset.py] + D --> D1[Scan temp/doc_id/images/\nfor rendered PNGs] + D --> D2[Batch load field_results\nfrom database - batch 500] + + D1 --> E[Create DBYOLODataset] + D2 --> E + + E --> F[Split Train/Val/Test\n80% / 10% / 10%\nDocument-level, seed=42] + + F --> G[Export to YOLO Format] + G --> G1[Copy images to\ntrain/val/test dirs] + G --> G2[Generate .txt labels\nclass x_center y_center w h] + G --> G3[Generate dataset.yaml\n+ classes.txt] + G --> G4[Coordinate Conversion\nPDF points 72DPI -> render DPI\nNormalize to 0-1] + + G1 --> H{--export-only?} + G2 --> H + G3 --> H + G4 --> H + + H -- Yes --> Z[Done - Dataset exported] + H -- No --> I[Load YOLO Model] + + I --> I1{--resume?} + I1 -- Yes --> I2[Load last.pt checkpoint] + I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt] + + I2 --> J[Configure Training] + I3 --> J + + J --> J1[Conservative Augmentation\nrotation=5 deg, translate=5%\nno flip, no mosaic, no mixup] + J --> J2[imgsz=1280, pretrained=True] + + J1 --> K[model.train\nUltralytics Training Loop] + J2 --> K + + K --> L[Training Outputs\nruns/train/name/] + L --> L1[weights/best.pt\nweights/last.pt] + L --> L2[results.csv + results.png\nTraining curves] + L --> L3[PR curves, F1 curves\nConfusion matrix] + + L1 --> M[Test Set Validation\nmodel.val split=test] + M --> N[Report Metrics\nmAP@0.5 = 93.5%\nmAP@0.5-0.95] + + N --> O[Close DB Connection] + + style A fill:#4a90d9,color:#fff + style K fill:#e67e22,color:#fff + style N fill:#27ae60,color:#fff + style Z fill:#95a5a6,color:#fff diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 46ab77d..994e9c1 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -20,15 +20,35 @@ "zustand": "^4.5.0" }, "devDependencies": { + "@testing-library/jest-dom": "^6.9.1", + "@testing-library/react": "^16.3.2", + "@testing-library/user-event": "^14.6.1", "@types/node": "^22.14.0", "@vitejs/plugin-react": "^5.0.0", + "@vitest/coverage-v8": "^4.0.18", "autoprefixer": "^10.4.17", + "jsdom": "^27.4.0", "postcss": "^8.4.35", "tailwindcss": "^3.4.1", "typescript": "~5.8.2", - "vite": "^6.2.0" + "vite": "^6.2.0", + "vitest": "^4.0.18" } }, + "node_modules/@acemir/cssom": { + "version": "0.9.31", + "resolved": "https://registry.npmjs.org/@acemir/cssom/-/cssom-0.9.31.tgz", + "integrity": "sha512-ZnR3GSaH+/vJ0YlHau21FjfLYjMpYVIzTD8M8vIEQvIGxeOXyXdzCI140rrCY862p/C/BbzWsjc1dgnM9mkoTA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@adobe/css-tools": { + "version": "4.4.4", + "resolved": "https://registry.npmjs.org/@adobe/css-tools/-/css-tools-4.4.4.tgz", + "integrity": "sha512-Elp+iwUx5rN5+Y8xLt5/GRoG20WGoDCQ/1Fb+1LiGtvwbDavuSk0jhD/eZdckHAuzcDzccnkv+rEjyWfRx18gg==", + "dev": true, + "license": "MIT" + }, "node_modules/@alloc/quick-lru": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", @@ -42,6 +62,61 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/@asamuzakjp/css-color": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@asamuzakjp/css-color/-/css-color-4.1.1.tgz", + "integrity": "sha512-B0Hv6G3gWGMn0xKJ0txEi/jM5iFpT3MfDxmhZFb4W047GvytCf1DHQ1D69W3zHI4yWe2aTZAA0JnbMZ7Xc8DuQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@csstools/css-calc": "^2.1.4", + "@csstools/css-color-parser": "^3.1.0", + "@csstools/css-parser-algorithms": "^3.0.5", + "@csstools/css-tokenizer": "^3.0.4", + "lru-cache": "^11.2.4" + } + }, + "node_modules/@asamuzakjp/css-color/node_modules/lru-cache": { + "version": "11.2.5", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.5.tgz", + "integrity": "sha512-vFrFJkWtJvJnD5hg+hJvVE8Lh/TcMzKnTgCWmtBipwI5yLX/iX+5UB2tfuyODF5E7k9xEzMdYgGqaSb1c0c5Yw==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/dom-selector": { + "version": "6.7.6", + "resolved": "https://registry.npmjs.org/@asamuzakjp/dom-selector/-/dom-selector-6.7.6.tgz", + "integrity": "sha512-hBaJER6A9MpdG3WgdlOolHmbOYvSk46y7IQN/1+iqiCuUu6iWdQrs9DGKF8ocqsEqWujWf/V7b7vaDgiUmIvUg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/nwsapi": "^2.3.9", + "bidi-js": "^1.0.3", + "css-tree": "^3.1.0", + "is-potential-custom-element-name": "^1.0.1", + "lru-cache": "^11.2.4" + } + }, + "node_modules/@asamuzakjp/dom-selector/node_modules/lru-cache": { + "version": "11.2.5", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.5.tgz", + "integrity": "sha512-vFrFJkWtJvJnD5hg+hJvVE8Lh/TcMzKnTgCWmtBipwI5yLX/iX+5UB2tfuyODF5E7k9xEzMdYgGqaSb1c0c5Yw==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/nwsapi": { + "version": "2.3.9", + "resolved": "https://registry.npmjs.org/@asamuzakjp/nwsapi/-/nwsapi-2.3.9.tgz", + "integrity": "sha512-n8GuYSrI9bF7FFZ/SjhwevlHc8xaVlb/7HmHelnc/PZXBD2ZR49NnN9sMMuDdEGPeeRQ5d0hqlSlEpgCX3Wl0Q==", + "dev": true, + "license": "MIT" + }, "node_modules/@babel/code-frame": { "version": "7.28.6", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz", @@ -276,6 +351,16 @@ "@babel/core": "^7.0.0-0" } }, + "node_modules/@babel/runtime": { + "version": "7.28.6", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.6.tgz", + "integrity": "sha512-05WQkdpL9COIMz4LjTxGpPNCdlpyimKppYNoJ5Di5EUObifl8t4tuLuUBBZEpoLYOmfvIWrsp9fCl0HoPRVTdA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, "node_modules/@babel/template": { "version": "7.28.6", "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.28.6.tgz", @@ -324,6 +409,148 @@ "node": ">=6.9.0" } }, + "node_modules/@bcoe/v8-coverage": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@bcoe/v8-coverage/-/v8-coverage-1.0.2.tgz", + "integrity": "sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/@csstools/color-helpers": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/@csstools/color-helpers/-/color-helpers-5.1.0.tgz", + "integrity": "sha512-S11EXWJyy0Mz5SYvRmY8nJYTFFd1LCNV+7cXyAgQtOOuzb4EsgfqDufL+9esx72/eLhsRdGZwaldu/h+E4t4BA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0", + "engines": { + "node": ">=18" + } + }, + "node_modules/@csstools/css-calc": { + "version": "2.1.4", + "resolved": "https://registry.npmjs.org/@csstools/css-calc/-/css-calc-2.1.4.tgz", + "integrity": "sha512-3N8oaj+0juUw/1H3YwmDDJXCgTB1gKU6Hc/bB502u9zR0q2vd786XJH9QfrKIEgFlZmhZiq6epXl4rHqhzsIgQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^3.0.5", + "@csstools/css-tokenizer": "^3.0.4" + } + }, + "node_modules/@csstools/css-color-parser": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@csstools/css-color-parser/-/css-color-parser-3.1.0.tgz", + "integrity": "sha512-nbtKwh3a6xNVIp/VRuXV64yTKnb1IjTAEEh3irzS+HkKjAOYLTGNb9pmVNntZ8iVBHcWDA2Dof0QtPgFI1BaTA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "dependencies": { + "@csstools/color-helpers": "^5.1.0", + "@csstools/css-calc": "^2.1.4" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^3.0.5", + "@csstools/css-tokenizer": "^3.0.4" + } + }, + "node_modules/@csstools/css-parser-algorithms": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/@csstools/css-parser-algorithms/-/css-parser-algorithms-3.0.5.tgz", + "integrity": "sha512-DaDeUkXZKjdGhgYaHNJTV9pV7Y9B3b644jCLs9Upc3VeNGg6LWARAT6O+Q+/COo+2gg/bM5rhpMAtf70WqfBdQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@csstools/css-tokenizer": "^3.0.4" + } + }, + "node_modules/@csstools/css-syntax-patches-for-csstree": { + "version": "1.0.26", + "resolved": "https://registry.npmjs.org/@csstools/css-syntax-patches-for-csstree/-/css-syntax-patches-for-csstree-1.0.26.tgz", + "integrity": "sha512-6boXK0KkzT5u5xOgF6TKB+CLq9SOpEGmkZw0g5n9/7yg85wab3UzSxB8TxhLJ31L4SGJ6BCFRw/iftTha1CJXA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0" + }, + "node_modules/@csstools/css-tokenizer": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@csstools/css-tokenizer/-/css-tokenizer-3.0.4.tgz", + "integrity": "sha512-Vd/9EVDiu6PPJt9yAh6roZP6El1xHrdvIVGjyBsHR0RYwNHgL7FJPyIIW4fANJNG6FtyZfvlRPpFI4ZM/lubvw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/@esbuild/aix-ppc64": { "version": "0.25.12", "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz", @@ -766,6 +993,24 @@ "node": ">=18" } }, + "node_modules/@exodus/bytes": { + "version": "1.10.0", + "resolved": "https://registry.npmjs.org/@exodus/bytes/-/bytes-1.10.0.tgz", + "integrity": "sha512-tf8YdcbirXdPnJ+Nd4UN1EXnz+IP2DI45YVEr3vvzcVTOyrApkmIB4zvOQVd3XPr7RXnfBtAx+PXImXOIU0Ajg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + }, + "peerDependencies": { + "@noble/hashes": "^1.8.0 || ^2.0.0" + }, + "peerDependenciesMeta": { + "@noble/hashes": { + "optional": true + } + } + }, "node_modules/@jridgewell/gen-mapping": { "version": "0.3.13", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", @@ -1294,6 +1539,104 @@ "react": "^18 || ^19" } }, + "node_modules/@testing-library/dom": { + "version": "10.4.1", + "resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-10.4.1.tgz", + "integrity": "sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "@babel/code-frame": "^7.10.4", + "@babel/runtime": "^7.12.5", + "@types/aria-query": "^5.0.1", + "aria-query": "5.3.0", + "dom-accessibility-api": "^0.5.9", + "lz-string": "^1.5.0", + "picocolors": "1.1.1", + "pretty-format": "^27.0.2" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@testing-library/jest-dom": { + "version": "6.9.1", + "resolved": "https://registry.npmjs.org/@testing-library/jest-dom/-/jest-dom-6.9.1.tgz", + "integrity": "sha512-zIcONa+hVtVSSep9UT3jZ5rizo2BsxgyDYU7WFD5eICBE7no3881HGeb/QkGfsJs6JTkY1aQhT7rIPC7e+0nnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@adobe/css-tools": "^4.4.0", + "aria-query": "^5.0.0", + "css.escape": "^1.5.1", + "dom-accessibility-api": "^0.6.3", + "picocolors": "^1.1.1", + "redent": "^3.0.0" + }, + "engines": { + "node": ">=14", + "npm": ">=6", + "yarn": ">=1" + } + }, + "node_modules/@testing-library/jest-dom/node_modules/dom-accessibility-api": { + "version": "0.6.3", + "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.6.3.tgz", + "integrity": "sha512-7ZgogeTnjuHbo+ct10G9Ffp0mif17idi0IyWNVA/wcwcm7NPOD/WEHVP3n7n3MhXqxoIYm8d6MuZohYWIZ4T3w==", + "dev": true, + "license": "MIT" + }, + "node_modules/@testing-library/react": { + "version": "16.3.2", + "resolved": "https://registry.npmjs.org/@testing-library/react/-/react-16.3.2.tgz", + "integrity": "sha512-XU5/SytQM+ykqMnAnvB2umaJNIOsLF3PVv//1Ew4CTcpz0/BRyy/af40qqrt7SjKpDdT1saBMc42CUok5gaw+g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.12.5" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@testing-library/dom": "^10.0.0", + "@types/react": "^18.0.0 || ^19.0.0", + "@types/react-dom": "^18.0.0 || ^19.0.0", + "react": "^18.0.0 || ^19.0.0", + "react-dom": "^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@testing-library/user-event": { + "version": "14.6.1", + "resolved": "https://registry.npmjs.org/@testing-library/user-event/-/user-event-14.6.1.tgz", + "integrity": "sha512-vq7fv0rnt+QTXgPxr5Hjc210p6YKq2kmdziLgnsZGgLJ9e6VAShx1pACLuRjd/AS/sr7phAR58OIIpf0LlmQNw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12", + "npm": ">=6" + }, + "peerDependencies": { + "@testing-library/dom": ">=7.21.4" + } + }, + "node_modules/@types/aria-query": { + "version": "5.0.4", + "resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz", + "integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==", + "dev": true, + "license": "MIT", + "peer": true + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -1339,6 +1682,17 @@ "@babel/types": "^7.28.2" } }, + "node_modules/@types/chai": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/@types/chai/-/chai-5.2.3.tgz", + "integrity": "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/deep-eql": "*", + "assertion-error": "^2.0.1" + } + }, "node_modules/@types/d3-array": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz", @@ -1402,6 +1756,13 @@ "integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==", "license": "MIT" }, + "node_modules/@types/deep-eql": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz", + "integrity": "sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/estree": { "version": "1.0.8", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", @@ -1446,6 +1807,183 @@ "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, + "node_modules/@vitest/coverage-v8": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/coverage-v8/-/coverage-v8-4.0.18.tgz", + "integrity": "sha512-7i+N2i0+ME+2JFZhfuz7Tg/FqKtilHjGyGvoHYQ6iLV0zahbsJ9sljC9OcFcPDbhYKCet+sG8SsVqlyGvPflZg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@bcoe/v8-coverage": "^1.0.2", + "@vitest/utils": "4.0.18", + "ast-v8-to-istanbul": "^0.3.10", + "istanbul-lib-coverage": "^3.2.2", + "istanbul-lib-report": "^3.0.1", + "istanbul-reports": "^3.2.0", + "magicast": "^0.5.1", + "obug": "^2.1.1", + "std-env": "^3.10.0", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@vitest/browser": "4.0.18", + "vitest": "4.0.18" + }, + "peerDependenciesMeta": { + "@vitest/browser": { + "optional": true + } + } + }, + "node_modules/@vitest/expect": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.0.18.tgz", + "integrity": "sha512-8sCWUyckXXYvx4opfzVY03EOiYVxyNrHS5QxX3DAIi5dpJAAkyJezHCP77VMX4HKA2LDT/Jpfo8i2r5BE3GnQQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@standard-schema/spec": "^1.0.0", + "@types/chai": "^5.2.2", + "@vitest/spy": "4.0.18", + "@vitest/utils": "4.0.18", + "chai": "^6.2.1", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/mocker": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.0.18.tgz", + "integrity": "sha512-HhVd0MDnzzsgevnOWCBj5Otnzobjy5wLBe4EdeeFGv8luMsGcYqDuFRMcttKWZA5vVO8RFjexVovXvAM4JoJDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/spy": "4.0.18", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.21" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "msw": "^2.4.9", + "vite": "^6.0.0 || ^7.0.0-0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } + } + }, + "node_modules/@vitest/pretty-format": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.0.18.tgz", + "integrity": "sha512-P24GK3GulZWC5tz87ux0m8OADrQIUVDPIjjj65vBXYG17ZeU3qD7r+MNZ1RNv4l8CGU2vtTRqixrOi9fYk/yKw==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/runner": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.0.18.tgz", + "integrity": "sha512-rpk9y12PGa22Jg6g5M3UVVnTS7+zycIGk9ZNGN+m6tZHKQb7jrP7/77WfZy13Y/EUDd52NDsLRQhYKtv7XfPQw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "4.0.18", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/snapshot": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.0.18.tgz", + "integrity": "sha512-PCiV0rcl7jKQjbgYqjtakly6T1uwv/5BQ9SwBLekVg/EaYeQFPiXcgrC2Y7vDMA8dM1SUEAEV82kgSQIlXNMvA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.0.18", + "magic-string": "^0.30.21", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/spy": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.0.18.tgz", + "integrity": "sha512-cbQt3PTSD7P2OARdVW3qWER5EGq7PHlvE+QfzSC0lbwO+xnt7+XH06ZzFjFRgzUX//JmpxrCu92VdwvEPlWSNw==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/utils": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.0.18.tgz", + "integrity": "sha512-msMRKLMVLWygpK3u2Hybgi4MNjcYJvwTb0Ru09+fOyCXIgT5raYP041DRRdiJiI3k/2U6SEbAETB3YtBrUkCFA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.0.18", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/agent-base": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.4.tgz", + "integrity": "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "peer": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true, + "license": "MIT", + "peer": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, "node_modules/any-promise": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", @@ -1474,6 +2012,45 @@ "dev": true, "license": "MIT" }, + "node_modules/aria-query": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.0.tgz", + "integrity": "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "dequal": "^2.0.3" + } + }, + "node_modules/assertion-error": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", + "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, + "node_modules/ast-v8-to-istanbul": { + "version": "0.3.10", + "resolved": "https://registry.npmjs.org/ast-v8-to-istanbul/-/ast-v8-to-istanbul-0.3.10.tgz", + "integrity": "sha512-p4K7vMz2ZSk3wN8l5o3y2bJAoZXT3VuJI5OLTATY/01CYWumWvwkUw0SqDBnNq6IiTO3qDa1eSQDibAV8g7XOQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.31", + "estree-walker": "^3.0.3", + "js-tokens": "^9.0.1" + } + }, + "node_modules/ast-v8-to-istanbul/node_modules/js-tokens": { + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-9.0.1.tgz", + "integrity": "sha512-mxa9E9ITFOt0ban3j6L5MpjwegGz6lBQmM1IJkWeBZGcMxto50+eWdjC/52xDbS2vy0k7vIMK0Fe2wfL9OQSpQ==", + "dev": true, + "license": "MIT" + }, "node_modules/asynckit": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", @@ -1538,6 +2115,16 @@ "baseline-browser-mapping": "dist/cli.js" } }, + "node_modules/bidi-js": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/bidi-js/-/bidi-js-1.0.3.tgz", + "integrity": "sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==", + "dev": true, + "license": "MIT", + "dependencies": { + "require-from-string": "^2.0.2" + } + }, "node_modules/binary-extensions": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", @@ -1642,6 +2229,16 @@ ], "license": "CC-BY-4.0" }, + "node_modules/chai": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/chai/-/chai-6.2.2.tgz", + "integrity": "sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/chokidar": { "version": "3.6.0", "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", @@ -1718,6 +2315,27 @@ "dev": true, "license": "MIT" }, + "node_modules/css-tree": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-3.1.0.tgz", + "integrity": "sha512-0eW44TGN5SQXU1mWSkKwFstI/22X2bG1nYzZTYMAWjylYURhse752YgbE4Cx46AC+bAvI+/dYTPRk1LqSUnu6w==", + "dev": true, + "license": "MIT", + "dependencies": { + "mdn-data": "2.12.2", + "source-map-js": "^1.0.1" + }, + "engines": { + "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0" + } + }, + "node_modules/css.escape": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/css.escape/-/css.escape-1.5.1.tgz", + "integrity": "sha512-YUifsXXuknHlUsmlgyY0PKzgPOr7/FjCePfHNt0jxm83wHZi44VDMQ7/fGNkjY3/jV1MC+1CmZbaHzugyeRtpg==", + "dev": true, + "license": "MIT" + }, "node_modules/cssesc": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", @@ -1731,6 +2349,32 @@ "node": ">=4" } }, + "node_modules/cssstyle": { + "version": "5.3.7", + "resolved": "https://registry.npmjs.org/cssstyle/-/cssstyle-5.3.7.tgz", + "integrity": "sha512-7D2EPVltRrsTkhpQmksIu+LxeWAIEk6wRDMJ1qljlv+CKHJM+cJLlfhWIzNA44eAsHXSNe3+vO6DW1yCYx8SuQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/css-color": "^4.1.1", + "@csstools/css-syntax-patches-for-csstree": "^1.0.21", + "css-tree": "^3.1.0", + "lru-cache": "^11.2.4" + }, + "engines": { + "node": ">=20" + } + }, + "node_modules/cssstyle/node_modules/lru-cache": { + "version": "11.2.5", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.5.tgz", + "integrity": "sha512-vFrFJkWtJvJnD5hg+hJvVE8Lh/TcMzKnTgCWmtBipwI5yLX/iX+5UB2tfuyODF5E7k9xEzMdYgGqaSb1c0c5Yw==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, "node_modules/d3-array": { "version": "3.2.4", "resolved": "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz", @@ -1852,6 +2496,30 @@ "node": ">=12" } }, + "node_modules/data-urls": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/data-urls/-/data-urls-6.0.1.tgz", + "integrity": "sha512-euIQENZg6x8mj3fO6o9+fOW8MimUI4PpD/fZBhJfeioZVy9TUpM4UY7KjQNVZFlqwJ0UdzRDzkycB997HEq1BQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "whatwg-mimetype": "^5.0.0", + "whatwg-url": "^15.1.0" + }, + "engines": { + "node": ">=20" + } + }, + "node_modules/data-urls/node_modules/whatwg-mimetype": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-5.0.0.tgz", + "integrity": "sha512-sXcNcHOC51uPGF0P/D4NVtrkjSU2fNsm9iog4ZvZJsL3rjoDAzXZhkm2MWt1y+PUdggKAYVoMAIYcs78wJ51Cw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=20" + } + }, "node_modules/date-fns": { "version": "3.6.0", "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-3.6.0.tgz", @@ -1880,6 +2548,13 @@ } } }, + "node_modules/decimal.js": { + "version": "10.6.0", + "resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.6.0.tgz", + "integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==", + "dev": true, + "license": "MIT" + }, "node_modules/decimal.js-light": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/decimal.js-light/-/decimal.js-light-2.5.1.tgz", @@ -1895,6 +2570,16 @@ "node": ">=0.4.0" } }, + "node_modules/dequal": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz", + "integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/didyoumean": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", @@ -1909,6 +2594,14 @@ "dev": true, "license": "MIT" }, + "node_modules/dom-accessibility-api": { + "version": "0.5.16", + "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz", + "integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==", + "dev": true, + "license": "MIT", + "peer": true + }, "node_modules/dunder-proto": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", @@ -1930,6 +2623,19 @@ "dev": true, "license": "ISC" }, + "node_modules/entities": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/entities/-/entities-6.0.1.tgz", + "integrity": "sha512-aN97NXWF6AWBTahfVOIrB/NShkzi5H7F9r1s9mD3cDj4Ko5f2qhhVoYMibXF7GlLveb/D2ioWay8lxI97Ven3g==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.12" + }, + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, "node_modules/es-define-property": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", @@ -1948,6 +2654,13 @@ "node": ">= 0.4" } }, + "node_modules/es-module-lexer": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz", + "integrity": "sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==", + "dev": true, + "license": "MIT" + }, "node_modules/es-object-atoms": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", @@ -2037,12 +2750,32 @@ "node": ">=6" } }, + "node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, "node_modules/eventemitter3": { "version": "5.0.4", "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-5.0.4.tgz", "integrity": "sha512-mlsTRyGaPBjPedk6Bvw+aqbsXDtoAyAzm5MO7JgU+yVRyMQ5O8bD4Kcci7BS85f93veegeCPkL8R4GLClnjLFw==", "license": "MIT" }, + "node_modules/expect-type": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.3.0.tgz", + "integrity": "sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/fast-glob": { "version": "3.3.3", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", @@ -2242,6 +2975,16 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, "node_modules/has-symbols": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", @@ -2281,6 +3024,54 @@ "node": ">= 0.4" } }, + "node_modules/html-encoding-sniffer": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-6.0.0.tgz", + "integrity": "sha512-CV9TW3Y3f8/wT0BRFc1/KAVQ3TUHiXmaAb6VW9vtiMFf7SLoMd1PdAc4W3KFOFETBJUb90KatHqlsZMWV+R9Gg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@exodus/bytes": "^1.6.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, + "node_modules/html-escaper": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/html-escaper/-/html-escaper-2.0.2.tgz", + "integrity": "sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==", + "dev": true, + "license": "MIT" + }, + "node_modules/http-proxy-agent": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-7.0.2.tgz", + "integrity": "sha512-T1gkAiYYDWYx3V5Bmyu7HcfcvL7mUrTWiM6yOfa3PIphViJ/gFPbvidQ+veqSOHci/PxBcDabeUNCzpOODJZig==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.0", + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", + "integrity": "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, "node_modules/immer": { "version": "10.2.0", "resolved": "https://registry.npmjs.org/immer/-/immer-10.2.0.tgz", @@ -2291,6 +3082,16 @@ "url": "https://opencollective.com/immer" } }, + "node_modules/indent-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", + "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, "node_modules/internmap": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/internmap/-/internmap-2.0.3.tgz", @@ -2362,6 +3163,52 @@ "node": ">=0.12.0" } }, + "node_modules/is-potential-custom-element-name": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-potential-custom-element-name/-/is-potential-custom-element-name-1.0.1.tgz", + "integrity": "sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/istanbul-lib-coverage": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.2.tgz", + "integrity": "sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=8" + } + }, + "node_modules/istanbul-lib-report": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/istanbul-lib-report/-/istanbul-lib-report-3.0.1.tgz", + "integrity": "sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "istanbul-lib-coverage": "^3.0.0", + "make-dir": "^4.0.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/istanbul-reports": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/istanbul-reports/-/istanbul-reports-3.2.0.tgz", + "integrity": "sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "html-escaper": "^2.0.0", + "istanbul-lib-report": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/jiti": { "version": "1.21.7", "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz", @@ -2379,6 +3226,46 @@ "dev": true, "license": "MIT" }, + "node_modules/jsdom": { + "version": "27.4.0", + "resolved": "https://registry.npmjs.org/jsdom/-/jsdom-27.4.0.tgz", + "integrity": "sha512-mjzqwWRD9Y1J1KUi7W97Gja1bwOOM5Ug0EZ6UDK3xS7j7mndrkwozHtSblfomlzyB4NepioNt+B2sOSzczVgtQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@acemir/cssom": "^0.9.28", + "@asamuzakjp/dom-selector": "^6.7.6", + "@exodus/bytes": "^1.6.0", + "cssstyle": "^5.3.4", + "data-urls": "^6.0.0", + "decimal.js": "^10.6.0", + "html-encoding-sniffer": "^6.0.0", + "http-proxy-agent": "^7.0.2", + "https-proxy-agent": "^7.0.6", + "is-potential-custom-element-name": "^1.0.1", + "parse5": "^8.0.0", + "saxes": "^6.0.0", + "symbol-tree": "^3.2.4", + "tough-cookie": "^6.0.0", + "w3c-xmlserializer": "^5.0.0", + "webidl-conversions": "^8.0.0", + "whatwg-mimetype": "^4.0.0", + "whatwg-url": "^15.1.0", + "ws": "^8.18.3", + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + }, + "peerDependencies": { + "canvas": "^3.0.0" + }, + "peerDependenciesMeta": { + "canvas": { + "optional": true + } + } + }, "node_modules/jsesc": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", @@ -2444,6 +3331,68 @@ "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, + "node_modules/lz-string": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz", + "integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==", + "dev": true, + "license": "MIT", + "peer": true, + "bin": { + "lz-string": "bin/bin.js" + } + }, + "node_modules/magic-string": { + "version": "0.30.21", + "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.21.tgz", + "integrity": "sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.5" + } + }, + "node_modules/magicast": { + "version": "0.5.1", + "resolved": "https://registry.npmjs.org/magicast/-/magicast-0.5.1.tgz", + "integrity": "sha512-xrHS24IxaLrvuo613F719wvOIv9xPHFWQHuvGUBmPnCA/3MQxKI3b+r7n1jAoDHmsbC5bRhTZYR77invLAxVnw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.28.5", + "@babel/types": "^7.28.5", + "source-map-js": "^1.2.1" + } + }, + "node_modules/make-dir": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-4.0.0.tgz", + "integrity": "sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==", + "dev": true, + "license": "MIT", + "dependencies": { + "semver": "^7.5.3" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/make-dir/node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, "node_modules/math-intrinsics": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", @@ -2453,6 +3402,13 @@ "node": ">= 0.4" } }, + "node_modules/mdn-data": { + "version": "2.12.2", + "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.12.2.tgz", + "integrity": "sha512-IEn+pegP1aManZuckezWCO+XZQDplx1366JoVhTpMpBB1sPey/SbveZQUosKiKiGYjg1wH4pMlNgXbCiYgihQA==", + "dev": true, + "license": "CC0-1.0" + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -2498,6 +3454,16 @@ "node": ">= 0.6" } }, + "node_modules/min-indent": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz", + "integrity": "sha512-I9jwMn07Sy/IwOj3zVkVik2JTvgpaykDZEigL6Rx6N9LbMywwUSMtxET+7lVoDLLd3O3IXwJwvuuns8UB/HeAg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", @@ -2573,6 +3539,30 @@ "node": ">= 6" } }, + "node_modules/obug": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/obug/-/obug-2.1.1.tgz", + "integrity": "sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/sxzz", + "https://opencollective.com/debug" + ], + "license": "MIT" + }, + "node_modules/parse5": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/parse5/-/parse5-8.0.0.tgz", + "integrity": "sha512-9m4m5GSgXjL4AjumKzq1Fgfp3Z8rsvjRNbnkVwfu2ImRqE5D0LnY2QfDen18FSY9C573YU5XxSapdHZTZ2WolA==", + "dev": true, + "license": "MIT", + "dependencies": { + "entities": "^6.0.0" + }, + "funding": { + "url": "https://github.com/inikulin/parse5?sponsor=1" + } + }, "node_modules/path-parse": { "version": "1.0.7", "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", @@ -2580,6 +3570,13 @@ "dev": true, "license": "MIT" }, + "node_modules/pathe": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", + "integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==", + "dev": true, + "license": "MIT" + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", @@ -2783,12 +3780,46 @@ "dev": true, "license": "MIT" }, + "node_modules/pretty-format": { + "version": "27.5.1", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-27.5.1.tgz", + "integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "ansi-regex": "^5.0.1", + "ansi-styles": "^5.0.0", + "react-is": "^17.0.1" + }, + "engines": { + "node": "^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0" + } + }, + "node_modules/pretty-format/node_modules/react-is": { + "version": "17.0.2", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", + "integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==", + "dev": true, + "license": "MIT", + "peer": true + }, "node_modules/proxy-from-env": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", "license": "MIT" }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/queue-microtask": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", @@ -2956,6 +3987,20 @@ "react-is": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, + "node_modules/redent": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/redent/-/redent-3.0.0.tgz", + "integrity": "sha512-6tDA8g98We0zd0GvVeMT9arEOnTw9qM03L9cJXaCjrip1OO764RDBLBfrB4cwzNGDj5OA5ioymC9GkizgWJDUg==", + "dev": true, + "license": "MIT", + "dependencies": { + "indent-string": "^4.0.0", + "strip-indent": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/redux": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz", @@ -2971,6 +4016,16 @@ "redux": "^5.0.0" } }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/reselect": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/reselect/-/reselect-5.1.1.tgz", @@ -3078,6 +4133,19 @@ "queue-microtask": "^1.2.2" } }, + "node_modules/saxes": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/saxes/-/saxes-6.0.0.tgz", + "integrity": "sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==", + "dev": true, + "license": "ISC", + "dependencies": { + "xmlchars": "^2.2.0" + }, + "engines": { + "node": ">=v12.22.7" + } + }, "node_modules/scheduler": { "version": "0.27.0", "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.27.0.tgz", @@ -3094,6 +4162,13 @@ "semver": "bin/semver.js" } }, + "node_modules/siginfo": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz", + "integrity": "sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==", + "dev": true, + "license": "ISC" + }, "node_modules/source-map-js": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", @@ -3104,6 +4179,33 @@ "node": ">=0.10.0" } }, + "node_modules/stackback": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", + "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", + "dev": true, + "license": "MIT" + }, + "node_modules/std-env": { + "version": "3.10.0", + "resolved": "https://registry.npmjs.org/std-env/-/std-env-3.10.0.tgz", + "integrity": "sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==", + "dev": true, + "license": "MIT" + }, + "node_modules/strip-indent": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/strip-indent/-/strip-indent-3.0.0.tgz", + "integrity": "sha512-laJTa3Jb+VQpaC6DseHhF7dXVqHTfJPCRDaEbid/drOhgitgYku/letMUqOXFoWV0zIIUbjpdH2t+tYj4bQMRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "min-indent": "^1.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/sucrase": { "version": "3.35.1", "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.1.tgz", @@ -3127,6 +4229,19 @@ "node": ">=16 || 14 >=14.17" } }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/supports-preserve-symlinks-flag": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", @@ -3140,6 +4255,13 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/symbol-tree": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/symbol-tree/-/symbol-tree-3.2.4.tgz", + "integrity": "sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw==", + "dev": true, + "license": "MIT" + }, "node_modules/tailwindcss": { "version": "3.4.19", "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.19.tgz", @@ -3207,6 +4329,23 @@ "integrity": "sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==", "license": "MIT" }, + "node_modules/tinybench": { + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz", + "integrity": "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyexec": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-1.0.2.tgz", + "integrity": "sha512-W/KYk+NFhkmsYpuHq5JykngiOCnxeVL8v8dFnqxSD8qEEdRfXk1SDM6JzNqcERbcGYj9tMrDQBYV9cjgnunFIg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -3255,6 +4394,36 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/tinyrainbow": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-3.0.3.tgz", + "integrity": "sha512-PSkbLUoxOFRzJYjjxHJt9xro7D+iilgMX/C9lawzVuYiIdcihh9DXmVibBe8lmcFrRi/VzlPjBxbN7rH24q8/Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tldts": { + "version": "7.0.19", + "resolved": "https://registry.npmjs.org/tldts/-/tldts-7.0.19.tgz", + "integrity": "sha512-8PWx8tvC4jDB39BQw1m4x8y5MH1BcQ5xHeL2n7UVFulMPH/3Q0uiamahFJ3lXA0zO2SUyRXuVVbWSDmstlt9YA==", + "dev": true, + "license": "MIT", + "dependencies": { + "tldts-core": "^7.0.19" + }, + "bin": { + "tldts": "bin/cli.js" + } + }, + "node_modules/tldts-core": { + "version": "7.0.19", + "resolved": "https://registry.npmjs.org/tldts-core/-/tldts-core-7.0.19.tgz", + "integrity": "sha512-lJX2dEWx0SGH4O6p+7FPwYmJ/bu1JbcGJ8RLaG9b7liIgZ85itUVEPbMtWRVrde/0fnDPEPHW10ZsKW3kVsE9A==", + "dev": true, + "license": "MIT" + }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -3268,6 +4437,32 @@ "node": ">=8.0" } }, + "node_modules/tough-cookie": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-6.0.0.tgz", + "integrity": "sha512-kXuRi1mtaKMrsLUxz3sQYvVl37B0Ns6MzfrtV5DvJceE9bPyspOqk9xxv7XbZWcfLWbFmm997vl83qUWVJA64w==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "tldts": "^7.0.5" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/tr46": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-6.0.0.tgz", + "integrity": "sha512-bLVMLPtstlZ4iMQHpFHTR7GAGj2jxi8Dg0s2h2MafAE4uSWF98FC/3MomU51iQAMf8/qDUbKWf5GxuvvVcXEhw==", + "dev": true, + "license": "MIT", + "dependencies": { + "punycode": "^2.3.1" + }, + "engines": { + "node": ">=20" + } + }, "node_modules/ts-interface-checker": { "version": "0.1.13", "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", @@ -3471,6 +4666,200 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/vitest": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.0.18.tgz", + "integrity": "sha512-hOQuK7h0FGKgBAas7v0mSAsnvrIgAvWmRFjmzpJ7SwFHH3g1k2u37JtYwOwmEKhK6ZO3v9ggDBBm0La1LCK4uQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/expect": "4.0.18", + "@vitest/mocker": "4.0.18", + "@vitest/pretty-format": "4.0.18", + "@vitest/runner": "4.0.18", + "@vitest/snapshot": "4.0.18", + "@vitest/spy": "4.0.18", + "@vitest/utils": "4.0.18", + "es-module-lexer": "^1.7.0", + "expect-type": "^1.2.2", + "magic-string": "^0.30.21", + "obug": "^2.1.1", + "pathe": "^2.0.3", + "picomatch": "^4.0.3", + "std-env": "^3.10.0", + "tinybench": "^2.9.0", + "tinyexec": "^1.0.2", + "tinyglobby": "^0.2.15", + "tinyrainbow": "^3.0.3", + "vite": "^6.0.0 || ^7.0.0", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^20.0.0 || ^22.0.0 || >=24.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@opentelemetry/api": "^1.9.0", + "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", + "@vitest/browser-playwright": "4.0.18", + "@vitest/browser-preview": "4.0.18", + "@vitest/browser-webdriverio": "4.0.18", + "@vitest/ui": "4.0.18", + "happy-dom": "*", + "jsdom": "*" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@opentelemetry/api": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser-playwright": { + "optional": true + }, + "@vitest/browser-preview": { + "optional": true + }, + "@vitest/browser-webdriverio": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + } + } + }, + "node_modules/vitest/node_modules/picomatch": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/w3c-xmlserializer": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-5.0.0.tgz", + "integrity": "sha512-o8qghlI8NZHU1lLPrpi2+Uq7abh4GGPpYANlalzWxyWteJOCsr/P+oPBA49TOLu5FTZO4d3F9MnWJfiMo4BkmA==", + "dev": true, + "license": "MIT", + "dependencies": { + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/webidl-conversions": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-8.0.1.tgz", + "integrity": "sha512-BMhLD/Sw+GbJC21C/UgyaZX41nPt8bUTg+jWyDeg7e7YN4xOM05YPSIXceACnXVtqyEw/LMClUQMtMZ+PGGpqQ==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=20" + } + }, + "node_modules/whatwg-mimetype": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-4.0.0.tgz", + "integrity": "sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/whatwg-url": { + "version": "15.1.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-15.1.0.tgz", + "integrity": "sha512-2ytDk0kiEj/yu90JOAp44PVPUkO9+jVhyf+SybKlRHSDlvOOZhdPIrr7xTH64l4WixO2cP+wQIcgujkGBPPz6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "tr46": "^6.0.0", + "webidl-conversions": "^8.0.0" + }, + "engines": { + "node": ">=20" + } + }, + "node_modules/why-is-node-running": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz", + "integrity": "sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==", + "dev": true, + "license": "MIT", + "dependencies": { + "siginfo": "^2.0.0", + "stackback": "0.0.2" + }, + "bin": { + "why-is-node-running": "cli.js" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/ws": { + "version": "8.19.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz", + "integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xml-name-validator": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/xml-name-validator/-/xml-name-validator-5.0.0.tgz", + "integrity": "sha512-EvGK8EJ3DhaHfbRlETOWAS5pO9MZITeauHKJyb8wyajUfQUenkIg2MvLDTZ4T/TgIcm3HU0TFBgWWboAZ30UHg==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18" + } + }, + "node_modules/xmlchars": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz", + "integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==", + "dev": true, + "license": "MIT" + }, "node_modules/yallist": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", diff --git a/frontend/package.json b/frontend/package.json index 3fd68ad..1d7e41e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -6,27 +6,36 @@ "scripts": { "dev": "vite", "build": "vite build", - "preview": "vite preview" + "preview": "vite preview", + "test": "vitest run", + "test:watch": "vitest", + "test:coverage": "vitest run --coverage" }, "dependencies": { + "@tanstack/react-query": "^5.20.0", + "axios": "^1.6.7", + "clsx": "^2.1.0", + "date-fns": "^3.3.0", + "lucide-react": "^0.563.0", "react": "^19.2.3", "react-dom": "^19.2.3", - "lucide-react": "^0.563.0", - "recharts": "^3.7.0", - "axios": "^1.6.7", "react-router-dom": "^6.22.0", - "zustand": "^4.5.0", - "@tanstack/react-query": "^5.20.0", - "date-fns": "^3.3.0", - "clsx": "^2.1.0" + "recharts": "^3.7.0", + "zustand": "^4.5.0" }, "devDependencies": { + "@testing-library/jest-dom": "^6.9.1", + "@testing-library/react": "^16.3.2", + "@testing-library/user-event": "^14.6.1", "@types/node": "^22.14.0", "@vitejs/plugin-react": "^5.0.0", + "@vitest/coverage-v8": "^4.0.18", + "autoprefixer": "^10.4.17", + "jsdom": "^27.4.0", + "postcss": "^8.4.35", + "tailwindcss": "^3.4.1", "typescript": "~5.8.2", "vite": "^6.2.0", - "tailwindcss": "^3.4.1", - "autoprefixer": "^10.4.17", - "postcss": "^8.4.35" + "vitest": "^4.0.18" } } diff --git a/frontend/src/components/Badge.test.tsx b/frontend/src/components/Badge.test.tsx new file mode 100644 index 0000000..029a5d4 --- /dev/null +++ b/frontend/src/components/Badge.test.tsx @@ -0,0 +1,32 @@ +import { render, screen } from '@testing-library/react'; +import { describe, it, expect } from 'vitest'; +import { Badge } from './Badge'; +import { DocumentStatus } from '../types'; + +describe('Badge', () => { + it('renders Exported badge with check icon', () => { + render(); + expect(screen.getByText('Exported')).toBeInTheDocument(); + }); + + it('renders Pending status', () => { + render(); + expect(screen.getByText('Pending')).toBeInTheDocument(); + }); + + it('renders Verified status', () => { + render(); + expect(screen.getByText('Verified')).toBeInTheDocument(); + }); + + it('renders Labeled status', () => { + render(); + expect(screen.getByText('Labeled')).toBeInTheDocument(); + }); + + it('renders Partial status with warning indicator', () => { + render(); + expect(screen.getByText('Partial')).toBeInTheDocument(); + expect(screen.getByText('!')).toBeInTheDocument(); + }); +}); diff --git a/frontend/src/components/Button.test.tsx b/frontend/src/components/Button.test.tsx new file mode 100644 index 0000000..c62cec6 --- /dev/null +++ b/frontend/src/components/Button.test.tsx @@ -0,0 +1,38 @@ +import { render, screen } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { describe, it, expect, vi } from 'vitest'; +import { Button } from './Button'; + +describe('Button', () => { + it('renders children text', () => { + render(); + expect(screen.getByRole('button', { name: 'Click me' })).toBeInTheDocument(); + }); + + it('calls onClick handler', async () => { + const user = userEvent.setup(); + const onClick = vi.fn(); + render(); + await user.click(screen.getByRole('button')); + expect(onClick).toHaveBeenCalledOnce(); + }); + + it('is disabled when disabled prop is set', () => { + render(); + expect(screen.getByRole('button')).toBeDisabled(); + }); + + it('applies variant styles', () => { + const { rerender } = render(); + const btn = screen.getByRole('button'); + expect(btn.className).toContain('bg-warm-text-secondary'); + + rerender(); + expect(screen.getByRole('button').className).toContain('border'); + }); + + it('applies size styles', () => { + render(); + expect(screen.getByRole('button').className).toContain('h-8'); + }); +}); diff --git a/frontend/tests/setup.ts b/frontend/tests/setup.ts new file mode 100644 index 0000000..7b0828b --- /dev/null +++ b/frontend/tests/setup.ts @@ -0,0 +1 @@ +import '@testing-library/jest-dom'; diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts new file mode 100644 index 0000000..c2c081f --- /dev/null +++ b/frontend/vitest.config.ts @@ -0,0 +1,19 @@ +/// +import { defineConfig } from 'vite'; +import react from '@vitejs/plugin-react'; + +export default defineConfig({ + plugins: [react()], + test: { + globals: true, + environment: 'jsdom', + setupFiles: ['./tests/setup.ts'], + include: ['src/**/*.test.{ts,tsx}', 'tests/**/*.test.{ts,tsx}'], + coverage: { + provider: 'v8', + reporter: ['text', 'lcov'], + include: ['src/**/*.{ts,tsx}'], + exclude: ['src/**/*.test.{ts,tsx}', 'src/main.tsx'], + }, + }, +}); diff --git a/migrations/003_training_tasks.sql b/migrations/003_training_tasks.sql new file mode 100644 index 0000000..f84b94d --- /dev/null +++ b/migrations/003_training_tasks.sql @@ -0,0 +1,18 @@ +-- Training tasks table for async training job management. +-- Inference service writes pending tasks; training service polls and executes. + +CREATE TABLE IF NOT EXISTS training_tasks ( + task_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + status VARCHAR(20) NOT NULL DEFAULT 'pending', + config JSONB, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + scheduled_at TIMESTAMP WITH TIME ZONE, + started_at TIMESTAMP WITH TIME ZONE, + completed_at TIMESTAMP WITH TIME ZONE, + error_message TEXT, + model_path TEXT, + metrics JSONB +); + +CREATE INDEX IF NOT EXISTS idx_training_tasks_status ON training_tasks(status); +CREATE INDEX IF NOT EXISTS idx_training_tasks_created ON training_tasks(created_at); diff --git a/migrations/004_training_datasets.sql b/migrations/004_training_datasets.sql new file mode 100644 index 0000000..5228953 --- /dev/null +++ b/migrations/004_training_datasets.sql @@ -0,0 +1,39 @@ +-- Training Datasets Management +-- Tracks dataset-document relationships and train/val/test splits + +CREATE TABLE IF NOT EXISTS training_datasets ( + dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name VARCHAR(255) NOT NULL, + description TEXT, + status VARCHAR(20) NOT NULL DEFAULT 'building', + train_ratio FLOAT NOT NULL DEFAULT 0.8, + val_ratio FLOAT NOT NULL DEFAULT 0.1, + seed INTEGER NOT NULL DEFAULT 42, + total_documents INTEGER NOT NULL DEFAULT 0, + total_images INTEGER NOT NULL DEFAULT 0, + total_annotations INTEGER NOT NULL DEFAULT 0, + dataset_path VARCHAR(512), + error_message TEXT, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status); + +CREATE TABLE IF NOT EXISTS dataset_documents ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE, + document_id UUID NOT NULL REFERENCES admin_documents(document_id), + split VARCHAR(10) NOT NULL, + page_count INTEGER NOT NULL DEFAULT 0, + annotation_count INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + UNIQUE(dataset_id, document_id) +); + +CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id); +CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id); + +-- Add dataset_id to training_tasks +ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id); +CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id); diff --git a/packages/inference/Dockerfile b/packages/inference/Dockerfile new file mode 100644 index 0000000..c4d6b91 --- /dev/null +++ b/packages/inference/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install shared package +COPY packages/shared /app/packages/shared +RUN pip install --no-cache-dir -e /app/packages/shared + +# Install inference package +COPY packages/inference /app/packages/inference +RUN pip install --no-cache-dir -e /app/packages/inference + +# Copy frontend (if needed) +COPY frontend /app/frontend + +WORKDIR /app/packages/inference + +EXPOSE 8000 + +CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"] diff --git a/src/web/api/__init__.py b/packages/inference/inference/__init__.py similarity index 100% rename from src/web/api/__init__.py rename to packages/inference/inference/__init__.py diff --git a/src/web/api/v1/__init__.py b/packages/inference/inference/azure/__init__.py similarity index 100% rename from src/web/api/v1/__init__.py rename to packages/inference/inference/azure/__init__.py diff --git a/packages/inference/inference/azure/aci_trigger.py b/packages/inference/inference/azure/aci_trigger.py new file mode 100644 index 0000000..36210e0 --- /dev/null +++ b/packages/inference/inference/azure/aci_trigger.py @@ -0,0 +1,105 @@ +"""Trigger training jobs on Azure Container Instances.""" + +import logging +import os + +logger = logging.getLogger(__name__) + +# Azure SDK is optional; only needed if using ACI trigger +try: + from azure.identity import DefaultAzureCredential + from azure.mgmt.containerinstance import ContainerInstanceManagementClient + from azure.mgmt.containerinstance.models import ( + Container, + ContainerGroup, + EnvironmentVariable, + GpuResource, + ResourceRequests, + ResourceRequirements, + ) + + _AZURE_SDK_AVAILABLE = True +except ImportError: + _AZURE_SDK_AVAILABLE = False + + +def start_training_container(task_id: str) -> str | None: + """ + Start an Azure Container Instance for a training task. + + Returns the container group name if successful, None otherwise. + Requires environment variables: + AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE + """ + if not _AZURE_SDK_AVAILABLE: + logger.warning( + "Azure SDK not installed. Install azure-mgmt-containerinstance " + "and azure-identity to use ACI trigger." + ) + return None + + subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "") + resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg") + image = os.environ.get( + "AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest" + ) + gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100") + location = os.environ.get("AZURE_LOCATION", "eastus") + + if not subscription_id: + logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.") + return None + + credential = DefaultAzureCredential() + client = ContainerInstanceManagementClient(credential, subscription_id) + + container_name = f"training-{task_id[:8]}" + + env_vars = [ + EnvironmentVariable(name="TASK_ID", value=task_id), + ] + + # Pass DB connection securely + for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"): + val = os.environ.get(var, "") + if val: + env_vars.append(EnvironmentVariable(name=var, value=val)) + + db_password = os.environ.get("DB_PASSWORD", "") + if db_password: + env_vars.append( + EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password) + ) + + container = Container( + name=container_name, + image=image, + resources=ResourceRequirements( + requests=ResourceRequests( + cpu=4, + memory_in_gb=16, + gpu=GpuResource(count=1, sku=gpu_sku), + ) + ), + environment_variables=env_vars, + command=[ + "python", + "run_training.py", + "--task-id", + task_id, + ], + ) + + group = ContainerGroup( + location=location, + containers=[container], + os_type="Linux", + restart_policy="Never", + ) + + logger.info("Creating ACI container group: %s", container_name) + client.container_groups.begin_create_or_update( + resource_group, container_name, group + ) + + return container_name diff --git a/src/web/api/v1/batch/__init__.py b/packages/inference/inference/cli/__init__.py similarity index 100% rename from src/web/api/v1/batch/__init__.py rename to packages/inference/inference/cli/__init__.py diff --git a/src/cli/infer.py b/packages/inference/inference/cli/infer.py similarity index 96% rename from src/cli/infer.py rename to packages/inference/inference/cli/infer.py index c4ec682..834123c 100644 --- a/src/cli/infer.py +++ b/packages/inference/inference/cli/infer.py @@ -10,8 +10,7 @@ import json import sys from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import DEFAULT_DPI +from shared.config import DEFAULT_DPI def main(): @@ -91,7 +90,7 @@ def main(): print(f"Processing {len(pdf_files)} PDF file(s)") print(f"Model: {model_path}") - from ..inference import InferencePipeline + from inference.pipeline import InferencePipeline # Initialize pipeline pipeline = InferencePipeline( diff --git a/src/cli/serve.py b/packages/inference/inference/cli/serve.py similarity index 94% rename from src/cli/serve.py rename to packages/inference/inference/cli/serve.py index d87fff2..cb71ce4 100644 --- a/src/cli/serve.py +++ b/packages/inference/inference/cli/serve.py @@ -13,9 +13,8 @@ from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent.parent -sys.path.insert(0, str(project_root)) -from src.config import DEFAULT_DPI +from shared.config import DEFAULT_DPI def setup_logging(debug: bool = False) -> None: @@ -121,7 +120,7 @@ def main() -> None: logger.info("=" * 60) # Create config - from src.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig + from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig config = AppConfig( model=ModelConfig( @@ -142,7 +141,7 @@ def main() -> None: # Create and run app import uvicorn - from src.web.app import create_app + from inference.web.app import create_app app = create_app(config) diff --git a/packages/inference/inference/data/__init__.py b/packages/inference/inference/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/admin_db.py b/packages/inference/inference/data/admin_db.py similarity index 87% rename from src/data/admin_db.py rename to packages/inference/inference/data/admin_db.py index f55a36f..02f9d8c 100644 --- a/src/data/admin_db.py +++ b/packages/inference/inference/data/admin_db.py @@ -12,8 +12,8 @@ from uuid import UUID from sqlalchemy import func from sqlmodel import select -from src.data.database import get_session_context -from src.data.admin_models import ( +from inference.data.database import get_session_context +from inference.data.admin_models import ( AdminToken, AdminDocument, AdminAnnotation, @@ -23,6 +23,8 @@ from src.data.admin_models import ( BatchUploadFile, TrainingDocumentLink, AnnotationHistory, + TrainingDataset, + DatasetDocument, ) logger = logging.getLogger(__name__) @@ -174,7 +176,7 @@ class AdminDB: # For has_annotations filter, we need to join with annotations if has_annotations is not None: - from src.data.admin_models import AdminAnnotation + from inference.data.admin_models import AdminAnnotation if has_annotations: # Documents WITH annotations @@ -200,7 +202,7 @@ class AdminDB: # Apply has_annotations filter if has_annotations is not None: - from src.data.admin_models import AdminAnnotation + from inference.data.admin_models import AdminAnnotation if has_annotations: statement = ( @@ -456,6 +458,7 @@ class AdminDB: scheduled_at: datetime | None = None, cron_expression: str | None = None, is_recurring: bool = False, + dataset_id: str | None = None, ) -> str: """Create a new training task.""" with get_session_context() as session: @@ -469,6 +472,7 @@ class AdminDB: cron_expression=cron_expression, is_recurring=is_recurring, status="scheduled" if scheduled_at else "pending", + dataset_id=dataset_id, ) session.add(task) session.flush() @@ -1154,3 +1158,159 @@ class AdminDB: session.refresh(annotation) session.expunge(annotation) return annotation + + # ========================================================================== + # Training Dataset Operations + # ========================================================================== + + def create_dataset( + self, + name: str, + description: str | None = None, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + seed: int = 42, + ) -> TrainingDataset: + """Create a new training dataset.""" + with get_session_context() as session: + dataset = TrainingDataset( + name=name, + description=description, + train_ratio=train_ratio, + val_ratio=val_ratio, + seed=seed, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + session.expunge(dataset) + return dataset + + def get_dataset(self, dataset_id: str | UUID) -> TrainingDataset | None: + """Get a dataset by ID.""" + with get_session_context() as session: + dataset = session.get(TrainingDataset, UUID(str(dataset_id))) + if dataset: + session.expunge(dataset) + return dataset + + def get_datasets( + self, + status: str | None = None, + limit: int = 20, + offset: int = 0, + ) -> tuple[list[TrainingDataset], int]: + """List datasets with optional status filter.""" + with get_session_context() as session: + query = select(TrainingDataset) + count_query = select(func.count()).select_from(TrainingDataset) + if status: + query = query.where(TrainingDataset.status == status) + count_query = count_query.where(TrainingDataset.status == status) + total = session.exec(count_query).one() + datasets = session.exec( + query.order_by(TrainingDataset.created_at.desc()).offset(offset).limit(limit) + ).all() + for d in datasets: + session.expunge(d) + return list(datasets), total + + def update_dataset_status( + self, + dataset_id: str | UUID, + status: str, + error_message: str | None = None, + total_documents: int | None = None, + total_images: int | None = None, + total_annotations: int | None = None, + dataset_path: str | None = None, + ) -> None: + """Update dataset status and optional totals.""" + with get_session_context() as session: + dataset = session.get(TrainingDataset, UUID(str(dataset_id))) + if not dataset: + return + dataset.status = status + dataset.updated_at = datetime.utcnow() + if error_message is not None: + dataset.error_message = error_message + if total_documents is not None: + dataset.total_documents = total_documents + if total_images is not None: + dataset.total_images = total_images + if total_annotations is not None: + dataset.total_annotations = total_annotations + if dataset_path is not None: + dataset.dataset_path = dataset_path + session.add(dataset) + session.commit() + + def add_dataset_documents( + self, + dataset_id: str | UUID, + documents: list[dict[str, Any]], + ) -> None: + """Batch insert documents into a dataset. + + Each dict: {document_id, split, page_count, annotation_count} + """ + with get_session_context() as session: + for doc in documents: + dd = DatasetDocument( + dataset_id=UUID(str(dataset_id)), + document_id=UUID(str(doc["document_id"])), + split=doc["split"], + page_count=doc.get("page_count", 0), + annotation_count=doc.get("annotation_count", 0), + ) + session.add(dd) + session.commit() + + def get_dataset_documents( + self, dataset_id: str | UUID + ) -> list[DatasetDocument]: + """Get all documents in a dataset.""" + with get_session_context() as session: + results = session.exec( + select(DatasetDocument) + .where(DatasetDocument.dataset_id == UUID(str(dataset_id))) + ).all() + for r in results: + session.expunge(r) + return list(results) + + def get_documents_by_ids( + self, document_ids: list[str] + ) -> list[AdminDocument]: + """Get documents by list of IDs.""" + with get_session_context() as session: + uuids = [UUID(str(did)) for did in document_ids] + results = session.exec( + select(AdminDocument).where(AdminDocument.document_id.in_(uuids)) + ).all() + for r in results: + session.expunge(r) + return list(results) + + def get_annotations_for_document( + self, document_id: str | UUID + ) -> list[AdminAnnotation]: + """Get all annotations for a document.""" + with get_session_context() as session: + results = session.exec( + select(AdminAnnotation) + .where(AdminAnnotation.document_id == UUID(str(document_id))) + ).all() + for r in results: + session.expunge(r) + return list(results) + + def delete_dataset(self, dataset_id: str | UUID) -> bool: + """Delete a dataset and its document links (CASCADE).""" + with get_session_context() as session: + dataset = session.get(TrainingDataset, UUID(str(dataset_id))) + if not dataset: + return False + session.delete(dataset) + session.commit() + return True diff --git a/src/data/admin_models.py b/packages/inference/inference/data/admin_models.py similarity index 84% rename from src/data/admin_models.py rename to packages/inference/inference/data/admin_models.py index 748bfd4..a374d07 100644 --- a/src/data/admin_models.py +++ b/packages/inference/inference/data/admin_models.py @@ -131,6 +131,7 @@ class TrainingTask(SQLModel, table=True): # Status: pending, scheduled, running, completed, failed, cancelled task_type: str = Field(default="train", max_length=20) # Task type: train, finetune + dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True) # Training configuration config: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) # Schedule settings @@ -225,6 +226,42 @@ class BatchUploadFile(SQLModel, table=True): # ============================================================================= +class TrainingDataset(SQLModel, table=True): + """Training dataset containing selected documents with train/val/test splits.""" + + __tablename__ = "training_datasets" + + dataset_id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str = Field(max_length=255) + description: str | None = Field(default=None) + status: str = Field(default="building", max_length=20, index=True) + # Status: building, ready, training, archived, failed + train_ratio: float = Field(default=0.8) + val_ratio: float = Field(default=0.1) + seed: int = Field(default=42) + total_documents: int = Field(default=0) + total_images: int = Field(default=0) + total_annotations: int = Field(default=0) + dataset_path: str | None = Field(default=None, max_length=512) + error_message: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + +class DatasetDocument(SQLModel, table=True): + """Junction table linking datasets to documents with split assignment.""" + + __tablename__ = "dataset_documents" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + dataset_id: UUID = Field(foreign_key="training_datasets.dataset_id", index=True) + document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True) + split: str = Field(max_length=10) # train, val, test + page_count: int = Field(default=0) + annotation_count: int = Field(default=0) + created_at: datetime = Field(default_factory=datetime.utcnow) + + class TrainingDocumentLink(SQLModel, table=True): """Junction table linking training tasks to documents.""" @@ -336,4 +373,35 @@ class TrainingTaskRead(SQLModel): error_message: str | None result_metrics: dict[str, Any] | None model_path: str | None + dataset_id: UUID | None created_at: datetime + + +class TrainingDatasetRead(SQLModel): + """Training dataset response model.""" + + dataset_id: UUID + name: str + description: str | None + status: str + train_ratio: float + val_ratio: float + seed: int + total_documents: int + total_images: int + total_annotations: int + dataset_path: str | None + error_message: str | None + created_at: datetime + updated_at: datetime + + +class DatasetDocumentRead(SQLModel): + """Dataset document response model.""" + + id: UUID + dataset_id: UUID + document_id: UUID + split: str + page_count: int + annotation_count: int diff --git a/src/data/async_request_db.py b/packages/inference/inference/data/async_request_db.py similarity index 98% rename from src/data/async_request_db.py rename to packages/inference/inference/data/async_request_db.py index d3853f5..9a24a2f 100644 --- a/src/data/async_request_db.py +++ b/packages/inference/inference/data/async_request_db.py @@ -12,8 +12,8 @@ from uuid import UUID from sqlalchemy import func, text from sqlmodel import Session, select -from src.data.database import get_session_context, create_db_and_tables, close_engine -from src.data.models import ApiKey, AsyncRequest, RateLimitEvent +from inference.data.database import get_session_context, create_db_and_tables, close_engine +from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent logger = logging.getLogger(__name__) diff --git a/src/data/database.py b/packages/inference/inference/data/database.py similarity index 91% rename from src/data/database.py rename to packages/inference/inference/data/database.py index d356653..7613b6f 100644 --- a/src/data/database.py +++ b/packages/inference/inference/data/database.py @@ -13,8 +13,7 @@ from sqlalchemy import text from sqlmodel import Session, SQLModel, create_engine import sys -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import get_db_connection_string +from shared.config import get_db_connection_string logger = logging.getLogger(__name__) @@ -52,8 +51,8 @@ def get_engine(): def create_db_and_tables() -> None: """Create all database tables.""" - from src.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401 - from src.data.admin_models import ( # noqa: F401 + from inference.data.models import ApiKey, AsyncRequest, RateLimitEvent # noqa: F401 + from inference.data.admin_models import ( # noqa: F401 AdminToken, AdminDocument, AdminAnnotation, diff --git a/src/data/models.py b/packages/inference/inference/data/models.py similarity index 100% rename from src/data/models.py rename to packages/inference/inference/data/models.py diff --git a/src/inference/__init__.py b/packages/inference/inference/pipeline/__init__.py similarity index 100% rename from src/inference/__init__.py rename to packages/inference/inference/pipeline/__init__.py diff --git a/src/inference/constants.py b/packages/inference/inference/pipeline/constants.py similarity index 98% rename from src/inference/constants.py rename to packages/inference/inference/pipeline/constants.py index ef8a14c..a462975 100644 --- a/src/inference/constants.py +++ b/packages/inference/inference/pipeline/constants.py @@ -92,7 +92,7 @@ constructors or methods. The values here serve as sensible defaults based on Swedish invoice processing requirements. Example: - from src.inference.constants import DEFAULT_CONFIDENCE_THRESHOLD + from inference.pipeline.constants import DEFAULT_CONFIDENCE_THRESHOLD detector = YOLODetector( model_path="model.pt", diff --git a/src/inference/customer_number_parser.py b/packages/inference/inference/pipeline/customer_number_parser.py similarity index 99% rename from src/inference/customer_number_parser.py rename to packages/inference/inference/pipeline/customer_number_parser.py index 39f2256..3492700 100644 --- a/src/inference/customer_number_parser.py +++ b/packages/inference/inference/pipeline/customer_number_parser.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, List -from src.exceptions import CustomerNumberParseError +from shared.exceptions import CustomerNumberParseError @dataclass diff --git a/src/inference/field_extractor.py b/packages/inference/inference/pipeline/field_extractor.py similarity index 99% rename from src/inference/field_extractor.py rename to packages/inference/inference/pipeline/field_extractor.py index c6e4938..4846f1c 100644 --- a/src/inference/field_extractor.py +++ b/packages/inference/inference/pipeline/field_extractor.py @@ -4,7 +4,7 @@ Field Extractor Module Extracts and validates field values from detected regions. This module is used during inference to extract values from OCR text. -It uses shared utilities from src.utils for text cleaning and validation. +It uses shared utilities from shared.utils for text cleaning and validation. Enhanced features: - Multi-source fusion with confidence weighting @@ -24,10 +24,10 @@ from PIL import Image from .yolo_detector import Detection, CLASS_TO_FIELD # Import shared utilities for text cleaning and validation -from src.utils.text_cleaner import TextCleaner -from src.utils.validators import FieldValidators -from src.utils.fuzzy_matcher import FuzzyMatcher -from src.utils.ocr_corrections import OCRCorrections +from shared.utils.text_cleaner import TextCleaner +from shared.utils.validators import FieldValidators +from shared.utils.fuzzy_matcher import FuzzyMatcher +from shared.utils.ocr_corrections import OCRCorrections # Import new unified parsers from .payment_line_parser import PaymentLineParser @@ -104,7 +104,7 @@ class FieldExtractor: def ocr_engine(self): """Lazy-load OCR engine only when needed.""" if self._ocr_engine is None: - from ..ocr import OCREngine + from shared.ocr import OCREngine self._ocr_engine = OCREngine(lang=self.ocr_lang) return self._ocr_engine diff --git a/src/inference/payment_line_parser.py b/packages/inference/inference/pipeline/payment_line_parser.py similarity index 99% rename from src/inference/payment_line_parser.py rename to packages/inference/inference/pipeline/payment_line_parser.py index e294652..86c9068 100644 --- a/src/inference/payment_line_parser.py +++ b/packages/inference/inference/pipeline/payment_line_parser.py @@ -21,7 +21,7 @@ import logging from dataclasses import dataclass from typing import Optional -from src.exceptions import PaymentLineParseError +from shared.exceptions import PaymentLineParseError @dataclass diff --git a/src/inference/pipeline.py b/packages/inference/inference/pipeline/pipeline.py similarity index 99% rename from src/inference/pipeline.py rename to packages/inference/inference/pipeline/pipeline.py index c865402..c9ade47 100644 --- a/src/inference/pipeline.py +++ b/packages/inference/inference/pipeline/pipeline.py @@ -144,7 +144,7 @@ class InferencePipeline: Returns: InferenceResult with extracted fields """ - from ..pdf.renderer import render_pdf_to_images + from shared.pdf.renderer import render_pdf_to_images from PIL import Image import io import numpy as np @@ -381,8 +381,8 @@ class InferencePipeline: def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None: """Run full-page OCR fallback.""" - from ..pdf.renderer import render_pdf_to_images - from ..ocr import OCREngine + from shared.pdf.renderer import render_pdf_to_images + from shared.ocr import OCREngine from PIL import Image import io import numpy as np diff --git a/src/inference/yolo_detector.py b/packages/inference/inference/pipeline/yolo_detector.py similarity index 98% rename from src/inference/yolo_detector.py rename to packages/inference/inference/pipeline/yolo_detector.py index 4732aaf..395dc7c 100644 --- a/src/inference/yolo_detector.py +++ b/packages/inference/inference/pipeline/yolo_detector.py @@ -189,7 +189,7 @@ class YOLODetector: Returns: Dict mapping page number to list of detections """ - from ..pdf.renderer import render_pdf_to_images + from shared.pdf.renderer import render_pdf_to_images from PIL import Image import io diff --git a/src/validation/__init__.py b/packages/inference/inference/validation/__init__.py similarity index 100% rename from src/validation/__init__.py rename to packages/inference/inference/validation/__init__.py diff --git a/src/validation/llm_validator.py b/packages/inference/inference/validation/llm_validator.py similarity index 99% rename from src/validation/llm_validator.py rename to packages/inference/inference/validation/llm_validator.py index 66a60b3..1ee3a92 100644 --- a/src/validation/llm_validator.py +++ b/packages/inference/inference/validation/llm_validator.py @@ -16,7 +16,7 @@ from datetime import datetime import psycopg2 from psycopg2.extras import execute_values -from src.config import DEFAULT_DPI +from shared.config import DEFAULT_DPI @dataclass diff --git a/src/web/__init__.py b/packages/inference/inference/web/__init__.py similarity index 100% rename from src/web/__init__.py rename to packages/inference/inference/web/__init__.py diff --git a/packages/inference/inference/web/admin_routes_new.py b/packages/inference/inference/web/admin_routes_new.py new file mode 100644 index 0000000..296827a --- /dev/null +++ b/packages/inference/inference/web/admin_routes_new.py @@ -0,0 +1,8 @@ +""" +Backward compatibility shim for admin_routes.py + +DEPRECATED: Import from inference.web.api.v1.admin.documents instead. +""" +from inference.web.api.v1.admin.documents import * + +__all__ = ["create_admin_router"] diff --git a/packages/inference/inference/web/api/__init__.py b/packages/inference/inference/web/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/inference/inference/web/api/v1/__init__.py b/packages/inference/inference/web/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/inference/inference/web/api/v1/admin/__init__.py b/packages/inference/inference/web/api/v1/admin/__init__.py new file mode 100644 index 0000000..8d5081b --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/__init__.py @@ -0,0 +1,19 @@ +""" +Admin API v1 + +Document management, annotations, and training endpoints. +""" + +from inference.web.api.v1.admin.annotations import create_annotation_router +from inference.web.api.v1.admin.auth import create_auth_router +from inference.web.api.v1.admin.documents import create_documents_router +from inference.web.api.v1.admin.locks import create_locks_router +from inference.web.api.v1.admin.training import create_training_router + +__all__ = [ + "create_annotation_router", + "create_auth_router", + "create_documents_router", + "create_locks_router", + "create_training_router", +] diff --git a/src/web/api/v1/admin/annotations.py b/packages/inference/inference/web/api/v1/admin/annotations.py similarity index 98% rename from src/web/api/v1/admin/annotations.py rename to packages/inference/inference/web/api/v1/admin/annotations.py index b67cb09..609db93 100644 --- a/src/web/api/v1/admin/annotations.py +++ b/packages/inference/inference/web/api/v1/admin/annotations.py @@ -12,11 +12,11 @@ from uuid import UUID from fastapi import APIRouter, HTTPException, Query from fastapi.responses import FileResponse -from src.data.admin_db import AdminDB -from src.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS -from src.web.core.auth import AdminTokenDep, AdminDBDep -from src.web.services.autolabel import get_auto_label_service -from src.web.schemas.admin import ( +from inference.data.admin_db import AdminDB +from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.services.autolabel import get_auto_label_service +from inference.web.schemas.admin import ( AnnotationCreate, AnnotationItem, AnnotationListResponse, @@ -31,7 +31,7 @@ from src.web.schemas.admin import ( AutoLabelResponse, BoundingBox, ) -from src.web.schemas.common import ErrorResponse +from inference.web.schemas.common import ErrorResponse logger = logging.getLogger(__name__) diff --git a/src/web/api/v1/admin/auth.py b/packages/inference/inference/web/api/v1/admin/auth.py similarity index 92% rename from src/web/api/v1/admin/auth.py rename to packages/inference/inference/web/api/v1/admin/auth.py index daee30f..913be49 100644 --- a/src/web/api/v1/admin/auth.py +++ b/packages/inference/inference/web/api/v1/admin/auth.py @@ -10,12 +10,12 @@ from datetime import datetime, timedelta from fastapi import APIRouter -from src.web.core.auth import AdminTokenDep, AdminDBDep -from src.web.schemas.admin import ( +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.schemas.admin import ( AdminTokenCreate, AdminTokenResponse, ) -from src.web.schemas.common import ErrorResponse +from inference.web.schemas.common import ErrorResponse logger = logging.getLogger(__name__) diff --git a/src/web/api/v1/admin/documents.py b/packages/inference/inference/web/api/v1/admin/documents.py similarity index 97% rename from src/web/api/v1/admin/documents.py rename to packages/inference/inference/web/api/v1/admin/documents.py index 3d48b2a..fd2f355 100644 --- a/src/web/api/v1/admin/documents.py +++ b/packages/inference/inference/web/api/v1/admin/documents.py @@ -11,9 +11,9 @@ from uuid import UUID from fastapi import APIRouter, File, HTTPException, Query, UploadFile -from src.web.config import DEFAULT_DPI, StorageConfig -from src.web.core.auth import AdminTokenDep, AdminDBDep -from src.web.schemas.admin import ( +from inference.web.config import DEFAULT_DPI, StorageConfig +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.schemas.admin import ( AnnotationItem, AnnotationSource, AutoLabelStatus, @@ -27,7 +27,7 @@ from src.web.schemas.admin import ( ModelMetrics, TrainingHistoryItem, ) -from src.web.schemas.common import ErrorResponse +from inference.web.schemas.common import ErrorResponse logger = logging.getLogger(__name__) @@ -142,8 +142,8 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: raise HTTPException(status_code=500, detail="Failed to save file") # Update file path in database - from src.data.database import get_session_context - from src.data.admin_models import AdminDocument + from inference.data.database import get_session_context + from inference.data.admin_models import AdminDocument with get_session_context() as session: doc = session.get(AdminDocument, UUID(document_id)) if doc: @@ -520,7 +520,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: # If marking as labeled, save annotations to PostgreSQL DocumentDB db_save_result = None if status == "labeled": - from src.web.services.db_autolabel import save_manual_annotations_to_document_db + from inference.web.services.db_autolabel import save_manual_annotations_to_document_db # Get all annotations for this document annotations = db.get_annotations_for_document(document_id) diff --git a/src/web/api/v1/admin/locks.py b/packages/inference/inference/web/api/v1/admin/locks.py similarity index 97% rename from src/web/api/v1/admin/locks.py rename to packages/inference/inference/web/api/v1/admin/locks.py index 1b5f46e..7e23393 100644 --- a/src/web/api/v1/admin/locks.py +++ b/packages/inference/inference/web/api/v1/admin/locks.py @@ -10,12 +10,12 @@ from uuid import UUID from fastapi import APIRouter, HTTPException, Query -from src.web.core.auth import AdminTokenDep, AdminDBDep -from src.web.schemas.admin import ( +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.schemas.admin import ( AnnotationLockRequest, AnnotationLockResponse, ) -from src.web.schemas.common import ErrorResponse +from inference.web.schemas.common import ErrorResponse logger = logging.getLogger(__name__) diff --git a/packages/inference/inference/web/api/v1/admin/training/__init__.py b/packages/inference/inference/web/api/v1/admin/training/__init__.py new file mode 100644 index 0000000..d2fba3c --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/training/__init__.py @@ -0,0 +1,28 @@ +""" +Admin Training API Routes + +FastAPI endpoints for training task management and scheduling. +""" + +from fastapi import APIRouter + +from ._utils import _validate_uuid +from .tasks import register_task_routes +from .documents import register_document_routes +from .export import register_export_routes +from .datasets import register_dataset_routes + + +def create_training_router() -> APIRouter: + """Create training API router.""" + router = APIRouter(prefix="/admin/training", tags=["Admin Training"]) + + register_task_routes(router) + register_document_routes(router) + register_export_routes(router) + register_dataset_routes(router) + + return router + + +__all__ = ["create_training_router", "_validate_uuid"] diff --git a/packages/inference/inference/web/api/v1/admin/training/_utils.py b/packages/inference/inference/web/api/v1/admin/training/_utils.py new file mode 100644 index 0000000..9cb70a8 --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/training/_utils.py @@ -0,0 +1,16 @@ +"""Shared utilities for training routes.""" + +from uuid import UUID + +from fastapi import HTTPException + + +def _validate_uuid(value: str, name: str = "ID") -> None: + """Validate UUID format.""" + try: + UUID(value) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid {name} format. Must be a valid UUID.", + ) diff --git a/packages/inference/inference/web/api/v1/admin/training/datasets.py b/packages/inference/inference/web/api/v1/admin/training/datasets.py new file mode 100644 index 0000000..a46c4b3 --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/training/datasets.py @@ -0,0 +1,209 @@ +"""Training Dataset Endpoints.""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, HTTPException, Query + +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.schemas.admin import ( + DatasetCreateRequest, + DatasetDetailResponse, + DatasetDocumentItem, + DatasetListItem, + DatasetListResponse, + DatasetResponse, + DatasetTrainRequest, + TrainingStatus, + TrainingTaskResponse, +) + +from ._utils import _validate_uuid + +logger = logging.getLogger(__name__) + + +def register_dataset_routes(router: APIRouter) -> None: + """Register dataset endpoints on the router.""" + + @router.post( + "/datasets", + response_model=DatasetResponse, + summary="Create training dataset", + description="Create a dataset from selected documents with train/val/test splits.", + ) + async def create_dataset( + request: DatasetCreateRequest, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> DatasetResponse: + """Create a training dataset from document IDs.""" + from pathlib import Path + from inference.web.services.dataset_builder import DatasetBuilder + + dataset = db.create_dataset( + name=request.name, + description=request.description, + train_ratio=request.train_ratio, + val_ratio=request.val_ratio, + seed=request.seed, + ) + + builder = DatasetBuilder(db=db, base_dir=Path("data/datasets")) + try: + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=request.document_ids, + train_ratio=request.train_ratio, + val_ratio=request.val_ratio, + seed=request.seed, + admin_images_dir=Path("data/admin_images"), + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return DatasetResponse( + dataset_id=str(dataset.dataset_id), + name=dataset.name, + status="ready", + message="Dataset created successfully", + ) + + @router.get( + "/datasets", + response_model=DatasetListResponse, + summary="List datasets", + ) + async def list_datasets( + admin_token: AdminTokenDep, + db: AdminDBDep, + status: Annotated[str | None, Query(description="Filter by status")] = None, + limit: Annotated[int, Query(ge=1, le=100)] = 20, + offset: Annotated[int, Query(ge=0)] = 0, + ) -> DatasetListResponse: + """List training datasets.""" + datasets, total = db.get_datasets(status=status, limit=limit, offset=offset) + return DatasetListResponse( + total=total, + limit=limit, + offset=offset, + datasets=[ + DatasetListItem( + dataset_id=str(d.dataset_id), + name=d.name, + description=d.description, + status=d.status, + total_documents=d.total_documents, + total_images=d.total_images, + total_annotations=d.total_annotations, + created_at=d.created_at, + ) + for d in datasets + ], + ) + + @router.get( + "/datasets/{dataset_id}", + response_model=DatasetDetailResponse, + summary="Get dataset detail", + ) + async def get_dataset( + dataset_id: str, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> DatasetDetailResponse: + """Get dataset details with document list.""" + _validate_uuid(dataset_id, "dataset_id") + dataset = db.get_dataset(dataset_id) + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + docs = db.get_dataset_documents(dataset_id) + return DatasetDetailResponse( + dataset_id=str(dataset.dataset_id), + name=dataset.name, + description=dataset.description, + status=dataset.status, + train_ratio=dataset.train_ratio, + val_ratio=dataset.val_ratio, + seed=dataset.seed, + total_documents=dataset.total_documents, + total_images=dataset.total_images, + total_annotations=dataset.total_annotations, + dataset_path=dataset.dataset_path, + error_message=dataset.error_message, + documents=[ + DatasetDocumentItem( + document_id=str(d.document_id), + split=d.split, + page_count=d.page_count, + annotation_count=d.annotation_count, + ) + for d in docs + ], + created_at=dataset.created_at, + updated_at=dataset.updated_at, + ) + + @router.delete( + "/datasets/{dataset_id}", + summary="Delete dataset", + ) + async def delete_dataset( + dataset_id: str, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> dict: + """Delete a dataset and its files.""" + import shutil + from pathlib import Path + + _validate_uuid(dataset_id, "dataset_id") + dataset = db.get_dataset(dataset_id) + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + if dataset.dataset_path: + dataset_dir = Path(dataset.dataset_path) + if dataset_dir.exists(): + shutil.rmtree(dataset_dir) + + db.delete_dataset(dataset_id) + return {"message": "Dataset deleted"} + + @router.post( + "/datasets/{dataset_id}/train", + response_model=TrainingTaskResponse, + summary="Start training from dataset", + ) + async def train_from_dataset( + dataset_id: str, + request: DatasetTrainRequest, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> TrainingTaskResponse: + """Create a training task from a dataset.""" + _validate_uuid(dataset_id, "dataset_id") + dataset = db.get_dataset(dataset_id) + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + if dataset.status != "ready": + raise HTTPException( + status_code=400, + detail=f"Dataset is not ready (status: {dataset.status})", + ) + + config_dict = request.config.model_dump() + task_id = db.create_training_task( + admin_token=admin_token, + name=request.name, + task_type="train", + config=config_dict, + dataset_id=str(dataset.dataset_id), + ) + + return TrainingTaskResponse( + task_id=task_id, + status=TrainingStatus.PENDING, + message="Training task created from dataset", + ) diff --git a/packages/inference/inference/web/api/v1/admin/training/documents.py b/packages/inference/inference/web/api/v1/admin/training/documents.py new file mode 100644 index 0000000..27e935a --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/training/documents.py @@ -0,0 +1,211 @@ +"""Training Documents and Models Endpoints.""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, HTTPException, Query + +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.schemas.admin import ( + ModelMetrics, + TrainingDocumentItem, + TrainingDocumentsResponse, + TrainingModelItem, + TrainingModelsResponse, + TrainingStatus, +) +from inference.web.schemas.common import ErrorResponse + +from ._utils import _validate_uuid + +logger = logging.getLogger(__name__) + + +def register_document_routes(router: APIRouter) -> None: + """Register training document and model endpoints on the router.""" + + @router.get( + "/documents", + response_model=TrainingDocumentsResponse, + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + }, + summary="Get documents for training", + description="Get labeled documents available for training with filtering options.", + ) + async def get_training_documents( + admin_token: AdminTokenDep, + db: AdminDBDep, + has_annotations: Annotated[ + bool, + Query(description="Only include documents with annotations"), + ] = True, + min_annotation_count: Annotated[ + int | None, + Query(ge=1, description="Minimum annotation count"), + ] = None, + exclude_used_in_training: Annotated[ + bool, + Query(description="Exclude documents already used in training"), + ] = False, + limit: Annotated[ + int, + Query(ge=1, le=100, description="Page size"), + ] = 100, + offset: Annotated[ + int, + Query(ge=0, description="Offset"), + ] = 0, + ) -> TrainingDocumentsResponse: + """Get documents available for training.""" + documents, total = db.get_documents_for_training( + admin_token=admin_token, + status="labeled", + has_annotations=has_annotations, + min_annotation_count=min_annotation_count, + exclude_used_in_training=exclude_used_in_training, + limit=limit, + offset=offset, + ) + + items = [] + for doc in documents: + annotations = db.get_annotations_for_document(str(doc.document_id)) + + sources = {"manual": 0, "auto": 0} + for ann in annotations: + if ann.source in sources: + sources[ann.source] += 1 + + training_links = db.get_document_training_tasks(doc.document_id) + used_in_training = [str(link.task_id) for link in training_links] + + items.append( + TrainingDocumentItem( + document_id=str(doc.document_id), + filename=doc.filename, + annotation_count=len(annotations), + annotation_sources=sources, + used_in_training=used_in_training, + last_modified=doc.updated_at, + ) + ) + + return TrainingDocumentsResponse( + total=total, + limit=limit, + offset=offset, + documents=items, + ) + + @router.get( + "/models/{task_id}/download", + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + 404: {"model": ErrorResponse, "description": "Model not found"}, + }, + summary="Download trained model", + description="Download trained model weights file.", + ) + async def download_model( + task_id: str, + admin_token: AdminTokenDep, + db: AdminDBDep, + ): + """Download trained model.""" + from fastapi.responses import FileResponse + from pathlib import Path + + _validate_uuid(task_id, "task_id") + + task = db.get_training_task_by_token(task_id, admin_token) + if task is None: + raise HTTPException( + status_code=404, + detail="Training task not found or does not belong to this token", + ) + + if not task.model_path: + raise HTTPException( + status_code=404, + detail="Model file not available for this task", + ) + + model_path = Path(task.model_path) + if not model_path.exists(): + raise HTTPException( + status_code=404, + detail="Model file not found on disk", + ) + + return FileResponse( + path=str(model_path), + media_type="application/octet-stream", + filename=f"{task.name}_model.pt", + ) + + @router.get( + "/models", + response_model=TrainingModelsResponse, + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + }, + summary="Get trained models", + description="Get list of trained models with metrics and download links.", + ) + async def get_training_models( + admin_token: AdminTokenDep, + db: AdminDBDep, + status: Annotated[ + str | None, + Query(description="Filter by status (completed, failed, etc.)"), + ] = None, + limit: Annotated[ + int, + Query(ge=1, le=100, description="Page size"), + ] = 20, + offset: Annotated[ + int, + Query(ge=0, description="Offset"), + ] = 0, + ) -> TrainingModelsResponse: + """Get list of trained models.""" + tasks, total = db.get_training_tasks_by_token( + admin_token=admin_token, + status=status if status else "completed", + limit=limit, + offset=offset, + ) + + items = [] + for task in tasks: + metrics = ModelMetrics( + mAP=task.metrics_mAP, + precision=task.metrics_precision, + recall=task.metrics_recall, + ) + + download_url = None + if task.model_path and task.status == "completed": + download_url = f"/api/v1/admin/training/models/{task.task_id}/download" + + items.append( + TrainingModelItem( + task_id=str(task.task_id), + name=task.name, + status=TrainingStatus(task.status), + document_count=task.document_count, + created_at=task.created_at, + completed_at=task.completed_at, + metrics=metrics, + model_path=task.model_path, + download_url=download_url, + ) + ) + + return TrainingModelsResponse( + total=total, + limit=limit, + offset=offset, + models=items, + ) diff --git a/packages/inference/inference/web/api/v1/admin/training/export.py b/packages/inference/inference/web/api/v1/admin/training/export.py new file mode 100644 index 0000000..6ce2cc3 --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/training/export.py @@ -0,0 +1,121 @@ +"""Training Export Endpoints.""" + +import logging +from datetime import datetime + +from fastapi import APIRouter, HTTPException + +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.schemas.admin import ( + ExportRequest, + ExportResponse, +) +from inference.web.schemas.common import ErrorResponse + +logger = logging.getLogger(__name__) + + +def register_export_routes(router: APIRouter) -> None: + """Register export endpoints on the router.""" + + @router.post( + "/export", + response_model=ExportResponse, + responses={ + 400: {"model": ErrorResponse, "description": "Invalid request"}, + 401: {"model": ErrorResponse, "description": "Invalid token"}, + }, + summary="Export annotations", + description="Export annotations in YOLO format for training.", + ) + async def export_annotations( + request: ExportRequest, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> ExportResponse: + """Export annotations for training.""" + from pathlib import Path + import shutil + + if request.format not in ("yolo", "coco", "voc"): + raise HTTPException( + status_code=400, + detail=f"Unsupported export format: {request.format}", + ) + + documents = db.get_labeled_documents_for_export(admin_token) + + if not documents: + raise HTTPException( + status_code=400, + detail="No labeled documents available for export", + ) + + export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}" + export_dir.mkdir(parents=True, exist_ok=True) + + (export_dir / "images" / "train").mkdir(parents=True, exist_ok=True) + (export_dir / "images" / "val").mkdir(parents=True, exist_ok=True) + (export_dir / "labels" / "train").mkdir(parents=True, exist_ok=True) + (export_dir / "labels" / "val").mkdir(parents=True, exist_ok=True) + + total_docs = len(documents) + train_count = int(total_docs * request.split_ratio) + train_docs = documents[:train_count] + val_docs = documents[train_count:] + + total_images = 0 + total_annotations = 0 + + for split, docs in [("train", train_docs), ("val", val_docs)]: + for doc in docs: + annotations = db.get_annotations_for_document(str(doc.document_id)) + + if not annotations: + continue + + for page_num in range(1, doc.page_count + 1): + page_annotations = [a for a in annotations if a.page_number == page_num] + + if not page_annotations and not request.include_images: + continue + + src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png" + if not src_image.exists(): + continue + + image_name = f"{doc.document_id}_page{page_num}.png" + dst_image = export_dir / "images" / split / image_name + shutil.copy(src_image, dst_image) + total_images += 1 + + label_name = f"{doc.document_id}_page{page_num}.txt" + label_path = export_dir / "labels" / split / label_name + + with open(label_path, "w") as f: + for ann in page_annotations: + line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n" + f.write(line) + total_annotations += 1 + + from inference.data.admin_models import FIELD_CLASSES + + yaml_content = f"""# Auto-generated YOLO dataset config +path: {export_dir.absolute()} +train: images/train +val: images/val + +nc: {len(FIELD_CLASSES)} +names: {list(FIELD_CLASSES.values())} +""" + (export_dir / "data.yaml").write_text(yaml_content) + + return ExportResponse( + status="completed", + export_path=str(export_dir), + total_images=total_images, + total_annotations=total_annotations, + train_count=len(train_docs), + val_count=len(val_docs), + message=f"Exported {total_images} images with {total_annotations} annotations", + ) diff --git a/packages/inference/inference/web/api/v1/admin/training/tasks.py b/packages/inference/inference/web/api/v1/admin/training/tasks.py new file mode 100644 index 0000000..9ed3da2 --- /dev/null +++ b/packages/inference/inference/web/api/v1/admin/training/tasks.py @@ -0,0 +1,263 @@ +"""Training Task Endpoints.""" + +import logging +from typing import Annotated + +from fastapi import APIRouter, HTTPException, Query + +from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.schemas.admin import ( + TrainingLogItem, + TrainingLogsResponse, + TrainingStatus, + TrainingTaskCreate, + TrainingTaskDetailResponse, + TrainingTaskItem, + TrainingTaskListResponse, + TrainingTaskResponse, + TrainingType, +) +from inference.web.schemas.common import ErrorResponse + +from ._utils import _validate_uuid + +logger = logging.getLogger(__name__) + + +def register_task_routes(router: APIRouter) -> None: + """Register training task endpoints on the router.""" + + @router.post( + "/tasks", + response_model=TrainingTaskResponse, + responses={ + 400: {"model": ErrorResponse, "description": "Invalid request"}, + 401: {"model": ErrorResponse, "description": "Invalid token"}, + }, + summary="Create training task", + description="Create a new training task.", + ) + async def create_training_task( + request: TrainingTaskCreate, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> TrainingTaskResponse: + """Create a new training task.""" + config_dict = request.config.model_dump() if request.config else {} + + task_id = db.create_training_task( + admin_token=admin_token, + name=request.name, + task_type=request.task_type.value, + description=request.description, + config=config_dict, + scheduled_at=request.scheduled_at, + cron_expression=request.cron_expression, + is_recurring=bool(request.cron_expression), + ) + + return TrainingTaskResponse( + task_id=task_id, + status=TrainingStatus.SCHEDULED if request.scheduled_at else TrainingStatus.PENDING, + message="Training task created successfully", + ) + + @router.get( + "/tasks", + response_model=TrainingTaskListResponse, + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + }, + summary="List training tasks", + description="List all training tasks.", + ) + async def list_training_tasks( + admin_token: AdminTokenDep, + db: AdminDBDep, + status: Annotated[ + str | None, + Query(description="Filter by status"), + ] = None, + limit: Annotated[ + int, + Query(ge=1, le=100, description="Page size"), + ] = 20, + offset: Annotated[ + int, + Query(ge=0, description="Offset"), + ] = 0, + ) -> TrainingTaskListResponse: + """List training tasks.""" + valid_statuses = ("pending", "scheduled", "running", "completed", "failed", "cancelled") + if status and status not in valid_statuses: + raise HTTPException( + status_code=400, + detail=f"Invalid status: {status}. Must be one of: {', '.join(valid_statuses)}", + ) + + tasks, total = db.get_training_tasks_by_token( + admin_token=admin_token, + status=status, + limit=limit, + offset=offset, + ) + + items = [ + TrainingTaskItem( + task_id=str(task.task_id), + name=task.name, + task_type=TrainingType(task.task_type), + status=TrainingStatus(task.status), + scheduled_at=task.scheduled_at, + is_recurring=task.is_recurring, + started_at=task.started_at, + completed_at=task.completed_at, + created_at=task.created_at, + ) + for task in tasks + ] + + return TrainingTaskListResponse( + total=total, + limit=limit, + offset=offset, + tasks=items, + ) + + @router.get( + "/tasks/{task_id}", + response_model=TrainingTaskDetailResponse, + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + 404: {"model": ErrorResponse, "description": "Task not found"}, + }, + summary="Get training task detail", + description="Get training task details.", + ) + async def get_training_task( + task_id: str, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> TrainingTaskDetailResponse: + """Get training task details.""" + _validate_uuid(task_id, "task_id") + + task = db.get_training_task_by_token(task_id, admin_token) + if task is None: + raise HTTPException( + status_code=404, + detail="Training task not found or does not belong to this token", + ) + + return TrainingTaskDetailResponse( + task_id=str(task.task_id), + name=task.name, + description=task.description, + task_type=TrainingType(task.task_type), + status=TrainingStatus(task.status), + config=task.config, + scheduled_at=task.scheduled_at, + cron_expression=task.cron_expression, + is_recurring=task.is_recurring, + started_at=task.started_at, + completed_at=task.completed_at, + error_message=task.error_message, + result_metrics=task.result_metrics, + model_path=task.model_path, + created_at=task.created_at, + ) + + @router.post( + "/tasks/{task_id}/cancel", + response_model=TrainingTaskResponse, + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + 404: {"model": ErrorResponse, "description": "Task not found"}, + 409: {"model": ErrorResponse, "description": "Cannot cancel task"}, + }, + summary="Cancel training task", + description="Cancel a pending or scheduled training task.", + ) + async def cancel_training_task( + task_id: str, + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> TrainingTaskResponse: + """Cancel a training task.""" + _validate_uuid(task_id, "task_id") + + task = db.get_training_task_by_token(task_id, admin_token) + if task is None: + raise HTTPException( + status_code=404, + detail="Training task not found or does not belong to this token", + ) + + if task.status not in ("pending", "scheduled"): + raise HTTPException( + status_code=409, + detail=f"Cannot cancel task with status: {task.status}", + ) + + success = db.cancel_training_task(task_id) + if not success: + raise HTTPException( + status_code=500, + detail="Failed to cancel training task", + ) + + return TrainingTaskResponse( + task_id=task_id, + status=TrainingStatus.CANCELLED, + message="Training task cancelled successfully", + ) + + @router.get( + "/tasks/{task_id}/logs", + response_model=TrainingLogsResponse, + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + 404: {"model": ErrorResponse, "description": "Task not found"}, + }, + summary="Get training logs", + description="Get training task logs.", + ) + async def get_training_logs( + task_id: str, + admin_token: AdminTokenDep, + db: AdminDBDep, + limit: Annotated[ + int, + Query(ge=1, le=500, description="Maximum logs to return"), + ] = 100, + offset: Annotated[ + int, + Query(ge=0, description="Offset"), + ] = 0, + ) -> TrainingLogsResponse: + """Get training logs.""" + _validate_uuid(task_id, "task_id") + + task = db.get_training_task_by_token(task_id, admin_token) + if task is None: + raise HTTPException( + status_code=404, + detail="Training task not found or does not belong to this token", + ) + + logs = db.get_training_logs(task_id, limit, offset) + + items = [ + TrainingLogItem( + level=log.level, + message=log.message, + details=log.details, + created_at=log.created_at, + ) + for log in logs + ] + + return TrainingLogsResponse( + task_id=task_id, + logs=items, + ) diff --git a/packages/inference/inference/web/api/v1/batch/__init__.py b/packages/inference/inference/web/api/v1/batch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/web/api/v1/batch/routes.py b/packages/inference/inference/web/api/v1/batch/routes.py similarity index 96% rename from src/web/api/v1/batch/routes.py rename to packages/inference/inference/web/api/v1/batch/routes.py index c97819c..2a29c75 100644 --- a/src/web/api/v1/batch/routes.py +++ b/packages/inference/inference/web/api/v1/batch/routes.py @@ -14,10 +14,10 @@ from uuid import UUID from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form from fastapi.responses import JSONResponse -from src.data.admin_db import AdminDB -from src.web.core.auth import validate_admin_token, get_admin_db -from src.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE -from src.web.workers.batch_queue import BatchTask, get_batch_queue +from inference.data.admin_db import AdminDB +from inference.web.core.auth import validate_admin_token, get_admin_db +from inference.web.services.batch_upload import BatchUploadService, MAX_COMPRESSED_SIZE, MAX_UNCOMPRESSED_SIZE +from inference.web.workers.batch_queue import BatchTask, get_batch_queue logger = logging.getLogger(__name__) diff --git a/packages/inference/inference/web/api/v1/public/__init__.py b/packages/inference/inference/web/api/v1/public/__init__.py new file mode 100644 index 0000000..15bd0c0 --- /dev/null +++ b/packages/inference/inference/web/api/v1/public/__init__.py @@ -0,0 +1,16 @@ +""" +Public API v1 + +Customer-facing endpoints for inference, async processing, and labeling. +""" + +from inference.web.api.v1.public.inference import create_inference_router +from inference.web.api.v1.public.async_api import create_async_router, set_async_service +from inference.web.api.v1.public.labeling import create_labeling_router + +__all__ = [ + "create_inference_router", + "create_async_router", + "set_async_service", + "create_labeling_router", +] diff --git a/src/web/api/v1/public/async_api.py b/packages/inference/inference/web/api/v1/public/async_api.py similarity index 98% rename from src/web/api/v1/public/async_api.py rename to packages/inference/inference/web/api/v1/public/async_api.py index 6d5e3f2..234a063 100644 --- a/src/web/api/v1/public/async_api.py +++ b/packages/inference/inference/web/api/v1/public/async_api.py @@ -11,13 +11,13 @@ from uuid import UUID from fastapi import APIRouter, File, HTTPException, Query, UploadFile -from src.web.dependencies import ( +from inference.web.dependencies import ( ApiKeyDep, AsyncDBDep, PollRateLimitDep, SubmitRateLimitDep, ) -from src.web.schemas.inference import ( +from inference.web.schemas.inference import ( AsyncRequestItem, AsyncRequestsListResponse, AsyncResultResponse, @@ -27,7 +27,7 @@ from src.web.schemas.inference import ( DetectionResult, InferenceResult, ) -from src.web.schemas.common import ErrorResponse +from inference.web.schemas.common import ErrorResponse def _validate_request_id(request_id: str) -> None: diff --git a/src/web/api/v1/public/inference.py b/packages/inference/inference/web/api/v1/public/inference.py similarity index 96% rename from src/web/api/v1/public/inference.py rename to packages/inference/inference/web/api/v1/public/inference.py index a3d0849..4861b66 100644 --- a/src/web/api/v1/public/inference.py +++ b/packages/inference/inference/web/api/v1/public/inference.py @@ -15,17 +15,17 @@ from typing import TYPE_CHECKING from fastapi import APIRouter, File, HTTPException, UploadFile, status from fastapi.responses import FileResponse -from src.web.schemas.inference import ( +from inference.web.schemas.inference import ( DetectionResult, HealthResponse, InferenceResponse, InferenceResult, ) -from src.web.schemas.common import ErrorResponse +from inference.web.schemas.common import ErrorResponse if TYPE_CHECKING: - from src.web.services import InferenceService - from src.web.config import StorageConfig + from inference.web.services import InferenceService + from inference.web.config import StorageConfig logger = logging.getLogger(__name__) diff --git a/src/web/api/v1/public/labeling.py b/packages/inference/inference/web/api/v1/public/labeling.py similarity index 96% rename from src/web/api/v1/public/labeling.py rename to packages/inference/inference/web/api/v1/public/labeling.py index 75e5125..f029036 100644 --- a/src/web/api/v1/public/labeling.py +++ b/packages/inference/inference/web/api/v1/public/labeling.py @@ -13,13 +13,13 @@ from typing import TYPE_CHECKING from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status -from src.data.admin_db import AdminDB -from src.web.schemas.labeling import PreLabelResponse -from src.web.schemas.common import ErrorResponse +from inference.data.admin_db import AdminDB +from inference.web.schemas.labeling import PreLabelResponse +from inference.web.schemas.common import ErrorResponse if TYPE_CHECKING: - from src.web.services import InferenceService - from src.web.config import StorageConfig + from inference.web.services import InferenceService + from inference.web.config import StorageConfig logger = logging.getLogger(__name__) diff --git a/src/web/app.py b/packages/inference/inference/web/app.py similarity index 96% rename from src/web/app.py rename to packages/inference/inference/web/app.py index 2f46fd8..2cfbfb6 100644 --- a/src/web/app.py +++ b/packages/inference/inference/web/app.py @@ -17,10 +17,10 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse from .config import AppConfig, default_config -from src.web.services import InferenceService +from inference.web.services import InferenceService # Public API imports -from src.web.api.v1.public import ( +from inference.web.api.v1.public import ( create_inference_router, create_async_router, set_async_service, @@ -28,28 +28,28 @@ from src.web.api.v1.public import ( ) # Async processing imports -from src.data.async_request_db import AsyncRequestDB -from src.web.workers.async_queue import AsyncTaskQueue -from src.web.services.async_processing import AsyncProcessingService -from src.web.dependencies import init_dependencies -from src.web.core.rate_limiter import RateLimiter +from inference.data.async_request_db import AsyncRequestDB +from inference.web.workers.async_queue import AsyncTaskQueue +from inference.web.services.async_processing import AsyncProcessingService +from inference.web.dependencies import init_dependencies +from inference.web.core.rate_limiter import RateLimiter # Admin API imports -from src.web.api.v1.admin import ( +from inference.web.api.v1.admin import ( create_annotation_router, create_auth_router, create_documents_router, create_locks_router, create_training_router, ) -from src.web.core.scheduler import start_scheduler, stop_scheduler -from src.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler +from inference.web.core.scheduler import start_scheduler, stop_scheduler +from inference.web.core.autolabel_scheduler import start_autolabel_scheduler, stop_autolabel_scheduler # Batch upload imports -from src.web.api.v1.batch.routes import router as batch_upload_router -from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue -from src.web.services.batch_upload import BatchUploadService -from src.data.admin_db import AdminDB +from inference.web.api.v1.batch.routes import router as batch_upload_router +from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue +from inference.web.services.batch_upload import BatchUploadService +from inference.data.admin_db import AdminDB if TYPE_CHECKING: from collections.abc import AsyncGenerator diff --git a/src/web/config.py b/packages/inference/inference/web/config.py similarity index 98% rename from src/web/config.py rename to packages/inference/inference/web/config.py index d3701e6..4eab8e3 100644 --- a/src/web/config.py +++ b/packages/inference/inference/web/config.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any -from src.config import DEFAULT_DPI, PATHS +from shared.config import DEFAULT_DPI, PATHS @dataclass(frozen=True) diff --git a/src/web/core/__init__.py b/packages/inference/inference/web/core/__init__.py similarity index 61% rename from src/web/core/__init__.py rename to packages/inference/inference/web/core/__init__.py index 44c32e1..39cd2d7 100644 --- a/src/web/core/__init__.py +++ b/packages/inference/inference/web/core/__init__.py @@ -4,10 +4,10 @@ Core Components Reusable core functionality: authentication, rate limiting, scheduling. """ -from src.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep -from src.web.core.rate_limiter import RateLimiter -from src.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler -from src.web.core.autolabel_scheduler import ( +from inference.web.core.auth import validate_admin_token, get_admin_db, AdminTokenDep, AdminDBDep +from inference.web.core.rate_limiter import RateLimiter +from inference.web.core.scheduler import start_scheduler, stop_scheduler, get_training_scheduler +from inference.web.core.autolabel_scheduler import ( start_autolabel_scheduler, stop_autolabel_scheduler, get_autolabel_scheduler, diff --git a/src/web/core/auth.py b/packages/inference/inference/web/core/auth.py similarity index 93% rename from src/web/core/auth.py rename to packages/inference/inference/web/core/auth.py index 0e23cac..0cc069f 100644 --- a/src/web/core/auth.py +++ b/packages/inference/inference/web/core/auth.py @@ -9,8 +9,8 @@ from typing import Annotated from fastapi import Depends, Header, HTTPException -from src.data.admin_db import AdminDB -from src.data.database import get_session_context +from inference.data.admin_db import AdminDB +from inference.data.database import get_session_context logger = logging.getLogger(__name__) diff --git a/src/web/core/autolabel_scheduler.py b/packages/inference/inference/web/core/autolabel_scheduler.py similarity index 97% rename from src/web/core/autolabel_scheduler.py rename to packages/inference/inference/web/core/autolabel_scheduler.py index a3b3c0d..ded452b 100644 --- a/src/web/core/autolabel_scheduler.py +++ b/packages/inference/inference/web/core/autolabel_scheduler.py @@ -8,8 +8,8 @@ import logging import threading from pathlib import Path -from src.data.admin_db import AdminDB -from src.web.services.db_autolabel import ( +from inference.data.admin_db import AdminDB +from inference.web.services.db_autolabel import ( get_pending_autolabel_documents, process_document_autolabel, ) diff --git a/src/web/rate_limiter.py b/packages/inference/inference/web/core/rate_limiter.py similarity index 99% rename from src/web/rate_limiter.py rename to packages/inference/inference/web/core/rate_limiter.py index 95297a9..eca8574 100644 --- a/src/web/rate_limiter.py +++ b/packages/inference/inference/web/core/rate_limiter.py @@ -13,7 +13,7 @@ from threading import Lock from typing import TYPE_CHECKING if TYPE_CHECKING: - from src.data.async_request_db import AsyncRequestDB + from inference.data.async_request_db import AsyncRequestDB logger = logging.getLogger(__name__) diff --git a/src/web/core/scheduler.py b/packages/inference/inference/web/core/scheduler.py similarity index 88% rename from src/web/core/scheduler.py rename to packages/inference/inference/web/core/scheduler.py index 42814f7..ec36469 100644 --- a/src/web/core/scheduler.py +++ b/packages/inference/inference/web/core/scheduler.py @@ -10,7 +10,7 @@ from datetime import datetime from pathlib import Path from typing import Any -from src.data.admin_db import AdminDB +from inference.data.admin_db import AdminDB logger = logging.getLogger(__name__) @@ -86,7 +86,8 @@ class TrainingScheduler: logger.info(f"Starting training task: {task_id}") try: - self._execute_task(task_id, task.config or {}) + dataset_id = getattr(task, "dataset_id", None) + self._execute_task(task_id, task.config or {}, dataset_id=dataset_id) except Exception as e: logger.error(f"Training task {task_id} failed: {e}") self._db.update_training_task_status( @@ -98,7 +99,9 @@ class TrainingScheduler: except Exception as e: logger.error(f"Error checking pending tasks: {e}") - def _execute_task(self, task_id: str, config: dict[str, Any]) -> None: + def _execute_task( + self, task_id: str, config: dict[str, Any], dataset_id: str | None = None + ) -> None: """Execute a training task.""" # Update status to running self._db.update_training_task_status(task_id, "running") @@ -114,17 +117,25 @@ class TrainingScheduler: device = config.get("device", "0") project_name = config.get("project_name", "invoice_fields") - # Export annotations for training - export_result = self._export_training_data(task_id) - if not export_result: - raise ValueError("Failed to export training data") - - data_yaml = export_result["data_yaml"] - - self._db.add_training_log( - task_id, "INFO", - f"Exported {export_result['total_images']} images for training", - ) + # Use dataset if available, otherwise export from scratch + if dataset_id: + dataset = self._db.get_dataset(dataset_id) + if not dataset or not dataset.dataset_path: + raise ValueError(f"Dataset {dataset_id} not found or has no path") + data_yaml = str(Path(dataset.dataset_path) / "data.yaml") + self._db.add_training_log( + task_id, "INFO", + f"Using pre-built dataset: {dataset.name} ({dataset.total_images} images)", + ) + else: + export_result = self._export_training_data(task_id) + if not export_result: + raise ValueError("Failed to export training data") + data_yaml = export_result["data_yaml"] + self._db.add_training_log( + task_id, "INFO", + f"Exported {export_result['total_images']} images for training", + ) # Run YOLO training result = self._run_yolo_training( @@ -157,7 +168,7 @@ class TrainingScheduler: """Export training data for a task.""" from pathlib import Path import shutil - from src.data.admin_models import FIELD_CLASSES + from inference.data.admin_models import FIELD_CLASSES # Get all labeled documents documents = self._db.get_labeled_documents_for_export() diff --git a/src/web/dependencies.py b/packages/inference/inference/web/dependencies.py similarity index 97% rename from src/web/dependencies.py rename to packages/inference/inference/web/dependencies.py index e33755a..bae0644 100644 --- a/src/web/dependencies.py +++ b/packages/inference/inference/web/dependencies.py @@ -9,8 +9,8 @@ from typing import Annotated from fastapi import Depends, Header, HTTPException, Request -from src.data.async_request_db import AsyncRequestDB -from src.web.rate_limiter import RateLimiter +from inference.data.async_request_db import AsyncRequestDB +from inference.web.rate_limiter import RateLimiter logger = logging.getLogger(__name__) diff --git a/src/web/core/rate_limiter.py b/packages/inference/inference/web/rate_limiter.py similarity index 99% rename from src/web/core/rate_limiter.py rename to packages/inference/inference/web/rate_limiter.py index 95297a9..eca8574 100644 --- a/src/web/core/rate_limiter.py +++ b/packages/inference/inference/web/rate_limiter.py @@ -13,7 +13,7 @@ from threading import Lock from typing import TYPE_CHECKING if TYPE_CHECKING: - from src.data.async_request_db import AsyncRequestDB + from inference.data.async_request_db import AsyncRequestDB logger = logging.getLogger(__name__) diff --git a/packages/inference/inference/web/schemas/__init__.py b/packages/inference/inference/web/schemas/__init__.py new file mode 100644 index 0000000..f9b8b64 --- /dev/null +++ b/packages/inference/inference/web/schemas/__init__.py @@ -0,0 +1,11 @@ +""" +API Schemas + +Pydantic models for request/response validation. +""" + +# Import everything from sub-modules for backward compatibility +from inference.web.schemas.common import * # noqa: F401, F403 +from inference.web.schemas.admin import * # noqa: F401, F403 +from inference.web.schemas.inference import * # noqa: F401, F403 +from inference.web.schemas.labeling import * # noqa: F401, F403 diff --git a/packages/inference/inference/web/schemas/admin/__init__.py b/packages/inference/inference/web/schemas/admin/__init__.py new file mode 100644 index 0000000..1300b4e --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/__init__.py @@ -0,0 +1,17 @@ +""" +Admin API Request/Response Schemas + +Pydantic models for admin API validation and serialization. +""" + +from .enums import * # noqa: F401, F403 +from .auth import * # noqa: F401, F403 +from .documents import * # noqa: F401, F403 +from .annotations import * # noqa: F401, F403 +from .training import * # noqa: F401, F403 +from .datasets import * # noqa: F401, F403 + +# Resolve forward references for DocumentDetailResponse +from .documents import DocumentDetailResponse + +DocumentDetailResponse.model_rebuild() diff --git a/packages/inference/inference/web/schemas/admin/annotations.py b/packages/inference/inference/web/schemas/admin/annotations.py new file mode 100644 index 0000000..eb43047 --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/annotations.py @@ -0,0 +1,152 @@ +"""Admin Annotation Schemas.""" + +from datetime import datetime + +from pydantic import BaseModel, Field + +from .enums import AnnotationSource + + +class BoundingBox(BaseModel): + """Bounding box coordinates.""" + + x: int = Field(..., ge=0, description="X coordinate (pixels)") + y: int = Field(..., ge=0, description="Y coordinate (pixels)") + width: int = Field(..., ge=1, description="Width (pixels)") + height: int = Field(..., ge=1, description="Height (pixels)") + + +class AnnotationCreate(BaseModel): + """Request to create an annotation.""" + + page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)") + class_id: int = Field(..., ge=0, le=9, description="Class ID (0-9)") + bbox: BoundingBox = Field(..., description="Bounding box in pixels") + text_value: str | None = Field(None, description="Text value (optional)") + + +class AnnotationUpdate(BaseModel): + """Request to update an annotation.""" + + class_id: int | None = Field(None, ge=0, le=9, description="New class ID") + bbox: BoundingBox | None = Field(None, description="New bounding box") + text_value: str | None = Field(None, description="New text value") + + +class AnnotationItem(BaseModel): + """Single annotation item.""" + + annotation_id: str = Field(..., description="Annotation UUID") + page_number: int = Field(..., ge=1, description="Page number") + class_id: int = Field(..., ge=0, le=9, description="Class ID") + class_name: str = Field(..., description="Class name") + bbox: BoundingBox = Field(..., description="Bounding box in pixels") + normalized_bbox: dict[str, float] = Field( + ..., description="Normalized bbox (x_center, y_center, width, height)" + ) + text_value: str | None = Field(None, description="Text value") + confidence: float | None = Field(None, ge=0, le=1, description="Confidence score") + source: AnnotationSource = Field(..., description="Annotation source") + created_at: datetime = Field(..., description="Creation timestamp") + + +class AnnotationResponse(BaseModel): + """Response for annotation operation.""" + + annotation_id: str = Field(..., description="Annotation UUID") + message: str = Field(..., description="Status message") + + +class AnnotationListResponse(BaseModel): + """Response for annotation list.""" + + document_id: str = Field(..., description="Document UUID") + page_count: int = Field(..., ge=1, description="Total pages") + total_annotations: int = Field(..., ge=0, description="Total annotations") + annotations: list[AnnotationItem] = Field( + default_factory=list, description="Annotation list" + ) + + +class AnnotationLockRequest(BaseModel): + """Request to acquire annotation lock.""" + + duration_seconds: int = Field( + default=300, + ge=60, + le=3600, + description="Lock duration in seconds (60-3600)", + ) + + +class AnnotationLockResponse(BaseModel): + """Response for annotation lock operation.""" + + document_id: str = Field(..., description="Document UUID") + locked: bool = Field(..., description="Whether lock was acquired/released") + lock_expires_at: datetime | None = Field( + None, description="Lock expiration time" + ) + message: str = Field(..., description="Status message") + + +class AutoLabelRequest(BaseModel): + """Request to trigger auto-labeling.""" + + field_values: dict[str, str] = Field( + ..., + description="Field values to match (e.g., {'invoice_number': '12345'})", + ) + replace_existing: bool = Field( + default=False, description="Replace existing auto annotations" + ) + + +class AutoLabelResponse(BaseModel): + """Response for auto-labeling.""" + + document_id: str = Field(..., description="Document UUID") + status: str = Field(..., description="Auto-labeling status") + annotations_created: int = Field( + default=0, ge=0, description="Number of annotations created" + ) + message: str = Field(..., description="Status message") + + +class AnnotationVerifyRequest(BaseModel): + """Request to verify an annotation.""" + + pass # No body needed, just POST to verify + + +class AnnotationVerifyResponse(BaseModel): + """Response for annotation verification.""" + + annotation_id: str = Field(..., description="Annotation UUID") + is_verified: bool = Field(..., description="Verification status") + verified_at: datetime = Field(..., description="Verification timestamp") + verified_by: str = Field(..., description="Admin token who verified") + message: str = Field(..., description="Status message") + + +class AnnotationOverrideRequest(BaseModel): + """Request to override an annotation.""" + + bbox: dict[str, int] | None = Field( + None, description="Updated bounding box {x, y, width, height}" + ) + text_value: str | None = Field(None, description="Updated text value") + class_id: int | None = Field(None, ge=0, le=9, description="Updated class ID") + class_name: str | None = Field(None, description="Updated class name") + reason: str | None = Field(None, description="Reason for override") + + +class AnnotationOverrideResponse(BaseModel): + """Response for annotation override.""" + + annotation_id: str = Field(..., description="Annotation UUID") + source: str = Field(..., description="New source (manual)") + override_source: str | None = Field(None, description="Original source (auto)") + original_annotation_id: str | None = Field(None, description="Original annotation ID") + message: str = Field(..., description="Status message") + history_id: str = Field(..., description="History record UUID") diff --git a/packages/inference/inference/web/schemas/admin/auth.py b/packages/inference/inference/web/schemas/admin/auth.py new file mode 100644 index 0000000..9b724e6 --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/auth.py @@ -0,0 +1,23 @@ +"""Admin Auth Schemas.""" + +from datetime import datetime + +from pydantic import BaseModel, Field + + +class AdminTokenCreate(BaseModel): + """Request to create an admin token.""" + + name: str = Field(..., min_length=1, max_length=255, description="Token name") + expires_in_days: int | None = Field( + None, ge=1, le=365, description="Token expiration in days (optional)" + ) + + +class AdminTokenResponse(BaseModel): + """Response with created admin token.""" + + token: str = Field(..., description="Admin token") + name: str = Field(..., description="Token name") + expires_at: datetime | None = Field(None, description="Token expiration time") + message: str = Field(..., description="Status message") diff --git a/packages/inference/inference/web/schemas/admin/datasets.py b/packages/inference/inference/web/schemas/admin/datasets.py new file mode 100644 index 0000000..f7e38c9 --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/datasets.py @@ -0,0 +1,85 @@ +"""Admin Dataset Schemas.""" + +from datetime import datetime + +from pydantic import BaseModel, Field + +from .training import TrainingConfig + + +class DatasetCreateRequest(BaseModel): + """Request to create a training dataset.""" + + name: str = Field(..., min_length=1, max_length=255, description="Dataset name") + description: str | None = Field(None, description="Optional description") + document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include") + train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio") + val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio") + seed: int = Field(42, description="Random seed for split") + + +class DatasetDocumentItem(BaseModel): + """Document within a dataset.""" + + document_id: str + split: str + page_count: int + annotation_count: int + + +class DatasetResponse(BaseModel): + """Response after creating a dataset.""" + + dataset_id: str + name: str + status: str + message: str + + +class DatasetDetailResponse(BaseModel): + """Detailed dataset info with documents.""" + + dataset_id: str + name: str + description: str | None + status: str + train_ratio: float + val_ratio: float + seed: int + total_documents: int + total_images: int + total_annotations: int + dataset_path: str | None + error_message: str | None + documents: list[DatasetDocumentItem] + created_at: datetime + updated_at: datetime + + +class DatasetListItem(BaseModel): + """Dataset in list view.""" + + dataset_id: str + name: str + description: str | None + status: str + total_documents: int + total_images: int + total_annotations: int + created_at: datetime + + +class DatasetListResponse(BaseModel): + """Paginated dataset list.""" + + total: int + limit: int + offset: int + datasets: list[DatasetListItem] + + +class DatasetTrainRequest(BaseModel): + """Request to start training from a dataset.""" + + name: str = Field(..., min_length=1, max_length=255, description="Training task name") + config: TrainingConfig = Field(..., description="Training configuration") diff --git a/packages/inference/inference/web/schemas/admin/documents.py b/packages/inference/inference/web/schemas/admin/documents.py new file mode 100644 index 0000000..fdf3874 --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/documents.py @@ -0,0 +1,103 @@ +"""Admin Document Schemas.""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from pydantic import BaseModel, Field + +from .enums import AutoLabelStatus, DocumentStatus + +if TYPE_CHECKING: + from .annotations import AnnotationItem + from .training import TrainingHistoryItem + + +class DocumentUploadResponse(BaseModel): + """Response for document upload.""" + + document_id: str = Field(..., description="Document UUID") + filename: str = Field(..., description="Original filename") + file_size: int = Field(..., ge=0, description="File size in bytes") + page_count: int = Field(..., ge=1, description="Number of pages") + status: DocumentStatus = Field(..., description="Document status") + auto_label_started: bool = Field( + default=False, description="Whether auto-labeling was started" + ) + message: str = Field(..., description="Status message") + + +class DocumentItem(BaseModel): + """Single document in list.""" + + document_id: str = Field(..., description="Document UUID") + filename: str = Field(..., description="Original filename") + file_size: int = Field(..., ge=0, description="File size in bytes") + page_count: int = Field(..., ge=1, description="Number of pages") + status: DocumentStatus = Field(..., description="Document status") + auto_label_status: AutoLabelStatus | None = Field( + None, description="Auto-labeling status" + ) + annotation_count: int = Field(default=0, ge=0, description="Number of annotations") + upload_source: str = Field(default="ui", description="Upload source (ui or api)") + batch_id: str | None = Field(None, description="Batch ID if uploaded via batch") + can_annotate: bool = Field(default=True, description="Whether document can be annotated") + created_at: datetime = Field(..., description="Creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + + +class DocumentListResponse(BaseModel): + """Response for document list.""" + + total: int = Field(..., ge=0, description="Total documents") + limit: int = Field(..., ge=1, description="Page size") + offset: int = Field(..., ge=0, description="Current offset") + documents: list[DocumentItem] = Field( + default_factory=list, description="Document list" + ) + + +class DocumentDetailResponse(BaseModel): + """Response for document detail.""" + + document_id: str = Field(..., description="Document UUID") + filename: str = Field(..., description="Original filename") + file_size: int = Field(..., ge=0, description="File size in bytes") + content_type: str = Field(..., description="MIME type") + page_count: int = Field(..., ge=1, description="Number of pages") + status: DocumentStatus = Field(..., description="Document status") + auto_label_status: AutoLabelStatus | None = Field( + None, description="Auto-labeling status" + ) + auto_label_error: str | None = Field(None, description="Auto-labeling error") + upload_source: str = Field(default="ui", description="Upload source (ui or api)") + batch_id: str | None = Field(None, description="Batch ID if uploaded via batch") + csv_field_values: dict[str, str] | None = Field( + None, description="CSV field values if uploaded via batch" + ) + can_annotate: bool = Field(default=True, description="Whether document can be annotated") + annotation_lock_until: datetime | None = Field( + None, description="Lock expiration time if document is locked" + ) + annotations: list["AnnotationItem"] = Field( + default_factory=list, description="Document annotations" + ) + image_urls: list[str] = Field( + default_factory=list, description="URLs to page images" + ) + training_history: list["TrainingHistoryItem"] = Field( + default_factory=list, description="Training tasks that used this document" + ) + created_at: datetime = Field(..., description="Creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + + +class DocumentStatsResponse(BaseModel): + """Document statistics response.""" + + total: int = Field(..., ge=0, description="Total documents") + pending: int = Field(default=0, ge=0, description="Pending documents") + auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents") + labeled: int = Field(default=0, ge=0, description="Labeled documents") + exported: int = Field(default=0, ge=0, description="Exported documents") diff --git a/packages/inference/inference/web/schemas/admin/enums.py b/packages/inference/inference/web/schemas/admin/enums.py new file mode 100644 index 0000000..c4ea592 --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/enums.py @@ -0,0 +1,46 @@ +"""Admin API Enums.""" + +from enum import Enum + + +class DocumentStatus(str, Enum): + """Document status enum.""" + + PENDING = "pending" + AUTO_LABELING = "auto_labeling" + LABELED = "labeled" + EXPORTED = "exported" + + +class AutoLabelStatus(str, Enum): + """Auto-labeling status enum.""" + + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class TrainingStatus(str, Enum): + """Training task status enum.""" + + PENDING = "pending" + SCHEDULED = "scheduled" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class TrainingType(str, Enum): + """Training task type enum.""" + + TRAIN = "train" + FINETUNE = "finetune" + + +class AnnotationSource(str, Enum): + """Annotation source enum.""" + + MANUAL = "manual" + AUTO = "auto" + IMPORTED = "imported" diff --git a/packages/inference/inference/web/schemas/admin/training.py b/packages/inference/inference/web/schemas/admin/training.py new file mode 100644 index 0000000..6958692 --- /dev/null +++ b/packages/inference/inference/web/schemas/admin/training.py @@ -0,0 +1,202 @@ +"""Admin Training Schemas.""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field + +from .enums import TrainingStatus, TrainingType + + +class TrainingConfig(BaseModel): + """Training configuration.""" + + model_name: str = Field(default="yolo11n.pt", description="Base model name") + epochs: int = Field(default=100, ge=1, le=1000, description="Training epochs") + batch_size: int = Field(default=16, ge=1, le=128, description="Batch size") + image_size: int = Field(default=640, ge=320, le=1280, description="Image size") + learning_rate: float = Field(default=0.01, gt=0, le=1, description="Learning rate") + device: str = Field(default="0", description="Device (0 for GPU, cpu for CPU)") + project_name: str = Field( + default="invoice_fields", description="Training project name" + ) + + +class TrainingTaskCreate(BaseModel): + """Request to create a training task.""" + + name: str = Field(..., min_length=1, max_length=255, description="Task name") + description: str | None = Field(None, max_length=1000, description="Description") + task_type: TrainingType = Field( + default=TrainingType.TRAIN, description="Task type" + ) + config: TrainingConfig = Field( + default_factory=TrainingConfig, description="Training configuration" + ) + scheduled_at: datetime | None = Field( + None, description="Scheduled execution time" + ) + cron_expression: str | None = Field( + None, max_length=50, description="Cron expression for recurring tasks" + ) + + +class TrainingTaskItem(BaseModel): + """Single training task in list.""" + + task_id: str = Field(..., description="Task UUID") + name: str = Field(..., description="Task name") + task_type: TrainingType = Field(..., description="Task type") + status: TrainingStatus = Field(..., description="Task status") + scheduled_at: datetime | None = Field(None, description="Scheduled time") + is_recurring: bool = Field(default=False, description="Is recurring task") + started_at: datetime | None = Field(None, description="Start time") + completed_at: datetime | None = Field(None, description="Completion time") + created_at: datetime = Field(..., description="Creation timestamp") + + +class TrainingTaskListResponse(BaseModel): + """Response for training task list.""" + + total: int = Field(..., ge=0, description="Total tasks") + limit: int = Field(..., ge=1, description="Page size") + offset: int = Field(..., ge=0, description="Current offset") + tasks: list[TrainingTaskItem] = Field(default_factory=list, description="Task list") + + +class TrainingTaskDetailResponse(BaseModel): + """Response for training task detail.""" + + task_id: str = Field(..., description="Task UUID") + name: str = Field(..., description="Task name") + description: str | None = Field(None, description="Description") + task_type: TrainingType = Field(..., description="Task type") + status: TrainingStatus = Field(..., description="Task status") + config: dict[str, Any] | None = Field(None, description="Training configuration") + scheduled_at: datetime | None = Field(None, description="Scheduled time") + cron_expression: str | None = Field(None, description="Cron expression") + is_recurring: bool = Field(default=False, description="Is recurring task") + started_at: datetime | None = Field(None, description="Start time") + completed_at: datetime | None = Field(None, description="Completion time") + error_message: str | None = Field(None, description="Error message") + result_metrics: dict[str, Any] | None = Field(None, description="Result metrics") + model_path: str | None = Field(None, description="Trained model path") + created_at: datetime = Field(..., description="Creation timestamp") + + +class TrainingTaskResponse(BaseModel): + """Response for training task operation.""" + + task_id: str = Field(..., description="Task UUID") + status: TrainingStatus = Field(..., description="Task status") + message: str = Field(..., description="Status message") + + +class TrainingLogItem(BaseModel): + """Single training log entry.""" + + level: str = Field(..., description="Log level") + message: str = Field(..., description="Log message") + details: dict[str, Any] | None = Field(None, description="Additional details") + created_at: datetime = Field(..., description="Timestamp") + + +class TrainingLogsResponse(BaseModel): + """Response for training logs.""" + + task_id: str = Field(..., description="Task UUID") + logs: list[TrainingLogItem] = Field(default_factory=list, description="Log entries") + + +class ExportRequest(BaseModel): + """Request to export annotations.""" + + format: str = Field( + default="yolo", description="Export format (yolo, coco, voc)" + ) + include_images: bool = Field( + default=True, description="Include images in export" + ) + split_ratio: float = Field( + default=0.8, ge=0.5, le=1.0, description="Train/val split ratio" + ) + + +class ExportResponse(BaseModel): + """Response for export operation.""" + + status: str = Field(..., description="Export status") + export_path: str = Field(..., description="Path to exported dataset") + total_images: int = Field(..., ge=0, description="Total images exported") + total_annotations: int = Field(..., ge=0, description="Total annotations") + train_count: int = Field(..., ge=0, description="Training set count") + val_count: int = Field(..., ge=0, description="Validation set count") + message: str = Field(..., description="Status message") + + +class TrainingDocumentItem(BaseModel): + """Document item for training page.""" + + document_id: str = Field(..., description="Document UUID") + filename: str = Field(..., description="Filename") + annotation_count: int = Field(..., ge=0, description="Total annotations") + annotation_sources: dict[str, int] = Field( + ..., description="Annotation counts by source (manual, auto)" + ) + used_in_training: list[str] = Field( + default_factory=list, description="List of training task IDs that used this document" + ) + last_modified: datetime = Field(..., description="Last modification time") + + +class TrainingDocumentsResponse(BaseModel): + """Response for GET /admin/training/documents.""" + + total: int = Field(..., ge=0, description="Total document count") + limit: int = Field(..., ge=1, le=100, description="Page size") + offset: int = Field(..., ge=0, description="Pagination offset") + documents: list[TrainingDocumentItem] = Field( + default_factory=list, description="Documents available for training" + ) + + +class ModelMetrics(BaseModel): + """Training model metrics.""" + + mAP: float | None = Field(None, ge=0.0, le=1.0, description="Mean Average Precision") + precision: float | None = Field(None, ge=0.0, le=1.0, description="Precision") + recall: float | None = Field(None, ge=0.0, le=1.0, description="Recall") + + +class TrainingModelItem(BaseModel): + """Trained model item for model list.""" + + task_id: str = Field(..., description="Training task UUID") + name: str = Field(..., description="Model name") + status: TrainingStatus = Field(..., description="Training status") + document_count: int = Field(..., ge=0, description="Documents used in training") + created_at: datetime = Field(..., description="Creation timestamp") + completed_at: datetime | None = Field(None, description="Completion timestamp") + metrics: ModelMetrics = Field(..., description="Model metrics") + model_path: str | None = Field(None, description="Path to model weights") + download_url: str | None = Field(None, description="Download URL for model") + + +class TrainingModelsResponse(BaseModel): + """Response for GET /admin/training/models.""" + + total: int = Field(..., ge=0, description="Total model count") + limit: int = Field(..., ge=1, le=100, description="Page size") + offset: int = Field(..., ge=0, description="Pagination offset") + models: list[TrainingModelItem] = Field( + default_factory=list, description="Trained models" + ) + + +class TrainingHistoryItem(BaseModel): + """Training history for a document.""" + + task_id: str = Field(..., description="Training task UUID") + name: str = Field(..., description="Training task name") + trained_at: datetime = Field(..., description="Training timestamp") + model_metrics: ModelMetrics | None = Field(None, description="Model metrics") diff --git a/src/web/schemas/common.py b/packages/inference/inference/web/schemas/common.py similarity index 100% rename from src/web/schemas/common.py rename to packages/inference/inference/web/schemas/common.py diff --git a/src/web/schemas/inference.py b/packages/inference/inference/web/schemas/inference.py similarity index 100% rename from src/web/schemas/inference.py rename to packages/inference/inference/web/schemas/inference.py diff --git a/src/web/schemas/labeling.py b/packages/inference/inference/web/schemas/labeling.py similarity index 100% rename from src/web/schemas/labeling.py rename to packages/inference/inference/web/schemas/labeling.py diff --git a/packages/inference/inference/web/services/__init__.py b/packages/inference/inference/web/services/__init__.py new file mode 100644 index 0000000..1cc6e54 --- /dev/null +++ b/packages/inference/inference/web/services/__init__.py @@ -0,0 +1,18 @@ +""" +Business Logic Services + +Service layer for processing requests and orchestrating data operations. +""" + +from inference.web.services.autolabel import AutoLabelService, get_auto_label_service +from inference.web.services.inference import InferenceService +from inference.web.services.async_processing import AsyncProcessingService +from inference.web.services.batch_upload import BatchUploadService + +__all__ = [ + "AutoLabelService", + "get_auto_label_service", + "InferenceService", + "AsyncProcessingService", + "BatchUploadService", +] diff --git a/src/web/services/async_processing.py b/packages/inference/inference/web/services/async_processing.py similarity index 97% rename from src/web/services/async_processing.py rename to packages/inference/inference/web/services/async_processing.py index 54e2e08..11a0173 100644 --- a/src/web/services/async_processing.py +++ b/packages/inference/inference/web/services/async_processing.py @@ -14,13 +14,13 @@ from pathlib import Path from threading import Event, Thread from typing import TYPE_CHECKING -from src.data.async_request_db import AsyncRequestDB -from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue -from src.web.core.rate_limiter import RateLimiter +from inference.data.async_request_db import AsyncRequestDB +from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue +from inference.web.core.rate_limiter import RateLimiter if TYPE_CHECKING: - from src.web.config import AsyncConfig, StorageConfig - from src.web.services.inference import InferenceService + from inference.web.config import AsyncConfig, StorageConfig + from inference.web.services.inference import InferenceService logger = logging.getLogger(__name__) diff --git a/src/web/services/autolabel.py b/packages/inference/inference/web/services/autolabel.py similarity index 96% rename from src/web/services/autolabel.py rename to packages/inference/inference/web/services/autolabel.py index 2d50380..a2f3728 100644 --- a/src/web/services/autolabel.py +++ b/packages/inference/inference/web/services/autolabel.py @@ -11,11 +11,11 @@ from typing import Any import numpy as np from PIL import Image -from src.config import DEFAULT_DPI -from src.data.admin_db import AdminDB -from src.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES -from src.matcher.field_matcher import FieldMatcher -from src.ocr.paddle_ocr import OCREngine, OCRToken +from shared.config import DEFAULT_DPI +from inference.data.admin_db import AdminDB +from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES +from shared.matcher.field_matcher import FieldMatcher +from shared.ocr.paddle_ocr import OCREngine, OCRToken logger = logging.getLogger(__name__) @@ -144,7 +144,7 @@ class AutoLabelService: db: AdminDB, ) -> int: """Process PDF document and create annotations.""" - from src.pdf.renderer import render_pdf_to_images + from shared.pdf.renderer import render_pdf_to_images import io total_annotations = 0 @@ -222,7 +222,7 @@ class AutoLabelService: image_height: int, ) -> list[dict[str, Any]]: """Find annotations for field values using token matching.""" - from src.normalize import normalize_field + from shared.normalize import normalize_field annotations = [] diff --git a/src/web/services/batch_upload.py b/packages/inference/inference/web/services/batch_upload.py similarity index 99% rename from src/web/services/batch_upload.py rename to packages/inference/inference/web/services/batch_upload.py index db15e3f..5ac903b 100644 --- a/src/web/services/batch_upload.py +++ b/packages/inference/inference/web/services/batch_upload.py @@ -15,8 +15,8 @@ from uuid import UUID from pydantic import BaseModel, Field, field_validator -from src.data.admin_db import AdminDB -from src.data.admin_models import CSV_TO_CLASS_MAPPING +from inference.data.admin_db import AdminDB +from inference.data.admin_models import CSV_TO_CLASS_MAPPING logger = logging.getLogger(__name__) diff --git a/packages/inference/inference/web/services/dataset_builder.py b/packages/inference/inference/web/services/dataset_builder.py new file mode 100644 index 0000000..30c69ce --- /dev/null +++ b/packages/inference/inference/web/services/dataset_builder.py @@ -0,0 +1,188 @@ +""" +Dataset Builder Service + +Creates training datasets by copying images from admin storage, +generating YOLO label files, and splitting into train/val/test sets. +""" + +import logging +import random +import shutil +from pathlib import Path + +import yaml + +from inference.data.admin_models import FIELD_CLASSES + +logger = logging.getLogger(__name__) + + +class DatasetBuilder: + """Builds YOLO training datasets from admin documents.""" + + def __init__(self, db, base_dir: Path): + self._db = db + self._base_dir = Path(base_dir) + + def build_dataset( + self, + dataset_id: str, + document_ids: list[str], + train_ratio: float, + val_ratio: float, + seed: int, + admin_images_dir: Path, + ) -> dict: + """Build a complete YOLO dataset from document IDs. + + Args: + dataset_id: UUID of the dataset record. + document_ids: List of document UUIDs to include. + train_ratio: Fraction for training set. + val_ratio: Fraction for validation set. + seed: Random seed for reproducible splits. + admin_images_dir: Root directory of admin images. + + Returns: + Summary dict with total_documents, total_images, total_annotations. + + Raises: + ValueError: If no valid documents found. + """ + try: + return self._do_build( + dataset_id, document_ids, train_ratio, val_ratio, seed, admin_images_dir + ) + except Exception as e: + self._db.update_dataset_status( + dataset_id=dataset_id, + status="failed", + error_message=str(e), + ) + raise + + def _do_build( + self, + dataset_id: str, + document_ids: list[str], + train_ratio: float, + val_ratio: float, + seed: int, + admin_images_dir: Path, + ) -> dict: + # 1. Fetch documents + documents = self._db.get_documents_by_ids(document_ids) + if not documents: + raise ValueError("No valid documents found for the given IDs") + + # 2. Create directory structure + dataset_dir = self._base_dir / dataset_id + for split in ["train", "val", "test"]: + (dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True) + (dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True) + + # 3. Shuffle and split documents + doc_list = list(documents) + rng = random.Random(seed) + rng.shuffle(doc_list) + + n = len(doc_list) + n_train = max(1, round(n * train_ratio)) + n_val = max(0, round(n * val_ratio)) + n_test = n - n_train - n_val + + splits = ( + ["train"] * n_train + + ["val"] * n_val + + ["test"] * n_test + ) + + # 4. Process each document + total_images = 0 + total_annotations = 0 + dataset_docs = [] + + for doc, split in zip(doc_list, splits): + doc_id = str(doc.document_id) + annotations = self._db.get_annotations_for_document(doc.document_id) + + # Group annotations by page + page_annotations: dict[int, list] = {} + for ann in annotations: + page_annotations.setdefault(ann.page_number, []).append(ann) + + doc_image_count = 0 + doc_ann_count = 0 + + # Copy images and write labels for each page + for page_num in range(1, doc.page_count + 1): + src_image = Path(admin_images_dir) / doc_id / f"page_{page_num}.png" + if not src_image.exists(): + logger.warning("Image not found: %s", src_image) + continue + + dst_name = f"{doc_id}_page{page_num}" + dst_image = dataset_dir / "images" / split / f"{dst_name}.png" + shutil.copy2(src_image, dst_image) + doc_image_count += 1 + + # Write YOLO label file + page_anns = page_annotations.get(page_num, []) + label_lines = [] + for ann in page_anns: + label_lines.append( + f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} " + f"{ann.width:.6f} {ann.height:.6f}" + ) + doc_ann_count += 1 + + label_path = dataset_dir / "labels" / split / f"{dst_name}.txt" + label_path.write_text("\n".join(label_lines)) + + total_images += doc_image_count + total_annotations += doc_ann_count + + dataset_docs.append({ + "document_id": doc_id, + "split": split, + "page_count": doc_image_count, + "annotation_count": doc_ann_count, + }) + + # 5. Record document-split assignments in DB + self._db.add_dataset_documents( + dataset_id=dataset_id, + documents=dataset_docs, + ) + + # 6. Generate data.yaml + self._generate_data_yaml(dataset_dir) + + # 7. Update dataset status + self._db.update_dataset_status( + dataset_id=dataset_id, + status="ready", + total_documents=len(doc_list), + total_images=total_images, + total_annotations=total_annotations, + dataset_path=str(dataset_dir), + ) + + return { + "total_documents": len(doc_list), + "total_images": total_images, + "total_annotations": total_annotations, + } + + def _generate_data_yaml(self, dataset_dir: Path) -> None: + """Generate YOLO data.yaml configuration file.""" + data = { + "path": str(dataset_dir.absolute()), + "train": "images/train", + "val": "images/val", + "test": "images/test", + "nc": len(FIELD_CLASSES), + "names": FIELD_CLASSES, + } + yaml_path = dataset_dir / "data.yaml" + yaml_path.write_text(yaml.dump(data, default_flow_style=False, allow_unicode=True)) diff --git a/src/web/services/db_autolabel.py b/packages/inference/inference/web/services/db_autolabel.py similarity index 96% rename from src/web/services/db_autolabel.py rename to packages/inference/inference/web/services/db_autolabel.py index 231c0fe..44da968 100644 --- a/src/web/services/db_autolabel.py +++ b/packages/inference/inference/web/services/db_autolabel.py @@ -11,11 +11,11 @@ import logging from pathlib import Path from typing import Any -from src.config import DEFAULT_DPI -from src.data.admin_db import AdminDB -from src.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING -from src.data.db import DocumentDB -from src.web.config import StorageConfig +from shared.config import DEFAULT_DPI +from inference.data.admin_db import AdminDB +from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING +from shared.data.db import DocumentDB +from inference.web.config import StorageConfig logger = logging.getLogger(__name__) @@ -81,8 +81,8 @@ def get_pending_autolabel_documents( List of AdminDocument records with status='auto_labeling' and auto_label_status='pending' """ from sqlmodel import select - from src.data.database import get_session_context - from src.data.admin_models import AdminDocument + from inference.data.database import get_session_context + from inference.data.admin_models import AdminDocument with get_session_context() as session: statement = select(AdminDocument).where( @@ -116,8 +116,8 @@ def process_document_autolabel( Returns: Result dictionary with success status and annotations """ - from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf - from src.pdf import PDFDocument + from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf + from shared.pdf import PDFDocument document_id = str(document.document_id) file_path = Path(document.file_path) @@ -247,7 +247,7 @@ def _save_annotations_to_db( Number of annotations saved """ from PIL import Image - from src.data.admin_models import FIELD_CLASS_IDS + from inference.data.admin_models import FIELD_CLASS_IDS # Mapping from CSV field names to internal field names CSV_TO_INTERNAL_FIELD: dict[str, str] = { @@ -480,7 +480,7 @@ def save_manual_annotations_to_document_db( pdf_type = "unknown" if pdf_path.exists(): try: - from src.pdf import PDFDocument + from shared.pdf import PDFDocument with PDFDocument(pdf_path) as pdf_doc: tokens = list(pdf_doc.extract_text_tokens(0)) pdf_type = "scanned" if len(tokens) < 10 else "text" diff --git a/src/web/services/inference.py b/packages/inference/inference/web/services/inference.py similarity index 97% rename from src/web/services/inference.py rename to packages/inference/inference/web/services/inference.py index c30a16a..f087569 100644 --- a/src/web/services/inference.py +++ b/packages/inference/inference/web/services/inference.py @@ -71,8 +71,8 @@ class InferenceService: start_time = time.time() try: - from src.inference.pipeline import InferencePipeline - from src.inference.yolo_detector import YOLODetector + from inference.pipeline.pipeline import InferencePipeline + from inference.pipeline.yolo_detector import YOLODetector # Initialize YOLO detector for visualization self._detector = YOLODetector( @@ -257,7 +257,7 @@ class InferenceService: def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path: """Save visualization for PDF (first page).""" - from src.pdf.renderer import render_pdf_to_images + from shared.pdf.renderer import render_pdf_to_images from ultralytics import YOLO import io diff --git a/src/web/workers/__init__.py b/packages/inference/inference/web/workers/__init__.py similarity index 75% rename from src/web/workers/__init__.py rename to packages/inference/inference/web/workers/__init__.py index 8b8834d..f76f8b8 100644 --- a/src/web/workers/__init__.py +++ b/packages/inference/inference/web/workers/__init__.py @@ -4,8 +4,8 @@ Background Task Queues Worker queues for asynchronous and batch processing. """ -from src.web.workers.async_queue import AsyncTaskQueue, AsyncTask -from src.web.workers.batch_queue import ( +from inference.web.workers.async_queue import AsyncTaskQueue, AsyncTask +from inference.web.workers.batch_queue import ( BatchTaskQueue, BatchTask, init_batch_queue, diff --git a/src/web/workers/async_queue.py b/packages/inference/inference/web/workers/async_queue.py similarity index 100% rename from src/web/workers/async_queue.py rename to packages/inference/inference/web/workers/async_queue.py diff --git a/src/web/workers/batch_queue.py b/packages/inference/inference/web/workers/batch_queue.py similarity index 100% rename from src/web/workers/batch_queue.py rename to packages/inference/inference/web/workers/batch_queue.py diff --git a/packages/inference/requirements.txt b/packages/inference/requirements.txt new file mode 100644 index 0000000..dcb9ff4 --- /dev/null +++ b/packages/inference/requirements.txt @@ -0,0 +1,8 @@ +-e ../shared +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +python-multipart>=0.0.6 +sqlmodel>=0.0.22 +ultralytics>=8.1.0 +httpx>=0.25.0 +openai>=1.0.0 diff --git a/packages/inference/run_server.py b/packages/inference/run_server.py new file mode 100644 index 0000000..09ef573 --- /dev/null +++ b/packages/inference/run_server.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python +""" +Quick start script for the web server. + +Usage: + python run_server.py + python run_server.py --port 8080 + python run_server.py --debug --reload +""" + +from inference.cli.serve import main + +if __name__ == "__main__": + main() diff --git a/packages/inference/setup.py b/packages/inference/setup.py new file mode 100644 index 0000000..359255b --- /dev/null +++ b/packages/inference/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +setup( + name="invoice-inference", + version="0.1.0", + packages=find_packages(), + python_requires=">=3.11", + install_requires=[ + "invoice-shared", + "fastapi>=0.104.0", + "uvicorn[standard]>=0.24.0", + "python-multipart>=0.0.6", + "sqlmodel>=0.0.22", + "ultralytics>=8.1.0", + "httpx>=0.25.0", + ], +) diff --git a/packages/shared/requirements.txt b/packages/shared/requirements.txt new file mode 100644 index 0000000..0e4fdc9 --- /dev/null +++ b/packages/shared/requirements.txt @@ -0,0 +1,9 @@ +PyMuPDF>=1.23.0 +paddleocr>=2.7.0 +Pillow>=10.0.0 +numpy>=1.24.0 +opencv-python>=4.8.0 +psycopg2-binary>=2.9.0 +python-dotenv>=1.0.0 +pyyaml>=6.0 +thefuzz>=0.20.0 diff --git a/packages/shared/setup.py b/packages/shared/setup.py new file mode 100644 index 0000000..018e6c6 --- /dev/null +++ b/packages/shared/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup, find_packages + +setup( + name="invoice-shared", + version="0.1.0", + packages=find_packages(), + python_requires=">=3.11", + install_requires=[ + "PyMuPDF>=1.23.0", + "paddleocr>=2.7.0", + "Pillow>=10.0.0", + "numpy>=1.24.0", + "opencv-python>=4.8.0", + "psycopg2-binary>=2.9.0", + "python-dotenv>=1.0.0", + "pyyaml>=6.0", + "thefuzz>=0.20.0", + ], +) diff --git a/src/__init__.py b/packages/shared/shared/__init__.py similarity index 100% rename from src/__init__.py rename to packages/shared/shared/__init__.py diff --git a/src/config.py b/packages/shared/shared/config.py similarity index 84% rename from src/config.py rename to packages/shared/shared/config.py index 6f183f5..425f0e2 100644 --- a/src/config.py +++ b/packages/shared/shared/config.py @@ -7,10 +7,16 @@ import platform from pathlib import Path from dotenv import load_dotenv -# Load environment variables from .env file -# .env is at project root, config.py is in src/ -env_path = Path(__file__).parent.parent / '.env' -load_dotenv(dotenv_path=env_path) +# Load environment variables from .env file at project root +# Walk up from packages/shared/shared/config.py to find project root +_config_dir = Path(__file__).parent +for _candidate in [_config_dir.parent.parent.parent, _config_dir.parent.parent, _config_dir.parent]: + _env_path = _candidate / '.env' + if _env_path.exists(): + load_dotenv(dotenv_path=_env_path) + break +else: + load_dotenv() # fallback: search cwd and parents # Global DPI setting - must match training DPI for optimal model performance DEFAULT_DPI = 150 diff --git a/packages/shared/shared/data/__init__.py b/packages/shared/shared/data/__init__.py new file mode 100644 index 0000000..dec60f4 --- /dev/null +++ b/packages/shared/shared/data/__init__.py @@ -0,0 +1,3 @@ +from .csv_loader import CSVLoader, InvoiceRow + +__all__ = ['CSVLoader', 'InvoiceRow'] diff --git a/src/data/csv_loader.py b/packages/shared/shared/data/csv_loader.py similarity index 100% rename from src/data/csv_loader.py rename to packages/shared/shared/data/csv_loader.py diff --git a/src/data/db.py b/packages/shared/shared/data/db.py similarity index 99% rename from src/data/db.py rename to packages/shared/shared/data/db.py index ff22a98..36106f9 100644 --- a/src/data/db.py +++ b/packages/shared/shared/data/db.py @@ -9,8 +9,7 @@ from typing import Set, Dict, Any, Optional import sys from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import get_db_connection_string +from shared.config import get_db_connection_string class DocumentDB: diff --git a/src/exceptions.py b/packages/shared/shared/exceptions.py similarity index 100% rename from src/exceptions.py rename to packages/shared/shared/exceptions.py diff --git a/src/matcher/__init__.py b/packages/shared/shared/matcher/__init__.py similarity index 100% rename from src/matcher/__init__.py rename to packages/shared/shared/matcher/__init__.py diff --git a/src/matcher/context.py b/packages/shared/shared/matcher/context.py similarity index 100% rename from src/matcher/context.py rename to packages/shared/shared/matcher/context.py diff --git a/src/matcher/field_matcher.py b/packages/shared/shared/matcher/field_matcher.py similarity index 100% rename from src/matcher/field_matcher.py rename to packages/shared/shared/matcher/field_matcher.py diff --git a/src/matcher/field_matcher_old.py b/packages/shared/shared/matcher/field_matcher_old.py similarity index 100% rename from src/matcher/field_matcher_old.py rename to packages/shared/shared/matcher/field_matcher_old.py diff --git a/src/matcher/models.py b/packages/shared/shared/matcher/models.py similarity index 100% rename from src/matcher/models.py rename to packages/shared/shared/matcher/models.py diff --git a/src/matcher/strategies/__init__.py b/packages/shared/shared/matcher/strategies/__init__.py similarity index 100% rename from src/matcher/strategies/__init__.py rename to packages/shared/shared/matcher/strategies/__init__.py diff --git a/src/matcher/strategies/base.py b/packages/shared/shared/matcher/strategies/base.py similarity index 100% rename from src/matcher/strategies/base.py rename to packages/shared/shared/matcher/strategies/base.py diff --git a/src/matcher/strategies/concatenated_matcher.py b/packages/shared/shared/matcher/strategies/concatenated_matcher.py similarity index 100% rename from src/matcher/strategies/concatenated_matcher.py rename to packages/shared/shared/matcher/strategies/concatenated_matcher.py diff --git a/src/matcher/strategies/exact_matcher.py b/packages/shared/shared/matcher/strategies/exact_matcher.py similarity index 100% rename from src/matcher/strategies/exact_matcher.py rename to packages/shared/shared/matcher/strategies/exact_matcher.py diff --git a/src/matcher/strategies/flexible_date_matcher.py b/packages/shared/shared/matcher/strategies/flexible_date_matcher.py similarity index 100% rename from src/matcher/strategies/flexible_date_matcher.py rename to packages/shared/shared/matcher/strategies/flexible_date_matcher.py diff --git a/src/matcher/strategies/fuzzy_matcher.py b/packages/shared/shared/matcher/strategies/fuzzy_matcher.py similarity index 100% rename from src/matcher/strategies/fuzzy_matcher.py rename to packages/shared/shared/matcher/strategies/fuzzy_matcher.py diff --git a/src/matcher/strategies/substring_matcher.py b/packages/shared/shared/matcher/strategies/substring_matcher.py similarity index 100% rename from src/matcher/strategies/substring_matcher.py rename to packages/shared/shared/matcher/strategies/substring_matcher.py diff --git a/src/matcher/token_index.py b/packages/shared/shared/matcher/token_index.py similarity index 100% rename from src/matcher/token_index.py rename to packages/shared/shared/matcher/token_index.py diff --git a/src/matcher/utils.py b/packages/shared/shared/matcher/utils.py similarity index 100% rename from src/matcher/utils.py rename to packages/shared/shared/matcher/utils.py diff --git a/src/normalize/__init__.py b/packages/shared/shared/normalize/__init__.py similarity index 100% rename from src/normalize/__init__.py rename to packages/shared/shared/normalize/__init__.py diff --git a/src/normalize/normalizer.py b/packages/shared/shared/normalize/normalizer.py similarity index 99% rename from src/normalize/normalizer.py rename to packages/shared/shared/normalize/normalizer.py index 9bb48a5..eda9183 100644 --- a/src/normalize/normalizer.py +++ b/packages/shared/shared/normalize/normalizer.py @@ -9,7 +9,7 @@ Each normalizer is a separate, reusable module that can be used independently. from dataclasses import dataclass from typing import Callable -from src.utils.text_cleaner import TextCleaner +from shared.utils.text_cleaner import TextCleaner # Import individual normalizers from .normalizers import ( diff --git a/src/normalize/normalizers/__init__.py b/packages/shared/shared/normalize/normalizers/__init__.py similarity index 100% rename from src/normalize/normalizers/__init__.py rename to packages/shared/shared/normalize/normalizers/__init__.py diff --git a/src/normalize/normalizers/amount_normalizer.py b/packages/shared/shared/normalize/normalizers/amount_normalizer.py similarity index 100% rename from src/normalize/normalizers/amount_normalizer.py rename to packages/shared/shared/normalize/normalizers/amount_normalizer.py diff --git a/src/normalize/normalizers/bankgiro_normalizer.py b/packages/shared/shared/normalize/normalizers/bankgiro_normalizer.py similarity index 89% rename from src/normalize/normalizers/bankgiro_normalizer.py rename to packages/shared/shared/normalize/normalizers/bankgiro_normalizer.py index 2fe3cad..4879293 100644 --- a/src/normalize/normalizers/bankgiro_normalizer.py +++ b/packages/shared/shared/normalize/normalizers/bankgiro_normalizer.py @@ -5,8 +5,8 @@ Normalizes Swedish Bankgiro account numbers. """ from .base import BaseNormalizer -from src.utils.format_variants import FormatVariants -from src.utils.text_cleaner import TextCleaner +from shared.utils.format_variants import FormatVariants +from shared.utils.text_cleaner import TextCleaner class BankgiroNormalizer(BaseNormalizer): diff --git a/src/normalize/normalizers/base.py b/packages/shared/shared/normalize/normalizers/base.py similarity index 94% rename from src/normalize/normalizers/base.py rename to packages/shared/shared/normalize/normalizers/base.py index 9586b1e..4e99dcf 100644 --- a/src/normalize/normalizers/base.py +++ b/packages/shared/shared/normalize/normalizers/base.py @@ -3,7 +3,7 @@ Base class for field normalizers. """ from abc import ABC, abstractmethod -from src.utils.text_cleaner import TextCleaner +from shared.utils.text_cleaner import TextCleaner class BaseNormalizer(ABC): diff --git a/src/normalize/normalizers/customer_number_normalizer.py b/packages/shared/shared/normalize/normalizers/customer_number_normalizer.py similarity index 100% rename from src/normalize/normalizers/customer_number_normalizer.py rename to packages/shared/shared/normalize/normalizers/customer_number_normalizer.py diff --git a/src/normalize/normalizers/date_normalizer.py b/packages/shared/shared/normalize/normalizers/date_normalizer.py similarity index 100% rename from src/normalize/normalizers/date_normalizer.py rename to packages/shared/shared/normalize/normalizers/date_normalizer.py diff --git a/src/normalize/normalizers/invoice_number_normalizer.py b/packages/shared/shared/normalize/normalizers/invoice_number_normalizer.py similarity index 100% rename from src/normalize/normalizers/invoice_number_normalizer.py rename to packages/shared/shared/normalize/normalizers/invoice_number_normalizer.py diff --git a/src/normalize/normalizers/ocr_normalizer.py b/packages/shared/shared/normalize/normalizers/ocr_normalizer.py similarity index 100% rename from src/normalize/normalizers/ocr_normalizer.py rename to packages/shared/shared/normalize/normalizers/ocr_normalizer.py diff --git a/src/normalize/normalizers/organisation_number_normalizer.py b/packages/shared/shared/normalize/normalizers/organisation_number_normalizer.py similarity index 92% rename from src/normalize/normalizers/organisation_number_normalizer.py rename to packages/shared/shared/normalize/normalizers/organisation_number_normalizer.py index 3a4c003..778b525 100644 --- a/src/normalize/normalizers/organisation_number_normalizer.py +++ b/packages/shared/shared/normalize/normalizers/organisation_number_normalizer.py @@ -5,8 +5,8 @@ Normalizes Swedish organisation numbers and VAT numbers. """ from .base import BaseNormalizer -from src.utils.format_variants import FormatVariants -from src.utils.text_cleaner import TextCleaner +from shared.utils.format_variants import FormatVariants +from shared.utils.text_cleaner import TextCleaner class OrganisationNumberNormalizer(BaseNormalizer): diff --git a/src/normalize/normalizers/plusgiro_normalizer.py b/packages/shared/shared/normalize/normalizers/plusgiro_normalizer.py similarity index 89% rename from src/normalize/normalizers/plusgiro_normalizer.py rename to packages/shared/shared/normalize/normalizers/plusgiro_normalizer.py index ec4f788..7939a3b 100644 --- a/src/normalize/normalizers/plusgiro_normalizer.py +++ b/packages/shared/shared/normalize/normalizers/plusgiro_normalizer.py @@ -5,8 +5,8 @@ Normalizes Swedish Plusgiro account numbers. """ from .base import BaseNormalizer -from src.utils.format_variants import FormatVariants -from src.utils.text_cleaner import TextCleaner +from shared.utils.format_variants import FormatVariants +from shared.utils.text_cleaner import TextCleaner class PlusgiroNormalizer(BaseNormalizer): diff --git a/src/normalize/normalizers/supplier_accounts_normalizer.py b/packages/shared/shared/normalize/normalizers/supplier_accounts_normalizer.py similarity index 100% rename from src/normalize/normalizers/supplier_accounts_normalizer.py rename to packages/shared/shared/normalize/normalizers/supplier_accounts_normalizer.py diff --git a/src/ocr/__init__.py b/packages/shared/shared/ocr/__init__.py similarity index 100% rename from src/ocr/__init__.py rename to packages/shared/shared/ocr/__init__.py diff --git a/src/ocr/machine_code_parser.py b/packages/shared/shared/ocr/machine_code_parser.py similarity index 99% rename from src/ocr/machine_code_parser.py rename to packages/shared/shared/ocr/machine_code_parser.py index 06a2755..b008e2b 100644 --- a/src/ocr/machine_code_parser.py +++ b/packages/shared/shared/ocr/machine_code_parser.py @@ -41,8 +41,8 @@ import re from dataclasses import dataclass, field from typing import Optional -from src.pdf.extractor import Token as TextToken -from src.utils.validators import FieldValidators +from shared.pdf.extractor import Token as TextToken +from shared.utils.validators import FieldValidators @dataclass @@ -848,7 +848,7 @@ class MachineCodeParser: ... } """ - from src.normalize import normalize_field + from shared.normalize import normalize_field results = {} diff --git a/src/ocr/paddle_ocr.py b/packages/shared/shared/ocr/paddle_ocr.py similarity index 100% rename from src/ocr/paddle_ocr.py rename to packages/shared/shared/ocr/paddle_ocr.py diff --git a/src/pdf/__init__.py b/packages/shared/shared/pdf/__init__.py similarity index 100% rename from src/pdf/__init__.py rename to packages/shared/shared/pdf/__init__.py diff --git a/src/pdf/detector.py b/packages/shared/shared/pdf/detector.py similarity index 100% rename from src/pdf/detector.py rename to packages/shared/shared/pdf/detector.py diff --git a/src/pdf/extractor.py b/packages/shared/shared/pdf/extractor.py similarity index 100% rename from src/pdf/extractor.py rename to packages/shared/shared/pdf/extractor.py diff --git a/src/pdf/renderer.py b/packages/shared/shared/pdf/renderer.py similarity index 100% rename from src/pdf/renderer.py rename to packages/shared/shared/pdf/renderer.py diff --git a/src/utils/__init__.py b/packages/shared/shared/utils/__init__.py similarity index 100% rename from src/utils/__init__.py rename to packages/shared/shared/utils/__init__.py diff --git a/src/utils/context_extractor.py b/packages/shared/shared/utils/context_extractor.py similarity index 100% rename from src/utils/context_extractor.py rename to packages/shared/shared/utils/context_extractor.py diff --git a/src/utils/format_variants.py b/packages/shared/shared/utils/format_variants.py similarity index 100% rename from src/utils/format_variants.py rename to packages/shared/shared/utils/format_variants.py diff --git a/src/utils/fuzzy_matcher.py b/packages/shared/shared/utils/fuzzy_matcher.py similarity index 100% rename from src/utils/fuzzy_matcher.py rename to packages/shared/shared/utils/fuzzy_matcher.py diff --git a/src/utils/ocr_corrections.py b/packages/shared/shared/utils/ocr_corrections.py similarity index 100% rename from src/utils/ocr_corrections.py rename to packages/shared/shared/utils/ocr_corrections.py diff --git a/src/utils/text_cleaner.py b/packages/shared/shared/utils/text_cleaner.py similarity index 100% rename from src/utils/text_cleaner.py rename to packages/shared/shared/utils/text_cleaner.py diff --git a/src/utils/validators.py b/packages/shared/shared/utils/validators.py similarity index 100% rename from src/utils/validators.py rename to packages/shared/shared/utils/validators.py diff --git a/packages/training/Dockerfile b/packages/training/Dockerfile new file mode 100644 index 0000000..2be3d62 --- /dev/null +++ b/packages/training/Dockerfile @@ -0,0 +1,20 @@ +FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install shared package +COPY packages/shared /app/packages/shared +RUN pip install --no-cache-dir -e /app/packages/shared + +# Install training package +COPY packages/training /app/packages/training +RUN pip install --no-cache-dir -e /app/packages/training + +WORKDIR /app/packages/training + +CMD ["python", "run_training.py", "--task-id", "${TASK_ID}"] diff --git a/packages/training/requirements.txt b/packages/training/requirements.txt new file mode 100644 index 0000000..248445f --- /dev/null +++ b/packages/training/requirements.txt @@ -0,0 +1,4 @@ +-e ../shared +ultralytics>=8.1.0 +tqdm>=4.65.0 +torch>=2.0.0 diff --git a/packages/training/run_training.py b/packages/training/run_training.py new file mode 100644 index 0000000..73aad98 --- /dev/null +++ b/packages/training/run_training.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +""" +Training Service Entry Point. + +Runs a specific training task by ID (for Azure ACI on-demand mode) +or polls the database for pending tasks (for local dev). +""" + +import argparse +import logging +import sys +import time + +from training.data.training_db import TrainingTaskDB + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + + +def execute_training_task(db: TrainingTaskDB, task: dict) -> None: + """Execute a single training task.""" + task_id = task["task_id"] + config = task.get("config") or {} + + logger.info("Starting training task %s with config: %s", task_id, config) + db.update_status(task_id, "running") + + try: + from training.cli.train import run_training + + result = run_training( + epochs=config.get("epochs", 100), + batch=config.get("batch_size", 16), + model=config.get("base_model", "yolo11n.pt"), + imgsz=config.get("imgsz", 1280), + name=config.get("name", f"training_{task_id[:8]}"), + ) + + db.complete_task( + task_id, + model_path=result.get("model_path", ""), + metrics=result.get("metrics", {}), + ) + logger.info("Training task %s completed successfully.", task_id) + + except Exception as e: + logger.exception("Training task %s failed", task_id) + db.fail_task(task_id, str(e)) + sys.exit(1) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Invoice Training Service") + parser.add_argument( + "--task-id", + help="Specific task ID to run (ACI on-demand mode)", + ) + parser.add_argument( + "--poll", + action="store_true", + help="Poll database for pending tasks (local dev mode)", + ) + parser.add_argument( + "--poll-interval", + type=int, + default=60, + help="Seconds between polls (default: 60)", + ) + args = parser.parse_args() + + db = TrainingTaskDB() + + if args.task_id: + task = db.get_task(args.task_id) + if not task: + logger.error("Task %s not found", args.task_id) + sys.exit(1) + execute_training_task(db, task) + + elif args.poll: + logger.info( + "Starting training service in poll mode (interval=%ds)", + args.poll_interval, + ) + while True: + tasks = db.get_pending_tasks(limit=1) + for task in tasks: + execute_training_task(db, task) + time.sleep(args.poll_interval) + + else: + parser.print_help() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/packages/training/setup.py b/packages/training/setup.py new file mode 100644 index 0000000..56125c9 --- /dev/null +++ b/packages/training/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name="invoice-training", + version="0.1.0", + packages=find_packages(), + python_requires=">=3.11", + install_requires=[ + "invoice-shared", + "ultralytics>=8.1.0", + "tqdm>=4.65.0", + ], +) diff --git a/packages/training/training/__init__.py b/packages/training/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/training/training/cli/__init__.py b/packages/training/training/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cli/analyze_labels.py b/packages/training/training/cli/analyze_labels.py similarity index 98% rename from src/cli/analyze_labels.py rename to packages/training/training/cli/analyze_labels.py index 55df436..c8e9b0f 100644 --- a/src/cli/analyze_labels.py +++ b/packages/training/training/cli/analyze_labels.py @@ -15,14 +15,13 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Optional -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import get_db_connection_string +from shared.config import get_db_connection_string -from ..normalize import normalize_field -from ..matcher import FieldMatcher -from ..pdf import is_text_pdf, extract_text_tokens -from ..yolo.annotation_generator import FIELD_CLASSES -from ..data.db import DocumentDB +from shared.normalize import normalize_field +from shared.matcher import FieldMatcher +from shared.pdf import is_text_pdf, extract_text_tokens +from training.yolo.annotation_generator import FIELD_CLASSES +from shared.data.db import DocumentDB @dataclass diff --git a/src/cli/analyze_report.py b/packages/training/training/cli/analyze_report.py similarity index 99% rename from src/cli/analyze_report.py rename to packages/training/training/cli/analyze_report.py index 366b1e4..d739e98 100644 --- a/src/cli/analyze_report.py +++ b/packages/training/training/cli/analyze_report.py @@ -11,13 +11,12 @@ import sys from collections import defaultdict from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import get_db_connection_string +from shared.config import get_db_connection_string def load_reports_from_db() -> dict: """Load statistics directly from database using SQL aggregation.""" - from ..data.db import DocumentDB + from shared.data.db import DocumentDB db = DocumentDB() db.connect() diff --git a/src/cli/autolabel.py b/packages/training/training/cli/autolabel.py similarity index 96% rename from src/cli/autolabel.py rename to packages/training/training/cli/autolabel.py index 7c10391..09791a4 100644 --- a/src/cli/autolabel.py +++ b/packages/training/training/cli/autolabel.py @@ -33,8 +33,7 @@ def _signal_handler(signum, frame): if sys.platform == 'win32': multiprocessing.set_start_method('spawn', force=True) -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import get_db_connection_string, PATHS, AUTOLABEL +from shared.config import get_db_connection_string, PATHS, AUTOLABEL # Global OCR engine for worker processes (initialized once per worker) _worker_ocr_engine = None @@ -81,7 +80,7 @@ def _get_ocr_engine(): # Suppress warnings during OCR initialization with warnings.catch_warnings(): warnings.filterwarnings('ignore') - from ..ocr import OCREngine + from shared.ocr import OCREngine _worker_ocr_engine = OCREngine() return _worker_ocr_engine @@ -112,10 +111,10 @@ def process_single_document(args_tuple): row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple # Import inside worker to avoid pickling issues - from ..data import AutoLabelReport - from ..pdf import PDFDocument - from ..yolo.annotation_generator import FIELD_CLASSES - from ..processing.document_processor import process_page, record_unmatched_fields + from training.data.autolabel_report import AutoLabelReport + from shared.pdf import PDFDocument + from training.yolo.annotation_generator import FIELD_CLASSES + from training.processing.document_processor import process_page, record_unmatched_fields start_time = time.time() pdf_path = Path(pdf_path_str) @@ -336,14 +335,14 @@ def main(): signal.signal(signal.SIGTERM, _signal_handler) # Import here to avoid slow startup - from ..data import CSVLoader, AutoLabelReport, FieldMatchResult - from ..data.autolabel_report import ReportWriter - from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens - from ..pdf.renderer import get_render_dimensions - from ..ocr import OCREngine - from ..matcher import FieldMatcher - from ..normalize import normalize_field - from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES + from shared.data import CSVLoader + from training.data.autolabel_report import AutoLabelReport, FieldMatchResult, ReportWriter + from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens + from shared.pdf.renderer import get_render_dimensions + from shared.ocr import OCREngine + from shared.matcher import FieldMatcher + from shared.normalize import normalize_field + from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES # Handle comma-separated CSV paths csv_input = args.csv @@ -367,7 +366,7 @@ def main(): report_writer = ReportWriter(args.report, max_records_per_file=args.max_records) # Database connection for checking existing documents - from ..data.db import DocumentDB + from shared.data.db import DocumentDB db = DocumentDB() db.connect() db.create_tables() # Ensure tables exist @@ -450,8 +449,8 @@ def main(): use_dual_pool = args.cpu_workers is not None if use_dual_pool: - from src.processing import DualPoolCoordinator - from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf + from training.processing import DualPoolCoordinator + from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers") dual_pool_coordinator = DualPoolCoordinator( diff --git a/src/cli/import_report_to_db.py b/packages/training/training/cli/import_report_to_db.py similarity index 99% rename from src/cli/import_report_to_db.py rename to packages/training/training/cli/import_report_to_db.py index 1cb058b..49a0bec 100644 --- a/src/cli/import_report_to_db.py +++ b/packages/training/training/cli/import_report_to_db.py @@ -15,8 +15,7 @@ import psycopg2 from psycopg2.extras import execute_values # Add project root to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import get_db_connection_string, PATHS +from shared.config import get_db_connection_string, PATHS def create_tables(conn): diff --git a/src/cli/reprocess_failed.py b/packages/training/training/cli/reprocess_failed.py similarity index 98% rename from src/cli/reprocess_failed.py rename to packages/training/training/cli/reprocess_failed.py index e551317..03576fd 100644 --- a/src/cli/reprocess_failed.py +++ b/packages/training/training/cli/reprocess_failed.py @@ -15,12 +15,11 @@ from datetime import datetime from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError from tqdm import tqdm -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import DEFAULT_DPI -from src.data.db import DocumentDB -from src.data.csv_loader import CSVLoader -from src.normalize.normalizer import normalize_field +from shared.config import DEFAULT_DPI +from shared.data.db import DocumentDB +from shared.data.csv_loader import CSVLoader +from shared.normalize.normalizer import normalize_field def create_failed_match_table(db: DocumentDB): @@ -131,8 +130,8 @@ def process_single_document(args): # Try to extract OCR from PDF try: if pdf_path and os.path.exists(pdf_path): - from src.pdf import PDFDocument - from src.ocr import OCREngine + from shared.pdf import PDFDocument + from shared.ocr import OCREngine pdf_doc = PDFDocument(pdf_path) is_scanned = pdf_doc.detect_type() == "scanned" diff --git a/src/cli/train.py b/packages/training/training/cli/train.py similarity index 96% rename from src/cli/train.py rename to packages/training/training/cli/train.py index afb4ba0..ca64863 100644 --- a/src/cli/train.py +++ b/packages/training/training/cli/train.py @@ -10,8 +10,7 @@ import argparse import sys from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from src.config import DEFAULT_DPI, PATHS +from shared.config import DEFAULT_DPI, PATHS def main(): @@ -151,14 +150,14 @@ def main(): print(f"Document limit: {args.limit}") # Connect to database - from ..data.db import DocumentDB + from shared.data.db import DocumentDB print("\nConnecting to database...") db = DocumentDB() db.connect() # Create datasets from database - from ..yolo.db_dataset import create_datasets + from training.yolo.db_dataset import create_datasets print("Loading dataset from database...") datasets = create_datasets( @@ -189,7 +188,7 @@ def main(): print(f" {split_name}: {count} items exported") # Generate YOLO config files - from ..yolo.annotation_generator import AnnotationGenerator + from training.yolo.annotation_generator import AnnotationGenerator AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt') AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml') diff --git a/src/cli/validate.py b/packages/training/training/cli/validate.py similarity index 99% rename from src/cli/validate.py rename to packages/training/training/cli/validate.py index ce3183f..4a5c0de 100644 --- a/src/cli/validate.py +++ b/packages/training/training/cli/validate.py @@ -10,7 +10,6 @@ import argparse import sys from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) def main(): @@ -74,7 +73,7 @@ def main(): parser.print_help() return - from src.validation import LLMValidator + from inference.validation import LLMValidator validator = LLMValidator() validator.connect() diff --git a/packages/training/training/data/__init__.py b/packages/training/training/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/autolabel_report.py b/packages/training/training/data/autolabel_report.py similarity index 100% rename from src/data/autolabel_report.py rename to packages/training/training/data/autolabel_report.py diff --git a/packages/training/training/data/training_db.py b/packages/training/training/data/training_db.py new file mode 100644 index 0000000..08100a3 --- /dev/null +++ b/packages/training/training/data/training_db.py @@ -0,0 +1,134 @@ +"""Database operations for training tasks.""" + +import json +import logging +from datetime import datetime, timezone + +import psycopg2 +import psycopg2.extras + +from shared.config import get_db_connection_string + +logger = logging.getLogger(__name__) + + +class TrainingTaskDB: + """Read/write training_tasks table.""" + + def _connect(self): + return psycopg2.connect(get_db_connection_string()) + + def get_task(self, task_id: str) -> dict | None: + """Get a single training task by ID.""" + conn = self._connect() + try: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute( + "SELECT * FROM training_tasks WHERE task_id = %s", + (task_id,), + ) + return cur.fetchone() + finally: + conn.close() + + def get_pending_tasks(self, limit: int = 1) -> list[dict]: + """Get pending tasks ordered by creation time.""" + conn = self._connect() + try: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute( + """ + SELECT * FROM training_tasks + WHERE status = 'pending' + ORDER BY created_at ASC + LIMIT %s + """, + (limit,), + ) + return cur.fetchall() + finally: + conn.close() + + def update_status(self, task_id: str, status: str) -> None: + """Update task status with timestamp.""" + conn = self._connect() + try: + with conn.cursor() as cur: + if status == "running": + cur.execute( + "UPDATE training_tasks SET status = %s, started_at = %s WHERE task_id = %s", + (status, datetime.now(timezone.utc), task_id), + ) + else: + cur.execute( + "UPDATE training_tasks SET status = %s WHERE task_id = %s", + (status, task_id), + ) + conn.commit() + finally: + conn.close() + + def complete_task( + self, task_id: str, model_path: str, metrics: dict + ) -> None: + """Mark task as completed with results.""" + conn = self._connect() + try: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE training_tasks + SET status = 'completed', + completed_at = %s, + model_path = %s, + metrics = %s + WHERE task_id = %s + """, + ( + datetime.now(timezone.utc), + model_path, + json.dumps(metrics), + task_id, + ), + ) + conn.commit() + finally: + conn.close() + + def fail_task(self, task_id: str, error_message: str) -> None: + """Mark task as failed.""" + conn = self._connect() + try: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE training_tasks + SET status = 'failed', + completed_at = %s, + error_message = %s + WHERE task_id = %s + """, + (datetime.now(timezone.utc), error_message[:2000], task_id), + ) + conn.commit() + finally: + conn.close() + + def create_task(self, config: dict) -> str: + """Create a new training task. Returns task_id.""" + conn = self._connect() + try: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO training_tasks (config) + VALUES (%s) + RETURNING task_id + """, + (json.dumps(config),), + ) + task_id = str(cur.fetchone()[0]) + conn.commit() + return task_id + finally: + conn.close() diff --git a/src/processing/__init__.py b/packages/training/training/processing/__init__.py similarity index 51% rename from src/processing/__init__.py rename to packages/training/training/processing/__init__.py index 3604260..3c6186f 100644 --- a/src/processing/__init__.py +++ b/packages/training/training/processing/__init__.py @@ -5,11 +5,11 @@ This module provides a robust dual-pool architecture for processing documents with both CPU-bound and GPU-bound tasks. """ -from src.processing.worker_pool import WorkerPool, TaskResult -from src.processing.cpu_pool import CPUWorkerPool -from src.processing.gpu_pool import GPUWorkerPool -from src.processing.task_dispatcher import TaskDispatcher, TaskType -from src.processing.dual_pool_coordinator import DualPoolCoordinator +from training.processing.worker_pool import WorkerPool, TaskResult +from training.processing.cpu_pool import CPUWorkerPool +from training.processing.gpu_pool import GPUWorkerPool +from training.processing.task_dispatcher import TaskDispatcher, TaskType +from training.processing.dual_pool_coordinator import DualPoolCoordinator __all__ = [ "WorkerPool", diff --git a/src/processing/autolabel_tasks.py b/packages/training/training/processing/autolabel_tasks.py similarity index 94% rename from src/processing/autolabel_tasks.py rename to packages/training/training/processing/autolabel_tasks.py index bdd8855..df012ab 100644 --- a/src/processing/autolabel_tasks.py +++ b/packages/training/training/processing/autolabel_tasks.py @@ -12,7 +12,7 @@ import warnings from pathlib import Path from typing import Any, Dict, Optional -from src.config import DEFAULT_DPI +from shared.config import DEFAULT_DPI # Global OCR instance (initialized once per GPU worker process) _ocr_engine: Optional[Any] = None @@ -57,7 +57,7 @@ def _get_ocr_engine(): if _ocr_engine is None: with warnings.catch_warnings(): warnings.filterwarnings("ignore") - from src.ocr import OCREngine + from shared.ocr import OCREngine _ocr_engine = OCREngine() return _ocr_engine @@ -88,10 +88,10 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: Result dictionary with success status, annotations, and report. """ import shutil - from src.data import AutoLabelReport - from src.pdf import PDFDocument - from src.yolo.annotation_generator import FIELD_CLASSES - from src.processing.document_processor import process_page, record_unmatched_fields + from training.data.autolabel_report import AutoLabelReport + from shared.pdf import PDFDocument + from training.yolo.annotation_generator import FIELD_CLASSES + from training.processing.document_processor import process_page, record_unmatched_fields row_dict = task_data["row_dict"] pdf_path = Path(task_data["pdf_path"]) @@ -206,10 +206,10 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: Result dictionary with success status, annotations, and report. """ import shutil - from src.data import AutoLabelReport - from src.pdf import PDFDocument - from src.yolo.annotation_generator import FIELD_CLASSES - from src.processing.document_processor import process_page, record_unmatched_fields + from training.data.autolabel_report import AutoLabelReport + from shared.pdf import PDFDocument + from training.yolo.annotation_generator import FIELD_CLASSES + from training.processing.document_processor import process_page, record_unmatched_fields row_dict = task_data["row_dict"] pdf_path = Path(task_data["pdf_path"]) diff --git a/src/processing/cpu_pool.py b/packages/training/training/processing/cpu_pool.py similarity index 97% rename from src/processing/cpu_pool.py rename to packages/training/training/processing/cpu_pool.py index bc96176..eed5e90 100644 --- a/src/processing/cpu_pool.py +++ b/packages/training/training/processing/cpu_pool.py @@ -11,7 +11,7 @@ import logging import os from typing import Callable, Optional -from src.processing.worker_pool import WorkerPool +from training.processing.worker_pool import WorkerPool logger = logging.getLogger(__name__) diff --git a/src/processing/document_processor.py b/packages/training/training/processing/document_processor.py similarity index 97% rename from src/processing/document_processor.py rename to packages/training/training/processing/document_processor.py index 7462449..fb099a9 100644 --- a/src/processing/document_processor.py +++ b/packages/training/training/processing/document_processor.py @@ -11,11 +11,11 @@ from __future__ import annotations from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple -from ..data import FieldMatchResult -from ..matcher import FieldMatcher -from ..normalize import normalize_field -from ..ocr.machine_code_parser import MachineCodeParser -from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES +from training.data.autolabel_report import FieldMatchResult +from shared.matcher import FieldMatcher +from shared.normalize import normalize_field +from shared.ocr.machine_code_parser import MachineCodeParser +from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES def match_supplier_accounts( @@ -286,7 +286,7 @@ def match_standard_fields( # Fallback: Amount not found via token matching, but payment_line # successfully extracted a matching amount. Use payment_line bbox. # This handles cases where text PDFs merge multiple values into one token. - from src.matcher.field_matcher import Match + from shared.matcher.field_matcher import Match fallback_match = Match( field='Amount', diff --git a/src/processing/dual_pool_coordinator.py b/packages/training/training/processing/dual_pool_coordinator.py similarity index 97% rename from src/processing/dual_pool_coordinator.py rename to packages/training/training/processing/dual_pool_coordinator.py index 406c3f8..9a9326b 100644 --- a/src/processing/dual_pool_coordinator.py +++ b/packages/training/training/processing/dual_pool_coordinator.py @@ -13,10 +13,10 @@ from concurrent.futures import Future, TimeoutError, as_completed from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional -from src.processing.cpu_pool import CPUWorkerPool -from src.processing.gpu_pool import GPUWorkerPool -from src.processing.task_dispatcher import Task, TaskDispatcher, TaskType -from src.processing.worker_pool import TaskResult +from training.processing.cpu_pool import CPUWorkerPool +from training.processing.gpu_pool import GPUWorkerPool +from training.processing.task_dispatcher import Task, TaskDispatcher, TaskType +from training.processing.worker_pool import TaskResult logger = logging.getLogger(__name__) diff --git a/src/processing/gpu_pool.py b/packages/training/training/processing/gpu_pool.py similarity index 98% rename from src/processing/gpu_pool.py rename to packages/training/training/processing/gpu_pool.py index f9c0f58..7f4286e 100644 --- a/src/processing/gpu_pool.py +++ b/packages/training/training/processing/gpu_pool.py @@ -10,7 +10,7 @@ import logging import os from typing import Any, Callable, Optional -from src.processing.worker_pool import WorkerPool +from training.processing.worker_pool import WorkerPool logger = logging.getLogger(__name__) diff --git a/src/processing/task_dispatcher.py b/packages/training/training/processing/task_dispatcher.py similarity index 100% rename from src/processing/task_dispatcher.py rename to packages/training/training/processing/task_dispatcher.py diff --git a/src/processing/worker_pool.py b/packages/training/training/processing/worker_pool.py similarity index 100% rename from src/processing/worker_pool.py rename to packages/training/training/processing/worker_pool.py diff --git a/src/yolo/__init__.py b/packages/training/training/yolo/__init__.py similarity index 100% rename from src/yolo/__init__.py rename to packages/training/training/yolo/__init__.py diff --git a/src/yolo/annotation_generator.py b/packages/training/training/yolo/annotation_generator.py similarity index 97% rename from src/yolo/annotation_generator.py rename to packages/training/training/yolo/annotation_generator.py index 69072ac..9a95a86 100644 --- a/src/yolo/annotation_generator.py +++ b/packages/training/training/yolo/annotation_generator.py @@ -328,11 +328,11 @@ def generate_annotations( Returns: List of paths to generated annotation files """ - from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens - from ..pdf.renderer import get_render_dimensions - from ..ocr import OCREngine - from ..matcher import FieldMatcher - from ..normalize import normalize_field + from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens + from shared.pdf.renderer import get_render_dimensions + from shared.ocr import OCREngine + from shared.matcher import FieldMatcher + from shared.normalize import normalize_field output_dir = Path(output_dir) images_dir = output_dir / 'images' diff --git a/src/yolo/dataset_builder.py b/packages/training/training/yolo/dataset_builder.py similarity index 100% rename from src/yolo/dataset_builder.py rename to packages/training/training/yolo/dataset_builder.py diff --git a/src/yolo/db_dataset.py b/packages/training/training/yolo/db_dataset.py similarity index 99% rename from src/yolo/db_dataset.py rename to packages/training/training/yolo/db_dataset.py index 4aafbae..dc0f5be 100644 --- a/src/yolo/db_dataset.py +++ b/packages/training/training/yolo/db_dataset.py @@ -17,7 +17,7 @@ from typing import Any, Optional import numpy as np from PIL import Image -from src.config import DEFAULT_DPI +from shared.config import DEFAULT_DPI from .annotation_generator import FIELD_CLASSES, YOLOAnnotation logger = logging.getLogger(__name__) diff --git a/pyproject.toml b/pyproject.toml index 2165d00..fe13e50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,4 +75,4 @@ disallow_untyped_defs = true [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] -addopts = "-v --cov=src --cov-report=term-missing" +addopts = "-v --cov=packages --cov-report=term-missing" diff --git a/run_autolabel.py b/run_autolabel.py index 40cc97a..d98b2fd 100644 --- a/run_autolabel.py +++ b/run_autolabel.py @@ -4,7 +4,7 @@ 在 WSL 中运行: python run_autolabel.py """ -from src.cli.autolabel import main +from training.cli.autolabel import main if __name__ == '__main__': main() diff --git a/run_server.py b/run_server.py index a77a461..09ef573 100644 --- a/run_server.py +++ b/run_server.py @@ -8,7 +8,7 @@ Usage: python run_server.py --debug --reload """ -from src.cli.serve import main +from inference.cli.serve import main if __name__ == "__main__": main() diff --git a/src/cli/__init__.py b/src/cli/__init__.py deleted file mode 100644 index a24f998..0000000 --- a/src/cli/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# CLI modules for Invoice Master diff --git a/src/data/__init__.py b/src/data/__init__.py deleted file mode 100644 index 454510e..0000000 --- a/src/data/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .csv_loader import CSVLoader, InvoiceRow -from .autolabel_report import AutoLabelReport, FieldMatchResult - -__all__ = ['CSVLoader', 'InvoiceRow', 'AutoLabelReport', 'FieldMatchResult'] diff --git a/src/data/migrations/001_async_tables.sql b/src/data/migrations/001_async_tables.sql deleted file mode 100644 index fb3e267..0000000 --- a/src/data/migrations/001_async_tables.sql +++ /dev/null @@ -1,83 +0,0 @@ --- Async Invoice Processing Tables --- Migration: 001_async_tables.sql --- Created: 2024-01-15 - --- API Keys table for authentication and rate limiting -CREATE TABLE IF NOT EXISTS api_keys ( - api_key TEXT PRIMARY KEY, - name TEXT NOT NULL, - is_active BOOLEAN DEFAULT true, - - -- Rate limits - requests_per_minute INTEGER DEFAULT 10, - max_concurrent_jobs INTEGER DEFAULT 3, - max_file_size_mb INTEGER DEFAULT 50, - - -- Usage tracking - total_requests INTEGER DEFAULT 0, - total_processed INTEGER DEFAULT 0, - - -- Timestamps - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - last_used_at TIMESTAMPTZ -); - --- Async processing requests table -CREATE TABLE IF NOT EXISTS async_requests ( - request_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE, - status TEXT NOT NULL DEFAULT 'pending', - filename TEXT NOT NULL, - file_size INTEGER NOT NULL, - content_type TEXT NOT NULL, - - -- Processing metadata - document_id TEXT, - error_message TEXT, - retry_count INTEGER DEFAULT 0, - - -- Timestamps - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - started_at TIMESTAMPTZ, - completed_at TIMESTAMPTZ, - expires_at TIMESTAMPTZ NOT NULL, - - -- Result storage (JSONB for flexibility) - result JSONB, - - -- Processing time - processing_time_ms REAL, - - -- Visualization path - visualization_path TEXT, - - CONSTRAINT valid_status CHECK (status IN ('pending', 'processing', 'completed', 'failed')) -); - --- Indexes for async_requests -CREATE INDEX IF NOT EXISTS idx_async_requests_api_key ON async_requests(api_key); -CREATE INDEX IF NOT EXISTS idx_async_requests_status ON async_requests(status); -CREATE INDEX IF NOT EXISTS idx_async_requests_created_at ON async_requests(created_at); -CREATE INDEX IF NOT EXISTS idx_async_requests_expires_at ON async_requests(expires_at); -CREATE INDEX IF NOT EXISTS idx_async_requests_api_key_status ON async_requests(api_key, status); - --- Rate limit tracking table -CREATE TABLE IF NOT EXISTS rate_limit_events ( - id SERIAL PRIMARY KEY, - api_key TEXT NOT NULL REFERENCES api_keys(api_key) ON DELETE CASCADE, - event_type TEXT NOT NULL, -- 'request', 'complete', 'fail' - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- Index for rate limiting queries (recent events only) -CREATE INDEX IF NOT EXISTS idx_rate_limit_events_api_key_time -ON rate_limit_events(api_key, created_at DESC); - --- Cleanup old rate limit events index -CREATE INDEX IF NOT EXISTS idx_rate_limit_events_cleanup -ON rate_limit_events(created_at); - --- Insert default API key for development/testing -INSERT INTO api_keys (api_key, name, requests_per_minute, max_concurrent_jobs) -VALUES ('dev-api-key-12345', 'Development Key', 100, 10) -ON CONFLICT (api_key) DO NOTHING; diff --git a/src/data/migrations/002_nullable_admin_token.sql b/src/data/migrations/002_nullable_admin_token.sql deleted file mode 100644 index 29c406c..0000000 --- a/src/data/migrations/002_nullable_admin_token.sql +++ /dev/null @@ -1,5 +0,0 @@ --- Migration: Make admin_token nullable in admin_documents table --- This allows documents uploaded via public API to not require an admin token - -ALTER TABLE admin_documents -ALTER COLUMN admin_token DROP NOT NULL; diff --git a/src/matcher/README.md b/src/matcher/README.md deleted file mode 100644 index efa81c2..0000000 --- a/src/matcher/README.md +++ /dev/null @@ -1,358 +0,0 @@ -# Matcher Module - 字段匹配模块 - -将标准化后的字段值与PDF文档中的tokens进行匹配,返回字段在文档中的位置(bbox),用于生成YOLO训练标注。 - -## 📁 模块结构 - -``` -src/matcher/ -├── __init__.py # 导出主要接口 -├── field_matcher.py # 主类 (205行, 从876行简化) -├── models.py # 数据模型 -├── token_index.py # 空间索引 -├── context.py # 上下文关键词 -├── utils.py # 工具函数 -└── strategies/ # 匹配策略 - ├── __init__.py - ├── base.py # 基础策略类 - ├── exact_matcher.py # 精确匹配 - ├── concatenated_matcher.py # 多token拼接匹配 - ├── substring_matcher.py # 子串匹配 - ├── fuzzy_matcher.py # 模糊匹配 (金额) - └── flexible_date_matcher.py # 灵活日期匹配 -``` - -## 🎯 核心功能 - -### FieldMatcher - 字段匹配器 - -主类,协调各个匹配策略: - -```python -from src.matcher import FieldMatcher - -matcher = FieldMatcher( - context_radius=200.0, # 上下文关键词搜索半径(像素) - min_score_threshold=0.5 # 最低匹配分数 -) - -# 匹配字段 -matches = matcher.find_matches( - tokens=tokens, # PDF提取的tokens - field_name="InvoiceNumber", # 字段名 - normalized_values=["100017500321", "INV-100017500321"], # 标准化变体 - page_no=0 # 页码 -) - -# matches: List[Match] -for match in matches: - print(f"Field: {match.field}") - print(f"Value: {match.value}") - print(f"BBox: {match.bbox}") - print(f"Score: {match.score}") - print(f"Context: {match.context_keywords}") -``` - -### 5种匹配策略 - -#### 1. ExactMatcher - 精确匹配 -```python -from src.matcher.strategies import ExactMatcher - -matcher = ExactMatcher(context_radius=200.0) -matches = matcher.find_matches(tokens, "100017500321", "InvoiceNumber") -``` - -匹配规则: -- 完全匹配: score = 1.0 -- 大小写不敏感: score = 0.95 -- 纯数字匹配: score = 0.9 -- 上下文关键词加分: +0.1/keyword (最多+0.25) - -#### 2. ConcatenatedMatcher - 拼接匹配 -```python -from src.matcher.strategies import ConcatenatedMatcher - -matcher = ConcatenatedMatcher() -matches = matcher.find_matches(tokens, "100017500321", "InvoiceNumber") -``` - -用于处理OCR将单个值拆成多个token的情况。 - -#### 3. SubstringMatcher - 子串匹配 -```python -from src.matcher.strategies import SubstringMatcher - -matcher = SubstringMatcher() -matches = matcher.find_matches(tokens, "2026-01-09", "InvoiceDate") -``` - -匹配嵌入在长文本中的字段值: -- `"Fakturadatum: 2026-01-09"` 匹配 `"2026-01-09"` -- `"Fakturanummer: 2465027205"` 匹配 `"2465027205"` - -#### 4. FuzzyMatcher - 模糊匹配 -```python -from src.matcher.strategies import FuzzyMatcher - -matcher = FuzzyMatcher() -matches = matcher.find_matches(tokens, "1234.56", "Amount") -``` - -用于金额字段,允许小数点差异 (±0.01)。 - -#### 5. FlexibleDateMatcher - 灵活日期匹配 -```python -from src.matcher.strategies import FlexibleDateMatcher - -matcher = FlexibleDateMatcher() -matches = matcher.find_matches(tokens, "2025-01-15", "InvoiceDate") -``` - -当精确匹配失败时使用: -- 同年月: score = 0.7-0.8 -- 7天内: score = 0.75+ -- 3天内: score = 0.8+ -- 14天内: score = 0.6 -- 30天内: score = 0.55 - -### 数据模型 - -#### Match - 匹配结果 -```python -from src.matcher.models import Match - -match = Match( - field="InvoiceNumber", - value="100017500321", - bbox=(100.0, 200.0, 300.0, 220.0), - page_no=0, - score=0.95, - matched_text="100017500321", - context_keywords=["fakturanr"] -) - -# 转换为YOLO格式 -yolo_annotation = match.to_yolo_format( - image_width=1200, - image_height=1600, - class_id=0 -) -# "0 0.166667 0.131250 0.166667 0.012500" -``` - -#### TokenIndex - 空间索引 -```python -from src.matcher.token_index import TokenIndex - -# 构建索引 -index = TokenIndex(tokens, grid_size=100.0) - -# 快速查找附近tokens (O(1)平均复杂度) -nearby = index.find_nearby(token, radius=200.0) - -# 获取缓存的中心坐标 -center = index.get_center(token) - -# 获取缓存的小写文本 -text_lower = index.get_text_lower(token) -``` - -### 上下文关键词 - -```python -from src.matcher.context import CONTEXT_KEYWORDS, find_context_keywords - -# 查看字段的上下文关键词 -keywords = CONTEXT_KEYWORDS["InvoiceNumber"] -# ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', ...] - -# 查找附近的关键词 -found_keywords, boost_score = find_context_keywords( - tokens=tokens, - target_token=token, - field_name="InvoiceNumber", - context_radius=200.0, - token_index=index # 可选,提供则使用O(1)查找 -) -``` - -支持的字段: -- InvoiceNumber -- InvoiceDate -- InvoiceDueDate -- OCR -- Bankgiro -- Plusgiro -- Amount -- supplier_organisation_number -- supplier_accounts - -### 工具函数 - -```python -from src.matcher.utils import ( - normalize_dashes, - parse_amount, - tokens_on_same_line, - bbox_overlap, - DATE_PATTERN, - WHITESPACE_PATTERN, - NON_DIGIT_PATTERN, - DASH_PATTERN, -) - -# 标准化各种破折号 -text = normalize_dashes("123–456") # "123-456" - -# 解析瑞典金额格式 -amount = parse_amount("1 234,56 kr") # 1234.56 -amount = parse_amount("239 00") # 239.00 (öre格式) - -# 检查tokens是否在同一行 -same_line = tokens_on_same_line(token1, token2) - -# 计算bbox重叠度 (IoU) -overlap = bbox_overlap(bbox1, bbox2) # 0.0 - 1.0 -``` - -## 🧪 测试 - -```bash -# 在WSL中运行 -conda activate invoice-py311 - -# 运行所有matcher测试 -pytest tests/matcher/ -v - -# 运行特定策略测试 -pytest tests/matcher/strategies/test_exact_matcher.py -v - -# 查看覆盖率 -pytest tests/matcher/ --cov=src/matcher --cov-report=html -``` - -测试覆盖: -- ✅ 77个测试全部通过 -- ✅ TokenIndex 空间索引 -- ✅ 5种匹配策略 -- ✅ 上下文关键词 -- ✅ 工具函数 -- ✅ 去重逻辑 - -## 📊 重构成果 - -| 指标 | 重构前 | 重构后 | 改进 | -|------|--------|--------|------| -| field_matcher.py | 876行 | 205行 | ↓ 76% | -| 模块数 | 1 | 11 | 更清晰 | -| 最大文件大小 | 876行 | 154行 | 更易读 | -| 测试通过率 | - | 100% | ✅ | - -## 🚀 使用示例 - -### 完整流程 - -```python -from src.matcher import FieldMatcher, find_field_matches - -# 1. 提取PDF tokens (使用PDF模块) -from src.pdf import PDFExtractor -extractor = PDFExtractor("invoice.pdf") -tokens = extractor.extract_tokens() - -# 2. 准备字段值 (从CSV或数据库) -field_values = { - "InvoiceNumber": "100017500321", - "InvoiceDate": "2026-01-09", - "Amount": "1234.56", -} - -# 3. 查找所有字段匹配 -results = find_field_matches(tokens, field_values, page_no=0) - -# 4. 使用结果 -for field_name, matches in results.items(): - if matches: - best_match = matches[0] # 已按score降序排列 - print(f"{field_name}: {best_match.value} @ {best_match.bbox}") - print(f" Score: {best_match.score:.2f}") - print(f" Context: {best_match.context_keywords}") -``` - -### 添加自定义策略 - -```python -from src.matcher.strategies.base import BaseMatchStrategy -from src.matcher.models import Match - -class CustomMatcher(BaseMatchStrategy): - """自定义匹配策略""" - - def find_matches(self, tokens, value, field_name, token_index=None): - matches = [] - # 实现你的匹配逻辑 - for token in tokens: - if self._custom_match_logic(token.text, value): - match = Match( - field=field_name, - value=value, - bbox=token.bbox, - page_no=token.page_no, - score=0.85, - matched_text=token.text, - context_keywords=[] - ) - matches.append(match) - return matches - - def _custom_match_logic(self, token_text, value): - # 你的匹配逻辑 - return True - -# 在FieldMatcher中使用 -from src.matcher import FieldMatcher -matcher = FieldMatcher() -matcher.custom_matcher = CustomMatcher() -``` - -## 🔧 维护指南 - -### 添加新的上下文关键词 - -编辑 [src/matcher/context.py](context.py): - -```python -CONTEXT_KEYWORDS = { - 'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', '新关键词'], - # ... -} -``` - -### 调整匹配分数 - -编辑对应的策略文件: -- [exact_matcher.py](strategies/exact_matcher.py) - 精确匹配分数 -- [fuzzy_matcher.py](strategies/fuzzy_matcher.py) - 模糊匹配容差 -- [flexible_date_matcher.py](strategies/flexible_date_matcher.py) - 日期距离分数 - -### 性能优化 - -1. **TokenIndex网格大小**: 默认100px,可根据实际文档调整 -2. **上下文半径**: 默认200px,可根据扫描DPI调整 -3. **去重网格**: 默认50px,影响bbox重叠检测性能 - -## 📚 相关文档 - -- [PDF模块文档](../pdf/README.md) - Token提取 -- [Normalize模块文档](../normalize/README.md) - 字段值标准化 -- [YOLO模块文档](../yolo/README.md) - 标注生成 - -## ✅ 总结 - -这个模块化的matcher系统提供: -- **清晰的职责分离**: 每个策略专注一个匹配方法 -- **易于测试**: 独立测试每个组件 -- **高性能**: O(1)空间索引,智能去重 -- **可扩展**: 轻松添加新策略 -- **完整测试**: 77个测试100%通过 diff --git a/src/normalize/normalizers/README.md b/src/normalize/normalizers/README.md deleted file mode 100644 index ce99f92..0000000 --- a/src/normalize/normalizers/README.md +++ /dev/null @@ -1,225 +0,0 @@ -# Normalizer Modules - -独立的字段标准化模块,用于生成字段值的各种变体以进行匹配。 - -## 架构 - -每个字段类型都有自己的独立 normalizer 模块,便于复用和维护: - -``` -src/normalize/normalizers/ -├── __init__.py # 导出所有 normalizer -├── base.py # BaseNormalizer 基类 -├── invoice_number_normalizer.py # 发票号码 -├── ocr_normalizer.py # OCR 参考号 -├── bankgiro_normalizer.py # Bankgiro 账号 -├── plusgiro_normalizer.py # Plusgiro 账号 -├── amount_normalizer.py # 金额 -├── date_normalizer.py # 日期 -├── organisation_number_normalizer.py # 组织编号 -├── supplier_accounts_normalizer.py # 供应商账号 -└── customer_number_normalizer.py # 客户编号 -``` - -## 使用方法 - -### 方法 1: 通过 FieldNormalizer 门面类 (推荐) - -```python -from src.normalize.normalizer import FieldNormalizer - -# 标准化发票号码 -variants = FieldNormalizer.normalize_invoice_number('INV-100017500321') -# 返回: ['INV-100017500321', '100017500321'] - -# 标准化金额 -variants = FieldNormalizer.normalize_amount('1 234,56') -# 返回: ['1 234,56', '1234,56', '1234.56', ...] - -# 标准化日期 -variants = FieldNormalizer.normalize_date('2025-12-13') -# 返回: ['2025-12-13', '13/12/2025', '13.12.2025', ...] -``` - -### 方法 2: 通过主函数 (自动选择 normalizer) - -```python -from src.normalize import normalize_field - -# 自动选择合适的 normalizer -variants = normalize_field('InvoiceNumber', 'INV-12345') -variants = normalize_field('Amount', '1234.56') -variants = normalize_field('InvoiceDate', '2025-12-13') -``` - -### 方法 3: 直接使用独立 normalizer (最大灵活性) - -```python -from src.normalize.normalizers import ( - InvoiceNumberNormalizer, - AmountNormalizer, - DateNormalizer, -) - -# 实例化 -invoice_normalizer = InvoiceNumberNormalizer() -amount_normalizer = AmountNormalizer() -date_normalizer = DateNormalizer() - -# 使用 -variants = invoice_normalizer.normalize('INV-12345') -variants = amount_normalizer.normalize('1234.56') -variants = date_normalizer.normalize('2025-12-13') - -# 也可以直接调用 (支持 __call__) -variants = invoice_normalizer('INV-12345') -``` - -## 各 Normalizer 功能 - -### InvoiceNumberNormalizer -- 提取纯数字版本 -- 保留原始格式 - -示例: -```python -'INV-100017500321' -> ['INV-100017500321', '100017500321'] -``` - -### OCRNormalizer -- 与 InvoiceNumberNormalizer 类似 -- 专门用于 OCR 参考号 - -### BankgiroNormalizer -- 生成有/无分隔符的格式 -- 添加 OCR 错误变体 - -示例: -```python -'5393-9484' -> ['5393-9484', '53939484', ...] -``` - -### PlusgiroNormalizer -- 生成有/无分隔符的格式 -- 添加 OCR 错误变体 - -示例: -```python -'1234567-8' -> ['1234567-8', '12345678', ...] -``` - -### AmountNormalizer -- 处理瑞典和国际格式 -- 支持不同的千位/小数分隔符 -- 空格作为小数或千位分隔符 - -示例: -```python -'1 234,56' -> ['1234,56', '1234.56', '1 234,56', ...] -'3045 52' -> ['3045.52', '3045,52', '304552'] -``` - -### DateNormalizer -- 转换为 ISO 格式 (YYYY-MM-DD) -- 生成多种日期格式变体 -- 支持瑞典月份名称 -- 处理模糊格式 (DD/MM 和 MM/DD) - -示例: -```python -'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025', ...] -'13 december 2025' -> ['2025-12-13', ...] -``` - -### OrganisationNumberNormalizer -- 标准化瑞典组织编号 -- 生成 VAT 号码变体 -- 添加 OCR 错误变体 - -示例: -```python -'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...] -``` - -### SupplierAccountsNormalizer -- 处理多个账号 (用 | 分隔) -- 移除/添加前缀 (PG:, BG:) -- 生成不同格式 - -示例: -```python -'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3', ...] -'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484', ...] -``` - -### CustomerNumberNormalizer -- 移除空格和连字符 -- 生成大小写变体 - -示例: -```python -'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566', ...] -``` - -## BaseNormalizer - -所有 normalizer 继承自 `BaseNormalizer`: - -```python -from src.normalize.normalizers.base import BaseNormalizer - -class MyCustomNormalizer(BaseNormalizer): - def normalize(self, value: str) -> list[str]: - # 实现标准化逻辑 - value = self.clean_text(value) # 使用基类的清理方法 - # ... 生成变体 - return variants -``` - -## 设计原则 - -1. **单一职责**: 每个 normalizer 只负责一种字段类型 -2. **独立复用**: 每个模块可独立导入使用 -3. **一致接口**: 所有 normalizer 实现 `normalize(value) -> list[str]` -4. **向后兼容**: 保持与原 `FieldNormalizer` API 兼容 - -## 测试 - -所有 normalizer 都经过全面测试: - -```bash -# 运行所有测试 -python -m pytest src/normalize/test_normalizer.py -v - -# 85 个测试用例全部通过 ✅ -``` - -## 添加新的 Normalizer - -1. 在 `src/normalize/normalizers/` 创建新文件 `my_field_normalizer.py` -2. 继承 `BaseNormalizer` 并实现 `normalize()` 方法 -3. 在 `__init__.py` 中导出 -4. 在 `normalizer.py` 的 `FieldNormalizer` 中添加静态方法 -5. 在 `NORMALIZERS` 字典中注册 - -示例: - -```python -# my_field_normalizer.py -from .base import BaseNormalizer - -class MyFieldNormalizer(BaseNormalizer): - def normalize(self, value: str) -> list[str]: - value = self.clean_text(value) - # ... 实现逻辑 - return variants -``` - -## 优势 - -- ✅ **模块化**: 每个字段类型独立维护 -- ✅ **可复用**: 可在不同项目中独立使用 -- ✅ **可测试**: 每个模块单独测试 -- ✅ **易扩展**: 添加新字段类型很简单 -- ✅ **向后兼容**: 不影响现有代码 -- ✅ **清晰**: 代码结构更清晰易懂 diff --git a/src/web/admin_routes_new.py b/src/web/admin_routes_new.py deleted file mode 100644 index 1e64889..0000000 --- a/src/web/admin_routes_new.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -Backward compatibility shim for admin_routes.py - -DEPRECATED: Import from src.web.api.v1.admin.documents instead. -""" -from src.web.api.v1.admin.documents import * - -__all__ = ["create_admin_router"] diff --git a/src/web/api/v1/admin/__init__.py b/src/web/api/v1/admin/__init__.py deleted file mode 100644 index 95ee920..0000000 --- a/src/web/api/v1/admin/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Admin API v1 - -Document management, annotations, and training endpoints. -""" - -from src.web.api.v1.admin.annotations import create_annotation_router -from src.web.api.v1.admin.auth import create_auth_router -from src.web.api.v1.admin.documents import create_documents_router -from src.web.api.v1.admin.locks import create_locks_router -from src.web.api.v1.admin.training import create_training_router - -__all__ = [ - "create_annotation_router", - "create_auth_router", - "create_documents_router", - "create_locks_router", - "create_training_router", -] diff --git a/src/web/api/v1/public/__init__.py b/src/web/api/v1/public/__init__.py deleted file mode 100644 index 8776b9b..0000000 --- a/src/web/api/v1/public/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Public API v1 - -Customer-facing endpoints for inference, async processing, and labeling. -""" - -from src.web.api.v1.public.inference import create_inference_router -from src.web.api.v1.public.async_api import create_async_router, set_async_service -from src.web.api.v1.public.labeling import create_labeling_router - -__all__ = [ - "create_inference_router", - "create_async_router", - "set_async_service", - "create_labeling_router", -] diff --git a/src/web/schemas/__init__.py b/src/web/schemas/__init__.py deleted file mode 100644 index 0cba086..0000000 --- a/src/web/schemas/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -API Schemas - -Pydantic models for request/response validation. -""" - -# Import everything from sub-modules for backward compatibility -from src.web.schemas.common import * # noqa: F401, F403 -from src.web.schemas.admin import * # noqa: F401, F403 -from src.web.schemas.inference import * # noqa: F401, F403 -from src.web.schemas.labeling import * # noqa: F401, F403 diff --git a/src/web/services/__init__.py b/src/web/services/__init__.py deleted file mode 100644 index e20189a..0000000 --- a/src/web/services/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Business Logic Services - -Service layer for processing requests and orchestrating data operations. -""" - -from src.web.services.autolabel import AutoLabelService, get_auto_label_service -from src.web.services.inference import InferenceService -from src.web.services.async_processing import AsyncProcessingService -from src.web.services.batch_upload import BatchUploadService - -__all__ = [ - "AutoLabelService", - "get_auto_label_service", - "InferenceService", - "AsyncProcessingService", - "BatchUploadService", -] diff --git a/tests/data/test_admin_models_v2.py b/tests/data/test_admin_models_v2.py index 4d65205..7593283 100644 --- a/tests/data/test_admin_models_v2.py +++ b/tests/data/test_admin_models_v2.py @@ -9,7 +9,7 @@ import pytest from datetime import datetime from uuid import UUID, uuid4 -from src.data.admin_models import ( +from inference.data.admin_models import ( BatchUpload, BatchUploadFile, TrainingDocumentLink, diff --git a/tests/data/test_csv_loader.py b/tests/data/test_csv_loader.py index 3282751..0ff567b 100644 --- a/tests/data/test_csv_loader.py +++ b/tests/data/test_csv_loader.py @@ -12,7 +12,7 @@ import tempfile from pathlib import Path from datetime import date from decimal import Decimal -from src.data.csv_loader import ( +from shared.data.csv_loader import ( InvoiceRow, CSVLoader, load_invoice_csv, diff --git a/tests/inference/test_field_extractor.py b/tests/inference/test_field_extractor.py index dc0fb60..627f0a0 100644 --- a/tests/inference/test_field_extractor.py +++ b/tests/inference/test_field_extractor.py @@ -11,7 +11,7 @@ Tests field normalization functions: """ import pytest -from src.inference.field_extractor import FieldExtractor +from inference.pipeline.field_extractor import FieldExtractor class TestFieldExtractorInit: diff --git a/tests/inference/test_pipeline.py b/tests/inference/test_pipeline.py index 1617a10..8564b90 100644 --- a/tests/inference/test_pipeline.py +++ b/tests/inference/test_pipeline.py @@ -10,7 +10,7 @@ Tests the cross-validation logic between payment_line and detected fields: import pytest from unittest.mock import MagicMock, patch -from src.inference.pipeline import InferencePipeline, InferenceResult, CrossValidationResult +from inference.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult class TestCrossValidationResult: diff --git a/tests/matcher/strategies/test_exact_matcher.py b/tests/matcher/strategies/test_exact_matcher.py index 5ff533d..3e9d7a6 100644 --- a/tests/matcher/strategies/test_exact_matcher.py +++ b/tests/matcher/strategies/test_exact_matcher.py @@ -7,7 +7,7 @@ Usage: import pytest from dataclasses import dataclass -from src.matcher.strategies.exact_matcher import ExactMatcher +from shared.matcher.strategies.exact_matcher import ExactMatcher @dataclass diff --git a/tests/matcher/test_field_matcher.py b/tests/matcher/test_field_matcher.py index d169ed2..bc03ead 100644 --- a/tests/matcher/test_field_matcher.py +++ b/tests/matcher/test_field_matcher.py @@ -9,13 +9,13 @@ Usage: import pytest from dataclasses import dataclass -from src.matcher.field_matcher import FieldMatcher, find_field_matches -from src.matcher.models import Match -from src.matcher.token_index import TokenIndex -from src.matcher.context import CONTEXT_KEYWORDS, find_context_keywords -from src.matcher import utils as matcher_utils -from src.matcher.utils import normalize_dashes as _normalize_dashes -from src.matcher.strategies import ( +from shared.matcher.field_matcher import FieldMatcher, find_field_matches +from shared.matcher.models import Match +from shared.matcher.token_index import TokenIndex +from shared.matcher.context import CONTEXT_KEYWORDS, find_context_keywords +from shared.matcher import utils as matcher_utils +from shared.matcher.utils import normalize_dashes as _normalize_dashes +from shared.matcher.strategies import ( SubstringMatcher, FlexibleDateMatcher, FuzzyMatcher, diff --git a/tests/normalize/normalizers/test_amount_normalizer.py b/tests/normalize/normalizers/test_amount_normalizer.py index bbd2042..eaa6a61 100644 --- a/tests/normalize/normalizers/test_amount_normalizer.py +++ b/tests/normalize/normalizers/test_amount_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.amount_normalizer import AmountNormalizer +from shared.normalize.normalizers.amount_normalizer import AmountNormalizer class TestAmountNormalizer: diff --git a/tests/normalize/normalizers/test_bankgiro_normalizer.py b/tests/normalize/normalizers/test_bankgiro_normalizer.py index eb1a75c..d142646 100644 --- a/tests/normalize/normalizers/test_bankgiro_normalizer.py +++ b/tests/normalize/normalizers/test_bankgiro_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.bankgiro_normalizer import BankgiroNormalizer +from shared.normalize.normalizers.bankgiro_normalizer import BankgiroNormalizer class TestBankgiroNormalizer: diff --git a/tests/normalize/normalizers/test_customer_number_normalizer.py b/tests/normalize/normalizers/test_customer_number_normalizer.py index ecbf215..9625e36 100644 --- a/tests/normalize/normalizers/test_customer_number_normalizer.py +++ b/tests/normalize/normalizers/test_customer_number_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.customer_number_normalizer import CustomerNumberNormalizer +from shared.normalize.normalizers.customer_number_normalizer import CustomerNumberNormalizer class TestCustomerNumberNormalizer: diff --git a/tests/normalize/normalizers/test_date_normalizer.py b/tests/normalize/normalizers/test_date_normalizer.py index ffcc7ce..4a0574f 100644 --- a/tests/normalize/normalizers/test_date_normalizer.py +++ b/tests/normalize/normalizers/test_date_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.date_normalizer import DateNormalizer +from shared.normalize.normalizers.date_normalizer import DateNormalizer class TestDateNormalizer: diff --git a/tests/normalize/normalizers/test_invoice_number_normalizer.py b/tests/normalize/normalizers/test_invoice_number_normalizer.py index fef38ee..d2e70da 100644 --- a/tests/normalize/normalizers/test_invoice_number_normalizer.py +++ b/tests/normalize/normalizers/test_invoice_number_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.invoice_number_normalizer import InvoiceNumberNormalizer +from shared.normalize.normalizers.invoice_number_normalizer import InvoiceNumberNormalizer class TestInvoiceNumberNormalizer: diff --git a/tests/normalize/normalizers/test_ocr_normalizer.py b/tests/normalize/normalizers/test_ocr_normalizer.py index 0a9ee6a..925685f 100644 --- a/tests/normalize/normalizers/test_ocr_normalizer.py +++ b/tests/normalize/normalizers/test_ocr_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.ocr_normalizer import OCRNormalizer +from shared.normalize.normalizers.ocr_normalizer import OCRNormalizer class TestOCRNormalizer: diff --git a/tests/normalize/normalizers/test_organisation_number_normalizer.py b/tests/normalize/normalizers/test_organisation_number_normalizer.py index 0113ba0..7c74c3e 100644 --- a/tests/normalize/normalizers/test_organisation_number_normalizer.py +++ b/tests/normalize/normalizers/test_organisation_number_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.organisation_number_normalizer import OrganisationNumberNormalizer +from shared.normalize.normalizers.organisation_number_normalizer import OrganisationNumberNormalizer class TestOrganisationNumberNormalizer: diff --git a/tests/normalize/normalizers/test_plusgiro_normalizer.py b/tests/normalize/normalizers/test_plusgiro_normalizer.py index 092229d..686b07f 100644 --- a/tests/normalize/normalizers/test_plusgiro_normalizer.py +++ b/tests/normalize/normalizers/test_plusgiro_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.plusgiro_normalizer import PlusgiroNormalizer +from shared.normalize.normalizers.plusgiro_normalizer import PlusgiroNormalizer class TestPlusgiroNormalizer: diff --git a/tests/normalize/normalizers/test_supplier_accounts_normalizer.py b/tests/normalize/normalizers/test_supplier_accounts_normalizer.py index f2fb709..5e51bdd 100644 --- a/tests/normalize/normalizers/test_supplier_accounts_normalizer.py +++ b/tests/normalize/normalizers/test_supplier_accounts_normalizer.py @@ -6,7 +6,7 @@ Usage: """ import pytest -from src.normalize.normalizers.supplier_accounts_normalizer import SupplierAccountsNormalizer +from shared.normalize.normalizers.supplier_accounts_normalizer import SupplierAccountsNormalizer class TestSupplierAccountsNormalizer: diff --git a/tests/normalize/test_normalizer.py b/tests/normalize/test_normalizer.py index 6d3e1c4..886c952 100644 --- a/tests/normalize/test_normalizer.py +++ b/tests/normalize/test_normalizer.py @@ -8,7 +8,7 @@ Usage: """ import pytest -from src.normalize.normalizer import ( +from shared.normalize.normalizer import ( FieldNormalizer, NormalizedValue, normalize_field, diff --git a/tests/ocr/test_machine_code_parser.py b/tests/ocr/test_machine_code_parser.py index 7893abf..b669dd6 100644 --- a/tests/ocr/test_machine_code_parser.py +++ b/tests/ocr/test_machine_code_parser.py @@ -9,8 +9,8 @@ Tests the parsing of Swedish invoice payment lines including: """ import pytest -from src.ocr.machine_code_parser import MachineCodeParser, MachineCodeResult -from src.pdf.extractor import Token as TextToken +from shared.ocr.machine_code_parser import MachineCodeParser, MachineCodeResult +from shared.pdf.extractor import Token as TextToken class TestParseStandardPaymentLine: diff --git a/tests/pdf/test_detector.py b/tests/pdf/test_detector.py index 8a7bd5c..0cb322c 100644 --- a/tests/pdf/test_detector.py +++ b/tests/pdf/test_detector.py @@ -13,7 +13,7 @@ Usage: import pytest from pathlib import Path from unittest.mock import patch, MagicMock -from src.pdf.detector import ( +from shared.pdf.detector import ( extract_text_first_page, is_text_pdf, get_pdf_type, @@ -54,12 +54,12 @@ class TestIsTextPDF: def test_empty_pdf_returns_false(self): """Should return False for PDF with no text.""" - with patch("src.pdf.detector.extract_text_first_page", return_value=""): + with patch("shared.pdf.detector.extract_text_first_page", return_value=""): assert is_text_pdf("test.pdf") is False def test_short_text_returns_false(self): """Should return False for PDF with very short text.""" - with patch("src.pdf.detector.extract_text_first_page", return_value="Hello"): + with patch("shared.pdf.detector.extract_text_first_page", return_value="Hello"): assert is_text_pdf("test.pdf") is False def test_readable_text_with_keywords_returns_true(self): @@ -72,7 +72,7 @@ class TestIsTextPDF: Moms: 25% """ + "a" * 200 # Ensure > 200 chars - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): assert is_text_pdf("test.pdf") is True def test_garbled_text_returns_false(self): @@ -80,7 +80,7 @@ class TestIsTextPDF: # Simulate garbled text (lots of non-printable characters) garbled = "\x00\x01\x02" * 100 + "abc" * 20 # Low readable ratio - with patch("src.pdf.detector.extract_text_first_page", return_value=garbled): + with patch("shared.pdf.detector.extract_text_first_page", return_value=garbled): assert is_text_pdf("test.pdf") is False def test_text_without_keywords_needs_high_readability(self): @@ -88,7 +88,7 @@ class TestIsTextPDF: # Text without invoice keywords text = "The quick brown fox jumps over the lazy dog. " * 10 - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): # Should pass if readable ratio is high enough result = is_text_pdf("test.pdf") # Result depends on character ratio - ASCII text should pass @@ -98,7 +98,7 @@ class TestIsTextPDF: """Should respect custom min_chars parameter.""" text = "Short text here" # 15 chars - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): # Default min_chars=30 - should fail assert is_text_pdf("test.pdf", min_chars=30) is False # Custom min_chars=10 - should pass basic length check @@ -273,7 +273,7 @@ class TestIsTextPDFKeywordDetection: # Create text with keyword and enough content text = f"Document with {keyword} keyword here" + " more text" * 50 - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): # Need at least 2 keywords for is_text_pdf to return True # So this tests if keyword is recognized when combined with others pass @@ -282,7 +282,7 @@ class TestIsTextPDFKeywordDetection: """Should detect English invoice keywords.""" text = "Invoice document with date and amount information" + " x" * 100 - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): # invoice + date = 2 keywords result = is_text_pdf("test.pdf") assert result is True @@ -292,7 +292,7 @@ class TestIsTextPDFKeywordDetection: # Only one keyword text = "This is a faktura document" + " x" * 200 - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): # With only 1 keyword, falls back to other checks # Should still pass if readability is high pass @@ -306,7 +306,7 @@ class TestReadabilityChecks: # Pure ASCII text text = "This is a normal document with only ASCII characters. " * 10 - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): result = is_text_pdf("test.pdf") assert result is True @@ -314,7 +314,7 @@ class TestReadabilityChecks: """Should accept Swedish characters as readable.""" text = "Fakturadatum för årets moms på öre belopp" + " normal" * 50 - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): result = is_text_pdf("test.pdf") assert result is True @@ -326,7 +326,7 @@ class TestReadabilityChecks: unreadable = "\x80\x81\x82" * 50 # 150 unreadable chars text = readable + unreadable - with patch("src.pdf.detector.extract_text_first_page", return_value=text): + with patch("shared.pdf.detector.extract_text_first_page", return_value=text): result = is_text_pdf("test.pdf") assert result is False diff --git a/tests/pdf/test_extractor.py b/tests/pdf/test_extractor.py index 3cebd88..8a9e225 100644 --- a/tests/pdf/test_extractor.py +++ b/tests/pdf/test_extractor.py @@ -12,7 +12,7 @@ Usage: import pytest from pathlib import Path from unittest.mock import patch, MagicMock -from src.pdf.extractor import ( +from shared.pdf.extractor import ( Token, PDFDocument, extract_text_tokens, @@ -509,7 +509,7 @@ class TestPDFDocumentIsTextPDF: mock_doc = MagicMock() with patch("fitz.open", return_value=mock_doc): - with patch("src.pdf.extractor._is_text_pdf_standalone", return_value=True) as mock_check: + with patch("shared.pdf.extractor._is_text_pdf_standalone", return_value=True) as mock_check: with PDFDocument("test.pdf") as pdf: result = pdf.is_text_pdf(min_chars=50) diff --git a/tests/test_config.py b/tests/test_config.py index bc3c9b5..376445c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -18,7 +18,7 @@ class TestDatabaseConfig: def test_config_loads_from_env(self): """Test that config loads successfully from .env file.""" # Import config (should load .env automatically) - from src import config + from shared import config # Verify database config is loaded assert config.DATABASE is not None @@ -30,7 +30,7 @@ class TestDatabaseConfig: def test_database_password_loaded(self): """Test that database password is loaded from environment.""" - from src import config + from shared import config # Password should be loaded from .env assert config.DATABASE['password'] is not None @@ -38,7 +38,7 @@ class TestDatabaseConfig: def test_database_connection_string(self): """Test database connection string generation.""" - from src import config + from shared import config conn_str = config.get_db_connection_string() @@ -71,7 +71,7 @@ class TestPathsConfig: def test_paths_config_exists(self): """Test that PATHS configuration exists.""" - from src import config + from shared import config assert config.PATHS is not None assert 'csv_dir' in config.PATHS @@ -85,7 +85,7 @@ class TestAutolabelConfig: def test_autolabel_config_exists(self): """Test that AUTOLABEL configuration exists.""" - from src import config + from shared import config assert config.AUTOLABEL is not None assert 'workers' in config.AUTOLABEL @@ -95,7 +95,7 @@ class TestAutolabelConfig: def test_autolabel_ratios_sum_to_one(self): """Test that train/val/test ratios sum to 1.0.""" - from src import config + from shared import config total = ( config.AUTOLABEL['train_ratio'] + diff --git a/tests/test_customer_number_parser.py b/tests/test_customer_number_parser.py index 32ea51d..e8b76c3 100644 --- a/tests/test_customer_number_parser.py +++ b/tests/test_customer_number_parser.py @@ -10,7 +10,7 @@ from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from src.inference.customer_number_parser import ( +from inference.pipeline.customer_number_parser import ( CustomerNumberParser, DashFormatPattern, NoDashFormatPattern, diff --git a/tests/test_db_security.py b/tests/test_db_security.py index 5cb9a48..b8eae77 100644 --- a/tests/test_db_security.py +++ b/tests/test_db_security.py @@ -11,7 +11,7 @@ from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from src.data.db import DocumentDB +from shared.data.db import DocumentDB class TestSQLInjectionPrevention: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 5dea6cd..b071531 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -10,7 +10,7 @@ from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from src.exceptions import ( +from shared.exceptions import ( InvoiceExtractionError, PDFProcessingError, OCRError, diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000..26fe300 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,48 @@ +"""Import validation tests. + +Ensures all lazy imports across packages resolve correctly, +catching cross-package import errors that mocks would hide. +""" +import importlib +import pkgutil + +import pytest + + +def _collect_modules(package_name: str) -> list[str]: + """Recursively collect all module names under a package.""" + try: + package = importlib.import_module(package_name) + except Exception: + return [package_name] + + modules = [package_name] + if hasattr(package, "__path__"): + for _importer, modname, _ispkg in pkgutil.walk_packages( + package.__path__, prefix=package_name + "." + ): + modules.append(modname) + return modules + + +SHARED_MODULES = _collect_modules("shared") +INFERENCE_MODULES = _collect_modules("inference") +TRAINING_MODULES = _collect_modules("training") + + +@pytest.mark.parametrize("module_name", SHARED_MODULES) +def test_shared_module_imports(module_name: str) -> None: + """Every module in the shared package should import without error.""" + importlib.import_module(module_name) + + +@pytest.mark.parametrize("module_name", INFERENCE_MODULES) +def test_inference_module_imports(module_name: str) -> None: + """Every module in the inference package should import without error.""" + importlib.import_module(module_name) + + +@pytest.mark.parametrize("module_name", TRAINING_MODULES) +def test_training_module_imports(module_name: str) -> None: + """Every module in the training package should import without error.""" + importlib.import_module(module_name) diff --git a/tests/test_payment_line_parser.py b/tests/test_payment_line_parser.py index 51bfe60..f1a45f8 100644 --- a/tests/test_payment_line_parser.py +++ b/tests/test_payment_line_parser.py @@ -10,7 +10,7 @@ from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from src.inference.payment_line_parser import PaymentLineParser, PaymentLineData +from inference.pipeline.payment_line_parser import PaymentLineParser, PaymentLineData class TestPaymentLineParser: diff --git a/tests/utils/test_advanced_utils.py b/tests/utils/test_advanced_utils.py index 588f7d1..d013d10 100644 --- a/tests/utils/test_advanced_utils.py +++ b/tests/utils/test_advanced_utils.py @@ -6,9 +6,9 @@ Tests for advanced utility modules: """ import pytest -from src.utils.fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult -from src.utils.ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants -from src.utils.context_extractor import ContextExtractor, extract_field_with_context +from shared.utils.fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult +from shared.utils.ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants +from shared.utils.context_extractor import ContextExtractor, extract_field_with_context class TestFuzzyMatcher: diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 3222f1d..1cdfb8e 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -3,9 +3,9 @@ Tests for shared utility modules. """ import pytest -from src.utils.text_cleaner import TextCleaner -from src.utils.format_variants import FormatVariants -from src.utils.validators import FieldValidators +from shared.utils.text_cleaner import TextCleaner +from shared.utils.format_variants import FormatVariants +from shared.utils.validators import FieldValidators class TestTextCleaner: diff --git a/tests/web/conftest.py b/tests/web/conftest.py index c9a0fa5..465c65b 100644 --- a/tests/web/conftest.py +++ b/tests/web/conftest.py @@ -10,12 +10,12 @@ from uuid import UUID import pytest -from src.data.async_request_db import ApiKeyConfig, AsyncRequestDB -from src.data.models import AsyncRequest -from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue -from src.web.services.async_processing import AsyncProcessingService -from src.web.config import AsyncConfig, StorageConfig -from src.web.core.rate_limiter import RateLimiter +from inference.data.async_request_db import ApiKeyConfig, AsyncRequestDB +from inference.data.models import AsyncRequest +from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue +from inference.web.services.async_processing import AsyncProcessingService +from inference.web.config import AsyncConfig, StorageConfig +from inference.web.core.rate_limiter import RateLimiter @pytest.fixture diff --git a/tests/web/test_admin_annotations.py b/tests/web/test_admin_annotations.py index 2396fb2..9265810 100644 --- a/tests/web/test_admin_annotations.py +++ b/tests/web/test_admin_annotations.py @@ -9,9 +9,9 @@ from uuid import UUID from fastapi import HTTPException -from src.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES -from src.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router -from src.web.schemas.admin import ( +from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES +from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router +from inference.web.schemas.admin import ( AnnotationCreate, AnnotationUpdate, AutoLabelRequest, diff --git a/tests/web/test_admin_auth.py b/tests/web/test_admin_auth.py index 2e12b02..e61bc36 100644 --- a/tests/web/test_admin_auth.py +++ b/tests/web/test_admin_auth.py @@ -8,9 +8,9 @@ from unittest.mock import MagicMock, patch from fastapi import HTTPException -from src.data.admin_db import AdminDB -from src.data.admin_models import AdminToken -from src.web.core.auth import ( +from inference.data.admin_db import AdminDB +from inference.data.admin_models import AdminToken +from inference.web.core.auth import ( get_admin_db, reset_admin_db, validate_admin_token, @@ -81,7 +81,7 @@ class TestAdminDB: def test_is_valid_admin_token_active(self): """Test valid active token.""" - with patch("src.data.admin_db.get_session_context") as mock_ctx: + with patch("inference.data.admin_db.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session @@ -98,7 +98,7 @@ class TestAdminDB: def test_is_valid_admin_token_inactive(self): """Test inactive token.""" - with patch("src.data.admin_db.get_session_context") as mock_ctx: + with patch("inference.data.admin_db.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session @@ -115,7 +115,7 @@ class TestAdminDB: def test_is_valid_admin_token_expired(self): """Test expired token.""" - with patch("src.data.admin_db.get_session_context") as mock_ctx: + with patch("inference.data.admin_db.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session @@ -132,7 +132,7 @@ class TestAdminDB: def test_is_valid_admin_token_not_found(self): """Test token not found.""" - with patch("src.data.admin_db.get_session_context") as mock_ctx: + with patch("inference.data.admin_db.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session mock_session.get.return_value = None diff --git a/tests/web/test_admin_routes.py b/tests/web/test_admin_routes.py index 070ea6e..6af4208 100644 --- a/tests/web/test_admin_routes.py +++ b/tests/web/test_admin_routes.py @@ -12,8 +12,9 @@ from uuid import UUID from fastapi import HTTPException from fastapi.testclient import TestClient -from src.data.admin_models import AdminDocument, AdminToken -from src.web.api.v1.admin.documents import _validate_uuid, create_admin_router +from inference.data.admin_models import AdminDocument, AdminToken +from inference.web.api.v1.admin.documents import _validate_uuid, create_documents_router +from inference.web.config import StorageConfig # Test UUID @@ -42,13 +43,12 @@ class TestAdminRouter: def test_creates_router_with_endpoints(self): """Test router is created with expected endpoints.""" - router = create_admin_router((".pdf", ".png", ".jpg")) + router = create_documents_router(StorageConfig()) # Get route paths (include prefix from router) paths = [route.path for route in router.routes] - # Paths include the /admin prefix - assert any("/auth/token" in p for p in paths) + # Paths include the /admin/documents prefix assert any("/documents" in p for p in paths) assert any("/documents/stats" in p for p in paths) assert any("{document_id}" in p for p in paths) @@ -66,7 +66,7 @@ class TestCreateTokenEndpoint: def test_create_token_success(self, mock_db): """Test successful token creation.""" - from src.web.schemas.admin import AdminTokenCreate + from inference.web.schemas.admin import AdminTokenCreate request = AdminTokenCreate(name="Test Token", expires_in_days=30) diff --git a/tests/web/test_admin_routes_enhanced.py b/tests/web/test_admin_routes_enhanced.py index 5dac633..55b0563 100644 --- a/tests/web/test_admin_routes_enhanced.py +++ b/tests/web/test_admin_routes_enhanced.py @@ -9,8 +9,9 @@ from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient -from src.web.api.v1.admin.documents import create_admin_router -from src.web.core.auth import validate_admin_token, get_admin_db +from inference.web.api.v1.admin.documents import create_documents_router +from inference.web.config import StorageConfig +from inference.web.core.auth import validate_admin_token, get_admin_db class MockAdminDocument: @@ -189,7 +190,7 @@ def app(): app.dependency_overrides[get_admin_db] = lambda: mock_db # Include router - router = create_admin_router((".pdf", ".png", ".jpg")) + router = create_documents_router(StorageConfig()) app.include_router(router) return app diff --git a/tests/web/test_admin_schemas_split.py b/tests/web/test_admin_schemas_split.py new file mode 100644 index 0000000..5bdcd33 --- /dev/null +++ b/tests/web/test_admin_schemas_split.py @@ -0,0 +1,245 @@ +""" +Tests to verify admin schemas split maintains backward compatibility. + +All existing imports from inference.web.schemas.admin must continue to work. +""" + +import pytest + + +class TestEnumImports: + """All enums importable from inference.web.schemas.admin.""" + + def test_document_status(self): + from inference.web.schemas.admin import DocumentStatus + assert DocumentStatus.PENDING == "pending" + + def test_auto_label_status(self): + from inference.web.schemas.admin import AutoLabelStatus + assert AutoLabelStatus.RUNNING == "running" + + def test_training_status(self): + from inference.web.schemas.admin import TrainingStatus + assert TrainingStatus.PENDING == "pending" + + def test_training_type(self): + from inference.web.schemas.admin import TrainingType + assert TrainingType.TRAIN == "train" + + def test_annotation_source(self): + from inference.web.schemas.admin import AnnotationSource + assert AnnotationSource.MANUAL == "manual" + + +class TestAuthImports: + """Auth schemas importable.""" + + def test_admin_token_create(self): + from inference.web.schemas.admin import AdminTokenCreate + token = AdminTokenCreate(name="test") + assert token.name == "test" + + def test_admin_token_response(self): + from inference.web.schemas.admin import AdminTokenResponse + assert AdminTokenResponse is not None + + +class TestDocumentImports: + """Document schemas importable.""" + + def test_document_upload_response(self): + from inference.web.schemas.admin import DocumentUploadResponse + assert DocumentUploadResponse is not None + + def test_document_item(self): + from inference.web.schemas.admin import DocumentItem + assert DocumentItem is not None + + def test_document_list_response(self): + from inference.web.schemas.admin import DocumentListResponse + assert DocumentListResponse is not None + + def test_document_detail_response(self): + from inference.web.schemas.admin import DocumentDetailResponse + assert DocumentDetailResponse is not None + + def test_document_stats_response(self): + from inference.web.schemas.admin import DocumentStatsResponse + assert DocumentStatsResponse is not None + + +class TestAnnotationImports: + """Annotation schemas importable.""" + + def test_bounding_box(self): + from inference.web.schemas.admin import BoundingBox + bb = BoundingBox(x=0, y=0, width=100, height=50) + assert bb.width == 100 + + def test_annotation_create(self): + from inference.web.schemas.admin import AnnotationCreate + assert AnnotationCreate is not None + + def test_annotation_update(self): + from inference.web.schemas.admin import AnnotationUpdate + assert AnnotationUpdate is not None + + def test_annotation_item(self): + from inference.web.schemas.admin import AnnotationItem + assert AnnotationItem is not None + + def test_annotation_response(self): + from inference.web.schemas.admin import AnnotationResponse + assert AnnotationResponse is not None + + def test_annotation_list_response(self): + from inference.web.schemas.admin import AnnotationListResponse + assert AnnotationListResponse is not None + + def test_annotation_lock_request(self): + from inference.web.schemas.admin import AnnotationLockRequest + assert AnnotationLockRequest is not None + + def test_annotation_lock_response(self): + from inference.web.schemas.admin import AnnotationLockResponse + assert AnnotationLockResponse is not None + + def test_auto_label_request(self): + from inference.web.schemas.admin import AutoLabelRequest + assert AutoLabelRequest is not None + + def test_auto_label_response(self): + from inference.web.schemas.admin import AutoLabelResponse + assert AutoLabelResponse is not None + + def test_annotation_verify_request(self): + from inference.web.schemas.admin import AnnotationVerifyRequest + assert AnnotationVerifyRequest is not None + + def test_annotation_verify_response(self): + from inference.web.schemas.admin import AnnotationVerifyResponse + assert AnnotationVerifyResponse is not None + + def test_annotation_override_request(self): + from inference.web.schemas.admin import AnnotationOverrideRequest + assert AnnotationOverrideRequest is not None + + def test_annotation_override_response(self): + from inference.web.schemas.admin import AnnotationOverrideResponse + assert AnnotationOverrideResponse is not None + + +class TestTrainingImports: + """Training schemas importable.""" + + def test_training_config(self): + from inference.web.schemas.admin import TrainingConfig + config = TrainingConfig() + assert config.epochs == 100 + + def test_training_task_create(self): + from inference.web.schemas.admin import TrainingTaskCreate + assert TrainingTaskCreate is not None + + def test_training_task_item(self): + from inference.web.schemas.admin import TrainingTaskItem + assert TrainingTaskItem is not None + + def test_training_task_list_response(self): + from inference.web.schemas.admin import TrainingTaskListResponse + assert TrainingTaskListResponse is not None + + def test_training_task_detail_response(self): + from inference.web.schemas.admin import TrainingTaskDetailResponse + assert TrainingTaskDetailResponse is not None + + def test_training_task_response(self): + from inference.web.schemas.admin import TrainingTaskResponse + assert TrainingTaskResponse is not None + + def test_training_log_item(self): + from inference.web.schemas.admin import TrainingLogItem + assert TrainingLogItem is not None + + def test_training_logs_response(self): + from inference.web.schemas.admin import TrainingLogsResponse + assert TrainingLogsResponse is not None + + def test_export_request(self): + from inference.web.schemas.admin import ExportRequest + assert ExportRequest is not None + + def test_export_response(self): + from inference.web.schemas.admin import ExportResponse + assert ExportResponse is not None + + def test_training_document_item(self): + from inference.web.schemas.admin import TrainingDocumentItem + assert TrainingDocumentItem is not None + + def test_training_documents_response(self): + from inference.web.schemas.admin import TrainingDocumentsResponse + assert TrainingDocumentsResponse is not None + + def test_model_metrics(self): + from inference.web.schemas.admin import ModelMetrics + assert ModelMetrics is not None + + def test_training_model_item(self): + from inference.web.schemas.admin import TrainingModelItem + assert TrainingModelItem is not None + + def test_training_models_response(self): + from inference.web.schemas.admin import TrainingModelsResponse + assert TrainingModelsResponse is not None + + def test_training_history_item(self): + from inference.web.schemas.admin import TrainingHistoryItem + assert TrainingHistoryItem is not None + + +class TestDatasetImports: + """Dataset schemas importable.""" + + def test_dataset_create_request(self): + from inference.web.schemas.admin import DatasetCreateRequest + assert DatasetCreateRequest is not None + + def test_dataset_document_item(self): + from inference.web.schemas.admin import DatasetDocumentItem + assert DatasetDocumentItem is not None + + def test_dataset_response(self): + from inference.web.schemas.admin import DatasetResponse + assert DatasetResponse is not None + + def test_dataset_detail_response(self): + from inference.web.schemas.admin import DatasetDetailResponse + assert DatasetDetailResponse is not None + + def test_dataset_list_item(self): + from inference.web.schemas.admin import DatasetListItem + assert DatasetListItem is not None + + def test_dataset_list_response(self): + from inference.web.schemas.admin import DatasetListResponse + assert DatasetListResponse is not None + + def test_dataset_train_request(self): + from inference.web.schemas.admin import DatasetTrainRequest + assert DatasetTrainRequest is not None + + +class TestForwardReferences: + """Forward references resolve correctly.""" + + def test_document_detail_has_annotation_items(self): + from inference.web.schemas.admin import DocumentDetailResponse + fields = DocumentDetailResponse.model_fields + assert "annotations" in fields + assert "training_history" in fields + + def test_dataset_train_request_has_config(self): + from inference.web.schemas.admin import DatasetTrainRequest, TrainingConfig + req = DatasetTrainRequest(name="test", config=TrainingConfig()) + assert req.config.epochs == 100 diff --git a/tests/web/test_admin_training.py b/tests/web/test_admin_training.py index 62e84ac..a3747af 100644 --- a/tests/web/test_admin_training.py +++ b/tests/web/test_admin_training.py @@ -7,15 +7,15 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock, patch from uuid import UUID -from src.data.admin_models import TrainingTask, TrainingLog -from src.web.api.v1.admin.training import _validate_uuid, create_training_router -from src.web.core.scheduler import ( +from inference.data.admin_models import TrainingTask, TrainingLog +from inference.web.api.v1.admin.training import _validate_uuid, create_training_router +from inference.web.core.scheduler import ( TrainingScheduler, get_training_scheduler, start_scheduler, stop_scheduler, ) -from src.web.schemas.admin import ( +from inference.web.schemas.admin import ( TrainingConfig, TrainingStatus, TrainingTaskCreate, diff --git a/tests/web/test_annotation_locks.py b/tests/web/test_annotation_locks.py index dfff46d..47cbbd3 100644 --- a/tests/web/test_annotation_locks.py +++ b/tests/web/test_annotation_locks.py @@ -9,8 +9,8 @@ from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient -from src.web.api.v1.admin.documents import create_admin_router -from src.web.core.auth import validate_admin_token, get_admin_db +from inference.web.api.v1.admin.locks import create_locks_router +from inference.web.core.auth import validate_admin_token, get_admin_db class MockAdminDocument: @@ -110,7 +110,7 @@ def app(): app.dependency_overrides[get_admin_db] = lambda: mock_db # Include router - router = create_admin_router((".pdf", ".png", ".jpg")) + router = create_locks_router() app.include_router(router) return app diff --git a/tests/web/test_annotation_phase5.py b/tests/web/test_annotation_phase5.py index cba8c20..66d62ec 100644 --- a/tests/web/test_annotation_phase5.py +++ b/tests/web/test_annotation_phase5.py @@ -9,8 +9,8 @@ from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient -from src.web.api.v1.admin.annotations import create_annotation_router -from src.web.core.auth import validate_admin_token, get_admin_db +from inference.web.api.v1.admin.annotations import create_annotation_router +from inference.web.core.auth import validate_admin_token, get_admin_db class MockAdminDocument: diff --git a/tests/web/test_async_queue.py b/tests/web/test_async_queue.py index 1db16cf..cbeb99a 100644 --- a/tests/web/test_async_queue.py +++ b/tests/web/test_async_queue.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock import pytest -from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue +from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue class TestAsyncTask: diff --git a/tests/web/test_async_routes.py b/tests/web/test_async_routes.py index b7dfb6f..b7a58b9 100644 --- a/tests/web/test_async_routes.py +++ b/tests/web/test_async_routes.py @@ -11,12 +11,12 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from src.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB -from src.web.api.v1.async_api.routes import create_async_router, set_async_service -from src.web.services.async_processing import AsyncSubmitResult -from src.web.dependencies import init_dependencies -from src.web.rate_limiter import RateLimiter, RateLimitStatus -from src.web.schemas.inference import AsyncStatus +from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB +from inference.web.api.v1.public.async_api import create_async_router, set_async_service +from inference.web.services.async_processing import AsyncSubmitResult +from inference.web.dependencies import init_dependencies +from inference.web.rate_limiter import RateLimiter, RateLimitStatus +from inference.web.schemas.inference import AsyncStatus # Valid UUID for testing TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000" diff --git a/tests/web/test_async_service.py b/tests/web/test_async_service.py index ec8071c..556dc1e 100644 --- a/tests/web/test_async_service.py +++ b/tests/web/test_async_service.py @@ -10,11 +10,11 @@ from unittest.mock import MagicMock, patch import pytest -from src.data.async_request_db import AsyncRequest -from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue -from src.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult -from src.web.config import AsyncConfig, StorageConfig -from src.web.rate_limiter import RateLimiter +from inference.data.async_request_db import AsyncRequest +from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue +from inference.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult +from inference.web.config import AsyncConfig, StorageConfig +from inference.web.rate_limiter import RateLimiter @pytest.fixture diff --git a/tests/web/test_autolabel_with_locks.py b/tests/web/test_autolabel_with_locks.py index dfc400a..0fbbddc 100644 --- a/tests/web/test_autolabel_with_locks.py +++ b/tests/web/test_autolabel_with_locks.py @@ -8,8 +8,8 @@ from pathlib import Path from unittest.mock import Mock, MagicMock from uuid import uuid4 -from src.web.services.autolabel import AutoLabelService -from src.data.admin_db import AdminDB +from inference.web.services.autolabel import AutoLabelService +from inference.data.admin_db import AdminDB class MockDocument: diff --git a/tests/web/test_batch_queue.py b/tests/web/test_batch_queue.py index e619313..3941ac6 100644 --- a/tests/web/test_batch_queue.py +++ b/tests/web/test_batch_queue.py @@ -9,7 +9,7 @@ from uuid import uuid4 import pytest -from src.web.workers.batch_queue import BatchTask, BatchTaskQueue +from inference.web.workers.batch_queue import BatchTask, BatchTaskQueue class MockBatchService: diff --git a/tests/web/test_batch_upload_routes.py b/tests/web/test_batch_upload_routes.py index b039688..6a3427a 100644 --- a/tests/web/test_batch_upload_routes.py +++ b/tests/web/test_batch_upload_routes.py @@ -11,10 +11,10 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from src.web.api.v1.batch.routes import router -from src.web.core.auth import validate_admin_token, get_admin_db -from src.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue -from src.web.services.batch_upload import BatchUploadService +from inference.web.api.v1.batch.routes import router +from inference.web.core.auth import validate_admin_token, get_admin_db +from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue +from inference.web.services.batch_upload import BatchUploadService class MockAdminDB: diff --git a/tests/web/test_batch_upload_service.py b/tests/web/test_batch_upload_service.py index 102cf3a..5aa0d82 100644 --- a/tests/web/test_batch_upload_service.py +++ b/tests/web/test_batch_upload_service.py @@ -9,8 +9,8 @@ from uuid import uuid4 import pytest -from src.data.admin_db import AdminDB -from src.web.services.batch_upload import BatchUploadService +from inference.data.admin_db import AdminDB +from inference.web.services.batch_upload import BatchUploadService @pytest.fixture diff --git a/tests/web/test_dataset_builder.py b/tests/web/test_dataset_builder.py new file mode 100644 index 0000000..ae79912 --- /dev/null +++ b/tests/web/test_dataset_builder.py @@ -0,0 +1,331 @@ +""" +Tests for DatasetBuilder service. + +TDD: Write tests first, then implement dataset_builder.py. +""" + +import shutil +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from inference.data.admin_models import ( + AdminAnnotation, + AdminDocument, + TrainingDataset, + FIELD_CLASSES, +) + + +@pytest.fixture +def tmp_admin_images(tmp_path): + """Create mock admin images directory with sample images.""" + doc_ids = [uuid4() for _ in range(5)] + for doc_id in doc_ids: + doc_dir = tmp_path / "admin_images" / str(doc_id) + doc_dir.mkdir(parents=True) + # Create 2 pages per doc + for page in range(1, 3): + img_path = doc_dir / f"page_{page}.png" + img_path.write_bytes(b"fake-png-data") + return tmp_path, doc_ids + + +@pytest.fixture +def mock_admin_db(): + """Mock AdminDB with dataset and document methods.""" + db = MagicMock() + db.create_dataset.return_value = TrainingDataset( + dataset_id=uuid4(), + name="test-dataset", + status="building", + train_ratio=0.8, + val_ratio=0.1, + seed=42, + ) + return db + + +@pytest.fixture +def sample_documents(tmp_admin_images): + """Create sample AdminDocument objects.""" + tmp_path, doc_ids = tmp_admin_images + docs = [] + for doc_id in doc_ids: + doc = MagicMock(spec=AdminDocument) + doc.document_id = doc_id + doc.filename = f"{doc_id}.pdf" + doc.page_count = 2 + doc.file_path = str(tmp_path / "admin_images" / str(doc_id)) + docs.append(doc) + return docs + + +@pytest.fixture +def sample_annotations(sample_documents): + """Create sample annotations for each document page.""" + annotations = {} + for doc in sample_documents: + doc_anns = [] + for page in range(1, 3): + ann = MagicMock(spec=AdminAnnotation) + ann.document_id = doc.document_id + ann.page_number = page + ann.class_id = 0 + ann.class_name = "invoice_number" + ann.x_center = 0.5 + ann.y_center = 0.3 + ann.width = 0.2 + ann.height = 0.05 + doc_anns.append(ann) + annotations[str(doc.document_id)] = doc_anns + return annotations + + +class TestDatasetBuilder: + """Tests for DatasetBuilder.""" + + def test_build_creates_directory_structure( + self, tmp_path, mock_admin_db, sample_documents, sample_annotations + ): + """Dataset builder should create images/ and labels/ with train/val/test subdirs.""" + from inference.web.services.dataset_builder import DatasetBuilder + + dataset_dir = tmp_path / "datasets" / "test" + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + + # Mock DB calls + mock_admin_db.get_documents_by_ids.return_value = sample_documents + mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( + sample_annotations.get(str(doc_id), []) + ) + + dataset = mock_admin_db.create_dataset.return_value + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[str(d.document_id) for d in sample_documents], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + + result_dir = tmp_path / "datasets" / str(dataset.dataset_id) + for split in ["train", "val", "test"]: + assert (result_dir / "images" / split).exists() + assert (result_dir / "labels" / split).exists() + + def test_build_copies_images( + self, tmp_path, mock_admin_db, sample_documents, sample_annotations + ): + """Images should be copied from admin_images to dataset folder.""" + from inference.web.services.dataset_builder import DatasetBuilder + + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + mock_admin_db.get_documents_by_ids.return_value = sample_documents + mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( + sample_annotations.get(str(doc_id), []) + ) + + dataset = mock_admin_db.create_dataset.return_value + result = builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[str(d.document_id) for d in sample_documents], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + + # Check total images copied + result_dir = tmp_path / "datasets" / str(dataset.dataset_id) + total_images = sum( + len(list((result_dir / "images" / split).glob("*.png"))) + for split in ["train", "val", "test"] + ) + assert total_images == 10 # 5 docs * 2 pages + + def test_build_generates_yolo_labels( + self, tmp_path, mock_admin_db, sample_documents, sample_annotations + ): + """YOLO label files should be generated with correct format.""" + from inference.web.services.dataset_builder import DatasetBuilder + + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + mock_admin_db.get_documents_by_ids.return_value = sample_documents + mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( + sample_annotations.get(str(doc_id), []) + ) + + dataset = mock_admin_db.create_dataset.return_value + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[str(d.document_id) for d in sample_documents], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + + result_dir = tmp_path / "datasets" / str(dataset.dataset_id) + total_labels = sum( + len(list((result_dir / "labels" / split).glob("*.txt"))) + for split in ["train", "val", "test"] + ) + assert total_labels == 10 # 5 docs * 2 pages + + # Check label format: "class_id x_center y_center width height" + label_files = list((result_dir / "labels").rglob("*.txt")) + content = label_files[0].read_text().strip() + parts = content.split() + assert len(parts) == 5 + assert int(parts[0]) == 0 # class_id + assert 0 <= float(parts[1]) <= 1 # x_center + assert 0 <= float(parts[2]) <= 1 # y_center + + def test_build_generates_data_yaml( + self, tmp_path, mock_admin_db, sample_documents, sample_annotations + ): + """data.yaml should be generated with correct field classes.""" + from inference.web.services.dataset_builder import DatasetBuilder + + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + mock_admin_db.get_documents_by_ids.return_value = sample_documents + mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( + sample_annotations.get(str(doc_id), []) + ) + + dataset = mock_admin_db.create_dataset.return_value + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[str(d.document_id) for d in sample_documents], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + + yaml_path = tmp_path / "datasets" / str(dataset.dataset_id) / "data.yaml" + assert yaml_path.exists() + content = yaml_path.read_text() + assert "train:" in content + assert "val:" in content + assert "nc:" in content + assert "invoice_number" in content + + def test_build_splits_documents_correctly( + self, tmp_path, mock_admin_db, sample_documents, sample_annotations + ): + """Documents should be split into train/val/test according to ratios.""" + from inference.web.services.dataset_builder import DatasetBuilder + + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + mock_admin_db.get_documents_by_ids.return_value = sample_documents + mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( + sample_annotations.get(str(doc_id), []) + ) + + dataset = mock_admin_db.create_dataset.return_value + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[str(d.document_id) for d in sample_documents], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + + # Verify add_dataset_documents was called with correct splits + call_args = mock_admin_db.add_dataset_documents.call_args + docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1] + splits = [d["split"] for d in docs_added] + assert "train" in splits + # With 5 docs, 80/10/10 -> 4 train, 0-1 val, 0-1 test + train_count = splits.count("train") + assert train_count >= 3 # At least 3 of 5 should be train + + def test_build_updates_status_to_ready( + self, tmp_path, mock_admin_db, sample_documents, sample_annotations + ): + """After successful build, dataset status should be updated to 'ready'.""" + from inference.web.services.dataset_builder import DatasetBuilder + + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + mock_admin_db.get_documents_by_ids.return_value = sample_documents + mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( + sample_annotations.get(str(doc_id), []) + ) + + dataset = mock_admin_db.create_dataset.return_value + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[str(d.document_id) for d in sample_documents], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + + mock_admin_db.update_dataset_status.assert_called_once() + call_kwargs = mock_admin_db.update_dataset_status.call_args[1] + assert call_kwargs["status"] == "ready" + assert call_kwargs["total_documents"] == 5 + assert call_kwargs["total_images"] == 10 + + def test_build_sets_failed_on_error( + self, tmp_path, mock_admin_db + ): + """If build fails, dataset status should be set to 'failed'.""" + from inference.web.services.dataset_builder import DatasetBuilder + + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + mock_admin_db.get_documents_by_ids.return_value = [] # No docs found + + dataset = mock_admin_db.create_dataset.return_value + with pytest.raises(ValueError): + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=["nonexistent-id"], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + + mock_admin_db.update_dataset_status.assert_called_once() + call_kwargs = mock_admin_db.update_dataset_status.call_args[1] + assert call_kwargs["status"] == "failed" + + def test_build_with_seed_produces_deterministic_splits( + self, tmp_path, mock_admin_db, sample_documents, sample_annotations + ): + """Same seed should produce same splits.""" + from inference.web.services.dataset_builder import DatasetBuilder + + results = [] + for _ in range(2): + builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets") + mock_admin_db.get_documents_by_ids.return_value = sample_documents + mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: ( + sample_annotations.get(str(doc_id), []) + ) + mock_admin_db.add_dataset_documents.reset_mock() + mock_admin_db.update_dataset_status.reset_mock() + + dataset = mock_admin_db.create_dataset.return_value + builder.build_dataset( + dataset_id=str(dataset.dataset_id), + document_ids=[str(d.document_id) for d in sample_documents], + train_ratio=0.8, + val_ratio=0.1, + seed=42, + admin_images_dir=tmp_path / "admin_images", + ) + call_args = mock_admin_db.add_dataset_documents.call_args + docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1] + results.append([(d["document_id"], d["split"]) for d in docs]) + + assert results[0] == results[1] diff --git a/tests/web/test_dataset_routes.py b/tests/web/test_dataset_routes.py new file mode 100644 index 0000000..d2add37 --- /dev/null +++ b/tests/web/test_dataset_routes.py @@ -0,0 +1,200 @@ +""" +Tests for Dataset API routes in training.py. +""" + +import asyncio + +import pytest +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch +from uuid import UUID + +from inference.data.admin_models import TrainingDataset, DatasetDocument +from inference.web.api.v1.admin.training import create_training_router +from inference.web.schemas.admin import ( + DatasetCreateRequest, + DatasetTrainRequest, + TrainingConfig, + TrainingStatus, +) + + +TEST_DATASET_UUID = "880e8400-e29b-41d4-a716-446655440010" +TEST_DOC_UUID_1 = "990e8400-e29b-41d4-a716-446655440011" +TEST_DOC_UUID_2 = "990e8400-e29b-41d4-a716-446655440012" +TEST_TOKEN = "test-admin-token-12345" +TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002" + + +def _make_dataset(**overrides) -> MagicMock: + defaults = dict( + dataset_id=UUID(TEST_DATASET_UUID), + name="test-dataset", + description="Test dataset", + status="ready", + train_ratio=0.8, + val_ratio=0.1, + seed=42, + total_documents=2, + total_images=4, + total_annotations=10, + dataset_path="/data/datasets/test-dataset", + error_message=None, + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + defaults.update(overrides) + ds = MagicMock(spec=TrainingDataset) + for k, v in defaults.items(): + setattr(ds, k, v) + return ds + + +def _make_dataset_doc(doc_id: str, split: str = "train") -> MagicMock: + doc = MagicMock(spec=DatasetDocument) + doc.document_id = UUID(doc_id) + doc.split = split + doc.page_count = 2 + doc.annotation_count = 5 + return doc + + +def _find_endpoint(name: str): + router = create_training_router() + for route in router.routes: + if hasattr(route, "endpoint") and route.endpoint.__name__ == name: + return route.endpoint + raise AssertionError(f"Endpoint {name} not found") + + +class TestCreateDatasetRoute: + """Tests for POST /admin/training/datasets.""" + + def test_router_has_dataset_endpoints(self): + router = create_training_router() + paths = [route.path for route in router.routes] + assert any("datasets" in p for p in paths) + + def test_create_dataset_calls_builder(self): + fn = _find_endpoint("create_dataset") + + mock_db = MagicMock() + mock_db.create_dataset.return_value = _make_dataset(status="building") + + mock_builder = MagicMock() + mock_builder.build_dataset.return_value = { + "total_documents": 2, + "total_images": 4, + "total_annotations": 10, + } + + request = DatasetCreateRequest( + name="test-dataset", + document_ids=[TEST_DOC_UUID_1, TEST_DOC_UUID_2], + ) + + with patch( + "inference.web.services.dataset_builder.DatasetBuilder", + return_value=mock_builder, + ) as mock_cls: + result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db)) + + mock_db.create_dataset.assert_called_once() + mock_builder.build_dataset.assert_called_once() + assert result.dataset_id == TEST_DATASET_UUID + assert result.name == "test-dataset" + + +class TestListDatasetsRoute: + """Tests for GET /admin/training/datasets.""" + + def test_list_datasets(self): + fn = _find_endpoint("list_datasets") + + mock_db = MagicMock() + mock_db.get_datasets.return_value = ([_make_dataset()], 1) + + result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0)) + + assert result.total == 1 + assert len(result.datasets) == 1 + assert result.datasets[0].name == "test-dataset" + + +class TestGetDatasetRoute: + """Tests for GET /admin/training/datasets/{dataset_id}.""" + + def test_get_dataset_returns_detail(self): + fn = _find_endpoint("get_dataset") + + mock_db = MagicMock() + mock_db.get_dataset.return_value = _make_dataset() + mock_db.get_dataset_documents.return_value = [ + _make_dataset_doc(TEST_DOC_UUID_1, "train"), + _make_dataset_doc(TEST_DOC_UUID_2, "val"), + ] + + result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db)) + + assert result.dataset_id == TEST_DATASET_UUID + assert len(result.documents) == 2 + + def test_get_dataset_not_found(self): + fn = _find_endpoint("get_dataset") + + mock_db = MagicMock() + mock_db.get_dataset.return_value = None + + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db)) + assert exc_info.value.status_code == 404 + + +class TestDeleteDatasetRoute: + """Tests for DELETE /admin/training/datasets/{dataset_id}.""" + + def test_delete_dataset(self): + fn = _find_endpoint("delete_dataset") + + mock_db = MagicMock() + mock_db.get_dataset.return_value = _make_dataset(dataset_path=None) + + result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db)) + + mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID) + assert result["message"] == "Dataset deleted" + + +class TestTrainFromDatasetRoute: + """Tests for POST /admin/training/datasets/{dataset_id}/train.""" + + def test_train_from_ready_dataset(self): + fn = _find_endpoint("train_from_dataset") + + mock_db = MagicMock() + mock_db.get_dataset.return_value = _make_dataset(status="ready") + mock_db.create_training_task.return_value = TEST_TASK_UUID + + request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig()) + + result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) + + assert result.task_id == TEST_TASK_UUID + assert result.status == TrainingStatus.PENDING + mock_db.create_training_task.assert_called_once() + + def test_train_from_building_dataset_fails(self): + fn = _find_endpoint("train_from_dataset") + + mock_db = MagicMock() + mock_db.get_dataset.return_value = _make_dataset(status="building") + + request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig()) + + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db)) + assert exc_info.value.status_code == 400 diff --git a/tests/web/test_inference_api.py b/tests/web/test_inference_api.py index adc0d35..bc09f53 100644 --- a/tests/web/test_inference_api.py +++ b/tests/web/test_inference_api.py @@ -11,8 +11,8 @@ from fastapi.testclient import TestClient from PIL import Image import io -from src.web.app import create_app -from src.web.config import ModelConfig, StorageConfig, AppConfig +from inference.web.app import create_app +from inference.web.config import ModelConfig, StorageConfig, AppConfig @pytest.fixture @@ -87,8 +87,8 @@ class TestHealthEndpoint: class TestInferEndpoint: """Test /api/v1/infer endpoint.""" - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') def test_infer_accepts_png_file( self, mock_yolo_detector, @@ -150,8 +150,8 @@ class TestInferEndpoint: assert response.status_code == 422 # Unprocessable Entity - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') def test_infer_returns_cross_validation_if_available( self, mock_yolo_detector, @@ -210,8 +210,8 @@ class TestInferEndpoint: # This test documents that it should be added - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') def test_infer_handles_processing_errors_gracefully( self, mock_yolo_detector, @@ -280,16 +280,16 @@ class TestInferenceServiceImports: This test will fail if there are ImportError issues like: - from ..inference.pipeline (wrong relative import) - - from src.web.inference (non-existent module) + - from inference.web.inference (non-existent module) It ensures the imports are correct before runtime. """ - from src.web.services.inference import InferenceService + from inference.web.services.inference import InferenceService # Import the modules that InferenceService tries to import - from src.inference.pipeline import InferencePipeline - from src.inference.yolo_detector import YOLODetector - from src.pdf.renderer import render_pdf_to_images + from inference.pipeline.pipeline import InferencePipeline + from inference.pipeline.yolo_detector import YOLODetector + from shared.pdf.renderer import render_pdf_to_images # If we got here, all imports work correctly assert InferencePipeline is not None diff --git a/tests/web/test_inference_service.py b/tests/web/test_inference_service.py index 4aef00b..f3d4d32 100644 --- a/tests/web/test_inference_service.py +++ b/tests/web/test_inference_service.py @@ -10,8 +10,8 @@ from unittest.mock import Mock, patch from PIL import Image import io -from src.web.services.inference import InferenceService -from src.web.config import ModelConfig, StorageConfig +from inference.web.services.inference import InferenceService +from inference.web.config import ModelConfig, StorageConfig @pytest.fixture @@ -72,8 +72,8 @@ class TestInferenceServiceInitialization: gpu_available = inference_service.gpu_available assert isinstance(gpu_available, bool) - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') def test_initialize_imports_correctly( self, mock_yolo_detector, @@ -102,8 +102,8 @@ class TestInferenceServiceInitialization: mock_yolo_detector.assert_called_once() mock_pipeline.assert_called_once() - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') def test_initialize_sets_up_pipeline( self, mock_yolo_detector, @@ -135,8 +135,8 @@ class TestInferenceServiceInitialization: enable_fallback=True, ) - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') def test_initialize_idempotent( self, mock_yolo_detector, @@ -161,8 +161,8 @@ class TestInferenceServiceInitialization: class TestInferenceServiceProcessing: """Test inference processing methods.""" - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') @patch('ultralytics.YOLO') def test_process_image_basic_flow( self, @@ -197,8 +197,8 @@ class TestInferenceServiceProcessing: assert result.confidence == {"InvoiceNumber": 0.95} assert result.processing_time_ms > 0 - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') def test_process_image_handles_errors( self, mock_yolo_detector, @@ -228,9 +228,9 @@ class TestInferenceServiceProcessing: class TestInferenceServicePDFRendering: """Test PDF rendering imports.""" - @patch('src.inference.pipeline.InferencePipeline') - @patch('src.inference.yolo_detector.YOLODetector') - @patch('src.pdf.renderer.render_pdf_to_images') + @patch('inference.pipeline.pipeline.InferencePipeline') + @patch('inference.pipeline.yolo_detector.YOLODetector') + @patch('shared.pdf.renderer.render_pdf_to_images') @patch('ultralytics.YOLO') def test_pdf_visualization_imports_correctly( self, @@ -245,7 +245,7 @@ class TestInferenceServicePDFRendering: Test that _save_pdf_visualization imports render_pdf_to_images correctly. This catches the import error we had with: - from ..pdf.renderer (wrong) vs from src.pdf.renderer (correct) + from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct) """ # Setup mocks mock_detector_instance = Mock() diff --git a/tests/web/test_rate_limiter.py b/tests/web/test_rate_limiter.py index 5b191ff..d02eb7b 100644 --- a/tests/web/test_rate_limiter.py +++ b/tests/web/test_rate_limiter.py @@ -8,8 +8,8 @@ from unittest.mock import MagicMock import pytest -from src.data.async_request_db import ApiKeyConfig -from src.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus +from inference.data.async_request_db import ApiKeyConfig +from inference.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus class TestRateLimiter: diff --git a/tests/web/test_training_phase4.py b/tests/web/test_training_phase4.py index 41ae02d..1162985 100644 --- a/tests/web/test_training_phase4.py +++ b/tests/web/test_training_phase4.py @@ -9,8 +9,8 @@ from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient -from src.web.api.v1.admin.training import create_training_router -from src.web.core.auth import validate_admin_token, get_admin_db +from inference.web.api.v1.admin.training import create_training_router +from inference.web.core.auth import validate_admin_token, get_admin_db class MockTrainingTask: