This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -90,7 +90,23 @@
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")", "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")", "Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")", "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")" "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | head -150\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | tail -50\")",
"Bash(wsl bash -c \"lsof -ti:8000 | xargs -r kill -9 2>/dev/null; echo ''Port 8000 cleared''\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py\")",
"Bash(wsl bash -c \"curl -s http://localhost:3001 2>/dev/null | head -5 || echo ''Frontend not responding''\")",
"Bash(wsl bash -c \"curl -s http://localhost:3000 2>/dev/null | head -5 || echo ''Port 3000 not responding''\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c ''from shared.training import YOLOTrainer, TrainingConfig, TrainingResult; print\\(\"\"Shared training module imported successfully\"\"\\)''\")",
"Bash(npm run dev:*)",
"Bash(ping:*)",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/frontend && npm run dev\")",
"Bash(git checkout:*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && PGPASSWORD=$DB_PASSWORD psql -h 192.168.68.31 -U docmaster -d docmaster -f migrations/006_model_versions.sql 2>&1\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c \"\"\nimport os\nimport psycopg2\nfrom pathlib import Path\n\n# Get connection details\nhost = os.getenv\\(''DB_HOST'', ''192.168.68.31''\\)\nport = os.getenv\\(''DB_PORT'', ''5432''\\)\ndbname = os.getenv\\(''DB_NAME'', ''docmaster''\\)\nuser = os.getenv\\(''DB_USER'', ''docmaster''\\)\npassword = os.getenv\\(''DB_PASSWORD'', ''''\\)\n\nprint\\(f''Connecting to {host}:{port}/{dbname}...''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\nconn.autocommit = True\ncursor = conn.cursor\\(\\)\n\n# Run migration 006\nprint\\(''Running migration 006_model_versions.sql...''\\)\nsql = Path\\(''migrations/006_model_versions.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 006 complete!''\\)\n\n# Run migration 007\nprint\\(''Running migration 007_training_tasks_extra_columns.sql...''\\)\nsql = Path\\(''migrations/007_training_tasks_extra_columns.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 007 complete!''\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\nprint\\(''All migrations completed successfully!''\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")"
], ],
"deny": [], "deny": [],
"ask": [], "ask": [],

BIN
.coverage

Binary file not shown.

View File

@@ -8,6 +8,23 @@ DB_NAME=docmaster
DB_USER=docmaster DB_USER=docmaster
DB_PASSWORD=your_password_here DB_PASSWORD=your_password_here
# Storage Configuration
# Backend type: local, azure_blob, or s3
# All storage paths are relative to STORAGE_BASE_PATH (documents/, images/, uploads/, etc.)
STORAGE_BACKEND=local
STORAGE_BASE_PATH=./data
# Azure Blob Storage (when STORAGE_BACKEND=azure_blob)
# AZURE_STORAGE_CONNECTION_STRING=your_connection_string
# AZURE_STORAGE_CONTAINER=documents
# AWS S3 Storage (when STORAGE_BACKEND=s3)
# AWS_S3_BUCKET=your_bucket_name
# AWS_REGION=us-east-1
# AWS_ACCESS_KEY_ID=your_access_key
# AWS_SECRET_ACCESS_KEY=your_secret_key
# AWS_ENDPOINT_URL= # Optional: for S3-compatible services like MinIO
# Model Configuration (optional) # Model Configuration (optional)
# MODEL_PATH=runs/train/invoice_fields/weights/best.pt # MODEL_PATH=runs/train/invoice_fields/weights/best.pt
# CONFIDENCE_THRESHOLD=0.5 # CONFIDENCE_THRESHOLD=0.5

239
README.md
View File

@@ -7,8 +7,9 @@
本项目实现了一个完整的发票字段自动提取流程: 本项目实现了一个完整的发票字段自动提取流程:
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注 1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
2. **模型训练**: 使用 YOLOv11 训练字段检测模型 2. **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化 3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化
4. **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
### 架构 ### 架构
@@ -16,15 +17,17 @@
``` ```
packages/ packages/
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 工具) ├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 存储, 训练)
├── training/ # 训练服务 (GPU, 按需启动) ├── training/ # 训练服务 (GPU, 按需启动)
└── inference/ # 推理服务 (常驻运行) └── inference/ # 推理服务 (常驻运行)
frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
``` ```
| 服务 | 部署目标 | GPU | 生命周期 | | 服务 | 部署目标 | GPU | 生命周期 |
|------|---------|-----|---------| |------|---------|-----|---------|
| **Inference** | Azure App Service | 可选 | 常驻 7x24 | | **Frontend** | Vercel / Nginx | | 常驻 |
| **Training** | Azure ACI | 必需 | 按需启动/销毁 | | **Inference** | Azure App Service / AWS | 可选 | 常驻 7x24 |
| **Training** | Azure ACI / AWS ECS | 必需 | 按需启动/销毁 |
两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。 两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。
@@ -34,7 +37,8 @@ packages/
|------|------| |------|------|
| **已标注文档** | 9,738 (9,709 成功) | | **已标注文档** | 9,738 (9,709 成功) |
| **总体字段匹配率** | 94.8% (82,604/87,121) | | **总体字段匹配率** | 94.8% (82,604/87,121) |
| **测试** | 922 passed | | **测试** | 1,601 passed |
| **测试覆盖率** | 28% |
| **模型 mAP@0.5** | 93.5% | | **模型 mAP@0.5** | 93.5% |
**各字段匹配率:** **各字段匹配率:**
@@ -97,6 +101,9 @@ invoice-master-poc-v2/
│ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析 │ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析
│ │ ├── normalize/ # 字段规范化 (10 种 normalizer) │ │ ├── normalize/ # 字段规范化 (10 种 normalizer)
│ │ ├── matcher/ # 字段匹配 (精确/子串/模糊) │ │ ├── matcher/ # 字段匹配 (精确/子串/模糊)
│ │ ├── storage/ # 存储抽象层 (Local/Azure/S3)
│ │ ├── training/ # 共享训练组件 (YOLOTrainer)
│ │ ├── augmentation/ # 数据增强 (DatasetAugmenter)
│ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配) │ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配)
│ │ ├── data/ # DocumentDB, CSVLoader │ │ ├── data/ # DocumentDB, CSVLoader
│ │ ├── config.py # 全局配置 (数据库, 路径, DPI) │ │ ├── config.py # 全局配置 (数据库, 路径, DPI)
@@ -129,12 +136,29 @@ invoice-master-poc-v2/
│ ├── data/ # AdminDB, AsyncRequestDB, Models │ ├── data/ # AdminDB, AsyncRequestDB, Models
│ └── azure/ # ACI 训练触发器 │ └── azure/ # ACI 训练触发器
├── migrations/ # 数据库迁移 ├── frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
│ ├── 001_async_tables.sql │ ├── src/
│ ├── 002_nullable_admin_token.sql │ ├── api/ # API 客户端 (axios + react-query)
└── 003_training_tasks.sql │ ├── components/ # UI 组件
├── frontend/ # React 前端 (Vite + TypeScript) ├── Dashboard.tsx # 文档管理面板
├── tests/ # 测试 (922 tests) ├── Training.tsx # 训练管理 (数据集/任务)
│ │ │ ├── Models.tsx # 模型版本管理
│ │ │ ├── DatasetDetail.tsx # 数据集详情
│ │ │ └── InferenceDemo.tsx # 推理演示
│ │ └── hooks/ # React Query hooks
│ └── package.json
├── migrations/ # 数据库迁移 (SQL)
│ ├── 003_training_tasks.sql
│ ├── 004_training_datasets.sql
│ ├── 005_add_group_key.sql
│ ├── 006_model_versions.sql
│ ├── 007_training_tasks_extra_columns.sql
│ ├── 008_fix_model_versions_fk.sql
│ ├── 009_add_document_category.sql
│ └── 010_add_dataset_training_status.sql
├── tests/ # 测试 (1,601 tests)
├── docker-compose.yml # 本地开发 (postgres + inference + training) ├── docker-compose.yml # 本地开发 (postgres + inference + training)
├── run_server.py # 快捷启动脚本 ├── run_server.py # 快捷启动脚本
└── runs/train/ # 训练输出 (weights, curves) └── runs/train/ # 训练输出 (weights, curves)
@@ -270,9 +294,32 @@ Inference API PostgreSQL Training (ACI)
| POST | `/api/v1/admin/documents/upload` | 上传 PDF | | POST | `/api/v1/admin/documents/upload` | 上传 PDF |
| GET | `/api/v1/admin/documents/{id}` | 文档详情 | | GET | `/api/v1/admin/documents/{id}` | 文档详情 |
| PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 | | PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 |
| PATCH | `/api/v1/admin/documents/{id}/category` | 更新文档分类 |
| GET | `/api/v1/admin/documents/categories` | 获取分类列表 |
| POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 | | POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 |
| POST | `/api/v1/admin/training/trigger` | 触发训练任务 |
| GET | `/api/v1/admin/training/{id}/status` | 查询训练状态 | **Training API:**
| 方法 | 端点 | 描述 |
|------|------|------|
| POST | `/api/v1/admin/training/datasets` | 创建数据集 |
| GET | `/api/v1/admin/training/datasets` | 数据集列表 |
| GET | `/api/v1/admin/training/datasets/{id}` | 数据集详情 |
| DELETE | `/api/v1/admin/training/datasets/{id}` | 删除数据集 |
| POST | `/api/v1/admin/training/tasks` | 创建训练任务 |
| GET | `/api/v1/admin/training/tasks` | 任务列表 |
| GET | `/api/v1/admin/training/tasks/{id}` | 任务详情 |
| GET | `/api/v1/admin/training/tasks/{id}/logs` | 训练日志 |
**Model Versions API:**
| 方法 | 端点 | 描述 |
|------|------|------|
| GET | `/api/v1/admin/models` | 模型版本列表 |
| GET | `/api/v1/admin/models/{id}` | 模型详情 |
| POST | `/api/v1/admin/models/{id}/activate` | 激活模型 |
| POST | `/api/v1/admin/models/{id}/archive` | 归档模型 |
| DELETE | `/api/v1/admin/models/{id}` | 删除模型 |
## Python API ## Python API
@@ -332,8 +379,41 @@ print(f"Customer Number: {result}") # "UMJ 436-R"
| 数据库 | 用途 | 存储内容 | | 数据库 | 用途 | 存储内容 |
|--------|------|----------| |--------|------|----------|
| **PostgreSQL** | 标注结果 | `documents`, `field_results`, `training_tasks` | | **PostgreSQL** | 主数据库 | 文档、标注、训练任务、数据集、模型版本 |
| **SQLite** (AdminDB) | Web 应用 | 文档管理, 标注编辑, 用户认证 |
### 主要表
| 表名 | 说明 |
|------|------|
| `admin_documents` | 文档管理 (PDF 元数据, 状态, 分类) |
| `admin_annotations` | 标注数据 (YOLO 格式边界框) |
| `training_tasks` | 训练任务 (状态, 配置, 指标) |
| `training_datasets` | 数据集 (train/val/test 分割) |
| `dataset_documents` | 数据集-文档关联 |
| `model_versions` | 模型版本管理 (激活/归档) |
| `admin_tokens` | 管理员认证令牌 |
| `async_requests` | 异步推理请求 |
### 数据集状态
| 状态 | 说明 |
|------|------|
| `building` | 正在构建数据集 |
| `ready` | 数据集就绪,可开始训练 |
| `trained` | 已完成训练 |
| `failed` | 构建失败 |
| `archived` | 已归档 |
### 训练状态
| 状态 | 说明 |
|------|------|
| `pending` | 等待执行 |
| `scheduled` | 已计划 |
| `running` | 正在训练 |
| `completed` | 训练完成 |
| `failed` | 训练失败 |
| `cancelled` | 已取消 |
## 测试 ## 测试
@@ -347,8 +427,114 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
| 指标 | 数值 | | 指标 | 数值 |
|------|------| |------|------|
| **测试总数** | 922 | | **测试总数** | 1,601 |
| **通过率** | 100% | | **通过率** | 100% |
| **覆盖率** | 28% |
## 存储抽象层
统一的文件存储接口,支持多后端切换:
| 后端 | 用途 | 安装 |
|------|------|------|
| **Local** | 本地开发/测试 | 默认 |
| **Azure Blob** | Azure 云部署 | `pip install -e "packages/shared[azure]"` |
| **AWS S3** | AWS 云部署 | `pip install -e "packages/shared[s3]"` |
### 配置文件 (storage.yaml)
```yaml
backend: ${STORAGE_BACKEND:-local}
presigned_url_expiry: 3600
local:
base_path: ${STORAGE_BASE_PATH:-./data/storage}
azure:
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
s3:
bucket_name: ${AWS_S3_BUCKET}
region_name: ${AWS_REGION:-us-east-1}
```
### 使用示例
```python
from shared.storage import get_storage_backend
# 从配置文件加载
storage = get_storage_backend("storage.yaml")
# 上传文件
storage.upload(Path("local.pdf"), "documents/invoice.pdf")
# 获取预签名 URL (前端访问)
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=3600)
```
### 环境变量
| 变量 | 后端 | 说明 |
|------|------|------|
| `STORAGE_BACKEND` | 全部 | `local`, `azure_blob`, `s3` |
| `STORAGE_BASE_PATH` | Local | 本地存储路径 |
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | 连接字符串 |
| `AZURE_STORAGE_CONTAINER` | Azure | 容器名称 |
| `AWS_S3_BUCKET` | S3 | 存储桶名称 |
| `AWS_REGION` | S3 | 区域 (默认: us-east-1) |
## 数据增强
训练时支持多种数据增强策略:
| 增强类型 | 说明 |
|----------|------|
| `perspective_warp` | 透视变换 (模拟扫描角度) |
| `wrinkle` | 皱纹效果 |
| `edge_damage` | 边缘损坏 |
| `stain` | 污渍效果 |
| `lighting_variation` | 光照变化 |
| `shadow` | 阴影效果 |
| `gaussian_blur` | 高斯模糊 |
| `motion_blur` | 运动模糊 |
| `gaussian_noise` | 高斯噪声 |
| `salt_pepper` | 椒盐噪声 |
| `paper_texture` | 纸张纹理 |
| `scanner_artifacts` | 扫描伪影 |
增强配置示例:
```json
{
"augmentation": {
"gaussian_blur": { "enabled": true, "kernel_size": 5 },
"perspective_warp": { "enabled": true, "intensity": 0.1 }
},
"augmentation_multiplier": 2
}
```
## 前端功能
React 前端提供以下功能模块:
| 模块 | 功能 |
|------|------|
| **Dashboard** | 文档列表、上传、标注状态管理、分类筛选 |
| **Training** | 数据集创建/管理、训练任务配置、增强设置 |
| **Models** | 模型版本管理、激活/归档、指标查看 |
| **Inference Demo** | 实时推理演示、结果可视化 |
### 启动前端
```bash
cd frontend
npm install
npm run dev
# 访问 http://localhost:5173
```
## 技术栈 ## 技术栈
@@ -357,10 +543,27 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
| **目标检测** | YOLOv11 (Ultralytics) | | **目标检测** | YOLOv11 (Ultralytics) |
| **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) | | **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) |
| **PDF 处理** | PyMuPDF (fitz) | | **PDF 处理** | PyMuPDF (fitz) |
| **数据库** | PostgreSQL + psycopg2 | | **数据库** | PostgreSQL + SQLModel |
| **Web 框架** | FastAPI + Uvicorn | | **Web 框架** | FastAPI + Uvicorn |
| **前端** | React + TypeScript + Vite + TailwindCSS |
| **状态管理** | React Query (TanStack Query) |
| **深度学习** | PyTorch + CUDA 12.x | | **深度学习** | PyTorch + CUDA 12.x |
| **部署** | Docker + Azure ACI (训练) / App Service (推理) | | **部署** | Docker + Azure/AWS (训练) / App Service (推理) |
## 环境变量
| 变量 | 必需 | 说明 |
|------|------|------|
| `DB_PASSWORD` | 是 | PostgreSQL 密码 |
| `DB_HOST` | 否 | 数据库主机 (默认: localhost) |
| `DB_PORT` | 否 | 数据库端口 (默认: 5432) |
| `DB_NAME` | 否 | 数据库名 (默认: docmaster) |
| `DB_USER` | 否 | 数据库用户 (默认: docmaster) |
| `STORAGE_BASE_PATH` | 否 | 存储路径 (默认: ~/invoice-data/data) |
| `MODEL_PATH` | 否 | 模型路径 |
| `CONFIDENCE_THRESHOLD` | 否 | 置信度阈值 (默认: 0.5) |
| `SERVER_HOST` | 否 | 服务器主机 (默认: 0.0.0.0) |
| `SERVER_PORT` | 否 | 服务器端口 (默认: 8000) |
## 许可证 ## 许可证

View File

