WIP
This commit is contained in:
@@ -90,7 +90,23 @@
|
|||||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
|
||||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
|
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
|
||||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
|
||||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")"
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | head -150\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | tail -50\")",
|
||||||
|
"Bash(wsl bash -c \"lsof -ti:8000 | xargs -r kill -9 2>/dev/null; echo ''Port 8000 cleared''\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py\")",
|
||||||
|
"Bash(wsl bash -c \"curl -s http://localhost:3001 2>/dev/null | head -5 || echo ''Frontend not responding''\")",
|
||||||
|
"Bash(wsl bash -c \"curl -s http://localhost:3000 2>/dev/null | head -5 || echo ''Port 3000 not responding''\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c ''from shared.training import YOLOTrainer, TrainingConfig, TrainingResult; print\\(\"\"Shared training module imported successfully\"\"\\)''\")",
|
||||||
|
"Bash(npm run dev:*)",
|
||||||
|
"Bash(ping:*)",
|
||||||
|
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/frontend && npm run dev\")",
|
||||||
|
"Bash(git checkout:*)",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && PGPASSWORD=$DB_PASSWORD psql -h 192.168.68.31 -U docmaster -d docmaster -f migrations/006_model_versions.sql 2>&1\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c \"\"\nimport os\nimport psycopg2\nfrom pathlib import Path\n\n# Get connection details\nhost = os.getenv\\(''DB_HOST'', ''192.168.68.31''\\)\nport = os.getenv\\(''DB_PORT'', ''5432''\\)\ndbname = os.getenv\\(''DB_NAME'', ''docmaster''\\)\nuser = os.getenv\\(''DB_USER'', ''docmaster''\\)\npassword = os.getenv\\(''DB_PASSWORD'', ''''\\)\n\nprint\\(f''Connecting to {host}:{port}/{dbname}...''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\nconn.autocommit = True\ncursor = conn.cursor\\(\\)\n\n# Run migration 006\nprint\\(''Running migration 006_model_versions.sql...''\\)\nsql = Path\\(''migrations/006_model_versions.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 006 complete!''\\)\n\n# Run migration 007\nprint\\(''Running migration 007_training_tasks_extra_columns.sql...''\\)\nsql = Path\\(''migrations/007_training_tasks_extra_columns.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 007 complete!''\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\nprint\\(''All migrations completed successfully!''\\)\n\"\"\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||||
|
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": [],
|
"ask": [],
|
||||||
|
|||||||
17
.env.example
17
.env.example
@@ -8,6 +8,23 @@ DB_NAME=docmaster
|
|||||||
DB_USER=docmaster
|
DB_USER=docmaster
|
||||||
DB_PASSWORD=your_password_here
|
DB_PASSWORD=your_password_here
|
||||||
|
|
||||||
|
# Storage Configuration
|
||||||
|
# Backend type: local, azure_blob, or s3
|
||||||
|
# All storage paths are relative to STORAGE_BASE_PATH (documents/, images/, uploads/, etc.)
|
||||||
|
STORAGE_BACKEND=local
|
||||||
|
STORAGE_BASE_PATH=./data
|
||||||
|
|
||||||
|
# Azure Blob Storage (when STORAGE_BACKEND=azure_blob)
|
||||||
|
# AZURE_STORAGE_CONNECTION_STRING=your_connection_string
|
||||||
|
# AZURE_STORAGE_CONTAINER=documents
|
||||||
|
|
||||||
|
# AWS S3 Storage (when STORAGE_BACKEND=s3)
|
||||||
|
# AWS_S3_BUCKET=your_bucket_name
|
||||||
|
# AWS_REGION=us-east-1
|
||||||
|
# AWS_ACCESS_KEY_ID=your_access_key
|
||||||
|
# AWS_SECRET_ACCESS_KEY=your_secret_key
|
||||||
|
# AWS_ENDPOINT_URL= # Optional: for S3-compatible services like MinIO
|
||||||
|
|
||||||
# Model Configuration (optional)
|
# Model Configuration (optional)
|
||||||
# MODEL_PATH=runs/train/invoice_fields/weights/best.pt
|
# MODEL_PATH=runs/train/invoice_fields/weights/best.pt
|
||||||
# CONFIDENCE_THRESHOLD=0.5
|
# CONFIDENCE_THRESHOLD=0.5
|
||||||
|
|||||||
239
README.md
239
README.md
@@ -7,8 +7,9 @@
|
|||||||
本项目实现了一个完整的发票字段自动提取流程:
|
本项目实现了一个完整的发票字段自动提取流程:
|
||||||
|
|
||||||
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
|
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
|
||||||
2. **模型训练**: 使用 YOLOv11 训练字段检测模型
|
2. **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
|
||||||
3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化
|
3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化
|
||||||
|
4. **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
|
||||||
|
|
||||||
### 架构
|
### 架构
|
||||||
|
|
||||||
@@ -16,15 +17,17 @@
|
|||||||
|
|
||||||
```
|
```
|
||||||
packages/
|
packages/
|
||||||
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 工具)
|
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 存储, 训练)
|
||||||
├── training/ # 训练服务 (GPU, 按需启动)
|
├── training/ # 训练服务 (GPU, 按需启动)
|
||||||
└── inference/ # 推理服务 (常驻运行)
|
└── inference/ # 推理服务 (常驻运行)
|
||||||
|
frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|
||||||
```
|
```
|
||||||
|
|
||||||
| 服务 | 部署目标 | GPU | 生命周期 |
|
| 服务 | 部署目标 | GPU | 生命周期 |
|
||||||
|------|---------|-----|---------|
|
|------|---------|-----|---------|
|
||||||
| **Inference** | Azure App Service | 可选 | 常驻 7x24 |
|
| **Frontend** | Vercel / Nginx | 否 | 常驻 |
|
||||||
| **Training** | Azure ACI | 必需 | 按需启动/销毁 |
|
| **Inference** | Azure App Service / AWS | 可选 | 常驻 7x24 |
|
||||||
|
| **Training** | Azure ACI / AWS ECS | 必需 | 按需启动/销毁 |
|
||||||
|
|
||||||
两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。
|
两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。
|
||||||
|
|
||||||
@@ -34,7 +37,8 @@ packages/
|
|||||||
|------|------|
|
|------|------|
|
||||||
| **已标注文档** | 9,738 (9,709 成功) |
|
| **已标注文档** | 9,738 (9,709 成功) |
|
||||||
| **总体字段匹配率** | 94.8% (82,604/87,121) |
|
| **总体字段匹配率** | 94.8% (82,604/87,121) |
|
||||||
| **测试** | 922 passed |
|
| **测试** | 1,601 passed |
|
||||||
|
| **测试覆盖率** | 28% |
|
||||||
| **模型 mAP@0.5** | 93.5% |
|
| **模型 mAP@0.5** | 93.5% |
|
||||||
|
|
||||||
**各字段匹配率:**
|
**各字段匹配率:**
|
||||||
@@ -97,6 +101,9 @@ invoice-master-poc-v2/
|
|||||||
│ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析
|
│ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析
|
||||||
│ │ ├── normalize/ # 字段规范化 (10 种 normalizer)
|
│ │ ├── normalize/ # 字段规范化 (10 种 normalizer)
|
||||||
│ │ ├── matcher/ # 字段匹配 (精确/子串/模糊)
|
│ │ ├── matcher/ # 字段匹配 (精确/子串/模糊)
|
||||||
|
│ │ ├── storage/ # 存储抽象层 (Local/Azure/S3)
|
||||||
|
│ │ ├── training/ # 共享训练组件 (YOLOTrainer)
|
||||||
|
│ │ ├── augmentation/ # 数据增强 (DatasetAugmenter)
|
||||||
│ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配)
|
│ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配)
|
||||||
│ │ ├── data/ # DocumentDB, CSVLoader
|
│ │ ├── data/ # DocumentDB, CSVLoader
|
||||||
│ │ ├── config.py # 全局配置 (数据库, 路径, DPI)
|
│ │ ├── config.py # 全局配置 (数据库, 路径, DPI)
|
||||||
@@ -129,12 +136,29 @@ invoice-master-poc-v2/
|
|||||||
│ ├── data/ # AdminDB, AsyncRequestDB, Models
|
│ ├── data/ # AdminDB, AsyncRequestDB, Models
|
||||||
│ └── azure/ # ACI 训练触发器
|
│ └── azure/ # ACI 训练触发器
|
||||||
│
|
│
|
||||||
├── migrations/ # 数据库迁移
|
├── frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|
||||||
│ ├── 001_async_tables.sql
|
│ ├── src/
|
||||||
│ ├── 002_nullable_admin_token.sql
|
│ │ ├── api/ # API 客户端 (axios + react-query)
|
||||||
│ └── 003_training_tasks.sql
|
│ │ ├── components/ # UI 组件
|
||||||
├── frontend/ # React 前端 (Vite + TypeScript)
|
│ │ │ ├── Dashboard.tsx # 文档管理面板
|
||||||
├── tests/ # 测试 (922 tests)
|
│ │ │ ├── Training.tsx # 训练管理 (数据集/任务)
|
||||||
|
│ │ │ ├── Models.tsx # 模型版本管理
|
||||||
|
│ │ │ ├── DatasetDetail.tsx # 数据集详情
|
||||||
|
│ │ │ └── InferenceDemo.tsx # 推理演示
|
||||||
|
│ │ └── hooks/ # React Query hooks
|
||||||
|
│ └── package.json
|
||||||
|
│
|
||||||
|
├── migrations/ # 数据库迁移 (SQL)
|
||||||
|
│ ├── 003_training_tasks.sql
|
||||||
|
│ ├── 004_training_datasets.sql
|
||||||
|
│ ├── 005_add_group_key.sql
|
||||||
|
│ ├── 006_model_versions.sql
|
||||||
|
│ ├── 007_training_tasks_extra_columns.sql
|
||||||
|
│ ├── 008_fix_model_versions_fk.sql
|
||||||
|
│ ├── 009_add_document_category.sql
|
||||||
|
│ └── 010_add_dataset_training_status.sql
|
||||||
|
│
|
||||||
|
├── tests/ # 测试 (1,601 tests)
|
||||||
├── docker-compose.yml # 本地开发 (postgres + inference + training)
|
├── docker-compose.yml # 本地开发 (postgres + inference + training)
|
||||||
├── run_server.py # 快捷启动脚本
|
├── run_server.py # 快捷启动脚本
|
||||||
└── runs/train/ # 训练输出 (weights, curves)
|
└── runs/train/ # 训练输出 (weights, curves)
|
||||||
@@ -270,9 +294,32 @@ Inference API PostgreSQL Training (ACI)
|
|||||||
| POST | `/api/v1/admin/documents/upload` | 上传 PDF |
|
| POST | `/api/v1/admin/documents/upload` | 上传 PDF |
|
||||||
| GET | `/api/v1/admin/documents/{id}` | 文档详情 |
|
| GET | `/api/v1/admin/documents/{id}` | 文档详情 |
|
||||||
| PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 |
|
| PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 |
|
||||||
|
| PATCH | `/api/v1/admin/documents/{id}/category` | 更新文档分类 |
|
||||||
|
| GET | `/api/v1/admin/documents/categories` | 获取分类列表 |
|
||||||
| POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 |
|
| POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 |
|
||||||
| POST | `/api/v1/admin/training/trigger` | 触发训练任务 |
|
|
||||||
| GET | `/api/v1/admin/training/{id}/status` | 查询训练状态 |
|
**Training API:**
|
||||||
|
|
||||||
|
| 方法 | 端点 | 描述 |
|
||||||
|
|------|------|------|
|
||||||
|
| POST | `/api/v1/admin/training/datasets` | 创建数据集 |
|
||||||
|
| GET | `/api/v1/admin/training/datasets` | 数据集列表 |
|
||||||
|
| GET | `/api/v1/admin/training/datasets/{id}` | 数据集详情 |
|
||||||
|
| DELETE | `/api/v1/admin/training/datasets/{id}` | 删除数据集 |
|
||||||
|
| POST | `/api/v1/admin/training/tasks` | 创建训练任务 |
|
||||||
|
| GET | `/api/v1/admin/training/tasks` | 任务列表 |
|
||||||
|
| GET | `/api/v1/admin/training/tasks/{id}` | 任务详情 |
|
||||||
|
| GET | `/api/v1/admin/training/tasks/{id}/logs` | 训练日志 |
|
||||||
|
|
||||||
|
**Model Versions API:**
|
||||||
|
|
||||||
|
| 方法 | 端点 | 描述 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | `/api/v1/admin/models` | 模型版本列表 |
|
||||||
|
| GET | `/api/v1/admin/models/{id}` | 模型详情 |
|
||||||
|
| POST | `/api/v1/admin/models/{id}/activate` | 激活模型 |
|
||||||
|
| POST | `/api/v1/admin/models/{id}/archive` | 归档模型 |
|
||||||
|
| DELETE | `/api/v1/admin/models/{id}` | 删除模型 |
|
||||||
|
|
||||||
## Python API
|
## Python API
|
||||||
|
|
||||||
@@ -332,8 +379,41 @@ print(f"Customer Number: {result}") # "UMJ 436-R"
|
|||||||
|
|
||||||
| 数据库 | 用途 | 存储内容 |
|
| 数据库 | 用途 | 存储内容 |
|
||||||
|--------|------|----------|
|
|--------|------|----------|
|
||||||
| **PostgreSQL** | 标注结果 | `documents`, `field_results`, `training_tasks` |
|
| **PostgreSQL** | 主数据库 | 文档、标注、训练任务、数据集、模型版本 |
|
||||||
| **SQLite** (AdminDB) | Web 应用 | 文档管理, 标注编辑, 用户认证 |
|
|
||||||
|
### 主要表
|
||||||
|
|
||||||
|
| 表名 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `admin_documents` | 文档管理 (PDF 元数据, 状态, 分类) |
|
||||||
|
| `admin_annotations` | 标注数据 (YOLO 格式边界框) |
|
||||||
|
| `training_tasks` | 训练任务 (状态, 配置, 指标) |
|
||||||
|
| `training_datasets` | 数据集 (train/val/test 分割) |
|
||||||
|
| `dataset_documents` | 数据集-文档关联 |
|
||||||
|
| `model_versions` | 模型版本管理 (激活/归档) |
|
||||||
|
| `admin_tokens` | 管理员认证令牌 |
|
||||||
|
| `async_requests` | 异步推理请求 |
|
||||||
|
|
||||||
|
### 数据集状态
|
||||||
|
|
||||||
|
| 状态 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `building` | 正在构建数据集 |
|
||||||
|
| `ready` | 数据集就绪,可开始训练 |
|
||||||
|
| `trained` | 已完成训练 |
|
||||||
|
| `failed` | 构建失败 |
|
||||||
|
| `archived` | 已归档 |
|
||||||
|
|
||||||
|
### 训练状态
|
||||||
|
|
||||||
|
| 状态 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `pending` | 等待执行 |
|
||||||
|
| `scheduled` | 已计划 |
|
||||||
|
| `running` | 正在训练 |
|
||||||
|
| `completed` | 训练完成 |
|
||||||
|
| `failed` | 训练失败 |
|
||||||
|
| `cancelled` | 已取消 |
|
||||||
|
|
||||||
## 测试
|
## 测试
|
||||||
|
|
||||||
@@ -347,8 +427,114 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
|
|||||||
|
|
||||||
| 指标 | 数值 |
|
| 指标 | 数值 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| **测试总数** | 922 |
|
| **测试总数** | 1,601 |
|
||||||
| **通过率** | 100% |
|
| **通过率** | 100% |
|
||||||
|
| **覆盖率** | 28% |
|
||||||
|
|
||||||
|
## 存储抽象层
|
||||||
|
|
||||||
|
统一的文件存储接口,支持多后端切换:
|
||||||
|
|
||||||
|
| 后端 | 用途 | 安装 |
|
||||||
|
|------|------|------|
|
||||||
|
| **Local** | 本地开发/测试 | 默认 |
|
||||||
|
| **Azure Blob** | Azure 云部署 | `pip install -e "packages/shared[azure]"` |
|
||||||
|
| **AWS S3** | AWS 云部署 | `pip install -e "packages/shared[s3]"` |
|
||||||
|
|
||||||
|
### 配置文件 (storage.yaml)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
backend: ${STORAGE_BACKEND:-local}
|
||||||
|
presigned_url_expiry: 3600
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: ${STORAGE_BASE_PATH:-./data/storage}
|
||||||
|
|
||||||
|
azure:
|
||||||
|
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
|
||||||
|
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
|
||||||
|
|
||||||
|
s3:
|
||||||
|
bucket_name: ${AWS_S3_BUCKET}
|
||||||
|
region_name: ${AWS_REGION:-us-east-1}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
from shared.storage import get_storage_backend
|
||||||
|
|
||||||
|
# 从配置文件加载
|
||||||
|
storage = get_storage_backend("storage.yaml")
|
||||||
|
|
||||||
|
# 上传文件
|
||||||
|
storage.upload(Path("local.pdf"), "documents/invoice.pdf")
|
||||||
|
|
||||||
|
# 获取预签名 URL (前端访问)
|
||||||
|
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=3600)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 环境变量
|
||||||
|
|
||||||
|
| 变量 | 后端 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `STORAGE_BACKEND` | 全部 | `local`, `azure_blob`, `s3` |
|
||||||
|
| `STORAGE_BASE_PATH` | Local | 本地存储路径 |
|
||||||
|
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | 连接字符串 |
|
||||||
|
| `AZURE_STORAGE_CONTAINER` | Azure | 容器名称 |
|
||||||
|
| `AWS_S3_BUCKET` | S3 | 存储桶名称 |
|
||||||
|
| `AWS_REGION` | S3 | 区域 (默认: us-east-1) |
|
||||||
|
|
||||||
|
## 数据增强
|
||||||
|
|
||||||
|
训练时支持多种数据增强策略:
|
||||||
|
|
||||||
|
| 增强类型 | 说明 |
|
||||||
|
|----------|------|
|
||||||
|
| `perspective_warp` | 透视变换 (模拟扫描角度) |
|
||||||
|
| `wrinkle` | 皱纹效果 |
|
||||||
|
| `edge_damage` | 边缘损坏 |
|
||||||
|
| `stain` | 污渍效果 |
|
||||||
|
| `lighting_variation` | 光照变化 |
|
||||||
|
| `shadow` | 阴影效果 |
|
||||||
|
| `gaussian_blur` | 高斯模糊 |
|
||||||
|
| `motion_blur` | 运动模糊 |
|
||||||
|
| `gaussian_noise` | 高斯噪声 |
|
||||||
|
| `salt_pepper` | 椒盐噪声 |
|
||||||
|
| `paper_texture` | 纸张纹理 |
|
||||||
|
| `scanner_artifacts` | 扫描伪影 |
|
||||||
|
|
||||||
|
增强配置示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"augmentation": {
|
||||||
|
"gaussian_blur": { "enabled": true, "kernel_size": 5 },
|
||||||
|
"perspective_warp": { "enabled": true, "intensity": 0.1 }
|
||||||
|
},
|
||||||
|
"augmentation_multiplier": 2
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 前端功能
|
||||||
|
|
||||||
|
React 前端提供以下功能模块:
|
||||||
|
|
||||||
|
| 模块 | 功能 |
|
||||||
|
|------|------|
|
||||||
|
| **Dashboard** | 文档列表、上传、标注状态管理、分类筛选 |
|
||||||
|
| **Training** | 数据集创建/管理、训练任务配置、增强设置 |
|
||||||
|
| **Models** | 模型版本管理、激活/归档、指标查看 |
|
||||||
|
| **Inference Demo** | 实时推理演示、结果可视化 |
|
||||||
|
|
||||||
|
### 启动前端
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm install
|
||||||
|
npm run dev
|
||||||
|
# 访问 http://localhost:5173
|
||||||
|
```
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
@@ -357,10 +543,27 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
|
|||||||
| **目标检测** | YOLOv11 (Ultralytics) |
|
| **目标检测** | YOLOv11 (Ultralytics) |
|
||||||
| **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) |
|
| **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) |
|
||||||
| **PDF 处理** | PyMuPDF (fitz) |
|
| **PDF 处理** | PyMuPDF (fitz) |
|
||||||
| **数据库** | PostgreSQL + psycopg2 |
|
| **数据库** | PostgreSQL + SQLModel |
|
||||||
| **Web 框架** | FastAPI + Uvicorn |
|
| **Web 框架** | FastAPI + Uvicorn |
|
||||||
|
| **前端** | React + TypeScript + Vite + TailwindCSS |
|
||||||
|
| **状态管理** | React Query (TanStack Query) |
|
||||||
| **深度学习** | PyTorch + CUDA 12.x |
|
| **深度学习** | PyTorch + CUDA 12.x |
|
||||||
| **部署** | Docker + Azure ACI (训练) / App Service (推理) |
|
| **部署** | Docker + Azure/AWS (训练) / App Service (推理) |
|
||||||
|
|
||||||
|
## 环境变量
|
||||||
|
|
||||||
|
| 变量 | 必需 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `DB_PASSWORD` | 是 | PostgreSQL 密码 |
|
||||||
|
| `DB_HOST` | 否 | 数据库主机 (默认: localhost) |
|
||||||
|
| `DB_PORT` | 否 | 数据库端口 (默认: 5432) |
|
||||||
|
| `DB_NAME` | 否 | 数据库名 (默认: docmaster) |
|
||||||
|
| `DB_USER` | 否 | 数据库用户 (默认: docmaster) |
|
||||||
|
| `STORAGE_BASE_PATH` | 否 | 存储路径 (默认: ~/invoice-data/data) |
|
||||||
|
| `MODEL_PATH` | 否 | 模型路径 |
|
||||||
|
| `CONFIDENCE_THRESHOLD` | 否 | 置信度阈值 (默认: 0.5) |
|
||||||
|
| `SERVER_HOST` | 否 | 服务器主机 (默认: 0.0.0.0) |
|
||||||
|
| `SERVER_PORT` | 否 | 服务器端口 (默认: 8000) |
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
|
|||||||
772
docs/aws-deployment-guide.md
Normal file
772
docs/aws-deployment-guide.md
Normal file
@@ -0,0 +1,772 @@
|
|||||||
|
# AWS 部署方案完整指南
|
||||||
|
|
||||||
|
## 目录
|
||||||
|
- [核心问题](#核心问题)
|
||||||
|
- [存储方案](#存储方案)
|
||||||
|
- [训练方案](#训练方案)
|
||||||
|
- [推理方案](#推理方案)
|
||||||
|
- [价格对比](#价格对比)
|
||||||
|
- [推荐架构](#推荐架构)
|
||||||
|
- [实施步骤](#实施步骤)
|
||||||
|
- [AWS vs Azure 对比](#aws-vs-azure-对比)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 核心问题
|
||||||
|
|
||||||
|
| 问题 | 答案 |
|
||||||
|
|------|------|
|
||||||
|
| S3 能用于训练吗? | 可以,用 Mountpoint for S3 或 SageMaker 原生支持 |
|
||||||
|
| 能实时从 S3 读取训练吗? | 可以,SageMaker 支持 Pipe Mode 流式读取 |
|
||||||
|
| 本地能挂载 S3 吗? | 可以,用 s3fs-fuse 或 Rclone |
|
||||||
|
| EC2 空闲时收费吗? | 收费,只要运行就按小时计费 |
|
||||||
|
| 如何按需付费? | 用 SageMaker Managed Spot 或 Lambda |
|
||||||
|
| 推理服务用什么? | Lambda (Serverless) 或 ECS/Fargate (容器) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 存储方案
|
||||||
|
|
||||||
|
### Amazon S3(推荐)
|
||||||
|
|
||||||
|
S3 是 AWS 的核心存储服务,与 SageMaker 深度集成。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 S3 桶
|
||||||
|
aws s3 mb s3://invoice-training-data --region us-east-1
|
||||||
|
|
||||||
|
# 上传训练数据
|
||||||
|
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
|
||||||
|
|
||||||
|
# 创建目录结构
|
||||||
|
aws s3api put-object --bucket invoice-training-data --key datasets/
|
||||||
|
aws s3api put-object --bucket invoice-training-data --key models/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mountpoint for Amazon S3
|
||||||
|
|
||||||
|
AWS 官方的 S3 挂载客户端,性能优于 s3fs:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装 Mountpoint
|
||||||
|
wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb
|
||||||
|
sudo dpkg -i mount-s3.deb
|
||||||
|
|
||||||
|
# 挂载 S3
|
||||||
|
mkdir -p /mnt/s3-data
|
||||||
|
mount-s3 invoice-training-data /mnt/s3-data --region us-east-1
|
||||||
|
|
||||||
|
# 配置缓存(推荐)
|
||||||
|
mount-s3 invoice-training-data /mnt/s3-data \
|
||||||
|
--region us-east-1 \
|
||||||
|
--cache /tmp/s3-cache \
|
||||||
|
--metadata-ttl 60
|
||||||
|
```
|
||||||
|
|
||||||
|
### 本地开发挂载
|
||||||
|
|
||||||
|
**Linux/Mac (s3fs-fuse):**
|
||||||
|
```bash
|
||||||
|
# 安装
|
||||||
|
sudo apt-get install s3fs
|
||||||
|
|
||||||
|
# 配置凭证
|
||||||
|
echo ACCESS_KEY_ID:SECRET_ACCESS_KEY > ~/.passwd-s3fs
|
||||||
|
chmod 600 ~/.passwd-s3fs
|
||||||
|
|
||||||
|
# 挂载
|
||||||
|
s3fs invoice-training-data /mnt/s3 -o passwd_file=~/.passwd-s3fs
|
||||||
|
```
|
||||||
|
|
||||||
|
**Windows (Rclone):**
|
||||||
|
```powershell
|
||||||
|
# 安装
|
||||||
|
winget install Rclone.Rclone
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
rclone config # 选择 s3
|
||||||
|
|
||||||
|
# 挂载
|
||||||
|
rclone mount aws:invoice-training-data Z: --vfs-cache-mode full
|
||||||
|
```
|
||||||
|
|
||||||
|
### 存储费用
|
||||||
|
|
||||||
|
| 层级 | 价格 | 适用场景 |
|
||||||
|
|------|------|---------|
|
||||||
|
| S3 Standard | $0.023/GB/月 | 频繁访问 |
|
||||||
|
| S3 Intelligent-Tiering | $0.023/GB/月 | 自动分层 |
|
||||||
|
| S3 Infrequent Access | $0.0125/GB/月 | 偶尔访问 |
|
||||||
|
| S3 Glacier | $0.004/GB/月 | 长期存档 |
|
||||||
|
|
||||||
|
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.12/月**
|
||||||
|
|
||||||
|
### SageMaker 数据输入模式
|
||||||
|
|
||||||
|
| 模式 | 说明 | 适用场景 |
|
||||||
|
|------|------|---------|
|
||||||
|
| File Mode | 下载到本地再训练 | 小数据集 |
|
||||||
|
| Pipe Mode | 流式读取,不占本地空间 | 大数据集 |
|
||||||
|
| FastFile Mode | 按需下载,最高 3x 加速 | 推荐 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 训练方案
|
||||||
|
|
||||||
|
### 方案总览
|
||||||
|
|
||||||
|
| 方案 | 适用场景 | 空闲费用 | 复杂度 | Spot 支持 |
|
||||||
|
|------|---------|---------|--------|----------|
|
||||||
|
| EC2 GPU | 简单直接 | 24/7 收费 | 低 | 是 |
|
||||||
|
| SageMaker Training | MLOps 集成 | 按任务计费 | 中 | 是 |
|
||||||
|
| EKS + GPU | Kubernetes | 复杂计费 | 高 | 是 |
|
||||||
|
|
||||||
|
### EC2 vs SageMaker
|
||||||
|
|
||||||
|
| 特性 | EC2 | SageMaker |
|
||||||
|
|------|-----|-----------|
|
||||||
|
| 本质 | 虚拟机 | 托管 ML 平台 |
|
||||||
|
| 计算费用 | $3.06/hr (p3.2xlarge) | $3.825/hr (+25%) |
|
||||||
|
| 管理开销 | 需自己配置 | 全托管 |
|
||||||
|
| Spot 折扣 | 最高 90% | 最高 90% |
|
||||||
|
| 实验跟踪 | 无 | 内置 |
|
||||||
|
| 自动关机 | 无 | 任务完成自动停止 |
|
||||||
|
|
||||||
|
### GPU 实例价格 (2025 年 6 月降价后)
|
||||||
|
|
||||||
|
| 实例 | GPU | 显存 | On-Demand | Spot 价格 |
|
||||||
|
|------|-----|------|-----------|----------|
|
||||||
|
| g4dn.xlarge | 1x T4 | 16GB | $0.526/hr | ~$0.16/hr |
|
||||||
|
| g4dn.2xlarge | 1x T4 | 16GB | $0.752/hr | ~$0.23/hr |
|
||||||
|
| p3.2xlarge | 1x V100 | 16GB | $3.06/hr | ~$0.92/hr |
|
||||||
|
| p3.8xlarge | 4x V100 | 64GB | $12.24/hr | ~$3.67/hr |
|
||||||
|
| p4d.24xlarge | 8x A100 | 320GB | $32.77/hr | ~$9.83/hr |
|
||||||
|
|
||||||
|
**注意**: 2025 年 6 月 AWS 宣布 P4/P5 系列最高降价 45%。
|
||||||
|
|
||||||
|
### Spot 实例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# EC2 Spot 请求
|
||||||
|
aws ec2 request-spot-instances \
|
||||||
|
--instance-count 1 \
|
||||||
|
--type "one-time" \
|
||||||
|
--launch-specification '{
|
||||||
|
"ImageId": "ami-0123456789abcdef0",
|
||||||
|
"InstanceType": "p3.2xlarge",
|
||||||
|
"KeyName": "my-key"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### SageMaker Managed Spot Training
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sagemaker.pytorch import PyTorch
|
||||||
|
|
||||||
|
estimator = PyTorch(
|
||||||
|
entry_point="train.py",
|
||||||
|
source_dir="./src",
|
||||||
|
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||||
|
instance_count=1,
|
||||||
|
instance_type="ml.p3.2xlarge",
|
||||||
|
framework_version="2.0",
|
||||||
|
py_version="py310",
|
||||||
|
|
||||||
|
# 启用 Spot 实例
|
||||||
|
use_spot_instances=True,
|
||||||
|
max_run=3600, # 最长运行 1 小时
|
||||||
|
max_wait=7200, # 最长等待 2 小时
|
||||||
|
|
||||||
|
# 检查点配置(Spot 中断恢复)
|
||||||
|
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||||
|
checkpoint_local_path="/opt/ml/checkpoints",
|
||||||
|
|
||||||
|
hyperparameters={
|
||||||
|
"epochs": 100,
|
||||||
|
"batch-size": 16,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
estimator.fit({
|
||||||
|
"training": "s3://invoice-training-data/datasets/train/",
|
||||||
|
"validation": "s3://invoice-training-data/datasets/val/"
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 推理方案
|
||||||
|
|
||||||
|
### 方案对比
|
||||||
|
|
||||||
|
| 方案 | GPU 支持 | 扩缩容 | 冷启动 | 价格 | 适用场景 |
|
||||||
|
|------|---------|--------|--------|------|---------|
|
||||||
|
| Lambda | 否 | 自动 0-N | 快 | 按调用 | 低流量、CPU 推理 |
|
||||||
|
| Lambda + Container | 否 | 自动 0-N | 较慢 | 按调用 | 复杂依赖 |
|
||||||
|
| ECS Fargate | 否 | 自动 | 中 | ~$30/月 | 容器化服务 |
|
||||||
|
| ECS + EC2 GPU | 是 | 手动/自动 | 慢 | ~$100+/月 | GPU 推理 |
|
||||||
|
| SageMaker Endpoint | 是 | 自动 | 慢 | ~$80+/月 | MLOps 集成 |
|
||||||
|
| SageMaker Serverless | 否 | 自动 0-N | 中 | 按调用 | 间歇性流量 |
|
||||||
|
|
||||||
|
### 推荐方案 1: AWS Lambda (低流量)
|
||||||
|
|
||||||
|
对于 YOLO CPU 推理,Lambda 最经济:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# lambda_function.py
|
||||||
|
import json
|
||||||
|
import boto3
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# 模型在 Lambda Layer 或 /tmp 加载
|
||||||
|
model = None
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
global model
|
||||||
|
if model is None:
|
||||||
|
# 从 S3 下载模型到 /tmp
|
||||||
|
s3 = boto3.client('s3')
|
||||||
|
s3.download_file('invoice-models', 'best.pt', '/tmp/best.pt')
|
||||||
|
model = YOLO('/tmp/best.pt')
|
||||||
|
return model
|
||||||
|
|
||||||
|
def lambda_handler(event, context):
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
|
# 从 S3 获取图片
|
||||||
|
s3 = boto3.client('s3')
|
||||||
|
bucket = event['bucket']
|
||||||
|
key = event['key']
|
||||||
|
|
||||||
|
local_path = f'/tmp/{key.split("/")[-1]}'
|
||||||
|
s3.download_file(bucket, key, local_path)
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
results = model.predict(local_path, conf=0.5)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'statusCode': 200,
|
||||||
|
'body': json.dumps({
|
||||||
|
'fields': extract_fields(results),
|
||||||
|
'confidence': get_confidence(results)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Lambda 配置:**
|
||||||
|
```yaml
|
||||||
|
# serverless.yml
|
||||||
|
service: invoice-inference
|
||||||
|
|
||||||
|
provider:
|
||||||
|
name: aws
|
||||||
|
runtime: python3.11
|
||||||
|
timeout: 30
|
||||||
|
memorySize: 4096 # 4GB 内存
|
||||||
|
|
||||||
|
functions:
|
||||||
|
infer:
|
||||||
|
handler: lambda_function.lambda_handler
|
||||||
|
events:
|
||||||
|
- http:
|
||||||
|
path: /infer
|
||||||
|
method: post
|
||||||
|
layers:
|
||||||
|
- arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 推荐方案 2: ECS Fargate (中流量)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# task-definition.json
|
||||||
|
{
|
||||||
|
"family": "invoice-inference",
|
||||||
|
"networkMode": "awsvpc",
|
||||||
|
"requiresCompatibilities": ["FARGATE"],
|
||||||
|
"cpu": "2048",
|
||||||
|
"memory": "4096",
|
||||||
|
"containerDefinitions": [
|
||||||
|
{
|
||||||
|
"name": "inference",
|
||||||
|
"image": "123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest",
|
||||||
|
"portMappings": [
|
||||||
|
{
|
||||||
|
"containerPort": 8000,
|
||||||
|
"protocol": "tcp"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"environment": [
|
||||||
|
{"name": "MODEL_PATH", "value": "/app/models/best.pt"}
|
||||||
|
],
|
||||||
|
"logConfiguration": {
|
||||||
|
"logDriver": "awslogs",
|
||||||
|
"options": {
|
||||||
|
"awslogs-group": "/ecs/invoice-inference",
|
||||||
|
"awslogs-region": "us-east-1",
|
||||||
|
"awslogs-stream-prefix": "ecs"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Auto Scaling 配置:**
|
||||||
|
```bash
|
||||||
|
# 创建 Auto Scaling Target
|
||||||
|
aws application-autoscaling register-scalable-target \
|
||||||
|
--service-namespace ecs \
|
||||||
|
--resource-id service/invoice-cluster/invoice-service \
|
||||||
|
--scalable-dimension ecs:service:DesiredCount \
|
||||||
|
--min-capacity 1 \
|
||||||
|
--max-capacity 10
|
||||||
|
|
||||||
|
# 基于 CPU 使用率扩缩容
|
||||||
|
aws application-autoscaling put-scaling-policy \
|
||||||
|
--service-namespace ecs \
|
||||||
|
--resource-id service/invoice-cluster/invoice-service \
|
||||||
|
--scalable-dimension ecs:service:DesiredCount \
|
||||||
|
--policy-name cpu-scaling \
|
||||||
|
--policy-type TargetTrackingScaling \
|
||||||
|
--target-tracking-scaling-policy-configuration '{
|
||||||
|
"TargetValue": 70,
|
||||||
|
"PredefinedMetricSpecification": {
|
||||||
|
"PredefinedMetricType": "ECSServiceAverageCPUUtilization"
|
||||||
|
},
|
||||||
|
"ScaleOutCooldown": 60,
|
||||||
|
"ScaleInCooldown": 120
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方案 3: SageMaker Serverless Inference
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sagemaker.serverless import ServerlessInferenceConfig
|
||||||
|
from sagemaker.pytorch import PyTorchModel
|
||||||
|
|
||||||
|
model = PyTorchModel(
|
||||||
|
model_data="s3://invoice-models/model.tar.gz",
|
||||||
|
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||||
|
entry_point="inference.py",
|
||||||
|
framework_version="2.0",
|
||||||
|
py_version="py310"
|
||||||
|
)
|
||||||
|
|
||||||
|
serverless_config = ServerlessInferenceConfig(
|
||||||
|
memory_size_in_mb=4096,
|
||||||
|
max_concurrency=10
|
||||||
|
)
|
||||||
|
|
||||||
|
predictor = model.deploy(
|
||||||
|
serverless_inference_config=serverless_config,
|
||||||
|
endpoint_name="invoice-inference-serverless"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 推理性能对比
|
||||||
|
|
||||||
|
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|
||||||
|
|------|------------|---------|---------|
|
||||||
|
| Lambda 4GB | ~500-800ms | 按需扩展 | ~$15 (10K 请求) |
|
||||||
|
| Fargate 2vCPU 4GB | ~300-500ms | ~50 QPS | ~$30 |
|
||||||
|
| Fargate 4vCPU 8GB | ~200-300ms | ~100 QPS | ~$60 |
|
||||||
|
| EC2 g4dn.xlarge (T4) | ~50-100ms | ~200 QPS | ~$380 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 价格对比
|
||||||
|
|
||||||
|
### 训练成本对比(假设每天训练 2 小时)
|
||||||
|
|
||||||
|
| 方案 | 计算方式 | 月费 |
|
||||||
|
|------|---------|------|
|
||||||
|
| EC2 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
|
||||||
|
| EC2 按需启停 | 2h × 30天 × $3.06 | ~$184 |
|
||||||
|
| EC2 Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
|
||||||
|
| SageMaker On-Demand | 2h × 30天 × $3.825 | ~$230 |
|
||||||
|
| SageMaker Spot | 2h × 30天 × $1.15 | ~$69 |
|
||||||
|
|
||||||
|
### 本项目完整成本估算
|
||||||
|
|
||||||
|
| 组件 | 推荐方案 | 月费 |
|
||||||
|
|------|---------|------|
|
||||||
|
| 数据存储 | S3 Standard (5GB) | ~$0.12 |
|
||||||
|
| 数据库 | RDS PostgreSQL (db.t3.micro) | ~$15 |
|
||||||
|
| 推理服务 | Lambda (10K 请求/月) | ~$15 |
|
||||||
|
| 推理服务 (替代) | ECS Fargate | ~$30 |
|
||||||
|
| 训练服务 | SageMaker Spot (按需) | ~$2-5/次 |
|
||||||
|
| ECR (镜像存储) | 基本使用 | ~$1 |
|
||||||
|
| **总计 (Lambda)** | | **~$35/月** + 训练费 |
|
||||||
|
| **总计 (Fargate)** | | **~$50/月** + 训练费 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 推荐架构
|
||||||
|
|
||||||
|
### 整体架构图
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ Amazon S3 │
|
||||||
|
│ ├── training-images/ │
|
||||||
|
│ ├── datasets/ │
|
||||||
|
│ ├── models/ │
|
||||||
|
│ └── checkpoints/ │
|
||||||
|
└─────────────────┬───────────────────┘
|
||||||
|
│
|
||||||
|
┌─────────────────────────────────┼─────────────────────────────────┐
|
||||||
|
│ │ │
|
||||||
|
▼ ▼ ▼
|
||||||
|
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
|
||||||
|
│ 推理服务 │ │ 训练服务 │ │ API Gateway │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ 方案 A: Lambda │ │ SageMaker │ │ REST API │
|
||||||
|
│ ~$15/月 (10K req) │ │ Managed Spot │ │ 触发 Lambda/ECS │
|
||||||
|
│ │ │ ~$2-5/次训练 │ │ │
|
||||||
|
│ 方案 B: ECS Fargate │ │ │ │ │
|
||||||
|
│ ~$30/月 │ │ - 自动启动 │ │ │
|
||||||
|
│ │ │ - 训练完成自动停止 │ │ │
|
||||||
|
│ ┌───────────────────┐ │ │ - 检查点自动保存 │ │ │
|
||||||
|
│ │ FastAPI + YOLO │ │ │ │ │ │
|
||||||
|
│ │ CPU 推理 │ │ │ │ │ │
|
||||||
|
│ └───────────────────┘ │ └───────────┬───────────┘ └───────────────────────┘
|
||||||
|
└───────────┬───────────┘ │
|
||||||
|
│ │
|
||||||
|
└───────────────────────────────┼───────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────┐
|
||||||
|
│ Amazon RDS │
|
||||||
|
│ PostgreSQL │
|
||||||
|
│ db.t3.micro │
|
||||||
|
│ ~$15/月 │
|
||||||
|
└───────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lambda 推理配置
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# SAM template
|
||||||
|
AWSTemplateFormatVersion: '2010-09-09'
|
||||||
|
Transform: AWS::Serverless-2016-10-31
|
||||||
|
|
||||||
|
Resources:
|
||||||
|
InferenceFunction:
|
||||||
|
Type: AWS::Serverless::Function
|
||||||
|
Properties:
|
||||||
|
Handler: app.lambda_handler
|
||||||
|
Runtime: python3.11
|
||||||
|
MemorySize: 4096
|
||||||
|
Timeout: 30
|
||||||
|
Environment:
|
||||||
|
Variables:
|
||||||
|
MODEL_BUCKET: invoice-models
|
||||||
|
MODEL_KEY: best.pt
|
||||||
|
Policies:
|
||||||
|
- S3ReadPolicy:
|
||||||
|
BucketName: invoice-models
|
||||||
|
- S3ReadPolicy:
|
||||||
|
BucketName: invoice-uploads
|
||||||
|
Events:
|
||||||
|
InferApi:
|
||||||
|
Type: Api
|
||||||
|
Properties:
|
||||||
|
Path: /infer
|
||||||
|
Method: post
|
||||||
|
```
|
||||||
|
|
||||||
|
### SageMaker 训练配置
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sagemaker.pytorch import PyTorch
|
||||||
|
|
||||||
|
estimator = PyTorch(
|
||||||
|
entry_point="train.py",
|
||||||
|
source_dir="./src",
|
||||||
|
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||||
|
instance_count=1,
|
||||||
|
instance_type="ml.g4dn.xlarge", # T4 GPU
|
||||||
|
framework_version="2.0",
|
||||||
|
py_version="py310",
|
||||||
|
|
||||||
|
# Spot 实例配置
|
||||||
|
use_spot_instances=True,
|
||||||
|
max_run=7200,
|
||||||
|
max_wait=14400,
|
||||||
|
|
||||||
|
# 检查点
|
||||||
|
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||||
|
|
||||||
|
hyperparameters={
|
||||||
|
"epochs": 100,
|
||||||
|
"batch-size": 16,
|
||||||
|
"model": "yolo11n.pt"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实施步骤
|
||||||
|
|
||||||
|
### 阶段 1: 存储设置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 S3 桶
|
||||||
|
aws s3 mb s3://invoice-training-data --region us-east-1
|
||||||
|
aws s3 mb s3://invoice-models --region us-east-1
|
||||||
|
|
||||||
|
# 上传训练数据
|
||||||
|
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
|
||||||
|
|
||||||
|
# 配置生命周期(可选,自动转冷存储)
|
||||||
|
aws s3api put-bucket-lifecycle-configuration \
|
||||||
|
--bucket invoice-training-data \
|
||||||
|
--lifecycle-configuration '{
|
||||||
|
"Rules": [{
|
||||||
|
"ID": "MoveToIA",
|
||||||
|
"Status": "Enabled",
|
||||||
|
"Transitions": [{
|
||||||
|
"Days": 30,
|
||||||
|
"StorageClass": "STANDARD_IA"
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 2: 数据库设置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 RDS PostgreSQL
|
||||||
|
aws rds create-db-instance \
|
||||||
|
--db-instance-identifier invoice-db \
|
||||||
|
--db-instance-class db.t3.micro \
|
||||||
|
--engine postgres \
|
||||||
|
--engine-version 15 \
|
||||||
|
--master-username docmaster \
|
||||||
|
--master-user-password YOUR_PASSWORD \
|
||||||
|
--allocated-storage 20
|
||||||
|
|
||||||
|
# 配置安全组
|
||||||
|
aws ec2 authorize-security-group-ingress \
|
||||||
|
--group-id sg-xxx \
|
||||||
|
--protocol tcp \
|
||||||
|
--port 5432 \
|
||||||
|
--source-group sg-yyy
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 3: 推理服务部署
|
||||||
|
|
||||||
|
**方案 A: Lambda**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 Lambda Layer (依赖)
|
||||||
|
cd lambda-layer
|
||||||
|
pip install ultralytics opencv-python-headless -t python/
|
||||||
|
zip -r layer.zip python/
|
||||||
|
aws lambda publish-layer-version \
|
||||||
|
--layer-name yolo-deps \
|
||||||
|
--zip-file fileb://layer.zip \
|
||||||
|
--compatible-runtimes python3.11
|
||||||
|
|
||||||
|
# 部署 Lambda 函数
|
||||||
|
cd ../lambda
|
||||||
|
zip function.zip lambda_function.py
|
||||||
|
aws lambda create-function \
|
||||||
|
--function-name invoice-inference \
|
||||||
|
--runtime python3.11 \
|
||||||
|
--handler lambda_function.lambda_handler \
|
||||||
|
--role arn:aws:iam::123456789012:role/LambdaRole \
|
||||||
|
--zip-file fileb://function.zip \
|
||||||
|
--memory-size 4096 \
|
||||||
|
--timeout 30 \
|
||||||
|
--layers arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
|
||||||
|
|
||||||
|
# 创建 API Gateway
|
||||||
|
aws apigatewayv2 create-api \
|
||||||
|
--name invoice-api \
|
||||||
|
--protocol-type HTTP \
|
||||||
|
--target arn:aws:lambda:us-east-1:123456789012:function:invoice-inference
|
||||||
|
```
|
||||||
|
|
||||||
|
**方案 B: ECS Fargate**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 ECR 仓库
|
||||||
|
aws ecr create-repository --repository-name invoice-inference
|
||||||
|
|
||||||
|
# 构建并推送镜像
|
||||||
|
aws ecr get-login-password | docker login --username AWS --password-stdin 123456789012.dkr.ecr.us-east-1.amazonaws.com
|
||||||
|
docker build -t invoice-inference .
|
||||||
|
docker tag invoice-inference:latest 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
|
||||||
|
docker push 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
|
||||||
|
|
||||||
|
# 创建 ECS 集群
|
||||||
|
aws ecs create-cluster --cluster-name invoice-cluster
|
||||||
|
|
||||||
|
# 注册任务定义
|
||||||
|
aws ecs register-task-definition --cli-input-json file://task-definition.json
|
||||||
|
|
||||||
|
# 创建服务
|
||||||
|
aws ecs create-service \
|
||||||
|
--cluster invoice-cluster \
|
||||||
|
--service-name invoice-service \
|
||||||
|
--task-definition invoice-inference \
|
||||||
|
--desired-count 1 \
|
||||||
|
--launch-type FARGATE \
|
||||||
|
--network-configuration '{
|
||||||
|
"awsvpcConfiguration": {
|
||||||
|
"subnets": ["subnet-xxx"],
|
||||||
|
"securityGroups": ["sg-xxx"],
|
||||||
|
"assignPublicIp": "ENABLED"
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 4: 训练服务设置
|
||||||
|
|
||||||
|
```python
|
||||||
|
# setup_sagemaker.py
|
||||||
|
import boto3
|
||||||
|
import sagemaker
|
||||||
|
from sagemaker.pytorch import PyTorch
|
||||||
|
|
||||||
|
# 创建 SageMaker 执行角色
|
||||||
|
iam = boto3.client('iam')
|
||||||
|
role_arn = "arn:aws:iam::123456789012:role/SageMakerExecutionRole"
|
||||||
|
|
||||||
|
# 配置训练任务
|
||||||
|
estimator = PyTorch(
|
||||||
|
entry_point="train.py",
|
||||||
|
source_dir="./src/training",
|
||||||
|
role=role_arn,
|
||||||
|
instance_count=1,
|
||||||
|
instance_type="ml.g4dn.xlarge",
|
||||||
|
framework_version="2.0",
|
||||||
|
py_version="py310",
|
||||||
|
use_spot_instances=True,
|
||||||
|
max_run=7200,
|
||||||
|
max_wait=14400,
|
||||||
|
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存配置供后续使用
|
||||||
|
estimator.save("training_config.json")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 5: 集成训练触发 API
|
||||||
|
|
||||||
|
```python
|
||||||
|
# lambda_trigger_training.py
|
||||||
|
import boto3
|
||||||
|
import sagemaker
|
||||||
|
from sagemaker.pytorch import PyTorch
|
||||||
|
|
||||||
|
def lambda_handler(event, context):
|
||||||
|
"""触发 SageMaker 训练任务"""
|
||||||
|
|
||||||
|
epochs = event.get('epochs', 100)
|
||||||
|
|
||||||
|
estimator = PyTorch(
|
||||||
|
entry_point="train.py",
|
||||||
|
source_dir="s3://invoice-training-data/code/",
|
||||||
|
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||||
|
instance_count=1,
|
||||||
|
instance_type="ml.g4dn.xlarge",
|
||||||
|
framework_version="2.0",
|
||||||
|
py_version="py310",
|
||||||
|
use_spot_instances=True,
|
||||||
|
max_run=7200,
|
||||||
|
max_wait=14400,
|
||||||
|
hyperparameters={
|
||||||
|
"epochs": epochs,
|
||||||
|
"batch-size": 16,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
estimator.fit(
|
||||||
|
inputs={
|
||||||
|
"training": "s3://invoice-training-data/datasets/train/",
|
||||||
|
"validation": "s3://invoice-training-data/datasets/val/"
|
||||||
|
},
|
||||||
|
wait=False # 异步执行
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'statusCode': 200,
|
||||||
|
'body': {
|
||||||
|
'training_job_name': estimator.latest_training_job.name,
|
||||||
|
'status': 'Started'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AWS vs Azure 对比
|
||||||
|
|
||||||
|
### 服务对应关系
|
||||||
|
|
||||||
|
| 功能 | AWS | Azure |
|
||||||
|
|------|-----|-------|
|
||||||
|
| 对象存储 | S3 | Blob Storage |
|
||||||
|
| 挂载工具 | Mountpoint for S3 | BlobFuse2 |
|
||||||
|
| ML 平台 | SageMaker | Azure ML |
|
||||||
|
| 容器服务 | ECS/Fargate | Container Apps |
|
||||||
|
| Serverless | Lambda | Functions |
|
||||||
|
| GPU VM | EC2 P3/G4dn | NC/ND 系列 |
|
||||||
|
| 容器注册 | ECR | ACR |
|
||||||
|
| 数据库 | RDS PostgreSQL | PostgreSQL Flexible |
|
||||||
|
|
||||||
|
### 价格对比
|
||||||
|
|
||||||
|
| 组件 | AWS | Azure |
|
||||||
|
|------|-----|-------|
|
||||||
|
| 存储 (5GB) | ~$0.12/月 | ~$0.09/月 |
|
||||||
|
| 数据库 | ~$15/月 | ~$25/月 |
|
||||||
|
| 推理 (Serverless) | ~$15/月 | ~$30/月 |
|
||||||
|
| 推理 (容器) | ~$30/月 | ~$30/月 |
|
||||||
|
| 训练 (Spot GPU) | ~$2-5/次 | ~$1-5/次 |
|
||||||
|
| **总计** | **~$35-50/月** | **~$65/月** |
|
||||||
|
|
||||||
|
### 优劣对比
|
||||||
|
|
||||||
|
| 方面 | AWS 优势 | Azure 优势 |
|
||||||
|
|------|---------|-----------|
|
||||||
|
| 价格 | Lambda 更便宜 | GPU Spot 更便宜 |
|
||||||
|
| ML 平台 | SageMaker 更成熟 | Azure ML 更易用 |
|
||||||
|
| Serverless GPU | 无原生支持 | Container Apps GPU |
|
||||||
|
| 文档 | 更丰富 | 中文文档更好 |
|
||||||
|
| 生态 | 更大 | Office 365 集成 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
### 推荐配置
|
||||||
|
|
||||||
|
| 组件 | 推荐方案 | 月费估算 |
|
||||||
|
|------|---------|---------|
|
||||||
|
| 数据存储 | S3 Standard | ~$0.12 |
|
||||||
|
| 数据库 | RDS db.t3.micro | ~$15 |
|
||||||
|
| 推理服务 | Lambda 4GB | ~$15 |
|
||||||
|
| 训练服务 | SageMaker Spot | 按需 ~$2-5/次 |
|
||||||
|
| ECR | 基本使用 | ~$1 |
|
||||||
|
| **总计** | | **~$35/月** + 训练费 |
|
||||||
|
|
||||||
|
### 关键决策
|
||||||
|
|
||||||
|
| 场景 | 选择 |
|
||||||
|
|------|------|
|
||||||
|
| 最低成本 | Lambda + SageMaker Spot |
|
||||||
|
| 稳定推理 | ECS Fargate |
|
||||||
|
| GPU 推理 | ECS + EC2 GPU |
|
||||||
|
| MLOps 集成 | SageMaker 全家桶 |
|
||||||
|
|
||||||
|
### 注意事项
|
||||||
|
|
||||||
|
1. **Lambda 冷启动**: 首次调用 ~3-5 秒,可用 Provisioned Concurrency 解决
|
||||||
|
2. **Spot 中断**: 配置检查点,SageMaker 自动恢复
|
||||||
|
3. **S3 传输**: 同区域免费,跨区域收费
|
||||||
|
4. **Fargate 无 GPU**: 需要 GPU 必须用 ECS + EC2
|
||||||
|
5. **SageMaker 加价**: 比 EC2 贵 ~25%,但省管理成本
|
||||||
567
docs/azure-deployment-guide.md
Normal file
567
docs/azure-deployment-guide.md
Normal file
@@ -0,0 +1,567 @@
|
|||||||
|
# Azure 部署方案完整指南
|
||||||
|
|
||||||
|
## 目录
|
||||||
|
- [核心问题](#核心问题)
|
||||||
|
- [存储方案](#存储方案)
|
||||||
|
- [训练方案](#训练方案)
|
||||||
|
- [推理方案](#推理方案)
|
||||||
|
- [价格对比](#价格对比)
|
||||||
|
- [推荐架构](#推荐架构)
|
||||||
|
- [实施步骤](#实施步骤)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 核心问题
|
||||||
|
|
||||||
|
| 问题 | 答案 |
|
||||||
|
|------|------|
|
||||||
|
| Azure Blob Storage 能用于训练吗? | 可以,用 BlobFuse2 挂载 |
|
||||||
|
| 能实时从 Blob 读取训练吗? | 可以,但建议配置本地缓存 |
|
||||||
|
| 本地能挂载 Azure Blob 吗? | 可以,用 Rclone (Windows) 或 BlobFuse2 (Linux) |
|
||||||
|
| VM 空闲时收费吗? | 收费,只要开机就按小时计费 |
|
||||||
|
| 如何按需付费? | 用 Serverless GPU 或 min=0 的 Compute Cluster |
|
||||||
|
| 推理服务用什么? | Container Apps (CPU) 或 Serverless GPU |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 存储方案
|
||||||
|
|
||||||
|
### Azure Blob Storage + BlobFuse2(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装 BlobFuse2
|
||||||
|
sudo apt-get install blobfuse2
|
||||||
|
|
||||||
|
# 配置文件
|
||||||
|
cat > ~/blobfuse-config.yaml << 'EOF'
|
||||||
|
logging:
|
||||||
|
type: syslog
|
||||||
|
level: log_warning
|
||||||
|
|
||||||
|
components:
|
||||||
|
- libfuse
|
||||||
|
- file_cache
|
||||||
|
- azstorage
|
||||||
|
|
||||||
|
file_cache:
|
||||||
|
path: /tmp/blobfuse2
|
||||||
|
timeout-sec: 120
|
||||||
|
max-size-mb: 4096
|
||||||
|
|
||||||
|
azstorage:
|
||||||
|
type: block
|
||||||
|
account-name: YOUR_ACCOUNT
|
||||||
|
account-key: YOUR_KEY
|
||||||
|
container: training-images
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# 挂载
|
||||||
|
mkdir -p /mnt/azure-blob
|
||||||
|
blobfuse2 mount /mnt/azure-blob --config-file=~/blobfuse-config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 本地开发(Windows)
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
# 安装
|
||||||
|
winget install WinFsp.WinFsp
|
||||||
|
winget install Rclone.Rclone
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
rclone config # 选择 azureblob
|
||||||
|
|
||||||
|
# 挂载为 Z: 盘
|
||||||
|
rclone mount azure:training-images Z: --vfs-cache-mode full
|
||||||
|
```
|
||||||
|
|
||||||
|
### 存储费用
|
||||||
|
|
||||||
|
| 层级 | 价格 | 适用场景 |
|
||||||
|
|------|------|---------|
|
||||||
|
| Hot | $0.018/GB/月 | 频繁访问 |
|
||||||
|
| Cool | $0.01/GB/月 | 偶尔访问 |
|
||||||
|
| Archive | $0.002/GB/月 | 长期存档 |
|
||||||
|
|
||||||
|
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.09/月**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 训练方案
|
||||||
|
|
||||||
|
### 方案总览
|
||||||
|
|
||||||
|
| 方案 | 适用场景 | 空闲费用 | 复杂度 |
|
||||||
|
|------|---------|---------|--------|
|
||||||
|
| Azure VM | 简单直接 | 24/7 收费 | 低 |
|
||||||
|
| Azure VM Spot | 省钱、可中断 | 24/7 收费 | 低 |
|
||||||
|
| Azure ML Compute | MLOps 集成 | 可缩到 0 | 中 |
|
||||||
|
| Container Apps GPU | Serverless | 自动缩到 0 | 中 |
|
||||||
|
|
||||||
|
### Azure VM vs Azure ML
|
||||||
|
|
||||||
|
| 特性 | Azure VM | Azure ML |
|
||||||
|
|------|----------|----------|
|
||||||
|
| 本质 | 虚拟机 | 托管 ML 平台 |
|
||||||
|
| 计算费用 | $3.06/hr (NC6s_v3) | $3.06/hr (相同) |
|
||||||
|
| 附加费用 | ~$5/月 | ~$20-30/月 |
|
||||||
|
| 实验跟踪 | 无 | 内置 |
|
||||||
|
| 自动扩缩 | 无 | 支持 min=0 |
|
||||||
|
| 适用人群 | DevOps | 数据科学家 |
|
||||||
|
|
||||||
|
### Azure ML 附加费用明细
|
||||||
|
|
||||||
|
| 服务 | 用途 | 费用 |
|
||||||
|
|------|------|------|
|
||||||
|
| Container Registry | Docker 镜像 | ~$5-20/月 |
|
||||||
|
| Blob Storage | 日志、模型 | ~$0.10/月 |
|
||||||
|
| Application Insights | 监控 | ~$0-10/月 |
|
||||||
|
| Key Vault | 密钥管理 | <$1/月 |
|
||||||
|
|
||||||
|
### Spot 实例
|
||||||
|
|
||||||
|
两种平台都支持 Spot/低优先级实例,最高节省 90%:
|
||||||
|
|
||||||
|
| 类型 | 正常价格 | Spot 价格 | 节省 |
|
||||||
|
|------|---------|----------|------|
|
||||||
|
| NC6s_v3 (V100) | $3.06/hr | ~$0.92/hr | 70% |
|
||||||
|
| NC24ads_A100_v4 | $3.67/hr | ~$1.15/hr | 69% |
|
||||||
|
|
||||||
|
### GPU 实例价格
|
||||||
|
|
||||||
|
| 实例 | GPU | 显存 | 价格/小时 | Spot 价格 |
|
||||||
|
|------|-----|------|---------|----------|
|
||||||
|
| NC6s_v3 | 1x V100 | 16GB | $3.06 | $0.92 |
|
||||||
|
| NC24s_v3 | 4x V100 | 64GB | $12.24 | $3.67 |
|
||||||
|
| NC24ads_A100_v4 | 1x A100 | 80GB | $3.67 | $1.15 |
|
||||||
|
| NC48ads_A100_v4 | 2x A100 | 160GB | $7.35 | $2.30 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 推理方案
|
||||||
|
|
||||||
|
### 方案对比
|
||||||
|
|
||||||
|
| 方案 | GPU 支持 | 扩缩容 | 价格 | 适用场景 |
|
||||||
|
|------|---------|--------|------|---------|
|
||||||
|
| Container Apps (CPU) | 否 | 自动 0-N | ~$30/月 | YOLO 推理 (够用) |
|
||||||
|
| Container Apps (GPU) | 是 | Serverless | 按秒计费 | 高吞吐推理 |
|
||||||
|
| Azure App Service | 否 | 手动/自动 | ~$50/月 | 简单部署 |
|
||||||
|
| Azure ML Endpoint | 是 | 自动 | ~$100+/月 | MLOps 集成 |
|
||||||
|
| AKS (Kubernetes) | 是 | 自动 | 复杂计费 | 大规模生产 |
|
||||||
|
|
||||||
|
### 推荐: Container Apps (CPU)
|
||||||
|
|
||||||
|
对于 YOLO 推理,**CPU 足够**,不需要 GPU:
|
||||||
|
- YOLOv11n 在 CPU 上推理时间 ~200-500ms
|
||||||
|
- 比 GPU 便宜很多,适合中低流量
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Container Apps 配置
|
||||||
|
name: invoice-inference
|
||||||
|
image: myacr.azurecr.io/invoice-inference:v1
|
||||||
|
resources:
|
||||||
|
cpu: 2.0
|
||||||
|
memory: 4Gi
|
||||||
|
scale:
|
||||||
|
minReplicas: 1 # 最少 1 个实例保持响应
|
||||||
|
maxReplicas: 10 # 最多扩展到 10 个
|
||||||
|
rules:
|
||||||
|
- name: http-scaling
|
||||||
|
http:
|
||||||
|
metadata:
|
||||||
|
concurrentRequests: "50" # 每实例 50 并发时扩容
|
||||||
|
```
|
||||||
|
|
||||||
|
### 推理服务代码示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Dockerfile
|
||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# 复制代码和模型
|
||||||
|
COPY src/ ./src/
|
||||||
|
COPY models/best.pt ./models/
|
||||||
|
|
||||||
|
# 启动服务
|
||||||
|
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# src/web/app.py
|
||||||
|
from fastapi import FastAPI, UploadFile, File
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
model = YOLO("models/best.pt")
|
||||||
|
|
||||||
|
@app.post("/api/v1/infer")
|
||||||
|
async def infer(file: UploadFile = File(...)):
|
||||||
|
# 保存上传文件
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
||||||
|
content = await file.read()
|
||||||
|
tmp.write(content)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
results = model.predict(tmp_path, conf=0.5)
|
||||||
|
|
||||||
|
# 返回结果
|
||||||
|
return {
|
||||||
|
"fields": extract_fields(results),
|
||||||
|
"confidence": get_confidence(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "healthy"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 部署命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 创建 Container Registry
|
||||||
|
az acr create --name invoiceacr --resource-group myRG --sku Basic
|
||||||
|
|
||||||
|
# 2. 构建并推送镜像
|
||||||
|
az acr build --registry invoiceacr --image invoice-inference:v1 .
|
||||||
|
|
||||||
|
# 3. 创建 Container Apps 环境
|
||||||
|
az containerapp env create \
|
||||||
|
--name invoice-env \
|
||||||
|
--resource-group myRG \
|
||||||
|
--location eastus
|
||||||
|
|
||||||
|
# 4. 部署应用
|
||||||
|
az containerapp create \
|
||||||
|
--name invoice-inference \
|
||||||
|
--resource-group myRG \
|
||||||
|
--environment invoice-env \
|
||||||
|
--image invoiceacr.azurecr.io/invoice-inference:v1 \
|
||||||
|
--registry-server invoiceacr.azurecr.io \
|
||||||
|
--cpu 2 --memory 4Gi \
|
||||||
|
--min-replicas 1 --max-replicas 10 \
|
||||||
|
--ingress external --target-port 8000
|
||||||
|
|
||||||
|
# 5. 获取 URL
|
||||||
|
az containerapp show --name invoice-inference --resource-group myRG --query properties.configuration.ingress.fqdn
|
||||||
|
```
|
||||||
|
|
||||||
|
### 高吞吐场景: Serverless GPU
|
||||||
|
|
||||||
|
如果需要 GPU 加速推理(高并发、低延迟):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 请求 GPU 配额
|
||||||
|
az containerapp env workload-profile add \
|
||||||
|
--name invoice-env \
|
||||||
|
--resource-group myRG \
|
||||||
|
--workload-profile-name gpu \
|
||||||
|
--workload-profile-type Consumption-GPU-T4
|
||||||
|
|
||||||
|
# 部署 GPU 版本
|
||||||
|
az containerapp create \
|
||||||
|
--name invoice-inference-gpu \
|
||||||
|
--resource-group myRG \
|
||||||
|
--environment invoice-env \
|
||||||
|
--image invoiceacr.azurecr.io/invoice-inference-gpu:v1 \
|
||||||
|
--workload-profile-name gpu \
|
||||||
|
--cpu 4 --memory 8Gi \
|
||||||
|
--min-replicas 0 --max-replicas 5 \
|
||||||
|
--ingress external --target-port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
### 推理性能对比
|
||||||
|
|
||||||
|
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|
||||||
|
|------|------------|---------|---------|
|
||||||
|
| CPU 2核 4GB | ~300-500ms | ~50 QPS | ~$30 |
|
||||||
|
| CPU 4核 8GB | ~200-300ms | ~100 QPS | ~$60 |
|
||||||
|
| GPU T4 | ~50-100ms | ~200 QPS | 按秒计费 |
|
||||||
|
| GPU A100 | ~20-50ms | ~500 QPS | 按秒计费 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 价格对比
|
||||||
|
|
||||||
|
### 月度成本对比(假设每天训练 2 小时)
|
||||||
|
|
||||||
|
| 方案 | 计算方式 | 月费 |
|
||||||
|
|------|---------|------|
|
||||||
|
| VM 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
|
||||||
|
| VM 按需启停 | 2h × 30天 × $3.06 | ~$184 |
|
||||||
|
| VM Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
|
||||||
|
| Serverless GPU | 2h × 30天 × ~$3.50 | ~$210 |
|
||||||
|
| Azure ML (min=0) | 2h × 30天 × $3.06 | ~$184 |
|
||||||
|
|
||||||
|
### 本项目完整成本估算
|
||||||
|
|
||||||
|
| 组件 | 推荐方案 | 月费 |
|
||||||
|
|------|---------|------|
|
||||||
|
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
|
||||||
|
| 数据库 | PostgreSQL Flexible (Burstable B1ms) | ~$25 |
|
||||||
|
| 推理服务 | Container Apps CPU (2核4GB) | ~$30 |
|
||||||
|
| 训练服务 | Azure ML Spot (按需) | ~$1-5/次 |
|
||||||
|
| Container Registry | Basic | ~$5 |
|
||||||
|
| **总计** | | **~$65/月** + 训练费 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 推荐架构
|
||||||
|
|
||||||
|
### 整体架构图
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ Azure Blob Storage │
|
||||||
|
│ ├── training-images/ │
|
||||||
|
│ ├── datasets/ │
|
||||||
|
│ └── models/ │
|
||||||
|
└─────────────────┬───────────────────┘
|
||||||
|
│
|
||||||
|
┌─────────────────────────────────┼─────────────────────────────────┐
|
||||||
|
│ │ │
|
||||||
|
▼ ▼ ▼
|
||||||
|
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
|
||||||
|
│ 推理服务 (24/7) │ │ 训练服务 (按需) │ │ Web UI (可选) │
|
||||||
|
│ Container Apps │ │ Azure ML Compute │ │ Static Web Apps │
|
||||||
|
│ CPU 2核 4GB │ │ min=0, Spot │ │ ~$0 (免费层) │
|
||||||
|
│ ~$30/月 │ │ ~$1-5/次训练 │ │ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │
|
||||||
|
│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │
|
||||||
|
│ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │
|
||||||
|
│ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │
|
||||||
|
└───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘
|
||||||
|
│ │ │
|
||||||
|
└───────────────────────────────┼───────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────┐
|
||||||
|
│ PostgreSQL │
|
||||||
|
│ Flexible Server │
|
||||||
|
│ Burstable B1ms │
|
||||||
|
│ ~$25/月 │
|
||||||
|
└───────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 推理服务配置
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Container Apps - CPU (24/7 运行)
|
||||||
|
name: invoice-inference
|
||||||
|
resources:
|
||||||
|
cpu: 2
|
||||||
|
memory: 4Gi
|
||||||
|
scale:
|
||||||
|
minReplicas: 1
|
||||||
|
maxReplicas: 10
|
||||||
|
env:
|
||||||
|
- name: MODEL_PATH
|
||||||
|
value: /app/models/best.pt
|
||||||
|
- name: DB_HOST
|
||||||
|
secretRef: db-host
|
||||||
|
- name: DB_PASSWORD
|
||||||
|
secretRef: db-password
|
||||||
|
```
|
||||||
|
|
||||||
|
### 训练服务配置
|
||||||
|
|
||||||
|
**方案 A: Azure ML Compute(推荐)**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from azure.ai.ml.entities import AmlCompute
|
||||||
|
|
||||||
|
gpu_cluster = AmlCompute(
|
||||||
|
name="gpu-cluster",
|
||||||
|
size="Standard_NC6s_v3",
|
||||||
|
min_instances=0, # 空闲时关机
|
||||||
|
max_instances=1,
|
||||||
|
tier="LowPriority", # Spot 实例
|
||||||
|
idle_time_before_scale_down=120
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**方案 B: Container Apps Serverless GPU**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: invoice-training
|
||||||
|
resources:
|
||||||
|
gpu: 1
|
||||||
|
gpuType: A100
|
||||||
|
scale:
|
||||||
|
minReplicas: 0
|
||||||
|
maxReplicas: 1
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实施步骤
|
||||||
|
|
||||||
|
### 阶段 1: 存储设置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 Storage Account
|
||||||
|
az storage account create \
|
||||||
|
--name invoicestorage \
|
||||||
|
--resource-group myRG \
|
||||||
|
--sku Standard_LRS
|
||||||
|
|
||||||
|
# 创建容器
|
||||||
|
az storage container create --name training-images --account-name invoicestorage
|
||||||
|
az storage container create --name datasets --account-name invoicestorage
|
||||||
|
az storage container create --name models --account-name invoicestorage
|
||||||
|
|
||||||
|
# 上传训练数据
|
||||||
|
az storage blob upload-batch \
|
||||||
|
--destination training-images \
|
||||||
|
--source ./data/dataset/temp \
|
||||||
|
--account-name invoicestorage
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 2: 数据库设置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 PostgreSQL
|
||||||
|
az postgres flexible-server create \
|
||||||
|
--name invoice-db \
|
||||||
|
--resource-group myRG \
|
||||||
|
--sku-name Standard_B1ms \
|
||||||
|
--storage-size 32 \
|
||||||
|
--admin-user docmaster \
|
||||||
|
--admin-password YOUR_PASSWORD
|
||||||
|
|
||||||
|
# 配置防火墙
|
||||||
|
az postgres flexible-server firewall-rule create \
|
||||||
|
--name allow-azure \
|
||||||
|
--resource-group myRG \
|
||||||
|
--server-name invoice-db \
|
||||||
|
--start-ip-address 0.0.0.0 \
|
||||||
|
--end-ip-address 0.0.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 3: 推理服务部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 Container Registry
|
||||||
|
az acr create --name invoiceacr --resource-group myRG --sku Basic
|
||||||
|
|
||||||
|
# 构建镜像
|
||||||
|
az acr build --registry invoiceacr --image invoice-inference:v1 .
|
||||||
|
|
||||||
|
# 创建环境
|
||||||
|
az containerapp env create \
|
||||||
|
--name invoice-env \
|
||||||
|
--resource-group myRG \
|
||||||
|
--location eastus
|
||||||
|
|
||||||
|
# 部署推理服务
|
||||||
|
az containerapp create \
|
||||||
|
--name invoice-inference \
|
||||||
|
--resource-group myRG \
|
||||||
|
--environment invoice-env \
|
||||||
|
--image invoiceacr.azurecr.io/invoice-inference:v1 \
|
||||||
|
--registry-server invoiceacr.azurecr.io \
|
||||||
|
--cpu 2 --memory 4Gi \
|
||||||
|
--min-replicas 1 --max-replicas 10 \
|
||||||
|
--ingress external --target-port 8000 \
|
||||||
|
--env-vars \
|
||||||
|
DB_HOST=invoice-db.postgres.database.azure.com \
|
||||||
|
DB_NAME=docmaster \
|
||||||
|
DB_USER=docmaster \
|
||||||
|
--secrets db-password=YOUR_PASSWORD
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 4: 训练服务设置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建 Azure ML Workspace
|
||||||
|
az ml workspace create --name invoice-ml --resource-group myRG
|
||||||
|
|
||||||
|
# 创建 Compute Cluster
|
||||||
|
az ml compute create --name gpu-cluster \
|
||||||
|
--type AmlCompute \
|
||||||
|
--size Standard_NC6s_v3 \
|
||||||
|
--min-instances 0 \
|
||||||
|
--max-instances 1 \
|
||||||
|
--tier low_priority
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段 5: 集成训练触发 API
|
||||||
|
|
||||||
|
```python
|
||||||
|
# src/web/routes/training.py
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from azure.ai.ml import MLClient, command
|
||||||
|
from azure.identity import DefaultAzureCredential
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
ml_client = MLClient(
|
||||||
|
credential=DefaultAzureCredential(),
|
||||||
|
subscription_id="your-subscription-id",
|
||||||
|
resource_group_name="myRG",
|
||||||
|
workspace_name="invoice-ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/api/v1/train")
|
||||||
|
async def trigger_training(request: TrainingRequest):
|
||||||
|
"""触发 Azure ML 训练任务"""
|
||||||
|
training_job = command(
|
||||||
|
code="./training",
|
||||||
|
command=f"python train.py --epochs {request.epochs}",
|
||||||
|
environment="AzureML-pytorch-2.0-cuda11.8@latest",
|
||||||
|
compute="gpu-cluster",
|
||||||
|
)
|
||||||
|
job = ml_client.jobs.create_or_update(training_job)
|
||||||
|
return {
|
||||||
|
"job_id": job.name,
|
||||||
|
"status": job.status,
|
||||||
|
"studio_url": job.studio_url
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/api/v1/train/{job_id}/status")
|
||||||
|
async def get_training_status(job_id: str):
|
||||||
|
"""查询训练状态"""
|
||||||
|
job = ml_client.jobs.get(job_id)
|
||||||
|
return {"status": job.status}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
### 推荐配置
|
||||||
|
|
||||||
|
| 组件 | 推荐方案 | 月费估算 |
|
||||||
|
|------|---------|---------|
|
||||||
|
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
|
||||||
|
| 数据库 | PostgreSQL Flexible | ~$25 |
|
||||||
|
| 推理服务 | Container Apps CPU | ~$30 |
|
||||||
|
| 训练服务 | Azure ML (min=0, Spot) | 按需 ~$1-5/次 |
|
||||||
|
| Container Registry | Basic | ~$5 |
|
||||||
|
| **总计** | | **~$65/月** + 训练费 |
|
||||||
|
|
||||||
|
### 关键决策
|
||||||
|
|
||||||
|
| 场景 | 选择 |
|
||||||
|
|------|------|
|
||||||
|
| 偶尔训练,简单需求 | Azure VM Spot + 手动启停 |
|
||||||
|
| 需要 MLOps,团队协作 | Azure ML Compute |
|
||||||
|
| 追求最低空闲成本 | Container Apps Serverless GPU |
|
||||||
|
| 生产环境推理 | Container Apps CPU |
|
||||||
|
| 高并发推理 | Container Apps Serverless GPU |
|
||||||
|
|
||||||
|
### 注意事项
|
||||||
|
|
||||||
|
1. **冷启动**: Serverless GPU 启动需要 3-8 分钟
|
||||||
|
2. **Spot 中断**: 可能被抢占,需要检查点机制
|
||||||
|
3. **网络延迟**: Blob Storage 挂载比本地 SSD 慢,建议开启缓存
|
||||||
|
4. **区域选择**: 选择有 GPU 配额的区域 (East US, West Europe 等)
|
||||||
|
5. **推理优化**: CPU 推理对于 YOLO 已经足够,无需 GPU
|
||||||
@@ -4,11 +4,13 @@ import type {
|
|||||||
DocumentDetailResponse,
|
DocumentDetailResponse,
|
||||||
DocumentItem,
|
DocumentItem,
|
||||||
UploadDocumentResponse,
|
UploadDocumentResponse,
|
||||||
|
DocumentCategoriesResponse,
|
||||||
} from '../types'
|
} from '../types'
|
||||||
|
|
||||||
export const documentsApi = {
|
export const documentsApi = {
|
||||||
list: async (params?: {
|
list: async (params?: {
|
||||||
status?: string
|
status?: string
|
||||||
|
category?: string
|
||||||
limit?: number
|
limit?: number
|
||||||
offset?: number
|
offset?: number
|
||||||
}): Promise<DocumentListResponse> => {
|
}): Promise<DocumentListResponse> => {
|
||||||
@@ -16,18 +18,29 @@ export const documentsApi = {
|
|||||||
return data
|
return data
|
||||||
},
|
},
|
||||||
|
|
||||||
|
getCategories: async (): Promise<DocumentCategoriesResponse> => {
|
||||||
|
const { data } = await apiClient.get('/api/v1/admin/documents/categories')
|
||||||
|
return data
|
||||||
|
},
|
||||||
|
|
||||||
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
|
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
|
||||||
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
|
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
|
||||||
return data
|
return data
|
||||||
},
|
},
|
||||||
|
|
||||||
upload: async (file: File, groupKey?: string): Promise<UploadDocumentResponse> => {
|
upload: async (
|
||||||
|
file: File,
|
||||||
|
options?: { groupKey?: string; category?: string }
|
||||||
|
): Promise<UploadDocumentResponse> => {
|
||||||
const formData = new FormData()
|
const formData = new FormData()
|
||||||
formData.append('file', file)
|
formData.append('file', file)
|
||||||
|
|
||||||
const params: Record<string, string> = {}
|
const params: Record<string, string> = {}
|
||||||
if (groupKey) {
|
if (options?.groupKey) {
|
||||||
params.group_key = groupKey
|
params.group_key = options.groupKey
|
||||||
|
}
|
||||||
|
if (options?.category) {
|
||||||
|
params.category = options.category
|
||||||
}
|
}
|
||||||
|
|
||||||
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
||||||
@@ -95,4 +108,15 @@ export const documentsApi = {
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
},
|
},
|
||||||
|
|
||||||
|
updateCategory: async (
|
||||||
|
documentId: string,
|
||||||
|
category: string
|
||||||
|
): Promise<{ status: string; document_id: string; category: string; message: string }> => {
|
||||||
|
const { data } = await apiClient.patch(
|
||||||
|
`/api/v1/admin/documents/${documentId}/category`,
|
||||||
|
{ category }
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ export interface DocumentItem {
|
|||||||
auto_label_error: string | null
|
auto_label_error: string | null
|
||||||
upload_source: string
|
upload_source: string
|
||||||
group_key: string | null
|
group_key: string | null
|
||||||
|
category: string
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
annotation_count?: number
|
annotation_count?: number
|
||||||
@@ -61,6 +62,7 @@ export interface DocumentDetailResponse {
|
|||||||
upload_source: string
|
upload_source: string
|
||||||
batch_id: string | null
|
batch_id: string | null
|
||||||
group_key: string | null
|
group_key: string | null
|
||||||
|
category: string
|
||||||
csv_field_values: Record<string, string> | null
|
csv_field_values: Record<string, string> | null
|
||||||
can_annotate: boolean
|
can_annotate: boolean
|
||||||
annotation_lock_until: string | null
|
annotation_lock_until: string | null
|
||||||
@@ -101,8 +103,21 @@ export interface TrainingTask {
|
|||||||
updated_at: string
|
updated_at: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ModelVersionItem {
|
||||||
|
version_id: string
|
||||||
|
version: string
|
||||||
|
name: string
|
||||||
|
status: string
|
||||||
|
is_active: boolean
|
||||||
|
metrics_mAP: number | null
|
||||||
|
document_count: number
|
||||||
|
trained_at: string | null
|
||||||
|
activated_at: string | null
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
export interface TrainingModelsResponse {
|
export interface TrainingModelsResponse {
|
||||||
models: TrainingTask[]
|
models: ModelVersionItem[]
|
||||||
total: number
|
total: number
|
||||||
limit: number
|
limit: number
|
||||||
offset: number
|
offset: number
|
||||||
@@ -118,11 +133,17 @@ export interface UploadDocumentResponse {
|
|||||||
file_size: number
|
file_size: number
|
||||||
page_count: number
|
page_count: number
|
||||||
status: string
|
status: string
|
||||||
|
category: string
|
||||||
group_key: string | null
|
group_key: string | null
|
||||||
auto_label_started: boolean
|
auto_label_started: boolean
|
||||||
message: string
|
message: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface DocumentCategoriesResponse {
|
||||||
|
categories: string[]
|
||||||
|
total: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface CreateAnnotationRequest {
|
export interface CreateAnnotationRequest {
|
||||||
page_number: number
|
page_number: number
|
||||||
class_id: number
|
class_id: number
|
||||||
@@ -228,6 +249,8 @@ export interface DatasetDetailResponse {
|
|||||||
name: string
|
name: string
|
||||||
description: string | null
|
description: string | null
|
||||||
status: string
|
status: string
|
||||||
|
training_status: string | null
|
||||||
|
active_training_task_id: string | null
|
||||||
train_ratio: number
|
train_ratio: number
|
||||||
val_ratio: number
|
val_ratio: number
|
||||||
seed: number
|
seed: number
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
|
|||||||
import { Badge } from './Badge'
|
import { Badge } from './Badge'
|
||||||
import { Button } from './Button'
|
import { Button } from './Button'
|
||||||
import { UploadModal } from './UploadModal'
|
import { UploadModal } from './UploadModal'
|
||||||
import { useDocuments } from '../hooks/useDocuments'
|
import { useDocuments, useCategories } from '../hooks/useDocuments'
|
||||||
import type { DocumentItem } from '../api/types'
|
import type { DocumentItem } from '../api/types'
|
||||||
|
|
||||||
interface DashboardProps {
|
interface DashboardProps {
|
||||||
@@ -34,11 +34,15 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
const [isUploadOpen, setIsUploadOpen] = useState(false)
|
const [isUploadOpen, setIsUploadOpen] = useState(false)
|
||||||
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
|
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
|
||||||
const [statusFilter, setStatusFilter] = useState<string>('')
|
const [statusFilter, setStatusFilter] = useState<string>('')
|
||||||
|
const [categoryFilter, setCategoryFilter] = useState<string>('')
|
||||||
const [limit] = useState(20)
|
const [limit] = useState(20)
|
||||||
const [offset] = useState(0)
|
const [offset] = useState(0)
|
||||||
|
|
||||||
|
const { categories } = useCategories()
|
||||||
|
|
||||||
const { documents, total, isLoading, error, refetch } = useDocuments({
|
const { documents, total, isLoading, error, refetch } = useDocuments({
|
||||||
status: statusFilter || undefined,
|
status: statusFilter || undefined,
|
||||||
|
category: categoryFilter || undefined,
|
||||||
limit,
|
limit,
|
||||||
offset,
|
offset,
|
||||||
})
|
})
|
||||||
@@ -102,6 +106,24 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex gap-3">
|
<div className="flex gap-3">
|
||||||
|
<div className="relative">
|
||||||
|
<select
|
||||||
|
value={categoryFilter}
|
||||||
|
onChange={(e) => setCategoryFilter(e.target.value)}
|
||||||
|
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
|
||||||
|
>
|
||||||
|
<option value="">All Categories</option>
|
||||||
|
{categories.map((cat) => (
|
||||||
|
<option key={cat} value={cat}>
|
||||||
|
{cat.charAt(0).toUpperCase() + cat.slice(1)}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
<ChevronDown
|
||||||
|
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||||
|
size={14}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<select
|
<select
|
||||||
value={statusFilter}
|
value={statusFilter}
|
||||||
@@ -144,6 +166,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
Annotations
|
Annotations
|
||||||
</th>
|
</th>
|
||||||
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
|
Category
|
||||||
|
</th>
|
||||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||||
Group
|
Group
|
||||||
</th>
|
</th>
|
||||||
@@ -156,13 +181,13 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
<tbody>
|
<tbody>
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
<tr>
|
<tr>
|
||||||
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
|
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
|
||||||
Loading documents...
|
Loading documents...
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
) : documents.length === 0 ? (
|
) : documents.length === 0 ? (
|
||||||
<tr>
|
<tr>
|
||||||
<td colSpan={8} className="py-8 text-center text-warm-text-muted">
|
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
|
||||||
No documents found. Upload your first document to get started.
|
No documents found. Upload your first document to get started.
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
@@ -216,6 +241,9 @@ export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
|||||||
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
||||||
{doc.annotation_count || 0} annotations
|
{doc.annotation_count || 0} annotations
|
||||||
</td>
|
</td>
|
||||||
|
<td className="py-4 px-4 text-sm text-warm-text-secondary capitalize">
|
||||||
|
{doc.category || 'invoice'}
|
||||||
|
</td>
|
||||||
<td className="py-4 px-4 text-sm text-warm-text-muted">
|
<td className="py-4 px-4 text-sm text-warm-text-muted">
|
||||||
{doc.group_key || '-'}
|
{doc.group_key || '-'}
|
||||||
</td>
|
</td>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import React from 'react'
|
import React from 'react'
|
||||||
import { ArrowLeft, Loader2, Play, AlertCircle, Check } from 'lucide-react'
|
import { ArrowLeft, Loader2, Play, AlertCircle, Check, Award } from 'lucide-react'
|
||||||
import { Button } from './Button'
|
import { Button } from './Button'
|
||||||
import { useDatasetDetail } from '../hooks/useDatasets'
|
import { useDatasetDetail } from '../hooks/useDatasets'
|
||||||
|
|
||||||
@@ -14,6 +14,23 @@ const SPLIT_STYLES: Record<string, string> = {
|
|||||||
test: 'bg-warm-state-success/10 text-warm-state-success',
|
test: 'bg-warm-state-success/10 text-warm-state-success',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
|
||||||
|
building: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Building' },
|
||||||
|
ready: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Ready' },
|
||||||
|
trained: { bg: 'bg-purple-100', text: 'text-purple-700', label: 'Trained' },
|
||||||
|
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
|
||||||
|
archived: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Archived' },
|
||||||
|
}
|
||||||
|
|
||||||
|
const TRAINING_STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
|
||||||
|
pending: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Pending' },
|
||||||
|
scheduled: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Scheduled' },
|
||||||
|
running: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Training' },
|
||||||
|
completed: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Completed' },
|
||||||
|
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
|
||||||
|
cancelled: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Cancelled' },
|
||||||
|
}
|
||||||
|
|
||||||
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
|
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
|
||||||
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
|
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
|
||||||
|
|
||||||
@@ -36,11 +53,25 @@ export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const statusIcon = dataset.status === 'ready'
|
const statusConfig = STATUS_STYLES[dataset.status] || STATUS_STYLES.ready
|
||||||
|
const trainingStatusConfig = dataset.training_status
|
||||||
|
? TRAINING_STATUS_STYLES[dataset.training_status]
|
||||||
|
: null
|
||||||
|
|
||||||
|
// Determine if training button should be shown and enabled
|
||||||
|
const isTrainingInProgress = dataset.training_status === 'running' || dataset.training_status === 'pending'
|
||||||
|
const canStartTraining = dataset.status === 'ready' && !isTrainingInProgress
|
||||||
|
|
||||||
|
// Determine status icon
|
||||||
|
const statusIcon = dataset.status === 'trained'
|
||||||
|
? <Award size={14} className="text-purple-700" />
|
||||||
|
: dataset.status === 'ready'
|
||||||
? <Check size={14} className="text-warm-state-success" />
|
? <Check size={14} className="text-warm-state-success" />
|
||||||
: dataset.status === 'failed'
|
: dataset.status === 'failed'
|
||||||
? <AlertCircle size={14} className="text-warm-state-error" />
|
? <AlertCircle size={14} className="text-warm-state-error" />
|
||||||
: <Loader2 size={14} className="animate-spin text-warm-state-info" />
|
: dataset.status === 'building'
|
||||||
|
? <Loader2 size={14} className="animate-spin text-warm-state-info" />
|
||||||
|
: null
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="p-8 max-w-7xl mx-auto">
|
<div className="p-8 max-w-7xl mx-auto">
|
||||||
@@ -51,15 +82,38 @@ export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack
|
|||||||
|
|
||||||
<div className="flex items-center justify-between mb-6">
|
<div className="flex items-center justify-between mb-6">
|
||||||
<div>
|
<div>
|
||||||
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
|
<div className="flex items-center gap-3 mb-1">
|
||||||
{dataset.name} {statusIcon}
|
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
|
||||||
</h2>
|
{dataset.name} {statusIcon}
|
||||||
|
</h2>
|
||||||
|
{/* Status Badge */}
|
||||||
|
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${statusConfig.bg} ${statusConfig.text}`}>
|
||||||
|
{statusConfig.label}
|
||||||
|
</span>
|
||||||
|
{/* Training Status Badge */}
|
||||||
|
{trainingStatusConfig && (
|
||||||
|
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${trainingStatusConfig.bg} ${trainingStatusConfig.text}`}>
|
||||||
|
{isTrainingInProgress && <Loader2 size={12} className="mr-1 animate-spin" />}
|
||||||
|
{trainingStatusConfig.label}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
{dataset.description && (
|
{dataset.description && (
|
||||||
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
|
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
{dataset.status === 'ready' && (
|
{/* Training Button */}
|
||||||
<Button><Play size={14} className="mr-1" />Start Training</Button>
|
{(dataset.status === 'ready' || dataset.status === 'trained') && (
|
||||||
|
<Button
|
||||||
|
disabled={isTrainingInProgress}
|
||||||
|
className={isTrainingInProgress ? 'opacity-50 cursor-not-allowed' : ''}
|
||||||
|
>
|
||||||
|
{isTrainingInProgress ? (
|
||||||
|
<><Loader2 size={14} className="mr-1 animate-spin" />Training...</>
|
||||||
|
) : (
|
||||||
|
<><Play size={14} className="mr-1" />Start Training</>
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -72,12 +72,13 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
|
|||||||
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
||||||
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
||||||
|
|
||||||
// Fetch available trained models
|
// Fetch available trained models (active or inactive, not archived)
|
||||||
const { data: modelsData } = useQuery({
|
const { data: modelsData } = useQuery({
|
||||||
queryKey: ['training', 'models', 'completed'],
|
queryKey: ['training', 'models', 'available'],
|
||||||
queryFn: () => trainingApi.getModels({ status: 'completed' }),
|
queryFn: () => trainingApi.getModels(),
|
||||||
})
|
})
|
||||||
const completedModels = modelsData?.models ?? []
|
// Filter out archived models - only show active/inactive models for base model selection
|
||||||
|
const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived')
|
||||||
|
|
||||||
const handleSubmit = () => {
|
const handleSubmit = () => {
|
||||||
onSubmit({
|
onSubmit({
|
||||||
@@ -128,9 +129,9 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
|
|||||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||||
>
|
>
|
||||||
<option value="pretrained">yolo11n.pt (Pretrained)</option>
|
<option value="pretrained">yolo11n.pt (Pretrained)</option>
|
||||||
{completedModels.map(m => (
|
{availableModels.map(m => (
|
||||||
<option key={m.task_id} value={m.task_id}>
|
<option key={m.version_id} value={m.version_id}>
|
||||||
{m.name} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
{m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
||||||
</option>
|
</option>
|
||||||
))}
|
))}
|
||||||
</select>
|
</select>
|
||||||
@@ -293,8 +294,12 @@ const DatasetList: React.FC<{
|
|||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
|
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
|
||||||
disabled={isDeleting}
|
disabled={isDeleting || ds.status === 'pending' || ds.status === 'building'}
|
||||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors">
|
className={`p-1.5 rounded transition-colors ${
|
||||||
|
ds.status === 'pending' || ds.status === 'building'
|
||||||
|
? 'text-warm-text-muted/40 cursor-not-allowed'
|
||||||
|
: 'hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error'
|
||||||
|
}`}>
|
||||||
<Trash2 size={14} />
|
<Trash2 size={14} />
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import React, { useState, useRef } from 'react'
|
import React, { useState, useRef } from 'react'
|
||||||
import { X, UploadCloud, File, CheckCircle, AlertCircle } from 'lucide-react'
|
import { X, UploadCloud, File, CheckCircle, AlertCircle, ChevronDown } from 'lucide-react'
|
||||||
import { Button } from './Button'
|
import { Button } from './Button'
|
||||||
import { useDocuments } from '../hooks/useDocuments'
|
import { useDocuments, useCategories } from '../hooks/useDocuments'
|
||||||
|
|
||||||
interface UploadModalProps {
|
interface UploadModalProps {
|
||||||
isOpen: boolean
|
isOpen: boolean
|
||||||
@@ -12,11 +12,13 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
const [isDragging, setIsDragging] = useState(false)
|
const [isDragging, setIsDragging] = useState(false)
|
||||||
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
||||||
const [groupKey, setGroupKey] = useState('')
|
const [groupKey, setGroupKey] = useState('')
|
||||||
|
const [category, setCategory] = useState('invoice')
|
||||||
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
||||||
const [errorMessage, setErrorMessage] = useState('')
|
const [errorMessage, setErrorMessage] = useState('')
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||||
|
|
||||||
const { uploadDocument, isUploading } = useDocuments({})
|
const { uploadDocument, isUploading } = useDocuments({})
|
||||||
|
const { categories } = useCategories()
|
||||||
|
|
||||||
if (!isOpen) return null
|
if (!isOpen) return null
|
||||||
|
|
||||||
@@ -63,7 +65,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
for (const file of selectedFiles) {
|
for (const file of selectedFiles) {
|
||||||
await new Promise<void>((resolve, reject) => {
|
await new Promise<void>((resolve, reject) => {
|
||||||
uploadDocument(
|
uploadDocument(
|
||||||
{ file, groupKey: groupKey || undefined },
|
{ file, groupKey: groupKey || undefined, category: category || 'invoice' },
|
||||||
{
|
{
|
||||||
onSuccess: () => resolve(),
|
onSuccess: () => resolve(),
|
||||||
onError: (error: Error) => reject(error),
|
onError: (error: Error) => reject(error),
|
||||||
@@ -77,6 +79,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
onClose()
|
onClose()
|
||||||
setSelectedFiles([])
|
setSelectedFiles([])
|
||||||
setGroupKey('')
|
setGroupKey('')
|
||||||
|
setCategory('invoice')
|
||||||
setUploadStatus('idle')
|
setUploadStatus('idle')
|
||||||
}, 1500)
|
}, 1500)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -91,6 +94,7 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
}
|
}
|
||||||
setSelectedFiles([])
|
setSelectedFiles([])
|
||||||
setGroupKey('')
|
setGroupKey('')
|
||||||
|
setCategory('invoice')
|
||||||
setUploadStatus('idle')
|
setUploadStatus('idle')
|
||||||
setErrorMessage('')
|
setErrorMessage('')
|
||||||
onClose()
|
onClose()
|
||||||
@@ -179,6 +183,42 @@ export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) =>
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Category Select */}
|
||||||
|
{selectedFiles.length > 0 && (
|
||||||
|
<div className="mb-6">
|
||||||
|
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
|
||||||
|
Category
|
||||||
|
</label>
|
||||||
|
<div className="relative">
|
||||||
|
<select
|
||||||
|
value={category}
|
||||||
|
onChange={(e) => setCategory(e.target.value)}
|
||||||
|
className="w-full h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none cursor-pointer"
|
||||||
|
disabled={uploadStatus === 'uploading'}
|
||||||
|
>
|
||||||
|
<option value="invoice">Invoice</option>
|
||||||
|
<option value="letter">Letter</option>
|
||||||
|
<option value="receipt">Receipt</option>
|
||||||
|
<option value="contract">Contract</option>
|
||||||
|
{categories
|
||||||
|
.filter((cat) => !['invoice', 'letter', 'receipt', 'contract'].includes(cat))
|
||||||
|
.map((cat) => (
|
||||||
|
<option key={cat} value={cat}>
|
||||||
|
{cat.charAt(0).toUpperCase() + cat.slice(1)}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
<ChevronDown
|
||||||
|
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||||
|
size={14}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<p className="text-xs text-warm-text-muted mt-1">
|
||||||
|
Select document type for training different models
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Group Key Input */}
|
{/* Group Key Input */}
|
||||||
{selectedFiles.length > 0 && (
|
{selectedFiles.length > 0 && (
|
||||||
<div className="mb-6">
|
<div className="mb-6">
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
export { useDocuments } from './useDocuments'
|
export { useDocuments, useCategories } from './useDocuments'
|
||||||
export { useDocumentDetail } from './useDocumentDetail'
|
export { useDocumentDetail } from './useDocumentDetail'
|
||||||
export { useAnnotations } from './useAnnotations'
|
export { useAnnotations } from './useAnnotations'
|
||||||
export { useTraining, useTrainingDocuments } from './useTraining'
|
export { useTraining, useTrainingDocuments } from './useTraining'
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||||
import { documentsApi } from '../api/endpoints'
|
import { documentsApi } from '../api/endpoints'
|
||||||
import type { DocumentListResponse, UploadDocumentResponse } from '../api/types'
|
import type { DocumentListResponse, DocumentCategoriesResponse } from '../api/types'
|
||||||
|
|
||||||
interface UseDocumentsParams {
|
interface UseDocumentsParams {
|
||||||
status?: string
|
status?: string
|
||||||
|
category?: string
|
||||||
limit?: number
|
limit?: number
|
||||||
offset?: number
|
offset?: number
|
||||||
}
|
}
|
||||||
@@ -18,10 +19,11 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const uploadMutation = useMutation({
|
const uploadMutation = useMutation({
|
||||||
mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) =>
|
mutationFn: ({ file, groupKey, category }: { file: File; groupKey?: string; category?: string }) =>
|
||||||
documentsApi.upload(file, groupKey),
|
documentsApi.upload(file, { groupKey, category }),
|
||||||
onSuccess: () => {
|
onSuccess: () => {
|
||||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['categories'] })
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -63,6 +65,15 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const updateCategoryMutation = useMutation({
|
||||||
|
mutationFn: ({ documentId, category }: { documentId: string; category: string }) =>
|
||||||
|
documentsApi.updateCategory(documentId, category),
|
||||||
|
onSuccess: () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['categories'] })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
documents: data?.documents || [],
|
documents: data?.documents || [],
|
||||||
total: data?.total || 0,
|
total: data?.total || 0,
|
||||||
@@ -86,5 +97,24 @@ export const useDocuments = (params: UseDocumentsParams = {}) => {
|
|||||||
updateGroupKey: updateGroupKeyMutation.mutate,
|
updateGroupKey: updateGroupKeyMutation.mutate,
|
||||||
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
|
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
|
||||||
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
|
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
|
||||||
|
updateCategory: updateCategoryMutation.mutate,
|
||||||
|
updateCategoryAsync: updateCategoryMutation.mutateAsync,
|
||||||
|
isUpdatingCategory: updateCategoryMutation.isPending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useCategories = () => {
|
||||||
|
const { data, isLoading, error, refetch } = useQuery<DocumentCategoriesResponse>({
|
||||||
|
queryKey: ['categories'],
|
||||||
|
queryFn: () => documentsApi.getCategories(),
|
||||||
|
staleTime: 60000,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
categories: data?.categories || [],
|
||||||
|
total: data?.total || 0,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
refetch,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
13
migrations/009_add_document_category.sql
Normal file
13
migrations/009_add_document_category.sql
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
-- Add category column to admin_documents table
|
||||||
|
-- Allows categorizing documents for training different models (e.g., invoice, letter, receipt)
|
||||||
|
|
||||||
|
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
|
||||||
|
|
||||||
|
-- Update existing NULL values to default
|
||||||
|
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
|
||||||
|
|
||||||
|
-- Make it NOT NULL after setting defaults
|
||||||
|
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
|
||||||
|
|
||||||
|
-- Create index for category filtering
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
|
||||||
28
migrations/010_add_dataset_training_status.sql
Normal file
28
migrations/010_add_dataset_training_status.sql
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
-- Migration: Add training_status and active_training_task_id to training_datasets
|
||||||
|
-- Description: Track training status separately from dataset build status
|
||||||
|
|
||||||
|
-- Add training_status column
|
||||||
|
ALTER TABLE training_datasets
|
||||||
|
ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
|
||||||
|
|
||||||
|
-- Add active_training_task_id column
|
||||||
|
ALTER TABLE training_datasets
|
||||||
|
ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
|
||||||
|
|
||||||
|
-- Create index for training_status
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status
|
||||||
|
ON training_datasets(training_status);
|
||||||
|
|
||||||
|
-- Create index for active_training_task_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id
|
||||||
|
ON training_datasets(active_training_task_id);
|
||||||
|
|
||||||
|
-- Update existing datasets that have been used in completed training tasks to 'trained' status
|
||||||
|
UPDATE training_datasets d
|
||||||
|
SET status = 'trained'
|
||||||
|
WHERE d.status = 'ready'
|
||||||
|
AND EXISTS (
|
||||||
|
SELECT 1 FROM training_tasks t
|
||||||
|
WHERE t.dataset_id = d.dataset_id
|
||||||
|
AND t.status = 'completed'
|
||||||
|
);
|
||||||
@@ -120,7 +120,7 @@ def main() -> None:
|
|||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
|
|
||||||
# Create config
|
# Create config
|
||||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
from inference.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig
|
||||||
|
|
||||||
config = AppConfig(
|
config = AppConfig(
|
||||||
model=ModelConfig(
|
model=ModelConfig(
|
||||||
@@ -136,7 +136,7 @@ def main() -> None:
|
|||||||
reload=args.reload,
|
reload=args.reload,
|
||||||
workers=args.workers,
|
workers=args.workers,
|
||||||
),
|
),
|
||||||
storage=StorageConfig(),
|
file=FileConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create and run app
|
# Create and run app
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class AdminDB:
|
|||||||
upload_source: str = "ui",
|
upload_source: str = "ui",
|
||||||
csv_field_values: dict[str, Any] | None = None,
|
csv_field_values: dict[str, Any] | None = None,
|
||||||
group_key: str | None = None,
|
group_key: str | None = None,
|
||||||
|
category: str = "invoice",
|
||||||
admin_token: str | None = None, # Deprecated, kept for compatibility
|
admin_token: str | None = None, # Deprecated, kept for compatibility
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new document record."""
|
"""Create a new document record."""
|
||||||
@@ -125,6 +126,7 @@ class AdminDB:
|
|||||||
upload_source=upload_source,
|
upload_source=upload_source,
|
||||||
csv_field_values=csv_field_values,
|
csv_field_values=csv_field_values,
|
||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
|
category=category,
|
||||||
)
|
)
|
||||||
session.add(document)
|
session.add(document)
|
||||||
session.flush()
|
session.flush()
|
||||||
@@ -154,6 +156,7 @@ class AdminDB:
|
|||||||
has_annotations: bool | None = None,
|
has_annotations: bool | None = None,
|
||||||
auto_label_status: str | None = None,
|
auto_label_status: str | None = None,
|
||||||
batch_id: str | None = None,
|
batch_id: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> tuple[list[AdminDocument], int]:
|
) -> tuple[list[AdminDocument], int]:
|
||||||
@@ -171,6 +174,8 @@ class AdminDB:
|
|||||||
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
|
where_clauses.append(AdminDocument.auto_label_status == auto_label_status)
|
||||||
if batch_id:
|
if batch_id:
|
||||||
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
|
where_clauses.append(AdminDocument.batch_id == UUID(batch_id))
|
||||||
|
if category:
|
||||||
|
where_clauses.append(AdminDocument.category == category)
|
||||||
|
|
||||||
# Count query
|
# Count query
|
||||||
count_stmt = select(func.count()).select_from(AdminDocument)
|
count_stmt = select(func.count()).select_from(AdminDocument)
|
||||||
@@ -283,6 +288,32 @@ class AdminDB:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_document_categories(self) -> list[str]:
|
||||||
|
"""Get list of unique document categories."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
statement = (
|
||||||
|
select(AdminDocument.category)
|
||||||
|
.distinct()
|
||||||
|
.order_by(AdminDocument.category)
|
||||||
|
)
|
||||||
|
categories = session.exec(statement).all()
|
||||||
|
return [c for c in categories if c is not None]
|
||||||
|
|
||||||
|
def update_document_category(
|
||||||
|
self, document_id: str, category: str
|
||||||
|
) -> AdminDocument | None:
|
||||||
|
"""Update document category."""
|
||||||
|
with get_session_context() as session:
|
||||||
|
document = session.get(AdminDocument, UUID(document_id))
|
||||||
|
if document:
|
||||||
|
document.category = category
|
||||||
|
document.updated_at = datetime.utcnow()
|
||||||
|
session.add(document)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(document)
|
||||||
|
return document
|
||||||
|
return None
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# Annotation Operations
|
# Annotation Operations
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
@@ -1292,6 +1323,36 @@ class AdminDB:
|
|||||||
session.add(dataset)
|
session.add(dataset)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
def update_dataset_training_status(
|
||||||
|
self,
|
||||||
|
dataset_id: str | UUID,
|
||||||
|
training_status: str | None,
|
||||||
|
active_training_task_id: str | UUID | None = None,
|
||||||
|
update_main_status: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Update dataset training status and optionally the main status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset UUID
|
||||||
|
training_status: Training status (pending, running, completed, failed, cancelled)
|
||||||
|
active_training_task_id: Currently active training task ID
|
||||||
|
update_main_status: If True and training_status is 'completed', set main status to 'trained'
|
||||||
|
"""
|
||||||
|
with get_session_context() as session:
|
||||||
|
dataset = session.get(TrainingDataset, UUID(str(dataset_id)))
|
||||||
|
if not dataset:
|
||||||
|
return
|
||||||
|
dataset.training_status = training_status
|
||||||
|
dataset.active_training_task_id = (
|
||||||
|
UUID(str(active_training_task_id)) if active_training_task_id else None
|
||||||
|
)
|
||||||
|
dataset.updated_at = datetime.utcnow()
|
||||||
|
# Update main status to 'trained' when training completes
|
||||||
|
if update_main_status and training_status == "completed":
|
||||||
|
dataset.status = "trained"
|
||||||
|
session.add(dataset)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
def add_dataset_documents(
|
def add_dataset_documents(
|
||||||
self,
|
self,
|
||||||
dataset_id: str | UUID,
|
dataset_id: str | UUID,
|
||||||
|
|||||||
@@ -11,23 +11,8 @@ from uuid import UUID, uuid4
|
|||||||
|
|
||||||
from sqlmodel import Field, SQLModel, Column, JSON
|
from sqlmodel import Field, SQLModel, Column, JSON
|
||||||
|
|
||||||
|
# Import field mappings from single source of truth
|
||||||
# =============================================================================
|
from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS
|
||||||
# CSV to Field Class Mapping
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
CSV_TO_CLASS_MAPPING: dict[str, int] = {
|
|
||||||
"InvoiceNumber": 0, # invoice_number
|
|
||||||
"InvoiceDate": 1, # invoice_date
|
|
||||||
"InvoiceDueDate": 2, # invoice_due_date
|
|
||||||
"OCR": 3, # ocr_number
|
|
||||||
"Bankgiro": 4, # bankgiro
|
|
||||||
"Plusgiro": 5, # plusgiro
|
|
||||||
"Amount": 6, # amount
|
|
||||||
"supplier_organisation_number": 7, # supplier_organisation_number
|
|
||||||
# 8: payment_line (derived from OCR/Bankgiro/Amount)
|
|
||||||
"customer_number": 9, # customer_number
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -72,6 +57,8 @@ class AdminDocument(SQLModel, table=True):
|
|||||||
# Link to batch upload (if uploaded via ZIP)
|
# Link to batch upload (if uploaded via ZIP)
|
||||||
group_key: str | None = Field(default=None, max_length=255, index=True)
|
group_key: str | None = Field(default=None, max_length=255, index=True)
|
||||||
# User-defined grouping key for document organization
|
# User-defined grouping key for document organization
|
||||||
|
category: str = Field(default="invoice", max_length=100, index=True)
|
||||||
|
# Document category for training different models (e.g., invoice, letter, receipt)
|
||||||
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON))
|
||||||
# Original CSV values for reference
|
# Original CSV values for reference
|
||||||
auto_label_queued_at: datetime | None = Field(default=None)
|
auto_label_queued_at: datetime | None = Field(default=None)
|
||||||
@@ -237,7 +224,10 @@ class TrainingDataset(SQLModel, table=True):
|
|||||||
name: str = Field(max_length=255)
|
name: str = Field(max_length=255)
|
||||||
description: str | None = Field(default=None)
|
description: str | None = Field(default=None)
|
||||||
status: str = Field(default="building", max_length=20, index=True)
|
status: str = Field(default="building", max_length=20, index=True)
|
||||||
# Status: building, ready, training, archived, failed
|
# Status: building, ready, trained, archived, failed
|
||||||
|
training_status: str | None = Field(default=None, max_length=20, index=True)
|
||||||
|
# Training status: pending, scheduled, running, completed, failed, cancelled
|
||||||
|
active_training_task_id: UUID | None = Field(default=None, index=True)
|
||||||
train_ratio: float = Field(default=0.8)
|
train_ratio: float = Field(default=0.8)
|
||||||
val_ratio: float = Field(default=0.1)
|
val_ratio: float = Field(default=0.1)
|
||||||
seed: int = Field(default=42)
|
seed: int = Field(default=42)
|
||||||
@@ -354,21 +344,8 @@ class AnnotationHistory(SQLModel, table=True):
|
|||||||
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||||
|
|
||||||
|
|
||||||
# Field class mapping (same as src/cli/train.py)
|
# FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields
|
||||||
FIELD_CLASSES = {
|
# This ensures consistency with the trained YOLO model
|
||||||
0: "invoice_number",
|
|
||||||
1: "invoice_date",
|
|
||||||
2: "invoice_due_date",
|
|
||||||
3: "ocr_number",
|
|
||||||
4: "bankgiro",
|
|
||||||
5: "plusgiro",
|
|
||||||
6: "amount",
|
|
||||||
7: "supplier_organisation_number",
|
|
||||||
8: "payment_line",
|
|
||||||
9: "customer_number",
|
|
||||||
}
|
|
||||||
|
|
||||||
FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()}
|
|
||||||
|
|
||||||
|
|
||||||
# Read-only models for API responses
|
# Read-only models for API responses
|
||||||
@@ -383,6 +360,7 @@ class AdminDocumentRead(SQLModel):
|
|||||||
status: str
|
status: str
|
||||||
auto_label_status: str | None
|
auto_label_status: str | None
|
||||||
auto_label_error: str | None
|
auto_label_error: str | None
|
||||||
|
category: str = "invoice"
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,40 @@ def run_migrations() -> None:
|
|||||||
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id);
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
# Migration 009: Add category to admin_documents
|
||||||
|
(
|
||||||
|
"admin_documents_category",
|
||||||
|
"""
|
||||||
|
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
|
||||||
|
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
|
||||||
|
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
# Migration 010: Add training_status and active_training_task_id to training_datasets
|
||||||
|
(
|
||||||
|
"training_datasets_training_status",
|
||||||
|
"""
|
||||||
|
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
|
||||||
|
ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id);
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
# Migration 010b: Update existing datasets with completed training to 'trained' status
|
||||||
|
(
|
||||||
|
"training_datasets_update_trained_status",
|
||||||
|
"""
|
||||||
|
UPDATE training_datasets d
|
||||||
|
SET status = 'trained'
|
||||||
|
WHERE d.status = 'ready'
|
||||||
|
AND EXISTS (
|
||||||
|
SELECT 1 FROM training_tasks t
|
||||||
|
WHERE t.dataset_id = d.dataset_id
|
||||||
|
AND t.status = 'completed'
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ import re
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .yolo_detector import Detection, CLASS_TO_FIELD
|
from shared.fields import CLASS_TO_FIELD
|
||||||
|
from .yolo_detector import Detection
|
||||||
|
|
||||||
# Import shared utilities for text cleaning and validation
|
# Import shared utilities for text cleaning and validation
|
||||||
from shared.utils.text_cleaner import TextCleaner
|
from shared.utils.text_cleaner import TextCleaner
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from typing import Any
|
|||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
|
from shared.fields import CLASS_TO_FIELD
|
||||||
|
from .yolo_detector import YOLODetector, Detection
|
||||||
from .field_extractor import FieldExtractor, ExtractedField
|
from .field_extractor import FieldExtractor, ExtractedField
|
||||||
from .payment_line_parser import PaymentLineParser
|
from .payment_line_parser import PaymentLineParser
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# Import field mappings from single source of truth
|
||||||
|
from shared.fields import CLASS_NAMES, CLASS_TO_FIELD
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Detection:
|
class Detection:
|
||||||
@@ -72,33 +75,8 @@ class Detection:
|
|||||||
return (x0, y0, x1, y1)
|
return (x0, y0, x1, y1)
|
||||||
|
|
||||||
|
|
||||||
# Class names (must match training configuration)
|
# CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields
|
||||||
CLASS_NAMES = [
|
# This ensures consistency with the trained YOLO model
|
||||||
'invoice_number',
|
|
||||||
'invoice_date',
|
|
||||||
'invoice_due_date',
|
|
||||||
'ocr_number',
|
|
||||||
'bankgiro',
|
|
||||||
'plusgiro',
|
|
||||||
'amount',
|
|
||||||
'supplier_org_number', # Matches training class name
|
|
||||||
'customer_number',
|
|
||||||
'payment_line', # Machine code payment line at bottom of invoice
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mapping from class name to field name
|
|
||||||
CLASS_TO_FIELD = {
|
|
||||||
'invoice_number': 'InvoiceNumber',
|
|
||||||
'invoice_date': 'InvoiceDate',
|
|
||||||
'invoice_due_date': 'InvoiceDueDate',
|
|
||||||
'ocr_number': 'OCR',
|
|
||||||
'bankgiro': 'Bankgiro',
|
|
||||||
'plusgiro': 'Plusgiro',
|
|
||||||
'amount': 'Amount',
|
|
||||||
'supplier_org_number': 'supplier_org_number',
|
|
||||||
'customer_number': 'customer_number',
|
|
||||||
'payment_line': 'payment_line',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class YOLODetector:
|
class YOLODetector:
|
||||||
|
|||||||
@@ -4,18 +4,19 @@ Admin Annotation API Routes
|
|||||||
FastAPI endpoints for annotation management.
|
FastAPI endpoints for annotation management.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
|
||||||
from inference.data.admin_db import AdminDB
|
from inference.data.admin_db import AdminDB
|
||||||
from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS
|
from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS
|
||||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
from inference.web.services.autolabel import get_auto_label_service
|
from inference.web.services.autolabel import get_auto_label_service
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
from inference.web.schemas.admin import (
|
from inference.web.schemas.admin import (
|
||||||
AnnotationCreate,
|
AnnotationCreate,
|
||||||
AnnotationItem,
|
AnnotationItem,
|
||||||
@@ -35,9 +36,6 @@ from inference.web.schemas.common import ErrorResponse
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Image storage directory
|
|
||||||
ADMIN_IMAGES_DIR = Path("data/admin_images")
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||||
"""Validate UUID format."""
|
"""Validate UUID format."""
|
||||||
@@ -60,7 +58,9 @@ def create_annotation_router() -> APIRouter:
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{document_id}/images/{page_number}",
|
"/{document_id}/images/{page_number}",
|
||||||
|
response_model=None,
|
||||||
responses={
|
responses={
|
||||||
|
200: {"content": {"image/png": {}}, "description": "Page image"},
|
||||||
401: {"model": ErrorResponse, "description": "Invalid token"},
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
404: {"model": ErrorResponse, "description": "Not found"},
|
404: {"model": ErrorResponse, "description": "Not found"},
|
||||||
},
|
},
|
||||||
@@ -72,7 +72,7 @@ def create_annotation_router() -> APIRouter:
|
|||||||
page_number: int,
|
page_number: int,
|
||||||
admin_token: AdminTokenDep,
|
admin_token: AdminTokenDep,
|
||||||
db: AdminDBDep,
|
db: AdminDBDep,
|
||||||
) -> FileResponse:
|
) -> FileResponse | StreamingResponse:
|
||||||
"""Get page image."""
|
"""Get page image."""
|
||||||
_validate_uuid(document_id, "document_id")
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
@@ -91,18 +91,33 @@ def create_annotation_router() -> APIRouter:
|
|||||||
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
detail=f"Page {page_number} not found. Document has {document.page_count} pages.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find image file
|
# Get storage helper
|
||||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png"
|
storage = get_storage_helper()
|
||||||
if not image_path.exists():
|
|
||||||
|
# Check if image exists
|
||||||
|
if not storage.admin_image_exists(document_id, page_number):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Image for page {page_number} not found",
|
detail=f"Image for page {page_number} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return FileResponse(
|
# Try to get local path for efficient file serving
|
||||||
path=str(image_path),
|
local_path = storage.get_admin_image_local_path(document_id, page_number)
|
||||||
|
if local_path is not None:
|
||||||
|
return FileResponse(
|
||||||
|
path=str(local_path),
|
||||||
|
media_type="image/png",
|
||||||
|
filename=f"{document.filename}_page_{page_number}.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fall back to streaming for cloud storage
|
||||||
|
image_content = storage.get_admin_image(document_id, page_number)
|
||||||
|
return StreamingResponse(
|
||||||
|
io.BytesIO(image_content),
|
||||||
media_type="image/png",
|
media_type="image/png",
|
||||||
filename=f"{document.filename}_page_{page_number}.png",
|
headers={
|
||||||
|
"Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"'
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
@@ -210,16 +225,14 @@ def create_annotation_router() -> APIRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get image dimensions for normalization
|
# Get image dimensions for normalization
|
||||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png"
|
storage = get_storage_helper()
|
||||||
if not image_path.exists():
|
dimensions = storage.get_admin_image_dimensions(document_id, request.page_number)
|
||||||
|
if dimensions is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Image for page {request.page_number} not available",
|
detail=f"Image for page {request.page_number} not available",
|
||||||
)
|
)
|
||||||
|
image_width, image_height = dimensions
|
||||||
from PIL import Image
|
|
||||||
with Image.open(image_path) as img:
|
|
||||||
image_width, image_height = img.size
|
|
||||||
|
|
||||||
# Calculate normalized coordinates
|
# Calculate normalized coordinates
|
||||||
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
x_center = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||||
@@ -315,10 +328,14 @@ def create_annotation_router() -> APIRouter:
|
|||||||
|
|
||||||
if request.bbox is not None:
|
if request.bbox is not None:
|
||||||
# Get image dimensions
|
# Get image dimensions
|
||||||
image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png"
|
storage = get_storage_helper()
|
||||||
from PIL import Image
|
dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number)
|
||||||
with Image.open(image_path) as img:
|
if dimensions is None:
|
||||||
image_width, image_height = img.size
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Image for page {annotation.page_number} not available",
|
||||||
|
)
|
||||||
|
image_width, image_height = dimensions
|
||||||
|
|
||||||
# Calculate normalized coordinates
|
# Calculate normalized coordinates
|
||||||
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width
|
||||||
|
|||||||
@@ -13,16 +13,19 @@ from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
|||||||
|
|
||||||
from inference.web.config import DEFAULT_DPI, StorageConfig
|
from inference.web.config import DEFAULT_DPI, StorageConfig
|
||||||
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
from inference.web.core.auth import AdminTokenDep, AdminDBDep
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
from inference.web.schemas.admin import (
|
from inference.web.schemas.admin import (
|
||||||
AnnotationItem,
|
AnnotationItem,
|
||||||
AnnotationSource,
|
AnnotationSource,
|
||||||
AutoLabelStatus,
|
AutoLabelStatus,
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
|
DocumentCategoriesResponse,
|
||||||
DocumentDetailResponse,
|
DocumentDetailResponse,
|
||||||
DocumentItem,
|
DocumentItem,
|
||||||
DocumentListResponse,
|
DocumentListResponse,
|
||||||
DocumentStatus,
|
DocumentStatus,
|
||||||
DocumentStatsResponse,
|
DocumentStatsResponse,
|
||||||
|
DocumentUpdateRequest,
|
||||||
DocumentUploadResponse,
|
DocumentUploadResponse,
|
||||||
ModelMetrics,
|
ModelMetrics,
|
||||||
TrainingHistoryItem,
|
TrainingHistoryItem,
|
||||||
@@ -44,14 +47,12 @@ def _validate_uuid(value: str, name: str = "ID") -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _convert_pdf_to_images(
|
def _convert_pdf_to_images(
|
||||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
document_id: str, content: bytes, page_count: int, dpi: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Convert PDF pages to images for annotation."""
|
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||||
import fitz
|
import fitz
|
||||||
|
|
||||||
doc_images_dir = images_dir / document_id
|
storage = get_storage_helper()
|
||||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||||
|
|
||||||
for page_num in range(page_count):
|
for page_num in range(page_count):
|
||||||
@@ -60,8 +61,9 @@ def _convert_pdf_to_images(
|
|||||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||||
pix = page.get_pixmap(matrix=mat)
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
|
||||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
# Save to storage using StorageHelper
|
||||||
pix.save(str(image_path))
|
image_bytes = pix.tobytes("png")
|
||||||
|
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||||
|
|
||||||
pdf_doc.close()
|
pdf_doc.close()
|
||||||
|
|
||||||
@@ -95,6 +97,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
str | None,
|
str | None,
|
||||||
Query(description="Optional group key for document organization", max_length=255),
|
Query(description="Optional group key for document organization", max_length=255),
|
||||||
] = None,
|
] = None,
|
||||||
|
category: Annotated[
|
||||||
|
str,
|
||||||
|
Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100),
|
||||||
|
] = "invoice",
|
||||||
) -> DocumentUploadResponse:
|
) -> DocumentUploadResponse:
|
||||||
"""Upload a document for labeling."""
|
"""Upload a document for labeling."""
|
||||||
# Validate group_key length
|
# Validate group_key length
|
||||||
@@ -143,31 +149,33 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
file_path="", # Will update after saving
|
file_path="", # Will update after saving
|
||||||
page_count=page_count,
|
page_count=page_count,
|
||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
|
category=category,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save file to admin uploads
|
# Save file to storage using StorageHelper
|
||||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
storage = get_storage_helper()
|
||||||
|
filename = f"{document_id}{file_ext}"
|
||||||
try:
|
try:
|
||||||
file_path.write_bytes(content)
|
storage_path = storage.save_raw_pdf(content, filename)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save file: {e}")
|
logger.error(f"Failed to save file: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to save file")
|
raise HTTPException(status_code=500, detail="Failed to save file")
|
||||||
|
|
||||||
# Update file path in database
|
# Update file path in database (using storage path for reference)
|
||||||
from inference.data.database import get_session_context
|
from inference.data.database import get_session_context
|
||||||
from inference.data.admin_models import AdminDocument
|
from inference.data.admin_models import AdminDocument
|
||||||
with get_session_context() as session:
|
with get_session_context() as session:
|
||||||
doc = session.get(AdminDocument, UUID(document_id))
|
doc = session.get(AdminDocument, UUID(document_id))
|
||||||
if doc:
|
if doc:
|
||||||
doc.file_path = str(file_path)
|
# Store the storage path (relative path within storage)
|
||||||
|
doc.file_path = storage_path
|
||||||
session.add(doc)
|
session.add(doc)
|
||||||
|
|
||||||
# Convert PDF to images for annotation
|
# Convert PDF to images for annotation
|
||||||
if file_ext == ".pdf":
|
if file_ext == ".pdf":
|
||||||
try:
|
try:
|
||||||
_convert_pdf_to_images(
|
_convert_pdf_to_images(
|
||||||
document_id, content, page_count,
|
document_id, content, page_count, storage_config.dpi
|
||||||
storage_config.admin_images_dir, storage_config.dpi
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert PDF to images: {e}")
|
logger.error(f"Failed to convert PDF to images: {e}")
|
||||||
@@ -189,6 +197,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
file_size=len(content),
|
file_size=len(content),
|
||||||
page_count=page_count,
|
page_count=page_count,
|
||||||
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING,
|
||||||
|
category=category,
|
||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
auto_label_started=auto_label_started,
|
auto_label_started=auto_label_started,
|
||||||
message="Document uploaded successfully",
|
message="Document uploaded successfully",
|
||||||
@@ -226,6 +235,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
str | None,
|
str | None,
|
||||||
Query(description="Filter by batch ID"),
|
Query(description="Filter by batch ID"),
|
||||||
] = None,
|
] = None,
|
||||||
|
category: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(description="Filter by document category"),
|
||||||
|
] = None,
|
||||||
limit: Annotated[
|
limit: Annotated[
|
||||||
int,
|
int,
|
||||||
Query(ge=1, le=100, description="Page size"),
|
Query(ge=1, le=100, description="Page size"),
|
||||||
@@ -264,6 +277,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
has_annotations=has_annotations,
|
has_annotations=has_annotations,
|
||||||
auto_label_status=auto_label_status,
|
auto_label_status=auto_label_status,
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
|
category=category,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
)
|
)
|
||||||
@@ -291,6 +305,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui",
|
||||||
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None,
|
||||||
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
group_key=doc.group_key if hasattr(doc, 'group_key') else None,
|
||||||
|
category=doc.category if hasattr(doc, 'category') else "invoice",
|
||||||
can_annotate=can_annotate,
|
can_annotate=can_annotate,
|
||||||
created_at=doc.created_at,
|
created_at=doc.created_at,
|
||||||
updated_at=doc.updated_at,
|
updated_at=doc.updated_at,
|
||||||
@@ -436,6 +451,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui",
|
||||||
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None,
|
||||||
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
group_key=document.group_key if hasattr(document, 'group_key') else None,
|
||||||
|
category=document.category if hasattr(document, 'category') else "invoice",
|
||||||
csv_field_values=csv_field_values,
|
csv_field_values=csv_field_values,
|
||||||
can_annotate=can_annotate,
|
can_annotate=can_annotate,
|
||||||
annotation_lock_until=annotation_lock_until,
|
annotation_lock_until=annotation_lock_until,
|
||||||
@@ -471,16 +487,22 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
detail="Document not found or does not belong to this token",
|
detail="Document not found or does not belong to this token",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete file
|
# Delete file using StorageHelper
|
||||||
file_path = Path(document.file_path)
|
storage = get_storage_helper()
|
||||||
if file_path.exists():
|
|
||||||
file_path.unlink()
|
|
||||||
|
|
||||||
# Delete images
|
# Delete the raw PDF
|
||||||
images_dir = ADMIN_IMAGES_DIR / document_id
|
filename = Path(document.file_path).name
|
||||||
if images_dir.exists():
|
if filename:
|
||||||
import shutil
|
try:
|
||||||
shutil.rmtree(images_dir)
|
storage._storage.delete(document.file_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete PDF file: {e}")
|
||||||
|
|
||||||
|
# Delete admin images
|
||||||
|
try:
|
||||||
|
storage.delete_admin_images(document_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete admin images: {e}")
|
||||||
|
|
||||||
# Delete from database
|
# Delete from database
|
||||||
db.delete_document(document_id)
|
db.delete_document(document_id)
|
||||||
@@ -609,4 +631,61 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
|||||||
"message": "Document group key updated",
|
"message": "Document group key updated",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/categories",
|
||||||
|
response_model=DocumentCategoriesResponse,
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
},
|
||||||
|
summary="Get available categories",
|
||||||
|
description="Get list of all available document categories.",
|
||||||
|
)
|
||||||
|
async def get_categories(
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
) -> DocumentCategoriesResponse:
|
||||||
|
"""Get all available document categories."""
|
||||||
|
categories = db.get_document_categories()
|
||||||
|
return DocumentCategoriesResponse(
|
||||||
|
categories=categories,
|
||||||
|
total=len(categories),
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{document_id}/category",
|
||||||
|
responses={
|
||||||
|
401: {"model": ErrorResponse, "description": "Invalid token"},
|
||||||
|
404: {"model": ErrorResponse, "description": "Document not found"},
|
||||||
|
},
|
||||||
|
summary="Update document category",
|
||||||
|
description="Update the category for a document.",
|
||||||
|
)
|
||||||
|
async def update_document_category(
|
||||||
|
document_id: str,
|
||||||
|
admin_token: AdminTokenDep,
|
||||||
|
db: AdminDBDep,
|
||||||
|
request: DocumentUpdateRequest,
|
||||||
|
) -> dict:
|
||||||
|
"""Update document category."""
|
||||||
|
_validate_uuid(document_id, "document_id")
|
||||||
|
|
||||||
|
# Verify document exists
|
||||||
|
document = db.get_document_by_token(document_id, admin_token)
|
||||||
|
if document is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Document not found or does not belong to this token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update category if provided
|
||||||
|
if request.category is not None:
|
||||||
|
db.update_document_category(document_id, request.category)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "updated",
|
||||||
|
"document_id": document_id,
|
||||||
|
"category": request.category,
|
||||||
|
"message": "Document category updated",
|
||||||
|
}
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from inference.web.schemas.admin import (
|
|||||||
TrainingStatus,
|
TrainingStatus,
|
||||||
TrainingTaskResponse,
|
TrainingTaskResponse,
|
||||||
)
|
)
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
from ._utils import _validate_uuid
|
from ._utils import _validate_uuid
|
||||||
|
|
||||||
@@ -38,7 +39,6 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
db: AdminDBDep,
|
db: AdminDBDep,
|
||||||
) -> DatasetResponse:
|
) -> DatasetResponse:
|
||||||
"""Create a training dataset from document IDs."""
|
"""Create a training dataset from document IDs."""
|
||||||
from pathlib import Path
|
|
||||||
from inference.web.services.dataset_builder import DatasetBuilder
|
from inference.web.services.dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
# Validate minimum document count for proper train/val/test split
|
# Validate minimum document count for proper train/val/test split
|
||||||
@@ -56,7 +56,18 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
seed=request.seed,
|
seed=request.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
builder = DatasetBuilder(db=db, base_dir=Path("data/datasets"))
|
# Get storage paths from StorageHelper
|
||||||
|
storage = get_storage_helper()
|
||||||
|
datasets_dir = storage.get_datasets_base_path()
|
||||||
|
admin_images_dir = storage.get_admin_images_base_path()
|
||||||
|
|
||||||
|
if datasets_dir is None or admin_images_dir is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Storage not configured for local access",
|
||||||
|
)
|
||||||
|
|
||||||
|
builder = DatasetBuilder(db=db, base_dir=datasets_dir)
|
||||||
try:
|
try:
|
||||||
builder.build_dataset(
|
builder.build_dataset(
|
||||||
dataset_id=str(dataset.dataset_id),
|
dataset_id=str(dataset.dataset_id),
|
||||||
@@ -64,7 +75,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
train_ratio=request.train_ratio,
|
train_ratio=request.train_ratio,
|
||||||
val_ratio=request.val_ratio,
|
val_ratio=request.val_ratio,
|
||||||
seed=request.seed,
|
seed=request.seed,
|
||||||
admin_images_dir=Path("data/admin_images"),
|
admin_images_dir=admin_images_dir,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
@@ -142,6 +153,12 @@ def register_dataset_routes(router: APIRouter) -> None:
|
|||||||
name=dataset.name,
|
name=dataset.name,
|
||||||
description=dataset.description,
|
description=dataset.description,
|
||||||
status=dataset.status,
|
status=dataset.status,
|
||||||
|
training_status=dataset.training_status,
|
||||||
|
active_training_task_id=(
|
||||||
|
str(dataset.active_training_task_id)
|
||||||
|
if dataset.active_training_task_id
|
||||||
|
else None
|
||||||
|
),
|
||||||
train_ratio=dataset.train_ratio,
|
train_ratio=dataset.train_ratio,
|
||||||
val_ratio=dataset.val_ratio,
|
val_ratio=dataset.val_ratio,
|
||||||
seed=dataset.seed,
|
seed=dataset.seed,
|
||||||
|
|||||||
@@ -34,8 +34,10 @@ def register_export_routes(router: APIRouter) -> None:
|
|||||||
db: AdminDBDep,
|
db: AdminDBDep,
|
||||||
) -> ExportResponse:
|
) -> ExportResponse:
|
||||||
"""Export annotations for training."""
|
"""Export annotations for training."""
|
||||||
from pathlib import Path
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
import shutil
|
|
||||||
|
# Get storage helper for reading images and exports directory
|
||||||
|
storage = get_storage_helper()
|
||||||
|
|
||||||
if request.format not in ("yolo", "coco", "voc"):
|
if request.format not in ("yolo", "coco", "voc"):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -51,7 +53,14 @@ def register_export_routes(router: APIRouter) -> None:
|
|||||||
detail="No labeled documents available for export",
|
detail="No labeled documents available for export",
|
||||||
)
|
)
|
||||||
|
|
||||||
export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
# Get exports directory from StorageHelper
|
||||||
|
exports_base = storage.get_exports_base_path()
|
||||||
|
if exports_base is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Storage not configured for local access",
|
||||||
|
)
|
||||||
|
export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
||||||
export_dir.mkdir(parents=True, exist_ok=True)
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
(export_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
||||||
@@ -80,13 +89,16 @@ def register_export_routes(router: APIRouter) -> None:
|
|||||||
if not page_annotations and not request.include_images:
|
if not page_annotations and not request.include_images:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
# Get image from storage
|
||||||
if not src_image.exists():
|
doc_id = str(doc.document_id)
|
||||||
|
if not storage.admin_image_exists(doc_id, page_num):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Download image and save to export directory
|
||||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||||
dst_image = export_dir / "images" / split / image_name
|
dst_image = export_dir / "images" / split / image_name
|
||||||
shutil.copy(src_image, dst_image)
|
image_content = storage.get_admin_image(doc_id, page_num)
|
||||||
|
dst_image.write_bytes(image_content)
|
||||||
total_images += 1
|
total_images += 1
|
||||||
|
|
||||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||||
@@ -98,7 +110,7 @@ def register_export_routes(router: APIRouter) -> None:
|
|||||||
f.write(line)
|
f.write(line)
|
||||||
total_annotations += 1
|
total_annotations += 1
|
||||||
|
|
||||||
from inference.data.admin_models import FIELD_CLASSES
|
from shared.fields import FIELD_CLASSES
|
||||||
|
|
||||||
yaml_content = f"""# Auto-generated YOLO dataset config
|
yaml_content = f"""# Auto-generated YOLO dataset config
|
||||||
path: {export_dir.absolute()}
|
path: {export_dir.absolute()}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from inference.web.schemas.inference import (
|
|||||||
InferenceResult,
|
InferenceResult,
|
||||||
)
|
)
|
||||||
from inference.web.schemas.common import ErrorResponse
|
from inference.web.schemas.common import ErrorResponse
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from inference.web.services import InferenceService
|
from inference.web.services import InferenceService
|
||||||
@@ -90,8 +91,17 @@ def create_inference_router(
|
|||||||
# Generate document ID
|
# Generate document ID
|
||||||
doc_id = str(uuid.uuid4())[:8]
|
doc_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
# Save uploaded file
|
# Get storage helper and uploads directory
|
||||||
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
|
storage = get_storage_helper()
|
||||||
|
uploads_dir = storage.get_uploads_base_path(subfolder="inference")
|
||||||
|
if uploads_dir is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Storage not configured for local access",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save uploaded file to temporary location for processing
|
||||||
|
upload_path = uploads_dir / f"{doc_id}{file_ext}"
|
||||||
try:
|
try:
|
||||||
with open(upload_path, "wb") as f:
|
with open(upload_path, "wb") as f:
|
||||||
shutil.copyfileobj(file.file, f)
|
shutil.copyfileobj(file.file, f)
|
||||||
@@ -149,12 +159,13 @@ def create_inference_router(
|
|||||||
# Cleanup uploaded file
|
# Cleanup uploaded file
|
||||||
upload_path.unlink(missing_ok=True)
|
upload_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
@router.get("/results/{filename}")
|
@router.get("/results/{filename}", response_model=None)
|
||||||
async def get_result_image(filename: str) -> FileResponse:
|
async def get_result_image(filename: str) -> FileResponse:
|
||||||
"""Get visualization result image."""
|
"""Get visualization result image."""
|
||||||
file_path = storage_config.result_dir / filename
|
storage = get_storage_helper()
|
||||||
|
file_path = storage.get_result_local_path(filename)
|
||||||
|
|
||||||
if not file_path.exists():
|
if file_path is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=f"Result file not found: {filename}",
|
detail=f"Result file not found: {filename}",
|
||||||
@@ -169,15 +180,15 @@ def create_inference_router(
|
|||||||
@router.delete("/results/{filename}")
|
@router.delete("/results/{filename}")
|
||||||
async def delete_result(filename: str) -> dict:
|
async def delete_result(filename: str) -> dict:
|
||||||
"""Delete a result file."""
|
"""Delete a result file."""
|
||||||
file_path = storage_config.result_dir / filename
|
storage = get_storage_helper()
|
||||||
|
|
||||||
if not file_path.exists():
|
if not storage.result_exists(filename):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=f"Result file not found: {filename}",
|
detail=f"Result file not found: {filename}",
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path.unlink()
|
storage.delete_result(filename)
|
||||||
return {"status": "deleted", "filename": filename}
|
return {"status": "deleted", "filename": filename}
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
|
|||||||
from inference.data.admin_db import AdminDB
|
from inference.data.admin_db import AdminDB
|
||||||
from inference.web.schemas.labeling import PreLabelResponse
|
from inference.web.schemas.labeling import PreLabelResponse
|
||||||
from inference.web.schemas.common import ErrorResponse
|
from inference.web.schemas.common import ErrorResponse
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from inference.web.services import InferenceService
|
from inference.web.services import InferenceService
|
||||||
@@ -23,19 +24,14 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Storage directory for pre-label uploads (legacy, now uses storage_config)
|
|
||||||
PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads")
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_pdf_to_images(
|
def _convert_pdf_to_images(
|
||||||
document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int
|
document_id: str, content: bytes, page_count: int, dpi: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Convert PDF pages to images for annotation."""
|
"""Convert PDF pages to images for annotation using StorageHelper."""
|
||||||
import fitz
|
import fitz
|
||||||
|
|
||||||
doc_images_dir = images_dir / document_id
|
storage = get_storage_helper()
|
||||||
doc_images_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
pdf_doc = fitz.open(stream=content, filetype="pdf")
|
||||||
|
|
||||||
for page_num in range(page_count):
|
for page_num in range(page_count):
|
||||||
@@ -43,8 +39,9 @@ def _convert_pdf_to_images(
|
|||||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||||
pix = page.get_pixmap(matrix=mat)
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
|
||||||
image_path = doc_images_dir / f"page_{page_num + 1}.png"
|
# Save to storage using StorageHelper
|
||||||
pix.save(str(image_path))
|
image_bytes = pix.tobytes("png")
|
||||||
|
storage.save_admin_image(document_id, page_num + 1, image_bytes)
|
||||||
|
|
||||||
pdf_doc.close()
|
pdf_doc.close()
|
||||||
|
|
||||||
@@ -70,9 +67,6 @@ def create_labeling_router(
|
|||||||
"""
|
"""
|
||||||
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
router = APIRouter(prefix="/api/v1", tags=["labeling"])
|
||||||
|
|
||||||
# Ensure upload directory exists
|
|
||||||
PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/pre-label",
|
"/pre-label",
|
||||||
response_model=PreLabelResponse,
|
response_model=PreLabelResponse,
|
||||||
@@ -165,10 +159,11 @@ def create_labeling_router(
|
|||||||
csv_field_values=expected_values,
|
csv_field_values=expected_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save file to admin uploads
|
# Save file to storage using StorageHelper
|
||||||
file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}"
|
storage = get_storage_helper()
|
||||||
|
filename = f"{document_id}{file_ext}"
|
||||||
try:
|
try:
|
||||||
file_path.write_bytes(content)
|
storage_path = storage.save_raw_pdf(content, filename)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save file: {e}")
|
logger.error(f"Failed to save file: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -176,15 +171,14 @@ def create_labeling_router(
|
|||||||
detail="Failed to save file",
|
detail="Failed to save file",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update file path in database
|
# Update file path in database (using storage path)
|
||||||
db.update_document_file_path(document_id, str(file_path))
|
db.update_document_file_path(document_id, storage_path)
|
||||||
|
|
||||||
# Convert PDF to images for annotation UI
|
# Convert PDF to images for annotation UI
|
||||||
if file_ext == ".pdf":
|
if file_ext == ".pdf":
|
||||||
try:
|
try:
|
||||||
_convert_pdf_to_images(
|
_convert_pdf_to_images(
|
||||||
document_id, content, page_count,
|
document_id, content, page_count, storage_config.dpi
|
||||||
storage_config.admin_images_dir, storage_config.dpi
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert PDF to images: {e}")
|
logger.error(f"Failed to convert PDF to images: {e}")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from fastapi.responses import HTMLResponse
|
|||||||
|
|
||||||
from .config import AppConfig, default_config
|
from .config import AppConfig, default_config
|
||||||
from inference.web.services import InferenceService
|
from inference.web.services import InferenceService
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
# Public API imports
|
# Public API imports
|
||||||
from inference.web.api.v1.public import (
|
from inference.web.api.v1.public import (
|
||||||
@@ -238,13 +239,17 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mount static files for results
|
# Mount static files for results using StorageHelper
|
||||||
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
|
storage = get_storage_helper()
|
||||||
app.mount(
|
results_dir = storage.get_results_base_path()
|
||||||
"/static/results",
|
if results_dir:
|
||||||
StaticFiles(directory=str(config.storage.result_dir)),
|
app.mount(
|
||||||
name="results",
|
"/static/results",
|
||||||
)
|
StaticFiles(directory=str(results_dir)),
|
||||||
|
name="results",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("Could not mount static results directory: local storage not available")
|
||||||
|
|
||||||
# Include public API routes
|
# Include public API routes
|
||||||
inference_router = create_inference_router(inference_service, config.storage)
|
inference_router = create_inference_router(inference_service, config.storage)
|
||||||
|
|||||||
@@ -4,16 +4,49 @@ Web Application Configuration
|
|||||||
Centralized configuration for the web application.
|
Centralized configuration for the web application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from shared.config import DEFAULT_DPI, PATHS
|
from shared.config import DEFAULT_DPI
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_backend(
|
||||||
|
config_path: Path | str | None = None,
|
||||||
|
) -> "StorageBackend":
|
||||||
|
"""Get storage backend for file operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Optional path to storage configuration file.
|
||||||
|
If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured StorageBackend instance.
|
||||||
|
"""
|
||||||
|
from shared.storage import get_storage_backend as _get_storage_backend
|
||||||
|
|
||||||
|
# Check for config file path
|
||||||
|
if config_path is None:
|
||||||
|
config_path_str = os.environ.get("STORAGE_CONFIG_PATH")
|
||||||
|
if config_path_str:
|
||||||
|
config_path = Path(config_path_str)
|
||||||
|
|
||||||
|
return _get_storage_backend(config_path=config_path)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
"""YOLO model configuration."""
|
"""YOLO model configuration.
|
||||||
|
|
||||||
|
Note: Model files are stored locally (not in STORAGE_BASE_PATH) because:
|
||||||
|
- Models need to be accessible by inference service on any platform
|
||||||
|
- Models may be version-controlled or deployed separately
|
||||||
|
- Models are part of the application, not user data
|
||||||
|
"""
|
||||||
|
|
||||||
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
model_path: Path = Path("runs/train/invoice_fields/weights/best.pt")
|
||||||
confidence_threshold: float = 0.5
|
confidence_threshold: float = 0.5
|
||||||
@@ -33,24 +66,39 @@ class ServerConfig:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class StorageConfig:
|
class FileConfig:
|
||||||
"""File storage configuration.
|
"""File handling configuration.
|
||||||
|
|
||||||
Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored
|
This config holds file handling settings. For file operations,
|
||||||
directly in raw_pdfs directory. This ensures consistency with CLI autolabel
|
use the storage backend with PREFIXES from shared.storage.prefixes.
|
||||||
and avoids storing duplicate files.
|
|
||||||
|
Example:
|
||||||
|
from shared.storage import PREFIXES, get_storage_backend
|
||||||
|
|
||||||
|
storage = get_storage_backend()
|
||||||
|
path = PREFIXES.document_path(document_id)
|
||||||
|
storage.upload_bytes(content, path)
|
||||||
|
|
||||||
|
Note: The path fields (upload_dir, result_dir, etc.) are deprecated.
|
||||||
|
They are kept for backward compatibility with existing code and tests.
|
||||||
|
New code should use the storage backend with PREFIXES instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
upload_dir: Path = Path("uploads")
|
|
||||||
result_dir: Path = Path("results")
|
|
||||||
admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"]))
|
|
||||||
admin_images_dir: Path = Path("data/admin_images")
|
|
||||||
max_file_size_mb: int = 50
|
max_file_size_mb: int = 50
|
||||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||||
dpi: int = DEFAULT_DPI
|
dpi: int = DEFAULT_DPI
|
||||||
|
presigned_url_expiry_seconds: int = 3600
|
||||||
|
|
||||||
|
# Deprecated path fields - kept for backward compatibility
|
||||||
|
# New code should use storage backend with PREFIXES instead
|
||||||
|
# All paths are now under data/ to match WSL storage layout
|
||||||
|
upload_dir: Path = field(default_factory=lambda: Path("data/uploads"))
|
||||||
|
result_dir: Path = field(default_factory=lambda: Path("data/results"))
|
||||||
|
admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs"))
|
||||||
|
admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images"))
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Create directories if they don't exist."""
|
"""Create directories if they don't exist (for backward compatibility)."""
|
||||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||||
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir))
|
||||||
@@ -61,9 +109,17 @@ class StorageConfig:
|
|||||||
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
self.admin_images_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
StorageConfig = FileConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class AsyncConfig:
|
class AsyncConfig:
|
||||||
"""Async processing configuration."""
|
"""Async processing configuration.
|
||||||
|
|
||||||
|
Note: For file paths, use the storage backend with PREFIXES.
|
||||||
|
Example: PREFIXES.upload_path(filename, "async")
|
||||||
|
"""
|
||||||
|
|
||||||
# Queue settings
|
# Queue settings
|
||||||
queue_max_size: int = 100
|
queue_max_size: int = 100
|
||||||
@@ -77,14 +133,17 @@ class AsyncConfig:
|
|||||||
|
|
||||||
# Storage
|
# Storage
|
||||||
result_retention_days: int = 7
|
result_retention_days: int = 7
|
||||||
temp_upload_dir: Path = Path("uploads/async")
|
|
||||||
max_file_size_mb: int = 50
|
max_file_size_mb: int = 50
|
||||||
|
|
||||||
|
# Deprecated: kept for backward compatibility
|
||||||
|
# Path under data/ to match WSL storage layout
|
||||||
|
temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async"))
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
cleanup_interval_hours: int = 1
|
cleanup_interval_hours: int = 1
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Create directories if they don't exist."""
|
"""Create directories if they don't exist (for backward compatibility)."""
|
||||||
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir))
|
||||||
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
self.temp_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -95,19 +154,41 @@ class AppConfig:
|
|||||||
|
|
||||||
model: ModelConfig = field(default_factory=ModelConfig)
|
model: ModelConfig = field(default_factory=ModelConfig)
|
||||||
server: ServerConfig = field(default_factory=ServerConfig)
|
server: ServerConfig = field(default_factory=ServerConfig)
|
||||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
file: FileConfig = field(default_factory=FileConfig)
|
||||||
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
async_processing: AsyncConfig = field(default_factory=AsyncConfig)
|
||||||
|
storage_backend: "StorageBackend | None" = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage(self) -> FileConfig:
|
||||||
|
"""Backward compatibility alias for file config."""
|
||||||
|
return self.file
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||||
"""Create config from dictionary."""
|
"""Create config from dictionary."""
|
||||||
|
file_config = config_dict.get("file", config_dict.get("storage", {}))
|
||||||
return cls(
|
return cls(
|
||||||
model=ModelConfig(**config_dict.get("model", {})),
|
model=ModelConfig(**config_dict.get("model", {})),
|
||||||
server=ServerConfig(**config_dict.get("server", {})),
|
server=ServerConfig(**config_dict.get("server", {})),
|
||||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
file=FileConfig(**file_config),
|
||||||
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
async_processing=AsyncConfig(**config_dict.get("async_processing", {})),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app_config(
|
||||||
|
storage_config_path: Path | str | None = None,
|
||||||
|
) -> AppConfig:
|
||||||
|
"""Create application configuration with storage backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_config_path: Optional path to storage configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured AppConfig instance with storage backend initialized.
|
||||||
|
"""
|
||||||
|
storage_backend = get_storage_backend(config_path=storage_config_path)
|
||||||
|
return AppConfig(storage_backend=storage_backend)
|
||||||
|
|
||||||
|
|
||||||
# Default configuration instance
|
# Default configuration instance
|
||||||
default_config = AppConfig()
|
default_config = AppConfig()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from inference.web.services.db_autolabel import (
|
|||||||
get_pending_autolabel_documents,
|
get_pending_autolabel_documents,
|
||||||
process_document_autolabel,
|
process_document_autolabel,
|
||||||
)
|
)
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -36,7 +37,13 @@ class AutoLabelScheduler:
|
|||||||
"""
|
"""
|
||||||
self._check_interval = check_interval_seconds
|
self._check_interval = check_interval_seconds
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
|
|
||||||
|
# Get output directory from StorageHelper
|
||||||
|
if output_dir is None:
|
||||||
|
storage = get_storage_helper()
|
||||||
|
output_dir = storage.get_autolabel_output_path()
|
||||||
self._output_dir = output_dir or Path("data/autolabel_output")
|
self._output_dir = output_dir or Path("data/autolabel_output")
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._thread: threading.Thread | None = None
|
self._thread: threading.Thread | None = None
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from inference.data.admin_db import AdminDB
|
from inference.data.admin_db import AdminDB
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -107,6 +108,14 @@ class TrainingScheduler:
|
|||||||
self._db.update_training_task_status(task_id, "running")
|
self._db.update_training_task_status(task_id, "running")
|
||||||
self._db.add_training_log(task_id, "INFO", "Training task started")
|
self._db.add_training_log(task_id, "INFO", "Training task started")
|
||||||
|
|
||||||
|
# Update dataset training status to running
|
||||||
|
if dataset_id:
|
||||||
|
self._db.update_dataset_training_status(
|
||||||
|
dataset_id,
|
||||||
|
training_status="running",
|
||||||
|
active_training_task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get training configuration
|
# Get training configuration
|
||||||
model_name = config.get("model_name", "yolo11n.pt")
|
model_name = config.get("model_name", "yolo11n.pt")
|
||||||
@@ -192,6 +201,15 @@ class TrainingScheduler:
|
|||||||
)
|
)
|
||||||
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
self._db.add_training_log(task_id, "INFO", "Training completed successfully")
|
||||||
|
|
||||||
|
# Update dataset training status to completed and main status to trained
|
||||||
|
if dataset_id:
|
||||||
|
self._db.update_dataset_training_status(
|
||||||
|
dataset_id,
|
||||||
|
training_status="completed",
|
||||||
|
active_training_task_id=None,
|
||||||
|
update_main_status=True, # Set main status to 'trained'
|
||||||
|
)
|
||||||
|
|
||||||
# Auto-create model version for the completed training
|
# Auto-create model version for the completed training
|
||||||
self._create_model_version_from_training(
|
self._create_model_version_from_training(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@@ -203,6 +221,13 @@ class TrainingScheduler:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training task {task_id} failed: {e}")
|
logger.error(f"Training task {task_id} failed: {e}")
|
||||||
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}")
|
||||||
|
# Update dataset training status to failed
|
||||||
|
if dataset_id:
|
||||||
|
self._db.update_dataset_training_status(
|
||||||
|
dataset_id,
|
||||||
|
training_status="failed",
|
||||||
|
active_training_task_id=None,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_model_version_from_training(
|
def _create_model_version_from_training(
|
||||||
@@ -268,9 +293,10 @@ class TrainingScheduler:
|
|||||||
f"Created model version {version} (ID: {model_version.version_id}) "
|
f"Created model version {version} (ID: {model_version.version_id}) "
|
||||||
f"from training task {task_id}"
|
f"from training task {task_id}"
|
||||||
)
|
)
|
||||||
|
mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A"
|
||||||
self._db.add_training_log(
|
self._db.add_training_log(
|
||||||
task_id, "INFO",
|
task_id, "INFO",
|
||||||
f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})",
|
f"Model version {version} created (mAP: {mAP_display})",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -283,8 +309,11 @@ class TrainingScheduler:
|
|||||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||||
"""Export training data for a task."""
|
"""Export training data for a task."""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
from shared.fields import FIELD_CLASSES
|
||||||
from inference.data.admin_models import FIELD_CLASSES
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
|
# Get storage helper for reading images
|
||||||
|
storage = get_storage_helper()
|
||||||
|
|
||||||
# Get all labeled documents
|
# Get all labeled documents
|
||||||
documents = self._db.get_labeled_documents_for_export()
|
documents = self._db.get_labeled_documents_for_export()
|
||||||
@@ -293,8 +322,12 @@ class TrainingScheduler:
|
|||||||
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
self._db.add_training_log(task_id, "ERROR", "No labeled documents available")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Create export directory
|
# Create export directory using StorageHelper
|
||||||
export_dir = Path("data/training") / task_id
|
training_base = storage.get_training_data_path()
|
||||||
|
if training_base is None:
|
||||||
|
self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access")
|
||||||
|
return None
|
||||||
|
export_dir = training_base / task_id
|
||||||
export_dir.mkdir(parents=True, exist_ok=True)
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# YOLO format directories
|
# YOLO format directories
|
||||||
@@ -323,14 +356,16 @@ class TrainingScheduler:
|
|||||||
for page_num in range(1, doc.page_count + 1):
|
for page_num in range(1, doc.page_count + 1):
|
||||||
page_annotations = [a for a in annotations if a.page_number == page_num]
|
page_annotations = [a for a in annotations if a.page_number == page_num]
|
||||||
|
|
||||||
# Copy image
|
# Get image from storage
|
||||||
src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png"
|
doc_id = str(doc.document_id)
|
||||||
if not src_image.exists():
|
if not storage.admin_image_exists(doc_id, page_num):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Download image and save to export directory
|
||||||
image_name = f"{doc.document_id}_page{page_num}.png"
|
image_name = f"{doc.document_id}_page{page_num}.png"
|
||||||
dst_image = export_dir / "images" / split / image_name
|
dst_image = export_dir / "images" / split / image_name
|
||||||
shutil.copy(src_image, dst_image)
|
image_content = storage.get_admin_image(doc_id, page_num)
|
||||||
|
dst_image.write_bytes(image_content)
|
||||||
total_images += 1
|
total_images += 1
|
||||||
|
|
||||||
# Write YOLO label
|
# Write YOLO label
|
||||||
@@ -380,6 +415,8 @@ names: {list(FIELD_CLASSES.values())}
|
|||||||
self._db.add_training_log(task_id, level, message)
|
self._db.add_training_log(task_id, level, message)
|
||||||
|
|
||||||
# Create shared training config
|
# Create shared training config
|
||||||
|
# Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH)
|
||||||
|
# because models need to be accessible by inference service on any platform
|
||||||
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
# Note: workers=0 to avoid multiprocessing issues when running in scheduler thread
|
||||||
config = SharedTrainingConfig(
|
config = SharedTrainingConfig(
|
||||||
model_path=model_name,
|
model_path=model_name,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ class DatasetCreateRequest(BaseModel):
|
|||||||
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
name: str = Field(..., min_length=1, max_length=255, description="Dataset name")
|
||||||
description: str | None = Field(None, description="Optional description")
|
description: str | None = Field(None, description="Optional description")
|
||||||
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
|
document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include")
|
||||||
|
category: str | None = Field(None, description="Filter documents by category (optional)")
|
||||||
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
|
train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio")
|
||||||
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
|
val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio")
|
||||||
seed: int = Field(42, description="Random seed for split")
|
seed: int = Field(42, description="Random seed for split")
|
||||||
@@ -43,6 +44,8 @@ class DatasetDetailResponse(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: str | None
|
description: str | None
|
||||||
status: str
|
status: str
|
||||||
|
training_status: str | None = None
|
||||||
|
active_training_task_id: str | None = None
|
||||||
train_ratio: float
|
train_ratio: float
|
||||||
val_ratio: float
|
val_ratio: float
|
||||||
seed: int
|
seed: int
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel):
|
|||||||
file_size: int = Field(..., ge=0, description="File size in bytes")
|
file_size: int = Field(..., ge=0, description="File size in bytes")
|
||||||
page_count: int = Field(..., ge=1, description="Number of pages")
|
page_count: int = Field(..., ge=1, description="Number of pages")
|
||||||
status: DocumentStatus = Field(..., description="Document status")
|
status: DocumentStatus = Field(..., description="Document status")
|
||||||
|
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||||
group_key: str | None = Field(None, description="User-defined group key")
|
group_key: str | None = Field(None, description="User-defined group key")
|
||||||
auto_label_started: bool = Field(
|
auto_label_started: bool = Field(
|
||||||
default=False, description="Whether auto-labeling was started"
|
default=False, description="Whether auto-labeling was started"
|
||||||
@@ -44,6 +45,7 @@ class DocumentItem(BaseModel):
|
|||||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||||
group_key: str | None = Field(None, description="User-defined group key")
|
group_key: str | None = Field(None, description="User-defined group key")
|
||||||
|
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||||
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
can_annotate: bool = Field(default=True, description="Whether document can be annotated")
|
||||||
created_at: datetime = Field(..., description="Creation timestamp")
|
created_at: datetime = Field(..., description="Creation timestamp")
|
||||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||||
@@ -76,6 +78,7 @@ class DocumentDetailResponse(BaseModel):
|
|||||||
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
upload_source: str = Field(default="ui", description="Upload source (ui or api)")
|
||||||
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
batch_id: str | None = Field(None, description="Batch ID if uploaded via batch")
|
||||||
group_key: str | None = Field(None, description="User-defined group key")
|
group_key: str | None = Field(None, description="User-defined group key")
|
||||||
|
category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)")
|
||||||
csv_field_values: dict[str, str] | None = Field(
|
csv_field_values: dict[str, str] | None = Field(
|
||||||
None, description="CSV field values if uploaded via batch"
|
None, description="CSV field values if uploaded via batch"
|
||||||
)
|
)
|
||||||
@@ -104,3 +107,17 @@ class DocumentStatsResponse(BaseModel):
|
|||||||
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents")
|
||||||
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
labeled: int = Field(default=0, ge=0, description="Labeled documents")
|
||||||
exported: int = Field(default=0, ge=0, description="Exported documents")
|
exported: int = Field(default=0, ge=0, description="Exported documents")
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentUpdateRequest(BaseModel):
|
||||||
|
"""Request for updating document metadata."""
|
||||||
|
|
||||||
|
category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)")
|
||||||
|
group_key: str | None = Field(None, description="User-defined group key")
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentCategoriesResponse(BaseModel):
|
||||||
|
"""Response for available document categories."""
|
||||||
|
|
||||||
|
categories: list[str] = Field(..., description="List of available categories")
|
||||||
|
total: int = Field(..., ge=0, description="Total number of categories")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Manages async request lifecycle and background processing.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@@ -17,6 +18,7 @@ from typing import TYPE_CHECKING
|
|||||||
from inference.data.async_request_db import AsyncRequestDB
|
from inference.data.async_request_db import AsyncRequestDB
|
||||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||||
from inference.web.core.rate_limiter import RateLimiter
|
from inference.web.core.rate_limiter import RateLimiter
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from inference.web.config import AsyncConfig, StorageConfig
|
from inference.web.config import AsyncConfig, StorageConfig
|
||||||
@@ -189,9 +191,7 @@ class AsyncProcessingService:
|
|||||||
filename: str,
|
filename: str,
|
||||||
content: bytes,
|
content: bytes,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Save uploaded file to temp storage."""
|
"""Save uploaded file to temp storage using StorageHelper."""
|
||||||
import re
|
|
||||||
|
|
||||||
# Extract extension from filename
|
# Extract extension from filename
|
||||||
ext = Path(filename).suffix.lower()
|
ext = Path(filename).suffix.lower()
|
||||||
|
|
||||||
@@ -203,9 +203,11 @@ class AsyncProcessingService:
|
|||||||
if ext not in self.ALLOWED_EXTENSIONS:
|
if ext not in self.ALLOWED_EXTENSIONS:
|
||||||
ext = ".pdf"
|
ext = ".pdf"
|
||||||
|
|
||||||
# Create async upload directory
|
# Get upload directory from StorageHelper
|
||||||
upload_dir = self._async_config.temp_upload_dir
|
storage = get_storage_helper()
|
||||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||||
|
if upload_dir is None:
|
||||||
|
raise ValueError("Storage not configured for local access")
|
||||||
|
|
||||||
# Build file path - request_id is a UUID so it's safe
|
# Build file path - request_id is a UUID so it's safe
|
||||||
file_path = upload_dir / f"{request_id}{ext}"
|
file_path = upload_dir / f"{request_id}{ext}"
|
||||||
@@ -355,8 +357,9 @@ class AsyncProcessingService:
|
|||||||
|
|
||||||
def _cleanup_orphan_files(self) -> int:
|
def _cleanup_orphan_files(self) -> int:
|
||||||
"""Clean up upload files that don't have matching requests."""
|
"""Clean up upload files that don't have matching requests."""
|
||||||
upload_dir = self._async_config.temp_upload_dir
|
storage = get_storage_helper()
|
||||||
if not upload_dir.exists():
|
upload_dir = storage.get_uploads_base_path(subfolder="async")
|
||||||
|
if upload_dir is None or not upload_dir.exists():
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from shared.config import DEFAULT_DPI
|
from shared.config import DEFAULT_DPI
|
||||||
from inference.data.admin_db import AdminDB
|
from inference.data.admin_db import AdminDB
|
||||||
from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES
|
from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES
|
||||||
from shared.matcher.field_matcher import FieldMatcher
|
from shared.matcher.field_matcher import FieldMatcher
|
||||||
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
from shared.ocr.paddle_ocr import OCREngine, OCRToken
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from uuid import UUID
|
|||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from inference.data.admin_db import AdminDB
|
from inference.data.admin_db import AdminDB
|
||||||
from inference.data.admin_models import CSV_TO_CLASS_MAPPING
|
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from inference.data.admin_models import FIELD_CLASSES
|
from shared.fields import FIELD_CLASSES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,10 @@ from typing import Any
|
|||||||
|
|
||||||
from shared.config import DEFAULT_DPI
|
from shared.config import DEFAULT_DPI
|
||||||
from inference.data.admin_db import AdminDB
|
from inference.data.admin_db import AdminDB
|
||||||
from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING
|
from shared.fields import CSV_TO_CLASS_MAPPING
|
||||||
|
from inference.data.admin_models import AdminDocument
|
||||||
from shared.data.db import DocumentDB
|
from shared.data.db import DocumentDB
|
||||||
from inference.web.config import StorageConfig
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -122,8 +123,12 @@ def process_document_autolabel(
|
|||||||
document_id = str(document.document_id)
|
document_id = str(document.document_id)
|
||||||
file_path = Path(document.file_path)
|
file_path = Path(document.file_path)
|
||||||
|
|
||||||
|
# Get output directory from StorageHelper
|
||||||
|
storage = get_storage_helper()
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = Path("data/autolabel_output")
|
output_dir = storage.get_autolabel_output_path()
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = Path("data/autolabel_output")
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Mark as processing
|
# Mark as processing
|
||||||
@@ -152,10 +157,12 @@ def process_document_autolabel(
|
|||||||
is_scanned = len(tokens) < 10 # Threshold for "no text"
|
is_scanned = len(tokens) < 10 # Threshold for "no text"
|
||||||
|
|
||||||
# Build task data
|
# Build task data
|
||||||
# Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path
|
# Use raw_pdfs base path for pdf_path
|
||||||
# This ensures consistency with CLI autolabel for reprocess_failed.py
|
# This ensures consistency with CLI autolabel for reprocess_failed.py
|
||||||
storage_config = StorageConfig()
|
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||||
pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
if raw_pdfs_dir is None:
|
||||||
|
raise ValueError("Storage not configured for local access")
|
||||||
|
pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf"
|
||||||
|
|
||||||
task_data = {
|
task_data = {
|
||||||
"row_dict": row_dict,
|
"row_dict": row_dict,
|
||||||
@@ -246,8 +253,8 @@ def _save_annotations_to_db(
|
|||||||
Returns:
|
Returns:
|
||||||
Number of annotations saved
|
Number of annotations saved
|
||||||
"""
|
"""
|
||||||
from PIL import Image
|
from shared.fields import FIELD_CLASS_IDS
|
||||||
from inference.data.admin_models import FIELD_CLASS_IDS
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
# Mapping from CSV field names to internal field names
|
# Mapping from CSV field names to internal field names
|
||||||
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
|
CSV_TO_INTERNAL_FIELD: dict[str, str] = {
|
||||||
@@ -266,6 +273,9 @@ def _save_annotations_to_db(
|
|||||||
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
|
# Scale factor: PDF points (72 DPI) -> pixels (at configured DPI)
|
||||||
scale = dpi / 72.0
|
scale = dpi / 72.0
|
||||||
|
|
||||||
|
# Get storage helper for image dimensions
|
||||||
|
storage = get_storage_helper()
|
||||||
|
|
||||||
# Cache for image dimensions per page
|
# Cache for image dimensions per page
|
||||||
image_dimensions: dict[int, tuple[int, int]] = {}
|
image_dimensions: dict[int, tuple[int, int]] = {}
|
||||||
|
|
||||||
@@ -274,18 +284,11 @@ def _save_annotations_to_db(
|
|||||||
if page_no in image_dimensions:
|
if page_no in image_dimensions:
|
||||||
return image_dimensions[page_no]
|
return image_dimensions[page_no]
|
||||||
|
|
||||||
# Try to load from admin_images
|
# Get dimensions from storage helper
|
||||||
admin_images_dir = Path("data/admin_images") / document_id
|
dims = storage.get_admin_image_dimensions(document_id, page_no)
|
||||||
image_path = admin_images_dir / f"page_{page_no}.png"
|
if dims:
|
||||||
|
image_dimensions[page_no] = dims
|
||||||
if image_path.exists():
|
return dims
|
||||||
try:
|
|
||||||
with Image.open(image_path) as img:
|
|
||||||
dims = img.size # (width, height)
|
|
||||||
image_dimensions[page_no] = dims
|
|
||||||
return dims
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to read image dimensions from {image_path}: {e}")
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -449,10 +452,17 @@ def save_manual_annotations_to_document_db(
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
document_id = str(document.document_id)
|
document_id = str(document.document_id)
|
||||||
storage_config = StorageConfig()
|
|
||||||
|
|
||||||
# Build pdf_path using admin_upload_dir (same as auto-label)
|
# Build pdf_path using raw_pdfs base path (same as auto-label)
|
||||||
pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf"
|
storage = get_storage_helper()
|
||||||
|
raw_pdfs_dir = storage.get_raw_pdfs_base_path()
|
||||||
|
if raw_pdfs_dir is None:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"document_id": document_id,
|
||||||
|
"error": "Storage not configured for local access",
|
||||||
|
}
|
||||||
|
pdf_path = raw_pdfs_dir / f"{document_id}.pdf"
|
||||||
|
|
||||||
# Build report dict compatible with DocumentDB.save_document()
|
# Build report dict compatible with DocumentDB.save_document()
|
||||||
field_results = []
|
field_results = []
|
||||||
|
|||||||
217
packages/inference/inference/web/services/document_service.py
Normal file
217
packages/inference/inference/web/services/document_service.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""
|
||||||
|
Document Service for storage-backed file operations.
|
||||||
|
|
||||||
|
Provides a unified interface for document upload, download, and serving
|
||||||
|
using the storage abstraction layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocumentResult:
|
||||||
|
"""Result of document operation."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
file_path: str
|
||||||
|
filename: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentService:
|
||||||
|
"""Service for document file operations using storage backend.
|
||||||
|
|
||||||
|
Provides upload, download, and URL generation for documents and images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Storage path prefixes
|
||||||
|
DOCUMENTS_PREFIX = "documents"
|
||||||
|
IMAGES_PREFIX = "images"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_backend: "StorageBackend",
|
||||||
|
admin_db: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize document service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_backend: Storage backend for file operations.
|
||||||
|
admin_db: Optional AdminDB instance for database operations.
|
||||||
|
"""
|
||||||
|
self._storage = storage_backend
|
||||||
|
self._admin_db = admin_db
|
||||||
|
|
||||||
|
def upload_document(
|
||||||
|
self,
|
||||||
|
content: bytes,
|
||||||
|
filename: str,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
document_id: str | None = None,
|
||||||
|
) -> DocumentResult:
|
||||||
|
"""Upload a document to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Document content as bytes.
|
||||||
|
filename: Original filename.
|
||||||
|
dataset_id: Optional dataset ID for organization.
|
||||||
|
document_id: Optional document ID (generated if not provided).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DocumentResult with ID and storage path.
|
||||||
|
"""
|
||||||
|
if document_id is None:
|
||||||
|
document_id = str(uuid4())
|
||||||
|
|
||||||
|
# Extract extension from filename
|
||||||
|
ext = ""
|
||||||
|
if "." in filename:
|
||||||
|
ext = "." + filename.rsplit(".", 1)[-1].lower()
|
||||||
|
|
||||||
|
# Build logical path
|
||||||
|
remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}"
|
||||||
|
|
||||||
|
# Upload via storage backend
|
||||||
|
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||||
|
|
||||||
|
return DocumentResult(
|
||||||
|
id=document_id,
|
||||||
|
file_path=remote_path,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
def download_document(self, remote_path: str) -> bytes:
|
||||||
|
"""Download a document from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Logical path to the document.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Document content as bytes.
|
||||||
|
"""
|
||||||
|
return self._storage.download_bytes(remote_path)
|
||||||
|
|
||||||
|
def get_document_url(
|
||||||
|
self,
|
||||||
|
remote_path: str,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Get a URL for accessing a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Logical path to the document.
|
||||||
|
expires_in_seconds: URL validity duration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pre-signed URL for document access.
|
||||||
|
"""
|
||||||
|
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||||
|
|
||||||
|
def document_exists(self, remote_path: str) -> bool:
|
||||||
|
"""Check if a document exists in storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Logical path to the document.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if document exists.
|
||||||
|
"""
|
||||||
|
return self._storage.exists(remote_path)
|
||||||
|
|
||||||
|
def delete_document_files(self, remote_path: str) -> bool:
|
||||||
|
"""Delete a document from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Logical path to the document.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if document was deleted.
|
||||||
|
"""
|
||||||
|
return self._storage.delete(remote_path)
|
||||||
|
|
||||||
|
def save_page_image(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page_num: int,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Save a page image to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document ID.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
content: Image content as bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logical path where image was stored.
|
||||||
|
"""
|
||||||
|
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||||
|
self._storage.upload_bytes(content, remote_path, overwrite=True)
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def get_page_image_url(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page_num: int,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Get a URL for accessing a page image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document ID.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
expires_in_seconds: URL validity duration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pre-signed URL for image access.
|
||||||
|
"""
|
||||||
|
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||||
|
return self._storage.get_presigned_url(remote_path, expires_in_seconds)
|
||||||
|
|
||||||
|
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||||
|
"""Download a page image from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document ID.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image content as bytes.
|
||||||
|
"""
|
||||||
|
remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png"
|
||||||
|
return self._storage.download_bytes(remote_path)
|
||||||
|
|
||||||
|
def delete_document_images(self, document_id: str) -> int:
|
||||||
|
"""Delete all images for a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of images deleted.
|
||||||
|
"""
|
||||||
|
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||||
|
image_paths = self._storage.list_files(prefix)
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
for path in image_paths:
|
||||||
|
if self._storage.delete(path):
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
def list_document_images(self, document_id: str) -> list[str]:
|
||||||
|
"""List all images for a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of image paths.
|
||||||
|
"""
|
||||||
|
prefix = f"{self.IMAGES_PREFIX}/{document_id}/"
|
||||||
|
return self._storage.list_files(prefix)
|
||||||
@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Callable
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from inference.web.services.storage_helpers import get_storage_helper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .config import ModelConfig, StorageConfig
|
from .config import ModelConfig, StorageConfig
|
||||||
|
|
||||||
@@ -303,12 +305,19 @@ class InferenceService:
|
|||||||
"""Save visualization image with detections."""
|
"""Save visualization image with detections."""
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Get storage helper for results directory
|
||||||
|
storage = get_storage_helper()
|
||||||
|
results_dir = storage.get_results_base_path()
|
||||||
|
if results_dir is None:
|
||||||
|
logger.warning("Cannot save visualization: local storage not available")
|
||||||
|
return None
|
||||||
|
|
||||||
# Load model and run prediction with visualization
|
# Load model and run prediction with visualization
|
||||||
model = YOLO(str(self.model_config.model_path))
|
model = YOLO(str(self.model_config.model_path))
|
||||||
results = model.predict(str(image_path), verbose=False)
|
results = model.predict(str(image_path), verbose=False)
|
||||||
|
|
||||||
# Save annotated image
|
# Save annotated image
|
||||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
output_path = results_dir / f"{doc_id}_result.png"
|
||||||
for r in results:
|
for r in results:
|
||||||
r.save(filename=str(output_path))
|
r.save(filename=str(output_path))
|
||||||
|
|
||||||
@@ -320,19 +329,26 @@ class InferenceService:
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
# Get storage helper for results directory
|
||||||
|
storage = get_storage_helper()
|
||||||
|
results_dir = storage.get_results_base_path()
|
||||||
|
if results_dir is None:
|
||||||
|
logger.warning("Cannot save visualization: local storage not available")
|
||||||
|
return None
|
||||||
|
|
||||||
# Render first page
|
# Render first page
|
||||||
for page_no, image_bytes in render_pdf_to_images(
|
for page_no, image_bytes in render_pdf_to_images(
|
||||||
pdf_path, dpi=self.model_config.dpi
|
pdf_path, dpi=self.model_config.dpi
|
||||||
):
|
):
|
||||||
image = Image.open(io.BytesIO(image_bytes))
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
|
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||||
image.save(temp_path)
|
image.save(temp_path)
|
||||||
|
|
||||||
# Run YOLO and save visualization
|
# Run YOLO and save visualization
|
||||||
model = YOLO(str(self.model_config.model_path))
|
model = YOLO(str(self.model_config.model_path))
|
||||||
results = model.predict(str(temp_path), verbose=False)
|
results = model.predict(str(temp_path), verbose=False)
|
||||||
|
|
||||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
output_path = results_dir / f"{doc_id}_result.png"
|
||||||
for r in results:
|
for r in results:
|
||||||
r.save(filename=str(output_path))
|
r.save(filename=str(output_path))
|
||||||
|
|
||||||
|
|||||||
830
packages/inference/inference/web/services/storage_helpers.py
Normal file
830
packages/inference/inference/web/services/storage_helpers.py
Normal file
@@ -0,0 +1,830 @@
|
|||||||
|
"""
|
||||||
|
Storage helpers for web services.
|
||||||
|
|
||||||
|
Provides convenience functions for common storage operations,
|
||||||
|
wrapping the storage backend with proper path handling using prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from shared.storage import PREFIXES, get_storage_backend
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_storage() -> "StorageBackend":
|
||||||
|
"""Get the default storage backend.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured StorageBackend instance.
|
||||||
|
"""
|
||||||
|
return get_storage_backend()
|
||||||
|
|
||||||
|
|
||||||
|
class StorageHelper:
|
||||||
|
"""Helper class for storage operations with prefixes.
|
||||||
|
|
||||||
|
Provides high-level operations for document storage, including
|
||||||
|
upload, download, and URL generation with proper path prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, storage: "StorageBackend | None" = None) -> None:
|
||||||
|
"""Initialize storage helper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage backend to use. If None, creates default.
|
||||||
|
"""
|
||||||
|
self._storage = storage or get_default_storage()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage(self) -> "StorageBackend":
|
||||||
|
"""Get the underlying storage backend."""
|
||||||
|
return self._storage
|
||||||
|
|
||||||
|
# Document operations
|
||||||
|
|
||||||
|
def upload_document(
|
||||||
|
self,
|
||||||
|
content: bytes,
|
||||||
|
filename: str,
|
||||||
|
document_id: str | None = None,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Upload a document to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Document content as bytes.
|
||||||
|
filename: Original filename (used for extension).
|
||||||
|
document_id: Optional document ID. Generated if not provided.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (document_id, storage_path).
|
||||||
|
"""
|
||||||
|
if document_id is None:
|
||||||
|
document_id = str(uuid4())
|
||||||
|
|
||||||
|
ext = Path(filename).suffix.lower() or ".pdf"
|
||||||
|
path = PREFIXES.document_path(document_id, ext)
|
||||||
|
self._storage.upload_bytes(content, path, overwrite=True)
|
||||||
|
|
||||||
|
return document_id, path
|
||||||
|
|
||||||
|
def download_document(self, document_id: str, extension: str = ".pdf") -> bytes:
|
||||||
|
"""Download a document from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
extension: File extension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Document content as bytes.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.document_path(document_id, extension)
|
||||||
|
return self._storage.download_bytes(path)
|
||||||
|
|
||||||
|
def get_document_url(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
extension: str = ".pdf",
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Get presigned URL for a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
extension: File extension.
|
||||||
|
expires_in_seconds: URL expiration time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Presigned URL string.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.document_path(document_id, extension)
|
||||||
|
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||||
|
|
||||||
|
def document_exists(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||||
|
"""Check if a document exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
extension: File extension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if document exists.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.document_path(document_id, extension)
|
||||||
|
return self._storage.exists(path)
|
||||||
|
|
||||||
|
def delete_document(self, document_id: str, extension: str = ".pdf") -> bool:
|
||||||
|
"""Delete a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
extension: File extension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if document was deleted.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.document_path(document_id, extension)
|
||||||
|
return self._storage.delete(path)
|
||||||
|
|
||||||
|
# Image operations
|
||||||
|
|
||||||
|
def save_page_image(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page_num: int,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Save a page image to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
content: Image content as bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path where image was saved.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.image_path(document_id, page_num)
|
||||||
|
self._storage.upload_bytes(content, path, overwrite=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_page_image(self, document_id: str, page_num: int) -> bytes:
|
||||||
|
"""Download a page image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image content as bytes.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.image_path(document_id, page_num)
|
||||||
|
return self._storage.download_bytes(path)
|
||||||
|
|
||||||
|
def get_page_image_url(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page_num: int,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Get presigned URL for a page image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
expires_in_seconds: URL expiration time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Presigned URL string.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.image_path(document_id, page_num)
|
||||||
|
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||||
|
|
||||||
|
def delete_document_images(self, document_id: str) -> int:
|
||||||
|
"""Delete all images for a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of images deleted.
|
||||||
|
"""
|
||||||
|
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||||
|
images = self._storage.list_files(prefix)
|
||||||
|
deleted = 0
|
||||||
|
for img_path in images:
|
||||||
|
if self._storage.delete(img_path):
|
||||||
|
deleted += 1
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
def list_document_images(self, document_id: str) -> list[str]:
|
||||||
|
"""List all images for a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of image paths.
|
||||||
|
"""
|
||||||
|
prefix = f"{PREFIXES.IMAGES}/{document_id}/"
|
||||||
|
return self._storage.list_files(prefix)
|
||||||
|
|
||||||
|
# Upload staging operations
|
||||||
|
|
||||||
|
def save_upload(
|
||||||
|
self,
|
||||||
|
content: bytes,
|
||||||
|
filename: str,
|
||||||
|
subfolder: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Save a file to upload staging area.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content as bytes.
|
||||||
|
filename: Filename to save as.
|
||||||
|
subfolder: Optional subfolder (e.g., "async").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path where file was saved.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.upload_path(filename, subfolder)
|
||||||
|
self._storage.upload_bytes(content, path, overwrite=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_upload(self, filename: str, subfolder: str | None = None) -> bytes:
|
||||||
|
"""Get a file from upload staging area.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to retrieve.
|
||||||
|
subfolder: Optional subfolder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.upload_path(filename, subfolder)
|
||||||
|
return self._storage.download_bytes(path)
|
||||||
|
|
||||||
|
def delete_upload(self, filename: str, subfolder: str | None = None) -> bool:
|
||||||
|
"""Delete a file from upload staging area.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to delete.
|
||||||
|
subfolder: Optional subfolder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file was deleted.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.upload_path(filename, subfolder)
|
||||||
|
return self._storage.delete(path)
|
||||||
|
|
||||||
|
# Result operations
|
||||||
|
|
||||||
|
def save_result(self, content: bytes, filename: str) -> str:
|
||||||
|
"""Save a result file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content as bytes.
|
||||||
|
filename: Filename to save as.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path where file was saved.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.result_path(filename)
|
||||||
|
self._storage.upload_bytes(content, path, overwrite=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_result(self, filename: str) -> bytes:
|
||||||
|
"""Get a result file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.result_path(filename)
|
||||||
|
return self._storage.download_bytes(path)
|
||||||
|
|
||||||
|
def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str:
|
||||||
|
"""Get presigned URL for a result file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename.
|
||||||
|
expires_in_seconds: URL expiration time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Presigned URL string.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.result_path(filename)
|
||||||
|
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||||
|
|
||||||
|
def result_exists(self, filename: str) -> bool:
|
||||||
|
"""Check if a result file exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file exists.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.result_path(filename)
|
||||||
|
return self._storage.exists(path)
|
||||||
|
|
||||||
|
def delete_result(self, filename: str) -> bool:
|
||||||
|
"""Delete a result file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file was deleted.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.result_path(filename)
|
||||||
|
return self._storage.delete(path)
|
||||||
|
|
||||||
|
# Export operations
|
||||||
|
|
||||||
|
def save_export(self, content: bytes, export_id: str, filename: str) -> str:
|
||||||
|
"""Save an export file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content as bytes.
|
||||||
|
export_id: Export identifier.
|
||||||
|
filename: Filename to save as.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path where file was saved.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.export_path(export_id, filename)
|
||||||
|
self._storage.upload_bytes(content, path, overwrite=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_export_url(
|
||||||
|
self,
|
||||||
|
export_id: str,
|
||||||
|
filename: str,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Get presigned URL for an export file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_id: Export identifier.
|
||||||
|
filename: Filename.
|
||||||
|
expires_in_seconds: URL expiration time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Presigned URL string.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.export_path(export_id, filename)
|
||||||
|
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||||
|
|
||||||
|
# Admin image operations
|
||||||
|
|
||||||
|
def get_admin_image_path(self, document_id: str, page_num: int) -> str:
|
||||||
|
"""Get the storage path for an admin image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "admin_images/doc123/page_1.png"
|
||||||
|
"""
|
||||||
|
return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png"
|
||||||
|
|
||||||
|
def save_admin_image(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page_num: int,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Save an admin page image to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
content: Image content as bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path where image was saved.
|
||||||
|
"""
|
||||||
|
path = self.get_admin_image_path(document_id, page_num)
|
||||||
|
self._storage.upload_bytes(content, path, overwrite=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_admin_image(self, document_id: str, page_num: int) -> bytes:
|
||||||
|
"""Download an admin page image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image content as bytes.
|
||||||
|
"""
|
||||||
|
path = self.get_admin_image_path(document_id, page_num)
|
||||||
|
return self._storage.download_bytes(path)
|
||||||
|
|
||||||
|
def get_admin_image_url(
|
||||||
|
self,
|
||||||
|
document_id: str,
|
||||||
|
page_num: int,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Get presigned URL for an admin page image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
expires_in_seconds: URL expiration time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Presigned URL string.
|
||||||
|
"""
|
||||||
|
path = self.get_admin_image_path(document_id, page_num)
|
||||||
|
return self._storage.get_presigned_url(path, expires_in_seconds)
|
||||||
|
|
||||||
|
def admin_image_exists(self, document_id: str, page_num: int) -> bool:
|
||||||
|
"""Check if an admin page image exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if image exists.
|
||||||
|
"""
|
||||||
|
path = self.get_admin_image_path(document_id, page_num)
|
||||||
|
return self._storage.exists(path)
|
||||||
|
|
||||||
|
def list_admin_images(self, document_id: str) -> list[str]:
|
||||||
|
"""List all admin images for a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of image paths.
|
||||||
|
"""
|
||||||
|
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||||
|
return self._storage.list_files(prefix)
|
||||||
|
|
||||||
|
def delete_admin_images(self, document_id: str) -> int:
|
||||||
|
"""Delete all admin images for a document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of images deleted.
|
||||||
|
"""
|
||||||
|
prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/"
|
||||||
|
images = self._storage.list_files(prefix)
|
||||||
|
deleted = 0
|
||||||
|
for img_path in images:
|
||||||
|
if self._storage.delete(img_path):
|
||||||
|
deleted += 1
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
def get_admin_image_local_path(
|
||||||
|
self, document_id: str, page_num: int
|
||||||
|
) -> Path | None:
|
||||||
|
"""Get the local filesystem path for an admin image.
|
||||||
|
|
||||||
|
This method is useful for serving files via FileResponse.
|
||||||
|
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path object if using local storage and file exists, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
# Cloud storage - cannot get local path
|
||||||
|
return None
|
||||||
|
|
||||||
|
remote_path = self.get_admin_image_path(document_id, page_num)
|
||||||
|
try:
|
||||||
|
full_path = self._storage._get_full_path(remote_path)
|
||||||
|
if full_path.exists():
|
||||||
|
return full_path
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_admin_image_dimensions(
|
||||||
|
self, document_id: str, page_num: int
|
||||||
|
) -> tuple[int, int] | None:
|
||||||
|
"""Get the dimensions (width, height) of an admin image.
|
||||||
|
|
||||||
|
This method is useful for normalizing bounding box coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height) if image exists, None otherwise.
|
||||||
|
"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Try local path first for efficiency
|
||||||
|
local_path = self.get_admin_image_local_path(document_id, page_num)
|
||||||
|
if local_path is not None:
|
||||||
|
with Image.open(local_path) as img:
|
||||||
|
return img.size
|
||||||
|
|
||||||
|
# Fall back to downloading for cloud storage
|
||||||
|
if not self.admin_image_exists(document_id, page_num):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import io
|
||||||
|
image_bytes = self.get_admin_image(document_id, page_num)
|
||||||
|
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||||
|
return img.size
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Raw PDF operations (legacy compatibility)
|
||||||
|
|
||||||
|
def save_raw_pdf(self, content: bytes, filename: str) -> str:
|
||||||
|
"""Save a raw PDF for auto-labeling pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: PDF content as bytes.
|
||||||
|
filename: Filename to save as.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path where file was saved.
|
||||||
|
"""
|
||||||
|
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||||
|
self._storage.upload_bytes(content, path, overwrite=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_raw_pdf(self, filename: str) -> bytes:
|
||||||
|
"""Get a raw PDF from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PDF content as bytes.
|
||||||
|
"""
|
||||||
|
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||||
|
return self._storage.download_bytes(path)
|
||||||
|
|
||||||
|
def raw_pdf_exists(self, filename: str) -> bool:
|
||||||
|
"""Check if a raw PDF exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file exists.
|
||||||
|
"""
|
||||||
|
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||||
|
return self._storage.exists(path)
|
||||||
|
|
||||||
|
def get_raw_pdf_local_path(self, filename: str) -> Path | None:
|
||||||
|
"""Get the local filesystem path for a raw PDF.
|
||||||
|
|
||||||
|
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path object if using local storage and file exists, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
path = f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||||
|
try:
|
||||||
|
full_path = self._storage._get_full_path(path)
|
||||||
|
if full_path.exists():
|
||||||
|
return full_path
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_raw_pdf_path(self, filename: str) -> str:
|
||||||
|
"""Get the storage path for a raw PDF (not the local filesystem path).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "raw_pdfs/filename.pdf"
|
||||||
|
"""
|
||||||
|
return f"{PREFIXES.RAW_PDFS}/{filename}"
|
||||||
|
|
||||||
|
# Result local path operations
|
||||||
|
|
||||||
|
def get_result_local_path(self, filename: str) -> Path | None:
|
||||||
|
"""Get the local filesystem path for a result file.
|
||||||
|
|
||||||
|
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path object if using local storage and file exists, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
path = PREFIXES.result_path(filename)
|
||||||
|
try:
|
||||||
|
full_path = self._storage._get_full_path(path)
|
||||||
|
if full_path.exists():
|
||||||
|
return full_path
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_results_base_path(self) -> Path | None:
|
||||||
|
"""Get the base directory path for results (local storage only).
|
||||||
|
|
||||||
|
Used for mounting static file directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to results directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_path = self._storage._get_full_path(PREFIXES.RESULTS)
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Upload local path operations
|
||||||
|
|
||||||
|
def get_upload_local_path(
|
||||||
|
self, filename: str, subfolder: str | None = None
|
||||||
|
) -> Path | None:
|
||||||
|
"""Get the local filesystem path for an upload file.
|
||||||
|
|
||||||
|
Only works with LocalStorageBackend; returns None for cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to retrieve.
|
||||||
|
subfolder: Optional subfolder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path object if using local storage and file exists, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
path = PREFIXES.upload_path(filename, subfolder)
|
||||||
|
try:
|
||||||
|
full_path = self._storage._get_full_path(path)
|
||||||
|
if full_path.exists():
|
||||||
|
return full_path
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None:
|
||||||
|
"""Get the base directory path for uploads (local storage only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subfolder: Optional subfolder (e.g., "async").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to uploads directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if subfolder:
|
||||||
|
base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}")
|
||||||
|
else:
|
||||||
|
base_path = self._storage._get_full_path(PREFIXES.UPLOADS)
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upload_exists(self, filename: str, subfolder: str | None = None) -> bool:
|
||||||
|
"""Check if an upload file exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to check.
|
||||||
|
subfolder: Optional subfolder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file exists.
|
||||||
|
"""
|
||||||
|
path = PREFIXES.upload_path(filename, subfolder)
|
||||||
|
return self._storage.exists(path)
|
||||||
|
|
||||||
|
# Dataset operations
|
||||||
|
|
||||||
|
def get_datasets_base_path(self) -> Path | None:
|
||||||
|
"""Get the base directory path for datasets (local storage only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to datasets directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_path = self._storage._get_full_path(PREFIXES.DATASETS)
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_admin_images_base_path(self) -> Path | None:
|
||||||
|
"""Get the base directory path for admin images (local storage only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to admin_images directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES)
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_raw_pdfs_base_path(self) -> Path | None:
|
||||||
|
"""Get the base directory path for raw PDFs (local storage only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to raw_pdfs directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS)
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_autolabel_output_path(self) -> Path | None:
|
||||||
|
"""Get the directory path for autolabel output (local storage only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to autolabel_output directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use a subfolder under results for autolabel output
|
||||||
|
base_path = self._storage._get_full_path("autolabel_output")
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_training_data_path(self) -> Path | None:
|
||||||
|
"""Get the directory path for training data exports (local storage only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to training directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_path = self._storage._get_full_path("training")
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_exports_base_path(self) -> Path | None:
|
||||||
|
"""Get the base directory path for exports (local storage only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to exports directory if using local storage, None otherwise.
|
||||||
|
"""
|
||||||
|
if not isinstance(self._storage, LocalStorageBackend):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_path = self._storage._get_full_path(PREFIXES.EXPORTS)
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return base_path
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Default instance for convenience
|
||||||
|
_default_helper: StorageHelper | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_helper() -> StorageHelper:
|
||||||
|
"""Get the default storage helper instance.
|
||||||
|
|
||||||
|
Creates the helper on first call with default storage backend.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default StorageHelper instance.
|
||||||
|
"""
|
||||||
|
global _default_helper
|
||||||
|
if _default_helper is None:
|
||||||
|
_default_helper = StorageHelper()
|
||||||
|
return _default_helper
|
||||||
205
packages/shared/README.md
Normal file
205
packages/shared/README.md
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# Shared Package
|
||||||
|
|
||||||
|
Shared utilities and abstractions for the Invoice Master system.
|
||||||
|
|
||||||
|
## Storage Abstraction Layer
|
||||||
|
|
||||||
|
A unified storage abstraction supporting multiple backends:
|
||||||
|
- **Local filesystem** - Development and testing
|
||||||
|
- **Azure Blob Storage** - Azure cloud deployments
|
||||||
|
- **AWS S3** - AWS cloud deployments
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic installation (local storage only)
|
||||||
|
pip install -e packages/shared
|
||||||
|
|
||||||
|
# With Azure support
|
||||||
|
pip install -e "packages/shared[azure]"
|
||||||
|
|
||||||
|
# With S3 support
|
||||||
|
pip install -e "packages/shared[s3]"
|
||||||
|
|
||||||
|
# All cloud providers
|
||||||
|
pip install -e "packages/shared[all]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from shared.storage import get_storage_backend
|
||||||
|
|
||||||
|
# Option 1: From configuration file
|
||||||
|
storage = get_storage_backend("storage.yaml")
|
||||||
|
|
||||||
|
# Option 2: From environment variables
|
||||||
|
from shared.storage import create_storage_backend_from_env
|
||||||
|
storage = create_storage_backend_from_env()
|
||||||
|
|
||||||
|
# Upload a file
|
||||||
|
storage.upload(Path("local/file.pdf"), "documents/file.pdf")
|
||||||
|
|
||||||
|
# Download a file
|
||||||
|
storage.download("documents/file.pdf", Path("local/downloaded.pdf"))
|
||||||
|
|
||||||
|
# Get pre-signed URL for frontend access
|
||||||
|
url = storage.get_presigned_url("documents/file.pdf", expires_in_seconds=3600)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration File Format
|
||||||
|
|
||||||
|
Create a `storage.yaml` file with environment variable substitution support:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Backend selection: local, azure_blob, or s3
|
||||||
|
backend: ${STORAGE_BACKEND:-local}
|
||||||
|
|
||||||
|
# Default pre-signed URL expiry (seconds)
|
||||||
|
presigned_url_expiry: 3600
|
||||||
|
|
||||||
|
# Local storage configuration
|
||||||
|
local:
|
||||||
|
base_path: ${STORAGE_BASE_PATH:-./data/storage}
|
||||||
|
|
||||||
|
# Azure Blob Storage configuration
|
||||||
|
azure:
|
||||||
|
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
|
||||||
|
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
|
||||||
|
create_container: false
|
||||||
|
|
||||||
|
# AWS S3 configuration
|
||||||
|
s3:
|
||||||
|
bucket_name: ${AWS_S3_BUCKET}
|
||||||
|
region_name: ${AWS_REGION:-us-east-1}
|
||||||
|
access_key_id: ${AWS_ACCESS_KEY_ID}
|
||||||
|
secret_access_key: ${AWS_SECRET_ACCESS_KEY}
|
||||||
|
endpoint_url: ${AWS_ENDPOINT_URL} # Optional, for S3-compatible services
|
||||||
|
create_bucket: false
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
| Variable | Backend | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `STORAGE_BACKEND` | All | Backend type: `local`, `azure_blob`, `s3` |
|
||||||
|
| `STORAGE_BASE_PATH` | Local | Base directory path |
|
||||||
|
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | Connection string |
|
||||||
|
| `AZURE_STORAGE_CONTAINER` | Azure | Container name |
|
||||||
|
| `AWS_S3_BUCKET` | S3 | Bucket name |
|
||||||
|
| `AWS_REGION` | S3 | AWS region (default: us-east-1) |
|
||||||
|
| `AWS_ACCESS_KEY_ID` | S3 | Access key (optional, uses credential chain) |
|
||||||
|
| `AWS_SECRET_ACCESS_KEY` | S3 | Secret key (optional) |
|
||||||
|
| `AWS_ENDPOINT_URL` | S3 | Custom endpoint for S3-compatible services |
|
||||||
|
|
||||||
|
### API Reference
|
||||||
|
|
||||||
|
#### StorageBackend Interface
|
||||||
|
|
||||||
|
```python
|
||||||
|
class StorageBackend(ABC):
|
||||||
|
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
|
||||||
|
"""Upload a file to storage."""
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
"""Download a file from storage."""
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
"""Check if a file exists."""
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
"""List files with given prefix."""
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
"""Delete a file."""
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
"""Get URL for a file."""
|
||||||
|
|
||||||
|
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||||
|
"""Generate a pre-signed URL for temporary access (1-604800 seconds)."""
|
||||||
|
|
||||||
|
def upload_bytes(self, data: bytes, remote_path: str, overwrite: bool = False) -> str:
|
||||||
|
"""Upload bytes directly."""
|
||||||
|
|
||||||
|
def download_bytes(self, remote_path: str) -> bytes:
|
||||||
|
"""Download file as bytes."""
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Factory Functions
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create from configuration file
|
||||||
|
storage = create_storage_backend_from_file("storage.yaml")
|
||||||
|
|
||||||
|
# Create from environment variables
|
||||||
|
storage = create_storage_backend_from_env()
|
||||||
|
|
||||||
|
# Create from StorageConfig object
|
||||||
|
config = StorageConfig(backend_type="local", base_path=Path("./data"))
|
||||||
|
storage = create_storage_backend(config)
|
||||||
|
|
||||||
|
# Convenience function with fallback chain: config file -> env vars -> local default
|
||||||
|
storage = get_storage_backend("storage.yaml") # or None for env-only
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-signed URLs
|
||||||
|
|
||||||
|
Pre-signed URLs provide temporary access to files without exposing credentials:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Generate URL valid for 1 hour (default)
|
||||||
|
url = storage.get_presigned_url("documents/invoice.pdf")
|
||||||
|
|
||||||
|
# Generate URL valid for 24 hours
|
||||||
|
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=86400)
|
||||||
|
|
||||||
|
# Maximum expiry: 7 days (604800 seconds)
|
||||||
|
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=604800)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** Local storage returns `file://` URLs that don't actually expire.
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
```python
|
||||||
|
from shared.storage import (
|
||||||
|
StorageError,
|
||||||
|
FileNotFoundStorageError,
|
||||||
|
PresignedUrlNotSupportedError,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
storage.download("nonexistent.pdf", Path("local.pdf"))
|
||||||
|
except FileNotFoundStorageError as e:
|
||||||
|
print(f"File not found: {e}")
|
||||||
|
except StorageError as e:
|
||||||
|
print(f"Storage error: {e}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing with MinIO (S3-compatible)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start MinIO locally
|
||||||
|
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
|
||||||
|
|
||||||
|
# Configure environment
|
||||||
|
export STORAGE_BACKEND=s3
|
||||||
|
export AWS_S3_BUCKET=test-bucket
|
||||||
|
export AWS_ENDPOINT_URL=http://localhost:9000
|
||||||
|
export AWS_ACCESS_KEY_ID=minioadmin
|
||||||
|
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
shared/storage/
|
||||||
|
├── __init__.py # Public exports
|
||||||
|
├── base.py # Abstract interface and exceptions
|
||||||
|
├── local.py # Local filesystem backend
|
||||||
|
├── azure.py # Azure Blob Storage backend
|
||||||
|
├── s3.py # AWS S3 backend
|
||||||
|
├── config_loader.py # YAML configuration loader
|
||||||
|
└── factory.py # Backend factory functions
|
||||||
|
```
|
||||||
@@ -16,4 +16,18 @@ setup(
|
|||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0",
|
||||||
"thefuzz>=0.20.0",
|
"thefuzz>=0.20.0",
|
||||||
],
|
],
|
||||||
|
extras_require={
|
||||||
|
"azure": [
|
||||||
|
"azure-storage-blob>=12.19.0",
|
||||||
|
"azure-identity>=1.15.0",
|
||||||
|
],
|
||||||
|
"s3": [
|
||||||
|
"boto3>=1.34.0",
|
||||||
|
],
|
||||||
|
"all": [
|
||||||
|
"azure-storage-blob>=12.19.0",
|
||||||
|
"azure-identity>=1.15.0",
|
||||||
|
"boto3>=1.34.0",
|
||||||
|
],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -58,23 +58,16 @@ def get_db_connection_string():
|
|||||||
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
|
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
|
||||||
|
|
||||||
|
|
||||||
# Paths Configuration - auto-detect WSL vs Windows
|
# Paths Configuration - uses STORAGE_BASE_PATH for consistency
|
||||||
if _is_wsl():
|
# All paths are relative to STORAGE_BASE_PATH (defaults to ~/invoice-data/data)
|
||||||
# WSL: use native Linux filesystem for better I/O performance
|
_storage_base = os.path.expanduser(os.getenv('STORAGE_BASE_PATH', '~/invoice-data/data'))
|
||||||
PATHS = {
|
|
||||||
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
|
PATHS = {
|
||||||
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
|
'csv_dir': f'{_storage_base}/structured_data',
|
||||||
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
|
'pdf_dir': f'{_storage_base}/raw_pdfs',
|
||||||
'reports_dir': 'reports', # Keep reports in project directory
|
'output_dir': f'{_storage_base}/datasets',
|
||||||
}
|
'reports_dir': 'reports', # Keep reports in project directory
|
||||||
else:
|
}
|
||||||
# Windows or native Linux: use relative paths
|
|
||||||
PATHS = {
|
|
||||||
'csv_dir': 'data/structured_data',
|
|
||||||
'pdf_dir': 'data/raw_pdfs',
|
|
||||||
'output_dir': 'data/dataset',
|
|
||||||
'reports_dir': 'reports',
|
|
||||||
}
|
|
||||||
|
|
||||||
# Auto-labeling Configuration
|
# Auto-labeling Configuration
|
||||||
AUTOLABEL = {
|
AUTOLABEL = {
|
||||||
|
|||||||
46
packages/shared/shared/fields/__init__.py
Normal file
46
packages/shared/shared/fields/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
Shared Field Definitions - Single Source of Truth.
|
||||||
|
|
||||||
|
This module provides centralized field class definitions used throughout
|
||||||
|
the invoice extraction system. All field mappings are derived from
|
||||||
|
FIELD_DEFINITIONS to ensure consistency.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from shared.fields import FIELD_CLASSES, CLASS_NAMES, FIELD_CLASS_IDS
|
||||||
|
|
||||||
|
Available exports:
|
||||||
|
- FieldDefinition: Dataclass for field definition
|
||||||
|
- FIELD_DEFINITIONS: Tuple of all field definitions (immutable)
|
||||||
|
- NUM_CLASSES: Total number of field classes (10)
|
||||||
|
- CLASS_NAMES: List of class names in order [0..9]
|
||||||
|
- FIELD_CLASSES: dict[int, str] - class_id to class_name
|
||||||
|
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
|
||||||
|
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
|
||||||
|
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
|
||||||
|
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
|
||||||
|
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .field_config import FieldDefinition, FIELD_DEFINITIONS, NUM_CLASSES
|
||||||
|
from .mappings import (
|
||||||
|
CLASS_NAMES,
|
||||||
|
FIELD_CLASSES,
|
||||||
|
FIELD_CLASS_IDS,
|
||||||
|
CLASS_TO_FIELD,
|
||||||
|
CSV_TO_CLASS_MAPPING,
|
||||||
|
TRAINING_FIELD_CLASSES,
|
||||||
|
ACCOUNT_FIELD_MAPPING,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FieldDefinition",
|
||||||
|
"FIELD_DEFINITIONS",
|
||||||
|
"NUM_CLASSES",
|
||||||
|
"CLASS_NAMES",
|
||||||
|
"FIELD_CLASSES",
|
||||||
|
"FIELD_CLASS_IDS",
|
||||||
|
"CLASS_TO_FIELD",
|
||||||
|
"CSV_TO_CLASS_MAPPING",
|
||||||
|
"TRAINING_FIELD_CLASSES",
|
||||||
|
"ACCOUNT_FIELD_MAPPING",
|
||||||
|
]
|
||||||
58
packages/shared/shared/fields/field_config.py
Normal file
58
packages/shared/shared/fields/field_config.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
Field Configuration - Single Source of Truth
|
||||||
|
|
||||||
|
This module defines all invoice field classes used throughout the system.
|
||||||
|
The class IDs are verified against the trained YOLO model (best.pt).
|
||||||
|
|
||||||
|
IMPORTANT: Do not modify class_id values without retraining the model!
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FieldDefinition:
|
||||||
|
"""Immutable field definition for invoice extraction.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
class_id: YOLO class ID (0-9), must match trained model
|
||||||
|
class_name: YOLO class name (lowercase_underscore)
|
||||||
|
field_name: Business field name used in API responses
|
||||||
|
csv_name: CSV column name for data import/export
|
||||||
|
is_derived: True if field is derived from other fields (not in CSV)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class_id: int
|
||||||
|
class_name: str
|
||||||
|
field_name: str
|
||||||
|
csv_name: str
|
||||||
|
is_derived: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# Verified from model weights (runs/train/invoice_fields/weights/best.pt)
|
||||||
|
# model.names = {0: 'invoice_number', 1: 'invoice_date', ..., 8: 'customer_number', 9: 'payment_line'}
|
||||||
|
#
|
||||||
|
# DO NOT CHANGE THE ORDER - it must match the trained model!
|
||||||
|
FIELD_DEFINITIONS: Final[tuple[FieldDefinition, ...]] = (
|
||||||
|
FieldDefinition(0, "invoice_number", "InvoiceNumber", "InvoiceNumber"),
|
||||||
|
FieldDefinition(1, "invoice_date", "InvoiceDate", "InvoiceDate"),
|
||||||
|
FieldDefinition(2, "invoice_due_date", "InvoiceDueDate", "InvoiceDueDate"),
|
||||||
|
FieldDefinition(3, "ocr_number", "OCR", "OCR"),
|
||||||
|
FieldDefinition(4, "bankgiro", "Bankgiro", "Bankgiro"),
|
||||||
|
FieldDefinition(5, "plusgiro", "Plusgiro", "Plusgiro"),
|
||||||
|
FieldDefinition(6, "amount", "Amount", "Amount"),
|
||||||
|
FieldDefinition(
|
||||||
|
7,
|
||||||
|
"supplier_org_number",
|
||||||
|
"supplier_organisation_number",
|
||||||
|
"supplier_organisation_number",
|
||||||
|
),
|
||||||
|
FieldDefinition(8, "customer_number", "customer_number", "customer_number"),
|
||||||
|
FieldDefinition(
|
||||||
|
9, "payment_line", "payment_line", "payment_line", is_derived=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Total number of field classes
|
||||||
|
NUM_CLASSES: Final[int] = len(FIELD_DEFINITIONS)
|
||||||
57
packages/shared/shared/fields/mappings.py
Normal file
57
packages/shared/shared/fields/mappings.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
Field Mappings - Auto-generated from FIELD_DEFINITIONS.
|
||||||
|
|
||||||
|
All mappings in this file are derived from field_config.FIELD_DEFINITIONS.
|
||||||
|
This ensures consistency across the entire codebase.
|
||||||
|
|
||||||
|
DO NOT hardcode field mappings elsewhere - always import from this module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
from .field_config import FIELD_DEFINITIONS
|
||||||
|
|
||||||
|
|
||||||
|
# List of class names in order (for YOLO classes.txt generation)
|
||||||
|
# Index matches class_id: CLASS_NAMES[0] = "invoice_number"
|
||||||
|
CLASS_NAMES: Final[list[str]] = [fd.class_name for fd in FIELD_DEFINITIONS]
|
||||||
|
|
||||||
|
# class_id -> class_name mapping
|
||||||
|
# Example: {0: "invoice_number", 1: "invoice_date", ...}
|
||||||
|
FIELD_CLASSES: Final[dict[int, str]] = {
|
||||||
|
fd.class_id: fd.class_name for fd in FIELD_DEFINITIONS
|
||||||
|
}
|
||||||
|
|
||||||
|
# class_name -> class_id mapping (reverse of FIELD_CLASSES)
|
||||||
|
# Example: {"invoice_number": 0, "invoice_date": 1, ...}
|
||||||
|
FIELD_CLASS_IDS: Final[dict[str, int]] = {
|
||||||
|
fd.class_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||||
|
}
|
||||||
|
|
||||||
|
# class_name -> field_name mapping (for API responses)
|
||||||
|
# Example: {"invoice_number": "InvoiceNumber", "ocr_number": "OCR", ...}
|
||||||
|
CLASS_TO_FIELD: Final[dict[str, str]] = {
|
||||||
|
fd.class_name: fd.field_name for fd in FIELD_DEFINITIONS
|
||||||
|
}
|
||||||
|
|
||||||
|
# field_name -> class_id mapping (for CSV import)
|
||||||
|
# Excludes derived fields like payment_line
|
||||||
|
# Example: {"InvoiceNumber": 0, "InvoiceDate": 1, ...}
|
||||||
|
CSV_TO_CLASS_MAPPING: Final[dict[str, int]] = {
|
||||||
|
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS if not fd.is_derived
|
||||||
|
}
|
||||||
|
|
||||||
|
# field_name -> class_id mapping (for training, includes all fields)
|
||||||
|
# Example: {"InvoiceNumber": 0, ..., "payment_line": 9}
|
||||||
|
TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
|
||||||
|
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||||
|
}
|
||||||
|
|
||||||
|
# Account field mapping for supplier_accounts special handling
|
||||||
|
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
|
||||||
|
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
|
||||||
|
"supplier_accounts": {
|
||||||
|
"BG": "Bankgiro",
|
||||||
|
"PG": "Plusgiro",
|
||||||
|
}
|
||||||
|
}
|
||||||
59
packages/shared/shared/storage/__init__.py
Normal file
59
packages/shared/shared/storage/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
Storage abstraction layer for training data.
|
||||||
|
|
||||||
|
Provides a unified interface for local filesystem, Azure Blob Storage, and AWS S3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from shared.storage.base import (
|
||||||
|
FileNotFoundStorageError,
|
||||||
|
PresignedUrlNotSupportedError,
|
||||||
|
StorageBackend,
|
||||||
|
StorageConfig,
|
||||||
|
StorageError,
|
||||||
|
)
|
||||||
|
from shared.storage.factory import (
|
||||||
|
create_storage_backend,
|
||||||
|
create_storage_backend_from_env,
|
||||||
|
create_storage_backend_from_file,
|
||||||
|
get_default_storage_config,
|
||||||
|
get_storage_backend,
|
||||||
|
)
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
from shared.storage.prefixes import PREFIXES, StoragePrefixes
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Base classes and exceptions
|
||||||
|
"StorageBackend",
|
||||||
|
"StorageConfig",
|
||||||
|
"StorageError",
|
||||||
|
"FileNotFoundStorageError",
|
||||||
|
"PresignedUrlNotSupportedError",
|
||||||
|
# Backends
|
||||||
|
"LocalStorageBackend",
|
||||||
|
# Factory functions
|
||||||
|
"create_storage_backend",
|
||||||
|
"create_storage_backend_from_env",
|
||||||
|
"create_storage_backend_from_file",
|
||||||
|
"get_default_storage_config",
|
||||||
|
"get_storage_backend",
|
||||||
|
# Path prefixes
|
||||||
|
"PREFIXES",
|
||||||
|
"StoragePrefixes",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Lazy imports to avoid dependencies when not using specific backends
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "AzureBlobStorageBackend":
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
return AzureBlobStorageBackend
|
||||||
|
if name == "S3StorageBackend":
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
return S3StorageBackend
|
||||||
|
if name == "load_storage_config":
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
return load_storage_config
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
335
packages/shared/shared/storage/azure.py
Normal file
335
packages/shared/shared/storage/azure.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""
|
||||||
|
Azure Blob Storage backend.
|
||||||
|
|
||||||
|
Provides storage operations using Azure Blob Storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from azure.storage.blob import (
|
||||||
|
BlobSasPermissions,
|
||||||
|
BlobServiceClient,
|
||||||
|
ContainerClient,
|
||||||
|
generate_blob_sas,
|
||||||
|
)
|
||||||
|
|
||||||
|
from shared.storage.base import (
|
||||||
|
FileNotFoundStorageError,
|
||||||
|
StorageBackend,
|
||||||
|
StorageError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureBlobStorageBackend(StorageBackend):
|
||||||
|
"""Storage backend using Azure Blob Storage.
|
||||||
|
|
||||||
|
Files are stored as blobs in an Azure Blob Storage container.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_string: str,
|
||||||
|
container_name: str,
|
||||||
|
create_container: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Azure Blob Storage backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_string: Azure Storage connection string.
|
||||||
|
container_name: Name of the blob container.
|
||||||
|
create_container: If True, create the container if it doesn't exist.
|
||||||
|
"""
|
||||||
|
self._connection_string = connection_string
|
||||||
|
self._container_name = container_name
|
||||||
|
|
||||||
|
self._blob_service = BlobServiceClient.from_connection_string(connection_string)
|
||||||
|
self._container = self._blob_service.get_container_client(container_name)
|
||||||
|
|
||||||
|
# Extract account key from connection string for SAS token generation
|
||||||
|
self._account_key = self._extract_account_key(connection_string)
|
||||||
|
|
||||||
|
if create_container and not self._container.exists():
|
||||||
|
self._container.create_container()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_account_key(connection_string: str) -> str | None:
|
||||||
|
"""Extract account key from connection string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_string: Azure Storage connection string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Account key if found, None otherwise.
|
||||||
|
"""
|
||||||
|
for part in connection_string.split(";"):
|
||||||
|
if part.startswith("AccountKey="):
|
||||||
|
return part[len("AccountKey=") :]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def container_name(self) -> str:
|
||||||
|
"""Get the container name for this storage backend."""
|
||||||
|
return self._container_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def container_client(self) -> ContainerClient:
|
||||||
|
"""Get the Azure container client."""
|
||||||
|
return self._container
|
||||||
|
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload a file to Azure Blob Storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Path to the local file to upload.
|
||||||
|
remote_path: Destination blob path.
|
||||||
|
overwrite: If True, overwrite existing blob.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the file was stored.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If local_path doesn't exist.
|
||||||
|
StorageError: If blob exists and overwrite is False.
|
||||||
|
"""
|
||||||
|
if not local_path.exists():
|
||||||
|
raise FileNotFoundStorageError(str(local_path))
|
||||||
|
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
|
||||||
|
if blob_client.exists() and not overwrite:
|
||||||
|
raise StorageError(f"File already exists: {remote_path}")
|
||||||
|
|
||||||
|
with open(local_path, "rb") as f:
|
||||||
|
blob_client.upload_blob(f, overwrite=overwrite)
|
||||||
|
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
"""Download a blob from Azure Blob Storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Blob path in storage.
|
||||||
|
local_path: Local destination path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The local path where the file was downloaded.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
|
||||||
|
if not blob_client.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
stream = blob_client.download_blob()
|
||||||
|
local_path.write_bytes(stream.readall())
|
||||||
|
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
"""Check if a blob exists in storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Blob path to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the blob exists, False otherwise.
|
||||||
|
"""
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
return blob_client.exists()
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
"""List blobs in storage with given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Blob path prefix to filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of blob paths matching the prefix.
|
||||||
|
"""
|
||||||
|
if prefix:
|
||||||
|
blobs = self._container.list_blobs(name_starts_with=prefix)
|
||||||
|
else:
|
||||||
|
blobs = self._container.list_blobs()
|
||||||
|
|
||||||
|
return [blob.name for blob in blobs]
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
"""Delete a blob from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Blob path to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if blob was deleted, False if it didn't exist.
|
||||||
|
"""
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
|
||||||
|
if not blob_client.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
blob_client.delete_blob()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
"""Get the URL for a blob.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Blob path in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL to access the blob.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
|
||||||
|
if not blob_client.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
return blob_client.url
|
||||||
|
|
||||||
|
def upload_bytes(
|
||||||
|
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload bytes directly to Azure Blob Storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Bytes to upload.
|
||||||
|
remote_path: Destination blob path.
|
||||||
|
overwrite: If True, overwrite existing blob.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the data was stored.
|
||||||
|
"""
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
|
||||||
|
if blob_client.exists() and not overwrite:
|
||||||
|
raise StorageError(f"File already exists: {remote_path}")
|
||||||
|
|
||||||
|
blob_client.upload_blob(data, overwrite=overwrite)
|
||||||
|
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download_bytes(self, remote_path: str) -> bytes:
|
||||||
|
"""Download a blob as bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Blob path in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The blob contents as bytes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
|
||||||
|
if not blob_client.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
stream = blob_client.download_blob()
|
||||||
|
return stream.readall()
|
||||||
|
|
||||||
|
def upload_directory(
|
||||||
|
self, local_dir: Path, remote_prefix: str, overwrite: bool = False
|
||||||
|
) -> list[str]:
|
||||||
|
"""Upload all files in a directory to Azure Blob Storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_dir: Local directory to upload.
|
||||||
|
remote_prefix: Prefix for remote blob paths.
|
||||||
|
overwrite: If True, overwrite existing blobs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of remote paths where files were stored.
|
||||||
|
"""
|
||||||
|
uploaded: list[str] = []
|
||||||
|
|
||||||
|
for file_path in local_dir.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
relative_path = file_path.relative_to(local_dir)
|
||||||
|
remote_path = f"{remote_prefix}{relative_path}".replace("\\", "/")
|
||||||
|
self.upload(file_path, remote_path, overwrite=overwrite)
|
||||||
|
uploaded.append(remote_path)
|
||||||
|
|
||||||
|
return uploaded
|
||||||
|
|
||||||
|
def download_directory(
|
||||||
|
self, remote_prefix: str, local_dir: Path
|
||||||
|
) -> list[Path]:
|
||||||
|
"""Download all blobs with a prefix to a local directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_prefix: Blob path prefix to download.
|
||||||
|
local_dir: Local directory to download to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of local paths where files were downloaded.
|
||||||
|
"""
|
||||||
|
downloaded: list[Path] = []
|
||||||
|
|
||||||
|
blobs = self.list_files(remote_prefix)
|
||||||
|
|
||||||
|
for blob_path in blobs:
|
||||||
|
# Remove prefix to get relative path
|
||||||
|
if remote_prefix:
|
||||||
|
relative_path = blob_path[len(remote_prefix):]
|
||||||
|
if relative_path.startswith("/"):
|
||||||
|
relative_path = relative_path[1:]
|
||||||
|
else:
|
||||||
|
relative_path = blob_path
|
||||||
|
|
||||||
|
local_path = local_dir / relative_path
|
||||||
|
self.download(blob_path, local_path)
|
||||||
|
downloaded.append(local_path)
|
||||||
|
|
||||||
|
return downloaded
|
||||||
|
|
||||||
|
def get_presigned_url(
|
||||||
|
self,
|
||||||
|
remote_path: str,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a SAS URL for temporary blob access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Blob path in storage.
|
||||||
|
expires_in_seconds: SAS token validity duration (1 to 604800 seconds / 7 days).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Blob URL with SAS token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
ValueError: If expires_in_seconds is out of valid range.
|
||||||
|
"""
|
||||||
|
if expires_in_seconds < 1 or expires_in_seconds > 604800:
|
||||||
|
raise ValueError(
|
||||||
|
"expires_in_seconds must be between 1 and 604800 (7 days)"
|
||||||
|
)
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
blob_client = self._container.get_blob_client(remote_path)
|
||||||
|
|
||||||
|
if not blob_client.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
# Generate SAS token
|
||||||
|
sas_token = generate_blob_sas(
|
||||||
|
account_name=self._blob_service.account_name,
|
||||||
|
container_name=self._container_name,
|
||||||
|
blob_name=remote_path,
|
||||||
|
account_key=self._account_key,
|
||||||
|
permission=BlobSasPermissions(read=True),
|
||||||
|
expiry=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds),
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"{blob_client.url}?{sas_token}"
|
||||||
229
packages/shared/shared/storage/base.py
Normal file
229
packages/shared/shared/storage/base.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""
|
||||||
|
Base classes and interfaces for storage backends.
|
||||||
|
|
||||||
|
Defines the abstract StorageBackend interface and common exceptions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class StorageError(Exception):
|
||||||
|
"""Base exception for storage operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FileNotFoundStorageError(StorageError):
|
||||||
|
"""Raised when a file is not found in storage."""
|
||||||
|
|
||||||
|
def __init__(self, path: str) -> None:
|
||||||
|
self.path = path
|
||||||
|
super().__init__(f"File not found in storage: {path}")
|
||||||
|
|
||||||
|
|
||||||
|
class PresignedUrlNotSupportedError(StorageError):
|
||||||
|
"""Raised when pre-signed URLs are not supported by a backend."""
|
||||||
|
|
||||||
|
def __init__(self, backend_type: str) -> None:
|
||||||
|
self.backend_type = backend_type
|
||||||
|
super().__init__(f"Pre-signed URLs not supported for backend: {backend_type}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StorageConfig:
|
||||||
|
"""Configuration for storage backend.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
backend_type: Type of storage backend ("local", "azure_blob", or "s3").
|
||||||
|
connection_string: Azure Blob Storage connection string (for azure_blob).
|
||||||
|
container_name: Azure Blob Storage container name (for azure_blob).
|
||||||
|
base_path: Base path for local storage (for local).
|
||||||
|
bucket_name: S3 bucket name (for s3).
|
||||||
|
region_name: AWS region name (for s3).
|
||||||
|
access_key_id: AWS access key ID (for s3).
|
||||||
|
secret_access_key: AWS secret access key (for s3).
|
||||||
|
endpoint_url: Custom endpoint URL for S3-compatible services (for s3).
|
||||||
|
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_type: str
|
||||||
|
connection_string: str | None = None
|
||||||
|
container_name: str | None = None
|
||||||
|
base_path: Path | None = None
|
||||||
|
bucket_name: str | None = None
|
||||||
|
region_name: str | None = None
|
||||||
|
access_key_id: str | None = None
|
||||||
|
secret_access_key: str | None = None
|
||||||
|
endpoint_url: str | None = None
|
||||||
|
presigned_url_expiry: int = 3600
|
||||||
|
|
||||||
|
|
||||||
|
class StorageBackend(ABC):
|
||||||
|
"""Abstract base class for storage backends.
|
||||||
|
|
||||||
|
Provides a unified interface for storing and retrieving files
|
||||||
|
from different storage systems (local filesystem, Azure Blob, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload a file to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Path to the local file to upload.
|
||||||
|
remote_path: Destination path in storage.
|
||||||
|
overwrite: If True, overwrite existing file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the file was stored.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If local_path doesn't exist.
|
||||||
|
StorageError: If file exists and overwrite is False.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
"""Download a file from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
local_path: Local destination path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The local path where the file was downloaded.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
"""Check if a file exists in storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to check in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file exists, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
"""List files in storage with given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Path prefix to filter files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of file paths matching the prefix.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
"""Delete a file from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file was deleted, False if it didn't exist.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
"""Get a URL or path to access a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL or path to access the file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_presigned_url(
|
||||||
|
self,
|
||||||
|
remote_path: str,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a pre-signed URL for temporary access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
expires_in_seconds: URL validity duration (default 1 hour).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pre-signed URL string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
PresignedUrlNotSupportedError: If backend doesn't support pre-signed URLs.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def upload_bytes(
|
||||||
|
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload bytes directly to storage.
|
||||||
|
|
||||||
|
Default implementation writes to temp file then uploads.
|
||||||
|
Subclasses may override for more efficient implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Bytes to upload.
|
||||||
|
remote_path: Destination path in storage.
|
||||||
|
overwrite: If True, overwrite existing file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the data was stored.
|
||||||
|
"""
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||||
|
f.write(data)
|
||||||
|
temp_path = Path(f.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.upload(temp_path, remote_path, overwrite=overwrite)
|
||||||
|
finally:
|
||||||
|
temp_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def download_bytes(self, remote_path: str) -> bytes:
|
||||||
|
"""Download a file as bytes.
|
||||||
|
|
||||||
|
Default implementation downloads to temp file then reads.
|
||||||
|
Subclasses may override for more efficient implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The file contents as bytes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||||
|
temp_path = Path(f.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.download(remote_path, temp_path)
|
||||||
|
return temp_path.read_bytes()
|
||||||
|
finally:
|
||||||
|
temp_path.unlink(missing_ok=True)
|
||||||
242
packages/shared/shared/storage/config_loader.py
Normal file
242
packages/shared/shared/storage/config_loader.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Configuration file loader for storage backends.
|
||||||
|
|
||||||
|
Supports YAML configuration files with environment variable substitution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LocalConfig:
|
||||||
|
"""Local storage backend configuration."""
|
||||||
|
|
||||||
|
base_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AzureConfig:
|
||||||
|
"""Azure Blob Storage configuration."""
|
||||||
|
|
||||||
|
connection_string: str
|
||||||
|
container_name: str
|
||||||
|
create_container: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class S3Config:
|
||||||
|
"""AWS S3 configuration."""
|
||||||
|
|
||||||
|
bucket_name: str
|
||||||
|
region_name: str | None = None
|
||||||
|
access_key_id: str | None = None
|
||||||
|
secret_access_key: str | None = None
|
||||||
|
endpoint_url: str | None = None
|
||||||
|
create_bucket: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StorageFileConfig:
|
||||||
|
"""Extended storage configuration from file.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
backend_type: Type of storage backend.
|
||||||
|
local: Local backend configuration.
|
||||||
|
azure: Azure Blob configuration.
|
||||||
|
s3: S3 configuration.
|
||||||
|
presigned_url_expiry: Default expiry for pre-signed URLs in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_type: str
|
||||||
|
local: LocalConfig | None = None
|
||||||
|
azure: AzureConfig | None = None
|
||||||
|
s3: S3Config | None = None
|
||||||
|
presigned_url_expiry: int = 3600
|
||||||
|
|
||||||
|
|
||||||
|
def substitute_env_vars(value: str) -> str:
|
||||||
|
"""Substitute environment variables in a string.
|
||||||
|
|
||||||
|
Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String potentially containing env var references.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String with env vars substituted.
|
||||||
|
"""
|
||||||
|
pattern = r"\$\{([A-Z_][A-Z0-9_]*)(?::-([^}]*))?\}"
|
||||||
|
|
||||||
|
def replace(match: re.Match[str]) -> str:
|
||||||
|
var_name = match.group(1)
|
||||||
|
default = match.group(2)
|
||||||
|
return os.environ.get(var_name, default or "")
|
||||||
|
|
||||||
|
return re.sub(pattern, replace, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _substitute_in_dict(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Recursively substitute env vars in a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New dictionary with substitutions applied.
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
result[key] = substitute_env_vars(value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
result[key] = _substitute_in_dict(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
result[key] = [
|
||||||
|
substitute_env_vars(item) if isinstance(item, str) else item
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_local_config(data: dict[str, Any]) -> LocalConfig:
|
||||||
|
"""Parse local configuration section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing local config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LocalConfig instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required fields are missing.
|
||||||
|
"""
|
||||||
|
base_path = data.get("base_path")
|
||||||
|
if not base_path:
|
||||||
|
raise ValueError("local.base_path is required")
|
||||||
|
return LocalConfig(base_path=Path(base_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_azure_config(data: dict[str, Any]) -> AzureConfig:
|
||||||
|
"""Parse Azure configuration section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing Azure config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AzureConfig instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required fields are missing.
|
||||||
|
"""
|
||||||
|
connection_string = data.get("connection_string")
|
||||||
|
container_name = data.get("container_name")
|
||||||
|
|
||||||
|
if not connection_string:
|
||||||
|
raise ValueError("azure.connection_string is required")
|
||||||
|
if not container_name:
|
||||||
|
raise ValueError("azure.container_name is required")
|
||||||
|
|
||||||
|
return AzureConfig(
|
||||||
|
connection_string=connection_string,
|
||||||
|
container_name=container_name,
|
||||||
|
create_container=data.get("create_container", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_s3_config(data: dict[str, Any]) -> S3Config:
|
||||||
|
"""Parse S3 configuration section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing S3 config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
S3Config instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required fields are missing.
|
||||||
|
"""
|
||||||
|
bucket_name = data.get("bucket_name")
|
||||||
|
|
||||||
|
if not bucket_name:
|
||||||
|
raise ValueError("s3.bucket_name is required")
|
||||||
|
|
||||||
|
return S3Config(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
region_name=data.get("region_name"),
|
||||||
|
access_key_id=data.get("access_key_id"),
|
||||||
|
secret_access_key=data.get("secret_access_key"),
|
||||||
|
endpoint_url=data.get("endpoint_url"),
|
||||||
|
create_bucket=data.get("create_bucket", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_storage_config(config_path: Path | str) -> StorageFileConfig:
|
||||||
|
"""Load storage configuration from YAML file.
|
||||||
|
|
||||||
|
Supports environment variable substitution using ${VAR_NAME} or
|
||||||
|
${VAR_NAME:-default} syntax.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed StorageFileConfig.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If config file doesn't exist.
|
||||||
|
ValueError: If config is invalid.
|
||||||
|
"""
|
||||||
|
config_path = Path(config_path)
|
||||||
|
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_content = config_path.read_text(encoding="utf-8")
|
||||||
|
data = yaml.safe_load(raw_content)
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Invalid YAML in config file: {e}") from e
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("Config file must contain a YAML dictionary")
|
||||||
|
|
||||||
|
# Substitute environment variables
|
||||||
|
data = _substitute_in_dict(data)
|
||||||
|
|
||||||
|
# Extract backend type
|
||||||
|
backend_type = data.get("backend")
|
||||||
|
if not backend_type:
|
||||||
|
raise ValueError("'backend' field is required in config file")
|
||||||
|
|
||||||
|
# Parse presigned URL expiry
|
||||||
|
presigned_url_expiry = data.get("presigned_url_expiry", 3600)
|
||||||
|
|
||||||
|
# Parse backend-specific configurations
|
||||||
|
local_config = None
|
||||||
|
azure_config = None
|
||||||
|
s3_config = None
|
||||||
|
|
||||||
|
if "local" in data:
|
||||||
|
local_config = _parse_local_config(data["local"])
|
||||||
|
|
||||||
|
if "azure" in data:
|
||||||
|
azure_config = _parse_azure_config(data["azure"])
|
||||||
|
|
||||||
|
if "s3" in data:
|
||||||
|
s3_config = _parse_s3_config(data["s3"])
|
||||||
|
|
||||||
|
return StorageFileConfig(
|
||||||
|
backend_type=backend_type,
|
||||||
|
local=local_config,
|
||||||
|
azure=azure_config,
|
||||||
|
s3=s3_config,
|
||||||
|
presigned_url_expiry=presigned_url_expiry,
|
||||||
|
)
|
||||||
296
packages/shared/shared/storage/factory.py
Normal file
296
packages/shared/shared/storage/factory.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""
|
||||||
|
Factory functions for creating storage backends.
|
||||||
|
|
||||||
|
Provides convenient functions for creating storage backends from
|
||||||
|
configuration or environment variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from shared.storage.base import StorageBackend, StorageConfig
|
||||||
|
|
||||||
|
|
||||||
|
def create_storage_backend(config: StorageConfig) -> StorageBackend:
|
||||||
|
"""Create a storage backend from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Storage configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured storage backend.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If configuration is invalid.
|
||||||
|
"""
|
||||||
|
if config.backend_type == "local":
|
||||||
|
if config.base_path is None:
|
||||||
|
raise ValueError("base_path is required for local storage backend")
|
||||||
|
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
return LocalStorageBackend(base_path=config.base_path)
|
||||||
|
|
||||||
|
elif config.backend_type == "azure_blob":
|
||||||
|
if config.connection_string is None:
|
||||||
|
raise ValueError(
|
||||||
|
"connection_string is required for Azure blob storage backend"
|
||||||
|
)
|
||||||
|
if config.container_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"container_name is required for Azure blob storage backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here to allow lazy loading of Azure SDK
|
||||||
|
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||||
|
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
return AzureBlobStorageBackend(
|
||||||
|
connection_string=config.connection_string,
|
||||||
|
container_name=config.container_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif config.backend_type == "s3":
|
||||||
|
if config.bucket_name is None:
|
||||||
|
raise ValueError("bucket_name is required for S3 storage backend")
|
||||||
|
|
||||||
|
# Import here to allow lazy loading of boto3
|
||||||
|
import boto3 # noqa: F401
|
||||||
|
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
return S3StorageBackend(
|
||||||
|
bucket_name=config.bucket_name,
|
||||||
|
region_name=config.region_name,
|
||||||
|
access_key_id=config.access_key_id,
|
||||||
|
secret_access_key=config.secret_access_key,
|
||||||
|
endpoint_url=config.endpoint_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown storage backend type: {config.backend_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_storage_config() -> StorageConfig:
|
||||||
|
"""Get storage configuration from environment variables.
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
|
||||||
|
STORAGE_BASE_PATH: Base path for local storage.
|
||||||
|
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
|
||||||
|
AZURE_STORAGE_CONTAINER: Azure container name.
|
||||||
|
AWS_S3_BUCKET: S3 bucket name.
|
||||||
|
AWS_REGION: AWS region name.
|
||||||
|
AWS_ACCESS_KEY_ID: AWS access key ID.
|
||||||
|
AWS_SECRET_ACCESS_KEY: AWS secret access key.
|
||||||
|
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StorageConfig from environment.
|
||||||
|
"""
|
||||||
|
backend_type = os.environ.get("STORAGE_BACKEND", "local")
|
||||||
|
|
||||||
|
if backend_type == "local":
|
||||||
|
base_path_str = os.environ.get("STORAGE_BASE_PATH")
|
||||||
|
# Expand ~ to home directory
|
||||||
|
base_path = Path(os.path.expanduser(base_path_str)) if base_path_str else None
|
||||||
|
|
||||||
|
return StorageConfig(
|
||||||
|
backend_type="local",
|
||||||
|
base_path=base_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend_type == "azure_blob":
|
||||||
|
return StorageConfig(
|
||||||
|
backend_type="azure_blob",
|
||||||
|
connection_string=os.environ.get("AZURE_STORAGE_CONNECTION_STRING"),
|
||||||
|
container_name=os.environ.get("AZURE_STORAGE_CONTAINER"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend_type == "s3":
|
||||||
|
return StorageConfig(
|
||||||
|
backend_type="s3",
|
||||||
|
bucket_name=os.environ.get("AWS_S3_BUCKET"),
|
||||||
|
region_name=os.environ.get("AWS_REGION"),
|
||||||
|
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||||
|
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
endpoint_url=os.environ.get("AWS_ENDPOINT_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return StorageConfig(backend_type=backend_type)
|
||||||
|
|
||||||
|
|
||||||
|
def create_storage_backend_from_env() -> StorageBackend:
|
||||||
|
"""Create a storage backend from environment variables.
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local".
|
||||||
|
STORAGE_BASE_PATH: Base path for local storage.
|
||||||
|
AZURE_STORAGE_CONNECTION_STRING: Azure connection string.
|
||||||
|
AZURE_STORAGE_CONTAINER: Azure container name.
|
||||||
|
AWS_S3_BUCKET: S3 bucket name.
|
||||||
|
AWS_REGION: AWS region name.
|
||||||
|
AWS_ACCESS_KEY_ID: AWS access key ID.
|
||||||
|
AWS_SECRET_ACCESS_KEY: AWS secret access key.
|
||||||
|
AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured storage backend.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required environment variables are missing or empty.
|
||||||
|
"""
|
||||||
|
backend_type = os.environ.get("STORAGE_BACKEND", "local").strip()
|
||||||
|
|
||||||
|
if backend_type == "local":
|
||||||
|
base_path = os.environ.get("STORAGE_BASE_PATH", "").strip()
|
||||||
|
if not base_path:
|
||||||
|
raise ValueError(
|
||||||
|
"STORAGE_BASE_PATH environment variable is required and cannot be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expand ~ to home directory
|
||||||
|
base_path_expanded = os.path.expanduser(base_path)
|
||||||
|
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
return LocalStorageBackend(base_path=Path(base_path_expanded))
|
||||||
|
|
||||||
|
elif backend_type == "azure_blob":
|
||||||
|
connection_string = os.environ.get(
|
||||||
|
"AZURE_STORAGE_CONNECTION_STRING", ""
|
||||||
|
).strip()
|
||||||
|
if not connection_string:
|
||||||
|
raise ValueError(
|
||||||
|
"AZURE_STORAGE_CONNECTION_STRING environment variable is required "
|
||||||
|
"and cannot be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
container_name = os.environ.get("AZURE_STORAGE_CONTAINER", "").strip()
|
||||||
|
if not container_name:
|
||||||
|
raise ValueError(
|
||||||
|
"AZURE_STORAGE_CONTAINER environment variable is required "
|
||||||
|
"and cannot be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here to allow lazy loading of Azure SDK
|
||||||
|
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||||
|
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
return AzureBlobStorageBackend(
|
||||||
|
connection_string=connection_string,
|
||||||
|
container_name=container_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif backend_type == "s3":
|
||||||
|
bucket_name = os.environ.get("AWS_S3_BUCKET", "").strip()
|
||||||
|
if not bucket_name:
|
||||||
|
raise ValueError(
|
||||||
|
"AWS_S3_BUCKET environment variable is required and cannot be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here to allow lazy loading of boto3
|
||||||
|
import boto3 # noqa: F401
|
||||||
|
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
return S3StorageBackend(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
region_name=os.environ.get("AWS_REGION", "").strip() or None,
|
||||||
|
access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "").strip() or None,
|
||||||
|
secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "").strip()
|
||||||
|
or None,
|
||||||
|
endpoint_url=os.environ.get("AWS_ENDPOINT_URL", "").strip() or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown storage backend type: {backend_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_storage_backend_from_file(config_path: Path | str) -> StorageBackend:
|
||||||
|
"""Create a storage backend from a configuration file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to YAML configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured storage backend.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If config file doesn't exist.
|
||||||
|
ValueError: If configuration is invalid.
|
||||||
|
"""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
file_config = load_storage_config(config_path)
|
||||||
|
|
||||||
|
if file_config.backend_type == "local":
|
||||||
|
if file_config.local is None:
|
||||||
|
raise ValueError("local configuration section is required")
|
||||||
|
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
return LocalStorageBackend(base_path=file_config.local.base_path)
|
||||||
|
|
||||||
|
elif file_config.backend_type == "azure_blob":
|
||||||
|
if file_config.azure is None:
|
||||||
|
raise ValueError("azure configuration section is required")
|
||||||
|
|
||||||
|
# Import here to allow lazy loading of Azure SDK
|
||||||
|
from azure.storage.blob import BlobServiceClient # noqa: F401
|
||||||
|
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
return AzureBlobStorageBackend(
|
||||||
|
connection_string=file_config.azure.connection_string,
|
||||||
|
container_name=file_config.azure.container_name,
|
||||||
|
create_container=file_config.azure.create_container,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif file_config.backend_type == "s3":
|
||||||
|
if file_config.s3 is None:
|
||||||
|
raise ValueError("s3 configuration section is required")
|
||||||
|
|
||||||
|
# Import here to allow lazy loading of boto3
|
||||||
|
import boto3 # noqa: F401
|
||||||
|
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
return S3StorageBackend(
|
||||||
|
bucket_name=file_config.s3.bucket_name,
|
||||||
|
region_name=file_config.s3.region_name,
|
||||||
|
access_key_id=file_config.s3.access_key_id,
|
||||||
|
secret_access_key=file_config.s3.secret_access_key,
|
||||||
|
endpoint_url=file_config.s3.endpoint_url,
|
||||||
|
create_bucket=file_config.s3.create_bucket,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown storage backend type: {file_config.backend_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_backend(config_path: Path | str | None = None) -> StorageBackend:
|
||||||
|
"""Get storage backend with fallback chain.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. Config file (if provided)
|
||||||
|
2. Environment variables
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Optional path to config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured storage backend.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If configuration is invalid.
|
||||||
|
FileNotFoundError: If specified config file doesn't exist.
|
||||||
|
"""
|
||||||
|
if config_path:
|
||||||
|
return create_storage_backend_from_file(config_path)
|
||||||
|
|
||||||
|
# Fall back to environment variables
|
||||||
|
return create_storage_backend_from_env()
|
||||||
262
packages/shared/shared/storage/local.py
Normal file
262
packages/shared/shared/storage/local.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
Local filesystem storage backend.
|
||||||
|
|
||||||
|
Provides storage operations using the local filesystem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from shared.storage.base import (
|
||||||
|
FileNotFoundStorageError,
|
||||||
|
StorageBackend,
|
||||||
|
StorageError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalStorageBackend(StorageBackend):
|
||||||
|
"""Storage backend using local filesystem.
|
||||||
|
|
||||||
|
Files are stored relative to a base path on the local filesystem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_path: str | Path) -> None:
|
||||||
|
"""Initialize local storage backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: Base directory for all storage operations.
|
||||||
|
Will be created if it doesn't exist.
|
||||||
|
"""
|
||||||
|
self._base_path = Path(base_path)
|
||||||
|
self._base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_path(self) -> Path:
|
||||||
|
"""Get the base path for this storage backend."""
|
||||||
|
return self._base_path
|
||||||
|
|
||||||
|
def _get_full_path(self, remote_path: str) -> Path:
|
||||||
|
"""Convert a remote path to a full local path with security validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: The remote path to resolve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The full local path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If the path attempts to escape the base directory.
|
||||||
|
"""
|
||||||
|
# Reject absolute paths
|
||||||
|
if remote_path.startswith("/") or (len(remote_path) > 1 and remote_path[1] == ":"):
|
||||||
|
raise StorageError(f"Absolute paths not allowed: {remote_path}")
|
||||||
|
|
||||||
|
# Resolve to prevent path traversal attacks
|
||||||
|
full_path = (self._base_path / remote_path).resolve()
|
||||||
|
base_resolved = self._base_path.resolve()
|
||||||
|
|
||||||
|
# Verify the resolved path is within base_path
|
||||||
|
try:
|
||||||
|
full_path.relative_to(base_resolved)
|
||||||
|
except ValueError:
|
||||||
|
raise StorageError(f"Path traversal not allowed: {remote_path}")
|
||||||
|
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload a file to local storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Path to the local file to upload.
|
||||||
|
remote_path: Destination path in storage.
|
||||||
|
overwrite: If True, overwrite existing file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the file was stored.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If local_path doesn't exist.
|
||||||
|
StorageError: If file exists and overwrite is False.
|
||||||
|
"""
|
||||||
|
if not local_path.exists():
|
||||||
|
raise FileNotFoundStorageError(str(local_path))
|
||||||
|
|
||||||
|
dest_path = self._get_full_path(remote_path)
|
||||||
|
|
||||||
|
if dest_path.exists() and not overwrite:
|
||||||
|
raise StorageError(f"File already exists: {remote_path}")
|
||||||
|
|
||||||
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(local_path, dest_path)
|
||||||
|
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
"""Download a file from local storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
local_path: Local destination path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The local path where the file was downloaded.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
source_path = self._get_full_path(remote_path)
|
||||||
|
|
||||||
|
if not source_path.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(source_path, local_path)
|
||||||
|
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
"""Check if a file exists in storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to check in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file exists, False otherwise.
|
||||||
|
"""
|
||||||
|
return self._get_full_path(remote_path).exists()
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
"""List files in storage with given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Path prefix to filter files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of file paths matching the prefix.
|
||||||
|
"""
|
||||||
|
if prefix:
|
||||||
|
search_path = self._get_full_path(prefix)
|
||||||
|
if not search_path.exists():
|
||||||
|
return []
|
||||||
|
base_for_relative = self._base_path
|
||||||
|
else:
|
||||||
|
search_path = self._base_path
|
||||||
|
base_for_relative = self._base_path
|
||||||
|
|
||||||
|
files: list[str] = []
|
||||||
|
if search_path.is_file():
|
||||||
|
files.append(str(search_path.relative_to(self._base_path)))
|
||||||
|
elif search_path.is_dir():
|
||||||
|
for file_path in search_path.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
relative_path = file_path.relative_to(self._base_path)
|
||||||
|
files.append(str(relative_path).replace("\\", "/"))
|
||||||
|
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
"""Delete a file from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file was deleted, False if it didn't exist.
|
||||||
|
"""
|
||||||
|
file_path = self._get_full_path(remote_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_path.unlink()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
"""Get a file:// URL to access a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file:// URL to access the file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
file_path = self._get_full_path(remote_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
return file_path.as_uri()
|
||||||
|
|
||||||
|
def upload_bytes(
|
||||||
|
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload bytes directly to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Bytes to upload.
|
||||||
|
remote_path: Destination path in storage.
|
||||||
|
overwrite: If True, overwrite existing file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the data was stored.
|
||||||
|
"""
|
||||||
|
dest_path = self._get_full_path(remote_path)
|
||||||
|
|
||||||
|
if dest_path.exists() and not overwrite:
|
||||||
|
raise StorageError(f"File already exists: {remote_path}")
|
||||||
|
|
||||||
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
dest_path.write_bytes(data)
|
||||||
|
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download_bytes(self, remote_path: str) -> bytes:
|
||||||
|
"""Download a file as bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The file contents as bytes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
file_path = self._get_full_path(remote_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
return file_path.read_bytes()
|
||||||
|
|
||||||
|
def get_presigned_url(
|
||||||
|
self,
|
||||||
|
remote_path: str,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Get a file:// URL for local file access.
|
||||||
|
|
||||||
|
For local storage, this returns a file:// URI.
|
||||||
|
Note: Local file:// URLs don't actually expire.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Path to the file in storage.
|
||||||
|
expires_in_seconds: Ignored for local storage (URLs don't expire).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file:// URL to access the file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
file_path = self._get_full_path(remote_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
return file_path.as_uri()
|
||||||
158
packages/shared/shared/storage/prefixes.py
Normal file
158
packages/shared/shared/storage/prefixes.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
Storage path prefixes for unified file organization.
|
||||||
|
|
||||||
|
Provides standardized path prefixes for organizing files within
|
||||||
|
the storage backend, ensuring consistent structure across
|
||||||
|
local, Azure Blob, and S3 storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StoragePrefixes:
|
||||||
|
"""Standardized storage path prefixes.
|
||||||
|
|
||||||
|
All paths are relative to the storage backend root.
|
||||||
|
These prefixes ensure consistent file organization across
|
||||||
|
all storage backends (local, Azure, S3).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from shared.storage.prefixes import PREFIXES
|
||||||
|
|
||||||
|
path = f"{PREFIXES.DOCUMENTS}/{document_id}.pdf"
|
||||||
|
storage.upload_bytes(content, path)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Document storage
|
||||||
|
DOCUMENTS: str = "documents"
|
||||||
|
"""Original document files (PDFs, etc.)"""
|
||||||
|
|
||||||
|
IMAGES: str = "images"
|
||||||
|
"""Page images extracted from documents"""
|
||||||
|
|
||||||
|
# Processing directories
|
||||||
|
UPLOADS: str = "uploads"
|
||||||
|
"""Temporary upload staging area"""
|
||||||
|
|
||||||
|
RESULTS: str = "results"
|
||||||
|
"""Inference results and visualizations"""
|
||||||
|
|
||||||
|
EXPORTS: str = "exports"
|
||||||
|
"""Exported datasets and annotations"""
|
||||||
|
|
||||||
|
# Training data
|
||||||
|
DATASETS: str = "datasets"
|
||||||
|
"""Training dataset files"""
|
||||||
|
|
||||||
|
MODELS: str = "models"
|
||||||
|
"""Trained model weights and checkpoints"""
|
||||||
|
|
||||||
|
# Data pipeline directories (legacy compatibility)
|
||||||
|
RAW_PDFS: str = "raw_pdfs"
|
||||||
|
"""Raw PDF files for auto-labeling pipeline"""
|
||||||
|
|
||||||
|
STRUCTURED_DATA: str = "structured_data"
|
||||||
|
"""CSV/structured data for matching"""
|
||||||
|
|
||||||
|
ADMIN_IMAGES: str = "admin_images"
|
||||||
|
"""Admin UI page images"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def document_path(document_id: str, extension: str = ".pdf") -> str:
|
||||||
|
"""Get path for a document file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Unique document identifier.
|
||||||
|
extension: File extension (include leading dot).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "documents/abc123.pdf"
|
||||||
|
"""
|
||||||
|
ext = extension if extension.startswith(".") else f".{extension}"
|
||||||
|
return f"{PREFIXES.DOCUMENTS}/{document_id}{ext}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def image_path(document_id: str, page_num: int, extension: str = ".png") -> str:
|
||||||
|
"""Get path for a page image file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id: Unique document identifier.
|
||||||
|
page_num: Page number (1-indexed).
|
||||||
|
extension: File extension (include leading dot).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "images/abc123/page_1.png"
|
||||||
|
"""
|
||||||
|
ext = extension if extension.startswith(".") else f".{extension}"
|
||||||
|
return f"{PREFIXES.IMAGES}/{document_id}/page_{page_num}{ext}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upload_path(filename: str, subfolder: str | None = None) -> str:
|
||||||
|
"""Get path for a temporary upload file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Original filename.
|
||||||
|
subfolder: Optional subfolder (e.g., "async").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "uploads/filename.pdf" or "uploads/async/filename.pdf"
|
||||||
|
"""
|
||||||
|
if subfolder:
|
||||||
|
return f"{PREFIXES.UPLOADS}/{subfolder}/{filename}"
|
||||||
|
return f"{PREFIXES.UPLOADS}/{filename}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def result_path(filename: str) -> str:
|
||||||
|
"""Get path for a result file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Result filename.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "results/filename.json"
|
||||||
|
"""
|
||||||
|
return f"{PREFIXES.RESULTS}/{filename}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_path(export_id: str, filename: str) -> str:
|
||||||
|
"""Get path for an export file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_id: Unique export identifier.
|
||||||
|
filename: Export filename.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "exports/abc123/filename.zip"
|
||||||
|
"""
|
||||||
|
return f"{PREFIXES.EXPORTS}/{export_id}/{filename}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dataset_path(dataset_id: str, filename: str) -> str:
|
||||||
|
"""Get path for a dataset file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Unique dataset identifier.
|
||||||
|
filename: Dataset filename.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "datasets/abc123/filename.yaml"
|
||||||
|
"""
|
||||||
|
return f"{PREFIXES.DATASETS}/{dataset_id}/{filename}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def model_path(version: str, filename: str) -> str:
|
||||||
|
"""Get path for a model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Model version string.
|
||||||
|
filename: Model filename.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path like "models/v1.0.0/best.pt"
|
||||||
|
"""
|
||||||
|
return f"{PREFIXES.MODELS}/{version}/{filename}"
|
||||||
|
|
||||||
|
|
||||||
|
# Default instance for convenient access
|
||||||
|
PREFIXES = StoragePrefixes()
|
||||||
309
packages/shared/shared/storage/s3.py
Normal file
309
packages/shared/shared/storage/s3.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
"""
|
||||||
|
AWS S3 Storage backend.
|
||||||
|
|
||||||
|
Provides storage operations using AWS S3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from mypy_boto3_s3 import S3Client
|
||||||
|
|
||||||
|
from shared.storage.base import (
|
||||||
|
FileNotFoundStorageError,
|
||||||
|
StorageBackend,
|
||||||
|
StorageError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class S3StorageBackend(StorageBackend):
|
||||||
|
"""Storage backend using AWS S3.
|
||||||
|
|
||||||
|
Files are stored as objects in an S3 bucket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bucket_name: str,
|
||||||
|
region_name: str | None = None,
|
||||||
|
access_key_id: str | None = None,
|
||||||
|
secret_access_key: str | None = None,
|
||||||
|
endpoint_url: str | None = None,
|
||||||
|
create_bucket: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize S3 storage backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bucket_name: Name of the S3 bucket.
|
||||||
|
region_name: AWS region name (optional, uses default if not set).
|
||||||
|
access_key_id: AWS access key ID (optional, uses credentials chain).
|
||||||
|
secret_access_key: AWS secret access key (optional).
|
||||||
|
endpoint_url: Custom endpoint URL (for S3-compatible services).
|
||||||
|
create_bucket: If True, create the bucket if it doesn't exist.
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
self._bucket_name = bucket_name
|
||||||
|
self._region_name = region_name
|
||||||
|
|
||||||
|
# Build client kwargs
|
||||||
|
client_kwargs: dict[str, Any] = {}
|
||||||
|
if region_name:
|
||||||
|
client_kwargs["region_name"] = region_name
|
||||||
|
if endpoint_url:
|
||||||
|
client_kwargs["endpoint_url"] = endpoint_url
|
||||||
|
if access_key_id and secret_access_key:
|
||||||
|
client_kwargs["aws_access_key_id"] = access_key_id
|
||||||
|
client_kwargs["aws_secret_access_key"] = secret_access_key
|
||||||
|
|
||||||
|
self._s3: "S3Client" = boto3.client("s3", **client_kwargs)
|
||||||
|
|
||||||
|
if create_bucket:
|
||||||
|
self._ensure_bucket_exists()
|
||||||
|
|
||||||
|
def _ensure_bucket_exists(self) -> None:
|
||||||
|
"""Create the bucket if it doesn't exist."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._s3.head_bucket(Bucket=self._bucket_name)
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response.get("Error", {}).get("Code", "")
|
||||||
|
if error_code in ("404", "NoSuchBucket"):
|
||||||
|
# Bucket doesn't exist, create it
|
||||||
|
create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
|
||||||
|
if self._region_name and self._region_name != "us-east-1":
|
||||||
|
create_kwargs["CreateBucketConfiguration"] = {
|
||||||
|
"LocationConstraint": self._region_name
|
||||||
|
}
|
||||||
|
self._s3.create_bucket(**create_kwargs)
|
||||||
|
else:
|
||||||
|
# Re-raise permission errors, network issues, etc.
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _object_exists(self, key: str) -> bool:
|
||||||
|
"""Check if an object exists in S3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Object key to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if object exists, False otherwise.
|
||||||
|
"""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._s3.head_object(Bucket=self._bucket_name, Key=key)
|
||||||
|
return True
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response.get("Error", {}).get("Code", "")
|
||||||
|
if error_code in ("404", "NoSuchKey"):
|
||||||
|
return False
|
||||||
|
raise
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bucket_name(self) -> str:
|
||||||
|
"""Get the bucket name for this storage backend."""
|
||||||
|
return self._bucket_name
|
||||||
|
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload a file to S3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Path to the local file to upload.
|
||||||
|
remote_path: Destination object key.
|
||||||
|
overwrite: If True, overwrite existing object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the file was stored.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If local_path doesn't exist.
|
||||||
|
StorageError: If object exists and overwrite is False.
|
||||||
|
"""
|
||||||
|
if not local_path.exists():
|
||||||
|
raise FileNotFoundStorageError(str(local_path))
|
||||||
|
|
||||||
|
if not overwrite and self._object_exists(remote_path):
|
||||||
|
raise StorageError(f"File already exists: {remote_path}")
|
||||||
|
|
||||||
|
self._s3.upload_file(str(local_path), self._bucket_name, remote_path)
|
||||||
|
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
"""Download an object from S3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Object key in S3.
|
||||||
|
local_path: Local destination path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The local path where the file was downloaded.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
if not self._object_exists(remote_path):
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self._s3.download_file(self._bucket_name, remote_path, str(local_path))
|
||||||
|
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
"""Check if an object exists in S3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Object key to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the object exists, False otherwise.
|
||||||
|
"""
|
||||||
|
return self._object_exists(remote_path)
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
"""List objects in S3 with given prefix.
|
||||||
|
|
||||||
|
Handles pagination to return all matching objects (S3 returns max 1000 per request).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Object key prefix to filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of object keys matching the prefix.
|
||||||
|
"""
|
||||||
|
kwargs: dict[str, Any] = {"Bucket": self._bucket_name}
|
||||||
|
if prefix:
|
||||||
|
kwargs["Prefix"] = prefix
|
||||||
|
|
||||||
|
all_keys: list[str] = []
|
||||||
|
while True:
|
||||||
|
response = self._s3.list_objects_v2(**kwargs)
|
||||||
|
contents = response.get("Contents", [])
|
||||||
|
all_keys.extend(obj["Key"] for obj in contents)
|
||||||
|
|
||||||
|
if not response.get("IsTruncated"):
|
||||||
|
break
|
||||||
|
kwargs["ContinuationToken"] = response["NextContinuationToken"]
|
||||||
|
|
||||||
|
return all_keys
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
"""Delete an object from S3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Object key to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if object was deleted, False if it didn't exist.
|
||||||
|
"""
|
||||||
|
if not self._object_exists(remote_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
"""Get a URL for an object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Object key in S3.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL to access the object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
if not self._object_exists(remote_path):
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
return self._s3.generate_presigned_url(
|
||||||
|
"get_object",
|
||||||
|
Params={"Bucket": self._bucket_name, "Key": remote_path},
|
||||||
|
ExpiresIn=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_presigned_url(
|
||||||
|
self,
|
||||||
|
remote_path: str,
|
||||||
|
expires_in_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a pre-signed URL for temporary object access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Object key in S3.
|
||||||
|
expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pre-signed URL string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
ValueError: If expires_in_seconds is out of valid range.
|
||||||
|
"""
|
||||||
|
if expires_in_seconds < 1 or expires_in_seconds > 604800:
|
||||||
|
raise ValueError(
|
||||||
|
"expires_in_seconds must be between 1 and 604800 (7 days)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._object_exists(remote_path):
|
||||||
|
raise FileNotFoundStorageError(remote_path)
|
||||||
|
|
||||||
|
return self._s3.generate_presigned_url(
|
||||||
|
"get_object",
|
||||||
|
Params={"Bucket": self._bucket_name, "Key": remote_path},
|
||||||
|
ExpiresIn=expires_in_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_bytes(
|
||||||
|
self, data: bytes, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Upload bytes directly to S3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Bytes to upload.
|
||||||
|
remote_path: Destination object key.
|
||||||
|
overwrite: If True, overwrite existing object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remote path where the data was stored.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If object exists and overwrite is False.
|
||||||
|
"""
|
||||||
|
if not overwrite and self._object_exists(remote_path):
|
||||||
|
raise StorageError(f"File already exists: {remote_path}")
|
||||||
|
|
||||||
|
self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data)
|
||||||
|
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download_bytes(self, remote_path: str) -> bytes:
|
||||||
|
"""Download an object as bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_path: Object key in S3.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The object contents as bytes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundStorageError: If remote_path doesn't exist.
|
||||||
|
"""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path)
|
||||||
|
return response["Body"].read()
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response.get("Error", {}).get("Code", "")
|
||||||
|
if error_code in ("404", "NoSuchKey"):
|
||||||
|
raise FileNotFoundStorageError(remote_path) from e
|
||||||
|
raise
|
||||||
@@ -20,7 +20,7 @@ from shared.config import get_db_connection_string
|
|||||||
from shared.normalize import normalize_field
|
from shared.normalize import normalize_field
|
||||||
from shared.matcher import FieldMatcher
|
from shared.matcher import FieldMatcher
|
||||||
from shared.pdf import is_text_pdf, extract_text_tokens
|
from shared.pdf import is_text_pdf, extract_text_tokens
|
||||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||||
from shared.data.db import DocumentDB
|
from shared.data.db import DocumentDB
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ def process_single_document(args_tuple):
|
|||||||
# Import inside worker to avoid pickling issues
|
# Import inside worker to avoid pickling issues
|
||||||
from training.data.autolabel_report import AutoLabelReport
|
from training.data.autolabel_report import AutoLabelReport
|
||||||
from shared.pdf import PDFDocument
|
from shared.pdf import PDFDocument
|
||||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -342,7 +342,8 @@ def main():
|
|||||||
from shared.ocr import OCREngine
|
from shared.ocr import OCREngine
|
||||||
from shared.matcher import FieldMatcher
|
from shared.matcher import FieldMatcher
|
||||||
from shared.normalize import normalize_field
|
from shared.normalize import normalize_field
|
||||||
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||||
|
from training.yolo.annotation_generator import AnnotationGenerator
|
||||||
|
|
||||||
# Handle comma-separated CSV paths
|
# Handle comma-separated CSV paths
|
||||||
csv_input = args.csv
|
csv_input = args.csv
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
import shutil
|
import shutil
|
||||||
from training.data.autolabel_report import AutoLabelReport
|
from training.data.autolabel_report import AutoLabelReport
|
||||||
from shared.pdf import PDFDocument
|
from shared.pdf import PDFDocument
|
||||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||||
|
|
||||||
row_dict = task_data["row_dict"]
|
row_dict = task_data["row_dict"]
|
||||||
@@ -208,7 +208,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
import shutil
|
import shutil
|
||||||
from training.data.autolabel_report import AutoLabelReport
|
from training.data.autolabel_report import AutoLabelReport
|
||||||
from shared.pdf import PDFDocument
|
from shared.pdf import PDFDocument
|
||||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||||
|
|
||||||
row_dict = task_data["row_dict"]
|
row_dict = task_data["row_dict"]
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ from training.data.autolabel_report import FieldMatchResult
|
|||||||
from shared.matcher import FieldMatcher
|
from shared.matcher import FieldMatcher
|
||||||
from shared.normalize import normalize_field
|
from shared.normalize import normalize_field
|
||||||
from shared.ocr.machine_code_parser import MachineCodeParser
|
from shared.ocr.machine_code_parser import MachineCodeParser
|
||||||
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||||
|
from training.yolo.annotation_generator import AnnotationGenerator
|
||||||
|
|
||||||
|
|
||||||
def match_supplier_accounts(
|
def match_supplier_accounts(
|
||||||
|
|||||||
@@ -9,43 +9,12 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
|
# Import field mappings from single source of truth
|
||||||
# Field class mapping for YOLO
|
from shared.fields import (
|
||||||
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
|
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
||||||
FIELD_CLASSES = {
|
CLASS_NAMES,
|
||||||
'InvoiceNumber': 0,
|
ACCOUNT_FIELD_MAPPING,
|
||||||
'InvoiceDate': 1,
|
)
|
||||||
'InvoiceDueDate': 2,
|
|
||||||
'OCR': 3,
|
|
||||||
'Bankgiro': 4,
|
|
||||||
'Plusgiro': 5,
|
|
||||||
'Amount': 6,
|
|
||||||
'supplier_organisation_number': 7,
|
|
||||||
'customer_number': 8,
|
|
||||||
'payment_line': 9, # Machine code payment line at bottom of invoice
|
|
||||||
}
|
|
||||||
|
|
||||||
# Fields that need matching but map to other YOLO classes
|
|
||||||
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
|
|
||||||
ACCOUNT_FIELD_MAPPING = {
|
|
||||||
'supplier_accounts': {
|
|
||||||
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
|
|
||||||
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CLASS_NAMES = [
|
|
||||||
'invoice_number',
|
|
||||||
'invoice_date',
|
|
||||||
'invoice_due_date',
|
|
||||||
'ocr_number',
|
|
||||||
'bankgiro',
|
|
||||||
'plusgiro',
|
|
||||||
'amount',
|
|
||||||
'supplier_org_number',
|
|
||||||
'customer_number',
|
|
||||||
'payment_line', # Machine code payment line at bottom of invoice
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -101,7 +101,8 @@ class DatasetBuilder:
|
|||||||
Returns:
|
Returns:
|
||||||
DatasetStats with build results
|
DatasetStats with build results
|
||||||
"""
|
"""
|
||||||
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
|
from shared.fields import CLASS_NAMES
|
||||||
|
from .annotation_generator import AnnotationGenerator
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from shared.config import DEFAULT_DPI
|
from shared.config import DEFAULT_DPI
|
||||||
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation
|
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||||
|
from .annotation_generator import YOLOAnnotation
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
73
run_migration.py
Normal file
73
run_migration.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Run database migration for training_status fields."""
|
||||||
|
import psycopg2
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Read password from .env file
|
||||||
|
password = ""
|
||||||
|
try:
|
||||||
|
with open(".env") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.startswith("DB_PASSWORD="):
|
||||||
|
password = line.strip().split("=", 1)[1].strip('"').strip("'")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading .env: {e}")
|
||||||
|
|
||||||
|
print(f"Password found: {bool(password)}")
|
||||||
|
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host="192.168.68.31",
|
||||||
|
port=5432,
|
||||||
|
database="docmaster",
|
||||||
|
user="docmaster",
|
||||||
|
password=password
|
||||||
|
)
|
||||||
|
conn.autocommit = True
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
# Add training_status column
|
||||||
|
try:
|
||||||
|
cur.execute("ALTER TABLE training_datasets ADD COLUMN training_status VARCHAR(20) DEFAULT NULL")
|
||||||
|
print("Added training_status column")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"training_status: {e}")
|
||||||
|
|
||||||
|
# Add active_training_task_id column
|
||||||
|
try:
|
||||||
|
cur.execute("ALTER TABLE training_datasets ADD COLUMN active_training_task_id UUID DEFAULT NULL")
|
||||||
|
print("Added active_training_task_id column")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"active_training_task_id: {e}")
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
try:
|
||||||
|
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status)")
|
||||||
|
print("Created training_status index")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"index training_status: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id)")
|
||||||
|
print("Created active_training_task_id index")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"index active_training_task_id: {e}")
|
||||||
|
|
||||||
|
# Update existing datasets that have been used in completed training tasks to trained status
|
||||||
|
try:
|
||||||
|
cur.execute("""
|
||||||
|
UPDATE training_datasets d
|
||||||
|
SET status = 'trained'
|
||||||
|
WHERE d.status = 'ready'
|
||||||
|
AND EXISTS (
|
||||||
|
SELECT 1 FROM training_tasks t
|
||||||
|
WHERE t.dataset_id = d.dataset_id
|
||||||
|
AND t.status = 'completed'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
print(f"Updated {cur.rowcount} datasets to trained status")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"update status: {e}")
|
||||||
|
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
print("Migration complete!")
|
||||||
@@ -17,9 +17,8 @@ from inference.data.admin_models import (
|
|||||||
AdminDocument,
|
AdminDocument,
|
||||||
AdminAnnotation,
|
AdminAnnotation,
|
||||||
TrainingTask,
|
TrainingTask,
|
||||||
FIELD_CLASSES,
|
|
||||||
CSV_TO_CLASS_MAPPING,
|
|
||||||
)
|
)
|
||||||
|
from shared.fields import FIELD_CLASSES, CSV_TO_CLASS_MAPPING
|
||||||
|
|
||||||
|
|
||||||
class TestBatchUpload:
|
class TestBatchUpload:
|
||||||
@@ -507,7 +506,10 @@ class TestCSVToClassMapping:
|
|||||||
assert len(CSV_TO_CLASS_MAPPING) > 0
|
assert len(CSV_TO_CLASS_MAPPING) > 0
|
||||||
|
|
||||||
def test_csv_mapping_values(self):
|
def test_csv_mapping_values(self):
|
||||||
"""Test specific CSV column mappings."""
|
"""Test specific CSV column mappings.
|
||||||
|
|
||||||
|
Note: customer_number is class 8 (verified from trained model best.pt).
|
||||||
|
"""
|
||||||
assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0
|
assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0
|
||||||
assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1
|
assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1
|
||||||
assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2
|
assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2
|
||||||
@@ -516,7 +518,7 @@ class TestCSVToClassMapping:
|
|||||||
assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5
|
assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5
|
||||||
assert CSV_TO_CLASS_MAPPING["Amount"] == 6
|
assert CSV_TO_CLASS_MAPPING["Amount"] == 6
|
||||||
assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7
|
assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7
|
||||||
assert CSV_TO_CLASS_MAPPING["customer_number"] == 9
|
assert CSV_TO_CLASS_MAPPING["customer_number"] == 8 # Fixed: was 9, model uses 8
|
||||||
|
|
||||||
def test_csv_mapping_matches_field_classes(self):
|
def test_csv_mapping_matches_field_classes(self):
|
||||||
"""Test that CSV mapping is consistent with FIELD_CLASSES."""
|
"""Test that CSV mapping is consistent with FIELD_CLASSES."""
|
||||||
|
|||||||
1
tests/shared/fields/__init__.py
Normal file
1
tests/shared/fields/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for shared.fields module."""
|
||||||
200
tests/shared/fields/test_field_config.py
Normal file
200
tests/shared/fields/test_field_config.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""
|
||||||
|
Tests for field configuration - Single Source of Truth.
|
||||||
|
|
||||||
|
These tests ensure consistency across all field definitions and prevent
|
||||||
|
accidental changes that could break model inference.
|
||||||
|
|
||||||
|
CRITICAL: These tests verify that field definitions match the trained YOLO model.
|
||||||
|
If these tests fail, it likely means someone modified field IDs incorrectly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.fields import (
|
||||||
|
FIELD_DEFINITIONS,
|
||||||
|
CLASS_NAMES,
|
||||||
|
FIELD_CLASSES,
|
||||||
|
FIELD_CLASS_IDS,
|
||||||
|
CLASS_TO_FIELD,
|
||||||
|
CSV_TO_CLASS_MAPPING,
|
||||||
|
TRAINING_FIELD_CLASSES,
|
||||||
|
NUM_CLASSES,
|
||||||
|
FieldDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFieldDefinitionsIntegrity:
|
||||||
|
"""Tests to ensure field definitions are complete and consistent."""
|
||||||
|
|
||||||
|
def test_exactly_10_field_definitions(self):
|
||||||
|
"""Verify we have exactly 10 field classes (matching trained model)."""
|
||||||
|
assert len(FIELD_DEFINITIONS) == 10
|
||||||
|
assert NUM_CLASSES == 10
|
||||||
|
|
||||||
|
def test_class_ids_are_sequential(self):
|
||||||
|
"""Verify class IDs are 0-9 without gaps."""
|
||||||
|
class_ids = {fd.class_id for fd in FIELD_DEFINITIONS}
|
||||||
|
assert class_ids == set(range(10))
|
||||||
|
|
||||||
|
def test_class_ids_are_unique(self):
|
||||||
|
"""Verify no duplicate class IDs."""
|
||||||
|
class_ids = [fd.class_id for fd in FIELD_DEFINITIONS]
|
||||||
|
assert len(class_ids) == len(set(class_ids))
|
||||||
|
|
||||||
|
def test_class_names_are_unique(self):
|
||||||
|
"""Verify no duplicate class names."""
|
||||||
|
class_names = [fd.class_name for fd in FIELD_DEFINITIONS]
|
||||||
|
assert len(class_names) == len(set(class_names))
|
||||||
|
|
||||||
|
def test_field_definition_is_immutable(self):
|
||||||
|
"""Verify FieldDefinition is frozen (immutable)."""
|
||||||
|
fd = FIELD_DEFINITIONS[0]
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
fd.class_id = 99 # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelCompatibility:
|
||||||
|
"""Tests to verify field definitions match the trained YOLO model.
|
||||||
|
|
||||||
|
These exact values are read from runs/train/invoice_fields/weights/best.pt
|
||||||
|
and MUST NOT be changed without retraining the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Expected model.names from best.pt - DO NOT CHANGE
|
||||||
|
EXPECTED_MODEL_NAMES = {
|
||||||
|
0: "invoice_number",
|
||||||
|
1: "invoice_date",
|
||||||
|
2: "invoice_due_date",
|
||||||
|
3: "ocr_number",
|
||||||
|
4: "bankgiro",
|
||||||
|
5: "plusgiro",
|
||||||
|
6: "amount",
|
||||||
|
7: "supplier_org_number",
|
||||||
|
8: "customer_number",
|
||||||
|
9: "payment_line",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_field_classes_match_model(self):
|
||||||
|
"""CRITICAL: Verify FIELD_CLASSES matches trained model exactly."""
|
||||||
|
assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES
|
||||||
|
|
||||||
|
def test_class_names_order_matches_model(self):
|
||||||
|
"""CRITICAL: Verify CLASS_NAMES order matches model class IDs."""
|
||||||
|
expected_order = [
|
||||||
|
self.EXPECTED_MODEL_NAMES[i] for i in range(10)
|
||||||
|
]
|
||||||
|
assert CLASS_NAMES == expected_order
|
||||||
|
|
||||||
|
def test_customer_number_is_class_8(self):
|
||||||
|
"""CRITICAL: customer_number must be class 8 (not 9)."""
|
||||||
|
assert FIELD_CLASS_IDS["customer_number"] == 8
|
||||||
|
assert FIELD_CLASSES[8] == "customer_number"
|
||||||
|
|
||||||
|
def test_payment_line_is_class_9(self):
|
||||||
|
"""CRITICAL: payment_line must be class 9 (not 8)."""
|
||||||
|
assert FIELD_CLASS_IDS["payment_line"] == 9
|
||||||
|
assert FIELD_CLASSES[9] == "payment_line"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMappingConsistency:
|
||||||
|
"""Tests to verify all mappings are consistent with each other."""
|
||||||
|
|
||||||
|
def test_field_classes_and_field_class_ids_are_inverses(self):
|
||||||
|
"""Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses."""
|
||||||
|
for class_id, class_name in FIELD_CLASSES.items():
|
||||||
|
assert FIELD_CLASS_IDS[class_name] == class_id
|
||||||
|
|
||||||
|
for class_name, class_id in FIELD_CLASS_IDS.items():
|
||||||
|
assert FIELD_CLASSES[class_id] == class_name
|
||||||
|
|
||||||
|
def test_class_names_matches_field_classes_values(self):
|
||||||
|
"""Verify CLASS_NAMES list matches FIELD_CLASSES values in order."""
|
||||||
|
for i, class_name in enumerate(CLASS_NAMES):
|
||||||
|
assert FIELD_CLASSES[i] == class_name
|
||||||
|
|
||||||
|
def test_class_to_field_has_all_classes(self):
|
||||||
|
"""Verify CLASS_TO_FIELD has mapping for all class names."""
|
||||||
|
for class_name in CLASS_NAMES:
|
||||||
|
assert class_name in CLASS_TO_FIELD
|
||||||
|
|
||||||
|
def test_csv_mapping_excludes_derived_fields(self):
|
||||||
|
"""Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line."""
|
||||||
|
# payment_line is derived, should not be in CSV mapping
|
||||||
|
assert "payment_line" not in CSV_TO_CLASS_MAPPING
|
||||||
|
|
||||||
|
# All non-derived fields should be in CSV mapping
|
||||||
|
for fd in FIELD_DEFINITIONS:
|
||||||
|
if not fd.is_derived:
|
||||||
|
assert fd.field_name in CSV_TO_CLASS_MAPPING
|
||||||
|
|
||||||
|
def test_training_field_classes_includes_all(self):
|
||||||
|
"""Verify TRAINING_FIELD_CLASSES includes all fields including derived."""
|
||||||
|
for fd in FIELD_DEFINITIONS:
|
||||||
|
assert fd.field_name in TRAINING_FIELD_CLASSES
|
||||||
|
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecificFieldDefinitions:
|
||||||
|
"""Tests for specific field definitions to catch common mistakes."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"class_id,expected_class_name",
|
||||||
|
[
|
||||||
|
(0, "invoice_number"),
|
||||||
|
(1, "invoice_date"),
|
||||||
|
(2, "invoice_due_date"),
|
||||||
|
(3, "ocr_number"),
|
||||||
|
(4, "bankgiro"),
|
||||||
|
(5, "plusgiro"),
|
||||||
|
(6, "amount"),
|
||||||
|
(7, "supplier_org_number"),
|
||||||
|
(8, "customer_number"),
|
||||||
|
(9, "payment_line"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str):
|
||||||
|
"""Verify each class ID maps to the correct class name."""
|
||||||
|
assert FIELD_CLASSES[class_id] == expected_class_name
|
||||||
|
|
||||||
|
def test_payment_line_is_derived(self):
|
||||||
|
"""Verify payment_line is marked as derived."""
|
||||||
|
payment_line_def = next(
|
||||||
|
fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line"
|
||||||
|
)
|
||||||
|
assert payment_line_def.is_derived is True
|
||||||
|
|
||||||
|
def test_other_fields_are_not_derived(self):
|
||||||
|
"""Verify all fields except payment_line are not derived."""
|
||||||
|
for fd in FIELD_DEFINITIONS:
|
||||||
|
if fd.class_name != "payment_line":
|
||||||
|
assert fd.is_derived is False, f"{fd.class_name} should not be derived"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackwardCompatibility:
|
||||||
|
"""Tests to ensure backward compatibility with existing code."""
|
||||||
|
|
||||||
|
def test_csv_to_class_mapping_field_names(self):
|
||||||
|
"""Verify CSV_TO_CLASS_MAPPING uses correct field names."""
|
||||||
|
# These are the field names used in CSV files
|
||||||
|
expected_fields = {
|
||||||
|
"InvoiceNumber": 0,
|
||||||
|
"InvoiceDate": 1,
|
||||||
|
"InvoiceDueDate": 2,
|
||||||
|
"OCR": 3,
|
||||||
|
"Bankgiro": 4,
|
||||||
|
"Plusgiro": 5,
|
||||||
|
"Amount": 6,
|
||||||
|
"supplier_organisation_number": 7,
|
||||||
|
"customer_number": 8,
|
||||||
|
# payment_line (9) is derived, not in CSV
|
||||||
|
}
|
||||||
|
assert CSV_TO_CLASS_MAPPING == expected_fields
|
||||||
|
|
||||||
|
def test_class_to_field_returns_field_names(self):
|
||||||
|
"""Verify CLASS_TO_FIELD maps class names to field names correctly."""
|
||||||
|
# Sample checks for key fields
|
||||||
|
assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber"
|
||||||
|
assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate"
|
||||||
|
assert CLASS_TO_FIELD["ocr_number"] == "OCR"
|
||||||
|
assert CLASS_TO_FIELD["customer_number"] == "customer_number"
|
||||||
|
assert CLASS_TO_FIELD["payment_line"] == "payment_line"
|
||||||
1
tests/shared/storage/__init__.py
Normal file
1
tests/shared/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Tests for storage module
|
||||||
718
tests/shared/storage/test_azure.py
Normal file
718
tests/shared/storage/test_azure.py
Normal file
@@ -0,0 +1,718 @@
|
|||||||
|
"""
|
||||||
|
Tests for AzureBlobStorageBackend.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
Uses mocking to avoid requiring actual Azure credentials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_blob_service_client() -> MagicMock:
|
||||||
|
"""Create a mock BlobServiceClient."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_container_client(mock_blob_service_client: MagicMock) -> MagicMock:
|
||||||
|
"""Create a mock ContainerClient."""
|
||||||
|
container_client = MagicMock()
|
||||||
|
mock_blob_service_client.get_container_client.return_value = container_client
|
||||||
|
return container_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_blob_client(mock_container_client: MagicMock) -> MagicMock:
|
||||||
|
"""Create a mock BlobClient."""
|
||||||
|
blob_client = MagicMock()
|
||||||
|
mock_container_client.get_blob_client.return_value = blob_client
|
||||||
|
return blob_client
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendCreation:
|
||||||
|
"""Tests for AzureBlobStorageBackend instantiation."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_create_with_connection_string(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test creating backend with connection string."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
connection_string = "DefaultEndpointsProtocol=https;AccountName=test;..."
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string=connection_string,
|
||||||
|
container_name="training-images",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_service_class.from_connection_string.assert_called_once_with(
|
||||||
|
connection_string
|
||||||
|
)
|
||||||
|
assert backend.container_name == "training-images"
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_create_creates_container_if_not_exists(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that container is created if it doesn't exist."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_container.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="new-container",
|
||||||
|
create_container=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_container.create_container.assert_called_once()
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_create_does_not_create_container_by_default(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that container is not created by default."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_container.exists.return_value = True
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="existing-container",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_container.create_container.assert_not_called()
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_is_storage_backend_subclass(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that AzureBlobStorageBackend is a StorageBackend."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(backend, StorageBackend)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendUpload:
|
||||||
|
"""Tests for AzureBlobStorageBackend.upload method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_upload_file(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test uploading a file."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||||
|
f.write(b"Hello, World!")
|
||||||
|
temp_path = Path(f.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = backend.upload(temp_path, "uploads/sample.txt")
|
||||||
|
|
||||||
|
assert result == "uploads/sample.txt"
|
||||||
|
mock_container.get_blob_client.assert_called_with("uploads/sample.txt")
|
||||||
|
mock_blob.upload_blob.assert_called_once()
|
||||||
|
finally:
|
||||||
|
temp_path.unlink()
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_upload_fails_if_blob_exists_without_overwrite(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload fails if blob exists and overwrite is False."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||||
|
f.write(b"content")
|
||||||
|
temp_path = Path(f.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(StorageError, match="already exists"):
|
||||||
|
backend.upload(temp_path, "existing.txt", overwrite=False)
|
||||||
|
finally:
|
||||||
|
temp_path.unlink()
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_upload_succeeds_with_overwrite(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload succeeds with overwrite=True."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
||||||
|
f.write(b"content")
|
||||||
|
temp_path = Path(f.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = backend.upload(temp_path, "existing.txt", overwrite=True)
|
||||||
|
|
||||||
|
assert result == "existing.txt"
|
||||||
|
mock_blob.upload_blob.assert_called_once()
|
||||||
|
# Check overwrite=True was passed
|
||||||
|
call_kwargs = mock_blob.upload_blob.call_args[1]
|
||||||
|
assert call_kwargs.get("overwrite") is True
|
||||||
|
finally:
|
||||||
|
temp_path.unlink()
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_upload_nonexistent_file_fails(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that uploading nonexistent file fails."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendDownload:
|
||||||
|
"""Tests for AzureBlobStorageBackend.download method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_download_file(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test downloading a file."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
|
||||||
|
# Mock download_blob to return stream
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.readall.return_value = b"Hello, World!"
|
||||||
|
mock_blob.download_blob.return_value = mock_stream
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
local_path = Path(temp_dir) / "downloaded.txt"
|
||||||
|
result = backend.download("remote/sample.txt", local_path)
|
||||||
|
|
||||||
|
assert result == local_path
|
||||||
|
assert local_path.exists()
|
||||||
|
assert local_path.read_bytes() == b"Hello, World!"
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_download_creates_parent_directories(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that download creates parent directories."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.readall.return_value = b"content"
|
||||||
|
mock_blob.download_blob.return_value = mock_stream
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
local_path = Path(temp_dir) / "deep" / "nested" / "downloaded.txt"
|
||||||
|
result = backend.download("sample.txt", local_path)
|
||||||
|
|
||||||
|
assert local_path.exists()
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_download_nonexistent_blob_fails(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that downloading nonexistent blob fails."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
|
||||||
|
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendExists:
|
||||||
|
"""Tests for AzureBlobStorageBackend.exists method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_exists_returns_true_for_existing_blob(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test exists returns True for existing blob."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert backend.exists("existing.txt") is True
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_exists_returns_false_for_nonexistent_blob(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test exists returns False for nonexistent blob."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert backend.exists("nonexistent.txt") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendListFiles:
|
||||||
|
"""Tests for AzureBlobStorageBackend.list_files method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_list_files_empty_container(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test listing files in empty container."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_container.list_blobs.return_value = []
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert backend.list_files("") == []
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_list_files_returns_all_blobs(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test listing all blobs."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
|
||||||
|
# Create mock blob items
|
||||||
|
mock_blob1 = MagicMock()
|
||||||
|
mock_blob1.name = "file1.txt"
|
||||||
|
mock_blob2 = MagicMock()
|
||||||
|
mock_blob2.name = "file2.txt"
|
||||||
|
mock_blob3 = MagicMock()
|
||||||
|
mock_blob3.name = "subdir/file3.txt"
|
||||||
|
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
files = backend.list_files("")
|
||||||
|
|
||||||
|
assert len(files) == 3
|
||||||
|
assert "file1.txt" in files
|
||||||
|
assert "file2.txt" in files
|
||||||
|
assert "subdir/file3.txt" in files
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_list_files_with_prefix(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test listing files with prefix filter."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
|
||||||
|
mock_blob1 = MagicMock()
|
||||||
|
mock_blob1.name = "images/a.png"
|
||||||
|
mock_blob2 = MagicMock()
|
||||||
|
mock_blob2.name = "images/b.png"
|
||||||
|
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
files = backend.list_files("images/")
|
||||||
|
|
||||||
|
mock_container.list_blobs.assert_called_with(name_starts_with="images/")
|
||||||
|
assert len(files) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendDelete:
|
||||||
|
"""Tests for AzureBlobStorageBackend.delete method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_delete_existing_blob(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting an existing blob."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = backend.delete("sample.txt")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_blob.delete_blob.assert_called_once()
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_delete_nonexistent_blob_returns_false(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting nonexistent blob returns False."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = backend.delete("nonexistent.txt")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_blob.delete_blob.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendGetUrl:
|
||||||
|
"""Tests for AzureBlobStorageBackend.get_url method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_get_url_returns_blob_url(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_url returns blob URL."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
mock_blob.url = "https://account.blob.core.windows.net/container/sample.txt"
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
url = backend.get_url("sample.txt")
|
||||||
|
|
||||||
|
assert url == "https://account.blob.core.windows.net/container/sample.txt"
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_get_url_nonexistent_blob_fails(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_url for nonexistent blob fails."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.get_url("nonexistent.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendUploadBytes:
|
||||||
|
"""Tests for AzureBlobStorageBackend.upload_bytes method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_upload_bytes(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test uploading bytes directly."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = b"Binary content here"
|
||||||
|
result = backend.upload_bytes(data, "binary.dat")
|
||||||
|
|
||||||
|
assert result == "binary.dat"
|
||||||
|
mock_blob.upload_blob.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendDownloadBytes:
|
||||||
|
"""Tests for AzureBlobStorageBackend.download_bytes method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_download_bytes(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test downloading blob as bytes."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = True
|
||||||
|
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.readall.return_value = b"Hello, World!"
|
||||||
|
mock_blob.download_blob.return_value = mock_stream
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = backend.download_bytes("sample.txt")
|
||||||
|
|
||||||
|
assert data == b"Hello, World!"
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_download_bytes_nonexistent(
|
||||||
|
self, mock_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading nonexistent blob as bytes."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.download_bytes("nonexistent.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendBatchOperations:
|
||||||
|
"""Tests for batch operations in AzureBlobStorageBackend."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_upload_directory(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test uploading an entire directory."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
mock_blob = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob
|
||||||
|
mock_blob.exists.return_value = False
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_path = Path(temp_dir)
|
||||||
|
(temp_path / "file1.txt").write_text("content1")
|
||||||
|
(temp_path / "subdir").mkdir()
|
||||||
|
(temp_path / "subdir" / "file2.txt").write_text("content2")
|
||||||
|
|
||||||
|
results = backend.upload_directory(temp_path, "uploads/")
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert "uploads/file1.txt" in results
|
||||||
|
assert "uploads/subdir/file2.txt" in results
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_download_directory(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test downloading blobs matching a prefix."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service_class.from_connection_string.return_value = mock_service
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_service.get_container_client.return_value = mock_container
|
||||||
|
|
||||||
|
# Mock blob listing
|
||||||
|
mock_blob1 = MagicMock()
|
||||||
|
mock_blob1.name = "images/a.png"
|
||||||
|
mock_blob2 = MagicMock()
|
||||||
|
mock_blob2.name = "images/b.png"
|
||||||
|
mock_container.list_blobs.return_value = [mock_blob1, mock_blob2]
|
||||||
|
|
||||||
|
# Mock blob clients
|
||||||
|
mock_blob_client = MagicMock()
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob_client
|
||||||
|
mock_blob_client.exists.return_value = True
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.readall.return_value = b"image content"
|
||||||
|
mock_blob_client.download_blob.return_value = mock_stream
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="connection_string",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
local_path = Path(temp_dir)
|
||||||
|
results = backend.download_directory("images/", local_path)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# Files should be created relative to prefix
|
||||||
|
assert (local_path / "a.png").exists() or (local_path / "images" / "a.png").exists()
|
||||||
301
tests/shared/storage/test_base.py
Normal file
301
tests/shared/storage/test_base.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
Tests for storage base module.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import BinaryIO
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageBackendInterface:
|
||||||
|
"""Tests for StorageBackend abstract base class."""
|
||||||
|
|
||||||
|
def test_cannot_instantiate_directly(self) -> None:
|
||||||
|
"""Test that StorageBackend cannot be instantiated."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
StorageBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_is_abstract_base_class(self) -> None:
|
||||||
|
"""Test that StorageBackend is an ABC."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
assert issubclass(StorageBackend, ABC)
|
||||||
|
|
||||||
|
def test_subclass_must_implement_upload(self) -> None:
|
||||||
|
"""Test that subclass must implement upload method."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class IncompleteBackend(StorageBackend):
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_subclass_must_implement_download(self) -> None:
|
||||||
|
"""Test that subclass must implement download method."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class IncompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_subclass_must_implement_exists(self) -> None:
|
||||||
|
"""Test that subclass must implement exists method."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class IncompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_subclass_must_implement_list_files(self) -> None:
|
||||||
|
"""Test that subclass must implement list_files method."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class IncompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_subclass_must_implement_delete(self) -> None:
|
||||||
|
"""Test that subclass must implement delete method."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class IncompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_subclass_must_implement_get_url(self) -> None:
|
||||||
|
"""Test that subclass must implement get_url method."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class IncompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_valid_subclass_can_be_instantiated(self) -> None:
|
||||||
|
"""Test that a complete subclass can be instantiated."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class CompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_presigned_url(
|
||||||
|
self, remote_path: str, expires_in_seconds: int = 3600
|
||||||
|
) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
backend = CompleteBackend()
|
||||||
|
assert isinstance(backend, StorageBackend)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageError:
|
||||||
|
"""Tests for StorageError exception."""
|
||||||
|
|
||||||
|
def test_storage_error_is_exception(self) -> None:
|
||||||
|
"""Test that StorageError is an Exception."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
|
||||||
|
assert issubclass(StorageError, Exception)
|
||||||
|
|
||||||
|
def test_storage_error_with_message(self) -> None:
|
||||||
|
"""Test StorageError with message."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
|
||||||
|
error = StorageError("Upload failed")
|
||||||
|
assert str(error) == "Upload failed"
|
||||||
|
|
||||||
|
def test_storage_error_can_be_raised(self) -> None:
|
||||||
|
"""Test that StorageError can be raised and caught."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="test error"):
|
||||||
|
raise StorageError("test error")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileNotFoundError:
|
||||||
|
"""Tests for FileNotFoundStorageError exception."""
|
||||||
|
|
||||||
|
def test_file_not_found_is_storage_error(self) -> None:
|
||||||
|
"""Test that FileNotFoundStorageError is a StorageError."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError, StorageError
|
||||||
|
|
||||||
|
assert issubclass(FileNotFoundStorageError, StorageError)
|
||||||
|
|
||||||
|
def test_file_not_found_with_path(self) -> None:
|
||||||
|
"""Test FileNotFoundStorageError with path."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
|
||||||
|
error = FileNotFoundStorageError("images/test.png")
|
||||||
|
assert "images/test.png" in str(error)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageConfig:
|
||||||
|
"""Tests for StorageConfig dataclass."""
|
||||||
|
|
||||||
|
def test_storage_config_creation(self) -> None:
|
||||||
|
"""Test creating StorageConfig."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="azure_blob",
|
||||||
|
connection_string="DefaultEndpointsProtocol=https;...",
|
||||||
|
container_name="training-images",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.backend_type == "azure_blob"
|
||||||
|
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
|
||||||
|
assert config.container_name == "training-images"
|
||||||
|
|
||||||
|
def test_storage_config_defaults(self) -> None:
|
||||||
|
"""Test StorageConfig with defaults."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
|
||||||
|
config = StorageConfig(backend_type="local")
|
||||||
|
|
||||||
|
assert config.backend_type == "local"
|
||||||
|
assert config.connection_string is None
|
||||||
|
assert config.container_name is None
|
||||||
|
assert config.base_path is None
|
||||||
|
|
||||||
|
def test_storage_config_with_base_path(self) -> None:
|
||||||
|
"""Test StorageConfig with base_path for local backend."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="local",
|
||||||
|
base_path=Path("/data/images"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.backend_type == "local"
|
||||||
|
assert config.base_path == Path("/data/images")
|
||||||
|
|
||||||
|
def test_storage_config_immutable(self) -> None:
|
||||||
|
"""Test that StorageConfig is immutable (frozen)."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
|
||||||
|
config = StorageConfig(backend_type="local")
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
config.backend_type = "azure_blob" # type: ignore
|
||||||
348
tests/shared/storage/test_config_loader.py
Normal file
348
tests/shared/storage/test_config_loader.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
"""
|
||||||
|
Tests for storage configuration file loader.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir() -> Path:
|
||||||
|
"""Create a temporary directory for tests."""
|
||||||
|
temp_dir = Path(tempfile.mkdtemp())
|
||||||
|
yield temp_dir
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvVarSubstitution:
|
||||||
|
"""Tests for environment variable substitution in config values."""
|
||||||
|
|
||||||
|
def test_substitute_simple_env_var(self) -> None:
|
||||||
|
"""Test substituting a simple environment variable."""
|
||||||
|
from shared.storage.config_loader import substitute_env_vars
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"MY_VAR": "my_value"}):
|
||||||
|
result = substitute_env_vars("${MY_VAR}")
|
||||||
|
assert result == "my_value"
|
||||||
|
|
||||||
|
def test_substitute_env_var_with_default(self) -> None:
|
||||||
|
"""Test substituting env var with default when var is not set."""
|
||||||
|
from shared.storage.config_loader import substitute_env_vars
|
||||||
|
|
||||||
|
# Ensure var is not set
|
||||||
|
os.environ.pop("UNSET_VAR", None)
|
||||||
|
|
||||||
|
result = substitute_env_vars("${UNSET_VAR:-default_value}")
|
||||||
|
assert result == "default_value"
|
||||||
|
|
||||||
|
def test_substitute_env_var_ignores_default_when_set(self) -> None:
|
||||||
|
"""Test that default is ignored when env var is set."""
|
||||||
|
from shared.storage.config_loader import substitute_env_vars
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"SET_VAR": "actual_value"}):
|
||||||
|
result = substitute_env_vars("${SET_VAR:-default_value}")
|
||||||
|
assert result == "actual_value"
|
||||||
|
|
||||||
|
def test_substitute_multiple_env_vars(self) -> None:
|
||||||
|
"""Test substituting multiple env vars in one string."""
|
||||||
|
from shared.storage.config_loader import substitute_env_vars
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"HOST": "localhost", "PORT": "5432"}):
|
||||||
|
result = substitute_env_vars("postgres://${HOST}:${PORT}/db")
|
||||||
|
assert result == "postgres://localhost:5432/db"
|
||||||
|
|
||||||
|
def test_substitute_preserves_non_env_text(self) -> None:
|
||||||
|
"""Test that non-env-var text is preserved."""
|
||||||
|
from shared.storage.config_loader import substitute_env_vars
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"VAR": "value"}):
|
||||||
|
result = substitute_env_vars("prefix_${VAR}_suffix")
|
||||||
|
assert result == "prefix_value_suffix"
|
||||||
|
|
||||||
|
def test_substitute_empty_string_when_not_set_and_no_default(self) -> None:
|
||||||
|
"""Test that empty string is returned when var not set and no default."""
|
||||||
|
from shared.storage.config_loader import substitute_env_vars
|
||||||
|
|
||||||
|
os.environ.pop("MISSING_VAR", None)
|
||||||
|
|
||||||
|
result = substitute_env_vars("${MISSING_VAR}")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadStorageConfigYaml:
|
||||||
|
"""Tests for loading storage configuration from YAML files."""
|
||||||
|
|
||||||
|
def test_load_local_backend_config(self, temp_dir: Path) -> None:
|
||||||
|
"""Test loading configuration for local backend."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
config_path = temp_dir / "storage.yaml"
|
||||||
|
config_path.write_text("""
|
||||||
|
backend: local
|
||||||
|
presigned_url_expiry: 3600
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: ./data/storage
|
||||||
|
""")
|
||||||
|
|
||||||
|
config = load_storage_config(config_path)
|
||||||
|
|
||||||
|
assert config.backend_type == "local"
|
||||||
|
assert config.presigned_url_expiry == 3600
|
||||||
|
assert config.local is not None
|
||||||
|
assert config.local.base_path == Path("./data/storage")
|
||||||
|
|
||||||
|
def test_load_azure_backend_config(self, temp_dir: Path) -> None:
|
||||||
|
"""Test loading configuration for Azure backend."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
config_path = temp_dir / "storage.yaml"
|
||||||
|
config_path.write_text("""
|
||||||
|
backend: azure_blob
|
||||||
|
presigned_url_expiry: 7200
|
||||||
|
|
||||||
|
azure:
|
||||||
|
connection_string: DefaultEndpointsProtocol=https;AccountName=test
|
||||||
|
container_name: documents
|
||||||
|
create_container: true
|
||||||
|
""")
|
||||||
|
|
||||||
|
config = load_storage_config(config_path)
|
||||||
|
|
||||||
|
assert config.backend_type == "azure_blob"
|
||||||
|
assert config.presigned_url_expiry == 7200
|
||||||
|
assert config.azure is not None
|
||||||
|
assert config.azure.connection_string == "DefaultEndpointsProtocol=https;AccountName=test"
|
||||||
|
assert config.azure.container_name == "documents"
|
||||||
|
assert config.azure.create_container is True
|
||||||
|
|
||||||
|
def test_load_s3_backend_config(self, temp_dir: Path) -> None:
|
||||||
|
"""Test loading configuration for S3 backend."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
config_path = temp_dir / "storage.yaml"
|
||||||
|
config_path.write_text("""
|
||||||
|
backend: s3
|
||||||
|
presigned_url_expiry: 1800
|
||||||
|
|
||||||
|
s3:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
region_name: us-west-2
|
||||||
|
endpoint_url: http://localhost:9000
|
||||||
|
create_bucket: false
|
||||||
|
""")
|
||||||
|
|
||||||
|
config = load_storage_config(config_path)
|
||||||
|
|
||||||
|
assert config.backend_type == "s3"
|
||||||
|
assert config.presigned_url_expiry == 1800
|
||||||
|
assert config.s3 is not None
|
||||||
|
assert config.s3.bucket_name == "my-bucket"
|
||||||
|
assert config.s3.region_name == "us-west-2"
|
||||||
|
assert config.s3.endpoint_url == "http://localhost:9000"
|
||||||
|
assert config.s3.create_bucket is False
|
||||||
|
|
||||||
|
def test_load_config_with_env_var_substitution(self, temp_dir: Path) -> None:
|
||||||
|
"""Test that environment variables are substituted in config."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
config_path = temp_dir / "storage.yaml"
|
||||||
|
config_path.write_text("""
|
||||||
|
backend: ${STORAGE_BACKEND:-local}
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: ${STORAGE_PATH:-./default/path}
|
||||||
|
""")
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"STORAGE_BACKEND": "local", "STORAGE_PATH": "/custom/path"}):
|
||||||
|
config = load_storage_config(config_path)
|
||||||
|
|
||||||
|
assert config.backend_type == "local"
|
||||||
|
assert config.local is not None
|
||||||
|
assert config.local.base_path == Path("/custom/path")
|
||||||
|
|
||||||
|
def test_load_config_file_not_found_raises(self, temp_dir: Path) -> None:
|
||||||
|
"""Test that FileNotFoundError is raised for missing config file."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_storage_config(temp_dir / "nonexistent.yaml")
|
||||||
|
|
||||||
|
def test_load_config_invalid_yaml_raises(self, temp_dir: Path) -> None:
|
||||||
|
"""Test that ValueError is raised for invalid YAML."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
config_path = temp_dir / "storage.yaml"
|
||||||
|
config_path.write_text("invalid: yaml: content: [")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid"):
|
||||||
|
load_storage_config(config_path)
|
||||||
|
|
||||||
|
def test_load_config_missing_backend_raises(self, temp_dir: Path) -> None:
|
||||||
|
"""Test that ValueError is raised when backend is missing."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
config_path = temp_dir / "storage.yaml"
|
||||||
|
config_path.write_text("""
|
||||||
|
local:
|
||||||
|
base_path: ./data
|
||||||
|
""")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="backend"):
|
||||||
|
load_storage_config(config_path)
|
||||||
|
|
||||||
|
def test_load_config_default_presigned_url_expiry(self, temp_dir: Path) -> None:
|
||||||
|
"""Test default presigned_url_expiry when not specified."""
|
||||||
|
from shared.storage.config_loader import load_storage_config
|
||||||
|
|
||||||
|
config_path = temp_dir / "storage.yaml"
|
||||||
|
config_path.write_text("""
|
||||||
|
backend: local
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: ./data
|
||||||
|
""")
|
||||||
|
|
||||||
|
config = load_storage_config(config_path)
|
||||||
|
|
||||||
|
assert config.presigned_url_expiry == 3600 # Default value
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageFileConfig:
|
||||||
|
"""Tests for StorageFileConfig dataclass."""
|
||||||
|
|
||||||
|
def test_storage_file_config_is_immutable(self) -> None:
|
||||||
|
"""Test that StorageFileConfig is frozen (immutable)."""
|
||||||
|
from shared.storage.config_loader import StorageFileConfig
|
||||||
|
|
||||||
|
config = StorageFileConfig(backend_type="local")
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
config.backend_type = "azure_blob" # type: ignore
|
||||||
|
|
||||||
|
def test_storage_file_config_defaults(self) -> None:
|
||||||
|
"""Test StorageFileConfig default values."""
|
||||||
|
from shared.storage.config_loader import StorageFileConfig
|
||||||
|
|
||||||
|
config = StorageFileConfig(backend_type="local")
|
||||||
|
|
||||||
|
assert config.backend_type == "local"
|
||||||
|
assert config.local is None
|
||||||
|
assert config.azure is None
|
||||||
|
assert config.s3 is None
|
||||||
|
assert config.presigned_url_expiry == 3600
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalConfig:
|
||||||
|
"""Tests for LocalConfig dataclass."""
|
||||||
|
|
||||||
|
def test_local_config_creation(self) -> None:
|
||||||
|
"""Test creating LocalConfig."""
|
||||||
|
from shared.storage.config_loader import LocalConfig
|
||||||
|
|
||||||
|
config = LocalConfig(base_path=Path("/data/storage"))
|
||||||
|
|
||||||
|
assert config.base_path == Path("/data/storage")
|
||||||
|
|
||||||
|
def test_local_config_is_immutable(self) -> None:
|
||||||
|
"""Test that LocalConfig is frozen."""
|
||||||
|
from shared.storage.config_loader import LocalConfig
|
||||||
|
|
||||||
|
config = LocalConfig(base_path=Path("/data"))
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
config.base_path = Path("/other") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureConfig:
|
||||||
|
"""Tests for AzureConfig dataclass."""
|
||||||
|
|
||||||
|
def test_azure_config_creation(self) -> None:
|
||||||
|
"""Test creating AzureConfig."""
|
||||||
|
from shared.storage.config_loader import AzureConfig
|
||||||
|
|
||||||
|
config = AzureConfig(
|
||||||
|
connection_string="test_connection",
|
||||||
|
container_name="test_container",
|
||||||
|
create_container=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.connection_string == "test_connection"
|
||||||
|
assert config.container_name == "test_container"
|
||||||
|
assert config.create_container is True
|
||||||
|
|
||||||
|
def test_azure_config_defaults(self) -> None:
|
||||||
|
"""Test AzureConfig default values."""
|
||||||
|
from shared.storage.config_loader import AzureConfig
|
||||||
|
|
||||||
|
config = AzureConfig(
|
||||||
|
connection_string="conn",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.create_container is False
|
||||||
|
|
||||||
|
def test_azure_config_is_immutable(self) -> None:
|
||||||
|
"""Test that AzureConfig is frozen."""
|
||||||
|
from shared.storage.config_loader import AzureConfig
|
||||||
|
|
||||||
|
config = AzureConfig(
|
||||||
|
connection_string="conn",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
config.container_name = "other" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3Config:
|
||||||
|
"""Tests for S3Config dataclass."""
|
||||||
|
|
||||||
|
def test_s3_config_creation(self) -> None:
|
||||||
|
"""Test creating S3Config."""
|
||||||
|
from shared.storage.config_loader import S3Config
|
||||||
|
|
||||||
|
config = S3Config(
|
||||||
|
bucket_name="my-bucket",
|
||||||
|
region_name="us-east-1",
|
||||||
|
access_key_id="AKIAIOSFODNN7EXAMPLE",
|
||||||
|
secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||||
|
endpoint_url="http://localhost:9000",
|
||||||
|
create_bucket=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.bucket_name == "my-bucket"
|
||||||
|
assert config.region_name == "us-east-1"
|
||||||
|
assert config.access_key_id == "AKIAIOSFODNN7EXAMPLE"
|
||||||
|
assert config.secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||||
|
assert config.endpoint_url == "http://localhost:9000"
|
||||||
|
assert config.create_bucket is True
|
||||||
|
|
||||||
|
def test_s3_config_minimal(self) -> None:
|
||||||
|
"""Test S3Config with only required fields."""
|
||||||
|
from shared.storage.config_loader import S3Config
|
||||||
|
|
||||||
|
config = S3Config(bucket_name="bucket")
|
||||||
|
|
||||||
|
assert config.bucket_name == "bucket"
|
||||||
|
assert config.region_name is None
|
||||||
|
assert config.access_key_id is None
|
||||||
|
assert config.secret_access_key is None
|
||||||
|
assert config.endpoint_url is None
|
||||||
|
assert config.create_bucket is False
|
||||||
|
|
||||||
|
def test_s3_config_is_immutable(self) -> None:
|
||||||
|
"""Test that S3Config is frozen."""
|
||||||
|
from shared.storage.config_loader import S3Config
|
||||||
|
|
||||||
|
config = S3Config(bucket_name="bucket")
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
config.bucket_name = "other" # type: ignore
|
||||||
423
tests/shared/storage/test_factory.py
Normal file
423
tests/shared/storage/test_factory.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
"""
|
||||||
|
Tests for storage factory.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageFactory:
|
||||||
|
"""Tests for create_storage_backend factory function."""
|
||||||
|
|
||||||
|
def test_create_local_backend(self) -> None:
|
||||||
|
"""Test creating local storage backend."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="local",
|
||||||
|
base_path=Path(temp_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = create_storage_backend(config)
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
assert backend.base_path == Path(temp_dir)
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_create_azure_backend(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test creating Azure blob storage backend."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="azure_blob",
|
||||||
|
connection_string="DefaultEndpointsProtocol=https;...",
|
||||||
|
container_name="training-images",
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = create_storage_backend(config)
|
||||||
|
|
||||||
|
assert isinstance(backend, AzureBlobStorageBackend)
|
||||||
|
|
||||||
|
def test_create_unknown_backend_raises(self) -> None:
|
||||||
|
"""Test that unknown backend type raises ValueError."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
|
||||||
|
config = StorageConfig(backend_type="unknown_backend")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown storage backend"):
|
||||||
|
create_storage_backend(config)
|
||||||
|
|
||||||
|
def test_create_local_requires_base_path(self) -> None:
|
||||||
|
"""Test that local backend requires base_path."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
|
||||||
|
config = StorageConfig(backend_type="local")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="base_path"):
|
||||||
|
create_storage_backend(config)
|
||||||
|
|
||||||
|
def test_create_azure_requires_connection_string(self) -> None:
|
||||||
|
"""Test that Azure backend requires connection_string."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="azure_blob",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="connection_string"):
|
||||||
|
create_storage_backend(config)
|
||||||
|
|
||||||
|
def test_create_azure_requires_container_name(self) -> None:
|
||||||
|
"""Test that Azure backend requires container_name."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="azure_blob",
|
||||||
|
connection_string="connection_string",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="container_name"):
|
||||||
|
create_storage_backend(config)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageFactoryFromEnv:
|
||||||
|
"""Tests for create_storage_backend_from_env factory function."""
|
||||||
|
|
||||||
|
def test_create_from_env_local(self) -> None:
|
||||||
|
"""Test creating local backend from environment variables."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_env
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "local",
|
||||||
|
"STORAGE_BASE_PATH": temp_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
backend = create_storage_backend_from_env()
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None:
|
||||||
|
"""Test creating Azure backend from environment variables."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.factory import create_storage_backend_from_env
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "azure_blob",
|
||||||
|
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
||||||
|
"AZURE_STORAGE_CONTAINER": "training-images",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
backend = create_storage_backend_from_env()
|
||||||
|
|
||||||
|
assert isinstance(backend, AzureBlobStorageBackend)
|
||||||
|
|
||||||
|
def test_create_from_env_defaults_to_local(self) -> None:
|
||||||
|
"""Test that factory defaults to local backend."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_env
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
env = {
|
||||||
|
"STORAGE_BASE_PATH": temp_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove STORAGE_BACKEND if present
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
if "STORAGE_BACKEND" in os.environ:
|
||||||
|
del os.environ["STORAGE_BACKEND"]
|
||||||
|
backend = create_storage_backend_from_env()
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
|
||||||
|
def test_create_from_env_missing_azure_vars(self) -> None:
|
||||||
|
"""Test error when Azure env vars are missing."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_env
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "azure_blob",
|
||||||
|
# Missing AZURE_STORAGE_CONNECTION_STRING
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
# Remove the connection string if present
|
||||||
|
if "AZURE_STORAGE_CONNECTION_STRING" in os.environ:
|
||||||
|
del os.environ["AZURE_STORAGE_CONNECTION_STRING"]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"):
|
||||||
|
create_storage_backend_from_env()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDefaultStorageConfig:
|
||||||
|
"""Tests for get_default_storage_config function."""
|
||||||
|
|
||||||
|
def test_get_default_config_local(self) -> None:
|
||||||
|
"""Test getting default local config."""
|
||||||
|
from shared.storage.factory import get_default_storage_config
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "local",
|
||||||
|
"STORAGE_BASE_PATH": temp_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
config = get_default_storage_config()
|
||||||
|
|
||||||
|
assert config.backend_type == "local"
|
||||||
|
assert config.base_path == Path(temp_dir)
|
||||||
|
|
||||||
|
def test_get_default_config_azure(self) -> None:
|
||||||
|
"""Test getting default Azure config."""
|
||||||
|
from shared.storage.factory import get_default_storage_config
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "azure_blob",
|
||||||
|
"AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...",
|
||||||
|
"AZURE_STORAGE_CONTAINER": "training-images",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
config = get_default_storage_config()
|
||||||
|
|
||||||
|
assert config.backend_type == "azure_blob"
|
||||||
|
assert config.connection_string == "DefaultEndpointsProtocol=https;..."
|
||||||
|
assert config.container_name == "training-images"
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageFactoryS3:
|
||||||
|
"""Tests for S3 backend support in factory."""
|
||||||
|
|
||||||
|
@patch("boto3.client")
|
||||||
|
def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test creating S3 storage backend."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="s3",
|
||||||
|
bucket_name="test-bucket",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = create_storage_backend(config)
|
||||||
|
|
||||||
|
assert isinstance(backend, S3StorageBackend)
|
||||||
|
|
||||||
|
def test_create_s3_requires_bucket_name(self) -> None:
|
||||||
|
"""Test that S3 backend requires bucket_name."""
|
||||||
|
from shared.storage.base import StorageConfig
|
||||||
|
from shared.storage.factory import create_storage_backend
|
||||||
|
|
||||||
|
config = StorageConfig(
|
||||||
|
backend_type="s3",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="bucket_name"):
|
||||||
|
create_storage_backend(config)
|
||||||
|
|
||||||
|
@patch("boto3.client")
|
||||||
|
def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test creating S3 backend from environment variables."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_env
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "s3",
|
||||||
|
"AWS_S3_BUCKET": "test-bucket",
|
||||||
|
"AWS_REGION": "us-east-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
backend = create_storage_backend_from_env()
|
||||||
|
|
||||||
|
assert isinstance(backend, S3StorageBackend)
|
||||||
|
|
||||||
|
def test_create_from_env_s3_missing_bucket(self) -> None:
|
||||||
|
"""Test error when S3 bucket env var is missing."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_env
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "s3",
|
||||||
|
# Missing AWS_S3_BUCKET
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
if "AWS_S3_BUCKET" in os.environ:
|
||||||
|
del os.environ["AWS_S3_BUCKET"]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="AWS_S3_BUCKET"):
|
||||||
|
create_storage_backend_from_env()
|
||||||
|
|
||||||
|
def test_get_default_config_s3(self) -> None:
|
||||||
|
"""Test getting default S3 config."""
|
||||||
|
from shared.storage.factory import get_default_storage_config
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "s3",
|
||||||
|
"AWS_S3_BUCKET": "test-bucket",
|
||||||
|
"AWS_REGION": "us-west-2",
|
||||||
|
"AWS_ENDPOINT_URL": "http://localhost:9000",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
config = get_default_storage_config()
|
||||||
|
|
||||||
|
assert config.backend_type == "s3"
|
||||||
|
assert config.bucket_name == "test-bucket"
|
||||||
|
assert config.region_name == "us-west-2"
|
||||||
|
assert config.endpoint_url == "http://localhost:9000"
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageFactoryFromFile:
|
||||||
|
"""Tests for create_storage_backend_from_file factory function."""
|
||||||
|
|
||||||
|
def test_create_from_yaml_file_local(self, tmp_path: Path) -> None:
|
||||||
|
"""Test creating local backend from YAML config file."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_file
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
config_file = tmp_path / "storage.yaml"
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
config_file.write_text(f"""
|
||||||
|
backend: local
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: {storage_path}
|
||||||
|
""")
|
||||||
|
|
||||||
|
backend = create_storage_backend_from_file(config_file)
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_create_from_yaml_file_azure(
|
||||||
|
self, mock_service_class: MagicMock, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test creating Azure backend from YAML config file."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
from shared.storage.factory import create_storage_backend_from_file
|
||||||
|
|
||||||
|
config_file = tmp_path / "storage.yaml"
|
||||||
|
config_file.write_text("""
|
||||||
|
backend: azure_blob
|
||||||
|
|
||||||
|
azure:
|
||||||
|
connection_string: DefaultEndpointsProtocol=https;AccountName=test
|
||||||
|
container_name: documents
|
||||||
|
""")
|
||||||
|
|
||||||
|
backend = create_storage_backend_from_file(config_file)
|
||||||
|
|
||||||
|
assert isinstance(backend, AzureBlobStorageBackend)
|
||||||
|
|
||||||
|
@patch("boto3.client")
|
||||||
|
def test_create_from_yaml_file_s3(
|
||||||
|
self, mock_boto3_client: MagicMock, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test creating S3 backend from YAML config file."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_file
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
config_file = tmp_path / "storage.yaml"
|
||||||
|
config_file.write_text("""
|
||||||
|
backend: s3
|
||||||
|
|
||||||
|
s3:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
region_name: us-east-1
|
||||||
|
""")
|
||||||
|
|
||||||
|
backend = create_storage_backend_from_file(config_file)
|
||||||
|
|
||||||
|
assert isinstance(backend, S3StorageBackend)
|
||||||
|
|
||||||
|
def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None:
|
||||||
|
"""Test that env vars are substituted in config file."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_file
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
config_file = tmp_path / "storage.yaml"
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
config_file.write_text("""
|
||||||
|
backend: ${STORAGE_BACKEND:-local}
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: ${CUSTOM_STORAGE_PATH}
|
||||||
|
""")
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)},
|
||||||
|
):
|
||||||
|
backend = create_storage_backend_from_file(config_file)
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
|
||||||
|
def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None:
|
||||||
|
"""Test that FileNotFoundError is raised for missing file."""
|
||||||
|
from shared.storage.factory import create_storage_backend_from_file
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
create_storage_backend_from_file(tmp_path / "nonexistent.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetStorageBackend:
|
||||||
|
"""Tests for get_storage_backend convenience function."""
|
||||||
|
|
||||||
|
def test_get_storage_backend_from_file(self, tmp_path: Path) -> None:
|
||||||
|
"""Test getting backend from explicit config file."""
|
||||||
|
from shared.storage.factory import get_storage_backend
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
config_file = tmp_path / "storage.yaml"
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
config_file.write_text(f"""
|
||||||
|
backend: local
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: {storage_path}
|
||||||
|
""")
|
||||||
|
|
||||||
|
backend = get_storage_backend(config_path=config_file)
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
|
||||||
|
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
|
||||||
|
"""Test that get_storage_backend falls back to env vars."""
|
||||||
|
from shared.storage.factory import get_storage_backend
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "local",
|
||||||
|
"STORAGE_BASE_PATH": str(storage_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
# No config file provided, should use env vars
|
||||||
|
backend = get_storage_backend(config_path=None)
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
712
tests/shared/storage/test_local.py
Normal file
712
tests/shared/storage/test_local.py
Normal file
@@ -0,0 +1,712 @@
|
|||||||
|
"""
|
||||||
|
Tests for LocalStorageBackend.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_storage_dir() -> Path:
|
||||||
|
"""Create a temporary directory for storage tests."""
|
||||||
|
temp_dir = Path(tempfile.mkdtemp())
|
||||||
|
yield temp_dir
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_file(temp_storage_dir: Path) -> Path:
|
||||||
|
"""Create a sample file for testing."""
|
||||||
|
file_path = temp_storage_dir / "sample.txt"
|
||||||
|
file_path.write_text("Hello, World!")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_image(temp_storage_dir: Path) -> Path:
|
||||||
|
"""Create a sample PNG file for testing."""
|
||||||
|
file_path = temp_storage_dir / "sample.png"
|
||||||
|
# Minimal valid PNG (1x1 transparent pixel)
|
||||||
|
png_data = bytes(
|
||||||
|
[
|
||||||
|
0x89,
|
||||||
|
0x50,
|
||||||
|
0x4E,
|
||||||
|
0x47,
|
||||||
|
0x0D,
|
||||||
|
0x0A,
|
||||||
|
0x1A,
|
||||||
|
0x0A, # PNG signature
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x0D, # IHDR length
|
||||||
|
0x49,
|
||||||
|
0x48,
|
||||||
|
0x44,
|
||||||
|
0x52, # IHDR
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x01, # width: 1
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x01, # height: 1
|
||||||
|
0x08,
|
||||||
|
0x06,
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x00, # 8-bit RGBA
|
||||||
|
0x1F,
|
||||||
|
0x15,
|
||||||
|
0xC4,
|
||||||
|
0x89, # CRC
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x0A, # IDAT length
|
||||||
|
0x49,
|
||||||
|
0x44,
|
||||||
|
0x41,
|
||||||
|
0x54, # IDAT
|
||||||
|
0x78,
|
||||||
|
0x9C,
|
||||||
|
0x63,
|
||||||
|
0x00,
|
||||||
|
0x01,
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x05,
|
||||||
|
0x00,
|
||||||
|
0x01, # compressed data
|
||||||
|
0x0D,
|
||||||
|
0x0A,
|
||||||
|
0x2D,
|
||||||
|
0xB4, # CRC
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x00,
|
||||||
|
0x00, # IEND length
|
||||||
|
0x49,
|
||||||
|
0x45,
|
||||||
|
0x4E,
|
||||||
|
0x44, # IEND
|
||||||
|
0xAE,
|
||||||
|
0x42,
|
||||||
|
0x60,
|
||||||
|
0x82, # CRC
|
||||||
|
]
|
||||||
|
)
|
||||||
|
file_path.write_bytes(png_data)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendCreation:
|
||||||
|
"""Tests for LocalStorageBackend instantiation."""
|
||||||
|
|
||||||
|
def test_create_with_base_path(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test creating backend with base path."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
assert backend.base_path == temp_storage_dir
|
||||||
|
|
||||||
|
def test_create_with_string_path(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test creating backend with string path."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=str(temp_storage_dir))
|
||||||
|
|
||||||
|
assert backend.base_path == temp_storage_dir
|
||||||
|
|
||||||
|
def test_create_creates_directory_if_not_exists(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that base directory is created if it doesn't exist."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
new_dir = temp_storage_dir / "new_storage"
|
||||||
|
assert not new_dir.exists()
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=new_dir)
|
||||||
|
|
||||||
|
assert new_dir.exists()
|
||||||
|
assert backend.base_path == new_dir
|
||||||
|
|
||||||
|
def test_is_storage_backend_subclass(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test that LocalStorageBackend is a StorageBackend."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
assert isinstance(backend, StorageBackend)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendUpload:
|
||||||
|
"""Tests for LocalStorageBackend.upload method."""
|
||||||
|
|
||||||
|
def test_upload_file(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test uploading a file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
result = backend.upload(sample_file, "uploads/sample.txt")
|
||||||
|
|
||||||
|
assert result == "uploads/sample.txt"
|
||||||
|
assert (storage_dir / "uploads" / "sample.txt").exists()
|
||||||
|
assert (storage_dir / "uploads" / "sample.txt").read_text() == "Hello, World!"
|
||||||
|
|
||||||
|
def test_upload_creates_subdirectories(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload creates necessary subdirectories."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
result = backend.upload(sample_file, "deep/nested/path/sample.txt")
|
||||||
|
|
||||||
|
assert (storage_dir / "deep" / "nested" / "path" / "sample.txt").exists()
|
||||||
|
|
||||||
|
def test_upload_fails_if_file_exists_without_overwrite(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload fails if file exists and overwrite is False."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
# First upload succeeds
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
# Second upload should fail
|
||||||
|
with pytest.raises(StorageError, match="already exists"):
|
||||||
|
backend.upload(sample_file, "sample.txt", overwrite=False)
|
||||||
|
|
||||||
|
def test_upload_succeeds_with_overwrite(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload succeeds with overwrite=True."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
# First upload
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
# Modify original file
|
||||||
|
sample_file.write_text("Modified content")
|
||||||
|
|
||||||
|
# Second upload with overwrite
|
||||||
|
result = backend.upload(sample_file, "sample.txt", overwrite=True)
|
||||||
|
|
||||||
|
assert result == "sample.txt"
|
||||||
|
assert (storage_dir / "sample.txt").read_text() == "Modified content"
|
||||||
|
|
||||||
|
def test_upload_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test that uploading nonexistent file fails."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.upload(Path("/nonexistent/file.txt"), "sample.txt")
|
||||||
|
|
||||||
|
def test_upload_binary_file(
|
||||||
|
self, temp_storage_dir: Path, sample_image: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test uploading a binary file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
result = backend.upload(sample_image, "images/sample.png")
|
||||||
|
|
||||||
|
assert result == "images/sample.png"
|
||||||
|
uploaded_content = (storage_dir / "images" / "sample.png").read_bytes()
|
||||||
|
assert uploaded_content == sample_image.read_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendDownload:
|
||||||
|
"""Tests for LocalStorageBackend.download method."""
|
||||||
|
|
||||||
|
def test_download_file(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading a file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
download_dir = temp_storage_dir / "downloads"
|
||||||
|
download_dir.mkdir()
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
# First upload
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
# Then download
|
||||||
|
local_path = download_dir / "downloaded.txt"
|
||||||
|
result = backend.download("sample.txt", local_path)
|
||||||
|
|
||||||
|
assert result == local_path
|
||||||
|
assert local_path.exists()
|
||||||
|
assert local_path.read_text() == "Hello, World!"
|
||||||
|
|
||||||
|
def test_download_creates_parent_directories(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that download creates parent directories."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
local_path = temp_storage_dir / "deep" / "nested" / "downloaded.txt"
|
||||||
|
result = backend.download("sample.txt", local_path)
|
||||||
|
|
||||||
|
assert local_path.exists()
|
||||||
|
assert local_path.read_text() == "Hello, World!"
|
||||||
|
|
||||||
|
def test_download_nonexistent_file_fails(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test that downloading nonexistent file fails."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"):
|
||||||
|
backend.download("nonexistent.txt", Path("/tmp/file.txt"))
|
||||||
|
|
||||||
|
def test_download_nested_file(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading a file from nested path."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "a/b/c/sample.txt")
|
||||||
|
|
||||||
|
local_path = temp_storage_dir / "downloaded.txt"
|
||||||
|
result = backend.download("a/b/c/sample.txt", local_path)
|
||||||
|
|
||||||
|
assert local_path.read_text() == "Hello, World!"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendExists:
|
||||||
|
"""Tests for LocalStorageBackend.exists method."""
|
||||||
|
|
||||||
|
def test_exists_returns_true_for_existing_file(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test exists returns True for existing file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
assert backend.exists("sample.txt") is True
|
||||||
|
|
||||||
|
def test_exists_returns_false_for_nonexistent_file(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test exists returns False for nonexistent file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
assert backend.exists("nonexistent.txt") is False
|
||||||
|
|
||||||
|
def test_exists_with_nested_path(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test exists with nested path."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "a/b/sample.txt")
|
||||||
|
|
||||||
|
assert backend.exists("a/b/sample.txt") is True
|
||||||
|
assert backend.exists("a/b/other.txt") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendListFiles:
|
||||||
|
"""Tests for LocalStorageBackend.list_files method."""
|
||||||
|
|
||||||
|
def test_list_files_empty_storage(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test listing files in empty storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
assert backend.list_files("") == []
|
||||||
|
|
||||||
|
def test_list_files_returns_all_files(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test listing all files."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
# Upload multiple files
|
||||||
|
backend.upload(sample_file, "file1.txt")
|
||||||
|
backend.upload(sample_file, "file2.txt")
|
||||||
|
backend.upload(sample_file, "subdir/file3.txt")
|
||||||
|
|
||||||
|
files = backend.list_files("")
|
||||||
|
|
||||||
|
assert len(files) == 3
|
||||||
|
assert "file1.txt" in files
|
||||||
|
assert "file2.txt" in files
|
||||||
|
assert "subdir/file3.txt" in files
|
||||||
|
|
||||||
|
def test_list_files_with_prefix(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test listing files with prefix filter."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
backend.upload(sample_file, "images/a.png")
|
||||||
|
backend.upload(sample_file, "images/b.png")
|
||||||
|
backend.upload(sample_file, "labels/a.txt")
|
||||||
|
|
||||||
|
files = backend.list_files("images/")
|
||||||
|
|
||||||
|
assert len(files) == 2
|
||||||
|
assert "images/a.png" in files
|
||||||
|
assert "images/b.png" in files
|
||||||
|
assert "labels/a.txt" not in files
|
||||||
|
|
||||||
|
def test_list_files_returns_sorted(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that list_files returns sorted list."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
backend.upload(sample_file, "c.txt")
|
||||||
|
backend.upload(sample_file, "a.txt")
|
||||||
|
backend.upload(sample_file, "b.txt")
|
||||||
|
|
||||||
|
files = backend.list_files("")
|
||||||
|
|
||||||
|
assert files == ["a.txt", "b.txt", "c.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendDelete:
|
||||||
|
"""Tests for LocalStorageBackend.delete method."""
|
||||||
|
|
||||||
|
def test_delete_existing_file(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting an existing file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
result = backend.delete("sample.txt")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert not (storage_dir / "sample.txt").exists()
|
||||||
|
|
||||||
|
def test_delete_nonexistent_file_returns_false(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting nonexistent file returns False."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
result = backend.delete("nonexistent.txt")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_delete_nested_file(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a nested file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "a/b/sample.txt")
|
||||||
|
|
||||||
|
result = backend.delete("a/b/sample.txt")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert not (storage_dir / "a" / "b" / "sample.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendGetUrl:
|
||||||
|
"""Tests for LocalStorageBackend.get_url method."""
|
||||||
|
|
||||||
|
def test_get_url_returns_file_path(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test get_url returns file:// URL."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
url = backend.get_url("sample.txt")
|
||||||
|
|
||||||
|
# Should return file:// URL or absolute path
|
||||||
|
assert "sample.txt" in url
|
||||||
|
# URL should be usable to locate the file
|
||||||
|
expected_path = storage_dir / "sample.txt"
|
||||||
|
assert str(expected_path) in url or expected_path.as_uri() == url
|
||||||
|
|
||||||
|
def test_get_url_nonexistent_file(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test get_url for nonexistent file."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.get_url("nonexistent.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendUploadBytes:
|
||||||
|
"""Tests for LocalStorageBackend.upload_bytes method."""
|
||||||
|
|
||||||
|
def test_upload_bytes(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test uploading bytes directly."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
data = b"Binary content here"
|
||||||
|
result = backend.upload_bytes(data, "binary.dat")
|
||||||
|
|
||||||
|
assert result == "binary.dat"
|
||||||
|
assert (storage_dir / "binary.dat").read_bytes() == data
|
||||||
|
|
||||||
|
def test_upload_bytes_creates_subdirectories(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload_bytes creates subdirectories."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
data = b"content"
|
||||||
|
backend.upload_bytes(data, "a/b/c/file.dat")
|
||||||
|
|
||||||
|
assert (storage_dir / "a" / "b" / "c" / "file.dat").exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendDownloadBytes:
|
||||||
|
"""Tests for LocalStorageBackend.download_bytes method."""
|
||||||
|
|
||||||
|
def test_download_bytes(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading file as bytes."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
data = backend.download_bytes("sample.txt")
|
||||||
|
|
||||||
|
assert data == b"Hello, World!"
|
||||||
|
|
||||||
|
def test_download_bytes_nonexistent(self, temp_storage_dir: Path) -> None:
|
||||||
|
"""Test downloading nonexistent file as bytes."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.download_bytes("nonexistent.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendSecurity:
|
||||||
|
"""Security tests for LocalStorageBackend - path traversal prevention."""
|
||||||
|
|
||||||
|
def test_path_traversal_with_dotdot_blocked(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal using ../ is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.upload(sample_file, "../escape.txt")
|
||||||
|
|
||||||
|
def test_path_traversal_with_nested_dotdot_blocked(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that nested path traversal is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.upload(sample_file, "subdir/../../escape.txt")
|
||||||
|
|
||||||
|
def test_path_traversal_with_many_dotdot_blocked(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that deeply nested path traversal is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.upload(sample_file, "a/b/c/../../../../escape.txt")
|
||||||
|
|
||||||
|
def test_absolute_path_unix_blocked(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that absolute Unix paths are blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Absolute paths not allowed"):
|
||||||
|
backend.upload(sample_file, "/etc/passwd")
|
||||||
|
|
||||||
|
def test_absolute_path_windows_blocked(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that absolute Windows paths are blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Absolute paths not allowed"):
|
||||||
|
backend.upload(sample_file, "C:\\Windows\\System32\\config")
|
||||||
|
|
||||||
|
def test_download_path_traversal_blocked(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal in download is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.download("../escape.txt", Path("/tmp/file.txt"))
|
||||||
|
|
||||||
|
def test_exists_path_traversal_blocked(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal in exists is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.exists("../escape.txt")
|
||||||
|
|
||||||
|
def test_delete_path_traversal_blocked(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal in delete is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.delete("../escape.txt")
|
||||||
|
|
||||||
|
def test_get_url_path_traversal_blocked(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal in get_url is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.get_url("../escape.txt")
|
||||||
|
|
||||||
|
def test_upload_bytes_path_traversal_blocked(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal in upload_bytes is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.upload_bytes(b"content", "../escape.txt")
|
||||||
|
|
||||||
|
def test_download_bytes_path_traversal_blocked(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal in download_bytes is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.download_bytes("../escape.txt")
|
||||||
|
|
||||||
|
def test_valid_nested_path_still_works(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that valid nested paths still work after security fix."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
|
||||||
|
# Valid nested paths should still work
|
||||||
|
result = backend.upload(sample_file, "a/b/c/d/file.txt")
|
||||||
|
|
||||||
|
assert result == "a/b/c/d/file.txt"
|
||||||
|
assert (storage_dir / "a" / "b" / "c" / "d" / "file.txt").exists()
|
||||||
158
tests/shared/storage/test_prefixes.py
Normal file
158
tests/shared/storage/test_prefixes.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Tests for storage prefixes module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from shared.storage.prefixes import PREFIXES, StoragePrefixes
|
||||||
|
|
||||||
|
|
||||||
|
class TestStoragePrefixes:
|
||||||
|
"""Tests for StoragePrefixes class."""
|
||||||
|
|
||||||
|
def test_prefixes_are_strings(self) -> None:
|
||||||
|
"""All prefix constants should be strings."""
|
||||||
|
assert isinstance(PREFIXES.DOCUMENTS, str)
|
||||||
|
assert isinstance(PREFIXES.IMAGES, str)
|
||||||
|
assert isinstance(PREFIXES.UPLOADS, str)
|
||||||
|
assert isinstance(PREFIXES.RESULTS, str)
|
||||||
|
assert isinstance(PREFIXES.EXPORTS, str)
|
||||||
|
assert isinstance(PREFIXES.DATASETS, str)
|
||||||
|
assert isinstance(PREFIXES.MODELS, str)
|
||||||
|
assert isinstance(PREFIXES.RAW_PDFS, str)
|
||||||
|
assert isinstance(PREFIXES.STRUCTURED_DATA, str)
|
||||||
|
assert isinstance(PREFIXES.ADMIN_IMAGES, str)
|
||||||
|
|
||||||
|
def test_prefixes_are_non_empty(self) -> None:
|
||||||
|
"""All prefix constants should be non-empty."""
|
||||||
|
assert PREFIXES.DOCUMENTS
|
||||||
|
assert PREFIXES.IMAGES
|
||||||
|
assert PREFIXES.UPLOADS
|
||||||
|
assert PREFIXES.RESULTS
|
||||||
|
assert PREFIXES.EXPORTS
|
||||||
|
assert PREFIXES.DATASETS
|
||||||
|
assert PREFIXES.MODELS
|
||||||
|
assert PREFIXES.RAW_PDFS
|
||||||
|
assert PREFIXES.STRUCTURED_DATA
|
||||||
|
assert PREFIXES.ADMIN_IMAGES
|
||||||
|
|
||||||
|
def test_prefixes_have_no_leading_slash(self) -> None:
|
||||||
|
"""Prefixes should not start with a slash for portability."""
|
||||||
|
assert not PREFIXES.DOCUMENTS.startswith("/")
|
||||||
|
assert not PREFIXES.IMAGES.startswith("/")
|
||||||
|
assert not PREFIXES.UPLOADS.startswith("/")
|
||||||
|
assert not PREFIXES.RESULTS.startswith("/")
|
||||||
|
|
||||||
|
def test_prefixes_have_no_trailing_slash(self) -> None:
|
||||||
|
"""Prefixes should not end with a slash."""
|
||||||
|
assert not PREFIXES.DOCUMENTS.endswith("/")
|
||||||
|
assert not PREFIXES.IMAGES.endswith("/")
|
||||||
|
assert not PREFIXES.UPLOADS.endswith("/")
|
||||||
|
assert not PREFIXES.RESULTS.endswith("/")
|
||||||
|
|
||||||
|
def test_frozen_dataclass(self) -> None:
|
||||||
|
"""StoragePrefixes should be immutable."""
|
||||||
|
with pytest.raises(Exception): # FrozenInstanceError
|
||||||
|
PREFIXES.DOCUMENTS = "new_value" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentPath:
|
||||||
|
"""Tests for document_path helper."""
|
||||||
|
|
||||||
|
def test_document_path_with_extension(self) -> None:
|
||||||
|
"""Should generate correct document path with extension."""
|
||||||
|
path = PREFIXES.document_path("abc123", ".pdf")
|
||||||
|
assert path == "documents/abc123.pdf"
|
||||||
|
|
||||||
|
def test_document_path_without_leading_dot(self) -> None:
|
||||||
|
"""Should handle extension without leading dot."""
|
||||||
|
path = PREFIXES.document_path("abc123", "pdf")
|
||||||
|
assert path == "documents/abc123.pdf"
|
||||||
|
|
||||||
|
def test_document_path_default_extension(self) -> None:
|
||||||
|
"""Should use .pdf as default extension."""
|
||||||
|
path = PREFIXES.document_path("abc123")
|
||||||
|
assert path == "documents/abc123.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
class TestImagePath:
|
||||||
|
"""Tests for image_path helper."""
|
||||||
|
|
||||||
|
def test_image_path_basic(self) -> None:
|
||||||
|
"""Should generate correct image path."""
|
||||||
|
path = PREFIXES.image_path("doc123", 1)
|
||||||
|
assert path == "images/doc123/page_1.png"
|
||||||
|
|
||||||
|
def test_image_path_page_number(self) -> None:
|
||||||
|
"""Should include page number in path."""
|
||||||
|
path = PREFIXES.image_path("doc123", 5)
|
||||||
|
assert path == "images/doc123/page_5.png"
|
||||||
|
|
||||||
|
def test_image_path_custom_extension(self) -> None:
|
||||||
|
"""Should support custom extension."""
|
||||||
|
path = PREFIXES.image_path("doc123", 1, ".jpg")
|
||||||
|
assert path == "images/doc123/page_1.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadPath:
|
||||||
|
"""Tests for upload_path helper."""
|
||||||
|
|
||||||
|
def test_upload_path_basic(self) -> None:
|
||||||
|
"""Should generate correct upload path."""
|
||||||
|
path = PREFIXES.upload_path("invoice.pdf")
|
||||||
|
assert path == "uploads/invoice.pdf"
|
||||||
|
|
||||||
|
def test_upload_path_with_subfolder(self) -> None:
|
||||||
|
"""Should include subfolder when provided."""
|
||||||
|
path = PREFIXES.upload_path("invoice.pdf", "async")
|
||||||
|
assert path == "uploads/async/invoice.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultPath:
|
||||||
|
"""Tests for result_path helper."""
|
||||||
|
|
||||||
|
def test_result_path_basic(self) -> None:
|
||||||
|
"""Should generate correct result path."""
|
||||||
|
path = PREFIXES.result_path("output.json")
|
||||||
|
assert path == "results/output.json"
|
||||||
|
|
||||||
|
|
||||||
|
class TestExportPath:
|
||||||
|
"""Tests for export_path helper."""
|
||||||
|
|
||||||
|
def test_export_path_basic(self) -> None:
|
||||||
|
"""Should generate correct export path."""
|
||||||
|
path = PREFIXES.export_path("exp123", "dataset.zip")
|
||||||
|
assert path == "exports/exp123/dataset.zip"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetPath:
|
||||||
|
"""Tests for dataset_path helper."""
|
||||||
|
|
||||||
|
def test_dataset_path_basic(self) -> None:
|
||||||
|
"""Should generate correct dataset path."""
|
||||||
|
path = PREFIXES.dataset_path("ds123", "data.yaml")
|
||||||
|
assert path == "datasets/ds123/data.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelPath:
|
||||||
|
"""Tests for model_path helper."""
|
||||||
|
|
||||||
|
def test_model_path_basic(self) -> None:
|
||||||
|
"""Should generate correct model path."""
|
||||||
|
path = PREFIXES.model_path("v1.0.0", "best.pt")
|
||||||
|
assert path == "models/v1.0.0/best.pt"
|
||||||
|
|
||||||
|
|
||||||
|
class TestExportsFromInit:
|
||||||
|
"""Tests for exports from storage __init__.py."""
|
||||||
|
|
||||||
|
def test_prefixes_exported(self) -> None:
|
||||||
|
"""PREFIXES should be exported from storage module."""
|
||||||
|
from shared.storage import PREFIXES as exported_prefixes
|
||||||
|
|
||||||
|
assert exported_prefixes is PREFIXES
|
||||||
|
|
||||||
|
def test_storage_prefixes_exported(self) -> None:
|
||||||
|
"""StoragePrefixes should be exported from storage module."""
|
||||||
|
from shared.storage import StoragePrefixes as exported_class
|
||||||
|
|
||||||
|
assert exported_class is StoragePrefixes
|
||||||
264
tests/shared/storage/test_presigned_urls.py
Normal file
264
tests/shared/storage/test_presigned_urls.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
"""
|
||||||
|
Tests for pre-signed URL functionality across all storage backends.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_storage_dir() -> Path:
|
||||||
|
"""Create a temporary directory for storage tests."""
|
||||||
|
temp_dir = Path(tempfile.mkdtemp())
|
||||||
|
yield temp_dir
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_file(temp_storage_dir: Path) -> Path:
|
||||||
|
"""Create a sample file for testing."""
|
||||||
|
file_path = temp_storage_dir / "sample.txt"
|
||||||
|
file_path.write_text("Hello, World!")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageBackendInterfacePresignedUrl:
|
||||||
|
"""Tests for get_presigned_url in StorageBackend interface."""
|
||||||
|
|
||||||
|
def test_subclass_must_implement_get_presigned_url(self) -> None:
|
||||||
|
"""Test that subclass must implement get_presigned_url method."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class IncompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
IncompleteBackend() # type: ignore
|
||||||
|
|
||||||
|
def test_valid_subclass_with_get_presigned_url_can_be_instantiated(self) -> None:
|
||||||
|
"""Test that a complete subclass with get_presigned_url can be instantiated."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
class CompleteBackend(StorageBackend):
|
||||||
|
def upload(
|
||||||
|
self, local_path: Path, remote_path: str, overwrite: bool = False
|
||||||
|
) -> str:
|
||||||
|
return remote_path
|
||||||
|
|
||||||
|
def download(self, remote_path: str, local_path: Path) -> Path:
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
def exists(self, remote_path: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_files(self, prefix: str) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, remote_path: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_url(self, remote_path: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_presigned_url(
|
||||||
|
self, remote_path: str, expires_in_seconds: int = 3600
|
||||||
|
) -> str:
|
||||||
|
return f"https://example.com/{remote_path}?token=abc"
|
||||||
|
|
||||||
|
backend = CompleteBackend()
|
||||||
|
assert isinstance(backend, StorageBackend)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalStorageBackendPresignedUrl:
|
||||||
|
"""Tests for LocalStorageBackend.get_presigned_url method."""
|
||||||
|
|
||||||
|
def test_get_presigned_url_returns_file_uri(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url returns file:// URI for existing file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
url = backend.get_presigned_url("sample.txt")
|
||||||
|
|
||||||
|
assert url.startswith("file://")
|
||||||
|
assert "sample.txt" in url
|
||||||
|
|
||||||
|
def test_get_presigned_url_with_custom_expiry(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url accepts expires_in_seconds parameter."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "sample.txt")
|
||||||
|
|
||||||
|
# For local storage, expiry is ignored but should not raise error
|
||||||
|
url = backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
|
||||||
|
|
||||||
|
assert url.startswith("file://")
|
||||||
|
|
||||||
|
def test_get_presigned_url_nonexistent_file_raises(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url raises FileNotFoundStorageError for missing file."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.get_presigned_url("nonexistent.txt")
|
||||||
|
|
||||||
|
def test_get_presigned_url_path_traversal_blocked(
|
||||||
|
self, temp_storage_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that path traversal in get_presigned_url is blocked."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
backend = LocalStorageBackend(base_path=temp_storage_dir)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="Path traversal not allowed"):
|
||||||
|
backend.get_presigned_url("../escape.txt")
|
||||||
|
|
||||||
|
def test_get_presigned_url_nested_path(
|
||||||
|
self, temp_storage_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url works with nested paths."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_dir = temp_storage_dir / "storage"
|
||||||
|
backend = LocalStorageBackend(base_path=storage_dir)
|
||||||
|
backend.upload(sample_file, "a/b/c/sample.txt")
|
||||||
|
|
||||||
|
url = backend.get_presigned_url("a/b/c/sample.txt")
|
||||||
|
|
||||||
|
assert url.startswith("file://")
|
||||||
|
assert "sample.txt" in url
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureBlobStorageBackendPresignedUrl:
|
||||||
|
"""Tests for AzureBlobStorageBackend.get_presigned_url method."""
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_get_presigned_url_generates_sas_url(
|
||||||
|
self, mock_blob_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url generates URL with SAS token."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_blob_service = MagicMock()
|
||||||
|
mock_blob_service.account_name = "testaccount"
|
||||||
|
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||||
|
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.exists.return_value = True
|
||||||
|
mock_blob_service.get_container_client.return_value = mock_container
|
||||||
|
|
||||||
|
mock_blob_client = MagicMock()
|
||||||
|
mock_blob_client.exists.return_value = True
|
||||||
|
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob_client
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
|
||||||
|
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
|
||||||
|
|
||||||
|
url = backend.get_presigned_url("test.txt", expires_in_seconds=3600)
|
||||||
|
|
||||||
|
assert "https://testaccount.blob.core.windows.net" in url
|
||||||
|
assert "sv=2021-06-08" in url or "test.txt" in url
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_get_presigned_url_nonexistent_blob_raises(
|
||||||
|
self, mock_blob_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url raises for nonexistent blob."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_blob_service = MagicMock()
|
||||||
|
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||||
|
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.exists.return_value = True
|
||||||
|
mock_blob_service.get_container_client.return_value = mock_container
|
||||||
|
|
||||||
|
mock_blob_client = MagicMock()
|
||||||
|
mock_blob_client.exists.return_value = False
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob_client
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key==;EndpointSuffix=core.windows.net",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.get_presigned_url("nonexistent.txt")
|
||||||
|
|
||||||
|
@patch("shared.storage.azure.BlobServiceClient")
|
||||||
|
def test_get_presigned_url_uses_custom_expiry(
|
||||||
|
self, mock_blob_service_class: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url uses custom expiry time."""
|
||||||
|
from shared.storage.azure import AzureBlobStorageBackend
|
||||||
|
|
||||||
|
mock_blob_service = MagicMock()
|
||||||
|
mock_blob_service.account_name = "testaccount"
|
||||||
|
mock_blob_service_class.from_connection_string.return_value = mock_blob_service
|
||||||
|
|
||||||
|
mock_container = MagicMock()
|
||||||
|
mock_container.exists.return_value = True
|
||||||
|
mock_blob_service.get_container_client.return_value = mock_container
|
||||||
|
|
||||||
|
mock_blob_client = MagicMock()
|
||||||
|
mock_blob_client.exists.return_value = True
|
||||||
|
mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt"
|
||||||
|
mock_container.get_blob_client.return_value = mock_blob_client
|
||||||
|
|
||||||
|
backend = AzureBlobStorageBackend(
|
||||||
|
connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net",
|
||||||
|
container_name="container",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas:
|
||||||
|
mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123"
|
||||||
|
|
||||||
|
backend.get_presigned_url("test.txt", expires_in_seconds=7200)
|
||||||
|
|
||||||
|
# Verify generate_blob_sas was called (expiry is part of the call)
|
||||||
|
mock_generate_sas.assert_called_once()
|
||||||
520
tests/shared/storage/test_s3.py
Normal file
520
tests/shared/storage/test_s3.py
Normal file
@@ -0,0 +1,520 @@
|
|||||||
|
"""
|
||||||
|
Tests for S3StorageBackend.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch, call
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir() -> Path:
|
||||||
|
"""Create a temporary directory for tests."""
|
||||||
|
temp_dir = Path(tempfile.mkdtemp())
|
||||||
|
yield temp_dir
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_file(temp_dir: Path) -> Path:
|
||||||
|
"""Create a sample file for testing."""
|
||||||
|
file_path = temp_dir / "sample.txt"
|
||||||
|
file_path.write_text("Hello, World!")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_boto3_client():
|
||||||
|
"""Create a mock boto3 S3 client."""
|
||||||
|
with patch("boto3.client") as mock_client_func:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_func.return_value = mock_client
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendCreation:
|
||||||
|
"""Tests for S3StorageBackend instantiation."""
|
||||||
|
|
||||||
|
def test_create_with_bucket_name(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test creating backend with bucket name."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
assert backend.bucket_name == "test-bucket"
|
||||||
|
|
||||||
|
def test_create_with_region(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test creating backend with region."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
with patch("boto3.client") as mock_client:
|
||||||
|
S3StorageBackend(
|
||||||
|
bucket_name="test-bucket",
|
||||||
|
region_name="us-west-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
call_kwargs = mock_client.call_args[1]
|
||||||
|
assert call_kwargs.get("region_name") == "us-west-2"
|
||||||
|
|
||||||
|
def test_create_with_credentials(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test creating backend with explicit credentials."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
with patch("boto3.client") as mock_client:
|
||||||
|
S3StorageBackend(
|
||||||
|
bucket_name="test-bucket",
|
||||||
|
access_key_id="AKIATEST",
|
||||||
|
secret_access_key="secret123",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
call_kwargs = mock_client.call_args[1]
|
||||||
|
assert call_kwargs.get("aws_access_key_id") == "AKIATEST"
|
||||||
|
assert call_kwargs.get("aws_secret_access_key") == "secret123"
|
||||||
|
|
||||||
|
def test_create_with_endpoint_url(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test creating backend with custom endpoint (for S3-compatible services)."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
with patch("boto3.client") as mock_client:
|
||||||
|
S3StorageBackend(
|
||||||
|
bucket_name="test-bucket",
|
||||||
|
endpoint_url="http://localhost:9000",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
call_kwargs = mock_client.call_args[1]
|
||||||
|
assert call_kwargs.get("endpoint_url") == "http://localhost:9000"
|
||||||
|
|
||||||
|
def test_create_bucket_when_requested(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test that bucket is created when create_bucket=True."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_bucket.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadBucket"
|
||||||
|
)
|
||||||
|
|
||||||
|
S3StorageBackend(
|
||||||
|
bucket_name="test-bucket",
|
||||||
|
create_bucket=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_boto3_client.create_bucket.assert_called_once()
|
||||||
|
|
||||||
|
def test_is_storage_backend_subclass(self, mock_boto3_client: MagicMock) -> None:
|
||||||
|
"""Test that S3StorageBackend is a StorageBackend."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
assert isinstance(backend, StorageBackend)
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendUpload:
|
||||||
|
"""Tests for S3StorageBackend.upload method."""
|
||||||
|
|
||||||
|
def test_upload_file(
|
||||||
|
self, mock_boto3_client: MagicMock, temp_dir: Path, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test uploading a file."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
# Object does not exist
|
||||||
|
mock_boto3_client.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
result = backend.upload(sample_file, "uploads/sample.txt")
|
||||||
|
|
||||||
|
assert result == "uploads/sample.txt"
|
||||||
|
mock_boto3_client.upload_file.assert_called_once()
|
||||||
|
|
||||||
|
def test_upload_fails_if_exists_without_overwrite(
|
||||||
|
self, mock_boto3_client: MagicMock, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload fails if object exists and overwrite is False."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="already exists"):
|
||||||
|
backend.upload(sample_file, "sample.txt", overwrite=False)
|
||||||
|
|
||||||
|
def test_upload_succeeds_with_overwrite(
|
||||||
|
self, mock_boto3_client: MagicMock, sample_file: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that upload succeeds with overwrite=True."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
result = backend.upload(sample_file, "sample.txt", overwrite=True)
|
||||||
|
|
||||||
|
assert result == "sample.txt"
|
||||||
|
mock_boto3_client.upload_file.assert_called_once()
|
||||||
|
|
||||||
|
def test_upload_nonexistent_file_fails(
|
||||||
|
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that uploading nonexistent file fails."""
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.upload(temp_dir / "nonexistent.txt", "sample.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendDownload:
|
||||||
|
"""Tests for S3StorageBackend.download method."""
|
||||||
|
|
||||||
|
def test_download_file(
|
||||||
|
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading a file."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
local_path = temp_dir / "downloaded.txt"
|
||||||
|
|
||||||
|
result = backend.download("sample.txt", local_path)
|
||||||
|
|
||||||
|
assert result == local_path
|
||||||
|
mock_boto3_client.download_file.assert_called_once()
|
||||||
|
|
||||||
|
def test_download_creates_parent_directories(
|
||||||
|
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that download creates parent directories."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {}
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
local_path = temp_dir / "deep" / "nested" / "downloaded.txt"
|
||||||
|
|
||||||
|
backend.download("sample.txt", local_path)
|
||||||
|
|
||||||
|
assert local_path.parent.exists()
|
||||||
|
|
||||||
|
def test_download_nonexistent_object_fails(
|
||||||
|
self, mock_boto3_client: MagicMock, temp_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that downloading nonexistent object fails."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.download("nonexistent.txt", temp_dir / "file.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendExists:
|
||||||
|
"""Tests for S3StorageBackend.exists method."""
|
||||||
|
|
||||||
|
def test_exists_returns_true_for_existing_object(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test exists returns True for existing object."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {}
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
assert backend.exists("sample.txt") is True
|
||||||
|
|
||||||
|
def test_exists_returns_false_for_nonexistent_object(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test exists returns False for nonexistent object."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
assert backend.exists("nonexistent.txt") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendListFiles:
|
||||||
|
"""Tests for S3StorageBackend.list_files method."""
|
||||||
|
|
||||||
|
def test_list_files_returns_objects(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test listing objects."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.list_objects_v2.return_value = {
|
||||||
|
"Contents": [
|
||||||
|
{"Key": "file1.txt"},
|
||||||
|
{"Key": "file2.txt"},
|
||||||
|
{"Key": "subdir/file3.txt"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
files = backend.list_files("")
|
||||||
|
|
||||||
|
assert len(files) == 3
|
||||||
|
assert "file1.txt" in files
|
||||||
|
assert "file2.txt" in files
|
||||||
|
assert "subdir/file3.txt" in files
|
||||||
|
|
||||||
|
def test_list_files_with_prefix(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test listing objects with prefix filter."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.list_objects_v2.return_value = {
|
||||||
|
"Contents": [
|
||||||
|
{"Key": "images/a.png"},
|
||||||
|
{"Key": "images/b.png"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
files = backend.list_files("images/")
|
||||||
|
|
||||||
|
mock_boto3_client.list_objects_v2.assert_called_with(
|
||||||
|
Bucket="test-bucket", Prefix="images/"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_list_files_empty_bucket(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test listing files in empty bucket."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.list_objects_v2.return_value = {} # No Contents key
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
files = backend.list_files("")
|
||||||
|
|
||||||
|
assert files == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendDelete:
|
||||||
|
"""Tests for S3StorageBackend.delete method."""
|
||||||
|
|
||||||
|
def test_delete_existing_object(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting an existing object."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {}
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
result = backend.delete("sample.txt")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_boto3_client.delete_object.assert_called_once()
|
||||||
|
|
||||||
|
def test_delete_nonexistent_object_returns_false(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting nonexistent object returns False."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
result = backend.delete("nonexistent.txt")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendGetUrl:
|
||||||
|
"""Tests for S3StorageBackend.get_url method."""
|
||||||
|
|
||||||
|
def test_get_url_returns_s3_url(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_url returns S3 URL."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {}
|
||||||
|
mock_boto3_client.generate_presigned_url.return_value = (
|
||||||
|
"https://test-bucket.s3.amazonaws.com/sample.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
url = backend.get_url("sample.txt")
|
||||||
|
|
||||||
|
assert "sample.txt" in url
|
||||||
|
|
||||||
|
def test_get_url_nonexistent_object_raises(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_url raises for nonexistent object."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.get_url("nonexistent.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendUploadBytes:
|
||||||
|
"""Tests for S3StorageBackend.upload_bytes method."""
|
||||||
|
|
||||||
|
def test_upload_bytes(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test uploading bytes directly."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
mock_boto3_client.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
data = b"Binary content here"
|
||||||
|
|
||||||
|
result = backend.upload_bytes(data, "binary.dat")
|
||||||
|
|
||||||
|
assert result == "binary.dat"
|
||||||
|
mock_boto3_client.put_object.assert_called_once()
|
||||||
|
|
||||||
|
def test_upload_bytes_fails_if_exists_without_overwrite(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test upload_bytes fails if object exists and overwrite is False."""
|
||||||
|
from shared.storage.base import StorageError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {} # Object exists
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="already exists"):
|
||||||
|
backend.upload_bytes(b"content", "sample.txt", overwrite=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendDownloadBytes:
|
||||||
|
"""Tests for S3StorageBackend.download_bytes method."""
|
||||||
|
|
||||||
|
def test_download_bytes(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading object as bytes."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.read.return_value = b"Hello, World!"
|
||||||
|
mock_boto3_client.get_object.return_value = {"Body": mock_response}
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
data = backend.download_bytes("sample.txt")
|
||||||
|
|
||||||
|
assert data == b"Hello, World!"
|
||||||
|
|
||||||
|
def test_download_bytes_nonexistent_raises(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading nonexistent object as bytes."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.get_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "NoSuchKey"}}, "GetObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.download_bytes("nonexistent.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3StorageBackendPresignedUrl:
|
||||||
|
"""Tests for S3StorageBackend.get_presigned_url method."""
|
||||||
|
|
||||||
|
def test_get_presigned_url_generates_url(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url generates presigned URL."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {}
|
||||||
|
mock_boto3_client.generate_presigned_url.return_value = (
|
||||||
|
"https://test-bucket.s3.amazonaws.com/sample.txt?X-Amz-Algorithm=..."
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
url = backend.get_presigned_url("sample.txt")
|
||||||
|
|
||||||
|
assert "X-Amz-Algorithm" in url or "sample.txt" in url
|
||||||
|
mock_boto3_client.generate_presigned_url.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_presigned_url_with_custom_expiry(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url uses custom expiry."""
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.return_value = {}
|
||||||
|
mock_boto3_client.generate_presigned_url.return_value = "https://..."
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
backend.get_presigned_url("sample.txt", expires_in_seconds=7200)
|
||||||
|
|
||||||
|
call_args = mock_boto3_client.generate_presigned_url.call_args
|
||||||
|
assert call_args[1].get("ExpiresIn") == 7200
|
||||||
|
|
||||||
|
def test_get_presigned_url_nonexistent_raises(
|
||||||
|
self, mock_boto3_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test get_presigned_url raises for nonexistent object."""
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from shared.storage.base import FileNotFoundStorageError
|
||||||
|
from shared.storage.s3 import S3StorageBackend
|
||||||
|
|
||||||
|
mock_boto3_client.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "404"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = S3StorageBackend(bucket_name="test-bucket")
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundStorageError):
|
||||||
|
backend.get_presigned_url("nonexistent.txt")
|
||||||
@@ -9,7 +9,8 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
|
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||||
|
from shared.fields import FIELD_CLASSES
|
||||||
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||||
from inference.web.schemas.admin import (
|
from inference.web.schemas.admin import (
|
||||||
AnnotationCreate,
|
AnnotationCreate,
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class MockAdminDocument:
|
|||||||
self.batch_id = kwargs.get('batch_id', None)
|
self.batch_id = kwargs.get('batch_id', None)
|
||||||
self.csv_field_values = kwargs.get('csv_field_values', None)
|
self.csv_field_values = kwargs.get('csv_field_values', None)
|
||||||
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
||||||
|
self.category = kwargs.get('category', 'invoice')
|
||||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||||
|
|
||||||
@@ -67,12 +68,13 @@ class MockAdminDB:
|
|||||||
|
|
||||||
def get_documents_by_token(
|
def get_documents_by_token(
|
||||||
self,
|
self,
|
||||||
admin_token,
|
admin_token=None,
|
||||||
status=None,
|
status=None,
|
||||||
upload_source=None,
|
upload_source=None,
|
||||||
has_annotations=None,
|
has_annotations=None,
|
||||||
auto_label_status=None,
|
auto_label_status=None,
|
||||||
batch_id=None,
|
batch_id=None,
|
||||||
|
category=None,
|
||||||
limit=20,
|
limit=20,
|
||||||
offset=0
|
offset=0
|
||||||
):
|
):
|
||||||
@@ -95,6 +97,8 @@ class MockAdminDB:
|
|||||||
docs = [d for d in docs if d.auto_label_status == auto_label_status]
|
docs = [d for d in docs if d.auto_label_status == auto_label_status]
|
||||||
if batch_id:
|
if batch_id:
|
||||||
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
|
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
|
||||||
|
if category:
|
||||||
|
docs = [d for d in docs if d.category == category]
|
||||||
|
|
||||||
total = len(docs)
|
total = len(docs)
|
||||||
return docs[offset:offset+limit], total
|
return docs[offset:offset+limit], total
|
||||||
|
|||||||
@@ -215,8 +215,10 @@ class TestAsyncProcessingService:
|
|||||||
|
|
||||||
def test_cleanup_orphan_files(self, async_service, mock_db):
|
def test_cleanup_orphan_files(self, async_service, mock_db):
|
||||||
"""Test cleanup of orphan files."""
|
"""Test cleanup of orphan files."""
|
||||||
# Create an orphan file
|
# Create the async upload directory
|
||||||
temp_dir = async_service._async_config.temp_upload_dir
|
temp_dir = async_service._async_config.temp_upload_dir
|
||||||
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
orphan_file = temp_dir / "orphan-request.pdf"
|
orphan_file = temp_dir / "orphan-request.pdf"
|
||||||
orphan_file.write_bytes(b"orphan content")
|
orphan_file.write_bytes(b"orphan content")
|
||||||
|
|
||||||
@@ -228,7 +230,13 @@ class TestAsyncProcessingService:
|
|||||||
# Mock database to say file doesn't exist
|
# Mock database to say file doesn't exist
|
||||||
mock_db.get_request.return_value = None
|
mock_db.get_request.return_value = None
|
||||||
|
|
||||||
count = async_service._cleanup_orphan_files()
|
# Mock the storage helper to return the same directory as the fixture
|
||||||
|
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
|
||||||
|
mock_helper = MagicMock()
|
||||||
|
mock_helper.get_uploads_base_path.return_value = temp_dir
|
||||||
|
mock_storage.return_value = mock_helper
|
||||||
|
|
||||||
|
count = async_service._cleanup_orphan_files()
|
||||||
|
|
||||||
assert count == 1
|
assert count == 1
|
||||||
assert not orphan_file.exists()
|
assert not orphan_file.exists()
|
||||||
|
|||||||
@@ -5,7 +5,75 @@ TDD Phase 5: RED - Write tests first, then implement to pass.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||||
|
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||||
|
|
||||||
|
|
||||||
|
TEST_ADMIN_TOKEN = "test-admin-token-12345"
|
||||||
|
TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001"
|
||||||
|
TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_token() -> str:
|
||||||
|
"""Provide admin token for testing."""
|
||||||
|
return TEST_ADMIN_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_admin_db() -> MagicMock:
|
||||||
|
"""Create a mock AdminDB for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
# Default return values
|
||||||
|
mock.get_document_by_token.return_value = None
|
||||||
|
mock.get_dataset.return_value = None
|
||||||
|
mock.get_augmented_datasets.return_value = ([], 0)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||||
|
"""Create test client with admin authentication."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
def get_token_override():
|
||||||
|
return TEST_ADMIN_TOKEN
|
||||||
|
|
||||||
|
def get_db_override():
|
||||||
|
return mock_admin_db
|
||||||
|
|
||||||
|
app.dependency_overrides[validate_admin_token] = get_token_override
|
||||||
|
app.dependency_overrides[get_admin_db] = get_db_override
|
||||||
|
|
||||||
|
# Include router - the router already has /augmentation prefix
|
||||||
|
# so we add /api/v1/admin to get /api/v1/admin/augmentation
|
||||||
|
router = create_augmentation_router()
|
||||||
|
app.include_router(router, prefix="/api/v1/admin")
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
|
||||||
|
"""Create test client WITHOUT admin authentication override."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Only override the database, NOT the token validation
|
||||||
|
def get_db_override():
|
||||||
|
return mock_admin_db
|
||||||
|
|
||||||
|
app.dependency_overrides[get_admin_db] = get_db_override
|
||||||
|
|
||||||
|
router = create_augmentation_router()
|
||||||
|
app.include_router(router, prefix="/api/v1/admin")
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
class TestAugmentationTypesEndpoint:
|
class TestAugmentationTypesEndpoint:
|
||||||
@@ -34,10 +102,10 @@ class TestAugmentationTypesEndpoint:
|
|||||||
assert "stage" in aug_type
|
assert "stage" in aug_type
|
||||||
|
|
||||||
def test_list_augmentation_types_unauthorized(
|
def test_list_augmentation_types_unauthorized(
|
||||||
self, admin_client: TestClient
|
self, unauthenticated_client: TestClient
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that unauthorized request is rejected."""
|
"""Test that unauthorized request is rejected."""
|
||||||
response = admin_client.get("/api/v1/admin/augmentation/types")
|
response = unauthenticated_client.get("/api/v1/admin/augmentation/types")
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
@@ -74,16 +142,30 @@ class TestAugmentationPreviewEndpoint:
|
|||||||
admin_client: TestClient,
|
admin_client: TestClient,
|
||||||
admin_token: str,
|
admin_token: str,
|
||||||
sample_document_id: str,
|
sample_document_id: str,
|
||||||
|
mock_admin_db: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test previewing augmentation on a document."""
|
"""Test previewing augmentation on a document."""
|
||||||
response = admin_client.post(
|
# Mock document exists
|
||||||
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
mock_document = MagicMock()
|
||||||
headers={"X-Admin-Token": admin_token},
|
mock_document.images_dir = "/fake/path"
|
||||||
json={
|
mock_admin_db.get_document.return_value = mock_document
|
||||||
"augmentation_type": "gaussian_noise",
|
|
||||||
"params": {"std": 15},
|
# Create a fake image (100x100 RGB)
|
||||||
},
|
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||||
)
|
|
||||||
|
with patch(
|
||||||
|
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||||
|
) as mock_load:
|
||||||
|
mock_load.return_value = fake_image
|
||||||
|
|
||||||
|
response = admin_client.post(
|
||||||
|
f"/api/v1/admin/augmentation/preview/{sample_document_id}",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"augmentation_type": "gaussian_noise",
|
||||||
|
"params": {"std": 15},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -136,18 +218,32 @@ class TestAugmentationPreviewConfigEndpoint:
|
|||||||
admin_client: TestClient,
|
admin_client: TestClient,
|
||||||
admin_token: str,
|
admin_token: str,
|
||||||
sample_document_id: str,
|
sample_document_id: str,
|
||||||
|
mock_admin_db: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test previewing full config on a document."""
|
"""Test previewing full config on a document."""
|
||||||
response = admin_client.post(
|
# Mock document exists
|
||||||
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
mock_document = MagicMock()
|
||||||
headers={"X-Admin-Token": admin_token},
|
mock_document.images_dir = "/fake/path"
|
||||||
json={
|
mock_admin_db.get_document.return_value = mock_document
|
||||||
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
|
||||||
"lighting_variation": {"enabled": True, "probability": 1.0},
|
# Create a fake image (100x100 RGB)
|
||||||
"preserve_bboxes": True,
|
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||||
"seed": 42,
|
|
||||||
},
|
with patch(
|
||||||
)
|
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||||
|
) as mock_load:
|
||||||
|
mock_load.return_value = fake_image
|
||||||
|
|
||||||
|
response = admin_client.post(
|
||||||
|
f"/api/v1/admin/augmentation/preview-config/{sample_document_id}",
|
||||||
|
headers={"X-Admin-Token": admin_token},
|
||||||
|
json={
|
||||||
|
"gaussian_noise": {"enabled": True, "probability": 1.0},
|
||||||
|
"lighting_variation": {"enabled": True, "probability": 1.0},
|
||||||
|
"preserve_bboxes": True,
|
||||||
|
"seed": 42,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -164,8 +260,14 @@ class TestAugmentationBatchEndpoint:
|
|||||||
admin_client: TestClient,
|
admin_client: TestClient,
|
||||||
admin_token: str,
|
admin_token: str,
|
||||||
sample_dataset_id: str,
|
sample_dataset_id: str,
|
||||||
|
mock_admin_db: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test creating augmented dataset."""
|
"""Test creating augmented dataset."""
|
||||||
|
# Mock dataset exists
|
||||||
|
mock_dataset = MagicMock()
|
||||||
|
mock_dataset.total_images = 100
|
||||||
|
mock_admin_db.get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
response = admin_client.post(
|
response = admin_client.post(
|
||||||
"/api/v1/admin/augmentation/batch",
|
"/api/v1/admin/augmentation/batch",
|
||||||
headers={"X-Admin-Token": admin_token},
|
headers={"X-Admin-Token": admin_token},
|
||||||
@@ -250,12 +352,10 @@ class TestAugmentedDatasetsListEndpoint:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_document_id() -> str:
|
def sample_document_id() -> str:
|
||||||
"""Provide a sample document ID for testing."""
|
"""Provide a sample document ID for testing."""
|
||||||
# This would need to be created in test setup
|
return TEST_DOCUMENT_UUID
|
||||||
return "test-document-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_dataset_id() -> str:
|
def sample_dataset_id() -> str:
|
||||||
"""Provide a sample dataset ID for testing."""
|
"""Provide a sample dataset ID for testing."""
|
||||||
# This would need to be created in test setup
|
return TEST_DATASET_UUID
|
||||||
return "test-dataset-id"
|
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ def _make_dataset(**overrides) -> MagicMock:
|
|||||||
name="test-dataset",
|
name="test-dataset",
|
||||||
description="Test dataset",
|
description="Test dataset",
|
||||||
status="ready",
|
status="ready",
|
||||||
|
training_status=None,
|
||||||
|
active_training_task_id=None,
|
||||||
train_ratio=0.8,
|
train_ratio=0.8,
|
||||||
val_ratio=0.1,
|
val_ratio=0.1,
|
||||||
seed=42,
|
seed=42,
|
||||||
@@ -183,6 +185,8 @@ class TestListDatasetsRoute:
|
|||||||
|
|
||||||
mock_db = MagicMock()
|
mock_db = MagicMock()
|
||||||
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
|
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
|
||||||
|
# Mock the active training tasks lookup to return empty dict
|
||||||
|
mock_db.get_active_training_tasks_for_datasets.return_value = {}
|
||||||
|
|
||||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||||
|
|
||||||
|
|||||||
363
tests/web/test_dataset_training_status.py
Normal file
363
tests/web/test_dataset_training_status.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
"""
|
||||||
|
Tests for dataset training status feature.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
1. Database model fields (training_status, active_training_task_id)
|
||||||
|
2. AdminDB update_dataset_training_status method
|
||||||
|
3. API response includes training status fields
|
||||||
|
4. Scheduler updates dataset status during training lifecycle
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Test Database Model
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingDatasetModel:
|
||||||
|
"""Tests for TrainingDataset model fields."""
|
||||||
|
|
||||||
|
def test_training_dataset_has_training_status_field(self):
|
||||||
|
"""TrainingDataset model should have training_status field."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset = TrainingDataset(
|
||||||
|
name="test-dataset",
|
||||||
|
training_status="running",
|
||||||
|
)
|
||||||
|
assert dataset.training_status == "running"
|
||||||
|
|
||||||
|
def test_training_dataset_has_active_training_task_id_field(self):
|
||||||
|
"""TrainingDataset model should have active_training_task_id field."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
task_id = uuid4()
|
||||||
|
dataset = TrainingDataset(
|
||||||
|
name="test-dataset",
|
||||||
|
active_training_task_id=task_id,
|
||||||
|
)
|
||||||
|
assert dataset.active_training_task_id == task_id
|
||||||
|
|
||||||
|
def test_training_dataset_defaults(self):
|
||||||
|
"""TrainingDataset should have correct defaults for new fields."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset = TrainingDataset(name="test-dataset")
|
||||||
|
assert dataset.training_status is None
|
||||||
|
assert dataset.active_training_task_id is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Test AdminDB Methods
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminDBDatasetTrainingStatus:
|
||||||
|
"""Tests for AdminDB.update_dataset_training_status method."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
"""Create mock database session."""
|
||||||
|
session = MagicMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
def test_update_dataset_training_status_sets_status(self, mock_session):
|
||||||
|
"""update_dataset_training_status should set training_status."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset_id = uuid4()
|
||||||
|
dataset = TrainingDataset(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
name="test-dataset",
|
||||||
|
status="ready",
|
||||||
|
)
|
||||||
|
mock_session.get.return_value = dataset
|
||||||
|
|
||||||
|
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||||
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
db.update_dataset_training_status(
|
||||||
|
dataset_id=str(dataset_id),
|
||||||
|
training_status="running",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset.training_status == "running"
|
||||||
|
mock_session.add.assert_called_once_with(dataset)
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
def test_update_dataset_training_status_sets_task_id(self, mock_session):
|
||||||
|
"""update_dataset_training_status should set active_training_task_id."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset_id = uuid4()
|
||||||
|
task_id = uuid4()
|
||||||
|
dataset = TrainingDataset(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
name="test-dataset",
|
||||||
|
status="ready",
|
||||||
|
)
|
||||||
|
mock_session.get.return_value = dataset
|
||||||
|
|
||||||
|
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||||
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
db.update_dataset_training_status(
|
||||||
|
dataset_id=str(dataset_id),
|
||||||
|
training_status="running",
|
||||||
|
active_training_task_id=str(task_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset.active_training_task_id == task_id
|
||||||
|
|
||||||
|
def test_update_dataset_training_status_updates_main_status_on_complete(
|
||||||
|
self, mock_session
|
||||||
|
):
|
||||||
|
"""update_dataset_training_status should update main status to 'trained' when completed."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset_id = uuid4()
|
||||||
|
dataset = TrainingDataset(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
name="test-dataset",
|
||||||
|
status="ready",
|
||||||
|
)
|
||||||
|
mock_session.get.return_value = dataset
|
||||||
|
|
||||||
|
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||||
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
db.update_dataset_training_status(
|
||||||
|
dataset_id=str(dataset_id),
|
||||||
|
training_status="completed",
|
||||||
|
update_main_status=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset.status == "trained"
|
||||||
|
assert dataset.training_status == "completed"
|
||||||
|
|
||||||
|
def test_update_dataset_training_status_clears_task_id_on_complete(
|
||||||
|
self, mock_session
|
||||||
|
):
|
||||||
|
"""update_dataset_training_status should clear task_id when training completes."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset_id = uuid4()
|
||||||
|
task_id = uuid4()
|
||||||
|
dataset = TrainingDataset(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
name="test-dataset",
|
||||||
|
status="ready",
|
||||||
|
training_status="running",
|
||||||
|
active_training_task_id=task_id,
|
||||||
|
)
|
||||||
|
mock_session.get.return_value = dataset
|
||||||
|
|
||||||
|
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||||
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
db.update_dataset_training_status(
|
||||||
|
dataset_id=str(dataset_id),
|
||||||
|
training_status="completed",
|
||||||
|
active_training_task_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset.active_training_task_id is None
|
||||||
|
|
||||||
|
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
|
||||||
|
"""update_dataset_training_status should handle missing dataset gracefully."""
|
||||||
|
mock_session.get.return_value = None
|
||||||
|
|
||||||
|
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||||
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
# Should not raise
|
||||||
|
db.update_dataset_training_status(
|
||||||
|
dataset_id=str(uuid4()),
|
||||||
|
training_status="running",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session.add.assert_not_called()
|
||||||
|
mock_session.commit.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Test API Response
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetDetailResponseTrainingStatus:
|
||||||
|
"""Tests for DatasetDetailResponse including training status fields."""
|
||||||
|
|
||||||
|
def test_dataset_detail_response_includes_training_status(self):
|
||||||
|
"""DatasetDetailResponse schema should include training_status field."""
|
||||||
|
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||||
|
|
||||||
|
response = DatasetDetailResponse(
|
||||||
|
dataset_id=str(uuid4()),
|
||||||
|
name="test-dataset",
|
||||||
|
description=None,
|
||||||
|
status="ready",
|
||||||
|
training_status="running",
|
||||||
|
active_training_task_id=str(uuid4()),
|
||||||
|
train_ratio=0.8,
|
||||||
|
val_ratio=0.1,
|
||||||
|
seed=42,
|
||||||
|
total_documents=10,
|
||||||
|
total_images=15,
|
||||||
|
total_annotations=100,
|
||||||
|
dataset_path="/path/to/dataset",
|
||||||
|
error_message=None,
|
||||||
|
documents=[],
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
updated_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.training_status == "running"
|
||||||
|
assert response.active_training_task_id is not None
|
||||||
|
|
||||||
|
def test_dataset_detail_response_allows_null_training_status(self):
|
||||||
|
"""DatasetDetailResponse should allow null training_status."""
|
||||||
|
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||||
|
|
||||||
|
response = DatasetDetailResponse(
|
||||||
|
dataset_id=str(uuid4()),
|
||||||
|
name="test-dataset",
|
||||||
|
description=None,
|
||||||
|
status="ready",
|
||||||
|
training_status=None,
|
||||||
|
active_training_task_id=None,
|
||||||
|
train_ratio=0.8,
|
||||||
|
val_ratio=0.1,
|
||||||
|
seed=42,
|
||||||
|
total_documents=10,
|
||||||
|
total_images=15,
|
||||||
|
total_annotations=100,
|
||||||
|
dataset_path=None,
|
||||||
|
error_message=None,
|
||||||
|
documents=[],
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
updated_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.training_status is None
|
||||||
|
assert response.active_training_task_id is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Test Scheduler Training Status Updates
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchedulerDatasetStatusUpdates:
|
||||||
|
"""Tests for scheduler updating dataset status during training."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db(self):
|
||||||
|
"""Create mock AdminDB."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.get_dataset.return_value = MagicMock(
|
||||||
|
dataset_id=uuid4(),
|
||||||
|
name="test-dataset",
|
||||||
|
dataset_path="/path/to/dataset",
|
||||||
|
total_images=100,
|
||||||
|
)
|
||||||
|
mock.get_pending_training_tasks.return_value = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
|
||||||
|
"""Scheduler should set dataset training_status to 'running' when task starts."""
|
||||||
|
from inference.web.core.scheduler import TrainingScheduler
|
||||||
|
|
||||||
|
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
|
||||||
|
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
|
||||||
|
|
||||||
|
scheduler = TrainingScheduler()
|
||||||
|
scheduler._db = mock_db
|
||||||
|
|
||||||
|
task_id = str(uuid4())
|
||||||
|
dataset_id = str(uuid4())
|
||||||
|
|
||||||
|
# Execute task (will fail but we check the status update call)
|
||||||
|
try:
|
||||||
|
scheduler._execute_task(
|
||||||
|
task_id=task_id,
|
||||||
|
config={"model_name": "yolo11n.pt"},
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Expected to fail in test environment
|
||||||
|
|
||||||
|
# Check that training status was updated to running
|
||||||
|
mock_db.update_dataset_training_status.assert_called()
|
||||||
|
first_call = mock_db.update_dataset_training_status.call_args_list[0]
|
||||||
|
assert first_call.kwargs["training_status"] == "running"
|
||||||
|
assert first_call.kwargs["active_training_task_id"] == task_id
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Test Dataset Status Values
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetStatusValues:
|
||||||
|
"""Tests for valid dataset status values."""
|
||||||
|
|
||||||
|
def test_dataset_status_building(self):
|
||||||
|
"""Dataset can have status 'building'."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset = TrainingDataset(name="test", status="building")
|
||||||
|
assert dataset.status == "building"
|
||||||
|
|
||||||
|
def test_dataset_status_ready(self):
|
||||||
|
"""Dataset can have status 'ready'."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset = TrainingDataset(name="test", status="ready")
|
||||||
|
assert dataset.status == "ready"
|
||||||
|
|
||||||
|
def test_dataset_status_trained(self):
|
||||||
|
"""Dataset can have status 'trained'."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset = TrainingDataset(name="test", status="trained")
|
||||||
|
assert dataset.status == "trained"
|
||||||
|
|
||||||
|
def test_dataset_status_failed(self):
|
||||||
|
"""Dataset can have status 'failed'."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
dataset = TrainingDataset(name="test", status="failed")
|
||||||
|
assert dataset.status == "failed"
|
||||||
|
|
||||||
|
def test_training_status_values(self):
|
||||||
|
"""Training status can have various values."""
|
||||||
|
from inference.data.admin_models import TrainingDataset
|
||||||
|
|
||||||
|
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
|
||||||
|
for status in valid_statuses:
|
||||||
|
dataset = TrainingDataset(name="test", training_status=status)
|
||||||
|
assert dataset.training_status == status
|
||||||
207
tests/web/test_document_category.py
Normal file
207
tests/web/test_document_category.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""
|
||||||
|
Tests for Document Category Feature.
|
||||||
|
|
||||||
|
TDD tests for adding category field to admin_documents table.
|
||||||
|
Documents can be categorized (e.g., invoice, letter, receipt) for training different models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from inference.data.admin_models import AdminDocument
|
||||||
|
|
||||||
|
|
||||||
|
# Test constants
|
||||||
|
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
TEST_TOKEN = "test-admin-token-12345"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminDocumentCategoryField:
|
||||||
|
"""Tests for AdminDocument category field."""
|
||||||
|
|
||||||
|
def test_document_has_category_field(self):
|
||||||
|
"""Test AdminDocument model has category field."""
|
||||||
|
doc = AdminDocument(
|
||||||
|
document_id=UUID(TEST_DOC_UUID),
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
content_type="application/pdf",
|
||||||
|
file_path="/path/to/file.pdf",
|
||||||
|
)
|
||||||
|
assert hasattr(doc, "category")
|
||||||
|
|
||||||
|
def test_document_category_defaults_to_invoice(self):
|
||||||
|
"""Test category defaults to 'invoice' when not specified."""
|
||||||
|
doc = AdminDocument(
|
||||||
|
document_id=UUID(TEST_DOC_UUID),
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
content_type="application/pdf",
|
||||||
|
file_path="/path/to/file.pdf",
|
||||||
|
)
|
||||||
|
assert doc.category == "invoice"
|
||||||
|
|
||||||
|
def test_document_accepts_custom_category(self):
|
||||||
|
"""Test document accepts custom category values."""
|
||||||
|
categories = ["invoice", "letter", "receipt", "contract", "custom_type"]
|
||||||
|
|
||||||
|
for cat in categories:
|
||||||
|
doc = AdminDocument(
|
||||||
|
document_id=uuid4(),
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
content_type="application/pdf",
|
||||||
|
file_path="/path/to/file.pdf",
|
||||||
|
category=cat,
|
||||||
|
)
|
||||||
|
assert doc.category == cat
|
||||||
|
|
||||||
|
def test_document_category_is_string_type(self):
|
||||||
|
"""Test category field is a string type."""
|
||||||
|
doc = AdminDocument(
|
||||||
|
document_id=UUID(TEST_DOC_UUID),
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
content_type="application/pdf",
|
||||||
|
file_path="/path/to/file.pdf",
|
||||||
|
category="letter",
|
||||||
|
)
|
||||||
|
assert isinstance(doc.category, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentCategoryInReadModel:
|
||||||
|
"""Tests for category in response models."""
|
||||||
|
|
||||||
|
def test_admin_document_read_has_category(self):
|
||||||
|
"""Test AdminDocumentRead includes category field."""
|
||||||
|
from inference.data.admin_models import AdminDocumentRead
|
||||||
|
|
||||||
|
# Check the model has category field in its schema
|
||||||
|
assert "category" in AdminDocumentRead.model_fields
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentCategoryAPI:
|
||||||
|
"""Tests for document category in API endpoints."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_admin_db(self):
|
||||||
|
"""Create mock AdminDB."""
|
||||||
|
db = MagicMock()
|
||||||
|
db.is_valid_admin_token.return_value = True
|
||||||
|
return db
|
||||||
|
|
||||||
|
def test_upload_document_with_category(self, mock_admin_db):
|
||||||
|
"""Test uploading document with category parameter."""
|
||||||
|
from inference.web.schemas.admin import DocumentUploadResponse
|
||||||
|
|
||||||
|
# Verify response schema supports category
|
||||||
|
response = DocumentUploadResponse(
|
||||||
|
document_id=TEST_DOC_UUID,
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
page_count=1,
|
||||||
|
status="pending",
|
||||||
|
message="Upload successful",
|
||||||
|
category="letter",
|
||||||
|
)
|
||||||
|
assert response.category == "letter"
|
||||||
|
|
||||||
|
def test_list_documents_returns_category(self, mock_admin_db):
|
||||||
|
"""Test list documents endpoint returns category."""
|
||||||
|
from inference.web.schemas.admin import DocumentItem
|
||||||
|
|
||||||
|
item = DocumentItem(
|
||||||
|
document_id=TEST_DOC_UUID,
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
page_count=1,
|
||||||
|
status="pending",
|
||||||
|
annotation_count=0,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
updated_at=datetime.utcnow(),
|
||||||
|
category="invoice",
|
||||||
|
)
|
||||||
|
assert item.category == "invoice"
|
||||||
|
|
||||||
|
def test_document_detail_includes_category(self, mock_admin_db):
|
||||||
|
"""Test document detail response includes category."""
|
||||||
|
from inference.web.schemas.admin import DocumentDetailResponse
|
||||||
|
|
||||||
|
# Check schema has category
|
||||||
|
assert "category" in DocumentDetailResponse.model_fields
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentCategoryFiltering:
|
||||||
|
"""Tests for filtering documents by category."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_admin_db(self):
|
||||||
|
"""Create mock AdminDB with category filtering support."""
|
||||||
|
db = MagicMock()
|
||||||
|
db.is_valid_admin_token.return_value = True
|
||||||
|
|
||||||
|
# Mock documents with different categories
|
||||||
|
invoice_doc = MagicMock()
|
||||||
|
invoice_doc.document_id = uuid4()
|
||||||
|
invoice_doc.category = "invoice"
|
||||||
|
|
||||||
|
letter_doc = MagicMock()
|
||||||
|
letter_doc.document_id = uuid4()
|
||||||
|
letter_doc.category = "letter"
|
||||||
|
|
||||||
|
db.get_documents_by_category.return_value = [invoice_doc]
|
||||||
|
return db
|
||||||
|
|
||||||
|
def test_filter_documents_by_category(self, mock_admin_db):
|
||||||
|
"""Test filtering documents by category."""
|
||||||
|
# This tests the DB method signature
|
||||||
|
result = mock_admin_db.get_documents_by_category("invoice")
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].category == "invoice"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentCategoryUpdate:
|
||||||
|
"""Tests for updating document category."""
|
||||||
|
|
||||||
|
def test_update_document_category_schema(self):
|
||||||
|
"""Test update document request supports category."""
|
||||||
|
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||||
|
|
||||||
|
request = DocumentUpdateRequest(category="letter")
|
||||||
|
assert request.category == "letter"
|
||||||
|
|
||||||
|
def test_update_document_category_optional(self):
|
||||||
|
"""Test category is optional in update request."""
|
||||||
|
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||||
|
|
||||||
|
# Should not raise - category is optional
|
||||||
|
request = DocumentUpdateRequest()
|
||||||
|
assert request.category is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetWithCategory:
|
||||||
|
"""Tests for dataset creation with category filtering."""
|
||||||
|
|
||||||
|
def test_dataset_create_with_category_filter(self):
|
||||||
|
"""Test creating dataset can filter by document category."""
|
||||||
|
from inference.web.schemas.admin import DatasetCreateRequest
|
||||||
|
|
||||||
|
request = DatasetCreateRequest(
|
||||||
|
name="Invoice Training Set",
|
||||||
|
document_ids=[TEST_DOC_UUID],
|
||||||
|
category="invoice", # Optional filter
|
||||||
|
)
|
||||||
|
assert request.category == "invoice"
|
||||||
|
|
||||||
|
def test_dataset_create_category_is_optional(self):
|
||||||
|
"""Test category filter is optional when creating dataset."""
|
||||||
|
from inference.web.schemas.admin import DatasetCreateRequest
|
||||||
|
|
||||||
|
request = DatasetCreateRequest(
|
||||||
|
name="Mixed Training Set",
|
||||||
|
document_ids=[TEST_DOC_UUID],
|
||||||
|
)
|
||||||
|
# category should be optional
|
||||||
|
assert not hasattr(request, "category") or request.category is None
|
||||||
165
tests/web/test_document_category_api.py
Normal file
165
tests/web/test_document_category_api.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Tests for Document Category API Endpoints.
|
||||||
|
|
||||||
|
TDD tests for category filtering and management in document endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
# Test constants
|
||||||
|
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
TEST_TOKEN = "test-admin-token-12345"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCategoriesEndpoint:
|
||||||
|
"""Tests for GET /admin/documents/categories endpoint."""
|
||||||
|
|
||||||
|
def test_categories_endpoint_returns_list(self):
|
||||||
|
"""Test categories endpoint returns list of available categories."""
|
||||||
|
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||||
|
|
||||||
|
# Test schema exists and works
|
||||||
|
response = DocumentCategoriesResponse(
|
||||||
|
categories=["invoice", "letter", "receipt"],
|
||||||
|
total=3,
|
||||||
|
)
|
||||||
|
assert response.categories == ["invoice", "letter", "receipt"]
|
||||||
|
assert response.total == 3
|
||||||
|
|
||||||
|
def test_categories_response_schema(self):
|
||||||
|
"""Test DocumentCategoriesResponse schema structure."""
|
||||||
|
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||||
|
|
||||||
|
assert "categories" in DocumentCategoriesResponse.model_fields
|
||||||
|
assert "total" in DocumentCategoriesResponse.model_fields
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentListFilterByCategory:
|
||||||
|
"""Tests for filtering documents by category."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_admin_db(self):
|
||||||
|
"""Create mock AdminDB."""
|
||||||
|
db = MagicMock()
|
||||||
|
db.is_valid_admin_token.return_value = True
|
||||||
|
|
||||||
|
# Mock documents with different categories
|
||||||
|
invoice_doc = MagicMock()
|
||||||
|
invoice_doc.document_id = uuid4()
|
||||||
|
invoice_doc.category = "invoice"
|
||||||
|
invoice_doc.filename = "invoice1.pdf"
|
||||||
|
|
||||||
|
letter_doc = MagicMock()
|
||||||
|
letter_doc.document_id = uuid4()
|
||||||
|
letter_doc.category = "letter"
|
||||||
|
letter_doc.filename = "letter1.pdf"
|
||||||
|
|
||||||
|
db.get_documents.return_value = ([invoice_doc], 1)
|
||||||
|
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
|
||||||
|
return db
|
||||||
|
|
||||||
|
def test_list_documents_accepts_category_filter(self, mock_admin_db):
|
||||||
|
"""Test list documents endpoint accepts category query parameter."""
|
||||||
|
# The endpoint should accept ?category=invoice parameter
|
||||||
|
# This test verifies the schema/query parameter exists
|
||||||
|
from inference.web.schemas.admin import DocumentListResponse
|
||||||
|
|
||||||
|
# Schema should work with category filter applied
|
||||||
|
assert DocumentListResponse is not None
|
||||||
|
|
||||||
|
def test_get_document_categories_from_db(self, mock_admin_db):
|
||||||
|
"""Test fetching unique categories from database."""
|
||||||
|
categories = mock_admin_db.get_document_categories()
|
||||||
|
assert "invoice" in categories
|
||||||
|
assert "letter" in categories
|
||||||
|
assert len(categories) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentUploadWithCategory:
|
||||||
|
"""Tests for uploading documents with category."""
|
||||||
|
|
||||||
|
def test_upload_request_accepts_category(self):
|
||||||
|
"""Test upload request can include category field."""
|
||||||
|
# When uploading via form data, category should be accepted
|
||||||
|
# This is typically a form field, not a schema
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_upload_response_includes_category(self):
|
||||||
|
"""Test upload response includes the category that was set."""
|
||||||
|
from inference.web.schemas.admin import DocumentUploadResponse
|
||||||
|
|
||||||
|
response = DocumentUploadResponse(
|
||||||
|
document_id=TEST_DOC_UUID,
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
page_count=1,
|
||||||
|
status="pending",
|
||||||
|
category="letter", # Custom category
|
||||||
|
message="Upload successful",
|
||||||
|
)
|
||||||
|
assert response.category == "letter"
|
||||||
|
|
||||||
|
def test_upload_defaults_to_invoice_category(self):
|
||||||
|
"""Test upload defaults to 'invoice' if no category specified."""
|
||||||
|
from inference.web.schemas.admin import DocumentUploadResponse
|
||||||
|
|
||||||
|
response = DocumentUploadResponse(
|
||||||
|
document_id=TEST_DOC_UUID,
|
||||||
|
filename="test.pdf",
|
||||||
|
file_size=1024,
|
||||||
|
page_count=1,
|
||||||
|
status="pending",
|
||||||
|
message="Upload successful",
|
||||||
|
# No category specified - should default to "invoice"
|
||||||
|
)
|
||||||
|
assert response.category == "invoice"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminDBCategoryMethods:
|
||||||
|
"""Tests for AdminDB category-related methods."""
|
||||||
|
|
||||||
|
def test_get_document_categories_method_exists(self):
|
||||||
|
"""Test AdminDB has get_document_categories method."""
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
assert hasattr(db, "get_document_categories")
|
||||||
|
|
||||||
|
def test_get_documents_accepts_category_filter(self):
|
||||||
|
"""Test get_documents_by_token method accepts category parameter."""
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
# Check the method exists and accepts category parameter
|
||||||
|
method = getattr(db, "get_documents_by_token", None)
|
||||||
|
assert callable(method)
|
||||||
|
|
||||||
|
# Check category is in the method signature
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
assert "category" in sig.parameters
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateDocumentCategory:
|
||||||
|
"""Tests for updating document category."""
|
||||||
|
|
||||||
|
def test_update_document_category_method_exists(self):
|
||||||
|
"""Test AdminDB has method to update document category."""
|
||||||
|
from inference.data.admin_db import AdminDB
|
||||||
|
|
||||||
|
db = AdminDB()
|
||||||
|
assert hasattr(db, "update_document_category")
|
||||||
|
|
||||||
|
def test_update_request_schema(self):
|
||||||
|
"""Test DocumentUpdateRequest can update category."""
|
||||||
|
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||||
|
|
||||||
|
request = DocumentUpdateRequest(category="receipt")
|
||||||
|
assert request.category == "receipt"
|
||||||
@@ -32,10 +32,10 @@ def test_app(tmp_path):
|
|||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
dpi=150,
|
dpi=150,
|
||||||
),
|
),
|
||||||
storage=StorageConfig(
|
file=StorageConfig(
|
||||||
upload_dir=upload_dir,
|
upload_dir=upload_dir,
|
||||||
result_dir=result_dir,
|
result_dir=result_dir,
|
||||||
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
|
||||||
max_file_size_mb=50,
|
max_file_size_mb=50,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -252,20 +252,25 @@ class TestResultsEndpoint:
|
|||||||
response = client.get("/api/v1/results/nonexistent.png")
|
response = client.get("/api/v1/results/nonexistent.png")
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
|
def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
|
||||||
"""Test that existing result file is returned."""
|
"""Test that existing result file is returned."""
|
||||||
# Get storage config from app
|
# Create a test result file in temp directory
|
||||||
storage_config = test_app.extra.get("storage_config")
|
result_dir = tmp_path / "results"
|
||||||
if not storage_config:
|
result_dir.mkdir(exist_ok=True)
|
||||||
pytest.skip("Storage config not available in test app")
|
result_file = result_dir / "test_result.png"
|
||||||
|
|
||||||
# Create a test result file
|
|
||||||
result_file = storage_config.result_dir / "test_result.png"
|
|
||||||
img = Image.new('RGB', (100, 100), color='red')
|
img = Image.new('RGB', (100, 100), color='red')
|
||||||
img.save(result_file)
|
img.save(result_file)
|
||||||
|
|
||||||
# Request the file
|
# Mock the storage helper to return our test file path
|
||||||
response = client.get("/api/v1/results/test_result.png")
|
with patch(
|
||||||
|
"inference.web.api.v1.public.inference.get_storage_helper"
|
||||||
|
) as mock_storage:
|
||||||
|
mock_helper = Mock()
|
||||||
|
mock_helper.get_result_local_path.return_value = result_file
|
||||||
|
mock_storage.return_value = mock_helper
|
||||||
|
|
||||||
|
# Request the file
|
||||||
|
response = client.get("/api/v1/results/test_result.png")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "image/png"
|
assert response.headers["content-type"] == "image/png"
|
||||||
|
|||||||
@@ -266,7 +266,11 @@ class TestActivateModelVersionRoute:
|
|||||||
mock_db = MagicMock()
|
mock_db = MagicMock()
|
||||||
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||||
|
|
||||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
# Create mock request with app state
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.app.state.inference_service = None
|
||||||
|
|
||||||
|
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
|
|
||||||
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||||
assert result.status == "active"
|
assert result.status == "active"
|
||||||
@@ -278,10 +282,14 @@ class TestActivateModelVersionRoute:
|
|||||||
mock_db = MagicMock()
|
mock_db = MagicMock()
|
||||||
mock_db.activate_model_version.return_value = None
|
mock_db.activate_model_version.return_value = None
|
||||||
|
|
||||||
|
# Create mock request with app state
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.app.state.inference_service = None
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||||
assert exc_info.value.status_code == 404
|
assert exc_info.value.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
828
tests/web/test_storage_helpers.py
Normal file
828
tests/web/test_storage_helpers.py
Normal file
@@ -0,0 +1,828 @@
|
|||||||
|
"""Tests for storage helpers module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
|
||||||
|
from shared.storage import PREFIXES
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_storage() -> MagicMock:
|
||||||
|
"""Create a mock storage backend."""
|
||||||
|
storage = MagicMock()
|
||||||
|
storage.upload_bytes = MagicMock()
|
||||||
|
storage.download_bytes = MagicMock(return_value=b"test content")
|
||||||
|
storage.get_presigned_url = MagicMock(return_value="https://example.com/file")
|
||||||
|
storage.exists = MagicMock(return_value=True)
|
||||||
|
storage.delete = MagicMock(return_value=True)
|
||||||
|
storage.list_files = MagicMock(return_value=[])
|
||||||
|
return storage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def helper(mock_storage: MagicMock) -> StorageHelper:
|
||||||
|
"""Create a storage helper with mock backend."""
|
||||||
|
return StorageHelper(storage=mock_storage)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageHelperInit:
|
||||||
|
"""Tests for StorageHelper initialization."""
|
||||||
|
|
||||||
|
def test_init_with_storage(self, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should use provided storage backend."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
assert helper.storage is mock_storage
|
||||||
|
|
||||||
|
def test_storage_property(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Storage property should return the backend."""
|
||||||
|
assert helper.storage is mock_storage
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentOperations:
|
||||||
|
"""Tests for document storage operations."""
|
||||||
|
|
||||||
|
def test_upload_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should upload document with correct path."""
|
||||||
|
doc_id, path = helper.upload_document(b"pdf content", "invoice.pdf", "doc123")
|
||||||
|
|
||||||
|
assert doc_id == "doc123"
|
||||||
|
assert path == "documents/doc123.pdf"
|
||||||
|
mock_storage.upload_bytes.assert_called_once_with(
|
||||||
|
b"pdf content", "documents/doc123.pdf", overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_upload_document_generates_id(self, helper: StorageHelper) -> None:
|
||||||
|
"""Should generate document ID if not provided."""
|
||||||
|
doc_id, path = helper.upload_document(b"content", "file.pdf")
|
||||||
|
|
||||||
|
assert doc_id is not None
|
||||||
|
assert len(doc_id) > 0
|
||||||
|
assert path.startswith("documents/")
|
||||||
|
|
||||||
|
def test_download_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should download document from correct path."""
|
||||||
|
content = helper.download_document("doc123")
|
||||||
|
|
||||||
|
assert content == b"test content"
|
||||||
|
mock_storage.download_bytes.assert_called_once_with("documents/doc123.pdf")
|
||||||
|
|
||||||
|
def test_get_document_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get presigned URL for document."""
|
||||||
|
url = helper.get_document_url("doc123", expires_in_seconds=7200)
|
||||||
|
|
||||||
|
assert url == "https://example.com/file"
|
||||||
|
mock_storage.get_presigned_url.assert_called_once_with(
|
||||||
|
"documents/doc123.pdf", 7200
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_document_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should check document existence."""
|
||||||
|
exists = helper.document_exists("doc123")
|
||||||
|
|
||||||
|
assert exists is True
|
||||||
|
mock_storage.exists.assert_called_once_with("documents/doc123.pdf")
|
||||||
|
|
||||||
|
def test_delete_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should delete document."""
|
||||||
|
result = helper.delete_document("doc123")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_storage.delete.assert_called_once_with("documents/doc123.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
class TestImageOperations:
|
||||||
|
"""Tests for image storage operations."""
|
||||||
|
|
||||||
|
def test_save_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should save page image with correct path."""
|
||||||
|
path = helper.save_page_image("doc123", 1, b"image data")
|
||||||
|
|
||||||
|
assert path == "images/doc123/page_1.png"
|
||||||
|
mock_storage.upload_bytes.assert_called_once_with(
|
||||||
|
b"image data", "images/doc123/page_1.png", overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get page image from correct path."""
|
||||||
|
content = helper.get_page_image("doc123", 2)
|
||||||
|
|
||||||
|
assert content == b"test content"
|
||||||
|
mock_storage.download_bytes.assert_called_once_with("images/doc123/page_2.png")
|
||||||
|
|
||||||
|
def test_get_page_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get presigned URL for page image."""
|
||||||
|
url = helper.get_page_image_url("doc123", 3)
|
||||||
|
|
||||||
|
assert url == "https://example.com/file"
|
||||||
|
mock_storage.get_presigned_url.assert_called_once_with(
|
||||||
|
"images/doc123/page_3.png", 3600
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_delete_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should delete all images for a document."""
|
||||||
|
mock_storage.list_files.return_value = [
|
||||||
|
"images/doc123/page_1.png",
|
||||||
|
"images/doc123/page_2.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
deleted = helper.delete_document_images("doc123")
|
||||||
|
|
||||||
|
assert deleted == 2
|
||||||
|
mock_storage.list_files.assert_called_once_with("images/doc123/")
|
||||||
|
|
||||||
|
def test_list_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should list all images for a document."""
|
||||||
|
mock_storage.list_files.return_value = ["images/doc123/page_1.png"]
|
||||||
|
|
||||||
|
images = helper.list_document_images("doc123")
|
||||||
|
|
||||||
|
assert images == ["images/doc123/page_1.png"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadOperations:
|
||||||
|
"""Tests for upload staging operations."""
|
||||||
|
|
||||||
|
def test_save_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should save upload to correct path."""
|
||||||
|
path = helper.save_upload(b"content", "file.pdf")
|
||||||
|
|
||||||
|
assert path == "uploads/file.pdf"
|
||||||
|
mock_storage.upload_bytes.assert_called_once()
|
||||||
|
|
||||||
|
def test_save_upload_with_subfolder(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should save upload with subfolder."""
|
||||||
|
path = helper.save_upload(b"content", "file.pdf", "async")
|
||||||
|
|
||||||
|
assert path == "uploads/async/file.pdf"
|
||||||
|
|
||||||
|
def test_get_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get upload from correct path."""
|
||||||
|
content = helper.get_upload("file.pdf", "async")
|
||||||
|
|
||||||
|
mock_storage.download_bytes.assert_called_once_with("uploads/async/file.pdf")
|
||||||
|
|
||||||
|
def test_delete_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should delete upload."""
|
||||||
|
result = helper.delete_upload("file.pdf")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_storage.delete.assert_called_once_with("uploads/file.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultOperations:
|
||||||
|
"""Tests for result file operations."""
|
||||||
|
|
||||||
|
def test_save_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should save result to correct path."""
|
||||||
|
path = helper.save_result(b"result data", "output.json")
|
||||||
|
|
||||||
|
assert path == "results/output.json"
|
||||||
|
mock_storage.upload_bytes.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get result from correct path."""
|
||||||
|
content = helper.get_result("output.json")
|
||||||
|
|
||||||
|
mock_storage.download_bytes.assert_called_once_with("results/output.json")
|
||||||
|
|
||||||
|
def test_get_result_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get presigned URL for result."""
|
||||||
|
url = helper.get_result_url("output.json")
|
||||||
|
|
||||||
|
mock_storage.get_presigned_url.assert_called_once_with("results/output.json", 3600)
|
||||||
|
|
||||||
|
def test_result_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should check result existence."""
|
||||||
|
exists = helper.result_exists("output.json")
|
||||||
|
|
||||||
|
assert exists is True
|
||||||
|
mock_storage.exists.assert_called_once_with("results/output.json")
|
||||||
|
|
||||||
|
|
||||||
|
class TestExportOperations:
|
||||||
|
"""Tests for export file operations."""
|
||||||
|
|
||||||
|
def test_save_export(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should save export to correct path."""
|
||||||
|
path = helper.save_export(b"export data", "exp123", "dataset.zip")
|
||||||
|
|
||||||
|
assert path == "exports/exp123/dataset.zip"
|
||||||
|
mock_storage.upload_bytes.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_export_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get presigned URL for export."""
|
||||||
|
url = helper.get_export_url("exp123", "dataset.zip")
|
||||||
|
|
||||||
|
mock_storage.get_presigned_url.assert_called_once_with(
|
||||||
|
"exports/exp123/dataset.zip", 3600
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRawPdfOperations:
|
||||||
|
"""Tests for raw PDF operations (legacy compatibility)."""
|
||||||
|
|
||||||
|
def test_save_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should save raw PDF to correct path."""
|
||||||
|
path = helper.save_raw_pdf(b"pdf data", "invoice.pdf")
|
||||||
|
|
||||||
|
assert path == "raw_pdfs/invoice.pdf"
|
||||||
|
mock_storage.upload_bytes.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get raw PDF from correct path."""
|
||||||
|
content = helper.get_raw_pdf("invoice.pdf")
|
||||||
|
|
||||||
|
mock_storage.download_bytes.assert_called_once_with("raw_pdfs/invoice.pdf")
|
||||||
|
|
||||||
|
def test_raw_pdf_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should check raw PDF existence."""
|
||||||
|
exists = helper.raw_pdf_exists("invoice.pdf")
|
||||||
|
|
||||||
|
assert exists is True
|
||||||
|
mock_storage.exists.assert_called_once_with("raw_pdfs/invoice.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminImageOperations:
|
||||||
|
"""Tests for admin image storage operations."""
|
||||||
|
|
||||||
|
def test_save_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should save admin image with correct path."""
|
||||||
|
path = helper.save_admin_image("doc123", 1, b"image data")
|
||||||
|
|
||||||
|
assert path == "admin_images/doc123/page_1.png"
|
||||||
|
mock_storage.upload_bytes.assert_called_once_with(
|
||||||
|
b"image data", "admin_images/doc123/page_1.png", overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get admin image from correct path."""
|
||||||
|
content = helper.get_admin_image("doc123", 2)
|
||||||
|
|
||||||
|
assert content == b"test content"
|
||||||
|
mock_storage.download_bytes.assert_called_once_with("admin_images/doc123/page_2.png")
|
||||||
|
|
||||||
|
def test_get_admin_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should get presigned URL for admin image."""
|
||||||
|
url = helper.get_admin_image_url("doc123", 3)
|
||||||
|
|
||||||
|
assert url == "https://example.com/file"
|
||||||
|
mock_storage.get_presigned_url.assert_called_once_with(
|
||||||
|
"admin_images/doc123/page_3.png", 3600
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_admin_image_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should check admin image existence."""
|
||||||
|
exists = helper.admin_image_exists("doc123", 1)
|
||||||
|
|
||||||
|
assert exists is True
|
||||||
|
mock_storage.exists.assert_called_once_with("admin_images/doc123/page_1.png")
|
||||||
|
|
||||||
|
def test_get_admin_image_path(self, helper: StorageHelper) -> None:
|
||||||
|
"""Should return correct admin image path."""
|
||||||
|
path = helper.get_admin_image_path("doc123", 2)
|
||||||
|
|
||||||
|
assert path == "admin_images/doc123/page_2.png"
|
||||||
|
|
||||||
|
def test_list_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should list all admin images for a document."""
|
||||||
|
mock_storage.list_files.return_value = [
|
||||||
|
"admin_images/doc123/page_1.png",
|
||||||
|
"admin_images/doc123/page_2.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
images = helper.list_admin_images("doc123")
|
||||||
|
|
||||||
|
assert images == ["admin_images/doc123/page_1.png", "admin_images/doc123/page_2.png"]
|
||||||
|
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
|
||||||
|
|
||||||
|
def test_delete_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should delete all admin images for a document."""
|
||||||
|
mock_storage.list_files.return_value = [
|
||||||
|
"admin_images/doc123/page_1.png",
|
||||||
|
"admin_images/doc123/page_2.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
deleted = helper.delete_admin_images("doc123")
|
||||||
|
|
||||||
|
assert deleted == 2
|
||||||
|
mock_storage.list_files.assert_called_once_with("admin_images/doc123/")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLocalPath:
|
||||||
|
"""Tests for get_local_path method."""
|
||||||
|
|
||||||
|
def test_get_admin_image_local_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return local path when using local storage backend."""
|
||||||
|
from pathlib import Path
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
# Create a test image
|
||||||
|
test_path = Path(temp_dir) / "admin_images" / "doc123"
|
||||||
|
test_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
(test_path / "page_1.png").write_bytes(b"test image")
|
||||||
|
|
||||||
|
local_path = helper.get_admin_image_local_path("doc123", 1)
|
||||||
|
|
||||||
|
assert local_path is not None
|
||||||
|
assert local_path.exists()
|
||||||
|
assert local_path.name == "page_1.png"
|
||||||
|
|
||||||
|
def test_get_admin_image_local_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when storage doesn't support local paths."""
|
||||||
|
# Mock storage without get_local_path method (simulating cloud storage)
|
||||||
|
mock_storage.get_local_path = MagicMock(return_value=None)
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
|
||||||
|
local_path = helper.get_admin_image_local_path("doc123", 1)
|
||||||
|
|
||||||
|
assert local_path is None
|
||||||
|
|
||||||
|
def test_get_admin_image_local_path_nonexistent_file(self) -> None:
|
||||||
|
"""Should return None when file doesn't exist."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
local_path = helper.get_admin_image_local_path("nonexistent", 1)
|
||||||
|
|
||||||
|
assert local_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAdminImageDimensions:
|
||||||
|
"""Tests for get_admin_image_dimensions method."""
|
||||||
|
|
||||||
|
def test_get_dimensions_with_local_storage(self) -> None:
|
||||||
|
"""Should return image dimensions when using local storage."""
|
||||||
|
from pathlib import Path
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
from PIL import Image
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
# Create a test image with known dimensions
|
||||||
|
test_path = Path(temp_dir) / "admin_images" / "doc123"
|
||||||
|
test_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
img = Image.new("RGB", (800, 600), color="white")
|
||||||
|
img.save(test_path / "page_1.png")
|
||||||
|
|
||||||
|
dimensions = helper.get_admin_image_dimensions("doc123", 1)
|
||||||
|
|
||||||
|
assert dimensions == (800, 600)
|
||||||
|
|
||||||
|
def test_get_dimensions_nonexistent_file(self) -> None:
|
||||||
|
"""Should return None when file doesn't exist."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
dimensions = helper.get_admin_image_dimensions("nonexistent", 1)
|
||||||
|
|
||||||
|
assert dimensions is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetStorageHelper:
|
||||||
|
"""Tests for get_storage_helper function."""
|
||||||
|
|
||||||
|
def test_returns_helper_instance(self) -> None:
|
||||||
|
"""Should return a StorageHelper instance."""
|
||||||
|
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||||
|
mock_get.return_value = MagicMock()
|
||||||
|
# Reset the global helper
|
||||||
|
import inference.web.services.storage_helpers as module
|
||||||
|
module._default_helper = None
|
||||||
|
|
||||||
|
helper = get_storage_helper()
|
||||||
|
|
||||||
|
assert isinstance(helper, StorageHelper)
|
||||||
|
|
||||||
|
def test_returns_same_instance(self) -> None:
|
||||||
|
"""Should return the same instance on subsequent calls."""
|
||||||
|
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||||
|
mock_get.return_value = MagicMock()
|
||||||
|
import inference.web.services.storage_helpers as module
|
||||||
|
module._default_helper = None
|
||||||
|
|
||||||
|
helper1 = get_storage_helper()
|
||||||
|
helper2 = get_storage_helper()
|
||||||
|
|
||||||
|
assert helper1 is helper2
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteResult:
|
||||||
|
"""Tests for delete_result method."""
|
||||||
|
|
||||||
|
def test_delete_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should delete result file."""
|
||||||
|
result = helper.delete_result("output.json")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_storage.delete.assert_called_once_with("results/output.json")
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultLocalPath:
|
||||||
|
"""Tests for get_result_local_path method."""
|
||||||
|
|
||||||
|
def test_get_result_local_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return local path when using local storage backend."""
|
||||||
|
from pathlib import Path
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
# Create a test result file
|
||||||
|
test_path = Path(temp_dir) / "results"
|
||||||
|
test_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
(test_path / "output.json").write_bytes(b"test result")
|
||||||
|
|
||||||
|
local_path = helper.get_result_local_path("output.json")
|
||||||
|
|
||||||
|
assert local_path is not None
|
||||||
|
assert local_path.exists()
|
||||||
|
assert local_path.name == "output.json"
|
||||||
|
|
||||||
|
def test_get_result_local_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when storage doesn't support local paths."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
local_path = helper.get_result_local_path("output.json")
|
||||||
|
assert local_path is None
|
||||||
|
|
||||||
|
def test_get_result_local_path_nonexistent_file(self) -> None:
|
||||||
|
"""Should return None when file doesn't exist."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
local_path = helper.get_result_local_path("nonexistent.json")
|
||||||
|
|
||||||
|
assert local_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultsBasePath:
|
||||||
|
"""Tests for get_results_base_path method."""
|
||||||
|
|
||||||
|
def test_get_results_base_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return base path when using local storage."""
|
||||||
|
from pathlib import Path
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
base_path = helper.get_results_base_path()
|
||||||
|
|
||||||
|
assert base_path is not None
|
||||||
|
assert base_path.exists()
|
||||||
|
assert base_path.name == "results"
|
||||||
|
|
||||||
|
def test_get_results_base_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
base_path = helper.get_results_base_path()
|
||||||
|
assert base_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadLocalPath:
|
||||||
|
"""Tests for get_upload_local_path method."""
|
||||||
|
|
||||||
|
def test_get_upload_local_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return local path when using local storage backend."""
|
||||||
|
from pathlib import Path
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
# Create a test upload file
|
||||||
|
test_path = Path(temp_dir) / "uploads"
|
||||||
|
test_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
(test_path / "file.pdf").write_bytes(b"test upload")
|
||||||
|
|
||||||
|
local_path = helper.get_upload_local_path("file.pdf")
|
||||||
|
|
||||||
|
assert local_path is not None
|
||||||
|
assert local_path.exists()
|
||||||
|
assert local_path.name == "file.pdf"
|
||||||
|
|
||||||
|
def test_get_upload_local_path_with_subfolder(self) -> None:
|
||||||
|
"""Should return local path with subfolder."""
|
||||||
|
from pathlib import Path
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
# Create a test upload file with subfolder
|
||||||
|
test_path = Path(temp_dir) / "uploads" / "async"
|
||||||
|
test_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
(test_path / "file.pdf").write_bytes(b"test upload")
|
||||||
|
|
||||||
|
local_path = helper.get_upload_local_path("file.pdf", "async")
|
||||||
|
|
||||||
|
assert local_path is not None
|
||||||
|
assert local_path.exists()
|
||||||
|
|
||||||
|
def test_get_upload_local_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
local_path = helper.get_upload_local_path("file.pdf")
|
||||||
|
assert local_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadsBasePath:
|
||||||
|
"""Tests for get_uploads_base_path method."""
|
||||||
|
|
||||||
|
def test_get_uploads_base_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return base path when using local storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
base_path = helper.get_uploads_base_path()
|
||||||
|
|
||||||
|
assert base_path is not None
|
||||||
|
assert base_path.exists()
|
||||||
|
assert base_path.name == "uploads"
|
||||||
|
|
||||||
|
def test_get_uploads_base_path_with_subfolder(self) -> None:
|
||||||
|
"""Should return base path with subfolder."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
base_path = helper.get_uploads_base_path("async")
|
||||||
|
|
||||||
|
assert base_path is not None
|
||||||
|
assert base_path.exists()
|
||||||
|
assert base_path.name == "async"
|
||||||
|
|
||||||
|
def test_get_uploads_base_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
base_path = helper.get_uploads_base_path()
|
||||||
|
assert base_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadExists:
|
||||||
|
"""Tests for upload_exists method."""
|
||||||
|
|
||||||
|
def test_upload_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None:
|
||||||
|
"""Should check upload existence."""
|
||||||
|
exists = helper.upload_exists("file.pdf")
|
||||||
|
|
||||||
|
assert exists is True
|
||||||
|
mock_storage.exists.assert_called_once_with("uploads/file.pdf")
|
||||||
|
|
||||||
|
def test_upload_exists_with_subfolder(
|
||||||
|
self, helper: StorageHelper, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should check upload existence with subfolder."""
|
||||||
|
helper.upload_exists("file.pdf", "async")
|
||||||
|
|
||||||
|
mock_storage.exists.assert_called_once_with("uploads/async/file.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetsBasePath:
|
||||||
|
"""Tests for get_datasets_base_path method."""
|
||||||
|
|
||||||
|
def test_get_datasets_base_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return base path when using local storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
base_path = helper.get_datasets_base_path()
|
||||||
|
|
||||||
|
assert base_path is not None
|
||||||
|
assert base_path.exists()
|
||||||
|
assert base_path.name == "datasets"
|
||||||
|
|
||||||
|
def test_get_datasets_base_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
base_path = helper.get_datasets_base_path()
|
||||||
|
assert base_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminImagesBasePath:
|
||||||
|
"""Tests for get_admin_images_base_path method."""
|
||||||
|
|
||||||
|
def test_get_admin_images_base_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return base path when using local storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
base_path = helper.get_admin_images_base_path()
|
||||||
|
|
||||||
|
assert base_path is not None
|
||||||
|
assert base_path.exists()
|
||||||
|
assert base_path.name == "admin_images"
|
||||||
|
|
||||||
|
def test_get_admin_images_base_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
base_path = helper.get_admin_images_base_path()
|
||||||
|
assert base_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRawPdfsBasePath:
|
||||||
|
"""Tests for get_raw_pdfs_base_path method."""
|
||||||
|
|
||||||
|
def test_get_raw_pdfs_base_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return base path when using local storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
base_path = helper.get_raw_pdfs_base_path()
|
||||||
|
|
||||||
|
assert base_path is not None
|
||||||
|
assert base_path.exists()
|
||||||
|
assert base_path.name == "raw_pdfs"
|
||||||
|
|
||||||
|
def test_get_raw_pdfs_base_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
base_path = helper.get_raw_pdfs_base_path()
|
||||||
|
assert base_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRawPdfLocalPath:
|
||||||
|
"""Tests for get_raw_pdf_local_path method."""
|
||||||
|
|
||||||
|
def test_get_raw_pdf_local_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return local path when using local storage backend."""
|
||||||
|
from pathlib import Path
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
# Create a test raw PDF
|
||||||
|
test_path = Path(temp_dir) / "raw_pdfs"
|
||||||
|
test_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
(test_path / "invoice.pdf").write_bytes(b"test pdf")
|
||||||
|
|
||||||
|
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
|
||||||
|
|
||||||
|
assert local_path is not None
|
||||||
|
assert local_path.exists()
|
||||||
|
assert local_path.name == "invoice.pdf"
|
||||||
|
|
||||||
|
def test_get_raw_pdf_local_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
local_path = helper.get_raw_pdf_local_path("invoice.pdf")
|
||||||
|
assert local_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRawPdfPath:
|
||||||
|
"""Tests for get_raw_pdf_path method."""
|
||||||
|
|
||||||
|
def test_get_raw_pdf_path(self, helper: StorageHelper) -> None:
|
||||||
|
"""Should return correct storage path."""
|
||||||
|
path = helper.get_raw_pdf_path("invoice.pdf")
|
||||||
|
assert path == "raw_pdfs/invoice.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutolabelOutputPath:
|
||||||
|
"""Tests for get_autolabel_output_path method."""
|
||||||
|
|
||||||
|
def test_get_autolabel_output_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return output path when using local storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
output_path = helper.get_autolabel_output_path()
|
||||||
|
|
||||||
|
assert output_path is not None
|
||||||
|
assert output_path.exists()
|
||||||
|
assert output_path.name == "autolabel_output"
|
||||||
|
|
||||||
|
def test_get_autolabel_output_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
output_path = helper.get_autolabel_output_path()
|
||||||
|
assert output_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingDataPath:
|
||||||
|
"""Tests for get_training_data_path method."""
|
||||||
|
|
||||||
|
def test_get_training_data_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return training path when using local storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
training_path = helper.get_training_data_path()
|
||||||
|
|
||||||
|
assert training_path is not None
|
||||||
|
assert training_path.exists()
|
||||||
|
assert training_path.name == "training"
|
||||||
|
|
||||||
|
def test_get_training_data_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
training_path = helper.get_training_data_path()
|
||||||
|
assert training_path is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExportsBasePath:
|
||||||
|
"""Tests for get_exports_base_path method."""
|
||||||
|
|
||||||
|
def test_get_exports_base_path_with_local_storage(self) -> None:
|
||||||
|
"""Should return base path when using local storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
storage = LocalStorageBackend(temp_dir)
|
||||||
|
helper = StorageHelper(storage=storage)
|
||||||
|
|
||||||
|
base_path = helper.get_exports_base_path()
|
||||||
|
|
||||||
|
assert base_path is not None
|
||||||
|
assert base_path.exists()
|
||||||
|
assert base_path.name == "exports"
|
||||||
|
|
||||||
|
def test_get_exports_base_path_returns_none_for_cloud(
|
||||||
|
self, mock_storage: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Should return None when not using local storage."""
|
||||||
|
helper = StorageHelper(storage=mock_storage)
|
||||||
|
base_path = helper.get_exports_base_path()
|
||||||
|
assert base_path is None
|
||||||
306
tests/web/test_storage_integration.py
Normal file
306
tests/web/test_storage_integration.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
"""
|
||||||
|
Tests for storage backend integration in web application.
|
||||||
|
|
||||||
|
TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageBackendInitialization:
|
||||||
|
"""Tests for storage backend initialization in web config."""
|
||||||
|
|
||||||
|
def test_get_storage_backend_returns_backend(self, tmp_path: Path) -> None:
|
||||||
|
"""Test that get_storage_backend returns a StorageBackend instance."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
from inference.web.config import get_storage_backend
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "local",
|
||||||
|
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
backend = get_storage_backend()
|
||||||
|
|
||||||
|
assert isinstance(backend, StorageBackend)
|
||||||
|
|
||||||
|
def test_get_storage_backend_uses_config_file_if_exists(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test that storage config file is used when present."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
from inference.web.config import get_storage_backend
|
||||||
|
|
||||||
|
config_file = tmp_path / "storage.yaml"
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
config_file.write_text(f"""
|
||||||
|
backend: local
|
||||||
|
|
||||||
|
local:
|
||||||
|
base_path: {storage_path}
|
||||||
|
""")
|
||||||
|
|
||||||
|
backend = get_storage_backend(config_path=config_file)
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
|
||||||
|
def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None:
|
||||||
|
"""Test fallback to environment variables when no config file."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
from inference.web.config import get_storage_backend
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "local",
|
||||||
|
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
backend = get_storage_backend(config_path=None)
|
||||||
|
|
||||||
|
assert isinstance(backend, LocalStorageBackend)
|
||||||
|
|
||||||
|
def test_app_config_has_storage_backend(self, tmp_path: Path) -> None:
|
||||||
|
"""Test that AppConfig can be created with storage backend."""
|
||||||
|
from shared.storage.base import StorageBackend
|
||||||
|
|
||||||
|
from inference.web.config import AppConfig, create_app_config
|
||||||
|
|
||||||
|
env = {
|
||||||
|
"STORAGE_BACKEND": "local",
|
||||||
|
"STORAGE_BASE_PATH": str(tmp_path / "storage"),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
config = create_app_config()
|
||||||
|
|
||||||
|
assert hasattr(config, "storage_backend")
|
||||||
|
assert isinstance(config.storage_backend, StorageBackend)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageBackendInDocumentUpload:
|
||||||
|
"""Tests for storage backend usage in document upload."""
|
||||||
|
|
||||||
|
def test_upload_document_uses_storage_backend(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that document upload uses storage backend."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
# Create a mock upload file
|
||||||
|
pdf_content = b"%PDF-1.4 test content"
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
# Upload should use storage backend
|
||||||
|
result = service.upload_document(
|
||||||
|
content=pdf_content,
|
||||||
|
filename="test.pdf",
|
||||||
|
dataset_id="dataset-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
# Verify file was stored via storage backend
|
||||||
|
assert backend.exists(f"documents/{result.id}.pdf")
|
||||||
|
|
||||||
|
def test_upload_document_stores_logical_path(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that document stores logical path, not absolute path."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
pdf_content = b"%PDF-1.4 test content"
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
result = service.upload_document(
|
||||||
|
content=pdf_content,
|
||||||
|
filename="test.pdf",
|
||||||
|
dataset_id="dataset-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Path should be logical (relative), not absolute
|
||||||
|
assert not result.file_path.startswith("/")
|
||||||
|
assert not result.file_path.startswith("C:")
|
||||||
|
assert result.file_path.startswith("documents/")
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageBackendInDocumentDownload:
|
||||||
|
"""Tests for storage backend usage in document download/serving."""
|
||||||
|
|
||||||
|
def test_get_document_url_returns_presigned_url(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that document URL uses presigned URL from storage backend."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
# Create a test file
|
||||||
|
doc_path = "documents/test-doc.pdf"
|
||||||
|
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
url = service.get_document_url(doc_path)
|
||||||
|
|
||||||
|
# Should return a URL (file:// for local, https:// for cloud)
|
||||||
|
assert url is not None
|
||||||
|
assert "test-doc.pdf" in url
|
||||||
|
|
||||||
|
def test_download_document_uses_storage_backend(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that document download uses storage backend."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
# Create a test file
|
||||||
|
doc_path = "documents/test-doc.pdf"
|
||||||
|
original_content = b"%PDF-1.4 test content"
|
||||||
|
backend.upload_bytes(original_content, doc_path)
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
content = service.download_document(doc_path)
|
||||||
|
|
||||||
|
assert content == original_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageBackendInImageServing:
|
||||||
|
"""Tests for storage backend usage in image serving."""
|
||||||
|
|
||||||
|
def test_get_page_image_url_returns_presigned_url(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that page image URL uses presigned URL."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
# Create a test image
|
||||||
|
image_path = "images/doc-123/page_1.png"
|
||||||
|
backend.upload_bytes(b"fake png content", image_path)
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
url = service.get_page_image_url("doc-123", 1)
|
||||||
|
|
||||||
|
assert url is not None
|
||||||
|
assert "page_1.png" in url
|
||||||
|
|
||||||
|
def test_save_page_image_uses_storage_backend(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that page image saving uses storage backend."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
image_content = b"fake png content"
|
||||||
|
service.save_page_image("doc-123", 1, image_content)
|
||||||
|
|
||||||
|
# Verify image was stored
|
||||||
|
assert backend.exists("images/doc-123/page_1.png")
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageBackendInDocumentDeletion:
|
||||||
|
"""Tests for storage backend usage in document deletion."""
|
||||||
|
|
||||||
|
def test_delete_document_removes_from_storage(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that document deletion removes file from storage."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
# Create test files
|
||||||
|
doc_path = "documents/test-doc.pdf"
|
||||||
|
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
service.delete_document_files(doc_path)
|
||||||
|
|
||||||
|
assert not backend.exists(doc_path)
|
||||||
|
|
||||||
|
def test_delete_document_removes_images(
|
||||||
|
self, tmp_path: Path, mock_admin_db: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that document deletion removes associated images."""
|
||||||
|
from shared.storage.local import LocalStorageBackend
|
||||||
|
|
||||||
|
storage_path = tmp_path / "storage"
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
backend = LocalStorageBackend(str(storage_path))
|
||||||
|
|
||||||
|
# Create test files
|
||||||
|
doc_id = "test-doc-123"
|
||||||
|
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
|
||||||
|
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
|
||||||
|
|
||||||
|
from inference.web.services.document_service import DocumentService
|
||||||
|
|
||||||
|
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||||
|
|
||||||
|
service.delete_document_images(doc_id)
|
||||||
|
|
||||||
|
assert not backend.exists(f"images/{doc_id}/page_1.png")
|
||||||
|
assert not backend.exists(f"images/{doc_id}/page_2.png")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_admin_db() -> MagicMock:
|
||||||
|
"""Create a mock AdminDB for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.get_document.return_value = None
|
||||||
|
mock.create_document.return_value = MagicMock(
|
||||||
|
id="test-doc-id",
|
||||||
|
file_path="documents/test-doc-id.pdf",
|
||||||
|
)
|
||||||
|
return mock
|
||||||
@@ -103,6 +103,31 @@ class MockAnnotation:
|
|||||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||||
|
|
||||||
|
|
||||||
|
class MockModelVersion:
|
||||||
|
"""Mock ModelVersion for testing."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.version_id = kwargs.get('version_id', uuid4())
|
||||||
|
self.version = kwargs.get('version', '1.0.0')
|
||||||
|
self.name = kwargs.get('name', 'Test Model')
|
||||||
|
self.description = kwargs.get('description', None)
|
||||||
|
self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt')
|
||||||
|
self.status = kwargs.get('status', 'inactive')
|
||||||
|
self.is_active = kwargs.get('is_active', False)
|
||||||
|
self.task_id = kwargs.get('task_id', None)
|
||||||
|
self.dataset_id = kwargs.get('dataset_id', None)
|
||||||
|
self.metrics_mAP = kwargs.get('metrics_mAP', 0.935)
|
||||||
|
self.metrics_precision = kwargs.get('metrics_precision', 0.92)
|
||||||
|
self.metrics_recall = kwargs.get('metrics_recall', 0.88)
|
||||||
|
self.document_count = kwargs.get('document_count', 100)
|
||||||
|
self.training_config = kwargs.get('training_config', {})
|
||||||
|
self.file_size = kwargs.get('file_size', 52428800)
|
||||||
|
self.trained_at = kwargs.get('trained_at', datetime.utcnow())
|
||||||
|
self.activated_at = kwargs.get('activated_at', None)
|
||||||
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||||
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||||
|
|
||||||
|
|
||||||
class MockAdminDB:
|
class MockAdminDB:
|
||||||
"""Mock AdminDB for testing Phase 4."""
|
"""Mock AdminDB for testing Phase 4."""
|
||||||
|
|
||||||
@@ -111,6 +136,7 @@ class MockAdminDB:
|
|||||||
self.annotations = {}
|
self.annotations = {}
|
||||||
self.training_tasks = {}
|
self.training_tasks = {}
|
||||||
self.training_links = {}
|
self.training_links = {}
|
||||||
|
self.model_versions = {}
|
||||||
|
|
||||||
def get_documents_for_training(
|
def get_documents_for_training(
|
||||||
self,
|
self,
|
||||||
@@ -174,6 +200,14 @@ class MockAdminDB:
|
|||||||
"""Get training task by ID."""
|
"""Get training task by ID."""
|
||||||
return self.training_tasks.get(str(task_id))
|
return self.training_tasks.get(str(task_id))
|
||||||
|
|
||||||
|
def get_model_versions(self, status=None, limit=20, offset=0):
|
||||||
|
"""Get model versions with optional filtering."""
|
||||||
|
models = list(self.model_versions.values())
|
||||||
|
if status:
|
||||||
|
models = [m for m in models if m.status == status]
|
||||||
|
total = len(models)
|
||||||
|
return models[offset:offset+limit], total
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app():
|
def app():
|
||||||
@@ -241,6 +275,30 @@ def app():
|
|||||||
)
|
)
|
||||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||||
|
|
||||||
|
# Add model versions
|
||||||
|
model1 = MockModelVersion(
|
||||||
|
version="1.0.0",
|
||||||
|
name="Model v1.0.0",
|
||||||
|
status="inactive",
|
||||||
|
is_active=False,
|
||||||
|
metrics_mAP=0.935,
|
||||||
|
metrics_precision=0.92,
|
||||||
|
metrics_recall=0.88,
|
||||||
|
document_count=500,
|
||||||
|
)
|
||||||
|
model2 = MockModelVersion(
|
||||||
|
version="1.1.0",
|
||||||
|
name="Model v1.1.0",
|
||||||
|
status="active",
|
||||||
|
is_active=True,
|
||||||
|
metrics_mAP=0.951,
|
||||||
|
metrics_precision=0.94,
|
||||||
|
metrics_recall=0.92,
|
||||||
|
document_count=600,
|
||||||
|
)
|
||||||
|
mock_db.model_versions[str(model1.version_id)] = model1
|
||||||
|
mock_db.model_versions[str(model2.version_id)] = model2
|
||||||
|
|
||||||
# Override dependencies
|
# Override dependencies
|
||||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||||
@@ -324,10 +382,10 @@ class TestTrainingDocuments:
|
|||||||
|
|
||||||
|
|
||||||
class TestTrainingModels:
|
class TestTrainingModels:
|
||||||
"""Tests for GET /admin/training/models endpoint."""
|
"""Tests for GET /admin/training/models endpoint (ModelVersionListResponse)."""
|
||||||
|
|
||||||
def test_get_training_models_success(self, client):
|
def test_get_training_models_success(self, client):
|
||||||
"""Test getting trained models list."""
|
"""Test getting model versions list."""
|
||||||
response = client.get("/admin/training/models")
|
response = client.get("/admin/training/models")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -338,43 +396,44 @@ class TestTrainingModels:
|
|||||||
assert len(data["models"]) == 2
|
assert len(data["models"]) == 2
|
||||||
|
|
||||||
def test_get_training_models_includes_metrics(self, client):
|
def test_get_training_models_includes_metrics(self, client):
|
||||||
"""Test that models include metrics."""
|
"""Test that model versions include metrics."""
|
||||||
response = client.get("/admin/training/models")
|
response = client.get("/admin/training/models")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
# Check first model has metrics
|
# Check first model has metrics fields
|
||||||
model = data["models"][0]
|
model = data["models"][0]
|
||||||
assert "metrics" in model
|
assert "metrics_mAP" in model
|
||||||
assert "mAP" in model["metrics"]
|
assert model["metrics_mAP"] is not None
|
||||||
assert model["metrics"]["mAP"] is not None
|
|
||||||
assert "precision" in model["metrics"]
|
|
||||||
assert "recall" in model["metrics"]
|
|
||||||
|
|
||||||
def test_get_training_models_includes_download_url(self, client):
|
def test_get_training_models_includes_version_fields(self, client):
|
||||||
"""Test that completed models have download URLs."""
|
"""Test that model versions include version fields."""
|
||||||
response = client.get("/admin/training/models")
|
response = client.get("/admin/training/models")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
# Check completed models have download URLs
|
# Check model has expected fields
|
||||||
for model in data["models"]:
|
model = data["models"][0]
|
||||||
if model["status"] == "completed":
|
assert "version_id" in model
|
||||||
assert "download_url" in model
|
assert "version" in model
|
||||||
assert model["download_url"] is not None
|
assert "name" in model
|
||||||
|
assert "status" in model
|
||||||
|
assert "is_active" in model
|
||||||
|
assert "document_count" in model
|
||||||
|
|
||||||
def test_get_training_models_filter_by_status(self, client):
|
def test_get_training_models_filter_by_status(self, client):
|
||||||
"""Test filtering models by status."""
|
"""Test filtering model versions by status."""
|
||||||
response = client.get("/admin/training/models?status=completed")
|
response = client.get("/admin/training/models?status=active")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
# All returned models should be completed
|
assert data["total"] == 1
|
||||||
|
# All returned models should be active
|
||||||
for model in data["models"]:
|
for model in data["models"]:
|
||||||
assert model["status"] == "completed"
|
assert model["status"] == "active"
|
||||||
|
|
||||||
def test_get_training_models_pagination(self, client):
|
def test_get_training_models_pagination(self, client):
|
||||||
"""Test pagination for models."""
|
"""Test pagination for model versions."""
|
||||||
response = client.get("/admin/training/models?limit=1&offset=0")
|
response = client.get("/admin/training/models?limit=1&offset=0")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user