@@ -0,0 +1,772 @@
# AWS 部署方案完整指南
## 目录
- [核心问题](#核心问题)
- [存储方案](#存储方案)
- [训练方案](#训练方案)
- [推理方案](#推理方案)
- [价格对比](#价格对比)
- [推荐架构](#推荐架构)
- [实施步骤](#实施步骤)
- [AWS vs Azure 对比](#aws-vs-azure-对比)
---
## 核心问题
| 问题 | 答案 |
|------|------|
| S3 能用于训练吗? | 可以,用 Mountpoint for S3 或 SageMaker 原生支持 |
| 能实时从 S3 读取训练吗? | 可以SageMaker 支持 Pipe Mode 流式读取 |
| 本地能挂载 S3 吗? | 可以,用 s3fs-fuse 或 Rclone |
| EC2 空闲时收费吗? | 收费,只要运行就按小时计费 |
| 如何按需付费? | 用 SageMaker Managed Spot 或 Lambda |
| 推理服务用什么? | Lambda (Serverless) 或 ECS/Fargate (容器) |
---
## 存储方案
### Amazon S3推荐
S3 是 AWS 的核心存储服务,与 SageMaker 深度集成。
```bash
# 创建 S3 桶
aws s3 mb s3://invoice-training-data --region us-east-1
# 上传训练数据
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
# 创建目录结构
aws s3api put-object --bucket invoice-training-data --key datasets/
aws s3api put-object --bucket invoice-training-data --key models/
```
### Mountpoint for Amazon S3
AWS 官方的 S3 挂载客户端,性能优于 s3fs
```bash
# 安装 Mountpoint
wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb
sudo dpkg -i mount-s3.deb
# 挂载 S3
mkdir -p /mnt/s3-data
mount-s3 invoice-training-data /mnt/s3-data --region us-east-1
# 配置缓存(推荐)
mount-s3 invoice-training-data /mnt/s3-data \
--region us-east-1 \
--cache /tmp/s3-cache \
--metadata-ttl 60
```
### 本地开发挂载
**Linux/Mac (s3fs-fuse):**
```bash
# 安装
sudo apt-get install s3fs
# 配置凭证
echo ACCESS_KEY_ID:SECRET_ACCESS_KEY > ~/.passwd-s3fs
chmod 600 ~/.passwd-s3fs
# 挂载
s3fs invoice-training-data /mnt/s3 -o passwd_file=~/.passwd-s3fs
```
**Windows (Rclone):**
```powershell
# 安装
winget install Rclone.Rclone
# 配置
rclone config # 选择 s3
# 挂载
rclone mount aws:invoice-training-data Z: --vfs-cache-mode full
```
### 存储费用
| 层级 | 价格 | 适用场景 |
|------|------|---------|
| S3 Standard | $0.023/GB/月 | 频繁访问 |
| S3 Intelligent-Tiering | $0.023/GB/月 | 自动分层 |
| S3 Infrequent Access | $0.0125/GB/月 | 偶尔访问 |
| S3 Glacier | $0.004/GB/月 | 长期存档 |
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.12/月**
### SageMaker 数据输入模式
| 模式 | 说明 | 适用场景 |
|------|------|---------|
| File Mode | 下载到本地再训练 | 小数据集 |
| Pipe Mode | 流式读取,不占本地空间 | 大数据集 |
| FastFile Mode | 按需下载,最高 3x 加速 | 推荐 |
---
## 训练方案
### 方案总览
| 方案 | 适用场景 | 空闲费用 | 复杂度 | Spot 支持 |
|------|---------|---------|--------|----------|
| EC2 GPU | 简单直接 | 24/7 收费 | 低 | 是 |
| SageMaker Training | MLOps 集成 | 按任务计费 | 中 | 是 |
| EKS + GPU | Kubernetes | 复杂计费 | 高 | 是 |
### EC2 vs SageMaker
| 特性 | EC2 | SageMaker |
|------|-----|-----------|
| 本质 | 虚拟机 | 托管 ML 平台 |
| 计算费用 | $3.06/hr (p3.2xlarge) | $3.825/hr (+25%) |
| 管理开销 | 需自己配置 | 全托管 |
| Spot 折扣 | 最高 90% | 最高 90% |
| 实验跟踪 | 无 | 内置 |
| 自动关机 | 无 | 任务完成自动停止 |
### GPU 实例价格 (2025 年 6 月降价后)
| 实例 | GPU | 显存 | On-Demand | Spot 价格 |
|------|-----|------|-----------|----------|
| g4dn.xlarge | 1x T4 | 16GB | $0.526/hr | ~$0.16/hr |
| g4dn.2xlarge | 1x T4 | 16GB | $0.752/hr | ~$0.23/hr |
| p3.2xlarge | 1x V100 | 16GB | $3.06/hr | ~$0.92/hr |
| p3.8xlarge | 4x V100 | 64GB | $12.24/hr | ~$3.67/hr |
| p4d.24xlarge | 8x A100 | 320GB | $32.77/hr | ~$9.83/hr |
**注意**: 2025 年 6 月 AWS 宣布 P4/P5 系列最高降价 45%。
### Spot 实例
```bash
# EC2 Spot 请求
aws ec2 request-spot-instances \
--instance-count 1 \
--type "one-time" \
--launch-specification '{
"ImageId": "ami-0123456789abcdef0",
"InstanceType": "p3.2xlarge",
"KeyName": "my-key"
}'
```
### SageMaker Managed Spot Training
```python
from sagemaker.pytorch import PyTorch
estimator = PyTorch(
entry_point="train.py",
source_dir="./src",
role="arn:aws:iam::123456789012:role/SageMakerRole",
instance_count=1,
instance_type="ml.p3.2xlarge",
framework_version="2.0",
py_version="py310",
# 启用 Spot 实例
use_spot_instances=True,
max_run=3600, # 最长运行 1 小时
max_wait=7200, # 最长等待 2 小时
# 检查点配置Spot 中断恢复)
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
checkpoint_local_path="/opt/ml/checkpoints",
hyperparameters={
"epochs": 100,
"batch-size": 16,
}
)
estimator.fit({
"training": "s3://invoice-training-data/datasets/train/",
"validation": "s3://invoice-training-data/datasets/val/"
})
```
---
## 推理方案
### 方案对比
| 方案 | GPU 支持 | 扩缩容 | 冷启动 | 价格 | 适用场景 |
|------|---------|--------|--------|------|---------|
| Lambda | 否 | 自动 0-N | 快 | 按调用 | 低流量、CPU 推理 |
| Lambda + Container | 否 | 自动 0-N | 较慢 | 按调用 | 复杂依赖 |
| ECS Fargate | 否 | 自动 | 中 | ~$30/月 | 容器化服务 |
| ECS + EC2 GPU | 是 | 手动/自动 | 慢 | ~$100+/月 | GPU 推理 |
| SageMaker Endpoint | 是 | 自动 | 慢 | ~$80+/月 | MLOps 集成 |
| SageMaker Serverless | 否 | 自动 0-N | 中 | 按调用 | 间歇性流量 |
### 推荐方案 1: AWS Lambda (低流量)
对于 YOLO CPU 推理Lambda 最经济:
```python
# lambda_function.py
import json
import boto3
from ultralytics import YOLO
# 模型在 Lambda Layer 或 /tmp 加载
model = None
def load_model():
global model
if model is None:
# 从 S3 下载模型到 /tmp
s3 = boto3.client('s3')
s3.download_file('invoice-models', 'best.pt', '/tmp/best.pt')
model = YOLO('/tmp/best.pt')
return model
def lambda_handler(event, context):
model = load_model()
# 从 S3 获取图片
s3 = boto3.client('s3')
bucket = event['bucket']
key = event['key']
local_path = f'/tmp/{key.split("/")[-1]}'
s3.download_file(bucket, key, local_path)
# 执行推理
results = model.predict(local_path, conf=0.5)
return {
'statusCode': 200,
'body': json.dumps({
'fields': extract_fields(results),
'confidence': get_confidence(results)
})
}
```
**Lambda 配置:**
```yaml
# serverless.yml
service: invoice-inference
provider:
name: aws
runtime: python3.11
timeout: 30
memorySize: 4096 # 4GB 内存
functions:
infer:
handler: lambda_function.lambda_handler
events:
- http:
path: /infer
method: post
layers:
- arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
```
### 推荐方案 2: ECS Fargate (中流量)
```yaml
# task-definition.json
{
"family": "invoice-inference",
"networkMode": "awsvpc",
"requiresCompatibilities": ["FARGATE"],
"cpu": "2048",
"memory": "4096",
"containerDefinitions": [
{
"name": "inference",
"image": "123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest",
"portMappings": [
{
"containerPort": 8000,
"protocol": "tcp"
}
],
"environment": [
{"name": "MODEL_PATH", "value": "/app/models/best.pt"}
],
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/ecs/invoice-inference",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "ecs"
}
}
}
]
}
```
**Auto Scaling 配置:**
```bash
# 创建 Auto Scaling Target
aws application-autoscaling register-scalable-target \
--service-namespace ecs \
--resource-id service/invoice-cluster/invoice-service \
--scalable-dimension ecs:service:DesiredCount \
--min-capacity 1 \
--max-capacity 10
# 基于 CPU 使用率扩缩容
aws application-autoscaling put-scaling-policy \
--service-namespace ecs \
--resource-id service/invoice-cluster/invoice-service \
--scalable-dimension ecs:service:DesiredCount \
--policy-name cpu-scaling \
--policy-type TargetTrackingScaling \
--target-tracking-scaling-policy-configuration '{
"TargetValue": 70,
"PredefinedMetricSpecification": {
"PredefinedMetricType": "ECSServiceAverageCPUUtilization"
},
"ScaleOutCooldown": 60,
"ScaleInCooldown": 120
}'
```
### 方案 3: SageMaker Serverless Inference
```python
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.pytorch import PyTorchModel
model = PyTorchModel(
model_data="s3://invoice-models/model.tar.gz",
role="arn:aws:iam::123456789012:role/SageMakerRole",
entry_point="inference.py",
framework_version="2.0",
py_version="py310"
)
serverless_config = ServerlessInferenceConfig(
memory_size_in_mb=4096,
max_concurrency=10
)
predictor = model.deploy(
serverless_inference_config=serverless_config,
endpoint_name="invoice-inference-serverless"
)
```
### 推理性能对比
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|------|------------|---------|---------|
| Lambda 4GB | ~500-800ms | 按需扩展 | ~$15 (10K 请求) |
| Fargate 2vCPU 4GB | ~300-500ms | ~50 QPS | ~$30 |
| Fargate 4vCPU 8GB | ~200-300ms | ~100 QPS | ~$60 |
| EC2 g4dn.xlarge (T4) | ~50-100ms | ~200 QPS | ~$380 |
---
## 价格对比
### 训练成本对比(假设每天训练 2 小时)
| 方案 | 计算方式 | 月费 |
|------|---------|------|
| EC2 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
| EC2 按需启停 | 2h × 30天 × $3.06 | ~$184 |
| EC2 Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
| SageMaker On-Demand | 2h × 30天 × $3.825 | ~$230 |
| SageMaker Spot | 2h × 30天 × $1.15 | ~$69 |
### 本项目完整成本估算
| 组件 | 推荐方案 | 月费 |
|------|---------|------|
| 数据存储 | S3 Standard (5GB) | ~$0.12 |
| 数据库 | RDS PostgreSQL (db.t3.micro) | ~$15 |
| 推理服务 | Lambda (10K 请求/月) | ~$15 |
| 推理服务 (替代) | ECS Fargate | ~$30 |
| 训练服务 | SageMaker Spot (按需) | ~$2-5/次 |
| ECR (镜像存储) | 基本使用 | ~$1 |
| **总计 (Lambda)** | | **~$35/月** + 训练费 |
| **总计 (Fargate)** | | **~$50/月** + 训练费 |
---
## 推荐架构
### 整体架构图
```
┌─────────────────────────────────────┐
│ Amazon S3 │
│ ├── training-images/ │
│ ├── datasets/ │
│ ├── models/ │
│ └── checkpoints/ │
└─────────────────┬───────────────────┘
┌─────────────────────────────────┼─────────────────────────────────┐
│ │ │
▼ ▼ ▼
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
│ 推理服务 │ │ 训练服务 │ │ API Gateway │
│ │ │ │ │ │
│ 方案 A: Lambda │ │ SageMaker │ │ REST API │
│ ~$15/月 (10K req) │ │ Managed Spot │ │ 触发 Lambda/ECS │
│ │ │ ~$2-5/次训练 │ │ │
│ 方案 B: ECS Fargate │ │ │ │ │
│ ~$30/月 │ │ - 自动启动 │ │ │
│ │ │ - 训练完成自动停止 │ │ │
│ ┌───────────────────┐ │ │ - 检查点自动保存 │ │ │
│ │ FastAPI + YOLO │ │ │ │ │ │
│ │ CPU 推理 │ │ │ │ │ │
│ └───────────────────┘ │ └───────────┬───────────┘ └───────────────────────┘
└───────────┬───────────┘ │
│ │
└───────────────────────────────┼───────────────────────────────────────────┘
┌───────────────────────┐
│ Amazon RDS │
│ PostgreSQL │
│ db.t3.micro │
│ ~$15/月 │
└───────────────────────┘
```
### Lambda 推理配置
```yaml
# SAM template
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Resources:
InferenceFunction:
Type: AWS::Serverless::Function
Properties:
Handler: app.lambda_handler
Runtime: python3.11
MemorySize: 4096
Timeout: 30
Environment:
Variables:
MODEL_BUCKET: invoice-models
MODEL_KEY: best.pt
Policies:
- S3ReadPolicy:
BucketName: invoice-models
- S3ReadPolicy:
BucketName: invoice-uploads
Events:
InferApi:
Type: Api
Properties:
Path: /infer
Method: post
```
### SageMaker 训练配置
```python
from sagemaker.pytorch import PyTorch
estimator = PyTorch(
entry_point="train.py",
source_dir="./src",
role="arn:aws:iam::123456789012:role/SageMakerRole",
instance_count=1,
instance_type="ml.g4dn.xlarge", # T4 GPU
framework_version="2.0",
py_version="py310",
# Spot 实例配置
use_spot_instances=True,
max_run=7200,
max_wait=14400,
# 检查点
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
hyperparameters={
"epochs": 100,
"batch-size": 16,
"model": "yolo11n.pt"
}
)
```
---
## 实施步骤
### 阶段 1: 存储设置
```bash
# 创建 S3 桶
aws s3 mb s3://invoice-training-data --region us-east-1
aws s3 mb s3://invoice-models --region us-east-1
# 上传训练数据
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
# 配置生命周期(可选,自动转冷存储)
aws s3api put-bucket-lifecycle-configuration \
--bucket invoice-training-data \
--lifecycle-configuration '{
"Rules": [{
"ID": "MoveToIA",
"Status": "Enabled",
"Transitions": [{
"Days": 30,
"StorageClass": "STANDARD_IA"
}]
}]
}'
```
### 阶段 2: 数据库设置
```bash
# 创建 RDS PostgreSQL
aws rds create-db-instance \
--db-instance-identifier invoice-db \
--db-instance-class db.t3.micro \
--engine postgres \
--engine-version 15 \
--master-username docmaster \
--master-user-password YOUR_PASSWORD \
--allocated-storage 20
# 配置安全组
aws ec2 authorize-security-group-ingress \
--group-id sg-xxx \
--protocol tcp \
--port 5432 \
--source-group sg-yyy
```
### 阶段 3: 推理服务部署
**方案 A: Lambda**
```bash
# 创建 Lambda Layer (依赖)
cd lambda-layer
pip install ultralytics opencv-python-headless -t python/
zip -r layer.zip python/
aws lambda publish-layer-version \
--layer-name yolo-deps \
--zip-file fileb://layer.zip \
--compatible-runtimes python3.11
# 部署 Lambda 函数
cd ../lambda
zip function.zip lambda_function.py
aws lambda create-function \
--function-name invoice-inference \
--runtime python3.11 \
--handler lambda_function.lambda_handler \
--role arn:aws:iam::123456789012:role/LambdaRole \
--zip-file fileb://function.zip \
--memory-size 4096 \
--timeout 30 \
--layers arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
# 创建 API Gateway
aws apigatewayv2 create-api \
--name invoice-api \
--protocol-type HTTP \
--target arn:aws:lambda:us-east-1:123456789012:function:invoice-inference
```
**方案 B: ECS Fargate**
```bash
# 创建 ECR 仓库
aws ecr create-repository --repository-name invoice-inference
# 构建并推送镜像
aws ecr get-login-password | docker login --username AWS --password-stdin 123456789012.dkr.ecr.us-east-1.amazonaws.com
docker build -t invoice-inference .
docker tag invoice-inference:latest 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
docker push 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
# 创建 ECS 集群
aws ecs create-cluster --cluster-name invoice-cluster
# 注册任务定义
aws ecs register-task-definition --cli-input-json file://task-definition.json
# 创建服务
aws ecs create-service \
--cluster invoice-cluster \
--service-name invoice-service \
--task-definition invoice-inference \
--desired-count 1 \
--launch-type FARGATE \
--network-configuration '{
"awsvpcConfiguration": {
"subnets": ["subnet-xxx"],
"securityGroups": ["sg-xxx"],
"assignPublicIp": "ENABLED"
}
}'
```
### 阶段 4: 训练服务设置
```python
# setup_sagemaker.py
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
# 创建 SageMaker 执行角色
iam = boto3.client('iam')
role_arn = "arn:aws:iam::123456789012:role/SageMakerExecutionRole"
# 配置训练任务
estimator = PyTorch(
entry_point="train.py",
source_dir="./src/training",
role=role_arn,
instance_count=1,
instance_type="ml.g4dn.xlarge",
framework_version="2.0",
py_version="py310",
use_spot_instances=True,
max_run=7200,
max_wait=14400,
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
)
# 保存配置供后续使用
estimator.save("training_config.json")
```
### 阶段 5: 集成训练触发 API
```python
# lambda_trigger_training.py
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
def lambda_handler(event, context):
"""触发 SageMaker 训练任务"""
epochs = event.get('epochs', 100)
estimator = PyTorch(
entry_point="train.py",
source_dir="s3://invoice-training-data/code/",
role="arn:aws:iam::123456789012:role/SageMakerRole",
instance_count=1,
instance_type="ml.g4dn.xlarge",
framework_version="2.0",
py_version="py310",
use_spot_instances=True,
max_run=7200,
max_wait=14400,
hyperparameters={
"epochs": epochs,
"batch-size": 16,
}
)
estimator.fit(
inputs={
"training": "s3://invoice-training-data/datasets/train/",
"validation": "s3://invoice-training-data/datasets/val/"
},
wait=False # 异步执行
)
return {
'statusCode': 200,
'body': {
'training_job_name': estimator.latest_training_job.name,
'status': 'Started'
}
}
```
---
## AWS vs Azure 对比
### 服务对应关系
| 功能 | AWS | Azure |
|------|-----|-------|
| 对象存储 | S3 | Blob Storage |
| 挂载工具 | Mountpoint for S3 | BlobFuse2 |
| ML 平台 | SageMaker | Azure ML |
| 容器服务 | ECS/Fargate | Container Apps |
| Serverless | Lambda | Functions |
| GPU VM | EC2 P3/G4dn | NC/ND 系列 |
| 容器注册 | ECR | ACR |
| 数据库 | RDS PostgreSQL | PostgreSQL Flexible |
### 价格对比
| 组件 | AWS | Azure |
|------|-----|-------|
| 存储 (5GB) | ~$0.12/月 | ~$0.09/月 |
| 数据库 | ~$15/月 | ~$25/月 |
| 推理 (Serverless) | ~$15/月 | ~$30/月 |
| 推理 (容器) | ~$30/月 | ~$30/月 |
| 训练 (Spot GPU) | ~$2-5/次 | ~$1-5/次 |
| **总计** | **~$35-50/月** | **~$65/月** |
### 优劣对比
| 方面 | AWS 优势 | Azure 优势 |
|------|---------|-----------|
| 价格 | Lambda 更便宜 | GPU Spot 更便宜 |
| ML 平台 | SageMaker 更成熟 | Azure ML 更易用 |
| Serverless GPU | 无原生支持 | Container Apps GPU |
| 文档 | 更丰富 | 中文文档更好 |
| 生态 | 更大 | Office 365 集成 |
---
## 总结
### 推荐配置
| 组件 | 推荐方案 | 月费估算 |
|------|---------|---------|
| 数据存储 | S3 Standard | ~$0.12 |
| 数据库 | RDS db.t3.micro | ~$15 |
| 推理服务 | Lambda 4GB | ~$15 |
| 训练服务 | SageMaker Spot | 按需 ~$2-5/次 |
| ECR | 基本使用 | ~$1 |
| **总计** | | **~$35/月** + 训练费 |
### 关键决策
| 场景 | 选择 |
|------|------|
| 最低成本 | Lambda + SageMaker Spot |
| 稳定推理 | ECS Fargate |
| GPU 推理 | ECS + EC2 GPU |
| MLOps 集成 | SageMaker 全家桶 |
### 注意事项
1. **Lambda 冷启动**: 首次调用 ~3-5 秒,可用 Provisioned Concurrency 解决
2. **Spot 中断**: 配置检查点SageMaker 自动恢复
3. **S3 传输**: 同区域免费,跨区域收费
4. **Fargate 无 GPU**: 需要 GPU 必须用 ECS + EC2
5. **SageMaker 加价**: 比 EC2 贵 ~25%,但省管理成本

View File

@@ -0,0 +1,567 @@
# Azure 部署方案完整指南
## 目录
- [核心问题](#核心问题)
- [存储方案](#存储方案)
- [训练方案](#训练方案)
- [推理方案](#推理方案)
- [价格对比](#价格对比)
- [推荐架构](#推荐架构)
- [实施步骤](#实施步骤)
---
## 核心问题
| 问题 | 答案 |
|------|------|
| Azure Blob Storage 能用于训练吗? | 可以,用 BlobFuse2 挂载 |
| 能实时从 Blob 读取训练吗? | 可以,但建议配置本地缓存 |
| 本地能挂载 Azure Blob 吗? | 可以,用 Rclone (Windows) 或 BlobFuse2 (Linux) |
| VM 空闲时收费吗? | 收费,只要开机就按小时计费 |
| 如何按需付费? | 用 Serverless GPU 或 min=0 的 Compute Cluster |
| 推理服务用什么? | Container Apps (CPU) 或 Serverless GPU |
---
## 存储方案
### Azure Blob Storage + BlobFuse2推荐
```bash
# 安装 BlobFuse2
sudo apt-get install blobfuse2
# 配置文件
cat > ~/blobfuse-config.yaml << 'EOF'
logging:
type: syslog
level: log_warning
components:
- libfuse
- file_cache
- azstorage
file_cache:
path: /tmp/blobfuse2
timeout-sec: 120
max-size-mb: 4096
azstorage:
type: block
account-name: YOUR_ACCOUNT
account-key: YOUR_KEY
container: training-images
EOF
# 挂载
mkdir -p /mnt/azure-blob
blobfuse2 mount /mnt/azure-blob --config-file=~/blobfuse-config.yaml
```
### 本地开发Windows
```powershell
# 安装
winget install WinFsp.WinFsp
winget install Rclone.Rclone
# 配置
rclone config # 选择 azureblob
# 挂载为 Z: 盘
rclone mount azure:training-images Z: --vfs-cache-mode full
```
### 存储费用
| 层级 | 价格 | 适用场景 |
|------|------|---------|
| Hot | $0.018/GB/月 | 频繁访问 |
| Cool | $0.01/GB/月 | 偶尔访问 |
| Archive | $0.002/GB/月 | 长期存档 |
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.09/月**
---
## 训练方案
### 方案总览
| 方案 | 适用场景 | 空闲费用 | 复杂度 |
|------|---------|---------|--------|
| Azure VM | 简单直接 | 24/7 收费 | 低 |
| Azure VM Spot | 省钱、可中断 | 24/7 收费 | 低 |
| Azure ML Compute | MLOps 集成 | 可缩到 0 | 中 |
| Container Apps GPU | Serverless | 自动缩到 0 | 中 |
### Azure VM vs Azure ML
| 特性 | Azure VM | Azure ML |
|------|----------|----------|
| 本质 | 虚拟机 | 托管 ML 平台 |
| 计算费用 | $3.06/hr (NC6s_v3) | $3.06/hr (相同) |
| 附加费用 | ~$5/月 | ~$20-30/月 |
| 实验跟踪 | 无 | 内置 |
| 自动扩缩 | 无 | 支持 min=0 |
| 适用人群 | DevOps | 数据科学家 |
### Azure ML 附加费用明细
| 服务 | 用途 | 费用 |
|------|------|------|
| Container Registry | Docker 镜像 | ~$5-20/月 |
| Blob Storage | 日志、模型 | ~$0.10/月 |
| Application Insights | 监控 | ~$0-10/月 |
| Key Vault | 密钥管理 | <$1/月 |
### Spot 实例
两种平台都支持 Spot/低优先级实例,最高节省 90%
| 类型 | 正常价格 | Spot 价格 | 节省 |
|------|---------|----------|------|
| NC6s_v3 (V100) | $3.06/hr | ~$0.92/hr | 70% |
| NC24ads_A100_v4 | $3.67/hr | ~$1.15/hr | 69% |
### GPU 实例价格
| 实例 | GPU | 显存 | 价格/小时 | Spot 价格 |
|------|-----|------|---------|----------|
| NC6s_v3 | 1x V100 | 16GB | $3.06 | $0.92 |
| NC24s_v3 | 4x V100 | 64GB | $12.24 | $3.67 |
| NC24ads_A100_v4 | 1x A100 | 80GB | $3.67 | $1.15 |
| NC48ads_A100_v4 | 2x A100 | 160GB | $7.35 | $2.30 |
---
## 推理方案
### 方案对比
| 方案 | GPU 支持 | 扩缩容 | 价格 | 适用场景 |
|------|---------|--------|------|---------|
| Container Apps (CPU) | 否 | 自动 0-N | ~$30/月 | YOLO 推理 (够用) |
| Container Apps (GPU) | 是 | Serverless | 按秒计费 | 高吞吐推理 |
| Azure App Service | 否 | 手动/自动 | ~$50/月 | 简单部署 |
| Azure ML Endpoint | 是 | 自动 | ~$100+/月 | MLOps 集成 |
| AKS (Kubernetes) | 是 | 自动 | 复杂计费 | 大规模生产 |
### 推荐: Container Apps (CPU)
对于 YOLO 推理,**CPU 足够**,不需要 GPU
- YOLOv11n 在 CPU 上推理时间 ~200-500ms
- 比 GPU 便宜很多,适合中低流量
```yaml
# Container Apps 配置
name: invoice-inference
image: myacr.azurecr.io/invoice-inference:v1
resources:
cpu: 2.0
memory: 4Gi
scale:
minReplicas: 1 # 最少 1 个实例保持响应
maxReplicas: 10 # 最多扩展到 10 个
rules:
- name: http-scaling
http:
metadata:
concurrentRequests: "50" # 每实例 50 并发时扩容
```
### 推理服务代码示例
```python
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# 安装依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制代码和模型
COPY src/ ./src/
COPY models/best.pt ./models/
# 启动服务
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"]
```
```python
# src/web/app.py
from fastapi import FastAPI, UploadFile, File
from ultralytics import YOLO
import tempfile
app = FastAPI()
model = YOLO("models/best.pt")
@app.post("/api/v1/infer")
async def infer(file: UploadFile = File(...)):
# 保存上传文件
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
# 执行推理
results = model.predict(tmp_path, conf=0.5)
# 返回结果
return {
"fields": extract_fields(results),
"confidence": get_confidence(results)
}
@app.get("/health")
async def health():
return {"status": "healthy"}
```
### 部署命令
```bash
# 1. 创建 Container Registry
az acr create --name invoiceacr --resource-group myRG --sku Basic
# 2. 构建并推送镜像
az acr build --registry invoiceacr --image invoice-inference:v1 .
# 3. 创建 Container Apps 环境
az containerapp env create \
--name invoice-env \
--resource-group myRG \
--location eastus
# 4. 部署应用
az containerapp create \
--name invoice-inference \
--resource-group myRG \
--environment invoice-env \
--image invoiceacr.azurecr.io/invoice-inference:v1 \
--registry-server invoiceacr.azurecr.io \
--cpu 2 --memory 4Gi \
--min-replicas 1 --max-replicas 10 \
--ingress external --target-port 8000
# 5. 获取 URL
az containerapp show --name invoice-inference --resource-group myRG --query properties.configuration.ingress.fqdn
```
### 高吞吐场景: Serverless GPU
如果需要 GPU 加速推理(高并发、低延迟):
```bash
# 请求 GPU 配额
az containerapp env workload-profile add \
--name invoice-env \
--resource-group myRG \
--workload-profile-name gpu \
--workload-profile-type Consumption-GPU-T4
# 部署 GPU 版本
az containerapp create \
--name invoice-inference-gpu \
--resource-group myRG \
--environment invoice-env \
--image invoiceacr.azurecr.io/invoice-inference-gpu:v1 \
--workload-profile-name gpu \
--cpu 4 --memory 8Gi \
--min-replicas 0 --max-replicas 5 \
--ingress external --target-port 8000
```
### 推理性能对比
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|------|------------|---------|---------|
| CPU 2核 4GB | ~300-500ms | ~50 QPS | ~$30 |
| CPU 4核 8GB | ~200-300ms | ~100 QPS | ~$60 |
| GPU T4 | ~50-100ms | ~200 QPS | 按秒计费 |
| GPU A100 | ~20-50ms | ~500 QPS | 按秒计费 |
---
## 价格对比
### 月度成本对比(假设每天训练 2 小时)
| 方案 | 计算方式 | 月费 |
|------|---------|------|
| VM 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
| VM 按需启停 | 2h × 30天 × $3.06 | ~$184 |
| VM Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
| Serverless GPU | 2h × 30天 × ~$3.50 | ~$210 |
| Azure ML (min=0) | 2h × 30天 × $3.06 | ~$184 |
### 本项目完整成本估算
| 组件 | 推荐方案 | 月费 |
|------|---------|------|
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
| 数据库 | PostgreSQL Flexible (Burstable B1ms) | ~$25 |
| 推理服务 | Container Apps CPU (2核4GB) | ~$30 |
| 训练服务 | Azure ML Spot (按需) | ~$1-5/次 |
| Container Registry | Basic | ~$5 |
| **总计** | | **~$65/月** + 训练费 |
---
## 推荐架构
### 整体架构图
```
┌─────────────────────────────────────┐
│ Azure Blob Storage │
│ ├── training-images/ │
│ ├── datasets/ │
│ └── models/ │
└─────────────────┬───────────────────┘
┌─────────────────────────────────┼─────────────────────────────────┐
│ │ │
▼ ▼ ▼
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
│ 推理服务 (24/7) │ │ 训练服务 (按需) │ │ Web UI (可选) │
│ Container Apps │ │ Azure ML Compute │ │ Static Web Apps │
│ CPU 2核 4GB │ │ min=0, Spot │ │ ~$0 (免费层) │
│ ~$30/月 │ │ ~$1-5/次训练 │ │ │
│ │ │ │ │ │
│ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │
│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │
│ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │
│ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │
└───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘
│ │ │
└───────────────────────────────┼───────────────────────────────┘
┌───────────────────────┐
│ PostgreSQL │
│ Flexible Server │
│ Burstable B1ms │
│ ~$25/月 │
└───────────────────────┘
```
### 推理服务配置
```yaml
# Container Apps - CPU (24/7 运行)
name: invoice-inference
resources:
cpu: 2
memory: 4Gi
scale:
minReplicas: 1
maxReplicas: 10
env:
- name: MODEL_PATH
value: /app/models/best.pt
- name: DB_HOST
secretRef: db-host
- name: DB_PASSWORD
secretRef: db-password
```
### 训练服务配置
**方案 A: Azure ML Compute推荐**
```python
from azure.ai.ml.entities import AmlCompute
gpu_cluster = AmlCompute(
name="gpu-cluster",
size="Standard_NC6s_v3",
min_instances=0, # 空闲时关机
max_instances=1,
tier="LowPriority", # Spot 实例
idle_time_before_scale_down=120
)
```
**方案 B: Container Apps Serverless GPU**
```yaml
name: invoice-training
resources:
gpu: 1
gpuType: A100
scale:
minReplicas: 0
maxReplicas: 1
```
---
## 实施步骤
### 阶段 1: 存储设置
```bash
# 创建 Storage Account
az storage account create \
--name invoicestorage \
--resource-group myRG \
--sku Standard_LRS
# 创建容器
az storage container create --name training-images --account-name invoicestorage
az storage container create --name datasets --account-name invoicestorage
az storage container create --name models --account-name invoicestorage
# 上传训练数据
az storage blob upload-batch \
--destination training-images \
--source ./data/dataset/temp \
--account-name invoicestorage
```
### 阶段 2: 数据库设置
```bash
# 创建 PostgreSQL
az postgres flexible-server create \
--name invoice-db \
--resource-group myRG \
--sku-name Standard_B1ms \
--storage-size 32 \
--admin-user docmaster \
--admin-password YOUR_PASSWORD
# 配置防火墙
az postgres flexible-server firewall-rule create \
--name allow-azure \
--resource-group myRG \
--server-name invoice-db \
--start-ip-address 0.0.0.0 \
--end-ip-address 0.0.0.0
```
### 阶段 3: 推理服务部署
```bash
# 创建 Container Registry
az acr create --name invoiceacr --resource-group myRG --sku Basic
# 构建镜像
az acr build --registry invoiceacr --image invoice-inference:v1 .
# 创建环境
az containerapp env create \
--name invoice-env \
--resource-group myRG \
--location eastus
# 部署推理服务
az containerapp create \
--name invoice-inference \
--resource-group myRG \
--environment invoice-env \
--image invoiceacr.azurecr.io/invoice-inference:v1 \
--registry-server invoiceacr.azurecr.io \
--cpu 2 --memory 4Gi \
--min-replicas 1 --max-replicas 10 \
--ingress external --target-port 8000 \
--env-vars \
DB_HOST=invoice-db.postgres.database.azure.com \
DB_NAME=docmaster \
DB_USER=docmaster \
--secrets db-password=YOUR_PASSWORD
```
### 阶段 4: 训练服务设置
```bash
# 创建 Azure ML Workspace
az ml workspace create --name invoice-ml --resource-group myRG
# 创建 Compute Cluster
az ml compute create --name gpu-cluster \
--type AmlCompute \
--size Standard_NC6s_v3 \
--min-instances 0 \
--max-instances 1 \
--tier low_priority
```
### 阶段 5: 集成训练触发 API
```python
# src/web/routes/training.py
from fastapi import APIRouter
from azure.ai.ml import MLClient, command
from azure.identity import DefaultAzureCredential
router = APIRouter()
ml_client = MLClient(
credential=DefaultAzureCredential(),
subscription_id="your-subscription-id",
resource_group_name="myRG",
workspace_name="invoice-ml"
)
@router.post("/api/v1/train")
async def trigger_training(request: TrainingRequest):
"""触发 Azure ML 训练任务"""
training_job = command(
code="./training",
command=f"python train.py --epochs {request.epochs}",
environment="AzureML-pytorch-2.0-cuda11.8@latest",
compute="gpu-cluster",
)
job = ml_client.jobs.create_or_update(training_job)
return {
"job_id": job.name,
"status": job.status,
"studio_url": job.studio_url
}
@router.get("/api/v1/train/{job_id}/status")
async def get_training_status(job_id: str):
"""查询训练状态"""
job = ml_client.jobs.get(job_id)
return {"status": job.status}
```
---
## 总结
### 推荐配置
| 组件 | 推荐方案 | 月费估算 |
|------|---------|---------|
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
| 数据库 | PostgreSQL Flexible | ~$25 |
| 推理服务 | Container Apps CPU | ~$30 |
| 训练服务 | Azure ML (min=0, Spot) | 按需 ~$1-5/次 |
| Container Registry | Basic | ~$5 |
| **总计** | | **~$65/月** + 训练费 |
### 关键决策
| 场景 | 选择 |
|------|------|
| 偶尔训练,简单需求 | Azure VM Spot + 手动启停 |
| 需要 MLOps团队协作 | Azure ML Compute |
| 追求最低空闲成本 | Container Apps Serverless GPU |
| 生产环境推理 | Container Apps CPU |
| 高并发推理 | Container Apps Serverless GPU |
### 注意事项
1. **冷启动**: Serverless GPU 启动需要 3-8 分钟
2. **Spot 中断**: 可能被抢占,需要检查点机制
3. **网络延迟**: Blob Storage 挂载比本地 SSD 慢,建议开启缓存
4. **区域选择**: 选择有 GPU 配额的区域 (East US, West Europe 等)
5. **推理优化**: CPU 推理对于 YOLO 已经足够,无需 GPU

View File

@@ -4,11 +4,13 @@ import type {
DocumentDetailResponse, DocumentDetailResponse,
DocumentItem, DocumentItem,
UploadDocumentResponse, UploadDocumentResponse,
DocumentCategoriesResponse,
} from '../types' } from '../types'
export const documentsApi = { export const documentsApi = {
list: async (params?: { list: async (params?: {
status?: string status?: string
category?: string
limit?: number limit?: number
offset?: number offset?: number
}): Promise<DocumentListResponse> => { }): Promise<DocumentListResponse> => {
@@ -16,18 +18,29 @@ export const documentsApi = {
return data return data
}, },
getCategories: async (): Promise<DocumentCategoriesResponse> => {
const { data } = await apiClient.get('/api/v1/admin/documents/categories')
return data
},
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => { getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`) const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
return data return data
}, },
upload: async (file: File, groupKey?: string): Promise<UploadDocumentResponse> => { upload: async (
file: File,
options?: { groupKey?: string; category?: string }
): Promise<UploadDocumentResponse> => {
const formData = new FormData() const formData = new FormData()
formData.append('file', file) formData.append('file', file)
const params: Record<string, string> = {} const params: Record<string, string> = {}
if (groupKey) { if (options?.groupKey) {
params.group_key = groupKey params.group_key = options.groupKey
}
if (options?.category) {
params.category = options.category
} }
const { data } = await apiClient.post('/api/v1/admin/documents', formData, { const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
@@ -95,4 +108,15 @@ export const documentsApi = {
) )
return data return data
}, },
updateCategory: async (
documentId: string,
category: string
): Promise<{ status: string; document_id: string; category: string; message: string }> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/category`,
{ category }
)
return data
},
} }

View File

@@ -9,6 +9,7 @@ export interface DocumentItem {
auto_label_error: string | null auto_label_error: string | null
upload_source: string upload_source: string
group_key: string | null group_key: string | null
category: string
created_at: string created_at: string
updated_at: string updated_at: string
annotation_count?: number annotation_count?: number
@@ -61,6 +62,7 @@ export interface DocumentDetailResponse {
upload_source: string upload_source: string
batch_id: string | null batch_id: string | null
group_key: string | null group_key: string | null
category: string
csv_field_values: Record<string, string> | null csv_field_values: Record<string, string> | null
can_annotate: boolean can_annotate: boolean
annotation_lock_until: string | null annotation_lock_until: string | null
@@ -101,8 +103,21 @@ export interface TrainingTask {
updated_at: string updated_at: string
} }
export interface ModelVersionItem {
version_id: string
version: string
name: string
status: string
is_active: boolean
metrics_mAP: number | null
document_count: number
trained_at: string | null
activated_at: string | null
created_at: string
}
export interface TrainingModelsResponse { export interface TrainingModelsResponse {
models: TrainingTask[] models: ModelVersionItem[]
total: number total: number
limit: number limit: number
offset: number offset: number
@@ -118,11 +133,17 @@ export interface UploadDocumentResponse {
file_size: number file_size: number
page_count: number page_count: number
status: string status: string
category: string
group_key: string | null group_key: string | null
auto_label_started: boolean auto_label_started: boolean
message: string message: string
} }
export interface DocumentCategoriesResponse {
categories: string[]
total: number
}
export interface CreateAnnotationRequest { export interface CreateAnnotationRequest {
page_number: number page_number: number
class_id: number class_id: number
@@ -228,6 +249,8 @@ export interface DatasetDetailResponse {
name: string name: string
description: string | null description: string | null
status: string status: string
training_status: string | null
active_training_task_id: string | null
train_ratio: number train_ratio: number
val_ratio: number val_ratio: number
seed: number seed: number

View File

@@ -3,7 +3,7 @@ import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
import { Badge } from './Badge' import { Badge } from './Badge'
import { Button } from './Button' import { Button } from './Button'
import { UploadModal } from './UploadModal' import { UploadModal } from './UploadModal'
import { useDocuments } from '../hooks/useDocuments' import { useDocuments, useCategories } from '../hooks/useDocuments'
import type { DocumentItem } from '../api/types' import type { DocumentItem } from '../api/types'
interface DashboardProps { interface DashboardProps {
@@ -34,11 +34,15 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
const [isUploadOpen, setIsUploadOpen] = useState(false) const [isUploadOpen, setIsUploadOpen] = useState(false)
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set()) const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
const [statusFilter, setStatusFilter] = useState<string>('') const [statusFilter, setStatusFilter] = useState<string>('')
const [categoryFilter, setCategoryFilter] = useState<string>('')
const [limit] = useState(20) const [limit] = useState(20)
const [offset] = useState(0) const [offset] = useState(0)
const { categories } = useCategories()
const { documents, total, isLoading, error, refetch } = useDocuments({ const { documents, total, isLoading, error, refetch } = useDocuments({
status: statusFilter || undefined, status: statusFilter || undefined,
category: categoryFilter || undefined,
limit, limit,
offset, offset,
}) })
@@ -102,6 +106,24 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
</div> </div>
<div className="flex gap-3"> <div className="flex gap-3">
<div className="relative">
<select
value={categoryFilter}
onChange={(e) => setCategoryFilter(e.target.value)}
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
>
<option value="">All Categories</option>
{categories.map((cat) => (
<option key={cat} value={cat}>
{cat.charAt(0).toUpperCase() + cat.slice(1)}
</option>
))}
</select>
<ChevronDown
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
size={14}
/>
</div>
<div className="relative"> <div className="relative">
<select <select
value={statusFilter} value={statusFilter}
@@ -144,6 +166,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider"> <th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Annotations Annotations
</th> </th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Category
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider"> <th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Group Group
</th> </th>
@@ -156,13 +181,13 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<tbody> <tbody>
{isLoading ? ( {isLoading ? (
<tr> <tr>
<td colSpan={8} className="py-8 text-center text-warm-text-muted"> <td colSpan={9} className="py-8 text-center text-warm-text-muted">
Loading documents... Loading documents...
</td> </td>
</tr> </tr>
) : documents.length === 0 ? ( ) : documents.length === 0 ? (
<tr> <tr>
<td colSpan={8} className="py-8 text-center text-warm-text-muted"> <td colSpan={9} className="py-8 text-center text-warm-text-muted">
No documents found. Upload your first document to get started. No documents found. Upload your first document to get started.
</td> </td>
</tr> </tr>
@@ -216,6 +241,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
<td className="py-4 px-4 text-sm text-warm-text-secondary"> <td className="py-4 px-4 text-sm text-warm-text-secondary">
{doc.annotation_count || 0} annotations {doc.annotation_count || 0} annotations
</td> </td>
<td className="py-4 px-4 text-sm text-warm-text-secondary capitalize">
{doc.category || 'invoice'}
</td>
<td className="py-4 px-4 text-sm text-warm-text-muted"> <td className="py-4 px-4 text-sm text-warm-text-muted">
{doc.group_key || '-'} {doc.group_key || '-'}
</td> </td>

View File

@@ -1,5 +1,5 @@
import React from 'react' import React from 'react'
import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react' import { ArrowLeft, Loader2, Play, AlertCircle, Check, Award } from 'lucide-react'
import { Button } from './Button' import { Button } from './Button'
import { useDatasetDetail } from '../hooks/useDatasets' import { useDatasetDetail } from '../hooks/useDatasets'
@@ -14,6 +14,23 @@ const SPLIT_STYLES: Record<string, string> = {
test: 'bg-warm-state-success/10 text-warm-state-success', test: 'bg-warm-state-success/10 text-warm-state-success',
} }
const STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
building: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Building' },
ready: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Ready' },
trained: { bg: 'bg-purple-100', text: 'text-purple-700', label: 'Trained' },
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
archived: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Archived' },
}
const TRAINING_STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
pending: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Pending' },
scheduled: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Scheduled' },
running: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Training' },
completed: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Completed' },
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
cancelled: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Cancelled' },
}
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => { export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
const { dataset, isLoading, error } = useDatasetDetail(datasetId) const { dataset, isLoading, error } = useDatasetDetail(datasetId)
@@ -36,11 +53,25 @@ export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack
) )
} }
const statusIcon = dataset.status === 'ready' const statusConfig = STATUS_STYLES[dataset.status] || STATUS_STYLES.ready
const trainingStatusConfig = dataset.training_status
? TRAINING_STATUS_STYLES[dataset.training_status]
: null
// Determine if training button should be shown and enabled
const isTrainingInProgress = dataset.training_status === 'running' || dataset.training_status === 'pending'
const canStartTraining = dataset.status === 'ready' && !isTrainingInProgress
// Determine status icon
const statusIcon = dataset.status === 'trained'
? <Award size={14} className="text-purple-700" />
: dataset.status === 'ready'
? <Check size={14} className="text-warm-state-success" /> ? <Check size={14} className="text-warm-state-success" />
: dataset.status === 'failed' : dataset.status === 'failed'
? <AlertCircle size={14} className="text-warm-state-error" /> ? <AlertCircle size={14} className="text-warm-state-error" />
: <Loader2 size={14} className="animate-spin text-warm-state-info" /> : dataset.status === 'building'
? <Loader2 size={14} className="animate-spin text-warm-state-info" />
: null
return ( return (
<div className="p-8 max-w-7xl mx-auto"> <div className="p-8 max-w-7xl mx-auto">
@@ -51,15 +82,38 @@ export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack
<div className="flex items-center justify-between mb-6"> <div className="flex items-center justify-between mb-6">
<div> <div>
<div className="flex items-center gap-3 mb-1">
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2"> <h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
{dataset.name} {statusIcon} {dataset.name} {statusIcon}
</h2> </h2>
{/* Status Badge */}
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${statusConfig.bg} ${statusConfig.text}`}>
{statusConfig.label}
</span>
{/* Training Status Badge */}
{trainingStatusConfig && (
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${trainingStatusConfig.bg} ${trainingStatusConfig.text}`}>
{isTrainingInProgress && <Loader2 size={12} className="mr-1 animate-spin" />}
{trainingStatusConfig.label}
</span>
)}
</div>
{dataset.description && ( {dataset.description && (
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p> <p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
)} )}
</div> </div>
{dataset.status === 'ready' && ( {/* Training Button */}
<Button><Play size={14} className="mr-1" />Start Training</Button> {(dataset.status === 'ready' || dataset.status === 'trained') && (
<Button
disabled={isTrainingInProgress}
className={isTrainingInProgress ? 'opacity-50 cursor-not-allowed' : ''}
>
{isTrainingInProgress ? (
<><Loader2 size={14} className="mr-1 animate-spin" />Training...</>
) : (
<><Play size={14} className="mr-1" />Start Training</>
)}
</Button>
)} )}
</div> </div>

View File

@@ -72,12 +72,13 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({}) const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2) const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
// Fetch available trained models // Fetch available trained models (active or inactive, not archived)
const { data: modelsData } = useQuery({ const { data: modelsData } = useQuery({
queryKey: ['training', 'models', 'completed'], queryKey: ['training', 'models', 'available'],
queryFn: () => trainingApi.getModels({ status: 'completed' }), queryFn: () => trainingApi.getModels(),
}) })
const completedModels = modelsData?.models ?? [] // Filter out archived models - only show active/inactive models for base model selection
const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived')
const handleSubmit = () => { const handleSubmit = () => {
onSubmit({ onSubmit({
@@ -128,9 +129,9 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
> >
<option value="pretrained">yolo11n.pt (Pretrained)</option> <option value="pretrained">yolo11n.pt (Pretrained)</option>
{completedModels.map(m => ( {availableModels.map(m => (
<option key={m.task_id} value={m.task_id}> <option key={m.version_id} value={m.version_id}>
{m.name} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'}) {m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
</option> </option>
))} ))}
</select> </select>
@@ -293,8 +294,12 @@ const DatasetList: React.FC<{
</button> </button>
)} )}
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)} <button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
disabled={isDeleting} disabled={isDeleting || ds.status === 'pending' || ds.status === 'building'}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors"> className={`p-1.5 rounded transition-colors ${
ds.status === 'pending' || ds.status === 'building'
? 'text-warm-text-muted/40 cursor-not-allowed'
: 'hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error'
}`}>
<Trash2 size={14} /> <Trash2 size={14} />
</button> </button>
</div> </div>

View File

@@ -1,7 +1,7 @@
import React, { useState, useRef } from 'react' import React, { useState, useRef } from 'react'
import { X, UploadCloud, File, CheckCircle, AlertCircle } from 'lucide-react' import { X, UploadCloud, File, CheckCircle, AlertCircle, ChevronDown } from 'lucide-react'
import { Button } from './Button' import { Button } from './Button'
import { useDocuments } from '../hooks/useDocuments' import { useDocuments, useCategories } from '../hooks/useDocuments'
interface UploadModalProps { interface UploadModalProps {
isOpen: boolean isOpen: boolean
@@ -12,11 +12,13 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
const [isDragging, setIsDragging] = useState(false) const [isDragging, setIsDragging] = useState(false)
const [selectedFiles, setSelectedFiles] = useState<File[]>([]) const [selectedFiles, setSelectedFiles] = useState<File[]>([])
const [groupKey, setGroupKey] = useState('') const [groupKey, setGroupKey] = useState('')
const [category, setCategory] = useState('invoice')
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle') const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
const [errorMessage, setErrorMessage] = useState('') const [errorMessage, setErrorMessage] = useState('')
const fileInputRef = useRef<HTMLInputElement>(null) const fileInputRef = useRef<HTMLInputElement>(null)
const { uploadDocument, isUploading } = useDocuments({}) const { uploadDocument, isUploading } = useDocuments({})
const { categories } = useCategories()
if (!isOpen) return null if (!isOpen) return null
@@ -63,7 +65,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
for (const file of selectedFiles) { for (const file of selectedFiles) {
await new Promise<void>((resolve, reject) => { await new Promise<void>((resolve, reject) => {
uploadDocument( uploadDocument(
{ file, groupKey: groupKey || undefined }, { file, groupKey: groupKey || undefined, category: category || 'invoice' },
{ {
onSuccess: () => resolve(), onSuccess: () => resolve(),
onError: (error: Error) => reject(error), onError: (error: Error) => reject(error),
@@ -77,6 +79,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
onClose() onClose()
setSelectedFiles([]) setSelectedFiles([])
setGroupKey('') setGroupKey('')
setCategory('invoice')
setUploadStatus('idle') setUploadStatus('idle')
}, 1500) }, 1500)
} catch (error) { } catch (error) {
@@ -91,6 +94,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
} }
setSelectedFiles([]) setSelectedFiles([])
setGroupKey('') setGroupKey('')
setCategory('invoice')
setUploadStatus('idle') setUploadStatus('idle')
setErrorMessage('') setErrorMessage('')
onClose() onClose()
@@ -179,6 +183,42 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
</div> </div>
)} )}
{/* Category Select */}
{selectedFiles.length > 0 && (
<div className="mb-6">
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
Category
</label>
<div className="relative">
<select
value={category}
onChange={(e) => setCategory(e.target.value)}
className="w-full h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none cursor-pointer"
disabled={uploadStatus === 'uploading'}
>
<option value="invoice">Invoice</option>
<option value="letter">Letter</option>
<option value="receipt">Receipt</option>
<option value="contract">Contract</option>
{categories
.filter((cat) => !['invoice', 'letter', 'receipt', 'contract'].includes(cat))
.map((cat) => (
<option key={cat} value={cat}>
{cat.charAt(0).toUpperCase() + cat.slice(1)}
</option>
))}
</select>
<ChevronDown
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
size={14}
/>
</div>
<p className="text-xs text-warm-text-muted mt-1">
Select document type for training different models
</p>
</div>
)}
{/* Group Key Input */} {/* Group Key Input */}
{selectedFiles.length > 0 && ( {selectedFiles.length > 0 && (
<div className="mb-6"> <div className="mb-6">

View File

@@ -1,4 +1,4 @@
export { useDocuments } from './useDocuments' export { useDocuments, useCategories } from './useDocuments'
export { useDocumentDetail } from './useDocumentDetail' export { useDocumentDetail } from './useDocumentDetail'
export { useAnnotations } from './useAnnotations' export { useAnnotations } from './useAnnotations'
export { useTraining, useTrainingDocuments } from './useTraining' export { useTraining, useTrainingDocuments } from './useTraining'

View File

@@ -1,9 +1,10 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query' import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { documentsApi } from '../api/endpoints' import { documentsApi } from '../api/endpoints'
import type { DocumentListResponse, UploadDocumentResponse } from '../api/types' import type { DocumentListResponse, DocumentCategoriesResponse } from '../api/types'
interface UseDocumentsParams { interface UseDocumentsParams {
status?: string status?: string
category?: string
limit?: number limit?: number
offset?: number offset?: number
} }
@@ -18,10 +19,11 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
}) })
const uploadMutation = useMutation({ const uploadMutation = useMutation({
mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) => mutationFn: ({ file, groupKey, category }: { file: File; groupKey?: string; category?: string }) =>
documentsApi.upload(file, groupKey), documentsApi.upload(file, { groupKey, category }),
onSuccess: () => { onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] }) queryClient.invalidateQueries({ queryKey: ['documents'] })
queryClient.invalidateQueries({ queryKey: ['categories'] })
}, },
}) })
@@ -63,6 +65,15 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
}, },
}) })
const updateCategoryMutation = useMutation({
mutationFn: ({ documentId, category }: { documentId: string; category: string }) =>
documentsApi.updateCategory(documentId, category),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
queryClient.invalidateQueries({ queryKey: ['categories'] })
},
})
return { return {
documents: data?.documents || [], documents: data?.documents || [],
total: data?.total || 0, total: data?.total || 0,
@@ -86,5 +97,24 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
updateGroupKey: updateGroupKeyMutation.mutate, updateGroupKey: updateGroupKeyMutation.mutate,
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync, updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
isUpdatingGroupKey: updateGroupKeyMutation.isPending, isUpdatingGroupKey: updateGroupKeyMutation.isPending,
updateCategory: updateCategoryMutation.mutate,
updateCategoryAsync: updateCategoryMutation.mutateAsync,
isUpdatingCategory: updateCategoryMutation.isPending,
}
}
export const useCategories = () => {
const { data, isLoading, error, refetch } = useQuery<DocumentCategoriesResponse>({
queryKey: ['categories'],
queryFn: () => documentsApi.getCategories(),
staleTime: 60000,
})
return {
categories: data?.categories || [],
total: data?.total || 0,
isLoading,
error,
refetch,
} }
} }

View File

@@ -0,0 +1,13 @@
-- Add category column to admin_documents table
-- Allows categorizing documents for training different models (e.g., invoice, letter, receipt)
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
-- Update existing NULL values to default
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
-- Make it NOT NULL after setting defaults
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
-- Create index for category filtering
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);

View File

@@ -0,0 +1,28 @@
-- Migration: Add training_status and active_training_task_id to training_datasets
-- Description: Track training status separately from dataset build status
-- Add training_status column
ALTER TABLE training_datasets
ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
-- Add active_training_task_id column
ALTER TABLE training_datasets
ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
-- Create index for training_status
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status
ON training_datasets(training_status);
-- Create index for active_training_task_id
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id
ON training_datasets(active_training_task_id);
-- Update existing datasets that have been used in completed training tasks to 'trained' status
UPDATE training_datasets d
SET status = 'trained'
WHERE d.status = 'ready'
AND EXISTS (
SELECT 1 FROM training_tasks t
WHERE t.dataset_id = d.dataset_id
AND t.status = 'completed'
);

View File

@@ -120,7 +120,7 @@ def main() -> None:
logger.info("=" * 60) logger.info("=" * 60)
# Create config # Create config
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig from inference.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig
config = AppConfig( config = AppConfig(
model=ModelConfig( model=ModelConfig(
@@ -136,7 +136,7 @@ def main() -> None:
reload=args.reload, reload=args.reload,
workers=args.workers, workers=args.workers,
), ),
storage=StorageConfig(), file=FileConfig(),
) )
# Create and run app # Create and run app

View File

@@ -112,6 +112,7 @@ class AdminDB:
upload_source: str = "ui", upload_source: str = "ui",
csv_field_values: dict[str, Any] | None = None, csv_field_values: dict[str, Any] | None = None,
group_key: str | None = None, group_key: str | None = None,
category: str = "invoice",
admin_token: str | None = None, # Deprecated, kept for compatibility admin_token: str | None = None, # Deprecated, kept for compatibility
) -> str: ) -> str:
"""Create a new document record.""" """Create a new document record."""
@@ -125,6 +126,7 @@ class AdminDB:
upload_source=upload_source, upload_source=upload_source,
csv_field_values=csv_field_values, csv_field_values=csv_field_values,
group_key=group_key, group_key=group_key,
category=category,
) )
session.add(document) session.add(document)
session.flush() session.flush()
@@ -154,6 +156,7 @@ class AdminDB:
has_annotations: bool | None = None, has_annotations: bool | None = None,
auto_label_status: str | None = None, auto_label_status: str | None = None,
batch_id: str | None = None, batch_id: str | None = None,
category: str | None = None,
limit: int = 20, limit: int = 20,
offset: int = 0, offset: int = 0,
) -> tuple[list[AdminDocument], int]: ) -> tuple[list[AdminDocument], int]:
@@ -171,6 +174,8 @@ class AdminDB:
where_clauses.append(AdminDocument.auto_label_status == auto_label_status) where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
if batch_id: if batch_id:
where_clauses.append(AdminDocument.batch_id == UUID(batch_id)) where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
if category:
where_clauses.append(AdminDocument.category == category)
# Count query # Count query
count_stmt = select(func.count()).select_from(AdminDocument) count_stmt = select(func.count()).select_from(AdminDocument)
@@ -283,6 +288,32 @@ class AdminDB:
return True return True
return False return False
def get_document_categories(self) -> list[str]:
"""Get list of unique document categories."""
with get_session_context() as session:
statement = (
select(AdminDocument.category)
.distinct()
.order_by(AdminDocument.category)
)
categories = session.exec(statement).all()
return [c for c in categories if c is not None]
def update_document_category(
self, document_id: str, category: str
) -> AdminDocument | None:
"""Update document category."""
with get_session_context() as session:
document = session.get(AdminDocument, UUID(document_id))
if document:
document.category = category
document.updated_at = datetime.utcnow()
session.add(document)
session.commit()
session.refresh(document)
return document
return None
# ========================================================================== # ==========================================================================
# Annotation Operations # Annotation Operations
# ========================================================================== # ==========================================================================
@@ -1292,6 +1323,36 @@ class AdminDB:
session.add(dataset) session.add(dataset)
session.commit() session.commit()
def update_dataset_training_status(
self,
dataset_id: str | UUID,
training_status: str | None,
active_training_task_id: str | UUID | None = None,
update_main_status: bool = False,
) -> None:
"""Update dataset training status and optionally the main status.
Args:
dataset_id: Dataset UUID
training_status: Training status (pending, running, completed, failed, cancelled)
active_training_task_id: Currently active training task ID
update_main_status: If True and training_status is 'completed', set main status to 'trained'
"""
with get_session_context() as session:
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
if not dataset:
return
dataset.training_status = training_status
dataset.active_training_task_id = (
UUID(str(active_training_task_id)) if active_training_task_id else None
)
dataset.updated_at = datetime.utcnow()
# Update main status to 'trained' when training completes
if update_main_status and training_status == "completed":
dataset.status = "trained"
session.add(dataset)
session.commit()
def add_dataset_documents( def add_dataset_documents(
self, self,
dataset_id: str | UUID, dataset_id: str | UUID,

View File

@@ -11,23 +11,8 @@ from uuid import UUID, uuid4
from sqlmodel import Field, SQLModel, Column, JSON from sqlmodel import Field, SQLModel, Column, JSON
# Import field mappings from single source of truth
# ============================================================================= from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS
# CSV to Field Class Mapping
# =============================================================================
CSV_TO_CLASS_MAPPING: dict[str, int] = {
"InvoiceNumber": 0, # invoice_number
"InvoiceDate": 1, # invoice_date
"InvoiceDueDate": 2, # invoice_due_date
"OCR": 3, # ocr_number
"Bankgiro": 4, # bankgiro
"Plusgiro": 5, # plusgiro
"Amount": 6, # amount
"supplier_organisation_number": 7, # supplier_organisation_number
# 8: payment_line (derived from OCR/Bankgiro/Amount)
"customer_number": 9, # customer_number
}
# ============================================================================= # =============================================================================
@@ -72,6 +57,8 @@ class AdminDocument(SQLModel, table=True):
# Link to batch upload (if uploaded via ZIP) # Link to batch upload (if uploaded via ZIP)
group_key: str | None = Field(default=None, max_length=255, index=True) group_key: str | None = Field(default=None, max_length=255, index=True)
# User-defined grouping key for document organization # User-defined grouping key for document organization
category: str = Field(default="invoice", max_length=100, index=True)
# Document category for training different models (e.g., invoice, letter, receipt)
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
# Original CSV values for reference # Original CSV values for reference
auto_label_queued_at: datetime | None = Field(default=None) auto_label_queued_at: datetime | None = Field(default=None)
@@ -237,7 +224,10 @@ class TrainingDataset(SQLModel, table=True):
name: str = Field(max_length=255) name: str = Field(max_length=255)
description: str | None = Field(default=None) description: str | None = Field(default=None)
status: str = Field(default="building", max_length=20, index=True) status: str = Field(default="building", max_length=20, index=True)
# Status: building, ready, training, archived, failed # Status: building, ready, trained, archived, failed
training_status: str | None = Field(default=None, max_length=20, index=True)
# Training status: pending, scheduled, running, completed, failed, cancelled
active_training_task_id: UUID | None = Field(default=None, index=True)
train_ratio: float = Field(default=0.8) train_ratio: float = Field(default=0.8)
val_ratio: float = Field(default=0.1) val_ratio: float = Field(default=0.1)
seed: int = Field(default=42) seed: int = Field(default=42)
@@ -354,21 +344,8 @@ class AnnotationHistory(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow, index=True) created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
# Field class mapping (same as src/cli/train.py) # FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields
FIELD_CLASSES = { # This ensures consistency with the trained YOLO model
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_organisation_number",
8: "payment_line",
9: "customer_number",
}
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
# Read-only models for API responses # Read-only models for API responses
@@ -383,6 +360,7 @@ class AdminDocumentRead(SQLModel):
status: str status: str
auto_label_status: str | None auto_label_status: str | None
auto_label_error: str | None auto_label_error: str | None
category: str = "invoice"
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime

View File

@@ -141,6 +141,40 @@ def run_migrations() -> None:
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id); CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
""", """,
), ),
# Migration 009: Add category to admin_documents
(
"admin_documents_category",
"""
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
""",
),
# Migration 010: Add training_status and active_training_task_id to training_datasets
(
"training_datasets_training_status",
"""
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status);
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id);
""",
),
# Migration 010b: Update existing datasets with completed training to 'trained' status
(
"training_datasets_update_trained_status",
"""
UPDATE training_datasets d
SET status = 'trained'
WHERE d.status = 'ready'
AND EXISTS (
SELECT 1 FROM training_tasks t
WHERE t.dataset_id = d.dataset_id
AND t.status = 'completed'
);
""",
),
] ]
with engine.connect() as conn: with engine.connect() as conn:

View File

@@ -21,7 +21,8 @@ import re
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .yolo_detector import Detection, CLASS_TO_FIELD from shared.fields import CLASS_TO_FIELD
from .yolo_detector import Detection
# Import shared utilities for text cleaning and validation # Import shared utilities for text cleaning and validation
from shared.utils.text_cleaner import TextCleaner from shared.utils.text_cleaner import TextCleaner

View File

@@ -10,7 +10,8 @@ from typing import Any
import time import time
import re import re
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD from shared.fields import CLASS_TO_FIELD
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor, ExtractedField from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser from .payment_line_parser import PaymentLineParser

View File

@@ -9,6 +9,9 @@ from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
# Import field mappings from single source of truth
from shared.fields import CLASS_NAMES, CLASS_TO_FIELD
@dataclass @dataclass
class Detection: class Detection:
@@ -72,33 +75,8 @@ class Detection:
return (x0, y0, x1, y1) return (x0, y0, x1, y1)
# Class names (must match training configuration) # CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields
CLASS_NAMES = [ # This ensures consistency with the trained YOLO model
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number', # Matches training class name
'customer_number',
'payment_line', # Machine code payment line at bottom of invoice
]
# Mapping from class name to field name
CLASS_TO_FIELD = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'ocr_number': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'amount': 'Amount',
'supplier_org_number': 'supplier_org_number',
'customer_number': 'customer_number',
'payment_line': 'payment_line',
}
class YOLODetector: class YOLODetector:

View File

@@ -4,18 +4,19 @@ Admin Annotation API Routes
FastAPI endpoints for annotation management. FastAPI endpoints for annotation management.
""" """
import io
import logging import logging
from pathlib import Path
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse from fastapi.responses import FileResponse, StreamingResponse
from inference.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
from inference.web.core.auth import AdminTokenDep, AdminDBDep from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.services.autolabel import get_auto_label_service from inference.web.services.autolabel import get_auto_label_service
from inference.web.services.storage_helpers import get_storage_helper
from inference.web.schemas.admin import ( from inference.web.schemas.admin import (
AnnotationCreate, AnnotationCreate,
AnnotationItem, AnnotationItem,
@@ -35,9 +36,6 @@ from inference.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Image storage directory
ADMIN_IMAGES_DIR = Path("data/admin_images")
def _validate_uuid(value: str, name: str = "ID") -> None: def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format.""" """Validate UUID format."""
@@ -60,7 +58,9 @@ def create_annotation_router() -> APIRouter:
@router.get( @router.get(
"/{document_id}/images/{page_number}", "/{document_id}/images/{page_number}",
response_model=None,
responses={ responses={
200: {"content": {"image/png": {}}, "description": "Page image"},
401: {"model": ErrorResponse, "description": "Invalid token"}, 401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Not found"}, 404: {"model": ErrorResponse, "description": "Not found"},
}, },
@@ -72,7 +72,7 @@ def create_annotation_router() -> APIRouter:
page_number: int, page_number: int,
admin_token: AdminTokenDep, admin_token: AdminTokenDep,
db: AdminDBDep, db: AdminDBDep,
) -> FileResponse: ) -> FileResponse | StreamingResponse:
"""Get page image.""" """Get page image."""
_validate_uuid(document_id, "document_id") _validate_uuid(document_id, "document_id")
@@ -91,20 +91,35 @@ def create_annotation_router() -> APIRouter:
detail=f"Page {page_number} not found. Document has {document.page_count} pages.", detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
) )
# Find image file # Get storage helper
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png" storage = get_storage_helper()
if not image_path.exists():
# Check if image exists
if not storage.admin_image_exists(document_id, page_number):
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Image for page {page_number} not found", detail=f"Image for page {page_number} not found",
) )
# Try to get local path for efficient file serving
local_path = storage.get_admin_image_local_path(document_id, page_number)
if local_path is not None:
return FileResponse( return FileResponse(
path=str(image_path), path=str(local_path),
media_type="image/png", media_type="image/png",
filename=f"{document.filename}_page_{page_number}.png", filename=f"{document.filename}_page_{page_number}.png",
) )
# Fall back to streaming for cloud storage
image_content = storage.get_admin_image(document_id, page_number)
return StreamingResponse(
io.BytesIO(image_content),
media_type="image/png",
headers={
"Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"'
},
)
# ========================================================================= # =========================================================================
# Annotation Endpoints # Annotation Endpoints
# ========================================================================= # =========================================================================
@@ -210,16 +225,14 @@ def create_annotation_router() -> APIRouter:
) )
# Get image dimensions for normalization # Get image dimensions for normalization
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png" storage = get_storage_helper()
if not image_path.exists(): dimensions = storage.get_admin_image_dimensions(document_id, request.page_number)
if dimensions is None:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Image for page {request.page_number} not available", detail=f"Image for page {request.page_number} not available",
) )
image_width, image_height = dimensions
from PIL import Image
with Image.open(image_path) as img:
image_width, image_height = img.size
# Calculate normalized coordinates # Calculate normalized coordinates
x_center = (request.bbox.x + request.bbox.width / 2) / image_width x_center = (request.bbox.x + request.bbox.width / 2) / image_width
@@ -315,10 +328,14 @@ def create_annotation_router() -> APIRouter:
if request.bbox is not None: if request.bbox is not None:
# Get image dimensions # Get image dimensions
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png" storage = get_storage_helper()
from PIL import Image dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number)
with Image.open(image_path) as img: if dimensions is None:
image_width, image_height = img.size raise HTTPException(
status_code=400,
detail=f"Image for page {annotation.page_number} not available",
)
image_width, image_height = dimensions
# Calculate normalized coordinates # Calculate normalized coordinates
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width

View File

@@ -13,16 +13,19 @@ from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from inference.web.config import DEFAULT_DPI, StorageConfig from inference.web.config import DEFAULT_DPI, StorageConfig
from inference.web.core.auth import AdminTokenDep, AdminDBDep from inference.web.core.auth import AdminTokenDep, AdminDBDep
from inference.web.services.storage_helpers import get_storage_helper
from inference.web.schemas.admin import ( from inference.web.schemas.admin import (
AnnotationItem, AnnotationItem,
AnnotationSource, AnnotationSource,
AutoLabelStatus, AutoLabelStatus,
BoundingBox, BoundingBox,
DocumentCategoriesResponse,
DocumentDetailResponse, DocumentDetailResponse,
DocumentItem, DocumentItem,
DocumentListResponse, DocumentListResponse,
DocumentStatus, DocumentStatus,
DocumentStatsResponse, DocumentStatsResponse,
DocumentUpdateRequest,
DocumentUploadResponse, DocumentUploadResponse,
ModelMetrics, ModelMetrics,
TrainingHistoryItem, TrainingHistoryItem,
@@ -44,14 +47,12 @@ def _validate_uuid(value: str, name: str = "ID") -> None:
def _convert_pdf_to_images( def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int document_id: str, content: bytes, page_count: int, dpi: int
) -> None: ) -> None:
"""Convert PDF pages to images for annotation.""" """Convert PDF pages to images for annotation using StorageHelper."""
import fitz import fitz
doc_images_dir = images_dir / document_id storage = get_storage_helper()
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf") pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count): for page_num in range(page_count):
@@ -60,8 +61,9 @@ def _convert_pdf_to_images(
mat = fitz.Matrix(dpi / 72, dpi / 72) mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat) pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png" # Save to storage using StorageHelper
pix.save(str(image_path)) image_bytes = pix.tobytes("png")
storage.save_admin_image(document_id, page_num + 1, image_bytes)
pdf_doc.close() pdf_doc.close()
@@ -95,6 +97,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
str | None, str | None,
Query(description="Optional group key for document organization", max_length=255), Query(description="Optional group key for document organization", max_length=255),
] = None, ] = None,
category: Annotated[
str,
Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100),
] = "invoice",
) -> DocumentUploadResponse: ) -> DocumentUploadResponse:
"""Upload a document for labeling.""" """Upload a document for labeling."""
# Validate group_key length # Validate group_key length
@@ -143,31 +149,33 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_path="", # Will update after saving file_path="", # Will update after saving
page_count=page_count, page_count=page_count,
group_key=group_key, group_key=group_key,
category=category,
) )
# Save file to admin uploads # Save file to storage using StorageHelper
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}" storage = get_storage_helper()
filename = f"{document_id}{file_ext}"
try: try:
file_path.write_bytes(content) storage_path = storage.save_raw_pdf(content, filename)
except Exception as e: except Exception as e:
logger.error(f"Failed to save file: {e}") logger.error(f"Failed to save file: {e}")
raise HTTPException(status_code=500, detail="Failed to save file") raise HTTPException(status_code=500, detail="Failed to save file")
# Update file path in database # Update file path in database (using storage path for reference)
from inference.data.database import get_session_context from inference.data.database import get_session_context
from inference.data.admin_models import AdminDocument from inference.data.admin_models import AdminDocument
with get_session_context() as session: with get_session_context() as session:
doc = session.get(AdminDocument, UUID(document_id)) doc = session.get(AdminDocument, UUID(document_id))
if doc: if doc:
doc.file_path = str(file_path) # Store the storage path (relative path within storage)
doc.file_path = storage_path
session.add(doc) session.add(doc)
# Convert PDF to images for annotation # Convert PDF to images for annotation
if file_ext == ".pdf": if file_ext == ".pdf":
try: try:
_convert_pdf_to_images( _convert_pdf_to_images(
document_id, content, page_count, document_id, content, page_count, storage_config.dpi
storage_config.admin_images_dir, storage_config.dpi
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}") logger.error(f"Failed to convert PDF to images: {e}")
@@ -189,6 +197,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
file_size=len(content), file_size=len(content),
page_count=page_count, page_count=page_count,
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING, status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
category=category,
group_key=group_key, group_key=group_key,
auto_label_started=auto_label_started, auto_label_started=auto_label_started,
message="Document uploaded successfully", message="Document uploaded successfully",
@@ -226,6 +235,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
str | None, str | None,
Query(description="Filter by batch ID"), Query(description="Filter by batch ID"),
] = None, ] = None,
category: Annotated[
str | None,
Query(description="Filter by document category"),
] = None,
limit: Annotated[ limit: Annotated[
int, int,
Query(ge=1, le=100, description="Page size"), Query(ge=1, le=100, description="Page size"),
@@ -264,6 +277,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
has_annotations=has_annotations, has_annotations=has_annotations,
auto_label_status=auto_label_status, auto_label_status=auto_label_status,
batch_id=batch_id, batch_id=batch_id,
category=category,
limit=limit, limit=limit,
offset=offset, offset=offset,
) )
@@ -291,6 +305,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui", upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None, batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
group_key=doc.group_key if hasattr(doc, 'group_key') else None, group_key=doc.group_key if hasattr(doc, 'group_key') else None,
category=doc.category if hasattr(doc, 'category') else "invoice",
can_annotate=can_annotate, can_annotate=can_annotate,
created_at=doc.created_at, created_at=doc.created_at,
updated_at=doc.updated_at, updated_at=doc.updated_at,
@@ -436,6 +451,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui", upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None, batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
group_key=document.group_key if hasattr(document, 'group_key') else None, group_key=document.group_key if hasattr(document, 'group_key') else None,
category=document.category if hasattr(document, 'category') else "invoice",
csv_field_values=csv_field_values, csv_field_values=csv_field_values,
can_annotate=can_annotate, can_annotate=can_annotate,
annotation_lock_until=annotation_lock_until, annotation_lock_until=annotation_lock_until,
@@ -471,16 +487,22 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
detail="Document not found or does not belong to this token", detail="Document not found or does not belong to this token",
) )
# Delete file # Delete file using StorageHelper
file_path = Path(document.file_path) storage = get_storage_helper()
if file_path.exists():
file_path.unlink()
# Delete images # Delete the raw PDF
images_dir = ADMIN_IMAGES_DIR / document_id filename = Path(document.file_path).name
if images_dir.exists(): if filename:
import shutil try:
shutil.rmtree(images_dir) storage._storage.delete(document.file_path)
except Exception as e:
logger.warning(f"Failed to delete PDF file: {e}")
# Delete admin images
try:
storage.delete_admin_images(document_id)
except Exception as e:
logger.warning(f"Failed to delete admin images: {e}")
# Delete from database # Delete from database
db.delete_document(document_id) db.delete_document(document_id)
@@ -609,4 +631,61 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
"message": "Document group key updated", "message": "Document group key updated",
} }
@router.get(
"/categories",
response_model=DocumentCategoriesResponse,
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
},
summary="Get available categories",
description="Get list of all available document categories.",
)
async def get_categories(
admin_token: AdminTokenDep,
db: AdminDBDep,
) -> DocumentCategoriesResponse:
"""Get all available document categories."""
categories = db.get_document_categories()
return DocumentCategoriesResponse(
categories=categories,
total=len(categories),
)
@router.patch(
"/{document_id}/category",
responses={
401: {"model": ErrorResponse, "description": "Invalid token"},
404: {"model": ErrorResponse, "description": "Document not found"},
},
summary="Update document category",
description="Update the category for a document.",
)
async def update_document_category(
document_id: str,
admin_token: AdminTokenDep,
db: AdminDBDep,
request: DocumentUpdateRequest,
) -> dict:
"""Update document category."""
_validate_uuid(document_id, "document_id")
# Verify document exists
document = db.get_document_by_token(document_id, admin_token)
if document is None:
raise HTTPException(
status_code=404,
detail="Document not found or does not belong to this token",
)
# Update category if provided
if request.category is not None:
db.update_document_category(document_id, request.category)
return {
"status": "updated",
"document_id": document_id,
"category": request.category,
"message": "Document category updated",
}
return router return router

View File

@@ -17,6 +17,7 @@ from inference.web.schemas.admin import (
TrainingStatus, TrainingStatus,
TrainingTaskResponse, TrainingTaskResponse,
) )
from inference.web.services.storage_helpers import get_storage_helper
from ._utils import _validate_uuid from ._utils import _validate_uuid
@@ -38,7 +39,6 @@ def register_dataset_routes(router: APIRouter) -> None:
db: AdminDBDep, db: AdminDBDep,
) -> DatasetResponse: ) -> DatasetResponse:
"""Create a training dataset from document IDs.""" """Create a training dataset from document IDs."""
from pathlib import Path
from inference.web.services.dataset_builder import DatasetBuilder from inference.web.services.dataset_builder import DatasetBuilder
# Validate minimum document count for proper train/val/test split # Validate minimum document count for proper train/val/test split
@@ -56,7 +56,18 @@ def register_dataset_routes(router: APIRouter) -> None:
seed=request.seed, seed=request.seed,
) )
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets")) # Get storage paths from StorageHelper
storage = get_storage_helper()
datasets_dir = storage.get_datasets_base_path()
admin_images_dir = storage.get_admin_images_base_path()
if datasets_dir is None or admin_images_dir is None:
raise HTTPException(
status_code=500,
detail="Storage not configured for local access",
)
builder = DatasetBuilder(db=db, base_dir=datasets_dir)
try: try:
builder.build_dataset( builder.build_dataset(
dataset_id=str(dataset.dataset_id), dataset_id=str(dataset.dataset_id),
@@ -64,7 +75,7 @@ def register_dataset_routes(router: APIRouter) -> None:
train_ratio=request.train_ratio, train_ratio=request.train_ratio,
val_ratio=request.val_ratio, val_ratio=request.val_ratio,
seed=request.seed, seed=request.seed,
admin_images_dir=Path("data/admin_images"), admin_images_dir=admin_images_dir,
) )
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@@ -142,6 +153,12 @@ def register_dataset_routes(router: APIRouter) -> None:
name=dataset.name, name=dataset.name,
description=dataset.description, description=dataset.description,
status=dataset.status, status=dataset.status,
training_status=dataset.training_status,
active_training_task_id=(
str(dataset.active_training_task_id)
if dataset.active_training_task_id
else None
),
train_ratio=dataset.train_ratio, train_ratio=dataset.train_ratio,
val_ratio=dataset.val_ratio, val_ratio=dataset.val_ratio,
seed=dataset.seed, seed=dataset.seed,

View File

@@ -34,8 +34,10 @@ def register_export_routes(router: APIRouter) -> None:
db: AdminDBDep, db: AdminDBDep,
) -> ExportResponse: ) -> ExportResponse:
"""Export annotations for training.""" """Export annotations for training."""
from pathlib import Path from inference.web.services.storage_helpers import get_storage_helper
import shutil
# Get storage helper for reading images and exports directory
storage = get_storage_helper()
if request.format not in ("yolo", "coco", "voc"): if request.format not in ("yolo", "coco", "voc"):
raise HTTPException( raise HTTPException(
@@ -51,7 +53,14 @@ def register_export_routes(router: APIRouter) -> None:
detail="No labeled documents available for export", detail="No labeled documents available for export",
) )
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}" # Get exports directory from StorageHelper
exports_base = storage.get_exports_base_path()
if exports_base is None:
raise HTTPException(
status_code=500,
detail="Storage not configured for local access",
)
export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
export_dir.mkdir(parents=True, exist_ok=True) export_dir.mkdir(parents=True, exist_ok=True)
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True) (export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
@@ -80,13 +89,16 @@ def register_export_routes(router: APIRouter) -> None:
if not page_annotations and not request.include_images: if not page_annotations and not request.include_images:
continue continue
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png" # Get image from storage
if not src_image.exists(): doc_id = str(doc.document_id)
if not storage.admin_image_exists(doc_id, page_num):
continue continue
# Download image and save to export directory
image_name = f"{doc.document_id}_page{page_num}.png" image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image) image_content = storage.get_admin_image(doc_id, page_num)
dst_image.write_bytes(image_content)
total_images += 1 total_images += 1
label_name = f"{doc.document_id}_page{page_num}.txt" label_name = f"{doc.document_id}_page{page_num}.txt"
@@ -98,7 +110,7 @@ def register_export_routes(router: APIRouter) -> None:
f.write(line) f.write(line)
total_annotations += 1 total_annotations += 1
from inference.data.admin_models import FIELD_CLASSES from shared.fields import FIELD_CLASSES
yaml_content = f"""# Auto-generated YOLO dataset config yaml_content = f"""# Auto-generated YOLO dataset config
path: {export_dir.absolute()} path: {export_dir.absolute()}

View File

@@ -22,6 +22,7 @@ from inference.web.schemas.inference import (
InferenceResult, InferenceResult,
) )
from inference.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING: if TYPE_CHECKING:
from inference.web.services import InferenceService from inference.web.services import InferenceService
@@ -90,8 +91,17 @@ def create_inference_router(
# Generate document ID # Generate document ID
doc_id = str(uuid.uuid4())[:8] doc_id = str(uuid.uuid4())[:8]
# Save uploaded file # Get storage helper and uploads directory
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}" storage = get_storage_helper()
uploads_dir = storage.get_uploads_base_path(subfolder="inference")
if uploads_dir is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Storage not configured for local access",
)
# Save uploaded file to temporary location for processing
upload_path = uploads_dir / f"{doc_id}{file_ext}"
try: try:
with open(upload_path, "wb") as f: with open(upload_path, "wb") as f:
shutil.copyfileobj(file.file, f) shutil.copyfileobj(file.file, f)
@@ -149,12 +159,13 @@ def create_inference_router(
# Cleanup uploaded file # Cleanup uploaded file
upload_path.unlink(missing_ok=True) upload_path.unlink(missing_ok=True)
@router.get("/results/{filename}") @router.get("/results/{filename}", response_model=None)
async def get_result_image(filename: str) -> FileResponse: async def get_result_image(filename: str) -> FileResponse:
"""Get visualization result image.""" """Get visualization result image."""
file_path = storage_config.result_dir / filename storage = get_storage_helper()
file_path = storage.get_result_local_path(filename)
if not file_path.exists(): if file_path is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}", detail=f"Result file not found: {filename}",
@@ -169,15 +180,15 @@ def create_inference_router(
@router.delete("/results/{filename}") @router.delete("/results/{filename}")
async def delete_result(filename: str) -> dict: async def delete_result(filename: str) -> dict:
"""Delete a result file.""" """Delete a result file."""
file_path = storage_config.result_dir / filename storage = get_storage_helper()
if not file_path.exists(): if not storage.result_exists(filename):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}", detail=f"Result file not found: {filename}",
) )
file_path.unlink() storage.delete_result(filename)
return {"status": "deleted", "filename": filename} return {"status": "deleted", "filename": filename}
return router return router

View File

@@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
from inference.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from inference.web.schemas.labeling import PreLabelResponse from inference.web.schemas.labeling import PreLabelResponse
from inference.web.schemas.common import ErrorResponse from inference.web.schemas.common import ErrorResponse
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING: if TYPE_CHECKING:
from inference.web.services import InferenceService from inference.web.services import InferenceService
@@ -23,19 +24,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Storage directory for pre-label uploads (legacy, now uses storage_config)
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
def _convert_pdf_to_images( def _convert_pdf_to_images(
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int document_id: str, content: bytes, page_count: int, dpi: int
) -> None: ) -> None:
"""Convert PDF pages to images for annotation.""" """Convert PDF pages to images for annotation using StorageHelper."""
import fitz import fitz
doc_images_dir = images_dir / document_id storage = get_storage_helper()
doc_images_dir.mkdir(parents=True, exist_ok=True)
pdf_doc = fitz.open(stream=content, filetype="pdf") pdf_doc = fitz.open(stream=content, filetype="pdf")
for page_num in range(page_count): for page_num in range(page_count):
@@ -43,8 +39,9 @@ def _convert_pdf_to_images(
mat = fitz.Matrix(dpi / 72, dpi / 72) mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat) pix = page.get_pixmap(matrix=mat)
image_path = doc_images_dir / f"page_{page_num + 1}.png" # Save to storage using StorageHelper
pix.save(str(image_path)) image_bytes = pix.tobytes("png")
storage.save_admin_image(document_id, page_num + 1, image_bytes)
pdf_doc.close() pdf_doc.close()
@@ -70,9 +67,6 @@ def create_labeling_router(
""" """
router = APIRouter(prefix="/api/v1", tags=["labeling"]) router = APIRouter(prefix="/api/v1", tags=["labeling"])
# Ensure upload directory exists
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@router.post( @router.post(
"/pre-label", "/pre-label",
response_model=PreLabelResponse, response_model=PreLabelResponse,
@@ -165,10 +159,11 @@ def create_labeling_router(
csv_field_values=expected_values, csv_field_values=expected_values,
) )
# Save file to admin uploads # Save file to storage using StorageHelper
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}" storage = get_storage_helper()
filename = f"{document_id}{file_ext}"
try: try:
file_path.write_bytes(content) storage_path = storage.save_raw_pdf(content, filename)
except Exception as e: except Exception as e:
logger.error(f"Failed to save file: {e}") logger.error(f"Failed to save file: {e}")
raise HTTPException( raise HTTPException(
@@ -176,15 +171,14 @@ def create_labeling_router(
detail="Failed to save file", detail="Failed to save file",
) )
# Update file path in database # Update file path in database (using storage path)
db.update_document_file_path(document_id, str(file_path)) db.update_document_file_path(document_id, storage_path)
# Convert PDF to images for annotation UI # Convert PDF to images for annotation UI
if file_ext == ".pdf": if file_ext == ".pdf":
try: try:
_convert_pdf_to_images( _convert_pdf_to_images(
document_id, content, page_count, document_id, content, page_count, storage_config.dpi
storage_config.admin_images_dir, storage_config.dpi
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to convert PDF to images: {e}") logger.error(f"Failed to convert PDF to images: {e}")

View File

@@ -18,6 +18,7 @@ from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config from .config import AppConfig, default_config
from inference.web.services import InferenceService from inference.web.services import InferenceService
from inference.web.services.storage_helpers import get_storage_helper
# Public API imports # Public API imports
from inference.web.api.v1.public import ( from inference.web.api.v1.public import (
@@ -238,13 +239,17 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
allow_headers=["*"], allow_headers=["*"],
) )
# Mount static files for results # Mount static files for results using StorageHelper
config.storage.result_dir.mkdir(parents=True, exist_ok=True) storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir:
app.mount( app.mount(
"/static/results", "/static/results",
StaticFiles(directory=str(config.storage.result_dir)), StaticFiles(directory=str(results_dir)),
name="results", name="results",
) )
else:
logger.warning("Could not mount static results directory: local storage not available")
# Include public API routes # Include public API routes
inference_router = create_inference_router(inference_service, config.storage) inference_router = create_inference_router(inference_service, config.storage)

View File

@@ -4,16 +4,49 @@ Web Application Configuration
Centralized configuration for the web application. Centralized configuration for the web application.
""" """
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Any
from shared.config import DEFAULT_DPI, PATHS from shared.config import DEFAULT_DPI
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
def get_storage_backend(
config_path: Path | str | None = None,
) -> "StorageBackend":
"""Get storage backend for file operations.
Args:
config_path: Optional path to storage configuration file.
If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars.
Returns:
Configured StorageBackend instance.
"""
from shared.storage import get_storage_backend as _get_storage_backend
# Check for config file path
if config_path is None:
config_path_str = os.environ.get("STORAGE_CONFIG_PATH")
if config_path_str:
config_path = Path(config_path_str)
return _get_storage_backend(config_path=config_path)
@dataclass(frozen=True) @dataclass(frozen=True)
class ModelConfig: class ModelConfig:
"""YOLO model configuration.""" """YOLO model configuration.
Note: Model files are stored locally (not in STORAGE_BASE_PATH) because:
- Models need to be accessible by inference service on any platform
- Models may be version-controlled or deployed separately
- Models are part of the application, not user data
"""
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt") model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
confidence_threshold: float = 0.5 confidence_threshold: float = 0.5
@@ -33,24 +66,39 @@ class ServerConfig:
@dataclass(frozen=True) @dataclass(frozen=True)
class StorageConfig: class FileConfig:
"""File storage configuration. """File handling configuration.
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored This config holds file handling settings. For file operations,
directly in raw_pdfs directory. This ensures consistency with CLI autolabel use the storage backend with PREFIXES from shared.storage.prefixes.
and avoids storing duplicate files.
Example:
from shared.storage import PREFIXES, get_storage_backend
storage = get_storage_backend()
path = PREFIXES.document_path(document_id)
storage.upload_bytes(content, path)
Note: The path fields (upload_dir, result_dir, etc.) are deprecated.
They are kept for backward compatibility with existing code and tests.
New code should use the storage backend with PREFIXES instead.
""" """
upload_dir: Path = Path("uploads")
result_dir: Path = Path("results")
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
admin_images_dir: Path = Path("data/admin_images")
max_file_size_mb: int = 50 max_file_size_mb: int = 50
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg") allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
dpi: int = DEFAULT_DPI dpi: int = DEFAULT_DPI
presigned_url_expiry_seconds: int = 3600
# Deprecated path fields - kept for backward compatibility
# New code should use storage backend with PREFIXES instead
# All paths are now under data/ to match WSL storage layout
upload_dir: Path = field(default_factory=lambda: Path("data/uploads"))
result_dir: Path = field(default_factory=lambda: Path("data/results"))
admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs"))
admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images"))
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Create directories if they don't exist.""" """Create directories if they don't exist (for backward compatibility)."""
object.__setattr__(self, "upload_dir", Path(self.upload_dir)) object.__setattr__(self, "upload_dir", Path(self.upload_dir))
object.__setattr__(self, "result_dir", Path(self.result_dir)) object.__setattr__(self, "result_dir", Path(self.result_dir))
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir)) object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
@@ -61,9 +109,17 @@ class StorageConfig:
self.admin_images_dir.mkdir(parents=True, exist_ok=True) self.admin_images_dir.mkdir(parents=True, exist_ok=True)
# Backward compatibility alias
StorageConfig = FileConfig
@dataclass(frozen=True) @dataclass(frozen=True)
class AsyncConfig: class AsyncConfig:
"""Async processing configuration.""" """Async processing configuration.
Note: For file paths, use the storage backend with PREFIXES.
Example: PREFIXES.upload_path(filename, "async")
"""
# Queue settings # Queue settings
queue_max_size: int = 100 queue_max_size: int = 100
@@ -77,14 +133,17 @@ class AsyncConfig:
# Storage # Storage
result_retention_days: int = 7 result_retention_days: int = 7
temp_upload_dir: Path = Path("uploads/async")
max_file_size_mb: int = 50 max_file_size_mb: int = 50
# Deprecated: kept for backward compatibility
# Path under data/ to match WSL storage layout
temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async"))
# Cleanup # Cleanup
cleanup_interval_hours: int = 1 cleanup_interval_hours: int = 1
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Create directories if they don't exist.""" """Create directories if they don't exist (for backward compatibility)."""
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir)) object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
self.temp_upload_dir.mkdir(parents=True, exist_ok=True) self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
@@ -95,19 +154,41 @@ class AppConfig:
model: ModelConfig = field(default_factory=ModelConfig) model: ModelConfig = field(default_factory=ModelConfig)
server: ServerConfig = field(default_factory=ServerConfig) server: ServerConfig = field(default_factory=ServerConfig)
storage: StorageConfig = field(default_factory=StorageConfig) file: FileConfig = field(default_factory=FileConfig)
async_processing: AsyncConfig = field(default_factory=AsyncConfig) async_processing: AsyncConfig = field(default_factory=AsyncConfig)
storage_backend: "StorageBackend | None" = None
@property
def storage(self) -> FileConfig:
"""Backward compatibility alias for file config."""
return self.file
@classmethod @classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig": def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
"""Create config from dictionary.""" """Create config from dictionary."""
file_config = config_dict.get("file", config_dict.get("storage", {}))
return cls( return cls(
model=ModelConfig(**config_dict.get("model", {})), model=ModelConfig(**config_dict.get("model", {})),
server=ServerConfig(**config_dict.get("server", {})), server=ServerConfig(**config_dict.get("server", {})),
storage=StorageConfig(**config_dict.get("storage", {})), file=FileConfig(**file_config),
async_processing=AsyncConfig(**config_dict.get("async_processing", {})), async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
) )
def create_app_config(
storage_config_path: Path | str | None = None,
) -> AppConfig:
"""Create application configuration with storage backend.
Args:
storage_config_path: Optional path to storage configuration file.
Returns:
Configured AppConfig instance with storage backend initialized.
"""
storage_backend = get_storage_backend(config_path=storage_config_path)
return AppConfig(storage_backend=storage_backend)
# Default configuration instance # Default configuration instance
default_config = AppConfig() default_config = AppConfig()

View File

@@ -13,6 +13,7 @@ from inference.web.services.db_autolabel import (
get_pending_autolabel_documents, get_pending_autolabel_documents,
process_document_autolabel, process_document_autolabel,
) )
from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,7 +37,13 @@ class AutoLabelScheduler:
""" """
self._check_interval = check_interval_seconds self._check_interval = check_interval_seconds
self._batch_size = batch_size self._batch_size = batch_size
# Get output directory from StorageHelper
if output_dir is None:
storage = get_storage_helper()
output_dir = storage.get_autolabel_output_path()
self._output_dir = output_dir or Path("data/autolabel_output") self._output_dir = output_dir or Path("data/autolabel_output")
self._running = False self._running = False
self._thread: threading.Thread | None = None self._thread: threading.Thread | None = None
self._stop_event = threading.Event() self._stop_event = threading.Event()

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Any from typing import Any
from inference.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -107,6 +108,14 @@ class TrainingScheduler:
self._db.update_training_task_status(task_id, "running") self._db.update_training_task_status(task_id, "running")
self._db.add_training_log(task_id, "INFO", "Training task started") self._db.add_training_log(task_id, "INFO", "Training task started")
# Update dataset training status to running
if dataset_id:
self._db.update_dataset_training_status(
dataset_id,
training_status="running",
active_training_task_id=task_id,
)
try: try:
# Get training configuration # Get training configuration
model_name = config.get("model_name", "yolo11n.pt") model_name = config.get("model_name", "yolo11n.pt")
@@ -192,6 +201,15 @@ class TrainingScheduler:
) )
self._db.add_training_log(task_id, "INFO", "Training completed successfully") self._db.add_training_log(task_id, "INFO", "Training completed successfully")
# Update dataset training status to completed and main status to trained
if dataset_id:
self._db.update_dataset_training_status(
dataset_id,
training_status="completed",
active_training_task_id=None,
update_main_status=True, # Set main status to 'trained'
)
# Auto-create model version for the completed training # Auto-create model version for the completed training
self._create_model_version_from_training( self._create_model_version_from_training(
task_id=task_id, task_id=task_id,
@@ -203,6 +221,13 @@ class TrainingScheduler:
except Exception as e: except Exception as e:
logger.error(f"Training task {task_id} failed: {e}") logger.error(f"Training task {task_id} failed: {e}")
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}") self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
# Update dataset training status to failed
if dataset_id:
self._db.update_dataset_training_status(
dataset_id,
training_status="failed",
active_training_task_id=None,
)
raise raise
def _create_model_version_from_training( def _create_model_version_from_training(
@@ -268,9 +293,10 @@ class TrainingScheduler:
f"Created model version {version} (ID: {model_version.version_id}) " f"Created model version {version} (ID: {model_version.version_id}) "
f"from training task {task_id}" f"from training task {task_id}"
) )
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
self._db.add_training_log( self._db.add_training_log(
task_id, "INFO", task_id, "INFO",
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})", f"Model version {version} created (mAP: {mAP_display})",
) )
except Exception as e: except Exception as e:
@@ -283,8 +309,11 @@ class TrainingScheduler:
def _export_training_data(self, task_id: str) -> dict[str, Any] | None: def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task.""" """Export training data for a task."""
from pathlib import Path from pathlib import Path
import shutil from shared.fields import FIELD_CLASSES
from inference.data.admin_models import FIELD_CLASSES from inference.web.services.storage_helpers import get_storage_helper
# Get storage helper for reading images
storage = get_storage_helper()
# Get all labeled documents # Get all labeled documents
documents = self._db.get_labeled_documents_for_export() documents = self._db.get_labeled_documents_for_export()
@@ -293,8 +322,12 @@ class TrainingScheduler:
self._db.add_training_log(task_id, "ERROR", "No labeled documents available") self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
return None return None
# Create export directory # Create export directory using StorageHelper
export_dir = Path("data/training") / task_id training_base = storage.get_training_data_path()
if training_base is None:
self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access")
return None
export_dir = training_base / task_id
export_dir.mkdir(parents=True, exist_ok=True) export_dir.mkdir(parents=True, exist_ok=True)
# YOLO format directories # YOLO format directories
@@ -323,14 +356,16 @@ class TrainingScheduler:
for page_num in range(1, doc.page_count + 1): for page_num in range(1, doc.page_count + 1):
page_annotations = [a for a in annotations if a.page_number == page_num] page_annotations = [a for a in annotations if a.page_number == page_num]
# Copy image # Get image from storage
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png" doc_id = str(doc.document_id)
if not src_image.exists(): if not storage.admin_image_exists(doc_id, page_num):
continue continue
# Download image and save to export directory
image_name = f"{doc.document_id}_page{page_num}.png" image_name = f"{doc.document_id}_page{page_num}.png"
dst_image = export_dir / "images" / split / image_name dst_image = export_dir / "images" / split / image_name
shutil.copy(src_image, dst_image) image_content = storage.get_admin_image(doc_id, page_num)
dst_image.write_bytes(image_content)
total_images += 1 total_images += 1
# Write YOLO label # Write YOLO label
@@ -380,6 +415,8 @@ names: {list(FIELD_CLASSES.values())}
self._db.add_training_log(task_id, level, message) self._db.add_training_log(task_id, level, message)
# Create shared training config # Create shared training config
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
# because models need to be accessible by inference service on any platform
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread # Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
config = SharedTrainingConfig( config = SharedTrainingConfig(
model_path=model_name, model_path=model_name,

View File

@@ -13,6 +13,7 @@ class DatasetCreateRequest(BaseModel):
name: str = Field(..., min_length=1, max_length=255, description="Dataset name") name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
description: str | None = Field(None, description="Optional description") description: str | None = Field(None, description="Optional description")
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include") document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
category: str | None = Field(None, description="Filter documents by category (optional)")
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio") train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio") val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
seed: int = Field(42, description="Random seed for split") seed: int = Field(42, description="Random seed for split")
@@ -43,6 +44,8 @@ class DatasetDetailResponse(BaseModel):
name: str name: str
description: str | None description: str | None
status: str status: str
training_status: str | None = None
active_training_task_id: str | None = None
train_ratio: float train_ratio: float
val_ratio: float val_ratio: float
seed: int seed: int

View File

@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
file_size: int = Field(..., ge=0, description="File size in bytes") file_size: int = Field(..., ge=0, description="File size in bytes")
page_count: int = Field(..., ge=1, description="Number of pages") page_count: int = Field(..., ge=1, description="Number of pages")
status: DocumentStatus = Field(..., description="Document status") status: DocumentStatus = Field(..., description="Document status")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
group_key: str | None = Field(None, description="User-defined group key") group_key: str | None = Field(None, description="User-defined group key")
auto_label_started: bool = Field( auto_label_started: bool = Field(
default=False, description="Whether auto-labeling was started" default=False, description="Whether auto-labeling was started"
@@ -44,6 +45,7 @@ class DocumentItem(BaseModel):
upload_source: str = Field(default="ui", description="Upload source (ui or api)") upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch") batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key") group_key: str | None = Field(None, description="User-defined group key")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
can_annotate: bool = Field(default=True, description="Whether document can be annotated") can_annotate: bool = Field(default=True, description="Whether document can be annotated")
created_at: datetime = Field(..., description="Creation timestamp") created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp") updated_at: datetime = Field(..., description="Last update timestamp")
@@ -76,6 +78,7 @@ class DocumentDetailResponse(BaseModel):
upload_source: str = Field(default="ui", description="Upload source (ui or api)") upload_source: str = Field(default="ui", description="Upload source (ui or api)")
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch") batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
group_key: str | None = Field(None, description="User-defined group key") group_key: str | None = Field(None, description="User-defined group key")
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
csv_field_values: dict[str, str] | None = Field( csv_field_values: dict[str, str] | None = Field(
None, description="CSV field values if uploaded via batch" None, description="CSV field values if uploaded via batch"
) )
@@ -104,3 +107,17 @@ class DocumentStatsResponse(BaseModel):
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents") auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
labeled: int = Field(default=0, ge=0, description="Labeled documents") labeled: int = Field(default=0, ge=0, description="Labeled documents")
exported: int = Field(default=0, ge=0, description="Exported documents") exported: int = Field(default=0, ge=0, description="Exported documents")
class DocumentUpdateRequest(BaseModel):
"""Request for updating document metadata."""
category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)")
group_key: str | None = Field(None, description="User-defined group key")
class DocumentCategoriesResponse(BaseModel):
"""Response for available document categories."""
categories: list[str] = Field(..., description="List of available categories")
total: int = Field(..., ge=0, description="Total number of categories")

View File

@@ -5,6 +5,7 @@ Manages async request lifecycle and background processing.
""" """
import logging import logging
import re
import shutil import shutil
import time import time
import uuid import uuid
@@ -17,6 +18,7 @@ from typing import TYPE_CHECKING
from inference.data.async_request_db import AsyncRequestDB from inference.data.async_request_db import AsyncRequestDB
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.core.rate_limiter import RateLimiter from inference.web.core.rate_limiter import RateLimiter
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING: if TYPE_CHECKING:
from inference.web.config import AsyncConfig, StorageConfig from inference.web.config import AsyncConfig, StorageConfig
@@ -189,9 +191,7 @@ class AsyncProcessingService:
filename: str, filename: str,
content: bytes, content: bytes,
) -> Path: ) -> Path:
"""Save uploaded file to temp storage.""" """Save uploaded file to temp storage using StorageHelper."""
import re
# Extract extension from filename # Extract extension from filename
ext = Path(filename).suffix.lower() ext = Path(filename).suffix.lower()
@@ -203,9 +203,11 @@ class AsyncProcessingService:
if ext not in self.ALLOWED_EXTENSIONS: if ext not in self.ALLOWED_EXTENSIONS:
ext = ".pdf" ext = ".pdf"
# Create async upload directory # Get upload directory from StorageHelper
upload_dir = self._async_config.temp_upload_dir storage = get_storage_helper()
upload_dir.mkdir(parents=True, exist_ok=True) upload_dir = storage.get_uploads_base_path(subfolder="async")
if upload_dir is None:
raise ValueError("Storage not configured for local access")
# Build file path - request_id is a UUID so it's safe # Build file path - request_id is a UUID so it's safe
file_path = upload_dir / f"{request_id}{ext}" file_path = upload_dir / f"{request_id}{ext}"
@@ -355,8 +357,9 @@ class AsyncProcessingService:
def _cleanup_orphan_files(self) -> int: def _cleanup_orphan_files(self) -> int:
"""Clean up upload files that don't have matching requests.""" """Clean up upload files that don't have matching requests."""
upload_dir = self._async_config.temp_upload_dir storage = get_storage_helper()
if not upload_dir.exists(): upload_dir = storage.get_uploads_base_path(subfolder="async")
if upload_dir is None or not upload_dir.exists():
return 0 return 0
count = 0 count = 0

View File

@@ -13,7 +13,7 @@ from PIL import Image
from shared.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
from shared.matcher.field_matcher import FieldMatcher from shared.matcher.field_matcher import FieldMatcher
from shared.ocr.paddle_ocr import OCREngine, OCRToken from shared.ocr.paddle_ocr import OCREngine, OCRToken

View File

@@ -16,7 +16,7 @@ from uuid import UUID
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from inference.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from inference.data.admin_models import CSV_TO_CLASS_MAPPING from shared.fields import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -12,7 +12,7 @@ from pathlib import Path
import yaml import yaml
from inference.data.admin_models import FIELD_CLASSES from shared.fields import FIELD_CLASSES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -13,9 +13,10 @@ from typing import Any
from shared.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
from inference.data.admin_db import AdminDB from inference.data.admin_db import AdminDB
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING from shared.fields import CSV_TO_CLASS_MAPPING
from inference.data.admin_models import AdminDocument
from shared.data.db import DocumentDB from shared.data.db import DocumentDB
from inference.web.config import StorageConfig from inference.web.services.storage_helpers import get_storage_helper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -122,6 +123,10 @@ def process_document_autolabel(
document_id = str(document.document_id) document_id = str(document.document_id)
file_path = Path(document.file_path) file_path = Path(document.file_path)
# Get output directory from StorageHelper
storage = get_storage_helper()
if output_dir is None:
output_dir = storage.get_autolabel_output_path()
if output_dir is None: if output_dir is None:
output_dir = Path("data/autolabel_output") output_dir = Path("data/autolabel_output")
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@@ -152,10 +157,12 @@ def process_document_autolabel(
is_scanned = len(tokens) < 10 # Threshold for "no text" is_scanned = len(tokens) < 10 # Threshold for "no text"
# Build task data # Build task data
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path # Use raw_pdfs base path for pdf_path
# This ensures consistency with CLI autolabel for reprocess_failed.py # This ensures consistency with CLI autolabel for reprocess_failed.py
storage_config = StorageConfig() raw_pdfs_dir = storage.get_raw_pdfs_base_path()
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf" if raw_pdfs_dir is None:
raise ValueError("Storage not configured for local access")
pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf"
task_data = { task_data = {
"row_dict": row_dict, "row_dict": row_dict,
@@ -246,8 +253,8 @@ def _save_annotations_to_db(
Returns: Returns:
Number of annotations saved Number of annotations saved
""" """
from PIL import Image from shared.fields import FIELD_CLASS_IDS
from inference.data.admin_models import FIELD_CLASS_IDS from inference.web.services.storage_helpers import get_storage_helper
# Mapping from CSV field names to internal field names # Mapping from CSV field names to internal field names
CSV_TO_INTERNAL_FIELD: dict[str, str] = { CSV_TO_INTERNAL_FIELD: dict[str, str] = {
@@ -266,6 +273,9 @@ def _save_annotations_to_db(
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI) # Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
scale = dpi / 72.0 scale = dpi / 72.0
# Get storage helper for image dimensions
storage = get_storage_helper()
# Cache for image dimensions per page # Cache for image dimensions per page
image_dimensions: dict[int, tuple[int, int]] = {} image_dimensions: dict[int, tuple[int, int]] = {}
@@ -274,18 +284,11 @@ def _save_annotations_to_db(
if page_no in image_dimensions: if page_no in image_dimensions:
return image_dimensions[page_no] return image_dimensions[page_no]
# Try to load from admin_images # Get dimensions from storage helper
admin_images_dir = Path("data/admin_images") / document_id dims = storage.get_admin_image_dimensions(document_id, page_no)
image_path = admin_images_dir / f"page_{page_no}.png" if dims:
if image_path.exists():
try:
with Image.open(image_path) as img:
dims = img.size # (width, height)
image_dimensions[page_no] = dims image_dimensions[page_no] = dims
return dims return dims
except Exception as e:
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
return None return None
@@ -449,10 +452,17 @@ def save_manual_annotations_to_document_db(
from datetime import datetime from datetime import datetime
document_id = str(document.document_id) document_id = str(document.document_id)
storage_config = StorageConfig()
# Build pdf_path using admin_upload_dir (same as auto-label) # Build pdf_path using raw_pdfs base path (same as auto-label)
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf" storage = get_storage_helper()
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
if raw_pdfs_dir is None:
return {
"success": False,
"document_id": document_id,
"error": "Storage not configured for local access",
}
pdf_path = raw_pdfs_dir / f"{document_id}.pdf"
# Build report dict compatible with DocumentDB.save_document() # Build report dict compatible with DocumentDB.save_document()
field_results = [] field_results = []

View File

@@ -0,0 +1,217 @@
"""
Document Service for storage-backed file operations.
Provides a unified interface for document upload, download, and serving
using the storage abstraction layer.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from uuid import uuid4
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
@dataclass
class DocumentResult:
"""Result of document operation."""
id: str
file_path: str
filename: str | None = None
class DocumentService:
"""Service for document file operations using storage backend.
Provides upload, download, and URL generation for documents and images.
"""
# Storage path prefixes
DOCUMENTS_PREFIX = "documents"
IMAGES_PREFIX = "images"
def __init__(
self,
storage_backend: "StorageBackend",
admin_db: Any | None = None,
) -> None:
"""Initialize document service.
Args:
storage_backend: Storage backend for file operations.
admin_db: Optional AdminDB instance for database operations.
"""
self._storage = storage_backend
self._admin_db = admin_db
def upload_document(
self,
content: bytes,
filename: str,
dataset_id: str | None = None,
document_id: str | None = None,
) -> DocumentResult:
"""Upload a document to storage.
Args:
content: Document content as bytes.
filename: Original filename.
dataset_id: Optional dataset ID for organization.
document_id: Optional document ID (generated if not provided).
Returns:
DocumentResult with ID and storage path.
"""
if document_id is None:
document_id = str(uuid4())
# Extract extension from filename
ext = ""
if "." in filename:
ext = "." + filename.rsplit(".", 1)[-1].lower()
# Build logical path
remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}"
# Upload via storage backend
self._storage.upload_bytes(content, remote_path, overwrite=True)
return DocumentResult(
id=document_id,
file_path=remote_path,
filename=filename,
)
def download_document(self, remote_path: str) -> bytes:
"""Download a document from storage.
Args:
remote_path: Logical path to the document.
Returns:
Document content as bytes.
"""
return self._storage.download_bytes(remote_path)
def get_document_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get a URL for accessing a document.
Args:
remote_path: Logical path to the document.
expires_in_seconds: URL validity duration.
Returns:
Pre-signed URL for document access.
"""
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
def document_exists(self, remote_path: str) -> bool:
"""Check if a document exists in storage.
Args:
remote_path: Logical path to the document.
Returns:
True if document exists.
"""
return self._storage.exists(remote_path)
def delete_document_files(self, remote_path: str) -> bool:
"""Delete a document from storage.
Args:
remote_path: Logical path to the document.
Returns:
True if document was deleted.
"""
return self._storage.delete(remote_path)
def save_page_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save a page image to storage.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Logical path where image was stored.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
self._storage.upload_bytes(content, remote_path, overwrite=True)
return remote_path
def get_page_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get a URL for accessing a page image.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
expires_in_seconds: URL validity duration.
Returns:
Pre-signed URL for image access.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
def get_page_image(self, document_id: str, page_num: int) -> bytes:
"""Download a page image from storage.
Args:
document_id: Document ID.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
return self._storage.download_bytes(remote_path)
def delete_document_images(self, document_id: str) -> int:
"""Delete all images for a document.
Args:
document_id: Document ID.
Returns:
Number of images deleted.
"""
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
image_paths = self._storage.list_files(prefix)
deleted_count = 0
for path in image_paths:
if self._storage.delete(path):
deleted_count += 1
return deleted_count
def list_document_images(self, document_id: str) -> list[str]:
"""List all images for a document.
Args:
document_id: Document ID.
Returns:
List of image paths.
"""
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
return self._storage.list_files(prefix)

View File

@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Callable
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from inference.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING: if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig from .config import ModelConfig, StorageConfig
@@ -303,12 +305,19 @@ class InferenceService:
"""Save visualization image with detections.""" """Save visualization image with detections."""
from ultralytics import YOLO from ultralytics import YOLO
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Load model and run prediction with visualization # Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path)) model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False) results = model.predict(str(image_path), verbose=False)
# Save annotated image # Save annotated image
output_path = self.storage_config.result_dir / f"{doc_id}_result.png" output_path = results_dir / f"{doc_id}_result.png"
for r in results: for r in results:
r.save(filename=str(output_path)) r.save(filename=str(output_path))
@@ -320,19 +329,26 @@ class InferenceService:
from ultralytics import YOLO from ultralytics import YOLO
import io import io
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
if results_dir is None:
logger.warning("Cannot save visualization: local storage not available")
return None
# Render first page # Render first page
for page_no, image_bytes in render_pdf_to_images( for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi pdf_path, dpi=self.model_config.dpi
): ):
image = Image.open(io.BytesIO(image_bytes)) image = Image.open(io.BytesIO(image_bytes))
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png" temp_path = results_dir / f"{doc_id}_temp.png"
image.save(temp_path) image.save(temp_path)
# Run YOLO and save visualization # Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path)) model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False) results = model.predict(str(temp_path), verbose=False)
output_path = self.storage_config.result_dir / f"{doc_id}_result.png" output_path = results_dir / f"{doc_id}_result.png"
for r in results: for r in results:
r.save(filename=str(output_path)) r.save(filename=str(output_path))

View File

@@ -0,0 +1,830 @@
"""
Storage helpers for web services.
Provides convenience functions for common storage operations,
wrapping the storage backend with proper path handling using prefixes.
"""
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import uuid4
from shared.storage import PREFIXES, get_storage_backend
from shared.storage.local import LocalStorageBackend
if TYPE_CHECKING:
from shared.storage.base import StorageBackend
def get_default_storage() -> "StorageBackend":
"""Get the default storage backend.
Returns:
Configured StorageBackend instance.
"""
return get_storage_backend()
class StorageHelper:
"""Helper class for storage operations with prefixes.
Provides high-level operations for document storage, including
upload, download, and URL generation with proper path prefixes.
"""
def __init__(self, storage: "StorageBackend | None" = None) -> None:
"""Initialize storage helper.
Args:
storage: Storage backend to use. If None, creates default.
"""
self._storage = storage or get_default_storage()
@property
def storage(self) -> "StorageBackend":
"""Get the underlying storage backend."""
return self._storage
# Document operations
def upload_document(
self,
content: bytes,
filename: str,
document_id: str | None = None,
) -> tuple[str, str]:
"""Upload a document to storage.
Args:
content: Document content as bytes.
filename: Original filename (used for extension).
document_id: Optional document ID. Generated if not provided.
Returns:
Tuple of (document_id, storage_path).
"""
if document_id is None:
document_id = str(uuid4())
ext = Path(filename).suffix.lower() or ".pdf"
path = PREFIXES.document_path(document_id, ext)
self._storage.upload_bytes(content, path, overwrite=True)
return document_id, path
def download_document(self, document_id: str, extension: str = ".pdf") -> bytes:
"""Download a document from storage.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
Document content as bytes.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.download_bytes(path)
def get_document_url(
self,
document_id: str,
extension: str = ".pdf",
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for a document.
Args:
document_id: Document identifier.
extension: File extension.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.get_presigned_url(path, expires_in_seconds)
def document_exists(self, document_id: str, extension: str = ".pdf") -> bool:
"""Check if a document exists.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
True if document exists.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.exists(path)
def delete_document(self, document_id: str, extension: str = ".pdf") -> bool:
"""Delete a document.
Args:
document_id: Document identifier.
extension: File extension.
Returns:
True if document was deleted.
"""
path = PREFIXES.document_path(document_id, extension)
return self._storage.delete(path)
# Image operations
def save_page_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save a page image to storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Storage path where image was saved.
"""
path = PREFIXES.image_path(document_id, page_num)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_page_image(self, document_id: str, page_num: int) -> bytes:
"""Download a page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
path = PREFIXES.image_path(document_id, page_num)
return self._storage.download_bytes(path)
def get_page_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for a page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.image_path(document_id, page_num)
return self._storage.get_presigned_url(path, expires_in_seconds)
def delete_document_images(self, document_id: str) -> int:
"""Delete all images for a document.
Args:
document_id: Document identifier.
Returns:
Number of images deleted.
"""
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
images = self._storage.list_files(prefix)
deleted = 0
for img_path in images:
if self._storage.delete(img_path):
deleted += 1
return deleted
def list_document_images(self, document_id: str) -> list[str]:
"""List all images for a document.
Args:
document_id: Document identifier.
Returns:
List of image paths.
"""
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
return self._storage.list_files(prefix)
# Upload staging operations
def save_upload(
self,
content: bytes,
filename: str,
subfolder: str | None = None,
) -> str:
"""Save a file to upload staging area.
Args:
content: File content as bytes.
filename: Filename to save as.
subfolder: Optional subfolder (e.g., "async").
Returns:
Storage path where file was saved.
"""
path = PREFIXES.upload_path(filename, subfolder)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_upload(self, filename: str, subfolder: str | None = None) -> bytes:
"""Get a file from upload staging area.
Args:
filename: Filename to retrieve.
subfolder: Optional subfolder.
Returns:
File content as bytes.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.download_bytes(path)
def delete_upload(self, filename: str, subfolder: str | None = None) -> bool:
"""Delete a file from upload staging area.
Args:
filename: Filename to delete.
subfolder: Optional subfolder.
Returns:
True if file was deleted.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.delete(path)
# Result operations
def save_result(self, content: bytes, filename: str) -> str:
"""Save a result file.
Args:
content: File content as bytes.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = PREFIXES.result_path(filename)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_result(self, filename: str) -> bytes:
"""Get a result file.
Args:
filename: Filename to retrieve.
Returns:
File content as bytes.
"""
path = PREFIXES.result_path(filename)
return self._storage.download_bytes(path)
def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str:
"""Get presigned URL for a result file.
Args:
filename: Filename.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.result_path(filename)
return self._storage.get_presigned_url(path, expires_in_seconds)
def result_exists(self, filename: str) -> bool:
"""Check if a result file exists.
Args:
filename: Filename to check.
Returns:
True if file exists.
"""
path = PREFIXES.result_path(filename)
return self._storage.exists(path)
def delete_result(self, filename: str) -> bool:
"""Delete a result file.
Args:
filename: Filename to delete.
Returns:
True if file was deleted.
"""
path = PREFIXES.result_path(filename)
return self._storage.delete(path)
# Export operations
def save_export(self, content: bytes, export_id: str, filename: str) -> str:
"""Save an export file.
Args:
content: File content as bytes.
export_id: Export identifier.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = PREFIXES.export_path(export_id, filename)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_export_url(
self,
export_id: str,
filename: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for an export file.
Args:
export_id: Export identifier.
filename: Filename.
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = PREFIXES.export_path(export_id, filename)
return self._storage.get_presigned_url(path, expires_in_seconds)
# Admin image operations
def get_admin_image_path(self, document_id: str, page_num: int) -> str:
"""Get the storage path for an admin image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Storage path like "admin_images/doc123/page_1.png"
"""
return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png"
def save_admin_image(
self,
document_id: str,
page_num: int,
content: bytes,
) -> str:
"""Save an admin page image to storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
content: Image content as bytes.
Returns:
Storage path where image was saved.
"""
path = self.get_admin_image_path(document_id, page_num)
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_admin_image(self, document_id: str, page_num: int) -> bytes:
"""Download an admin page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Image content as bytes.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.download_bytes(path)
def get_admin_image_url(
self,
document_id: str,
page_num: int,
expires_in_seconds: int = 3600,
) -> str:
"""Get presigned URL for an admin page image.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
expires_in_seconds: URL expiration time.
Returns:
Presigned URL string.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.get_presigned_url(path, expires_in_seconds)
def admin_image_exists(self, document_id: str, page_num: int) -> bool:
"""Check if an admin page image exists.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
True if image exists.
"""
path = self.get_admin_image_path(document_id, page_num)
return self._storage.exists(path)
def list_admin_images(self, document_id: str) -> list[str]:
"""List all admin images for a document.
Args:
document_id: Document identifier.
Returns:
List of image paths.
"""
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
return self._storage.list_files(prefix)
def delete_admin_images(self, document_id: str) -> int:
"""Delete all admin images for a document.
Args:
document_id: Document identifier.
Returns:
Number of images deleted.
"""
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
images = self._storage.list_files(prefix)
deleted = 0
for img_path in images:
if self._storage.delete(img_path):
deleted += 1
return deleted
def get_admin_image_local_path(
self, document_id: str, page_num: int
) -> Path | None:
"""Get the local filesystem path for an admin image.
This method is useful for serving files via FileResponse.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
# Cloud storage - cannot get local path
return None
remote_path = self.get_admin_image_path(document_id, page_num)
try:
full_path = self._storage._get_full_path(remote_path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_admin_image_dimensions(
self, document_id: str, page_num: int
) -> tuple[int, int] | None:
"""Get the dimensions (width, height) of an admin image.
This method is useful for normalizing bounding box coordinates.
Args:
document_id: Document identifier.
page_num: Page number (1-indexed).
Returns:
Tuple of (width, height) if image exists, None otherwise.
"""
from PIL import Image
# Try local path first for efficiency
local_path = self.get_admin_image_local_path(document_id, page_num)
if local_path is not None:
with Image.open(local_path) as img:
return img.size
# Fall back to downloading for cloud storage
if not self.admin_image_exists(document_id, page_num):
return None
try:
import io
image_bytes = self.get_admin_image(document_id, page_num)
with Image.open(io.BytesIO(image_bytes)) as img:
return img.size
except Exception:
return None
# Raw PDF operations (legacy compatibility)
def save_raw_pdf(self, content: bytes, filename: str) -> str:
"""Save a raw PDF for auto-labeling pipeline.
Args:
content: PDF content as bytes.
filename: Filename to save as.
Returns:
Storage path where file was saved.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
self._storage.upload_bytes(content, path, overwrite=True)
return path
def get_raw_pdf(self, filename: str) -> bytes:
"""Get a raw PDF from storage.
Args:
filename: Filename to retrieve.
Returns:
PDF content as bytes.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
return self._storage.download_bytes(path)
def raw_pdf_exists(self, filename: str) -> bool:
"""Check if a raw PDF exists.
Args:
filename: Filename to check.
Returns:
True if file exists.
"""
path = f"{PREFIXES.RAW_PDFS}/{filename}"
return self._storage.exists(path)
def get_raw_pdf_local_path(self, filename: str) -> Path | None:
"""Get the local filesystem path for a raw PDF.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = f"{PREFIXES.RAW_PDFS}/{filename}"
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_raw_pdf_path(self, filename: str) -> str:
"""Get the storage path for a raw PDF (not the local filesystem path).
Args:
filename: Filename.
Returns:
Storage path like "raw_pdfs/filename.pdf"
"""
return f"{PREFIXES.RAW_PDFS}/{filename}"
# Result local path operations
def get_result_local_path(self, filename: str) -> Path | None:
"""Get the local filesystem path for a result file.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = PREFIXES.result_path(filename)
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_results_base_path(self) -> Path | None:
"""Get the base directory path for results (local storage only).
Used for mounting static file directories.
Returns:
Path to results directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.RESULTS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
# Upload local path operations
def get_upload_local_path(
self, filename: str, subfolder: str | None = None
) -> Path | None:
"""Get the local filesystem path for an upload file.
Only works with LocalStorageBackend; returns None for cloud storage.
Args:
filename: Filename to retrieve.
subfolder: Optional subfolder.
Returns:
Path object if using local storage and file exists, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
path = PREFIXES.upload_path(filename, subfolder)
try:
full_path = self._storage._get_full_path(path)
if full_path.exists():
return full_path
return None
except Exception:
return None
def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None:
"""Get the base directory path for uploads (local storage only).
Args:
subfolder: Optional subfolder (e.g., "async").
Returns:
Path to uploads directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
if subfolder:
base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}")
else:
base_path = self._storage._get_full_path(PREFIXES.UPLOADS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def upload_exists(self, filename: str, subfolder: str | None = None) -> bool:
"""Check if an upload file exists.
Args:
filename: Filename to check.
subfolder: Optional subfolder.
Returns:
True if file exists.
"""
path = PREFIXES.upload_path(filename, subfolder)
return self._storage.exists(path)
# Dataset operations
def get_datasets_base_path(self) -> Path | None:
"""Get the base directory path for datasets (local storage only).
Returns:
Path to datasets directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.DATASETS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_admin_images_base_path(self) -> Path | None:
"""Get the base directory path for admin images (local storage only).
Returns:
Path to admin_images directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_raw_pdfs_base_path(self) -> Path | None:
"""Get the base directory path for raw PDFs (local storage only).
Returns:
Path to raw_pdfs directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_autolabel_output_path(self) -> Path | None:
"""Get the directory path for autolabel output (local storage only).
Returns:
Path to autolabel_output directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
# Use a subfolder under results for autolabel output
base_path = self._storage._get_full_path("autolabel_output")
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_training_data_path(self) -> Path | None:
"""Get the directory path for training data exports (local storage only).
Returns:
Path to training directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path("training")
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
def get_exports_base_path(self) -> Path | None:
"""Get the base directory path for exports (local storage only).
Returns:
Path to exports directory if using local storage, None otherwise.
"""
if not isinstance(self._storage, LocalStorageBackend):
return None
try:
base_path = self._storage._get_full_path(PREFIXES.EXPORTS)
base_path.mkdir(parents=True, exist_ok=True)
return base_path
except Exception:
return None
# Default instance for convenience
_default_helper: StorageHelper | None = None
def get_storage_helper() -> StorageHelper:
"""Get the default storage helper instance.
Creates the helper on first call with default storage backend.
Returns:
Default StorageHelper instance.
"""
global _default_helper
if _default_helper is None:
_default_helper = StorageHelper()
return _default_helper

205
packages/shared/README.md Normal file
View File

@@ -0,0 +1,205 @@
# Shared Package
Shared utilities and abstractions for the Invoice Master system.
## Storage Abstraction Layer
A unified storage abstraction supporting multiple backends:
- **Local filesystem** - Development and testing
- **Azure Blob Storage** - Azure cloud deployments
- **AWS S3** - AWS cloud deployments
### Installation
```bash
# Basic installation (local storage only)
pip install -e packages/shared
# With Azure support
pip install -e "packages/shared[azure]"
# With S3 support
pip install -e "packages/shared[s3]"
# All cloud providers
pip install -e "packages/shared[all]"
```
### Quick Start
```python
from shared.storage import get_storage_backend
# Option 1: From configuration file
storage = get_storage_backend("storage.yaml")
# Option 2: From environment variables
from shared.storage import create_storage_backend_from_env
storage = create_storage_backend_from_env()
# Upload a file
storage.upload(Path("local/file.pdf"), "documents/file.pdf")
# Download a file
storage.download("documents/file.pdf", Path("local/downloaded.pdf"))
# Get pre-signed URL for frontend access
url = storage.get_presigned_url("documents/file.pdf", expires_in_seconds=3600)
```
### Configuration File Format
Create a `storage.yaml` file with environment variable substitution support:
```yaml
# Backend selection: local, azure_blob, or s3
backend: ${STORAGE_BACKEND:-local}
# Default pre-signed URL expiry (seconds)
presigned_url_expiry: 3600
# Local storage configuration
local:
base_path: ${STORAGE_BASE_PATH:-./data/storage}
# Azure Blob Storage configuration
azure:
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
create_container: false
# AWS S3 configuration
s3:
bucket_name: ${AWS_S3_BUCKET}
region_name: ${AWS_REGION:-us-east-1}
access_key_id: ${AWS_ACCESS_KEY_ID}
secret_access_key: ${AWS_SECRET_ACCESS_KEY}
endpoint_url: ${AWS_ENDPOINT_URL} # Optional, for S3-compatible services
create_bucket: false
```
### Environment Variables
| Variable | Backend | Description |
|----------|---------|-------------|
| `STORAGE_BACKEND` | All | Backend type: `local`, `azure_blob`, `s3` |
| `STORAGE_BASE_PATH` | Local | Base directory path |
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | Connection string |
| `AZURE_STORAGE_CONTAINER` | Azure | Container name |
| `AWS_S3_BUCKET` | S3 | Bucket name |
| `AWS_REGION` | S3 | AWS region (default: us-east-1) |
| `AWS_ACCESS_KEY_ID` | S3 | Access key (optional, uses credential chain) |
| `AWS_SECRET_ACCESS_KEY` | S3 | Secret key (optional) |
| `AWS_ENDPOINT_URL` | S3 | Custom endpoint for S3-compatible services |
### API Reference
#### StorageBackend Interface
```python
class StorageBackend(ABC):
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
"""Upload a file to storage."""
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from storage."""
def exists(self, remote_path: str) -> bool:
"""Check if a file exists."""
def list_files(self, prefix: str) -> list[str]:
"""List files with given prefix."""
def delete(self, remote_path: str) -> bool:
"""Delete a file."""
def get_url(self, remote_path: str) -> str:
"""Get URL for a file."""
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
"""Generate a pre-signed URL for temporary access (1-604800 seconds)."""
def upload_bytes(self, data: bytes, remote_path: str, overwrite: bool = False) -> str:
"""Upload bytes directly."""
def download_bytes(self, remote_path: str) -> bytes:
"""Download file as bytes."""
```
#### Factory Functions
```python
# Create from configuration file
storage = create_storage_backend_from_file("storage.yaml")
# Create from environment variables
storage = create_storage_backend_from_env()
# Create from StorageConfig object
config = StorageConfig(backend_type="local", base_path=Path("./data"))
storage = create_storage_backend(config)
# Convenience function with fallback chain: config file -> env vars -> local default
storage = get_storage_backend("storage.yaml") # or None for env-only
```
### Pre-signed URLs
Pre-signed URLs provide temporary access to files without exposing credentials:
```python
# Generate URL valid for 1 hour (default)
url = storage.get_presigned_url("documents/invoice.pdf")
# Generate URL valid for 24 hours
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=86400)
# Maximum expiry: 7 days (604800 seconds)
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=604800)
```
**Note:** Local storage returns `file://` URLs that don't actually expire.
### Error Handling
```python
from shared.storage import (
StorageError,
FileNotFoundStorageError,
PresignedUrlNotSupportedError,
)
try:
storage.download("nonexistent.pdf", Path("local.pdf"))
except FileNotFoundStorageError as e:
print(f"File not found: {e}")
except StorageError as e:
print(f"Storage error: {e}")
```
### Testing with MinIO (S3-compatible)
```bash
# Start MinIO locally
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
# Configure environment
export STORAGE_BACKEND=s3
export AWS_S3_BUCKET=test-bucket
export AWS_ENDPOINT_URL=http://localhost:9000
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
```
### Module Structure
```
shared/storage/
├── __init__.py # Public exports
├── base.py # Abstract interface and exceptions
├── local.py # Local filesystem backend
├── azure.py # Azure Blob Storage backend
├── s3.py # AWS S3 backend
├── config_loader.py # YAML configuration loader
└── factory.py # Backend factory functions
```

View File

@@ -16,4 +16,18 @@ setup(
"pyyaml>=6.0", "pyyaml>=6.0",
"thefuzz>=0.20.0", "thefuzz>=0.20.0",
], ],
extras_require={
"azure": [
"azure-storage-blob>=12.19.0",
"azure-identity>=1.15.0",
],
"s3": [
"boto3>=1.34.0",
],
"all": [
"azure-storage-blob>=12.19.0",
"azure-identity>=1.15.0",
"boto3>=1.34.0",
],
},
) )

View File

@@ -58,23 +58,16 @@ def get_db_connection_string():
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}" return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
# Paths Configuration - auto-detect WSL vs Windows # Paths Configuration - uses STORAGE_BASE_PATH for consistency
if _is_wsl(): # All paths are relative to STORAGE_BASE_PATH (defaults to ~/invoice-data/data)
# WSL: use native Linux filesystem for better I/O performance _storage_base = os.path.expanduser(os.getenv('STORAGE_BASE_PATH', '~/invoice-data/data'))
PATHS = { PATHS = {
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'), 'csv_dir': f'{_storage_base}/structured_data',
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'), 'pdf_dir': f'{_storage_base}/raw_pdfs',
'output_dir': os.path.expanduser('~/invoice-data/dataset'), 'output_dir': f'{_storage_base}/datasets',
'reports_dir': 'reports', # Keep reports in project directory 'reports_dir': 'reports', # Keep reports in project directory
} }
else:
# Windows or native Linux: use relative paths
PATHS = {
'csv_dir': 'data/structured_data',
'pdf_dir': 'data/raw_pdfs',
'output_dir': 'data/dataset',
'reports_dir': 'reports',
}
# Auto-labeling Configuration # Auto-labeling Configuration
AUTOLABEL = { AUTOLABEL = {

View File

@@ -0,0 +1,46 @@
"""
Shared Field Definitions - Single Source of Truth.
This module provides centralized field class definitions used throughout
the invoice extraction system. All field mappings are derived from
FIELD_DEFINITIONS to ensure consistency.
Usage:
from shared.fields import FIELD_CLASSES, CLASS_NAMES, FIELD_CLASS_IDS
Available exports:
- FieldDefinition: Dataclass for field definition
- FIELD_DEFINITIONS: Tuple of all field definitions (immutable)
- NUM_CLASSES: Total number of field classes (10)
- CLASS_NAMES: List of class names in order [0..9]
- FIELD_CLASSES: dict[int, str] - class_id to class_name
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
"""
from .field_config import FieldDefinition, FIELD_DEFINITIONS, NUM_CLASSES
from .mappings import (
CLASS_NAMES,
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
ACCOUNT_FIELD_MAPPING,
)
__all__ = [
"FieldDefinition",
"FIELD_DEFINITIONS",
"NUM_CLASSES",
"CLASS_NAMES",
"FIELD_CLASSES",
"FIELD_CLASS_IDS",
"CLASS_TO_FIELD",
"CSV_TO_CLASS_MAPPING",
"TRAINING_FIELD_CLASSES",
"ACCOUNT_FIELD_MAPPING",
]

View File

@@ -0,0 +1,58 @@
"""
Field Configuration - Single Source of Truth
This module defines all invoice field classes used throughout the system.
The class IDs are verified against the trained YOLO model (best.pt).
IMPORTANT: Do not modify class_id values without retraining the model!
"""
from dataclasses import dataclass
from typing import Final
@dataclass(frozen=True)
class FieldDefinition:
"""Immutable field definition for invoice extraction.
Attributes:
class_id: YOLO class ID (0-9), must match trained model
class_name: YOLO class name (lowercase_underscore)
field_name: Business field name used in API responses
csv_name: CSV column name for data import/export
is_derived: True if field is derived from other fields (not in CSV)
"""
class_id: int
class_name: str
field_name: str
csv_name: str
is_derived: bool = False
# Verified from model weights (runs/train/invoice_fields/weights/best.pt)
# model.names = {0: 'invoice_number', 1: 'invoice_date', ..., 8: 'customer_number', 9: 'payment_line'}
#
# DO NOT CHANGE THE ORDER - it must match the trained model!
FIELD_DEFINITIONS: Final[tuple[FieldDefinition, ...]] = (
FieldDefinition(0, "invoice_number", "InvoiceNumber", "InvoiceNumber"),
FieldDefinition(1, "invoice_date", "InvoiceDate", "InvoiceDate"),
FieldDefinition(2, "invoice_due_date", "InvoiceDueDate", "InvoiceDueDate"),
FieldDefinition(3, "ocr_number", "OCR", "OCR"),
FieldDefinition(4, "bankgiro", "Bankgiro", "Bankgiro"),
FieldDefinition(5, "plusgiro", "Plusgiro", "Plusgiro"),
FieldDefinition(6, "amount", "Amount", "Amount"),
FieldDefinition(
7,
"supplier_org_number",
"supplier_organisation_number",
"supplier_organisation_number",
),
FieldDefinition(8, "customer_number", "customer_number", "customer_number"),
FieldDefinition(
9, "payment_line", "payment_line", "payment_line", is_derived=True
),
)
# Total number of field classes
NUM_CLASSES: Final[int] = len(FIELD_DEFINITIONS)

View File

@@ -0,0 +1,57 @@
"""
Field Mappings - Auto-generated from FIELD_DEFINITIONS.
All mappings in this file are derived from field_config.FIELD_DEFINITIONS.
This ensures consistency across the entire codebase.
DO NOT hardcode field mappings elsewhere - always import from this module.
"""
from typing import Final
from .field_config import FIELD_DEFINITIONS
# List of class names in order (for YOLO classes.txt generation)
# Index matches class_id: CLASS_NAMES[0] = "invoice_number"
CLASS_NAMES: Final[list[str]] = [fd.class_name for fd in FIELD_DEFINITIONS]
# class_id -> class_name mapping
# Example: {0: "invoice_number", 1: "invoice_date", ...}
FIELD_CLASSES: Final[dict[int, str]] = {
fd.class_id: fd.class_name for fd in FIELD_DEFINITIONS
}
# class_name -> class_id mapping (reverse of FIELD_CLASSES)
# Example: {"invoice_number": 0, "invoice_date": 1, ...}
FIELD_CLASS_IDS: Final[dict[str, int]] = {
fd.class_name: fd.class_id for fd in FIELD_DEFINITIONS
}
# class_name -> field_name mapping (for API responses)
# Example: {"invoice_number": "InvoiceNumber", "ocr_number": "OCR", ...}
CLASS_TO_FIELD: Final[dict[str, str]] = {
fd.class_name: fd.field_name for fd in FIELD_DEFINITIONS
}
# field_name -> class_id mapping (for CSV import)
# Excludes derived fields like payment_line
# Example: {"InvoiceNumber": 0, "InvoiceDate": 1, ...}
CSV_TO_CLASS_MAPPING: Final[dict[str, int]] = {
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS if not fd.is_derived
}
# field_name -> class_id mapping (for training, includes all fields)
# Example: {"InvoiceNumber": 0, ..., "payment_line": 9}
TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
}
# Account field mapping for supplier_accounts special handling
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
"supplier_accounts": {
"BG": "Bankgiro",
"PG": "Plusgiro",
}
}

View File

@@ -0,0 +1,59 @@
"""
Storage abstraction layer for training data.
Provides a unified interface for local filesystem, Azure Blob Storage, and AWS S3.
"""
from shared.storage.base import (
FileNotFoundStorageError,
PresignedUrlNotSupportedError,
StorageBackend,
StorageConfig,
StorageError,
)
from shared.storage.factory import (
create_storage_backend,
create_storage_backend_from_env,
create_storage_backend_from_file,
get_default_storage_config,
get_storage_backend,
)
from shared.storage.local import LocalStorageBackend
from shared.storage.prefixes import PREFIXES, StoragePrefixes
__all__ = [
# Base classes and exceptions
"StorageBackend",
"StorageConfig",
"StorageError",
"FileNotFoundStorageError",
"PresignedUrlNotSupportedError",
# Backends
"LocalStorageBackend",
# Factory functions
"create_storage_backend",
"create_storage_backend_from_env",
"create_storage_backend_from_file",
"get_default_storage_config",
"get_storage_backend",
# Path prefixes
"PREFIXES",
"StoragePrefixes",
]
# Lazy imports to avoid dependencies when not using specific backends
def __getattr__(name: str):
if name == "AzureBlobStorageBackend":
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend
if name == "S3StorageBackend":
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend
if name == "load_storage_config":
from shared.storage.config_loader import load_storage_config
return load_storage_config
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,335 @@
"""
Azure Blob Storage backend.
Provides storage operations using Azure Blob Storage.
"""
from pathlib import Path
from azure.storage.blob import (
BlobSasPermissions,
BlobServiceClient,
ContainerClient,
generate_blob_sas,
)
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class AzureBlobStorageBackend(StorageBackend):
"""Storage backend using Azure Blob Storage.
Files are stored as blobs in an Azure Blob Storage container.
"""
def __init__(
self,
connection_string: str,
container_name: str,
create_container: bool = False,
) -> None:
"""Initialize Azure Blob Storage backend.
Args:
connection_string: Azure Storage connection string.
container_name: Name of the blob container.
create_container: If True, create the container if it doesn't exist.
"""
self._connection_string = connection_string
self._container_name = container_name
self._blob_service = BlobServiceClient.from_connection_string(connection_string)
self._container = self._blob_service.get_container_client(container_name)
# Extract account key from connection string for SAS token generation
self._account_key = self._extract_account_key(connection_string)
if create_container and not self._container.exists():
self._container.create_container()
@staticmethod
def _extract_account_key(connection_string: str) -> str | None:
"""Extract account key from connection string.
Args:
connection_string: Azure Storage connection string.
Returns:
Account key if found, None otherwise.
"""
for part in connection_string.split(";"):
if part.startswith("AccountKey="):
return part[len("AccountKey=") :]
return None
@property
def container_name(self) -> str:
"""Get the container name for this storage backend."""
return self._container_name
@property
def container_client(self) -> ContainerClient:
"""Get the Azure container client."""
return self._container
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to Azure Blob Storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination blob path.
overwrite: If True, overwrite existing blob.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If blob exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
blob_client = self._container.get_blob_client(remote_path)
if blob_client.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
with open(local_path, "rb") as f:
blob_client.upload_blob(f, overwrite=overwrite)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a blob from Azure Blob Storage.
Args:
remote_path: Blob path in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
stream = blob_client.download_blob()
local_path.write_bytes(stream.readall())
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if a blob exists in storage.
Args:
remote_path: Blob path to check.
Returns:
True if the blob exists, False otherwise.
"""
blob_client = self._container.get_blob_client(remote_path)
return blob_client.exists()
def list_files(self, prefix: str) -> list[str]:
"""List blobs in storage with given prefix.
Args:
prefix: Blob path prefix to filter.
Returns:
List of blob paths matching the prefix.
"""
if prefix:
blobs = self._container.list_blobs(name_starts_with=prefix)
else:
blobs = self._container.list_blobs()
return [blob.name for blob in blobs]
def delete(self, remote_path: str) -> bool:
"""Delete a blob from storage.
Args:
remote_path: Blob path to delete.
Returns:
True if blob was deleted, False if it didn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
return False
blob_client.delete_blob()
return True
def get_url(self, remote_path: str) -> str:
"""Get the URL for a blob.
Args:
remote_path: Blob path in storage.
Returns:
URL to access the blob.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
return blob_client.url
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to Azure Blob Storage.
Args:
data: Bytes to upload.
remote_path: Destination blob path.
overwrite: If True, overwrite existing blob.
Returns:
The remote path where the data was stored.
"""
blob_client = self._container.get_blob_client(remote_path)
if blob_client.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
blob_client.upload_blob(data, overwrite=overwrite)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download a blob as bytes.
Args:
remote_path: Blob path in storage.
Returns:
The blob contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
stream = blob_client.download_blob()
return stream.readall()
def upload_directory(
self, local_dir: Path, remote_prefix: str, overwrite: bool = False
) -> list[str]:
"""Upload all files in a directory to Azure Blob Storage.
Args:
local_dir: Local directory to upload.
remote_prefix: Prefix for remote blob paths.
overwrite: If True, overwrite existing blobs.
Returns:
List of remote paths where files were stored.
"""
uploaded: list[str] = []
for file_path in local_dir.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(local_dir)
remote_path = f"{remote_prefix}{relative_path}".replace("\\", "/")
self.upload(file_path, remote_path, overwrite=overwrite)
uploaded.append(remote_path)
return uploaded
def download_directory(
self, remote_prefix: str, local_dir: Path
) -> list[Path]:
"""Download all blobs with a prefix to a local directory.
Args:
remote_prefix: Blob path prefix to download.
local_dir: Local directory to download to.
Returns:
List of local paths where files were downloaded.
"""
downloaded: list[Path] = []
blobs = self.list_files(remote_prefix)
for blob_path in blobs:
# Remove prefix to get relative path
if remote_prefix:
relative_path = blob_path[len(remote_prefix):]
if relative_path.startswith("/"):
relative_path = relative_path[1:]
else:
relative_path = blob_path
local_path = local_dir / relative_path
self.download(blob_path, local_path)
downloaded.append(local_path)
return downloaded
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a SAS URL for temporary blob access.
Args:
remote_path: Blob path in storage.
expires_in_seconds: SAS token validity duration (1 to 604800 seconds / 7 days).
Returns:
Blob URL with SAS token.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
ValueError: If expires_in_seconds is out of valid range.
"""
if expires_in_seconds < 1 or expires_in_seconds > 604800:
raise ValueError(
"expires_in_seconds must be between 1 and 604800 (7 days)"
)
from datetime import datetime, timedelta, timezone
blob_client = self._container.get_blob_client(remote_path)
if not blob_client.exists():
raise FileNotFoundStorageError(remote_path)
# Generate SAS token
sas_token = generate_blob_sas(
account_name=self._blob_service.account_name,
container_name=self._container_name,
blob_name=remote_path,
account_key=self._account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
)
return f"{blob_client.url}?{sas_token}"

View File

@@ -0,0 +1,229 @@
"""
Base classes and interfaces for storage backends.
Defines the abstract StorageBackend interface and common exceptions.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
class StorageError(Exception):
"""Base exception for storage operations."""
pass
class FileNotFoundStorageError(StorageError):
"""Raised when a file is not found in storage."""
def __init__(self, path: str) -> None:
self.path = path
super().__init__(f"File not found in storage: {path}")
class PresignedUrlNotSupportedError(StorageError):
"""Raised when pre-signed URLs are not supported by a backend."""
def __init__(self, backend_type: str) -> None:
self.backend_type = backend_type
super().__init__(f"Pre-signed URLs not supported for backend: {backend_type}")
@dataclass(frozen=True)
class StorageConfig:
"""Configuration for storage backend.
Attributes:
backend_type: Type of storage backend ("local", "azure_blob", or "s3").
connection_string: Azure Blob Storage connection string (for azure_blob).
container_name: Azure Blob Storage container name (for azure_blob).
base_path: Base path for local storage (for local).
bucket_name: S3 bucket name (for s3).
region_name: AWS region name (for s3).
access_key_id: AWS access key ID (for s3).
secret_access_key: AWS secret access key (for s3).
endpoint_url: Custom endpoint URL for S3-compatible services (for s3).
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
"""
backend_type: str
connection_string: str | None = None
container_name: str | None = None
base_path: Path | None = None
bucket_name: str | None = None
region_name: str | None = None
access_key_id: str | None = None
secret_access_key: str | None = None
endpoint_url: str | None = None
presigned_url_expiry: int = 3600
class StorageBackend(ABC):
"""Abstract base class for storage backends.
Provides a unified interface for storing and retrieving files
from different storage systems (local filesystem, Azure Blob, etc.).
"""
@abstractmethod
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If file exists and overwrite is False.
"""
pass
@abstractmethod
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from storage.
Args:
remote_path: Path to the file in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
pass
@abstractmethod
def exists(self, remote_path: str) -> bool:
"""Check if a file exists in storage.
Args:
remote_path: Path to check in storage.
Returns:
True if the file exists, False otherwise.
"""
pass
@abstractmethod
def list_files(self, prefix: str) -> list[str]:
"""List files in storage with given prefix.
Args:
prefix: Path prefix to filter files.
Returns:
List of file paths matching the prefix.
"""
pass
@abstractmethod
def delete(self, remote_path: str) -> bool:
"""Delete a file from storage.
Args:
remote_path: Path to the file to delete.
Returns:
True if file was deleted, False if it didn't exist.
"""
pass
@abstractmethod
def get_url(self, remote_path: str) -> str:
"""Get a URL or path to access a file.
Args:
remote_path: Path to the file in storage.
Returns:
URL or path to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
pass
@abstractmethod
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a pre-signed URL for temporary access.
Args:
remote_path: Path to the file in storage.
expires_in_seconds: URL validity duration (default 1 hour).
Returns:
Pre-signed URL string.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
PresignedUrlNotSupportedError: If backend doesn't support pre-signed URLs.
"""
pass
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to storage.
Default implementation writes to temp file then uploads.
Subclasses may override for more efficient implementation.
Args:
data: Bytes to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the data was stored.
"""
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(data)
temp_path = Path(f.name)
try:
return self.upload(temp_path, remote_path, overwrite=overwrite)
finally:
temp_path.unlink(missing_ok=True)
def download_bytes(self, remote_path: str) -> bytes:
"""Download a file as bytes.
Default implementation downloads to temp file then reads.
Subclasses may override for more efficient implementation.
Args:
remote_path: Path to the file in storage.
Returns:
The file contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as f:
temp_path = Path(f.name)
try:
self.download(remote_path, temp_path)
return temp_path.read_bytes()
finally:
temp_path.unlink(missing_ok=True)

View File

@@ -0,0 +1,242 @@
"""
Configuration file loader for storage backends.
Supports YAML configuration files with environment variable substitution.
"""
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LocalConfig:
"""Local storage backend configuration."""
base_path: Path
@dataclass(frozen=True)
class AzureConfig:
"""Azure Blob Storage configuration."""
connection_string: str
container_name: str
create_container: bool = False
@dataclass(frozen=True)
class S3Config:
"""AWS S3 configuration."""
bucket_name: str
region_name: str | None = None
access_key_id: str | None = None
secret_access_key: str | None = None
endpoint_url: str | None = None
create_bucket: bool = False
@dataclass(frozen=True)
class StorageFileConfig:
"""Extended storage configuration from file.
Attributes:
backend_type: Type of storage backend.
local: Local backend configuration.
azure: Azure Blob configuration.
s3: S3 configuration.
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
"""
backend_type: str
local: LocalConfig | None = None
azure: AzureConfig | None = None
s3: S3Config | None = None
presigned_url_expiry: int = 3600
def substitute_env_vars(value: str) -> str:
"""Substitute environment variables in a string.
Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax.
Args:
value: String potentially containing env var references.
Returns:
String with env vars substituted.
"""
pattern = r"\$\{([A-Z_][A-Z0-9_]*)(?::-([^}]*))?\}"
def replace(match: re.Match[str]) -> str:
var_name = match.group(1)
default = match.group(2)
return os.environ.get(var_name, default or "")
return re.sub(pattern, replace, value)
def _substitute_in_dict(data: dict[str, Any]) -> dict[str, Any]:
"""Recursively substitute env vars in a dictionary.
Args:
data: Dictionary to process.
Returns:
New dictionary with substitutions applied.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, str):
result[key] = substitute_env_vars(value)
elif isinstance(value, dict):
result[key] = _substitute_in_dict(value)
elif isinstance(value, list):
result[key] = [
substitute_env_vars(item) if isinstance(item, str) else item
for item in value
]
else:
result[key] = value
return result
def _parse_local_config(data: dict[str, Any]) -> LocalConfig:
"""Parse local configuration section.
Args:
data: Dictionary containing local config.
Returns:
LocalConfig instance.
Raises:
ValueError: If required fields are missing.
"""
base_path = data.get("base_path")
if not base_path:
raise ValueError("local.base_path is required")
return LocalConfig(base_path=Path(base_path))
def _parse_azure_config(data: dict[str, Any]) -> AzureConfig:
"""Parse Azure configuration section.
Args:
data: Dictionary containing Azure config.
Returns:
AzureConfig instance.
Raises:
ValueError: If required fields are missing.
"""
connection_string = data.get("connection_string")
container_name = data.get("container_name")
if not connection_string:
raise ValueError("azure.connection_string is required")
if not container_name:
raise ValueError("azure.container_name is required")
return AzureConfig(
connection_string=connection_string,
container_name=container_name,
create_container=data.get("create_container", False),
)
def _parse_s3_config(data: dict[str, Any]) -> S3Config:
"""Parse S3 configuration section.
Args:
data: Dictionary containing S3 config.
Returns:
S3Config instance.
Raises:
ValueError: If required fields are missing.
"""
bucket_name = data.get("bucket_name")
if not bucket_name:
raise ValueError("s3.bucket_name is required")
return S3Config(
bucket_name=bucket_name,
region_name=data.get("region_name"),
access_key_id=data.get("access_key_id"),
secret_access_key=data.get("secret_access_key"),
endpoint_url=data.get("endpoint_url"),
create_bucket=data.get("create_bucket", False),
)
def load_storage_config(config_path: Path | str) -> StorageFileConfig:
"""Load storage configuration from YAML file.
Supports environment variable substitution using ${VAR_NAME} or
${VAR_NAME:-default} syntax.
Args:
config_path: Path to configuration file.
Returns:
Parsed StorageFileConfig.
Raises:
FileNotFoundError: If config file doesn't exist.
ValueError: If config is invalid.
"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
try:
raw_content = config_path.read_text(encoding="utf-8")
data = yaml.safe_load(raw_content)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in config file: {e}") from e
if not isinstance(data, dict):
raise ValueError("Config file must contain a YAML dictionary")
# Substitute environment variables
data = _substitute_in_dict(data)
# Extract backend type
backend_type = data.get("backend")
if not backend_type:
raise ValueError("'backend' field is required in config file")
# Parse presigned URL expiry
presigned_url_expiry = data.get("presigned_url_expiry", 3600)
# Parse backend-specific configurations
local_config = None
azure_config = None
s3_config = None
if "local" in data:
local_config = _parse_local_config(data["local"])
if "azure" in data:
azure_config = _parse_azure_config(data["azure"])
if "s3" in data:
s3_config = _parse_s3_config(data["s3"])
return StorageFileConfig(
backend_type=backend_type,
local=local_config,
azure=azure_config,
s3=s3_config,
presigned_url_expiry=presigned_url_expiry,
)

View File

@@ -0,0 +1,296 @@
"""
Factory functions for creating storage backends.
Provides convenient functions for creating storage backends from
configuration or environment variables.
"""
import os
from pathlib import Path
from shared.storage.base import StorageBackend, StorageConfig
def create_storage_backend(config: StorageConfig) -> StorageBackend:
"""Create a storage backend from configuration.
Args:
config: Storage configuration.
Returns:
A configured storage backend.
Raises:
ValueError: If configuration is invalid.
"""
if config.backend_type == "local":
if config.base_path is None:
raise ValueError("base_path is required for local storage backend")
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=config.base_path)
elif config.backend_type == "azure_blob":
if config.connection_string is None:
raise ValueError(
"connection_string is required for Azure blob storage backend"
)
if config.container_name is None:
raise ValueError(
"container_name is required for Azure blob storage backend"
)
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=config.connection_string,
container_name=config.container_name,
)
elif config.backend_type == "s3":
if config.bucket_name is None:
raise ValueError("bucket_name is required for S3 storage backend")
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=config.bucket_name,
region_name=config.region_name,
access_key_id=config.access_key_id,
secret_access_key=config.secret_access_key,
endpoint_url=config.endpoint_url,
)
else:
raise ValueError(f"Unknown storage backend type: {config.backend_type}")
def get_default_storage_config() -> StorageConfig:
"""Get storage configuration from environment variables.
Environment variables:
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
STORAGE_BASE_PATH: Base path for local storage.
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
AZURE_STORAGE_CONTAINER: Azure container name.
AWS_S3_BUCKET: S3 bucket name.
AWS_REGION: AWS region name.
AWS_ACCESS_KEY_ID: AWS access key ID.
AWS_SECRET_ACCESS_KEY: AWS secret access key.
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
Returns:
StorageConfig from environment.
"""
backend_type = os.environ.get("STORAGE_BACKEND", "local")
if backend_type == "local":
base_path_str = os.environ.get("STORAGE_BASE_PATH")
# Expand ~ to home directory
base_path = Path(os.path.expanduser(base_path_str)) if base_path_str else None
return StorageConfig(
backend_type="local",
base_path=base_path,
)
elif backend_type == "azure_blob":
return StorageConfig(
backend_type="azure_blob",
connection_string=os.environ.get("AZURE_STORAGE_CONNECTION_STRING"),
container_name=os.environ.get("AZURE_STORAGE_CONTAINER"),
)
elif backend_type == "s3":
return StorageConfig(
backend_type="s3",
bucket_name=os.environ.get("AWS_S3_BUCKET"),
region_name=os.environ.get("AWS_REGION"),
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
endpoint_url=os.environ.get("AWS_ENDPOINT_URL"),
)
else:
return StorageConfig(backend_type=backend_type)
def create_storage_backend_from_env() -> StorageBackend:
"""Create a storage backend from environment variables.
Environment variables:
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
STORAGE_BASE_PATH: Base path for local storage.
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
AZURE_STORAGE_CONTAINER: Azure container name.
AWS_S3_BUCKET: S3 bucket name.
AWS_REGION: AWS region name.
AWS_ACCESS_KEY_ID: AWS access key ID.
AWS_SECRET_ACCESS_KEY: AWS secret access key.
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
Returns:
A configured storage backend.
Raises:
ValueError: If required environment variables are missing or empty.
"""
backend_type = os.environ.get("STORAGE_BACKEND", "local").strip()
if backend_type == "local":
base_path = os.environ.get("STORAGE_BASE_PATH", "").strip()
if not base_path:
raise ValueError(
"STORAGE_BASE_PATH environment variable is required and cannot be empty"
)
# Expand ~ to home directory
base_path_expanded = os.path.expanduser(base_path)
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=Path(base_path_expanded))
elif backend_type == "azure_blob":
connection_string = os.environ.get(
"AZURE_STORAGE_CONNECTION_STRING", ""
).strip()
if not connection_string:
raise ValueError(
"AZURE_STORAGE_CONNECTION_STRING environment variable is required "
"and cannot be empty"
)
container_name = os.environ.get("AZURE_STORAGE_CONTAINER", "").strip()
if not container_name:
raise ValueError(
"AZURE_STORAGE_CONTAINER environment variable is required "
"and cannot be empty"
)
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=connection_string,
container_name=container_name,
)
elif backend_type == "s3":
bucket_name = os.environ.get("AWS_S3_BUCKET", "").strip()
if not bucket_name:
raise ValueError(
"AWS_S3_BUCKET environment variable is required and cannot be empty"
)
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=bucket_name,
region_name=os.environ.get("AWS_REGION", "").strip() or None,
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "").strip() or None,
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "").strip()
or None,
endpoint_url=os.environ.get("AWS_ENDPOINT_URL", "").strip() or None,
)
else:
raise ValueError(f"Unknown storage backend type: {backend_type}")
def create_storage_backend_from_file(config_path: Path | str) -> StorageBackend:
"""Create a storage backend from a configuration file.
Args:
config_path: Path to YAML configuration file.
Returns:
A configured storage backend.
Raises:
FileNotFoundError: If config file doesn't exist.
ValueError: If configuration is invalid.
"""
from shared.storage.config_loader import load_storage_config
file_config = load_storage_config(config_path)
if file_config.backend_type == "local":
if file_config.local is None:
raise ValueError("local configuration section is required")
from shared.storage.local import LocalStorageBackend
return LocalStorageBackend(base_path=file_config.local.base_path)
elif file_config.backend_type == "azure_blob":
if file_config.azure is None:
raise ValueError("azure configuration section is required")
# Import here to allow lazy loading of Azure SDK
from azure.storage.blob import BlobServiceClient # noqa: F401
from shared.storage.azure import AzureBlobStorageBackend
return AzureBlobStorageBackend(
connection_string=file_config.azure.connection_string,
container_name=file_config.azure.container_name,
create_container=file_config.azure.create_container,
)
elif file_config.backend_type == "s3":
if file_config.s3 is None:
raise ValueError("s3 configuration section is required")
# Import here to allow lazy loading of boto3
import boto3 # noqa: F401
from shared.storage.s3 import S3StorageBackend
return S3StorageBackend(
bucket_name=file_config.s3.bucket_name,
region_name=file_config.s3.region_name,
access_key_id=file_config.s3.access_key_id,
secret_access_key=file_config.s3.secret_access_key,
endpoint_url=file_config.s3.endpoint_url,
create_bucket=file_config.s3.create_bucket,
)
else:
raise ValueError(f"Unknown storage backend type: {file_config.backend_type}")
def get_storage_backend(config_path: Path | str | None = None) -> StorageBackend:
"""Get storage backend with fallback chain.
Priority:
1. Config file (if provided)
2. Environment variables
Args:
config_path: Optional path to config file.
Returns:
A configured storage backend.
Raises:
ValueError: If configuration is invalid.
FileNotFoundError: If specified config file doesn't exist.
"""
if config_path:
return create_storage_backend_from_file(config_path)
# Fall back to environment variables
return create_storage_backend_from_env()

View File

@@ -0,0 +1,262 @@
"""
Local filesystem storage backend.
Provides storage operations using the local filesystem.
"""
import shutil
from pathlib import Path
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class LocalStorageBackend(StorageBackend):
"""Storage backend using local filesystem.
Files are stored relative to a base path on the local filesystem.
"""
def __init__(self, base_path: str | Path) -> None:
"""Initialize local storage backend.
Args:
base_path: Base directory for all storage operations.
Will be created if it doesn't exist.
"""
self._base_path = Path(base_path)
self._base_path.mkdir(parents=True, exist_ok=True)
@property
def base_path(self) -> Path:
"""Get the base path for this storage backend."""
return self._base_path
def _get_full_path(self, remote_path: str) -> Path:
"""Convert a remote path to a full local path with security validation.
Args:
remote_path: The remote path to resolve.
Returns:
The full local path.
Raises:
StorageError: If the path attempts to escape the base directory.
"""
# Reject absolute paths
if remote_path.startswith("/") or (len(remote_path) > 1 and remote_path[1] == ":"):
raise StorageError(f"Absolute paths not allowed: {remote_path}")
# Resolve to prevent path traversal attacks
full_path = (self._base_path / remote_path).resolve()
base_resolved = self._base_path.resolve()
# Verify the resolved path is within base_path
try:
full_path.relative_to(base_resolved)
except ValueError:
raise StorageError(f"Path traversal not allowed: {remote_path}")
return full_path
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to local storage.
Args:
local_path: Path to the local file to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If file exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
dest_path = self._get_full_path(remote_path)
if dest_path.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(local_path, dest_path)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download a file from local storage.
Args:
remote_path: Path to the file in storage.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
source_path = self._get_full_path(remote_path)
if not source_path.exists():
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_path, local_path)
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if a file exists in storage.
Args:
remote_path: Path to check in storage.
Returns:
True if the file exists, False otherwise.
"""
return self._get_full_path(remote_path).exists()
def list_files(self, prefix: str) -> list[str]:
"""List files in storage with given prefix.
Args:
prefix: Path prefix to filter files.
Returns:
Sorted list of file paths matching the prefix.
"""
if prefix:
search_path = self._get_full_path(prefix)
if not search_path.exists():
return []
base_for_relative = self._base_path
else:
search_path = self._base_path
base_for_relative = self._base_path
files: list[str] = []
if search_path.is_file():
files.append(str(search_path.relative_to(self._base_path)))
elif search_path.is_dir():
for file_path in search_path.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(self._base_path)
files.append(str(relative_path).replace("\\", "/"))
return sorted(files)
def delete(self, remote_path: str) -> bool:
"""Delete a file from storage.
Args:
remote_path: Path to the file to delete.
Returns:
True if file was deleted, False if it didn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
return False
file_path.unlink()
return True
def get_url(self, remote_path: str) -> str:
"""Get a file:// URL to access a file.
Args:
remote_path: Path to the file in storage.
Returns:
file:// URL to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.as_uri()
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to storage.
Args:
data: Bytes to upload.
remote_path: Destination path in storage.
overwrite: If True, overwrite existing file.
Returns:
The remote path where the data was stored.
"""
dest_path = self._get_full_path(remote_path)
if dest_path.exists() and not overwrite:
raise StorageError(f"File already exists: {remote_path}")
dest_path.parent.mkdir(parents=True, exist_ok=True)
dest_path.write_bytes(data)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download a file as bytes.
Args:
remote_path: Path to the file in storage.
Returns:
The file contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.read_bytes()
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Get a file:// URL for local file access.
For local storage, this returns a file:// URI.
Note: Local file:// URLs don't actually expire.
Args:
remote_path: Path to the file in storage.
expires_in_seconds: Ignored for local storage (URLs don't expire).
Returns:
file:// URL to access the file.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
file_path = self._get_full_path(remote_path)
if not file_path.exists():
raise FileNotFoundStorageError(remote_path)
return file_path.as_uri()

View File

@@ -0,0 +1,158 @@
"""
Storage path prefixes for unified file organization.
Provides standardized path prefixes for organizing files within
the storage backend, ensuring consistent structure across
local, Azure Blob, and S3 storage.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class StoragePrefixes:
"""Standardized storage path prefixes.
All paths are relative to the storage backend root.
These prefixes ensure consistent file organization across
all storage backends (local, Azure, S3).
Usage:
from shared.storage.prefixes import PREFIXES
path = f"{PREFIXES.DOCUMENTS}/{document_id}.pdf"
storage.upload_bytes(content, path)
"""
# Document storage
DOCUMENTS: str = "documents"
"""Original document files (PDFs, etc.)"""
IMAGES: str = "images"
"""Page images extracted from documents"""
# Processing directories
UPLOADS: str = "uploads"
"""Temporary upload staging area"""
RESULTS: str = "results"
"""Inference results and visualizations"""
EXPORTS: str = "exports"
"""Exported datasets and annotations"""
# Training data
DATASETS: str = "datasets"
"""Training dataset files"""
MODELS: str = "models"
"""Trained model weights and checkpoints"""
# Data pipeline directories (legacy compatibility)
RAW_PDFS: str = "raw_pdfs"
"""Raw PDF files for auto-labeling pipeline"""
STRUCTURED_DATA: str = "structured_data"
"""CSV/structured data for matching"""
ADMIN_IMAGES: str = "admin_images"
"""Admin UI page images"""
@staticmethod
def document_path(document_id: str, extension: str = ".pdf") -> str:
"""Get path for a document file.
Args:
document_id: Unique document identifier.
extension: File extension (include leading dot).
Returns:
Storage path like "documents/abc123.pdf"
"""
ext = extension if extension.startswith(".") else f".{extension}"
return f"{PREFIXES.DOCUMENTS}/{document_id}{ext}"
@staticmethod
def image_path(document_id: str, page_num: int, extension: str = ".png") -> str:
"""Get path for a page image file.
Args:
document_id: Unique document identifier.
page_num: Page number (1-indexed).
extension: File extension (include leading dot).
Returns:
Storage path like "images/abc123/page_1.png"
"""
ext = extension if extension.startswith(".") else f".{extension}"
return f"{PREFIXES.IMAGES}/{document_id}/page_{page_num}{ext}"
@staticmethod
def upload_path(filename: str, subfolder: str | None = None) -> str:
"""Get path for a temporary upload file.
Args:
filename: Original filename.
subfolder: Optional subfolder (e.g., "async").
Returns:
Storage path like "uploads/filename.pdf" or "uploads/async/filename.pdf"
"""
if subfolder:
return f"{PREFIXES.UPLOADS}/{subfolder}/{filename}"
return f"{PREFIXES.UPLOADS}/{filename}"
@staticmethod
def result_path(filename: str) -> str:
"""Get path for a result file.
Args:
filename: Result filename.
Returns:
Storage path like "results/filename.json"
"""
return f"{PREFIXES.RESULTS}/{filename}"
@staticmethod
def export_path(export_id: str, filename: str) -> str:
"""Get path for an export file.
Args:
export_id: Unique export identifier.
filename: Export filename.
Returns:
Storage path like "exports/abc123/filename.zip"
"""
return f"{PREFIXES.EXPORTS}/{export_id}/{filename}"
@staticmethod
def dataset_path(dataset_id: str, filename: str) -> str:
"""Get path for a dataset file.
Args:
dataset_id: Unique dataset identifier.
filename: Dataset filename.
Returns:
Storage path like "datasets/abc123/filename.yaml"
"""
return f"{PREFIXES.DATASETS}/{dataset_id}/{filename}"
@staticmethod
def model_path(version: str, filename: str) -> str:
"""Get path for a model file.
Args:
version: Model version string.
filename: Model filename.
Returns:
Storage path like "models/v1.0.0/best.pt"
"""
return f"{PREFIXES.MODELS}/{version}/{filename}"
# Default instance for convenient access
PREFIXES = StoragePrefixes()

View File

@@ -0,0 +1,309 @@
"""
AWS S3 Storage backend.
Provides storage operations using AWS S3.
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client
from shared.storage.base import (
FileNotFoundStorageError,
StorageBackend,
StorageError,
)
class S3StorageBackend(StorageBackend):
"""Storage backend using AWS S3.
Files are stored as objects in an S3 bucket.
"""
def __init__(
self,
bucket_name: str,
region_name: str | None = None,
access_key_id: str | None = None,
secret_access_key: str | None = None,
endpoint_url: str | None = None,
create_bucket: bool = False,
) -> None:
"""Initialize S3 storage backend.
Args:
bucket_name: Name of the S3 bucket.
region_name: AWS region name (optional, uses default if not set).
access_key_id: AWS access key ID (optional, uses credentials chain).
secret_access_key: AWS secret access key (optional).
endpoint_url: Custom endpoint URL (for S3-compatible services).
create_bucket: If True, create the bucket if it doesn't exist.
"""
import boto3
self._bucket_name = bucket_name
self._region_name = region_name
# Build client kwargs
client_kwargs: dict[str, Any] = {}
if region_name:
client_kwargs["region_name"] = region_name
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url
if access_key_id and secret_access_key:
client_kwargs["aws_access_key_id"] = access_key_id
client_kwargs["aws_secret_access_key"] = secret_access_key
self._s3: "S3Client" = boto3.client("s3", **client_kwargs)
if create_bucket:
self._ensure_bucket_exists()
def _ensure_bucket_exists(self) -> None:
"""Create the bucket if it doesn't exist."""
from botocore.exceptions import ClientError
try:
self._s3.head_bucket(Bucket=self._bucket_name)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchBucket"):
# Bucket doesn't exist, create it
create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if self._region_name and self._region_name != "us-east-1":
create_kwargs["CreateBucketConfiguration"] = {
"LocationConstraint": self._region_name
}
self._s3.create_bucket(**create_kwargs)
else:
# Re-raise permission errors, network issues, etc.
raise
def _object_exists(self, key: str) -> bool:
"""Check if an object exists in S3.
Args:
key: Object key to check.
Returns:
True if object exists, False otherwise.
"""
from botocore.exceptions import ClientError
try:
self._s3.head_object(Bucket=self._bucket_name, Key=key)
return True
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
return False
raise
@property
def bucket_name(self) -> str:
"""Get the bucket name for this storage backend."""
return self._bucket_name
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
"""Upload a file to S3.
Args:
local_path: Path to the local file to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the file was stored.
Raises:
FileNotFoundStorageError: If local_path doesn't exist.
StorageError: If object exists and overwrite is False.
"""
if not local_path.exists():
raise FileNotFoundStorageError(str(local_path))
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.upload_file(str(local_path), self._bucket_name, remote_path)
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
"""Download an object from S3.
Args:
remote_path: Object key in S3.
local_path: Local destination path.
Returns:
The local path where the file was downloaded.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
self._s3.download_file(self._bucket_name, remote_path, str(local_path))
return local_path
def exists(self, remote_path: str) -> bool:
"""Check if an object exists in S3.
Args:
remote_path: Object key to check.
Returns:
True if the object exists, False otherwise.
"""
return self._object_exists(remote_path)
def list_files(self, prefix: str) -> list[str]:
"""List objects in S3 with given prefix.
Handles pagination to return all matching objects (S3 returns max 1000 per request).
Args:
prefix: Object key prefix to filter.
Returns:
List of object keys matching the prefix.
"""
kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
if prefix:
kwargs["Prefix"] = prefix
all_keys: list[str] = []
while True:
response = self._s3.list_objects_v2(**kwargs)
contents = response.get("Contents", [])
all_keys.extend(obj["Key"] for obj in contents)
if not response.get("IsTruncated"):
break
kwargs["ContinuationToken"] = response["NextContinuationToken"]
return all_keys
def delete(self, remote_path: str) -> bool:
"""Delete an object from S3.
Args:
remote_path: Object key to delete.
Returns:
True if object was deleted, False if it didn't exist.
"""
if not self._object_exists(remote_path):
return False
self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path)
return True
def get_url(self, remote_path: str) -> str:
"""Get a URL for an object.
Args:
remote_path: Object key in S3.
Returns:
URL to access the object.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=3600,
)
def get_presigned_url(
self,
remote_path: str,
expires_in_seconds: int = 3600,
) -> str:
"""Generate a pre-signed URL for temporary object access.
Args:
remote_path: Object key in S3.
expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days).
Returns:
Pre-signed URL string.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
ValueError: If expires_in_seconds is out of valid range.
"""
if expires_in_seconds < 1 or expires_in_seconds > 604800:
raise ValueError(
"expires_in_seconds must be between 1 and 604800 (7 days)"
)
if not self._object_exists(remote_path):
raise FileNotFoundStorageError(remote_path)
return self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket_name, "Key": remote_path},
ExpiresIn=expires_in_seconds,
)
def upload_bytes(
self, data: bytes, remote_path: str, overwrite: bool = False
) -> str:
"""Upload bytes directly to S3.
Args:
data: Bytes to upload.
remote_path: Destination object key.
overwrite: If True, overwrite existing object.
Returns:
The remote path where the data was stored.
Raises:
StorageError: If object exists and overwrite is False.
"""
if not overwrite and self._object_exists(remote_path):
raise StorageError(f"File already exists: {remote_path}")
self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data)
return remote_path
def download_bytes(self, remote_path: str) -> bytes:
"""Download an object as bytes.
Args:
remote_path: Object key in S3.
Returns:
The object contents as bytes.
Raises:
FileNotFoundStorageError: If remote_path doesn't exist.
"""
from botocore.exceptions import ClientError
try:
response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path)
return response["Body"].read()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("404", "NoSuchKey"):
raise FileNotFoundStorageError(remote_path) from e
raise

View File

@@ -20,7 +20,7 @@ from shared.config import get_db_connection_string
from shared.normalize import normalize_field from shared.normalize import normalize_field
from shared.matcher import FieldMatcher from shared.matcher import FieldMatcher
from shared.pdf import is_text_pdf, extract_text_tokens from shared.pdf import is_text_pdf, extract_text_tokens
from training.yolo.annotation_generator import FIELD_CLASSES from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from shared.data.db import DocumentDB from shared.data.db import DocumentDB

View File

@@ -113,7 +113,7 @@ def process_single_document(args_tuple):
# Import inside worker to avoid pickling issues # Import inside worker to avoid pickling issues
from training.data.autolabel_report import AutoLabelReport from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields from training.processing.document_processor import process_page, record_unmatched_fields
start_time = time.time() start_time = time.time()
@@ -342,7 +342,8 @@ def main():
from shared.ocr import OCREngine from shared.ocr import OCREngine
from shared.matcher import FieldMatcher from shared.matcher import FieldMatcher
from shared.normalize import normalize_field from shared.normalize import normalize_field
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.yolo.annotation_generator import AnnotationGenerator
# Handle comma-separated CSV paths # Handle comma-separated CSV paths
csv_input = args.csv csv_input = args.csv

View File

@@ -90,7 +90,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
import shutil import shutil
from training.data.autolabel_report import AutoLabelReport from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields from training.processing.document_processor import process_page, record_unmatched_fields
row_dict = task_data["row_dict"] row_dict = task_data["row_dict"]
@@ -208,7 +208,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
import shutil import shutil
from training.data.autolabel_report import AutoLabelReport from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields from training.processing.document_processor import process_page, record_unmatched_fields
row_dict = task_data["row_dict"] row_dict = task_data["row_dict"]

View File

@@ -15,7 +15,8 @@ from training.data.autolabel_report import FieldMatchResult
from shared.matcher import FieldMatcher from shared.matcher import FieldMatcher
from shared.normalize import normalize_field from shared.normalize import normalize_field
from shared.ocr.machine_code_parser import MachineCodeParser from shared.ocr.machine_code_parser import MachineCodeParser
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from training.yolo.annotation_generator import AnnotationGenerator
def match_supplier_accounts( def match_supplier_accounts(

View File

@@ -9,43 +9,12 @@ from pathlib import Path
from typing import Any from typing import Any
import csv import csv
# Import field mappings from single source of truth
# Field class mapping for YOLO from shared.fields import (
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro TRAINING_FIELD_CLASSES as FIELD_CLASSES,
FIELD_CLASSES = { CLASS_NAMES,
'InvoiceNumber': 0, ACCOUNT_FIELD_MAPPING,
'InvoiceDate': 1, )
'InvoiceDueDate': 2,
'OCR': 3,
'Bankgiro': 4,
'Plusgiro': 5,
'Amount': 6,
'supplier_organisation_number': 7,
'customer_number': 8,
'payment_line': 9, # Machine code payment line at bottom of invoice
}
# Fields that need matching but map to other YOLO classes
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
ACCOUNT_FIELD_MAPPING = {
'supplier_accounts': {
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
}
}
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number',
'customer_number',
'payment_line', # Machine code payment line at bottom of invoice
]
@dataclass @dataclass

View File

@@ -101,7 +101,8 @@ class DatasetBuilder:
Returns: Returns:
DatasetStats with build results DatasetStats with build results
""" """
from .annotation_generator import AnnotationGenerator, CLASS_NAMES from shared.fields import CLASS_NAMES
from .annotation_generator import AnnotationGenerator
random.seed(seed) random.seed(seed)

View File

@@ -18,7 +18,8 @@ import numpy as np
from PIL import Image from PIL import Image
from shared.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
from .annotation_generator import YOLOAnnotation
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

73
run_migration.py Normal file
View File

@@ -0,0 +1,73 @@
"""Run database migration for training_status fields."""
import psycopg2
import os
# Read password from .env file
password = ""
try:
with open(".env") as f:
for line in f:
if line.startswith("DB_PASSWORD="):
password = line.strip().split("=", 1)[1].strip('"').strip("'")
break
except Exception as e:
print(f"Error reading .env: {e}")
print(f"Password found: {bool(password)}")
conn = psycopg2.connect(
host="192.168.68.31",
port=5432,
database="docmaster",
user="docmaster",
password=password
)
conn.autocommit = True
cur = conn.cursor()
# Add training_status column
try:
cur.execute("ALTER TABLE training_datasets ADD COLUMN training_status VARCHAR(20) DEFAULT NULL")
print("Added training_status column")
except Exception as e:
print(f"training_status: {e}")
# Add active_training_task_id column
try:
cur.execute("ALTER TABLE training_datasets ADD COLUMN active_training_task_id UUID DEFAULT NULL")
print("Added active_training_task_id column")
except Exception as e:
print(f"active_training_task_id: {e}")
# Create indexes
try:
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status)")
print("Created training_status index")
except Exception as e:
print(f"index training_status: {e}")
try:
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id)")
print("Created active_training_task_id index")
except Exception as e:
print(f"index active_training_task_id: {e}")
# Update existing datasets that have been used in completed training tasks to trained status
try:
cur.execute("""
UPDATE training_datasets d
SET status = 'trained'
WHERE d.status = 'ready'
AND EXISTS (
SELECT 1 FROM training_tasks t
WHERE t.dataset_id = d.dataset_id
AND t.status = 'completed'
)
""")
print(f"Updated {cur.rowcount} datasets to trained status")
except Exception as e:
print(f"update status: {e}")
cur.close()
conn.close()
print("Migration complete!")

View File

@@ -17,9 +17,8 @@ from inference.data.admin_models import (
AdminDocument, AdminDocument,
AdminAnnotation, AdminAnnotation,
TrainingTask, TrainingTask,
FIELD_CLASSES,
CSV_TO_CLASS_MAPPING,
) )
from shared.fields import FIELD_CLASSES, CSV_TO_CLASS_MAPPING
class TestBatchUpload: class TestBatchUpload:
@@ -507,7 +506,10 @@ class TestCSVToClassMapping:
assert len(CSV_TO_CLASS_MAPPING) > 0 assert len(CSV_TO_CLASS_MAPPING) > 0
def test_csv_mapping_values(self): def test_csv_mapping_values(self):
"""Test specific CSV column mappings.""" """Test specific CSV column mappings.
Note: customer_number is class 8 (verified from trained model best.pt).
"""
assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0 assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0
assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1 assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1
assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2 assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2
@@ -516,7 +518,7 @@ class TestCSVToClassMapping:
assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5 assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5
assert CSV_TO_CLASS_MAPPING["Amount"] == 6 assert CSV_TO_CLASS_MAPPING["Amount"] == 6
assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7 assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7
assert CSV_TO_CLASS_MAPPING["customer_number"] == 9 assert CSV_TO_CLASS_MAPPING["customer_number"] == 8 # Fixed: was 9, model uses 8
def test_csv_mapping_matches_field_classes(self): def test_csv_mapping_matches_field_classes(self):
"""Test that CSV mapping is consistent with FIELD_CLASSES.""" """Test that CSV mapping is consistent with FIELD_CLASSES."""

View File

@@ -0,0 +1 @@
"""Tests for shared.fields module."""

View File

@@ -0,0 +1,200 @@
"""
Tests for field configuration - Single Source of Truth.
These tests ensure consistency across all field definitions and prevent
accidental changes that could break model inference.
CRITICAL: These tests verify that field definitions match the trained YOLO model.
If these tests fail, it likely means someone modified field IDs incorrectly.
"""
import pytest
from shared.fields import (
FIELD_DEFINITIONS,
CLASS_NAMES,
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
NUM_CLASSES,
FieldDefinition,
)
class TestFieldDefinitionsIntegrity:
"""Tests to ensure field definitions are complete and consistent."""
def test_exactly_10_field_definitions(self):
"""Verify we have exactly 10 field classes (matching trained model)."""
assert len(FIELD_DEFINITIONS) == 10
assert NUM_CLASSES == 10
def test_class_ids_are_sequential(self):
"""Verify class IDs are 0-9 without gaps."""
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
assert class_ids == set(range(10))
def test_class_ids_are_unique(self):
"""Verify no duplicate class IDs."""
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
assert len(class_ids) == len(set(class_ids))
def test_class_names_are_unique(self):
"""Verify no duplicate class names."""
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
assert len(class_names) == len(set(class_names))
def test_field_definition_is_immutable(self):
"""Verify FieldDefinition is frozen (immutable)."""
fd = FIELD_DEFINITIONS[0]
with pytest.raises(AttributeError):
fd.class_id = 99 # type: ignore
class TestModelCompatibility:
"""Tests to verify field definitions match the trained YOLO model.
These exact values are read from runs/train/invoice_fields/weights/best.pt
and MUST NOT be changed without retraining the model.
"""
# Expected model.names from best.pt - DO NOT CHANGE
EXPECTED_MODEL_NAMES = {
0: "invoice_number",
1: "invoice_date",
2: "invoice_due_date",
3: "ocr_number",
4: "bankgiro",
5: "plusgiro",
6: "amount",
7: "supplier_org_number",
8: "customer_number",
9: "payment_line",
}
def test_field_classes_match_model(self):
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
def test_class_names_order_matches_model(self):
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
expected_order = [
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
]
assert CLASS_NAMES == expected_order
def test_customer_number_is_class_8(self):
"""CRITICAL: customer_number must be class 8 (not 9)."""
assert FIELD_CLASS_IDS["customer_number"] == 8
assert FIELD_CLASSES[8] == "customer_number"
def test_payment_line_is_class_9(self):
"""CRITICAL: payment_line must be class 9 (not 8)."""
assert FIELD_CLASS_IDS["payment_line"] == 9
assert FIELD_CLASSES[9] == "payment_line"
class TestMappingConsistency:
"""Tests to verify all mappings are consistent with each other."""
def test_field_classes_and_field_class_ids_are_inverses(self):
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
for class_id, class_name in FIELD_CLASSES.items():
assert FIELD_CLASS_IDS[class_name] == class_id
for class_name, class_id in FIELD_CLASS_IDS.items():
assert FIELD_CLASSES[class_id] == class_name
def test_class_names_matches_field_classes_values(self):
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
for i, class_name in enumerate(CLASS_NAMES):
assert FIELD_CLASSES[i] == class_name
def test_class_to_field_has_all_classes(self):
"""Verify CLASS_TO_FIELD has mapping for all class names."""
for class_name in CLASS_NAMES:
assert class_name in CLASS_TO_FIELD
def test_csv_mapping_excludes_derived_fields(self):
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
# payment_line is derived, should not be in CSV mapping
assert "payment_line" not in CSV_TO_CLASS_MAPPING
# All non-derived fields should be in CSV mapping
for fd in FIELD_DEFINITIONS:
if not fd.is_derived:
assert fd.field_name in CSV_TO_CLASS_MAPPING
def test_training_field_classes_includes_all(self):
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
for fd in FIELD_DEFINITIONS:
assert fd.field_name in TRAINING_FIELD_CLASSES
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
class TestSpecificFieldDefinitions:
"""Tests for specific field definitions to catch common mistakes."""
@pytest.mark.parametrize(
"class_id,expected_class_name",
[
(0, "invoice_number"),
(1, "invoice_date"),
(2, "invoice_due_date"),
(3, "ocr_number"),
(4, "bankgiro"),
(5, "plusgiro"),
(6, "amount"),
(7, "supplier_org_number"),
(8, "customer_number"),
(9, "payment_line"),
],
)
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
"""Verify each class ID maps to the correct class name."""
assert FIELD_CLASSES[class_id] == expected_class_name
def test_payment_line_is_derived(self):
"""Verify payment_line is marked as derived."""
payment_line_def = next(
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
)
assert payment_line_def.is_derived is True
def test_other_fields_are_not_derived(self):
"""Verify all fields except payment_line are not derived."""
for fd in FIELD_DEFINITIONS:
if fd.class_name != "payment_line":
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
class TestBackwardCompatibility:
"""Tests to ensure backward compatibility with existing code."""
def test_csv_to_class_mapping_field_names(self):
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
# These are the field names used in CSV files
expected_fields = {
"InvoiceNumber": 0,
"InvoiceDate": 1,
"InvoiceDueDate": 2,
"OCR": 3,
"Bankgiro": 4,
"Plusgiro": 5,
"Amount": 6,
"supplier_organisation_number": 7,
"customer_number": 8,
# payment_line (9) is derived, not in CSV
}
assert CSV_TO_CLASS_MAPPING == expected_fields
def test_class_to_field_returns_field_names(self):
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
# Sample checks for key fields
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
assert CLASS_TO_FIELD["payment_line"] == "payment_line"

View File

@@ -0,0 +1 @@
# Tests for storage module

View File

@@ -0,0 +1,718 @@
"""
Tests for AzureBlobStorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
Uses mocking to avoid requiring actual Azure credentials.
"""
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@pytest.fixture
def mock_blob_service_client() -> MagicMock:
"""Create a mock BlobServiceClient."""
return MagicMock()
@pytest.fixture
def mock_container_client(mock_blob_service_client: MagicMock) -> MagicMock:
"""Create a mock ContainerClient."""
container_client = MagicMock()
mock_blob_service_client.get_container_client.return_value = container_client
return container_client
@pytest.fixture
def mock_blob_client(mock_container_client: MagicMock) -> MagicMock:
"""Create a mock BlobClient."""
blob_client = MagicMock()
mock_container_client.get_blob_client.return_value = blob_client
return blob_client
class TestAzureBlobStorageBackendCreation:
"""Tests for AzureBlobStorageBackend instantiation."""
@patch("shared.storage.azure.BlobServiceClient")
def test_create_with_connection_string(
self, mock_service_class: MagicMock
) -> None:
"""Test creating backend with connection string."""
from shared.storage.azure import AzureBlobStorageBackend
connection_string = "DefaultEndpointsProtocol=https;AccountName=test;..."
backend = AzureBlobStorageBackend(
connection_string=connection_string,
container_name="training-images",
)
mock_service_class.from_connection_string.assert_called_once_with(
connection_string
)
assert backend.container_name == "training-images"
@patch("shared.storage.azure.BlobServiceClient")
def test_create_creates_container_if_not_exists(
self, mock_service_class: MagicMock
) -> None:
"""Test that container is created if it doesn't exist."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="new-container",
create_container=True,
)
mock_container.create_container.assert_called_once()
@patch("shared.storage.azure.BlobServiceClient")
def test_create_does_not_create_container_by_default(
self, mock_service_class: MagicMock
) -> None:
"""Test that container is not created by default."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="existing-container",
)
mock_container.create_container.assert_not_called()
@patch("shared.storage.azure.BlobServiceClient")
def test_is_storage_backend_subclass(
self, mock_service_class: MagicMock
) -> None:
"""Test that AzureBlobStorageBackend is a StorageBackend."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageBackend
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert isinstance(backend, StorageBackend)
class TestAzureBlobStorageBackendUpload:
"""Tests for AzureBlobStorageBackend.upload method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_file(self, mock_service_class: MagicMock) -> None:
"""Test uploading a file."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"Hello, World!")
temp_path = Path(f.name)
try:
result = backend.upload(temp_path, "uploads/sample.txt")
assert result == "uploads/sample.txt"
mock_container.get_blob_client.assert_called_with("uploads/sample.txt")
mock_blob.upload_blob.assert_called_once()
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_fails_if_blob_exists_without_overwrite(
self, mock_service_class: MagicMock
) -> None:
"""Test that upload fails if blob exists and overwrite is False."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"content")
temp_path = Path(f.name)
try:
with pytest.raises(StorageError, match="already exists"):
backend.upload(temp_path, "existing.txt", overwrite=False)
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_succeeds_with_overwrite(
self, mock_service_class: MagicMock
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
f.write(b"content")
temp_path = Path(f.name)
try:
result = backend.upload(temp_path, "existing.txt", overwrite=True)
assert result == "existing.txt"
mock_blob.upload_blob.assert_called_once()
# Check overwrite=True was passed
call_kwargs = mock_blob.upload_blob.call_args[1]
assert call_kwargs.get("overwrite") is True
finally:
temp_path.unlink()
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_nonexistent_file_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
class TestAzureBlobStorageBackendDownload:
"""Tests for AzureBlobStorageBackend.download method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_download_file(self, mock_service_class: MagicMock) -> None:
"""Test downloading a file."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
# Mock download_blob to return stream
mock_stream = MagicMock()
mock_stream.readall.return_value = b"Hello, World!"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir) / "downloaded.txt"
result = backend.download("remote/sample.txt", local_path)
assert result == local_path
assert local_path.exists()
assert local_path.read_bytes() == b"Hello, World!"
@patch("shared.storage.azure.BlobServiceClient")
def test_download_creates_parent_directories(
self, mock_service_class: MagicMock
) -> None:
"""Test that download creates parent directories."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"content"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir) / "deep" / "nested" / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert local_path.exists()
@patch("shared.storage.azure.BlobServiceClient")
def test_download_nonexistent_blob_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test that downloading nonexistent blob fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
class TestAzureBlobStorageBackendExists:
"""Tests for AzureBlobStorageBackend.exists method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_exists_returns_true_for_existing_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test exists returns True for existing blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.exists("existing.txt") is True
@patch("shared.storage.azure.BlobServiceClient")
def test_exists_returns_false_for_nonexistent_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test exists returns False for nonexistent blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.exists("nonexistent.txt") is False
class TestAzureBlobStorageBackendListFiles:
"""Tests for AzureBlobStorageBackend.list_files method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_empty_container(
self, mock_service_class: MagicMock
) -> None:
"""Test listing files in empty container."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_container.list_blobs.return_value = []
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
assert backend.list_files("") == []
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_returns_all_blobs(
self, mock_service_class: MagicMock
) -> None:
"""Test listing all blobs."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
# Create mock blob items
mock_blob1 = MagicMock()
mock_blob1.name = "file1.txt"
mock_blob2 = MagicMock()
mock_blob2.name = "file2.txt"
mock_blob3 = MagicMock()
mock_blob3.name = "subdir/file3.txt"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
@patch("shared.storage.azure.BlobServiceClient")
def test_list_files_with_prefix(
self, mock_service_class: MagicMock
) -> None:
"""Test listing files with prefix filter."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob1 = MagicMock()
mock_blob1.name = "images/a.png"
mock_blob2 = MagicMock()
mock_blob2.name = "images/b.png"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
files = backend.list_files("images/")
mock_container.list_blobs.assert_called_with(name_starts_with="images/")
assert len(files) == 2
class TestAzureBlobStorageBackendDelete:
"""Tests for AzureBlobStorageBackend.delete method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_delete_existing_blob(
self, mock_service_class: MagicMock
) -> None:
"""Test deleting an existing blob."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
result = backend.delete("sample.txt")
assert result is True
mock_blob.delete_blob.assert_called_once()
@patch("shared.storage.azure.BlobServiceClient")
def test_delete_nonexistent_blob_returns_false(
self, mock_service_class: MagicMock
) -> None:
"""Test deleting nonexistent blob returns False."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
result = backend.delete("nonexistent.txt")
assert result is False
mock_blob.delete_blob.assert_not_called()
class TestAzureBlobStorageBackendGetUrl:
"""Tests for AzureBlobStorageBackend.get_url method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_get_url_returns_blob_url(
self, mock_service_class: MagicMock
) -> None:
"""Test get_url returns blob URL."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_blob.url = "https://account.blob.core.windows.net/container/sample.txt"
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
url = backend.get_url("sample.txt")
assert url == "https://account.blob.core.windows.net/container/sample.txt"
@patch("shared.storage.azure.BlobServiceClient")
def test_get_url_nonexistent_blob_fails(
self, mock_service_class: MagicMock
) -> None:
"""Test get_url for nonexistent blob fails."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestAzureBlobStorageBackendUploadBytes:
"""Tests for AzureBlobStorageBackend.upload_bytes method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_bytes(self, mock_service_class: MagicMock) -> None:
"""Test uploading bytes directly."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
mock_blob.upload_blob.assert_called_once()
class TestAzureBlobStorageBackendDownloadBytes:
"""Tests for AzureBlobStorageBackend.download_bytes method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_download_bytes(self, mock_service_class: MagicMock) -> None:
"""Test downloading blob as bytes."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"Hello, World!"
mock_blob.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
@patch("shared.storage.azure.BlobServiceClient")
def test_download_bytes_nonexistent(
self, mock_service_class: MagicMock
) -> None:
"""Test downloading nonexistent blob as bytes."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import FileNotFoundStorageError
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestAzureBlobStorageBackendBatchOperations:
"""Tests for batch operations in AzureBlobStorageBackend."""
@patch("shared.storage.azure.BlobServiceClient")
def test_upload_directory(self, mock_service_class: MagicMock) -> None:
"""Test uploading an entire directory."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
mock_blob = MagicMock()
mock_container.get_blob_client.return_value = mock_blob
mock_blob.exists.return_value = False
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
(temp_path / "file1.txt").write_text("content1")
(temp_path / "subdir").mkdir()
(temp_path / "subdir" / "file2.txt").write_text("content2")
results = backend.upload_directory(temp_path, "uploads/")
assert len(results) == 2
assert "uploads/file1.txt" in results
assert "uploads/subdir/file2.txt" in results
@patch("shared.storage.azure.BlobServiceClient")
def test_download_directory(self, mock_service_class: MagicMock) -> None:
"""Test downloading blobs matching a prefix."""
from shared.storage.azure import AzureBlobStorageBackend
mock_service = MagicMock()
mock_service_class.from_connection_string.return_value = mock_service
mock_container = MagicMock()
mock_service.get_container_client.return_value = mock_container
# Mock blob listing
mock_blob1 = MagicMock()
mock_blob1.name = "images/a.png"
mock_blob2 = MagicMock()
mock_blob2.name = "images/b.png"
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
# Mock blob clients
mock_blob_client = MagicMock()
mock_container.get_blob_client.return_value = mock_blob_client
mock_blob_client.exists.return_value = True
mock_stream = MagicMock()
mock_stream.readall.return_value = b"image content"
mock_blob_client.download_blob.return_value = mock_stream
backend = AzureBlobStorageBackend(
connection_string="connection_string",
container_name="container",
)
with tempfile.TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir)
results = backend.download_directory("images/", local_path)
assert len(results) == 2
# Files should be created relative to prefix
assert (local_path / "a.png").exists() or (local_path / "images" / "a.png").exists()

View File

@@ -0,0 +1,301 @@
"""
Tests for storage base module.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
from abc import ABC
from pathlib import Path
from typing import BinaryIO
from unittest.mock import MagicMock, patch
import pytest
class TestStorageBackendInterface:
"""Tests for StorageBackend abstract base class."""
def test_cannot_instantiate_directly(self) -> None:
"""Test that StorageBackend cannot be instantiated."""
from shared.storage.base import StorageBackend
with pytest.raises(TypeError):
StorageBackend() # type: ignore
def test_is_abstract_base_class(self) -> None:
"""Test that StorageBackend is an ABC."""
from shared.storage.base import StorageBackend
assert issubclass(StorageBackend, ABC)
def test_subclass_must_implement_upload(self) -> None:
"""Test that subclass must implement upload method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_download(self) -> None:
"""Test that subclass must implement download method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_exists(self) -> None:
"""Test that subclass must implement exists method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_list_files(self) -> None:
"""Test that subclass must implement list_files method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_delete(self) -> None:
"""Test that subclass must implement delete method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_subclass_must_implement_get_url(self) -> None:
"""Test that subclass must implement get_url method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_valid_subclass_can_be_instantiated(self) -> None:
"""Test that a complete subclass can be instantiated."""
from shared.storage.base import StorageBackend
class CompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
def get_presigned_url(
self, remote_path: str, expires_in_seconds: int = 3600
) -> str:
return ""
backend = CompleteBackend()
assert isinstance(backend, StorageBackend)
class TestStorageError:
"""Tests for StorageError exception."""
def test_storage_error_is_exception(self) -> None:
"""Test that StorageError is an Exception."""
from shared.storage.base import StorageError
assert issubclass(StorageError, Exception)
def test_storage_error_with_message(self) -> None:
"""Test StorageError with message."""
from shared.storage.base import StorageError
error = StorageError("Upload failed")
assert str(error) == "Upload failed"
def test_storage_error_can_be_raised(self) -> None:
"""Test that StorageError can be raised and caught."""
from shared.storage.base import StorageError
with pytest.raises(StorageError, match="test error"):
raise StorageError("test error")
class TestFileNotFoundError:
"""Tests for FileNotFoundStorageError exception."""
def test_file_not_found_is_storage_error(self) -> None:
"""Test that FileNotFoundStorageError is a StorageError."""
from shared.storage.base import FileNotFoundStorageError, StorageError
assert issubclass(FileNotFoundStorageError, StorageError)
def test_file_not_found_with_path(self) -> None:
"""Test FileNotFoundStorageError with path."""
from shared.storage.base import FileNotFoundStorageError
error = FileNotFoundStorageError("images/test.png")
assert "images/test.png" in str(error)
class TestStorageConfig:
"""Tests for StorageConfig dataclass."""
def test_storage_config_creation(self) -> None:
"""Test creating StorageConfig."""
from shared.storage.base import StorageConfig
config = StorageConfig(
backend_type="azure_blob",
connection_string="DefaultEndpointsProtocol=https;...",
container_name="training-images",
)
assert config.backend_type == "azure_blob"
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
assert config.container_name == "training-images"
def test_storage_config_defaults(self) -> None:
"""Test StorageConfig with defaults."""
from shared.storage.base import StorageConfig
config = StorageConfig(backend_type="local")
assert config.backend_type == "local"
assert config.connection_string is None
assert config.container_name is None
assert config.base_path is None
def test_storage_config_with_base_path(self) -> None:
"""Test StorageConfig with base_path for local backend."""
from shared.storage.base import StorageConfig
config = StorageConfig(
backend_type="local",
base_path=Path("/data/images"),
)
assert config.backend_type == "local"
assert config.base_path == Path("/data/images")
def test_storage_config_immutable(self) -> None:
"""Test that StorageConfig is immutable (frozen)."""
from shared.storage.base import StorageConfig
config = StorageConfig(backend_type="local")
with pytest.raises(AttributeError):
config.backend_type = "azure_blob" # type: ignore

View File

@@ -0,0 +1,348 @@
"""
Tests for storage configuration file loader.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
import shutil
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
class TestEnvVarSubstitution:
"""Tests for environment variable substitution in config values."""
def test_substitute_simple_env_var(self) -> None:
"""Test substituting a simple environment variable."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"MY_VAR": "my_value"}):
result = substitute_env_vars("${MY_VAR}")
assert result == "my_value"
def test_substitute_env_var_with_default(self) -> None:
"""Test substituting env var with default when var is not set."""
from shared.storage.config_loader import substitute_env_vars
# Ensure var is not set
os.environ.pop("UNSET_VAR", None)
result = substitute_env_vars("${UNSET_VAR:-default_value}")
assert result == "default_value"
def test_substitute_env_var_ignores_default_when_set(self) -> None:
"""Test that default is ignored when env var is set."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"SET_VAR": "actual_value"}):
result = substitute_env_vars("${SET_VAR:-default_value}")
assert result == "actual_value"
def test_substitute_multiple_env_vars(self) -> None:
"""Test substituting multiple env vars in one string."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"HOST": "localhost", "PORT": "5432"}):
result = substitute_env_vars("postgres://${HOST}:${PORT}/db")
assert result == "postgres://localhost:5432/db"
def test_substitute_preserves_non_env_text(self) -> None:
"""Test that non-env-var text is preserved."""
from shared.storage.config_loader import substitute_env_vars
with patch.dict(os.environ, {"VAR": "value"}):
result = substitute_env_vars("prefix_${VAR}_suffix")
assert result == "prefix_value_suffix"
def test_substitute_empty_string_when_not_set_and_no_default(self) -> None:
"""Test that empty string is returned when var not set and no default."""
from shared.storage.config_loader import substitute_env_vars
os.environ.pop("MISSING_VAR", None)
result = substitute_env_vars("${MISSING_VAR}")
assert result == ""
class TestLoadStorageConfigYaml:
"""Tests for loading storage configuration from YAML files."""
def test_load_local_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for local backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: local
presigned_url_expiry: 3600
local:
base_path: ./data/storage
""")
config = load_storage_config(config_path)
assert config.backend_type == "local"
assert config.presigned_url_expiry == 3600
assert config.local is not None
assert config.local.base_path == Path("./data/storage")
def test_load_azure_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for Azure backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: azure_blob
presigned_url_expiry: 7200
azure:
connection_string: DefaultEndpointsProtocol=https;AccountName=test
container_name: documents
create_container: true
""")
config = load_storage_config(config_path)
assert config.backend_type == "azure_blob"
assert config.presigned_url_expiry == 7200
assert config.azure is not None
assert config.azure.connection_string == "DefaultEndpointsProtocol=https;AccountName=test"
assert config.azure.container_name == "documents"
assert config.azure.create_container is True
def test_load_s3_backend_config(self, temp_dir: Path) -> None:
"""Test loading configuration for S3 backend."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: s3
presigned_url_expiry: 1800
s3:
bucket_name: my-bucket
region_name: us-west-2
endpoint_url: http://localhost:9000
create_bucket: false
""")
config = load_storage_config(config_path)
assert config.backend_type == "s3"
assert config.presigned_url_expiry == 1800
assert config.s3 is not None
assert config.s3.bucket_name == "my-bucket"
assert config.s3.region_name == "us-west-2"
assert config.s3.endpoint_url == "http://localhost:9000"
assert config.s3.create_bucket is False
def test_load_config_with_env_var_substitution(self, temp_dir: Path) -> None:
"""Test that environment variables are substituted in config."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: ${STORAGE_BACKEND:-local}
local:
base_path: ${STORAGE_PATH:-./default/path}
""")
with patch.dict(os.environ, {"STORAGE_BACKEND": "local", "STORAGE_PATH": "/custom/path"}):
config = load_storage_config(config_path)
assert config.backend_type == "local"
assert config.local is not None
assert config.local.base_path == Path("/custom/path")
def test_load_config_file_not_found_raises(self, temp_dir: Path) -> None:
"""Test that FileNotFoundError is raised for missing config file."""
from shared.storage.config_loader import load_storage_config
with pytest.raises(FileNotFoundError):
load_storage_config(temp_dir / "nonexistent.yaml")
def test_load_config_invalid_yaml_raises(self, temp_dir: Path) -> None:
"""Test that ValueError is raised for invalid YAML."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("invalid: yaml: content: [")
with pytest.raises(ValueError, match="Invalid"):
load_storage_config(config_path)
def test_load_config_missing_backend_raises(self, temp_dir: Path) -> None:
"""Test that ValueError is raised when backend is missing."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
local:
base_path: ./data
""")
with pytest.raises(ValueError, match="backend"):
load_storage_config(config_path)
def test_load_config_default_presigned_url_expiry(self, temp_dir: Path) -> None:
"""Test default presigned_url_expiry when not specified."""
from shared.storage.config_loader import load_storage_config
config_path = temp_dir / "storage.yaml"
config_path.write_text("""
backend: local
local:
base_path: ./data
""")
config = load_storage_config(config_path)
assert config.presigned_url_expiry == 3600 # Default value
class TestStorageFileConfig:
"""Tests for StorageFileConfig dataclass."""
def test_storage_file_config_is_immutable(self) -> None:
"""Test that StorageFileConfig is frozen (immutable)."""
from shared.storage.config_loader import StorageFileConfig
config = StorageFileConfig(backend_type="local")
with pytest.raises(AttributeError):
config.backend_type = "azure_blob" # type: ignore
def test_storage_file_config_defaults(self) -> None:
"""Test StorageFileConfig default values."""
from shared.storage.config_loader import StorageFileConfig
config = StorageFileConfig(backend_type="local")
assert config.backend_type == "local"
assert config.local is None
assert config.azure is None
assert config.s3 is None
assert config.presigned_url_expiry == 3600
class TestLocalConfig:
"""Tests for LocalConfig dataclass."""
def test_local_config_creation(self) -> None:
"""Test creating LocalConfig."""
from shared.storage.config_loader import LocalConfig
config = LocalConfig(base_path=Path("/data/storage"))
assert config.base_path == Path("/data/storage")
def test_local_config_is_immutable(self) -> None:
"""Test that LocalConfig is frozen."""
from shared.storage.config_loader import LocalConfig
config = LocalConfig(base_path=Path("/data"))
with pytest.raises(AttributeError):
config.base_path = Path("/other") # type: ignore
class TestAzureConfig:
"""Tests for AzureConfig dataclass."""
def test_azure_config_creation(self) -> None:
"""Test creating AzureConfig."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="test_connection",
container_name="test_container",
create_container=True,
)
assert config.connection_string == "test_connection"
assert config.container_name == "test_container"
assert config.create_container is True
def test_azure_config_defaults(self) -> None:
"""Test AzureConfig default values."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="conn",
container_name="container",
)
assert config.create_container is False
def test_azure_config_is_immutable(self) -> None:
"""Test that AzureConfig is frozen."""
from shared.storage.config_loader import AzureConfig
config = AzureConfig(
connection_string="conn",
container_name="container",
)
with pytest.raises(AttributeError):
config.container_name = "other" # type: ignore
class TestS3Config:
"""Tests for S3Config dataclass."""
def test_s3_config_creation(self) -> None:
"""Test creating S3Config."""
from shared.storage.config_loader import S3Config
config = S3Config(
bucket_name="my-bucket",
region_name="us-east-1",
access_key_id="AKIAIOSFODNN7EXAMPLE",
secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
endpoint_url="http://localhost:9000",
create_bucket=True,
)
assert config.bucket_name == "my-bucket"
assert config.region_name == "us-east-1"
assert config.access_key_id == "AKIAIOSFODNN7EXAMPLE"
assert config.secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
assert config.endpoint_url == "http://localhost:9000"
assert config.create_bucket is True
def test_s3_config_minimal(self) -> None:
"""Test S3Config with only required fields."""
from shared.storage.config_loader import S3Config
config = S3Config(bucket_name="bucket")
assert config.bucket_name == "bucket"
assert config.region_name is None
assert config.access_key_id is None
assert config.secret_access_key is None
assert config.endpoint_url is None
assert config.create_bucket is False
def test_s3_config_is_immutable(self) -> None:
"""Test that S3Config is frozen."""
from shared.storage.config_loader import S3Config
config = S3Config(bucket_name="bucket")
with pytest.raises(AttributeError):
config.bucket_name = "other" # type: ignore

View File

@@ -0,0 +1,423 @@
"""
Tests for storage factory.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
class TestStorageFactory:
"""Tests for create_storage_backend factory function."""
def test_create_local_backend(self) -> None:
"""Test creating local storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
config = StorageConfig(
backend_type="local",
base_path=Path(temp_dir),
)
backend = create_storage_backend(config)
assert isinstance(backend, LocalStorageBackend)
assert backend.base_path == Path(temp_dir)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure blob storage backend."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="DefaultEndpointsProtocol=https;...",
container_name="training-images",
)
backend = create_storage_backend(config)
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_unknown_backend_raises(self) -> None:
"""Test that unknown backend type raises ValueError."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="unknown_backend")
with pytest.raises(ValueError, match="Unknown storage backend"):
create_storage_backend(config)
def test_create_local_requires_base_path(self) -> None:
"""Test that local backend requires base_path."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(backend_type="local")
with pytest.raises(ValueError, match="base_path"):
create_storage_backend(config)
def test_create_azure_requires_connection_string(self) -> None:
"""Test that Azure backend requires connection_string."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
container_name="container",
)
with pytest.raises(ValueError, match="connection_string"):
create_storage_backend(config)
def test_create_azure_requires_container_name(self) -> None:
"""Test that Azure backend requires container_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="azure_blob",
connection_string="connection_string",
)
with pytest.raises(ValueError, match="container_name"):
create_storage_backend(config)
class TestStorageFactoryFromEnv:
"""Tests for create_storage_backend_from_env factory function."""
def test_create_from_env_local(self) -> None:
"""Test creating local backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
"""Test creating Azure backend from environment variables."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, AzureBlobStorageBackend)
def test_create_from_env_defaults_to_local(self) -> None:
"""Test that factory defaults to local backend."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.local import LocalStorageBackend
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BASE_PATH": temp_dir,
}
# Remove STORAGE_BACKEND if present
with patch.dict(os.environ, env, clear=False):
if "STORAGE_BACKEND" in os.environ:
del os.environ["STORAGE_BACKEND"]
backend = create_storage_backend_from_env()
assert isinstance(backend, LocalStorageBackend)
def test_create_from_env_missing_azure_vars(self) -> None:
"""Test error when Azure env vars are missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "azure_blob",
# Missing AZURE_STORAGE_CONNECTION_STRING
}
with patch.dict(os.environ, env, clear=False):
# Remove the connection string if present
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
create_storage_backend_from_env()
class TestGetDefaultStorageConfig:
"""Tests for get_default_storage_config function."""
def test_get_default_config_local(self) -> None:
"""Test getting default local config."""
from shared.storage.factory import get_default_storage_config
with tempfile.TemporaryDirectory() as temp_dir:
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": temp_dir,
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "local"
assert config.base_path == Path(temp_dir)
def test_get_default_config_azure(self) -> None:
"""Test getting default Azure config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "azure_blob",
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
"AZURE_STORAGE_CONTAINER": "training-images",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "azure_blob"
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
assert config.container_name == "training-images"
class TestStorageFactoryS3:
"""Tests for S3 backend support in factory."""
@patch("boto3.client")
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 storage backend."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
from shared.storage.s3 import S3StorageBackend
config = StorageConfig(
backend_type="s3",
bucket_name="test-bucket",
region_name="us-west-2",
)
backend = create_storage_backend(config)
assert isinstance(backend, S3StorageBackend)
def test_create_s3_requires_bucket_name(self) -> None:
"""Test that S3 backend requires bucket_name."""
from shared.storage.base import StorageConfig
from shared.storage.factory import create_storage_backend
config = StorageConfig(
backend_type="s3",
region_name="us-west-2",
)
with pytest.raises(ValueError, match="bucket_name"):
create_storage_backend(config)
@patch("boto3.client")
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
"""Test creating S3 backend from environment variables."""
from shared.storage.factory import create_storage_backend_from_env
from shared.storage.s3 import S3StorageBackend
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-east-1",
}
with patch.dict(os.environ, env, clear=False):
backend = create_storage_backend_from_env()
assert isinstance(backend, S3StorageBackend)
def test_create_from_env_s3_missing_bucket(self) -> None:
"""Test error when S3 bucket env var is missing."""
from shared.storage.factory import create_storage_backend_from_env
env = {
"STORAGE_BACKEND": "s3",
# Missing AWS_S3_BUCKET
}
with patch.dict(os.environ, env, clear=False):
if "AWS_S3_BUCKET" in os.environ:
del os.environ["AWS_S3_BUCKET"]
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
create_storage_backend_from_env()
def test_get_default_config_s3(self) -> None:
"""Test getting default S3 config."""
from shared.storage.factory import get_default_storage_config
env = {
"STORAGE_BACKEND": "s3",
"AWS_S3_BUCKET": "test-bucket",
"AWS_REGION": "us-west-2",
"AWS_ENDPOINT_URL": "http://localhost:9000",
}
with patch.dict(os.environ, env, clear=False):
config = get_default_storage_config()
assert config.backend_type == "s3"
assert config.bucket_name == "test-bucket"
assert config.region_name == "us-west-2"
assert config.endpoint_url == "http://localhost:9000"
class TestStorageFactoryFromFile:
"""Tests for create_storage_backend_from_file factory function."""
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
"""Test creating local backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
@patch("shared.storage.azure.BlobServiceClient")
def test_create_from_yaml_file_azure(
self, mock_service_class: MagicMock, tmp_path: Path
) -> None:
"""Test creating Azure backend from YAML config file."""
from shared.storage.azure import AzureBlobStorageBackend
from shared.storage.factory import create_storage_backend_from_file
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: azure_blob
azure:
connection_string: DefaultEndpointsProtocol=https;AccountName=test
container_name: documents
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, AzureBlobStorageBackend)
@patch("boto3.client")
def test_create_from_yaml_file_s3(
self, mock_boto3_client: MagicMock, tmp_path: Path
) -> None:
"""Test creating S3 backend from YAML config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.s3 import S3StorageBackend
config_file = tmp_path / "storage.yaml"
config_file.write_text("""
backend: s3
s3:
bucket_name: my-bucket
region_name: us-east-1
""")
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, S3StorageBackend)
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
"""Test that env vars are substituted in config file."""
from shared.storage.factory import create_storage_backend_from_file
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text("""
backend: ${STORAGE_BACKEND:-local}
local:
base_path: ${CUSTOM_STORAGE_PATH}
""")
with patch.dict(
os.environ,
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
):
backend = create_storage_backend_from_file(config_file)
assert isinstance(backend, LocalStorageBackend)
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
"""Test that FileNotFoundError is raised for missing file."""
from shared.storage.factory import create_storage_backend_from_file
with pytest.raises(FileNotFoundError):
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
class TestGetStorageBackend:
"""Tests for get_storage_backend convenience function."""
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
"""Test getting backend from explicit config file."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = get_storage_backend(config_path=config_file)
assert isinstance(backend, LocalStorageBackend)
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
"""Test that get_storage_backend falls back to env vars."""
from shared.storage.factory import get_storage_backend
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(storage_path),
}
with patch.dict(os.environ, env, clear=False):
# No config file provided, should use env vars
backend = get_storage_backend(config_path=None)
assert isinstance(backend, LocalStorageBackend)

View File

@@ -0,0 +1,712 @@
"""
Tests for LocalStorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
import pytest
@pytest.fixture
def temp_storage_dir() -> Path:
"""Create a temporary directory for storage tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_storage_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_storage_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
@pytest.fixture
def sample_image(temp_storage_dir: Path) -> Path:
"""Create a sample PNG file for testing."""
file_path = temp_storage_dir / "sample.png"
# Minimal valid PNG (1x1 transparent pixel)
png_data = bytes(
[
0x89,
0x50,
0x4E,
0x47,
0x0D,
0x0A,
0x1A,
0x0A, # PNG signature
0x00,
0x00,
0x00,
0x0D, # IHDR length
0x49,
0x48,
0x44,
0x52, # IHDR
0x00,
0x00,
0x00,
0x01, # width: 1
0x00,
0x00,
0x00,
0x01, # height: 1
0x08,
0x06,
0x00,
0x00,
0x00, # 8-bit RGBA
0x1F,
0x15,
0xC4,
0x89, # CRC
0x00,
0x00,
0x00,
0x0A, # IDAT length
0x49,
0x44,
0x41,
0x54, # IDAT
0x78,
0x9C,
0x63,
0x00,
0x01,
0x00,
0x00,
0x05,
0x00,
0x01, # compressed data
0x0D,
0x0A,
0x2D,
0xB4, # CRC
0x00,
0x00,
0x00,
0x00, # IEND length
0x49,
0x45,
0x4E,
0x44, # IEND
0xAE,
0x42,
0x60,
0x82, # CRC
]
)
file_path.write_bytes(png_data)
return file_path
class TestLocalStorageBackendCreation:
"""Tests for LocalStorageBackend instantiation."""
def test_create_with_base_path(self, temp_storage_dir: Path) -> None:
"""Test creating backend with base path."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.base_path == temp_storage_dir
def test_create_with_string_path(self, temp_storage_dir: Path) -> None:
"""Test creating backend with string path."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=str(temp_storage_dir))
assert backend.base_path == temp_storage_dir
def test_create_creates_directory_if_not_exists(
self, temp_storage_dir: Path
) -> None:
"""Test that base directory is created if it doesn't exist."""
from shared.storage.local import LocalStorageBackend
new_dir = temp_storage_dir / "new_storage"
assert not new_dir.exists()
backend = LocalStorageBackend(base_path=new_dir)
assert new_dir.exists()
assert backend.base_path == new_dir
def test_is_storage_backend_subclass(self, temp_storage_dir: Path) -> None:
"""Test that LocalStorageBackend is a StorageBackend."""
from shared.storage.base import StorageBackend
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert isinstance(backend, StorageBackend)
class TestLocalStorageBackendUpload:
"""Tests for LocalStorageBackend.upload method."""
def test_upload_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test uploading a file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_file, "uploads/sample.txt")
assert result == "uploads/sample.txt"
assert (storage_dir / "uploads" / "sample.txt").exists()
assert (storage_dir / "uploads" / "sample.txt").read_text() == "Hello, World!"
def test_upload_creates_subdirectories(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload creates necessary subdirectories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_file, "deep/nested/path/sample.txt")
assert (storage_dir / "deep" / "nested" / "path" / "sample.txt").exists()
def test_upload_fails_if_file_exists_without_overwrite(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload fails if file exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# First upload succeeds
backend.upload(sample_file, "sample.txt")
# Second upload should fail
with pytest.raises(StorageError, match="already exists"):
backend.upload(sample_file, "sample.txt", overwrite=False)
def test_upload_succeeds_with_overwrite(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# First upload
backend.upload(sample_file, "sample.txt")
# Modify original file
sample_file.write_text("Modified content")
# Second upload with overwrite
result = backend.upload(sample_file, "sample.txt", overwrite=True)
assert result == "sample.txt"
assert (storage_dir / "sample.txt").read_text() == "Modified content"
def test_upload_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
def test_upload_binary_file(
self, temp_storage_dir: Path, sample_image: Path
) -> None:
"""Test uploading a binary file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
result = backend.upload(sample_image, "images/sample.png")
assert result == "images/sample.png"
uploaded_content = (storage_dir / "images" / "sample.png").read_bytes()
assert uploaded_content == sample_image.read_bytes()
class TestLocalStorageBackendDownload:
"""Tests for LocalStorageBackend.download method."""
def test_download_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading a file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
download_dir = temp_storage_dir / "downloads"
download_dir.mkdir()
backend = LocalStorageBackend(base_path=storage_dir)
# First upload
backend.upload(sample_file, "sample.txt")
# Then download
local_path = download_dir / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert result == local_path
assert local_path.exists()
assert local_path.read_text() == "Hello, World!"
def test_download_creates_parent_directories(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that download creates parent directories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
local_path = temp_storage_dir / "deep" / "nested" / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert local_path.exists()
assert local_path.read_text() == "Hello, World!"
def test_download_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
"""Test that downloading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
def test_download_nested_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading a file from nested path."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/c/sample.txt")
local_path = temp_storage_dir / "downloaded.txt"
result = backend.download("a/b/c/sample.txt", local_path)
assert local_path.read_text() == "Hello, World!"
class TestLocalStorageBackendExists:
"""Tests for LocalStorageBackend.exists method."""
def test_exists_returns_true_for_existing_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test exists returns True for existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
assert backend.exists("sample.txt") is True
def test_exists_returns_false_for_nonexistent_file(
self, temp_storage_dir: Path
) -> None:
"""Test exists returns False for nonexistent file."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.exists("nonexistent.txt") is False
def test_exists_with_nested_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test exists with nested path."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/sample.txt")
assert backend.exists("a/b/sample.txt") is True
assert backend.exists("a/b/other.txt") is False
class TestLocalStorageBackendListFiles:
"""Tests for LocalStorageBackend.list_files method."""
def test_list_files_empty_storage(self, temp_storage_dir: Path) -> None:
"""Test listing files in empty storage."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
assert backend.list_files("") == []
def test_list_files_returns_all_files(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test listing all files."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# Upload multiple files
backend.upload(sample_file, "file1.txt")
backend.upload(sample_file, "file2.txt")
backend.upload(sample_file, "subdir/file3.txt")
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
def test_list_files_with_prefix(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test listing files with prefix filter."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "images/a.png")
backend.upload(sample_file, "images/b.png")
backend.upload(sample_file, "labels/a.txt")
files = backend.list_files("images/")
assert len(files) == 2
assert "images/a.png" in files
assert "images/b.png" in files
assert "labels/a.txt" not in files
def test_list_files_returns_sorted(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that list_files returns sorted list."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "c.txt")
backend.upload(sample_file, "a.txt")
backend.upload(sample_file, "b.txt")
files = backend.list_files("")
assert files == ["a.txt", "b.txt", "c.txt"]
class TestLocalStorageBackendDelete:
"""Tests for LocalStorageBackend.delete method."""
def test_delete_existing_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test deleting an existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
result = backend.delete("sample.txt")
assert result is True
assert not (storage_dir / "sample.txt").exists()
def test_delete_nonexistent_file_returns_false(
self, temp_storage_dir: Path
) -> None:
"""Test deleting nonexistent file returns False."""
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
result = backend.delete("nonexistent.txt")
assert result is False
def test_delete_nested_file(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test deleting a nested file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/sample.txt")
result = backend.delete("a/b/sample.txt")
assert result is True
assert not (storage_dir / "a" / "b" / "sample.txt").exists()
class TestLocalStorageBackendGetUrl:
"""Tests for LocalStorageBackend.get_url method."""
def test_get_url_returns_file_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_url returns file:// URL."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
url = backend.get_url("sample.txt")
# Should return file:// URL or absolute path
assert "sample.txt" in url
# URL should be usable to locate the file
expected_path = storage_dir / "sample.txt"
assert str(expected_path) in url or expected_path.as_uri() == url
def test_get_url_nonexistent_file(self, temp_storage_dir: Path) -> None:
"""Test get_url for nonexistent file."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestLocalStorageBackendUploadBytes:
"""Tests for LocalStorageBackend.upload_bytes method."""
def test_upload_bytes(self, temp_storage_dir: Path) -> None:
"""Test uploading bytes directly."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
assert (storage_dir / "binary.dat").read_bytes() == data
def test_upload_bytes_creates_subdirectories(
self, temp_storage_dir: Path
) -> None:
"""Test that upload_bytes creates subdirectories."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
data = b"content"
backend.upload_bytes(data, "a/b/c/file.dat")
assert (storage_dir / "a" / "b" / "c" / "file.dat").exists()
class TestLocalStorageBackendDownloadBytes:
"""Tests for LocalStorageBackend.download_bytes method."""
def test_download_bytes(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test downloading file as bytes."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
def test_download_bytes_nonexistent(self, temp_storage_dir: Path) -> None:
"""Test downloading nonexistent file as bytes."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestLocalStorageBackendSecurity:
"""Security tests for LocalStorageBackend - path traversal prevention."""
def test_path_traversal_with_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that path traversal using ../ is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "../escape.txt")
def test_path_traversal_with_nested_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that nested path traversal is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "subdir/../../escape.txt")
def test_path_traversal_with_many_dotdot_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that deeply nested path traversal is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload(sample_file, "a/b/c/../../../../escape.txt")
def test_absolute_path_unix_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that absolute Unix paths are blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Absolute paths not allowed"):
backend.upload(sample_file, "/etc/passwd")
def test_absolute_path_windows_blocked(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that absolute Windows paths are blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Absolute paths not allowed"):
backend.upload(sample_file, "C:\\Windows\\System32\\config")
def test_download_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in download is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.download("../escape.txt", Path("/tmp/file.txt"))
def test_exists_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in exists is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.exists("../escape.txt")
def test_delete_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in delete is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.delete("../escape.txt")
def test_get_url_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in get_url is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.get_url("../escape.txt")
def test_upload_bytes_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in upload_bytes is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.upload_bytes(b"content", "../escape.txt")
def test_download_bytes_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in download_bytes is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.download_bytes("../escape.txt")
def test_valid_nested_path_still_works(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test that valid nested paths still work after security fix."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
# Valid nested paths should still work
result = backend.upload(sample_file, "a/b/c/d/file.txt")
assert result == "a/b/c/d/file.txt"
assert (storage_dir / "a" / "b" / "c" / "d" / "file.txt").exists()

View File

@@ -0,0 +1,158 @@
"""Tests for storage prefixes module."""
import pytest
from shared.storage.prefixes import PREFIXES, StoragePrefixes
class TestStoragePrefixes:
"""Tests for StoragePrefixes class."""
def test_prefixes_are_strings(self) -> None:
"""All prefix constants should be strings."""
assert isinstance(PREFIXES.DOCUMENTS, str)
assert isinstance(PREFIXES.IMAGES, str)
assert isinstance(PREFIXES.UPLOADS, str)
assert isinstance(PREFIXES.RESULTS, str)
assert isinstance(PREFIXES.EXPORTS, str)
assert isinstance(PREFIXES.DATASETS, str)
assert isinstance(PREFIXES.MODELS, str)
assert isinstance(PREFIXES.RAW_PDFS, str)
assert isinstance(PREFIXES.STRUCTURED_DATA, str)
assert isinstance(PREFIXES.ADMIN_IMAGES, str)
def test_prefixes_are_non_empty(self) -> None:
"""All prefix constants should be non-empty."""
assert PREFIXES.DOCUMENTS
assert PREFIXES.IMAGES
assert PREFIXES.UPLOADS
assert PREFIXES.RESULTS
assert PREFIXES.EXPORTS
assert PREFIXES.DATASETS
assert PREFIXES.MODELS
assert PREFIXES.RAW_PDFS
assert PREFIXES.STRUCTURED_DATA
assert PREFIXES.ADMIN_IMAGES
def test_prefixes_have_no_leading_slash(self) -> None:
"""Prefixes should not start with a slash for portability."""
assert not PREFIXES.DOCUMENTS.startswith("/")
assert not PREFIXES.IMAGES.startswith("/")
assert not PREFIXES.UPLOADS.startswith("/")
assert not PREFIXES.RESULTS.startswith("/")
def test_prefixes_have_no_trailing_slash(self) -> None:
"""Prefixes should not end with a slash."""
assert not PREFIXES.DOCUMENTS.endswith("/")
assert not PREFIXES.IMAGES.endswith("/")
assert not PREFIXES.UPLOADS.endswith("/")
assert not PREFIXES.RESULTS.endswith("/")
def test_frozen_dataclass(self) -> None:
"""StoragePrefixes should be immutable."""
with pytest.raises(Exception): # FrozenInstanceError
PREFIXES.DOCUMENTS = "new_value" # type: ignore
class TestDocumentPath:
"""Tests for document_path helper."""
def test_document_path_with_extension(self) -> None:
"""Should generate correct document path with extension."""
path = PREFIXES.document_path("abc123", ".pdf")
assert path == "documents/abc123.pdf"
def test_document_path_without_leading_dot(self) -> None:
"""Should handle extension without leading dot."""
path = PREFIXES.document_path("abc123", "pdf")
assert path == "documents/abc123.pdf"
def test_document_path_default_extension(self) -> None:
"""Should use .pdf as default extension."""
path = PREFIXES.document_path("abc123")
assert path == "documents/abc123.pdf"
class TestImagePath:
"""Tests for image_path helper."""
def test_image_path_basic(self) -> None:
"""Should generate correct image path."""
path = PREFIXES.image_path("doc123", 1)
assert path == "images/doc123/page_1.png"
def test_image_path_page_number(self) -> None:
"""Should include page number in path."""
path = PREFIXES.image_path("doc123", 5)
assert path == "images/doc123/page_5.png"
def test_image_path_custom_extension(self) -> None:
"""Should support custom extension."""
path = PREFIXES.image_path("doc123", 1, ".jpg")
assert path == "images/doc123/page_1.jpg"
class TestUploadPath:
"""Tests for upload_path helper."""
def test_upload_path_basic(self) -> None:
"""Should generate correct upload path."""
path = PREFIXES.upload_path("invoice.pdf")
assert path == "uploads/invoice.pdf"
def test_upload_path_with_subfolder(self) -> None:
"""Should include subfolder when provided."""
path = PREFIXES.upload_path("invoice.pdf", "async")
assert path == "uploads/async/invoice.pdf"
class TestResultPath:
"""Tests for result_path helper."""
def test_result_path_basic(self) -> None:
"""Should generate correct result path."""
path = PREFIXES.result_path("output.json")
assert path == "results/output.json"
class TestExportPath:
"""Tests for export_path helper."""
def test_export_path_basic(self) -> None:
"""Should generate correct export path."""
path = PREFIXES.export_path("exp123", "dataset.zip")
assert path == "exports/exp123/dataset.zip"
class TestDatasetPath:
"""Tests for dataset_path helper."""
def test_dataset_path_basic(self) -> None:
"""Should generate correct dataset path."""
path = PREFIXES.dataset_path("ds123", "data.yaml")
assert path == "datasets/ds123/data.yaml"
class TestModelPath:
"""Tests for model_path helper."""
def test_model_path_basic(self) -> None:
"""Should generate correct model path."""
path = PREFIXES.model_path("v1.0.0", "best.pt")
assert path == "models/v1.0.0/best.pt"
class TestExportsFromInit:
"""Tests for exports from storage __init__.py."""
def test_prefixes_exported(self) -> None:
"""PREFIXES should be exported from storage module."""
from shared.storage import PREFIXES as exported_prefixes
assert exported_prefixes is PREFIXES
def test_storage_prefixes_exported(self) -> None:
"""StoragePrefixes should be exported from storage module."""
from shared.storage import StoragePrefixes as exported_class
assert exported_class is StoragePrefixes

View File

@@ -0,0 +1,264 @@
"""
Tests for pre-signed URL functionality across all storage backends.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def temp_storage_dir() -> Path:
"""Create a temporary directory for storage tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_storage_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_storage_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
class TestStorageBackendInterfacePresignedUrl:
"""Tests for get_presigned_url in StorageBackend interface."""
def test_subclass_must_implement_get_presigned_url(self) -> None:
"""Test that subclass must implement get_presigned_url method."""
from shared.storage.base import StorageBackend
class IncompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
with pytest.raises(TypeError):
IncompleteBackend() # type: ignore
def test_valid_subclass_with_get_presigned_url_can_be_instantiated(self) -> None:
"""Test that a complete subclass with get_presigned_url can be instantiated."""
from shared.storage.base import StorageBackend
class CompleteBackend(StorageBackend):
def upload(
self, local_path: Path, remote_path: str, overwrite: bool = False
) -> str:
return remote_path
def download(self, remote_path: str, local_path: Path) -> Path:
return local_path
def exists(self, remote_path: str) -> bool:
return False
def list_files(self, prefix: str) -> list[str]:
return []
def delete(self, remote_path: str) -> bool:
return True
def get_url(self, remote_path: str) -> str:
return ""
def get_presigned_url(
self, remote_path: str, expires_in_seconds: int = 3600
) -> str:
return f"https://example.com/{remote_path}?token=abc"
backend = CompleteBackend()
assert isinstance(backend, StorageBackend)
class TestLocalStorageBackendPresignedUrl:
"""Tests for LocalStorageBackend.get_presigned_url method."""
def test_get_presigned_url_returns_file_uri(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url returns file:// URI for existing file."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
url = backend.get_presigned_url("sample.txt")
assert url.startswith("file://")
assert "sample.txt" in url
def test_get_presigned_url_with_custom_expiry(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url accepts expires_in_seconds parameter."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "sample.txt")
# For local storage, expiry is ignored but should not raise error
url = backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
assert url.startswith("file://")
def test_get_presigned_url_nonexistent_file_raises(
self, temp_storage_dir: Path
) -> None:
"""Test get_presigned_url raises FileNotFoundStorageError for missing file."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")
def test_get_presigned_url_path_traversal_blocked(
self, temp_storage_dir: Path
) -> None:
"""Test that path traversal in get_presigned_url is blocked."""
from shared.storage.base import StorageError
from shared.storage.local import LocalStorageBackend
backend = LocalStorageBackend(base_path=temp_storage_dir)
with pytest.raises(StorageError, match="Path traversal not allowed"):
backend.get_presigned_url("../escape.txt")
def test_get_presigned_url_nested_path(
self, temp_storage_dir: Path, sample_file: Path
) -> None:
"""Test get_presigned_url works with nested paths."""
from shared.storage.local import LocalStorageBackend
storage_dir = temp_storage_dir / "storage"
backend = LocalStorageBackend(base_path=storage_dir)
backend.upload(sample_file, "a/b/c/sample.txt")
url = backend.get_presigned_url("a/b/c/sample.txt")
assert url.startswith("file://")
assert "sample.txt" in url
class TestAzureBlobStorageBackendPresignedUrl:
"""Tests for AzureBlobStorageBackend.get_presigned_url method."""
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_generates_sas_url(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url generates URL with SAS token."""
from shared.storage.azure import AzureBlobStorageBackend
# Setup mocks
mock_blob_service = MagicMock()
mock_blob_service.account_name = "testaccount"
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = True
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
container_name="container",
)
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
url = backend.get_presigned_url("test.txt", expires_in_seconds=3600)
assert "https://testaccount.blob.core.windows.net" in url
assert "sv=2021-06-08" in url or "test.txt" in url
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_nonexistent_blob_raises(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url raises for nonexistent blob."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.azure import AzureBlobStorageBackend
mock_blob_service = MagicMock()
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = False
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key==;EndpointSuffix=core.windows.net",
container_name="container",
)
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")
@patch("shared.storage.azure.BlobServiceClient")
def test_get_presigned_url_uses_custom_expiry(
self, mock_blob_service_class: MagicMock
) -> None:
"""Test get_presigned_url uses custom expiry time."""
from shared.storage.azure import AzureBlobStorageBackend
mock_blob_service = MagicMock()
mock_blob_service.account_name = "testaccount"
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
mock_container = MagicMock()
mock_container.exists.return_value = True
mock_blob_service.get_container_client.return_value = mock_container
mock_blob_client = MagicMock()
mock_blob_client.exists.return_value = True
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
mock_container.get_blob_client.return_value = mock_blob_client
backend = AzureBlobStorageBackend(
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
container_name="container",
)
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
backend.get_presigned_url("test.txt", expires_in_seconds=7200)
# Verify generate_blob_sas was called (expiry is part of the call)
mock_generate_sas.assert_called_once()

View File

@@ -0,0 +1,520 @@
"""
Tests for S3StorageBackend.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch, call
import pytest
@pytest.fixture
def temp_dir() -> Path:
"""Create a temporary directory for tests."""
temp_dir = Path(tempfile.mkdtemp())
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def sample_file(temp_dir: Path) -> Path:
"""Create a sample file for testing."""
file_path = temp_dir / "sample.txt"
file_path.write_text("Hello, World!")
return file_path
@pytest.fixture
def mock_boto3_client():
"""Create a mock boto3 S3 client."""
with patch("boto3.client") as mock_client_func:
mock_client = MagicMock()
mock_client_func.return_value = mock_client
yield mock_client
class TestS3StorageBackendCreation:
"""Tests for S3StorageBackend instantiation."""
def test_create_with_bucket_name(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with bucket name."""
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.bucket_name == "test-bucket"
def test_create_with_region(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with region."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
region_name="us-west-2",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("region_name") == "us-west-2"
def test_create_with_credentials(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with explicit credentials."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
access_key_id="AKIATEST",
secret_access_key="secret123",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("aws_access_key_id") == "AKIATEST"
assert call_kwargs.get("aws_secret_access_key") == "secret123"
def test_create_with_endpoint_url(self, mock_boto3_client: MagicMock) -> None:
"""Test creating backend with custom endpoint (for S3-compatible services)."""
from shared.storage.s3 import S3StorageBackend
with patch("boto3.client") as mock_client:
S3StorageBackend(
bucket_name="test-bucket",
endpoint_url="http://localhost:9000",
)
mock_client.assert_called_once()
call_kwargs = mock_client.call_args[1]
assert call_kwargs.get("endpoint_url") == "http://localhost:9000"
def test_create_bucket_when_requested(self, mock_boto3_client: MagicMock) -> None:
"""Test that bucket is created when create_bucket=True."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_bucket.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadBucket"
)
S3StorageBackend(
bucket_name="test-bucket",
create_bucket=True,
)
mock_boto3_client.create_bucket.assert_called_once()
def test_is_storage_backend_subclass(self, mock_boto3_client: MagicMock) -> None:
"""Test that S3StorageBackend is a StorageBackend."""
from shared.storage.base import StorageBackend
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
assert isinstance(backend, StorageBackend)
class TestS3StorageBackendUpload:
"""Tests for S3StorageBackend.upload method."""
def test_upload_file(
self, mock_boto3_client: MagicMock, temp_dir: Path, sample_file: Path
) -> None:
"""Test uploading a file."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
# Object does not exist
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.upload(sample_file, "uploads/sample.txt")
assert result == "uploads/sample.txt"
mock_boto3_client.upload_file.assert_called_once()
def test_upload_fails_if_exists_without_overwrite(
self, mock_boto3_client: MagicMock, sample_file: Path
) -> None:
"""Test that upload fails if object exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(StorageError, match="already exists"):
backend.upload(sample_file, "sample.txt", overwrite=False)
def test_upload_succeeds_with_overwrite(
self, mock_boto3_client: MagicMock, sample_file: Path
) -> None:
"""Test that upload succeeds with overwrite=True."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.upload(sample_file, "sample.txt", overwrite=True)
assert result == "sample.txt"
mock_boto3_client.upload_file.assert_called_once()
def test_upload_nonexistent_file_fails(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that uploading nonexistent file fails."""
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.upload(temp_dir / "nonexistent.txt", "sample.txt")
class TestS3StorageBackendDownload:
"""Tests for S3StorageBackend.download method."""
def test_download_file(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test downloading a file."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
local_path = temp_dir / "downloaded.txt"
result = backend.download("sample.txt", local_path)
assert result == local_path
mock_boto3_client.download_file.assert_called_once()
def test_download_creates_parent_directories(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that download creates parent directories."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
local_path = temp_dir / "deep" / "nested" / "downloaded.txt"
backend.download("sample.txt", local_path)
assert local_path.parent.exists()
def test_download_nonexistent_object_fails(
self, mock_boto3_client: MagicMock, temp_dir: Path
) -> None:
"""Test that downloading nonexistent object fails."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.download("nonexistent.txt", temp_dir / "file.txt")
class TestS3StorageBackendExists:
"""Tests for S3StorageBackend.exists method."""
def test_exists_returns_true_for_existing_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test exists returns True for existing object."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.exists("sample.txt") is True
def test_exists_returns_false_for_nonexistent_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test exists returns False for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
assert backend.exists("nonexistent.txt") is False
class TestS3StorageBackendListFiles:
"""Tests for S3StorageBackend.list_files method."""
def test_list_files_returns_objects(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing objects."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {
"Contents": [
{"Key": "file1.txt"},
{"Key": "file2.txt"},
{"Key": "subdir/file3.txt"},
]
}
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("")
assert len(files) == 3
assert "file1.txt" in files
assert "file2.txt" in files
assert "subdir/file3.txt" in files
def test_list_files_with_prefix(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing objects with prefix filter."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {
"Contents": [
{"Key": "images/a.png"},
{"Key": "images/b.png"},
]
}
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("images/")
mock_boto3_client.list_objects_v2.assert_called_with(
Bucket="test-bucket", Prefix="images/"
)
def test_list_files_empty_bucket(
self, mock_boto3_client: MagicMock
) -> None:
"""Test listing files in empty bucket."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.list_objects_v2.return_value = {} # No Contents key
backend = S3StorageBackend(bucket_name="test-bucket")
files = backend.list_files("")
assert files == []
class TestS3StorageBackendDelete:
"""Tests for S3StorageBackend.delete method."""
def test_delete_existing_object(
self, mock_boto3_client: MagicMock
) -> None:
"""Test deleting an existing object."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.delete("sample.txt")
assert result is True
mock_boto3_client.delete_object.assert_called_once()
def test_delete_nonexistent_object_returns_false(
self, mock_boto3_client: MagicMock
) -> None:
"""Test deleting nonexistent object returns False."""
from botocore.exceptions import ClientError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
result = backend.delete("nonexistent.txt")
assert result is False
class TestS3StorageBackendGetUrl:
"""Tests for S3StorageBackend.get_url method."""
def test_get_url_returns_s3_url(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_url returns S3 URL."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = (
"https://test-bucket.s3.amazonaws.com/sample.txt"
)
backend = S3StorageBackend(bucket_name="test-bucket")
url = backend.get_url("sample.txt")
assert "sample.txt" in url
def test_get_url_nonexistent_object_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_url raises for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.get_url("nonexistent.txt")
class TestS3StorageBackendUploadBytes:
"""Tests for S3StorageBackend.upload_bytes method."""
def test_upload_bytes(
self, mock_boto3_client: MagicMock
) -> None:
"""Test uploading bytes directly."""
from shared.storage.s3 import S3StorageBackend
from botocore.exceptions import ClientError
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
data = b"Binary content here"
result = backend.upload_bytes(data, "binary.dat")
assert result == "binary.dat"
mock_boto3_client.put_object.assert_called_once()
def test_upload_bytes_fails_if_exists_without_overwrite(
self, mock_boto3_client: MagicMock
) -> None:
"""Test upload_bytes fails if object exists and overwrite is False."""
from shared.storage.base import StorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {} # Object exists
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(StorageError, match="already exists"):
backend.upload_bytes(b"content", "sample.txt", overwrite=False)
class TestS3StorageBackendDownloadBytes:
"""Tests for S3StorageBackend.download_bytes method."""
def test_download_bytes(
self, mock_boto3_client: MagicMock
) -> None:
"""Test downloading object as bytes."""
from shared.storage.s3 import S3StorageBackend
mock_response = MagicMock()
mock_response.read.return_value = b"Hello, World!"
mock_boto3_client.get_object.return_value = {"Body": mock_response}
backend = S3StorageBackend(bucket_name="test-bucket")
data = backend.download_bytes("sample.txt")
assert data == b"Hello, World!"
def test_download_bytes_nonexistent_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test downloading nonexistent object as bytes."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.get_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey"}}, "GetObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.download_bytes("nonexistent.txt")
class TestS3StorageBackendPresignedUrl:
"""Tests for S3StorageBackend.get_presigned_url method."""
def test_get_presigned_url_generates_url(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url generates presigned URL."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = (
"https://test-bucket.s3.amazonaws.com/sample.txt?X-Amz-Algorithm=..."
)
backend = S3StorageBackend(bucket_name="test-bucket")
url = backend.get_presigned_url("sample.txt")
assert "X-Amz-Algorithm" in url or "sample.txt" in url
mock_boto3_client.generate_presigned_url.assert_called_once()
def test_get_presigned_url_with_custom_expiry(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url uses custom expiry."""
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.return_value = {}
mock_boto3_client.generate_presigned_url.return_value = "https://..."
backend = S3StorageBackend(bucket_name="test-bucket")
backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
call_args = mock_boto3_client.generate_presigned_url.call_args
assert call_args[1].get("ExpiresIn") == 7200
def test_get_presigned_url_nonexistent_raises(
self, mock_boto3_client: MagicMock
) -> None:
"""Test get_presigned_url raises for nonexistent object."""
from botocore.exceptions import ClientError
from shared.storage.base import FileNotFoundStorageError
from shared.storage.s3 import S3StorageBackend
mock_boto3_client.head_object.side_effect = ClientError(
{"Error": {"Code": "404"}}, "HeadObject"
)
backend = S3StorageBackend(bucket_name="test-bucket")
with pytest.raises(FileNotFoundStorageError):
backend.get_presigned_url("nonexistent.txt")

View File

@@ -9,7 +9,8 @@ from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES from inference.data.admin_models import AdminAnnotation, AdminDocument
from shared.fields import FIELD_CLASSES
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from inference.web.schemas.admin import ( from inference.web.schemas.admin import (
AnnotationCreate, AnnotationCreate,

View File

@@ -31,6 +31,7 @@ class MockAdminDocument:
self.batch_id = kwargs.get('batch_id', None) self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None) self.csv_field_values = kwargs.get('csv_field_values', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None) self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.category = kwargs.get('category', 'invoice')
self.created_at = kwargs.get('created_at', datetime.utcnow()) self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow())
@@ -67,12 +68,13 @@ class MockAdminDB:
def get_documents_by_token( def get_documents_by_token(
self, self,
admin_token, admin_token=None,
status=None, status=None,
upload_source=None, upload_source=None,
has_annotations=None, has_annotations=None,
auto_label_status=None, auto_label_status=None,
batch_id=None, batch_id=None,
category=None,
limit=20, limit=20,
offset=0 offset=0
): ):
@@ -95,6 +97,8 @@ class MockAdminDB:
docs = [d for d in docs if d.auto_label_status == auto_label_status] docs = [d for d in docs if d.auto_label_status == auto_label_status]
if batch_id: if batch_id:
docs = [d for d in docs if str(d.batch_id) == str(batch_id)] docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
if category:
docs = [d for d in docs if d.category == category]
total = len(docs) total = len(docs)
return docs[offset:offset+limit], total return docs[offset:offset+limit], total

View File

@@ -215,8 +215,10 @@ class TestAsyncProcessingService:
def test_cleanup_orphan_files(self, async_service, mock_db): def test_cleanup_orphan_files(self, async_service, mock_db):
"""Test cleanup of orphan files.""" """Test cleanup of orphan files."""
# Create an orphan file # Create the async upload directory
temp_dir = async_service._async_config.temp_upload_dir temp_dir = async_service._async_config.temp_upload_dir
temp_dir.mkdir(parents=True, exist_ok=True)
orphan_file = temp_dir / "orphan-request.pdf" orphan_file = temp_dir / "orphan-request.pdf"
orphan_file.write_bytes(b"orphan content") orphan_file.write_bytes(b"orphan content")
@@ -228,6 +230,12 @@ class TestAsyncProcessingService:
# Mock database to say file doesn't exist # Mock database to say file doesn't exist
mock_db.get_request.return_value = None mock_db.get_request.return_value = None
# Mock the storage helper to return the same directory as the fixture
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
mock_helper = MagicMock()
mock_helper.get_uploads_base_path.return_value = temp_dir
mock_storage.return_value = mock_helper
count = async_service._cleanup_orphan_files() count = async_service._cleanup_orphan_files()
assert count == 1 assert count == 1

View File

@@ -5,7 +5,75 @@ TDD Phase 5: RED - Write tests first, then implement to pass.
""" """
import pytest import pytest
from unittest.mock import MagicMock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import numpy as np
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.core.auth import validate_admin_token, get_admin_db
TEST_ADMIN_TOKEN = "test-admin-token-12345"
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
@pytest.fixture
def admin_token() -> str:
"""Provide admin token for testing."""
return TEST_ADMIN_TOKEN
@pytest.fixture
def mock_admin_db() -> MagicMock:
"""Create a mock AdminDB for testing."""
mock = MagicMock()
# Default return values
mock.get_document_by_token.return_value = None
mock.get_dataset.return_value = None
mock.get_augmented_datasets.return_value = ([], 0)
return mock
@pytest.fixture
def admin_client(mock_admin_db: MagicMock) -> TestClient:
"""Create test client with admin authentication."""
app = FastAPI()
# Override dependencies
def get_token_override():
return TEST_ADMIN_TOKEN
def get_db_override():
return mock_admin_db
app.dependency_overrides[validate_admin_token] = get_token_override
app.dependency_overrides[get_admin_db] = get_db_override
# Include router - the router already has /augmentation prefix
# so we add /api/v1/admin to get /api/v1/admin/augmentation
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
@pytest.fixture
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
"""Create test client WITHOUT admin authentication override."""
app = FastAPI()
# Only override the database, NOT the token validation
def get_db_override():
return mock_admin_db
app.dependency_overrides[get_admin_db] = get_db_override
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
return TestClient(app)
class TestAugmentationTypesEndpoint: class TestAugmentationTypesEndpoint:
@@ -34,10 +102,10 @@ class TestAugmentationTypesEndpoint:
assert "stage" in aug_type assert "stage" in aug_type
def test_list_augmentation_types_unauthorized( def test_list_augmentation_types_unauthorized(
self, admin_client: TestClient self, unauthenticated_client: TestClient
) -> None: ) -> None:
"""Test that unauthorized request is rejected.""" """Test that unauthorized request is rejected."""
response = admin_client.get("/api/v1/admin/augmentation/types") response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
assert response.status_code == 401 assert response.status_code == 401
@@ -74,8 +142,22 @@ class TestAugmentationPreviewEndpoint:
admin_client: TestClient, admin_client: TestClient,
admin_token: str, admin_token: str,
sample_document_id: str, sample_document_id: str,
mock_admin_db: MagicMock,
) -> None: ) -> None:
"""Test previewing augmentation on a document.""" """Test previewing augmentation on a document."""
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post( response = admin_client.post(
f"/api/v1/admin/augmentation/preview/{sample_document_id}", f"/api/v1/admin/augmentation/preview/{sample_document_id}",
headers={"X-Admin-Token": admin_token}, headers={"X-Admin-Token": admin_token},
@@ -136,8 +218,22 @@ class TestAugmentationPreviewConfigEndpoint:
admin_client: TestClient, admin_client: TestClient,
admin_token: str, admin_token: str,
sample_document_id: str, sample_document_id: str,
mock_admin_db: MagicMock,
) -> None: ) -> None:
"""Test previewing full config on a document.""" """Test previewing full config on a document."""
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
response = admin_client.post( response = admin_client.post(
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}", f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
headers={"X-Admin-Token": admin_token}, headers={"X-Admin-Token": admin_token},
@@ -164,8 +260,14 @@ class TestAugmentationBatchEndpoint:
admin_client: TestClient, admin_client: TestClient,
admin_token: str, admin_token: str,
sample_dataset_id: str, sample_dataset_id: str,
mock_admin_db: MagicMock,
) -> None: ) -> None:
"""Test creating augmented dataset.""" """Test creating augmented dataset."""
# Mock dataset exists
mock_dataset = MagicMock()
mock_dataset.total_images = 100
mock_admin_db.get_dataset.return_value = mock_dataset
response = admin_client.post( response = admin_client.post(
"/api/v1/admin/augmentation/batch", "/api/v1/admin/augmentation/batch",
headers={"X-Admin-Token": admin_token}, headers={"X-Admin-Token": admin_token},
@@ -250,12 +352,10 @@ class TestAugmentedDatasetsListEndpoint:
@pytest.fixture @pytest.fixture
def sample_document_id() -> str: def sample_document_id() -> str:
"""Provide a sample document ID for testing.""" """Provide a sample document ID for testing."""
# This would need to be created in test setup return TEST_DOCUMENT_UUID
return "test-document-id"
@pytest.fixture @pytest.fixture
def sample_dataset_id() -> str: def sample_dataset_id() -> str:
"""Provide a sample dataset ID for testing.""" """Provide a sample dataset ID for testing."""
# This would need to be created in test setup return TEST_DATASET_UUID
return "test-dataset-id"

View File

@@ -35,6 +35,8 @@ def _make_dataset(**overrides) -> MagicMock:
name="test-dataset", name="test-dataset",
description="Test dataset", description="Test dataset",
status="ready", status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8, train_ratio=0.8,
val_ratio=0.1, val_ratio=0.1,
seed=42, seed=42,
@@ -183,6 +185,8 @@ class TestListDatasetsRoute:
mock_db = MagicMock() mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1) mock_db.get_datasets.return_value = ([_make_dataset()], 1)
# Mock the active training tasks lookup to return empty dict
mock_db.get_active_training_tasks_for_datasets.return_value = {}
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0)) result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))

View File

@@ -0,0 +1,363 @@
"""
Tests for dataset training status feature.
Tests cover:
1. Database model fields (training_status, active_training_task_id)
2. AdminDB update_dataset_training_status method
3. API response includes training status fields
4. Scheduler updates dataset status during training lifecycle
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# =============================================================================
# Test Database Model
# =============================================================================
class TestTrainingDatasetModel:
"""Tests for TrainingDataset model fields."""
def test_training_dataset_has_training_status_field(self):
"""TrainingDataset model should have training_status field."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(
name="test-dataset",
training_status="running",
)
assert dataset.training_status == "running"
def test_training_dataset_has_active_training_task_id_field(self):
"""TrainingDataset model should have active_training_task_id field."""
from inference.data.admin_models import TrainingDataset
task_id = uuid4()
dataset = TrainingDataset(
name="test-dataset",
active_training_task_id=task_id,
)
assert dataset.active_training_task_id == task_id
def test_training_dataset_defaults(self):
"""TrainingDataset should have correct defaults for new fields."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test-dataset")
assert dataset.training_status is None
assert dataset.active_training_task_id is None
# =============================================================================
# Test AdminDB Methods
# =============================================================================
class TestAdminDBDatasetTrainingStatus:
"""Tests for AdminDB.update_dataset_training_status method."""
@pytest.fixture
def mock_session(self):
"""Create mock database session."""
session = MagicMock()
return session
def test_update_dataset_training_status_sets_status(self, mock_session):
"""update_dataset_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
)
assert dataset.training_status == "running"
mock_session.add.assert_called_once_with(dataset)
mock_session.commit.assert_called_once()
def test_update_dataset_training_status_sets_task_id(self, mock_session):
"""update_dataset_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
active_training_task_id=str(task_id),
)
assert dataset.active_training_task_id == task_id
def test_update_dataset_training_status_updates_main_status_on_complete(
self, mock_session
):
"""update_dataset_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
update_main_status=True,
)
assert dataset.status == "trained"
assert dataset.training_status == "completed"
def test_update_dataset_training_status_clears_task_id_on_complete(
self, mock_session
):
"""update_dataset_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
training_status="running",
active_training_task_id=task_id,
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
active_training_task_id=None,
)
assert dataset.active_training_task_id is None
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
"""update_dataset_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
# Should not raise
db.update_dataset_training_status(
dataset_id=str(uuid4()),
training_status="running",
)
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
# =============================================================================
# Test API Response
# =============================================================================
class TestDatasetDetailResponseTrainingStatus:
"""Tests for DatasetDetailResponse including training status fields."""
def test_dataset_detail_response_includes_training_status(self):
"""DatasetDetailResponse schema should include training_status field."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status="running",
active_training_task_id=str(uuid4()),
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path="/path/to/dataset",
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status == "running"
assert response.active_training_task_id is not None
def test_dataset_detail_response_allows_null_training_status(self):
"""DatasetDetailResponse should allow null training_status."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path=None,
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status is None
assert response.active_training_task_id is None
# =============================================================================
# Test Scheduler Training Status Updates
# =============================================================================
class TestSchedulerDatasetStatusUpdates:
"""Tests for scheduler updating dataset status during training."""
@pytest.fixture
def mock_db(self):
"""Create mock AdminDB."""
mock = MagicMock()
mock.get_dataset.return_value = MagicMock(
dataset_id=uuid4(),
name="test-dataset",
dataset_path="/path/to/dataset",
total_images=100,
)
mock.get_pending_training_tasks.return_value = []
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
scheduler = TrainingScheduler()
scheduler._db = mock_db
task_id = str(uuid4())
dataset_id = str(uuid4())
# Execute task (will fail but we check the status update call)
try:
scheduler._execute_task(
task_id=task_id,
config={"model_name": "yolo11n.pt"},
dataset_id=dataset_id,
)
except Exception:
pass # Expected to fail in test environment
# Check that training status was updated to running
mock_db.update_dataset_training_status.assert_called()
first_call = mock_db.update_dataset_training_status.call_args_list[0]
assert first_call.kwargs["training_status"] == "running"
assert first_call.kwargs["active_training_task_id"] == task_id
# =============================================================================
# Test Dataset Status Values
# =============================================================================
class TestDatasetStatusValues:
"""Tests for valid dataset status values."""
def test_dataset_status_building(self):
"""Dataset can have status 'building'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="building")
assert dataset.status == "building"
def test_dataset_status_ready(self):
"""Dataset can have status 'ready'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="ready")
assert dataset.status == "ready"
def test_dataset_status_trained(self):
"""Dataset can have status 'trained'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="trained")
assert dataset.status == "trained"
def test_dataset_status_failed(self):
"""Dataset can have status 'failed'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="failed")
assert dataset.status == "failed"
def test_training_status_values(self):
"""Training status can have various values."""
from inference.data.admin_models import TrainingDataset
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
for status in valid_statuses:
dataset = TrainingDataset(name="test", training_status=status)
assert dataset.training_status == status

View File

@@ -0,0 +1,207 @@
"""
Tests for Document Category Feature.
TDD tests for adding category field to admin_documents table.
Documents can be categorized (e.g., invoice, letter, receipt) for training different models.
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock
from uuid import UUID, uuid4
from inference.data.admin_models import AdminDocument
# Test constants
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_TOKEN = "test-admin-token-12345"
class TestAdminDocumentCategoryField:
"""Tests for AdminDocument category field."""
def test_document_has_category_field(self):
"""Test AdminDocument model has category field."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
)
assert hasattr(doc, "category")
def test_document_category_defaults_to_invoice(self):
"""Test category defaults to 'invoice' when not specified."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
)
assert doc.category == "invoice"
def test_document_accepts_custom_category(self):
"""Test document accepts custom category values."""
categories = ["invoice", "letter", "receipt", "contract", "custom_type"]
for cat in categories:
doc = AdminDocument(
document_id=uuid4(),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
category=cat,
)
assert doc.category == cat
def test_document_category_is_string_type(self):
"""Test category field is a string type."""
doc = AdminDocument(
document_id=UUID(TEST_DOC_UUID),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/path/to/file.pdf",
category="letter",
)
assert isinstance(doc.category, str)
class TestDocumentCategoryInReadModel:
"""Tests for category in response models."""
def test_admin_document_read_has_category(self):
"""Test AdminDocumentRead includes category field."""
from inference.data.admin_models import AdminDocumentRead
# Check the model has category field in its schema
assert "category" in AdminDocumentRead.model_fields
class TestDocumentCategoryAPI:
"""Tests for document category in API endpoints."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
return db
def test_upload_document_with_category(self, mock_admin_db):
"""Test uploading document with category parameter."""
from inference.web.schemas.admin import DocumentUploadResponse
# Verify response schema supports category
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
message="Upload successful",
category="letter",
)
assert response.category == "letter"
def test_list_documents_returns_category(self, mock_admin_db):
"""Test list documents endpoint returns category."""
from inference.web.schemas.admin import DocumentItem
item = DocumentItem(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
annotation_count=0,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
category="invoice",
)
assert item.category == "invoice"
def test_document_detail_includes_category(self, mock_admin_db):
"""Test document detail response includes category."""
from inference.web.schemas.admin import DocumentDetailResponse
# Check schema has category
assert "category" in DocumentDetailResponse.model_fields
class TestDocumentCategoryFiltering:
"""Tests for filtering documents by category."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB with category filtering support."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
# Mock documents with different categories
invoice_doc = MagicMock()
invoice_doc.document_id = uuid4()
invoice_doc.category = "invoice"
letter_doc = MagicMock()
letter_doc.document_id = uuid4()
letter_doc.category = "letter"
db.get_documents_by_category.return_value = [invoice_doc]
return db
def test_filter_documents_by_category(self, mock_admin_db):
"""Test filtering documents by category."""
# This tests the DB method signature
result = mock_admin_db.get_documents_by_category("invoice")
assert len(result) == 1
assert result[0].category == "invoice"
class TestDocumentCategoryUpdate:
"""Tests for updating document category."""
def test_update_document_category_schema(self):
"""Test update document request supports category."""
from inference.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="letter")
assert request.category == "letter"
def test_update_document_category_optional(self):
"""Test category is optional in update request."""
from inference.web.schemas.admin import DocumentUpdateRequest
# Should not raise - category is optional
request = DocumentUpdateRequest()
assert request.category is None
class TestDatasetWithCategory:
"""Tests for dataset creation with category filtering."""
def test_dataset_create_with_category_filter(self):
"""Test creating dataset can filter by document category."""
from inference.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Invoice Training Set",
document_ids=[TEST_DOC_UUID],
category="invoice", # Optional filter
)
assert request.category == "invoice"
def test_dataset_create_category_is_optional(self):
"""Test category filter is optional when creating dataset."""
from inference.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Mixed Training Set",
document_ids=[TEST_DOC_UUID],
)
# category should be optional
assert not hasattr(request, "category") or request.category is None

View File

@@ -0,0 +1,165 @@
"""
Tests for Document Category API Endpoints.
TDD tests for category filtering and management in document endpoints.
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import UUID, uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# Test constants
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
TEST_TOKEN = "test-admin-token-12345"
class TestGetCategoriesEndpoint:
"""Tests for GET /admin/documents/categories endpoint."""
def test_categories_endpoint_returns_list(self):
"""Test categories endpoint returns list of available categories."""
from inference.web.schemas.admin import DocumentCategoriesResponse
# Test schema exists and works
response = DocumentCategoriesResponse(
categories=["invoice", "letter", "receipt"],
total=3,
)
assert response.categories == ["invoice", "letter", "receipt"]
assert response.total == 3
def test_categories_response_schema(self):
"""Test DocumentCategoriesResponse schema structure."""
from inference.web.schemas.admin import DocumentCategoriesResponse
assert "categories" in DocumentCategoriesResponse.model_fields
assert "total" in DocumentCategoriesResponse.model_fields
class TestDocumentListFilterByCategory:
"""Tests for filtering documents by category."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
# Mock documents with different categories
invoice_doc = MagicMock()
invoice_doc.document_id = uuid4()
invoice_doc.category = "invoice"
invoice_doc.filename = "invoice1.pdf"
letter_doc = MagicMock()
letter_doc.document_id = uuid4()
letter_doc.category = "letter"
letter_doc.filename = "letter1.pdf"
db.get_documents.return_value = ([invoice_doc], 1)
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
return db
def test_list_documents_accepts_category_filter(self, mock_admin_db):
"""Test list documents endpoint accepts category query parameter."""
# The endpoint should accept ?category=invoice parameter
# This test verifies the schema/query parameter exists
from inference.web.schemas.admin import DocumentListResponse
# Schema should work with category filter applied
assert DocumentListResponse is not None
def test_get_document_categories_from_db(self, mock_admin_db):
"""Test fetching unique categories from database."""
categories = mock_admin_db.get_document_categories()
assert "invoice" in categories
assert "letter" in categories
assert len(categories) == 3
class TestDocumentUploadWithCategory:
"""Tests for uploading documents with category."""
def test_upload_request_accepts_category(self):
"""Test upload request can include category field."""
# When uploading via form data, category should be accepted
# This is typically a form field, not a schema
pass
def test_upload_response_includes_category(self):
"""Test upload response includes the category that was set."""
from inference.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
category="letter", # Custom category
message="Upload successful",
)
assert response.category == "letter"
def test_upload_defaults_to_invoice_category(self):
"""Test upload defaults to 'invoice' if no category specified."""
from inference.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
filename="test.pdf",
file_size=1024,
page_count=1,
status="pending",
message="Upload successful",
# No category specified - should default to "invoice"
)
assert response.category == "invoice"
class TestAdminDBCategoryMethods:
"""Tests for AdminDB category-related methods."""
def test_get_document_categories_method_exists(self):
"""Test AdminDB has get_document_categories method."""
from inference.data.admin_db import AdminDB
db = AdminDB()
assert hasattr(db, "get_document_categories")
def test_get_documents_accepts_category_filter(self):
"""Test get_documents_by_token method accepts category parameter."""
from inference.data.admin_db import AdminDB
import inspect
db = AdminDB()
# Check the method exists and accepts category parameter
method = getattr(db, "get_documents_by_token", None)
assert callable(method)
# Check category is in the method signature
sig = inspect.signature(method)
assert "category" in sig.parameters
class TestUpdateDocumentCategory:
"""Tests for updating document category."""
def test_update_document_category_method_exists(self):
"""Test AdminDB has method to update document category."""
from inference.data.admin_db import AdminDB
db = AdminDB()
assert hasattr(db, "update_document_category")
def test_update_request_schema(self):
"""Test DocumentUpdateRequest can update category."""
from inference.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="receipt")
assert request.category == "receipt"

View File

@@ -32,10 +32,10 @@ def test_app(tmp_path):
use_gpu=False, use_gpu=False,
dpi=150, dpi=150,
), ),
storage=StorageConfig( file=StorageConfig(
upload_dir=upload_dir, upload_dir=upload_dir,
result_dir=result_dir, result_dir=result_dir,
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"}, allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
max_file_size_mb=50, max_file_size_mb=50,
), ),
) )
@@ -252,18 +252,23 @@ class TestResultsEndpoint:
response = client.get("/api/v1/results/nonexistent.png") response = client.get("/api/v1/results/nonexistent.png")
assert response.status_code == 404 assert response.status_code == 404
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path): def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
"""Test that existing result file is returned.""" """Test that existing result file is returned."""
# Get storage config from app # Create a test result file in temp directory
storage_config = test_app.extra.get("storage_config") result_dir = tmp_path / "results"
if not storage_config: result_dir.mkdir(exist_ok=True)
pytest.skip("Storage config not available in test app") result_file = result_dir / "test_result.png"
# Create a test result file
result_file = storage_config.result_dir / "test_result.png"
img = Image.new('RGB', (100, 100), color='red') img = Image.new('RGB', (100, 100), color='red')
img.save(result_file) img.save(result_file)
# Mock the storage helper to return our test file path
with patch(
"inference.web.api.v1.public.inference.get_storage_helper"
) as mock_storage:
mock_helper = Mock()
mock_helper.get_result_local_path.return_value = result_file
mock_storage.return_value = mock_helper
# Request the file # Request the file
response = client.get("/api/v1/results/test_result.png") response = client.get("/api/v1/results/test_result.png")

View File

@@ -266,7 +266,11 @@ class TestActivateModelVersionRoute:
mock_db = MagicMock() mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True) mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) # Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID) mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active" assert result.status == "active"
@@ -278,10 +282,14 @@ class TestActivateModelVersionRoute:
mock_db = MagicMock() mock_db = MagicMock()
mock_db.activate_model_version.return_value = None mock_db.activate_model_version.return_value = None
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
from fastapi import HTTPException from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
assert exc_info.value.status_code == 404 assert exc_info.value.status_code == 404

View File

@@ -0,0 +1,828 @@
"""Tests for storage helpers module."""
import pytest
from unittest.mock import MagicMock, patch
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
from shared.storage import PREFIXES
@pytest.fixture
def mock_storage() -> MagicMock:
"""Create a mock storage backend."""
storage = MagicMock()
storage.upload_bytes = MagicMock()
storage.download_bytes = MagicMock(return_value=b"test content")
storage.get_presigned_url = MagicMock(return_value="https://example.com/file")
storage.exists = MagicMock(return_value=True)
storage.delete = MagicMock(return_value=True)
storage.list_files = MagicMock(return_value=[])
return storage
@pytest.fixture
def helper(mock_storage: MagicMock) -> StorageHelper:
"""Create a storage helper with mock backend."""
return StorageHelper(storage=mock_storage)
class TestStorageHelperInit:
"""Tests for StorageHelper initialization."""
def test_init_with_storage(self, mock_storage: MagicMock) -> None:
"""Should use provided storage backend."""
helper = StorageHelper(storage=mock_storage)
assert helper.storage is mock_storage
def test_storage_property(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Storage property should return the backend."""
assert helper.storage is mock_storage
class TestDocumentOperations:
"""Tests for document storage operations."""
def test_upload_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should upload document with correct path."""
doc_id, path = helper.upload_document(b"pdf content", "invoice.pdf", "doc123")
assert doc_id == "doc123"
assert path == "documents/doc123.pdf"
mock_storage.upload_bytes.assert_called_once_with(
b"pdf content", "documents/doc123.pdf", overwrite=True
)
def test_upload_document_generates_id(self, helper: StorageHelper) -> None:
"""Should generate document ID if not provided."""
doc_id, path = helper.upload_document(b"content", "file.pdf")
assert doc_id is not None
assert len(doc_id) > 0
assert path.startswith("documents/")
def test_download_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should download document from correct path."""
content = helper.download_document("doc123")
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("documents/doc123.pdf")
def test_get_document_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for document."""
url = helper.get_document_url("doc123", expires_in_seconds=7200)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"documents/doc123.pdf", 7200
)
def test_document_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check document existence."""
exists = helper.document_exists("doc123")
assert exists is True
mock_storage.exists.assert_called_once_with("documents/doc123.pdf")
def test_delete_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete document."""
result = helper.delete_document("doc123")
assert result is True
mock_storage.delete.assert_called_once_with("documents/doc123.pdf")
class TestImageOperations:
"""Tests for image storage operations."""
def test_save_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save page image with correct path."""
path = helper.save_page_image("doc123", 1, b"image data")
assert path == "images/doc123/page_1.png"
mock_storage.upload_bytes.assert_called_once_with(
b"image data", "images/doc123/page_1.png", overwrite=True
)
def test_get_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get page image from correct path."""
content = helper.get_page_image("doc123", 2)
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("images/doc123/page_2.png")
def test_get_page_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for page image."""
url = helper.get_page_image_url("doc123", 3)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"images/doc123/page_3.png", 3600
)
def test_delete_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete all images for a document."""
mock_storage.list_files.return_value = [
"images/doc123/page_1.png",
"images/doc123/page_2.png",
]
deleted = helper.delete_document_images("doc123")
assert deleted == 2
mock_storage.list_files.assert_called_once_with("images/doc123/")
def test_list_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should list all images for a document."""
mock_storage.list_files.return_value = ["images/doc123/page_1.png"]
images = helper.list_document_images("doc123")
assert images == ["images/doc123/page_1.png"]
class TestUploadOperations:
"""Tests for upload staging operations."""
def test_save_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save upload to correct path."""
path = helper.save_upload(b"content", "file.pdf")
assert path == "uploads/file.pdf"
mock_storage.upload_bytes.assert_called_once()
def test_save_upload_with_subfolder(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save upload with subfolder."""
path = helper.save_upload(b"content", "file.pdf", "async")
assert path == "uploads/async/file.pdf"
def test_get_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get upload from correct path."""
content = helper.get_upload("file.pdf", "async")
mock_storage.download_bytes.assert_called_once_with("uploads/async/file.pdf")
def test_delete_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete upload."""
result = helper.delete_upload("file.pdf")
assert result is True
mock_storage.delete.assert_called_once_with("uploads/file.pdf")
class TestResultOperations:
"""Tests for result file operations."""
def test_save_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save result to correct path."""
path = helper.save_result(b"result data", "output.json")
assert path == "results/output.json"
mock_storage.upload_bytes.assert_called_once()
def test_get_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get result from correct path."""
content = helper.get_result("output.json")
mock_storage.download_bytes.assert_called_once_with("results/output.json")
def test_get_result_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for result."""
url = helper.get_result_url("output.json")
mock_storage.get_presigned_url.assert_called_once_with("results/output.json", 3600)
def test_result_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check result existence."""
exists = helper.result_exists("output.json")
assert exists is True
mock_storage.exists.assert_called_once_with("results/output.json")
class TestExportOperations:
"""Tests for export file operations."""
def test_save_export(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save export to correct path."""
path = helper.save_export(b"export data", "exp123", "dataset.zip")
assert path == "exports/exp123/dataset.zip"
mock_storage.upload_bytes.assert_called_once()
def test_get_export_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for export."""
url = helper.get_export_url("exp123", "dataset.zip")
mock_storage.get_presigned_url.assert_called_once_with(
"exports/exp123/dataset.zip", 3600
)
class TestRawPdfOperations:
"""Tests for raw PDF operations (legacy compatibility)."""
def test_save_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save raw PDF to correct path."""
path = helper.save_raw_pdf(b"pdf data", "invoice.pdf")
assert path == "raw_pdfs/invoice.pdf"
mock_storage.upload_bytes.assert_called_once()
def test_get_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get raw PDF from correct path."""
content = helper.get_raw_pdf("invoice.pdf")
mock_storage.download_bytes.assert_called_once_with("raw_pdfs/invoice.pdf")
def test_raw_pdf_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check raw PDF existence."""
exists = helper.raw_pdf_exists("invoice.pdf")
assert exists is True
mock_storage.exists.assert_called_once_with("raw_pdfs/invoice.pdf")
class TestAdminImageOperations:
"""Tests for admin image storage operations."""
def test_save_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should save admin image with correct path."""
path = helper.save_admin_image("doc123", 1, b"image data")
assert path == "admin_images/doc123/page_1.png"
mock_storage.upload_bytes.assert_called_once_with(
b"image data", "admin_images/doc123/page_1.png", overwrite=True
)
def test_get_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get admin image from correct path."""
content = helper.get_admin_image("doc123", 2)
assert content == b"test content"
mock_storage.download_bytes.assert_called_once_with("admin_images/doc123/page_2.png")
def test_get_admin_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should get presigned URL for admin image."""
url = helper.get_admin_image_url("doc123", 3)
assert url == "https://example.com/file"
mock_storage.get_presigned_url.assert_called_once_with(
"admin_images/doc123/page_3.png", 3600
)
def test_admin_image_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check admin image existence."""
exists = helper.admin_image_exists("doc123", 1)
assert exists is True
mock_storage.exists.assert_called_once_with("admin_images/doc123/page_1.png")
def test_get_admin_image_path(self, helper: StorageHelper) -> None:
"""Should return correct admin image path."""
path = helper.get_admin_image_path("doc123", 2)
assert path == "admin_images/doc123/page_2.png"
def test_list_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should list all admin images for a document."""
mock_storage.list_files.return_value = [
"admin_images/doc123/page_1.png",
"admin_images/doc123/page_2.png",
]
images = helper.list_admin_images("doc123")
assert images == ["admin_images/doc123/page_1.png", "admin_images/doc123/page_2.png"]
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
def test_delete_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete all admin images for a document."""
mock_storage.list_files.return_value = [
"admin_images/doc123/page_1.png",
"admin_images/doc123/page_2.png",
]
deleted = helper.delete_admin_images("doc123")
assert deleted == 2
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
class TestGetLocalPath:
"""Tests for get_local_path method."""
def test_get_admin_image_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test image
test_path = Path(temp_dir) / "admin_images" / "doc123"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "page_1.png").write_bytes(b"test image")
local_path = helper.get_admin_image_local_path("doc123", 1)
assert local_path is not None
assert local_path.exists()
assert local_path.name == "page_1.png"
def test_get_admin_image_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when storage doesn't support local paths."""
# Mock storage without get_local_path method (simulating cloud storage)
mock_storage.get_local_path = MagicMock(return_value=None)
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_admin_image_local_path("doc123", 1)
assert local_path is None
def test_get_admin_image_local_path_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
local_path = helper.get_admin_image_local_path("nonexistent", 1)
assert local_path is None
class TestGetAdminImageDimensions:
"""Tests for get_admin_image_dimensions method."""
def test_get_dimensions_with_local_storage(self) -> None:
"""Should return image dimensions when using local storage."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
from PIL import Image
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test image with known dimensions
test_path = Path(temp_dir) / "admin_images" / "doc123"
test_path.mkdir(parents=True, exist_ok=True)
img = Image.new("RGB", (800, 600), color="white")
img.save(test_path / "page_1.png")
dimensions = helper.get_admin_image_dimensions("doc123", 1)
assert dimensions == (800, 600)
def test_get_dimensions_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
dimensions = helper.get_admin_image_dimensions("nonexistent", 1)
assert dimensions is None
class TestGetStorageHelper:
"""Tests for get_storage_helper function."""
def test_returns_helper_instance(self) -> None:
"""Should return a StorageHelper instance."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
# Reset the global helper
import inference.web.services.storage_helpers as module
module._default_helper = None
helper = get_storage_helper()
assert isinstance(helper, StorageHelper)
def test_returns_same_instance(self) -> None:
"""Should return the same instance on subsequent calls."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
import inference.web.services.storage_helpers as module
module._default_helper = None
helper1 = get_storage_helper()
helper2 = get_storage_helper()
assert helper1 is helper2
class TestDeleteResult:
"""Tests for delete_result method."""
def test_delete_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should delete result file."""
result = helper.delete_result("output.json")
assert result is True
mock_storage.delete.assert_called_once_with("results/output.json")
class TestResultLocalPath:
"""Tests for get_result_local_path method."""
def test_get_result_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test result file
test_path = Path(temp_dir) / "results"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "output.json").write_bytes(b"test result")
local_path = helper.get_result_local_path("output.json")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "output.json"
def test_get_result_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when storage doesn't support local paths."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_result_local_path("output.json")
assert local_path is None
def test_get_result_local_path_nonexistent_file(self) -> None:
"""Should return None when file doesn't exist."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
local_path = helper.get_result_local_path("nonexistent.json")
assert local_path is None
class TestResultsBasePath:
"""Tests for get_results_base_path method."""
def test_get_results_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_results_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "results"
def test_get_results_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_results_base_path()
assert base_path is None
class TestUploadLocalPath:
"""Tests for get_upload_local_path method."""
def test_get_upload_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test upload file
test_path = Path(temp_dir) / "uploads"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "file.pdf").write_bytes(b"test upload")
local_path = helper.get_upload_local_path("file.pdf")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "file.pdf"
def test_get_upload_local_path_with_subfolder(self) -> None:
"""Should return local path with subfolder."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test upload file with subfolder
test_path = Path(temp_dir) / "uploads" / "async"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "file.pdf").write_bytes(b"test upload")
local_path = helper.get_upload_local_path("file.pdf", "async")
assert local_path is not None
assert local_path.exists()
def test_get_upload_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_upload_local_path("file.pdf")
assert local_path is None
class TestUploadsBasePath:
"""Tests for get_uploads_base_path method."""
def test_get_uploads_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_uploads_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "uploads"
def test_get_uploads_base_path_with_subfolder(self) -> None:
"""Should return base path with subfolder."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_uploads_base_path("async")
assert base_path is not None
assert base_path.exists()
assert base_path.name == "async"
def test_get_uploads_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_uploads_base_path()
assert base_path is None
class TestUploadExists:
"""Tests for upload_exists method."""
def test_upload_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
"""Should check upload existence."""
exists = helper.upload_exists("file.pdf")
assert exists is True
mock_storage.exists.assert_called_once_with("uploads/file.pdf")
def test_upload_exists_with_subfolder(
self, helper: StorageHelper, mock_storage: MagicMock
) -> None:
"""Should check upload existence with subfolder."""
helper.upload_exists("file.pdf", "async")
mock_storage.exists.assert_called_once_with("uploads/async/file.pdf")
class TestDatasetsBasePath:
"""Tests for get_datasets_base_path method."""
def test_get_datasets_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_datasets_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "datasets"
def test_get_datasets_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_datasets_base_path()
assert base_path is None
class TestAdminImagesBasePath:
"""Tests for get_admin_images_base_path method."""
def test_get_admin_images_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_admin_images_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "admin_images"
def test_get_admin_images_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_admin_images_base_path()
assert base_path is None
class TestRawPdfsBasePath:
"""Tests for get_raw_pdfs_base_path method."""
def test_get_raw_pdfs_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_raw_pdfs_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "raw_pdfs"
def test_get_raw_pdfs_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_raw_pdfs_base_path()
assert base_path is None
class TestRawPdfLocalPath:
"""Tests for get_raw_pdf_local_path method."""
def test_get_raw_pdf_local_path_with_local_storage(self) -> None:
"""Should return local path when using local storage backend."""
from pathlib import Path
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
# Create a test raw PDF
test_path = Path(temp_dir) / "raw_pdfs"
test_path.mkdir(parents=True, exist_ok=True)
(test_path / "invoice.pdf").write_bytes(b"test pdf")
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
assert local_path is not None
assert local_path.exists()
assert local_path.name == "invoice.pdf"
def test_get_raw_pdf_local_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
assert local_path is None
class TestRawPdfPath:
"""Tests for get_raw_pdf_path method."""
def test_get_raw_pdf_path(self, helper: StorageHelper) -> None:
"""Should return correct storage path."""
path = helper.get_raw_pdf_path("invoice.pdf")
assert path == "raw_pdfs/invoice.pdf"
class TestAutolabelOutputPath:
"""Tests for get_autolabel_output_path method."""
def test_get_autolabel_output_path_with_local_storage(self) -> None:
"""Should return output path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
output_path = helper.get_autolabel_output_path()
assert output_path is not None
assert output_path.exists()
assert output_path.name == "autolabel_output"
def test_get_autolabel_output_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
output_path = helper.get_autolabel_output_path()
assert output_path is None
class TestTrainingDataPath:
"""Tests for get_training_data_path method."""
def test_get_training_data_path_with_local_storage(self) -> None:
"""Should return training path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
training_path = helper.get_training_data_path()
assert training_path is not None
assert training_path.exists()
assert training_path.name == "training"
def test_get_training_data_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
training_path = helper.get_training_data_path()
assert training_path is None
class TestExportsBasePath:
"""Tests for get_exports_base_path method."""
def test_get_exports_base_path_with_local_storage(self) -> None:
"""Should return base path when using local storage."""
from shared.storage.local import LocalStorageBackend
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = LocalStorageBackend(temp_dir)
helper = StorageHelper(storage=storage)
base_path = helper.get_exports_base_path()
assert base_path is not None
assert base_path.exists()
assert base_path.name == "exports"
def test_get_exports_base_path_returns_none_for_cloud(
self, mock_storage: MagicMock
) -> None:
"""Should return None when not using local storage."""
helper = StorageHelper(storage=mock_storage)
base_path = helper.get_exports_base_path()
assert base_path is None

View File

@@ -0,0 +1,306 @@
"""
Tests for storage backend integration in web application.
TDD Phase 1: RED - Write tests first, then implement to pass.
"""
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
class TestStorageBackendInitialization:
"""Tests for storage backend initialization in web config."""
def test_get_storage_backend_returns_backend(self, tmp_path: Path) -> None:
"""Test that get_storage_backend returns a StorageBackend instance."""
from shared.storage.base import StorageBackend
from inference.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
backend = get_storage_backend()
assert isinstance(backend, StorageBackend)
def test_get_storage_backend_uses_config_file_if_exists(
self, tmp_path: Path
) -> None:
"""Test that storage config file is used when present."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
config_file.write_text(f"""
backend: local
local:
base_path: {storage_path}
""")
backend = get_storage_backend(config_path=config_file)
assert isinstance(backend, LocalStorageBackend)
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
"""Test fallback to environment variables when no config file."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
backend = get_storage_backend(config_path=None)
assert isinstance(backend, LocalStorageBackend)
def test_app_config_has_storage_backend(self, tmp_path: Path) -> None:
"""Test that AppConfig can be created with storage backend."""
from shared.storage.base import StorageBackend
from inference.web.config import AppConfig, create_app_config
env = {
"STORAGE_BACKEND": "local",
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
}
with patch.dict(os.environ, env, clear=False):
config = create_app_config()
assert hasattr(config, "storage_backend")
assert isinstance(config.storage_backend, StorageBackend)
class TestStorageBackendInDocumentUpload:
"""Tests for storage backend usage in document upload."""
def test_upload_document_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document upload uses storage backend."""
from unittest.mock import AsyncMock
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a mock upload file
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
# Upload should use storage backend
result = service.upload_document(
content=pdf_content,
filename="test.pdf",
dataset_id="dataset-1",
)
assert result is not None
# Verify file was stored via storage backend
assert backend.exists(f"documents/{result.id}.pdf")
def test_upload_document_stores_logical_path(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document stores logical path, not absolute path."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
result = service.upload_document(
content=pdf_content,
filename="test.pdf",
dataset_id="dataset-1",
)
# Path should be logical (relative), not absolute
assert not result.file_path.startswith("/")
assert not result.file_path.startswith("C:")
assert result.file_path.startswith("documents/")
class TestStorageBackendInDocumentDownload:
"""Tests for storage backend usage in document download/serving."""
def test_get_document_url_returns_presigned_url(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document URL uses presigned URL from storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test file
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
url = service.get_document_url(doc_path)
# Should return a URL (file:// for local, https:// for cloud)
assert url is not None
assert "test-doc.pdf" in url
def test_download_document_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document download uses storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test file
doc_path = "documents/test-doc.pdf"
original_content = b"%PDF-1.4 test content"
backend.upload_bytes(original_content, doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
content = service.download_document(doc_path)
assert content == original_content
class TestStorageBackendInImageServing:
"""Tests for storage backend usage in image serving."""
def test_get_page_image_url_returns_presigned_url(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that page image URL uses presigned URL."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create a test image
image_path = "images/doc-123/page_1.png"
backend.upload_bytes(b"fake png content", image_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
url = service.get_page_image_url("doc-123", 1)
assert url is not None
assert "page_1.png" in url
def test_save_page_image_uses_storage_backend(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that page image saving uses storage backend."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
image_content = b"fake png content"
service.save_page_image("doc-123", 1, image_content)
# Verify image was stored
assert backend.exists("images/doc-123/page_1.png")
class TestStorageBackendInDocumentDeletion:
"""Tests for storage backend usage in document deletion."""
def test_delete_document_removes_from_storage(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document deletion removes file from storage."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create test files
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
service.delete_document_files(doc_path)
assert not backend.exists(doc_path)
def test_delete_document_removes_images(
self, tmp_path: Path, mock_admin_db: MagicMock
) -> None:
"""Test that document deletion removes associated images."""
from shared.storage.local import LocalStorageBackend
storage_path = tmp_path / "storage"
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
# Create test files
doc_id = "test-doc-123"
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
from inference.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
service.delete_document_images(doc_id)
assert not backend.exists(f"images/{doc_id}/page_1.png")
assert not backend.exists(f"images/{doc_id}/page_2.png")
@pytest.fixture
def mock_admin_db() -> MagicMock:
"""Create a mock AdminDB for testing."""
mock = MagicMock()
mock.get_document.return_value = None
mock.create_document.return_value = MagicMock(
id="test-doc-id",
file_path="documents/test-doc-id.pdf",
)
return mock

View File

@@ -103,6 +103,31 @@ class MockAnnotation:
self.updated_at = kwargs.get('updated_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockModelVersion:
"""Mock ModelVersion for testing."""
def __init__(self, **kwargs):
self.version_id = kwargs.get('version_id', uuid4())
self.version = kwargs.get('version', '1.0.0')
self.name = kwargs.get('name', 'Test Model')
self.description = kwargs.get('description', None)
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
self.status = kwargs.get('status', 'inactive')
self.is_active = kwargs.get('is_active', False)
self.task_id = kwargs.get('task_id', None)
self.dataset_id = kwargs.get('dataset_id', None)
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
self.document_count = kwargs.get('document_count', 100)
self.training_config = kwargs.get('training_config', {})
self.file_size = kwargs.get('file_size', 52428800)
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
self.activated_at = kwargs.get('activated_at', None)
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB: class MockAdminDB:
"""Mock AdminDB for testing Phase 4.""" """Mock AdminDB for testing Phase 4."""
@@ -111,6 +136,7 @@ class MockAdminDB:
self.annotations = {} self.annotations = {}
self.training_tasks = {} self.training_tasks = {}
self.training_links = {} self.training_links = {}
self.model_versions = {}
def get_documents_for_training( def get_documents_for_training(
self, self,
@@ -174,6 +200,14 @@ class MockAdminDB:
"""Get training task by ID.""" """Get training task by ID."""
return self.training_tasks.get(str(task_id)) return self.training_tasks.get(str(task_id))
def get_model_versions(self, status=None, limit=20, offset=0):
"""Get model versions with optional filtering."""
models = list(self.model_versions.values())
if status:
models = [m for m in models if m.status == status]
total = len(models)
return models[offset:offset+limit], total
@pytest.fixture @pytest.fixture
def app(): def app():
@@ -241,6 +275,30 @@ def app():
) )
mock_db.training_links[str(doc1.document_id)] = [link1] mock_db.training_links[str(doc1.document_id)] = [link1]
# Add model versions
model1 = MockModelVersion(
version="1.0.0",
name="Model v1.0.0",
status="inactive",
is_active=False,
metrics_mAP=0.935,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
)
model2 = MockModelVersion(
version="1.1.0",
name="Model v1.1.0",
status="active",
is_active=True,
metrics_mAP=0.951,
metrics_precision=0.94,
metrics_recall=0.92,
document_count=600,
)
mock_db.model_versions[str(model1.version_id)] = model1
mock_db.model_versions[str(model2.version_id)] = model2
# Override dependencies # Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token" app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db app.dependency_overrides[get_admin_db] = lambda: mock_db
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
class TestTrainingModels: class TestTrainingModels:
"""Tests for GET /admin/training/models endpoint.""" """Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
def test_get_training_models_success(self, client): def test_get_training_models_success(self, client):
"""Test getting trained models list.""" """Test getting model versions list."""
response = client.get("/admin/training/models") response = client.get("/admin/training/models")
assert response.status_code == 200 assert response.status_code == 200
@@ -338,43 +396,44 @@ class TestTrainingModels:
assert len(data["models"]) == 2 assert len(data["models"]) == 2
def test_get_training_models_includes_metrics(self, client): def test_get_training_models_includes_metrics(self, client):
"""Test that models include metrics.""" """Test that model versions include metrics."""
response = client.get("/admin/training/models") response = client.get("/admin/training/models")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Check first model has metrics # Check first model has metrics fields
model = data["models"][0] model = data["models"][0]
assert "metrics" in model assert "metrics_mAP" in model
assert "mAP" in model["metrics"] assert model["metrics_mAP"] is not None
assert model["metrics"]["mAP"] is not None
assert "precision" in model["metrics"]
assert "recall" in model["metrics"]
def test_get_training_models_includes_download_url(self, client): def test_get_training_models_includes_version_fields(self, client):
"""Test that completed models have download URLs.""" """Test that model versions include version fields."""
response = client.get("/admin/training/models") response = client.get("/admin/training/models")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Check completed models have download URLs # Check model has expected fields
for model in data["models"]: model = data["models"][0]
if model["status"] == "completed": assert "version_id" in model
assert "download_url" in model assert "version" in model
assert model["download_url"] is not None assert "name" in model
assert "status" in model
assert "is_active" in model
assert "document_count" in model
def test_get_training_models_filter_by_status(self, client): def test_get_training_models_filter_by_status(self, client):
"""Test filtering models by status.""" """Test filtering model versions by status."""
response = client.get("/admin/training/models?status=completed") response = client.get("/admin/training/models?status=active")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# All returned models should be completed assert data["total"] == 1
# All returned models should be active
for model in data["models"]: for model in data["models"]:
assert model["status"] == "completed" assert model["status"] == "active"
def test_get_training_models_pagination(self, client): def test_get_training_models_pagination(self, client):
"""Test pagination for models.""" """Test pagination for model versions."""
response = client.get("/admin/training/models?limit=1&offset=0") response = client.get("/admin/training/models?limit=1&offset=0")
assert response.status_code == 200 assert response.status_code == 200