diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 2c51df3..8a71fb4 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -90,7 +90,23 @@ "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")", "Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")", "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")", - "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")" + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | head -150\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | tail -50\")", + "Bash(wsl bash -c \"lsof -ti:8000 | xargs -r kill -9 2>/dev/null; echo ''Port 8000 cleared''\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py\")", + "Bash(wsl bash -c \"curl -s http://localhost:3001 2>/dev/null | head -5 || echo ''Frontend not responding''\")", + "Bash(wsl bash -c \"curl -s http://localhost:3000 2>/dev/null | head -5 || echo ''Port 3000 not responding''\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c ''from shared.training import YOLOTrainer, TrainingConfig, TrainingResult; print\\(\"\"Shared training module imported successfully\"\"\\)''\")", + "Bash(npm run dev:*)", + "Bash(ping:*)", + "Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/frontend && npm run dev\")", + "Bash(git checkout:*)", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && PGPASSWORD=$DB_PASSWORD psql -h 192.168.68.31 -U docmaster -d docmaster -f migrations/006_model_versions.sql 2>&1\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c \"\"\nimport os\nimport psycopg2\nfrom pathlib import Path\n\n# Get connection details\nhost = os.getenv\\(''DB_HOST'', ''192.168.68.31''\\)\nport = os.getenv\\(''DB_PORT'', ''5432''\\)\ndbname = os.getenv\\(''DB_NAME'', ''docmaster''\\)\nuser = os.getenv\\(''DB_USER'', ''docmaster''\\)\npassword = os.getenv\\(''DB_PASSWORD'', ''''\\)\n\nprint\\(f''Connecting to {host}:{port}/{dbname}...''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\nconn.autocommit = True\ncursor = conn.cursor\\(\\)\n\n# Run migration 006\nprint\\(''Running migration 006_model_versions.sql...''\\)\nsql = Path\\(''migrations/006_model_versions.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 006 complete!''\\)\n\n# Run migration 007\nprint\\(''Running migration 007_training_tasks_extra_columns.sql...''\\)\nsql = Path\\(''migrations/007_training_tasks_extra_columns.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 007 complete!''\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\nprint\\(''All migrations completed successfully!''\\)\n\"\"\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")", + "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")" ], "deny": [], "ask": [], diff --git a/.coverage b/.coverage index 932eb87..e5ab665 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.env.example b/.env.example index 657852b..dde83cd 100644 --- a/.env.example +++ b/.env.example @@ -8,6 +8,23 @@ DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=your_password_here +# Storage Configuration +# Backend type: local, azure_blob, or s3 +# All storage paths are relative to STORAGE_BASE_PATH (documents/, images/, uploads/, etc.) +STORAGE_BACKEND=local +STORAGE_BASE_PATH=./data + +# Azure Blob Storage (when STORAGE_BACKEND=azure_blob) +# AZURE_STORAGE_CONNECTION_STRING=your_connection_string +# AZURE_STORAGE_CONTAINER=documents + +# AWS S3 Storage (when STORAGE_BACKEND=s3) +# AWS_S3_BUCKET=your_bucket_name +# AWS_REGION=us-east-1 +# AWS_ACCESS_KEY_ID=your_access_key +# AWS_SECRET_ACCESS_KEY=your_secret_key +# AWS_ENDPOINT_URL= # Optional: for S3-compatible services like MinIO + # Model Configuration (optional) # MODEL_PATH=runs/train/invoice_fields/weights/best.pt # CONFIDENCE_THRESHOLD=0.5 diff --git a/README.md b/README.md index 2d23c67..3262c49 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,9 @@ 本项目实现了一个完整的发票字段自动提取流程: 1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注 -2. **模型训练**: 使用 YOLOv11 训练字段检测模型 +2. **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强 3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化 +4. **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理 ### 架构 @@ -16,15 +17,17 @@ ``` packages/ -├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 工具) +├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 存储, 训练) ├── training/ # 训练服务 (GPU, 按需启动) └── inference/ # 推理服务 (常驻运行) +frontend/ # React 前端 (Vite + TypeScript + TailwindCSS) ``` | 服务 | 部署目标 | GPU | 生命周期 | |------|---------|-----|---------| -| **Inference** | Azure App Service | 可选 | 常驻 7x24 | -| **Training** | Azure ACI | 必需 | 按需启动/销毁 | +| **Frontend** | Vercel / Nginx | 否 | 常驻 | +| **Inference** | Azure App Service / AWS | 可选 | 常驻 7x24 | +| **Training** | Azure ACI / AWS ECS | 必需 | 按需启动/销毁 | 两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。 @@ -34,7 +37,8 @@ packages/ |------|------| | **已标注文档** | 9,738 (9,709 成功) | | **总体字段匹配率** | 94.8% (82,604/87,121) | -| **测试** | 922 passed | +| **测试** | 1,601 passed | +| **测试覆盖率** | 28% | | **模型 mAP@0.5** | 93.5% | **各字段匹配率:** @@ -97,6 +101,9 @@ invoice-master-poc-v2/ │ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析 │ │ ├── normalize/ # 字段规范化 (10 种 normalizer) │ │ ├── matcher/ # 字段匹配 (精确/子串/模糊) +│ │ ├── storage/ # 存储抽象层 (Local/Azure/S3) +│ │ ├── training/ # 共享训练组件 (YOLOTrainer) +│ │ ├── augmentation/ # 数据增强 (DatasetAugmenter) │ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配) │ │ ├── data/ # DocumentDB, CSVLoader │ │ ├── config.py # 全局配置 (数据库, 路径, DPI) @@ -129,12 +136,29 @@ invoice-master-poc-v2/ │ ├── data/ # AdminDB, AsyncRequestDB, Models │ └── azure/ # ACI 训练触发器 │ -├── migrations/ # 数据库迁移 -│ ├── 001_async_tables.sql -│ ├── 002_nullable_admin_token.sql -│ └── 003_training_tasks.sql -├── frontend/ # React 前端 (Vite + TypeScript) -├── tests/ # 测试 (922 tests) +├── frontend/ # React 前端 (Vite + TypeScript + TailwindCSS) +│ ├── src/ +│ │ ├── api/ # API 客户端 (axios + react-query) +│ │ ├── components/ # UI 组件 +│ │ │ ├── Dashboard.tsx # 文档管理面板 +│ │ │ ├── Training.tsx # 训练管理 (数据集/任务) +│ │ │ ├── Models.tsx # 模型版本管理 +│ │ │ ├── DatasetDetail.tsx # 数据集详情 +│ │ │ └── InferenceDemo.tsx # 推理演示 +│ │ └── hooks/ # React Query hooks +│ └── package.json +│ +├── migrations/ # 数据库迁移 (SQL) +│ ├── 003_training_tasks.sql +│ ├── 004_training_datasets.sql +│ ├── 005_add_group_key.sql +│ ├── 006_model_versions.sql +│ ├── 007_training_tasks_extra_columns.sql +│ ├── 008_fix_model_versions_fk.sql +│ ├── 009_add_document_category.sql +│ └── 010_add_dataset_training_status.sql +│ +├── tests/ # 测试 (1,601 tests) ├── docker-compose.yml # 本地开发 (postgres + inference + training) ├── run_server.py # 快捷启动脚本 └── runs/train/ # 训练输出 (weights, curves) @@ -270,9 +294,32 @@ Inference API PostgreSQL Training (ACI) | POST | `/api/v1/admin/documents/upload` | 上传 PDF | | GET | `/api/v1/admin/documents/{id}` | 文档详情 | | PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 | +| PATCH | `/api/v1/admin/documents/{id}/category` | 更新文档分类 | +| GET | `/api/v1/admin/documents/categories` | 获取分类列表 | | POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 | -| POST | `/api/v1/admin/training/trigger` | 触发训练任务 | -| GET | `/api/v1/admin/training/{id}/status` | 查询训练状态 | + +**Training API:** + +| 方法 | 端点 | 描述 | +|------|------|------| +| POST | `/api/v1/admin/training/datasets` | 创建数据集 | +| GET | `/api/v1/admin/training/datasets` | 数据集列表 | +| GET | `/api/v1/admin/training/datasets/{id}` | 数据集详情 | +| DELETE | `/api/v1/admin/training/datasets/{id}` | 删除数据集 | +| POST | `/api/v1/admin/training/tasks` | 创建训练任务 | +| GET | `/api/v1/admin/training/tasks` | 任务列表 | +| GET | `/api/v1/admin/training/tasks/{id}` | 任务详情 | +| GET | `/api/v1/admin/training/tasks/{id}/logs` | 训练日志 | + +**Model Versions API:** + +| 方法 | 端点 | 描述 | +|------|------|------| +| GET | `/api/v1/admin/models` | 模型版本列表 | +| GET | `/api/v1/admin/models/{id}` | 模型详情 | +| POST | `/api/v1/admin/models/{id}/activate` | 激活模型 | +| POST | `/api/v1/admin/models/{id}/archive` | 归档模型 | +| DELETE | `/api/v1/admin/models/{id}` | 删除模型 | ## Python API @@ -332,8 +379,41 @@ print(f"Customer Number: {result}") # "UMJ 436-R" | 数据库 | 用途 | 存储内容 | |--------|------|----------| -| **PostgreSQL** | 标注结果 | `documents`, `field_results`, `training_tasks` | -| **SQLite** (AdminDB) | Web 应用 | 文档管理, 标注编辑, 用户认证 | +| **PostgreSQL** | 主数据库 | 文档、标注、训练任务、数据集、模型版本 | + +### 主要表 + +| 表名 | 说明 | +|------|------| +| `admin_documents` | 文档管理 (PDF 元数据, 状态, 分类) | +| `admin_annotations` | 标注数据 (YOLO 格式边界框) | +| `training_tasks` | 训练任务 (状态, 配置, 指标) | +| `training_datasets` | 数据集 (train/val/test 分割) | +| `dataset_documents` | 数据集-文档关联 | +| `model_versions` | 模型版本管理 (激活/归档) | +| `admin_tokens` | 管理员认证令牌 | +| `async_requests` | 异步推理请求 | + +### 数据集状态 + +| 状态 | 说明 | +|------|------| +| `building` | 正在构建数据集 | +| `ready` | 数据集就绪,可开始训练 | +| `trained` | 已完成训练 | +| `failed` | 构建失败 | +| `archived` | 已归档 | + +### 训练状态 + +| 状态 | 说明 | +|------|------| +| `pending` | 等待执行 | +| `scheduled` | 已计划 | +| `running` | 正在训练 | +| `completed` | 训练完成 | +| `failed` | 训练失败 | +| `cancelled` | 已取消 | ## 测试 @@ -347,8 +427,114 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing | 指标 | 数值 | |------|------| -| **测试总数** | 922 | +| **测试总数** | 1,601 | | **通过率** | 100% | +| **覆盖率** | 28% | + +## 存储抽象层 + +统一的文件存储接口,支持多后端切换: + +| 后端 | 用途 | 安装 | +|------|------|------| +| **Local** | 本地开发/测试 | 默认 | +| **Azure Blob** | Azure 云部署 | `pip install -e "packages/shared[azure]"` | +| **AWS S3** | AWS 云部署 | `pip install -e "packages/shared[s3]"` | + +### 配置文件 (storage.yaml) + +```yaml +backend: ${STORAGE_BACKEND:-local} +presigned_url_expiry: 3600 + +local: + base_path: ${STORAGE_BASE_PATH:-./data/storage} + +azure: + connection_string: ${AZURE_STORAGE_CONNECTION_STRING} + container_name: ${AZURE_STORAGE_CONTAINER:-documents} + +s3: + bucket_name: ${AWS_S3_BUCKET} + region_name: ${AWS_REGION:-us-east-1} +``` + +### 使用示例 + +```python +from shared.storage import get_storage_backend + +# 从配置文件加载 +storage = get_storage_backend("storage.yaml") + +# 上传文件 +storage.upload(Path("local.pdf"), "documents/invoice.pdf") + +# 获取预签名 URL (前端访问) +url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=3600) +``` + +### 环境变量 + +| 变量 | 后端 | 说明 | +|------|------|------| +| `STORAGE_BACKEND` | 全部 | `local`, `azure_blob`, `s3` | +| `STORAGE_BASE_PATH` | Local | 本地存储路径 | +| `AZURE_STORAGE_CONNECTION_STRING` | Azure | 连接字符串 | +| `AZURE_STORAGE_CONTAINER` | Azure | 容器名称 | +| `AWS_S3_BUCKET` | S3 | 存储桶名称 | +| `AWS_REGION` | S3 | 区域 (默认: us-east-1) | + +## 数据增强 + +训练时支持多种数据增强策略: + +| 增强类型 | 说明 | +|----------|------| +| `perspective_warp` | 透视变换 (模拟扫描角度) | +| `wrinkle` | 皱纹效果 | +| `edge_damage` | 边缘损坏 | +| `stain` | 污渍效果 | +| `lighting_variation` | 光照变化 | +| `shadow` | 阴影效果 | +| `gaussian_blur` | 高斯模糊 | +| `motion_blur` | 运动模糊 | +| `gaussian_noise` | 高斯噪声 | +| `salt_pepper` | 椒盐噪声 | +| `paper_texture` | 纸张纹理 | +| `scanner_artifacts` | 扫描伪影 | + +增强配置示例: + +```json +{ + "augmentation": { + "gaussian_blur": { "enabled": true, "kernel_size": 5 }, + "perspective_warp": { "enabled": true, "intensity": 0.1 } + }, + "augmentation_multiplier": 2 +} +``` + +## 前端功能 + +React 前端提供以下功能模块: + +| 模块 | 功能 | +|------|------| +| **Dashboard** | 文档列表、上传、标注状态管理、分类筛选 | +| **Training** | 数据集创建/管理、训练任务配置、增强设置 | +| **Models** | 模型版本管理、激活/归档、指标查看 | +| **Inference Demo** | 实时推理演示、结果可视化 | + +### 启动前端 + +```bash +cd frontend +npm install +npm run dev +# 访问 http://localhost:5173 +``` ## 技术栈 @@ -357,10 +543,27 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing | **目标检测** | YOLOv11 (Ultralytics) | | **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) | | **PDF 处理** | PyMuPDF (fitz) | -| **数据库** | PostgreSQL + psycopg2 | +| **数据库** | PostgreSQL + SQLModel | | **Web 框架** | FastAPI + Uvicorn | +| **前端** | React + TypeScript + Vite + TailwindCSS | +| **状态管理** | React Query (TanStack Query) | | **深度学习** | PyTorch + CUDA 12.x | -| **部署** | Docker + Azure ACI (训练) / App Service (推理) | +| **部署** | Docker + Azure/AWS (训练) / App Service (推理) | + +## 环境变量 + +| 变量 | 必需 | 说明 | +|------|------|------| +| `DB_PASSWORD` | 是 | PostgreSQL 密码 | +| `DB_HOST` | 否 | 数据库主机 (默认: localhost) | +| `DB_PORT` | 否 | 数据库端口 (默认: 5432) | +| `DB_NAME` | 否 | 数据库名 (默认: docmaster) | +| `DB_USER` | 否 | 数据库用户 (默认: docmaster) | +| `STORAGE_BASE_PATH` | 否 | 存储路径 (默认: ~/invoice-data/data) | +| `MODEL_PATH` | 否 | 模型路径 | +| `CONFIDENCE_THRESHOLD` | 否 | 置信度阈值 (默认: 0.5) | +| `SERVER_HOST` | 否 | 服务器主机 (默认: 0.0.0.0) | +| `SERVER_PORT` | 否 | 服务器端口 (默认: 8000) | ## 许可证 diff --git a/docs/aws-deployment-guide.md b/docs/aws-deployment-guide.md new file mode 100644 index 0000000..79f1842 --- /dev/null +++ b/docs/aws-deployment-guide.md @@ -0,0 +1,772 @@ +# AWS 部署方案完整指南 + +## 目录 +- [核心问题](#核心问题) +- [存储方案](#存储方案) +- [训练方案](#训练方案) +- [推理方案](#推理方案) +- [价格对比](#价格对比) +- [推荐架构](#推荐架构) +- [实施步骤](#实施步骤) +- [AWS vs Azure 对比](#aws-vs-azure-对比) + +--- + +## 核心问题 + +| 问题 | 答案 | +|------|------| +| S3 能用于训练吗? | 可以,用 Mountpoint for S3 或 SageMaker 原生支持 | +| 能实时从 S3 读取训练吗? | 可以,SageMaker 支持 Pipe Mode 流式读取 | +| 本地能挂载 S3 吗? | 可以,用 s3fs-fuse 或 Rclone | +| EC2 空闲时收费吗? | 收费,只要运行就按小时计费 | +| 如何按需付费? | 用 SageMaker Managed Spot 或 Lambda | +| 推理服务用什么? | Lambda (Serverless) 或 ECS/Fargate (容器) | + +--- + +## 存储方案 + +### Amazon S3(推荐) + +S3 是 AWS 的核心存储服务,与 SageMaker 深度集成。 + +```bash +# 创建 S3 桶 +aws s3 mb s3://invoice-training-data --region us-east-1 + +# 上传训练数据 +aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/ + +# 创建目录结构 +aws s3api put-object --bucket invoice-training-data --key datasets/ +aws s3api put-object --bucket invoice-training-data --key models/ +``` + +### Mountpoint for Amazon S3 + +AWS 官方的 S3 挂载客户端,性能优于 s3fs: + +```bash +# 安装 Mountpoint +wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb +sudo dpkg -i mount-s3.deb + +# 挂载 S3 +mkdir -p /mnt/s3-data +mount-s3 invoice-training-data /mnt/s3-data --region us-east-1 + +# 配置缓存(推荐) +mount-s3 invoice-training-data /mnt/s3-data \ + --region us-east-1 \ + --cache /tmp/s3-cache \ + --metadata-ttl 60 +``` + +### 本地开发挂载 + +**Linux/Mac (s3fs-fuse):** +```bash +# 安装 +sudo apt-get install s3fs + +# 配置凭证 +echo ACCESS_KEY_ID:SECRET_ACCESS_KEY > ~/.passwd-s3fs +chmod 600 ~/.passwd-s3fs + +# 挂载 +s3fs invoice-training-data /mnt/s3 -o passwd_file=~/.passwd-s3fs +``` + +**Windows (Rclone):** +```powershell +# 安装 +winget install Rclone.Rclone + +# 配置 +rclone config # 选择 s3 + +# 挂载 +rclone mount aws:invoice-training-data Z: --vfs-cache-mode full +``` + +### 存储费用 + +| 层级 | 价格 | 适用场景 | +|------|------|---------| +| S3 Standard | $0.023/GB/月 | 频繁访问 | +| S3 Intelligent-Tiering | $0.023/GB/月 | 自动分层 | +| S3 Infrequent Access | $0.0125/GB/月 | 偶尔访问 | +| S3 Glacier | $0.004/GB/月 | 长期存档 | + +**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.12/月** + +### SageMaker 数据输入模式 + +| 模式 | 说明 | 适用场景 | +|------|------|---------| +| File Mode | 下载到本地再训练 | 小数据集 | +| Pipe Mode | 流式读取,不占本地空间 | 大数据集 | +| FastFile Mode | 按需下载,最高 3x 加速 | 推荐 | + +--- + +## 训练方案 + +### 方案总览 + +| 方案 | 适用场景 | 空闲费用 | 复杂度 | Spot 支持 | +|------|---------|---------|--------|----------| +| EC2 GPU | 简单直接 | 24/7 收费 | 低 | 是 | +| SageMaker Training | MLOps 集成 | 按任务计费 | 中 | 是 | +| EKS + GPU | Kubernetes | 复杂计费 | 高 | 是 | + +### EC2 vs SageMaker + +| 特性 | EC2 | SageMaker | +|------|-----|-----------| +| 本质 | 虚拟机 | 托管 ML 平台 | +| 计算费用 | $3.06/hr (p3.2xlarge) | $3.825/hr (+25%) | +| 管理开销 | 需自己配置 | 全托管 | +| Spot 折扣 | 最高 90% | 最高 90% | +| 实验跟踪 | 无 | 内置 | +| 自动关机 | 无 | 任务完成自动停止 | + +### GPU 实例价格 (2025 年 6 月降价后) + +| 实例 | GPU | 显存 | On-Demand | Spot 价格 | +|------|-----|------|-----------|----------| +| g4dn.xlarge | 1x T4 | 16GB | $0.526/hr | ~$0.16/hr | +| g4dn.2xlarge | 1x T4 | 16GB | $0.752/hr | ~$0.23/hr | +| p3.2xlarge | 1x V100 | 16GB | $3.06/hr | ~$0.92/hr | +| p3.8xlarge | 4x V100 | 64GB | $12.24/hr | ~$3.67/hr | +| p4d.24xlarge | 8x A100 | 320GB | $32.77/hr | ~$9.83/hr | + +**注意**: 2025 年 6 月 AWS 宣布 P4/P5 系列最高降价 45%。 + +### Spot 实例 + +```bash +# EC2 Spot 请求 +aws ec2 request-spot-instances \ + --instance-count 1 \ + --type "one-time" \ + --launch-specification '{ + "ImageId": "ami-0123456789abcdef0", + "InstanceType": "p3.2xlarge", + "KeyName": "my-key" + }' +``` + +### SageMaker Managed Spot Training + +```python +from sagemaker.pytorch import PyTorch + +estimator = PyTorch( + entry_point="train.py", + source_dir="./src", + role="arn:aws:iam::123456789012:role/SageMakerRole", + instance_count=1, + instance_type="ml.p3.2xlarge", + framework_version="2.0", + py_version="py310", + + # 启用 Spot 实例 + use_spot_instances=True, + max_run=3600, # 最长运行 1 小时 + max_wait=7200, # 最长等待 2 小时 + + # 检查点配置(Spot 中断恢复) + checkpoint_s3_uri="s3://invoice-training-data/checkpoints/", + checkpoint_local_path="/opt/ml/checkpoints", + + hyperparameters={ + "epochs": 100, + "batch-size": 16, + } +) + +estimator.fit({ + "training": "s3://invoice-training-data/datasets/train/", + "validation": "s3://invoice-training-data/datasets/val/" +}) +``` + +--- + +## 推理方案 + +### 方案对比 + +| 方案 | GPU 支持 | 扩缩容 | 冷启动 | 价格 | 适用场景 | +|------|---------|--------|--------|------|---------| +| Lambda | 否 | 自动 0-N | 快 | 按调用 | 低流量、CPU 推理 | +| Lambda + Container | 否 | 自动 0-N | 较慢 | 按调用 | 复杂依赖 | +| ECS Fargate | 否 | 自动 | 中 | ~$30/月 | 容器化服务 | +| ECS + EC2 GPU | 是 | 手动/自动 | 慢 | ~$100+/月 | GPU 推理 | +| SageMaker Endpoint | 是 | 自动 | 慢 | ~$80+/月 | MLOps 集成 | +| SageMaker Serverless | 否 | 自动 0-N | 中 | 按调用 | 间歇性流量 | + +### 推荐方案 1: AWS Lambda (低流量) + +对于 YOLO CPU 推理,Lambda 最经济: + +```python +# lambda_function.py +import json +import boto3 +from ultralytics import YOLO + +# 模型在 Lambda Layer 或 /tmp 加载 +model = None + +def load_model(): + global model + if model is None: + # 从 S3 下载模型到 /tmp + s3 = boto3.client('s3') + s3.download_file('invoice-models', 'best.pt', '/tmp/best.pt') + model = YOLO('/tmp/best.pt') + return model + +def lambda_handler(event, context): + model = load_model() + + # 从 S3 获取图片 + s3 = boto3.client('s3') + bucket = event['bucket'] + key = event['key'] + + local_path = f'/tmp/{key.split("/")[-1]}' + s3.download_file(bucket, key, local_path) + + # 执行推理 + results = model.predict(local_path, conf=0.5) + + return { + 'statusCode': 200, + 'body': json.dumps({ + 'fields': extract_fields(results), + 'confidence': get_confidence(results) + }) + } +``` + +**Lambda 配置:** +```yaml +# serverless.yml +service: invoice-inference + +provider: + name: aws + runtime: python3.11 + timeout: 30 + memorySize: 4096 # 4GB 内存 + +functions: + infer: + handler: lambda_function.lambda_handler + events: + - http: + path: /infer + method: post + layers: + - arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1 +``` + +### 推荐方案 2: ECS Fargate (中流量) + +```yaml +# task-definition.json +{ + "family": "invoice-inference", + "networkMode": "awsvpc", + "requiresCompatibilities": ["FARGATE"], + "cpu": "2048", + "memory": "4096", + "containerDefinitions": [ + { + "name": "inference", + "image": "123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest", + "portMappings": [ + { + "containerPort": 8000, + "protocol": "tcp" + } + ], + "environment": [ + {"name": "MODEL_PATH", "value": "/app/models/best.pt"} + ], + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "/ecs/invoice-inference", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "ecs" + } + } + } + ] +} +``` + +**Auto Scaling 配置:** +```bash +# 创建 Auto Scaling Target +aws application-autoscaling register-scalable-target \ + --service-namespace ecs \ + --resource-id service/invoice-cluster/invoice-service \ + --scalable-dimension ecs:service:DesiredCount \ + --min-capacity 1 \ + --max-capacity 10 + +# 基于 CPU 使用率扩缩容 +aws application-autoscaling put-scaling-policy \ + --service-namespace ecs \ + --resource-id service/invoice-cluster/invoice-service \ + --scalable-dimension ecs:service:DesiredCount \ + --policy-name cpu-scaling \ + --policy-type TargetTrackingScaling \ + --target-tracking-scaling-policy-configuration '{ + "TargetValue": 70, + "PredefinedMetricSpecification": { + "PredefinedMetricType": "ECSServiceAverageCPUUtilization" + }, + "ScaleOutCooldown": 60, + "ScaleInCooldown": 120 + }' +``` + +### 方案 3: SageMaker Serverless Inference + +```python +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.pytorch import PyTorchModel + +model = PyTorchModel( + model_data="s3://invoice-models/model.tar.gz", + role="arn:aws:iam::123456789012:role/SageMakerRole", + entry_point="inference.py", + framework_version="2.0", + py_version="py310" +) + +serverless_config = ServerlessInferenceConfig( + memory_size_in_mb=4096, + max_concurrency=10 +) + +predictor = model.deploy( + serverless_inference_config=serverless_config, + endpoint_name="invoice-inference-serverless" +) +``` + +### 推理性能对比 + +| 配置 | 单次推理时间 | 并发能力 | 月费估算 | +|------|------------|---------|---------| +| Lambda 4GB | ~500-800ms | 按需扩展 | ~$15 (10K 请求) | +| Fargate 2vCPU 4GB | ~300-500ms | ~50 QPS | ~$30 | +| Fargate 4vCPU 8GB | ~200-300ms | ~100 QPS | ~$60 | +| EC2 g4dn.xlarge (T4) | ~50-100ms | ~200 QPS | ~$380 | + +--- + +## 价格对比 + +### 训练成本对比(假设每天训练 2 小时) + +| 方案 | 计算方式 | 月费 | +|------|---------|------| +| EC2 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 | +| EC2 按需启停 | 2h × 30天 × $3.06 | ~$184 | +| EC2 Spot 按需 | 2h × 30天 × $0.92 | ~$55 | +| SageMaker On-Demand | 2h × 30天 × $3.825 | ~$230 | +| SageMaker Spot | 2h × 30天 × $1.15 | ~$69 | + +### 本项目完整成本估算 + +| 组件 | 推荐方案 | 月费 | +|------|---------|------| +| 数据存储 | S3 Standard (5GB) | ~$0.12 | +| 数据库 | RDS PostgreSQL (db.t3.micro) | ~$15 | +| 推理服务 | Lambda (10K 请求/月) | ~$15 | +| 推理服务 (替代) | ECS Fargate | ~$30 | +| 训练服务 | SageMaker Spot (按需) | ~$2-5/次 | +| ECR (镜像存储) | 基本使用 | ~$1 | +| **总计 (Lambda)** | | **~$35/月** + 训练费 | +| **总计 (Fargate)** | | **~$50/月** + 训练费 | + +--- + +## 推荐架构 + +### 整体架构图 + +``` + ┌─────────────────────────────────────┐ + │ Amazon S3 │ + │ ├── training-images/ │ + │ ├── datasets/ │ + │ ├── models/ │ + │ └── checkpoints/ │ + └─────────────────┬───────────────────┘ + │ + ┌─────────────────────────────────┼─────────────────────────────────┐ + │ │ │ + ▼ ▼ ▼ +┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐ +│ 推理服务 │ │ 训练服务 │ │ API Gateway │ +│ │ │ │ │ │ +│ 方案 A: Lambda │ │ SageMaker │ │ REST API │ +│ ~$15/月 (10K req) │ │ Managed Spot │ │ 触发 Lambda/ECS │ +│ │ │ ~$2-5/次训练 │ │ │ +│ 方案 B: ECS Fargate │ │ │ │ │ +│ ~$30/月 │ │ - 自动启动 │ │ │ +│ │ │ - 训练完成自动停止 │ │ │ +│ ┌───────────────────┐ │ │ - 检查点自动保存 │ │ │ +│ │ FastAPI + YOLO │ │ │ │ │ │ +│ │ CPU 推理 │ │ │ │ │ │ +│ └───────────────────┘ │ └───────────┬───────────┘ └───────────────────────┘ +└───────────┬───────────┘ │ + │ │ + └───────────────────────────────┼───────────────────────────────────────────┘ + │ + ▼ + ┌───────────────────────┐ + │ Amazon RDS │ + │ PostgreSQL │ + │ db.t3.micro │ + │ ~$15/月 │ + └───────────────────────┘ +``` + +### Lambda 推理配置 + +```yaml +# SAM template +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + InferenceFunction: + Type: AWS::Serverless::Function + Properties: + Handler: app.lambda_handler + Runtime: python3.11 + MemorySize: 4096 + Timeout: 30 + Environment: + Variables: + MODEL_BUCKET: invoice-models + MODEL_KEY: best.pt + Policies: + - S3ReadPolicy: + BucketName: invoice-models + - S3ReadPolicy: + BucketName: invoice-uploads + Events: + InferApi: + Type: Api + Properties: + Path: /infer + Method: post +``` + +### SageMaker 训练配置 + +```python +from sagemaker.pytorch import PyTorch + +estimator = PyTorch( + entry_point="train.py", + source_dir="./src", + role="arn:aws:iam::123456789012:role/SageMakerRole", + instance_count=1, + instance_type="ml.g4dn.xlarge", # T4 GPU + framework_version="2.0", + py_version="py310", + + # Spot 实例配置 + use_spot_instances=True, + max_run=7200, + max_wait=14400, + + # 检查点 + checkpoint_s3_uri="s3://invoice-training-data/checkpoints/", + + hyperparameters={ + "epochs": 100, + "batch-size": 16, + "model": "yolo11n.pt" + } +) +``` + +--- + +## 实施步骤 + +### 阶段 1: 存储设置 + +```bash +# 创建 S3 桶 +aws s3 mb s3://invoice-training-data --region us-east-1 +aws s3 mb s3://invoice-models --region us-east-1 + +# 上传训练数据 +aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/ + +# 配置生命周期(可选,自动转冷存储) +aws s3api put-bucket-lifecycle-configuration \ + --bucket invoice-training-data \ + --lifecycle-configuration '{ + "Rules": [{ + "ID": "MoveToIA", + "Status": "Enabled", + "Transitions": [{ + "Days": 30, + "StorageClass": "STANDARD_IA" + }] + }] + }' +``` + +### 阶段 2: 数据库设置 + +```bash +# 创建 RDS PostgreSQL +aws rds create-db-instance \ + --db-instance-identifier invoice-db \ + --db-instance-class db.t3.micro \ + --engine postgres \ + --engine-version 15 \ + --master-username docmaster \ + --master-user-password YOUR_PASSWORD \ + --allocated-storage 20 + +# 配置安全组 +aws ec2 authorize-security-group-ingress \ + --group-id sg-xxx \ + --protocol tcp \ + --port 5432 \ + --source-group sg-yyy +``` + +### 阶段 3: 推理服务部署 + +**方案 A: Lambda** + +```bash +# 创建 Lambda Layer (依赖) +cd lambda-layer +pip install ultralytics opencv-python-headless -t python/ +zip -r layer.zip python/ +aws lambda publish-layer-version \ + --layer-name yolo-deps \ + --zip-file fileb://layer.zip \ + --compatible-runtimes python3.11 + +# 部署 Lambda 函数 +cd ../lambda +zip function.zip lambda_function.py +aws lambda create-function \ + --function-name invoice-inference \ + --runtime python3.11 \ + --handler lambda_function.lambda_handler \ + --role arn:aws:iam::123456789012:role/LambdaRole \ + --zip-file fileb://function.zip \ + --memory-size 4096 \ + --timeout 30 \ + --layers arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1 + +# 创建 API Gateway +aws apigatewayv2 create-api \ + --name invoice-api \ + --protocol-type HTTP \ + --target arn:aws:lambda:us-east-1:123456789012:function:invoice-inference +``` + +**方案 B: ECS Fargate** + +```bash +# 创建 ECR 仓库 +aws ecr create-repository --repository-name invoice-inference + +# 构建并推送镜像 +aws ecr get-login-password | docker login --username AWS --password-stdin 123456789012.dkr.ecr.us-east-1.amazonaws.com +docker build -t invoice-inference . +docker tag invoice-inference:latest 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest +docker push 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest + +# 创建 ECS 集群 +aws ecs create-cluster --cluster-name invoice-cluster + +# 注册任务定义 +aws ecs register-task-definition --cli-input-json file://task-definition.json + +# 创建服务 +aws ecs create-service \ + --cluster invoice-cluster \ + --service-name invoice-service \ + --task-definition invoice-inference \ + --desired-count 1 \ + --launch-type FARGATE \ + --network-configuration '{ + "awsvpcConfiguration": { + "subnets": ["subnet-xxx"], + "securityGroups": ["sg-xxx"], + "assignPublicIp": "ENABLED" + } + }' +``` + +### 阶段 4: 训练服务设置 + +```python +# setup_sagemaker.py +import boto3 +import sagemaker +from sagemaker.pytorch import PyTorch + +# 创建 SageMaker 执行角色 +iam = boto3.client('iam') +role_arn = "arn:aws:iam::123456789012:role/SageMakerExecutionRole" + +# 配置训练任务 +estimator = PyTorch( + entry_point="train.py", + source_dir="./src/training", + role=role_arn, + instance_count=1, + instance_type="ml.g4dn.xlarge", + framework_version="2.0", + py_version="py310", + use_spot_instances=True, + max_run=7200, + max_wait=14400, + checkpoint_s3_uri="s3://invoice-training-data/checkpoints/", +) + +# 保存配置供后续使用 +estimator.save("training_config.json") +``` + +### 阶段 5: 集成训练触发 API + +```python +# lambda_trigger_training.py +import boto3 +import sagemaker +from sagemaker.pytorch import PyTorch + +def lambda_handler(event, context): + """触发 SageMaker 训练任务""" + + epochs = event.get('epochs', 100) + + estimator = PyTorch( + entry_point="train.py", + source_dir="s3://invoice-training-data/code/", + role="arn:aws:iam::123456789012:role/SageMakerRole", + instance_count=1, + instance_type="ml.g4dn.xlarge", + framework_version="2.0", + py_version="py310", + use_spot_instances=True, + max_run=7200, + max_wait=14400, + hyperparameters={ + "epochs": epochs, + "batch-size": 16, + } + ) + + estimator.fit( + inputs={ + "training": "s3://invoice-training-data/datasets/train/", + "validation": "s3://invoice-training-data/datasets/val/" + }, + wait=False # 异步执行 + ) + + return { + 'statusCode': 200, + 'body': { + 'training_job_name': estimator.latest_training_job.name, + 'status': 'Started' + } + } +``` + +--- + +## AWS vs Azure 对比 + +### 服务对应关系 + +| 功能 | AWS | Azure | +|------|-----|-------| +| 对象存储 | S3 | Blob Storage | +| 挂载工具 | Mountpoint for S3 | BlobFuse2 | +| ML 平台 | SageMaker | Azure ML | +| 容器服务 | ECS/Fargate | Container Apps | +| Serverless | Lambda | Functions | +| GPU VM | EC2 P3/G4dn | NC/ND 系列 | +| 容器注册 | ECR | ACR | +| 数据库 | RDS PostgreSQL | PostgreSQL Flexible | + +### 价格对比 + +| 组件 | AWS | Azure | +|------|-----|-------| +| 存储 (5GB) | ~$0.12/月 | ~$0.09/月 | +| 数据库 | ~$15/月 | ~$25/月 | +| 推理 (Serverless) | ~$15/月 | ~$30/月 | +| 推理 (容器) | ~$30/月 | ~$30/月 | +| 训练 (Spot GPU) | ~$2-5/次 | ~$1-5/次 | +| **总计** | **~$35-50/月** | **~$65/月** | + +### 优劣对比 + +| 方面 | AWS 优势 | Azure 优势 | +|------|---------|-----------| +| 价格 | Lambda 更便宜 | GPU Spot 更便宜 | +| ML 平台 | SageMaker 更成熟 | Azure ML 更易用 | +| Serverless GPU | 无原生支持 | Container Apps GPU | +| 文档 | 更丰富 | 中文文档更好 | +| 生态 | 更大 | Office 365 集成 | + +--- + +## 总结 + +### 推荐配置 + +| 组件 | 推荐方案 | 月费估算 | +|------|---------|---------| +| 数据存储 | S3 Standard | ~$0.12 | +| 数据库 | RDS db.t3.micro | ~$15 | +| 推理服务 | Lambda 4GB | ~$15 | +| 训练服务 | SageMaker Spot | 按需 ~$2-5/次 | +| ECR | 基本使用 | ~$1 | +| **总计** | | **~$35/月** + 训练费 | + +### 关键决策 + +| 场景 | 选择 | +|------|------| +| 最低成本 | Lambda + SageMaker Spot | +| 稳定推理 | ECS Fargate | +| GPU 推理 | ECS + EC2 GPU | +| MLOps 集成 | SageMaker 全家桶 | + +### 注意事项 + +1. **Lambda 冷启动**: 首次调用 ~3-5 秒,可用 Provisioned Concurrency 解决 +2. **Spot 中断**: 配置检查点,SageMaker 自动恢复 +3. **S3 传输**: 同区域免费,跨区域收费 +4. **Fargate 无 GPU**: 需要 GPU 必须用 ECS + EC2 +5. **SageMaker 加价**: 比 EC2 贵 ~25%,但省管理成本 diff --git a/docs/azure-deployment-guide.md b/docs/azure-deployment-guide.md new file mode 100644 index 0000000..f2dec60 --- /dev/null +++ b/docs/azure-deployment-guide.md @@ -0,0 +1,567 @@ +# Azure 部署方案完整指南 + +## 目录 +- [核心问题](#核心问题) +- [存储方案](#存储方案) +- [训练方案](#训练方案) +- [推理方案](#推理方案) +- [价格对比](#价格对比) +- [推荐架构](#推荐架构) +- [实施步骤](#实施步骤) + +--- + +## 核心问题 + +| 问题 | 答案 | +|------|------| +| Azure Blob Storage 能用于训练吗? | 可以,用 BlobFuse2 挂载 | +| 能实时从 Blob 读取训练吗? | 可以,但建议配置本地缓存 | +| 本地能挂载 Azure Blob 吗? | 可以,用 Rclone (Windows) 或 BlobFuse2 (Linux) | +| VM 空闲时收费吗? | 收费,只要开机就按小时计费 | +| 如何按需付费? | 用 Serverless GPU 或 min=0 的 Compute Cluster | +| 推理服务用什么? | Container Apps (CPU) 或 Serverless GPU | + +--- + +## 存储方案 + +### Azure Blob Storage + BlobFuse2(推荐) + +```bash +# 安装 BlobFuse2 +sudo apt-get install blobfuse2 + +# 配置文件 +cat > ~/blobfuse-config.yaml << 'EOF' +logging: + type: syslog + level: log_warning + +components: + - libfuse + - file_cache + - azstorage + +file_cache: + path: /tmp/blobfuse2 + timeout-sec: 120 + max-size-mb: 4096 + +azstorage: + type: block + account-name: YOUR_ACCOUNT + account-key: YOUR_KEY + container: training-images +EOF + +# 挂载 +mkdir -p /mnt/azure-blob +blobfuse2 mount /mnt/azure-blob --config-file=~/blobfuse-config.yaml +``` + +### 本地开发(Windows) + +```powershell +# 安装 +winget install WinFsp.WinFsp +winget install Rclone.Rclone + +# 配置 +rclone config # 选择 azureblob + +# 挂载为 Z: 盘 +rclone mount azure:training-images Z: --vfs-cache-mode full +``` + +### 存储费用 + +| 层级 | 价格 | 适用场景 | +|------|------|---------| +| Hot | $0.018/GB/月 | 频繁访问 | +| Cool | $0.01/GB/月 | 偶尔访问 | +| Archive | $0.002/GB/月 | 长期存档 | + +**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.09/月** + +--- + +## 训练方案 + +### 方案总览 + +| 方案 | 适用场景 | 空闲费用 | 复杂度 | +|------|---------|---------|--------| +| Azure VM | 简单直接 | 24/7 收费 | 低 | +| Azure VM Spot | 省钱、可中断 | 24/7 收费 | 低 | +| Azure ML Compute | MLOps 集成 | 可缩到 0 | 中 | +| Container Apps GPU | Serverless | 自动缩到 0 | 中 | + +### Azure VM vs Azure ML + +| 特性 | Azure VM | Azure ML | +|------|----------|----------| +| 本质 | 虚拟机 | 托管 ML 平台 | +| 计算费用 | $3.06/hr (NC6s_v3) | $3.06/hr (相同) | +| 附加费用 | ~$5/月 | ~$20-30/月 | +| 实验跟踪 | 无 | 内置 | +| 自动扩缩 | 无 | 支持 min=0 | +| 适用人群 | DevOps | 数据科学家 | + +### Azure ML 附加费用明细 + +| 服务 | 用途 | 费用 | +|------|------|------| +| Container Registry | Docker 镜像 | ~$5-20/月 | +| Blob Storage | 日志、模型 | ~$0.10/月 | +| Application Insights | 监控 | ~$0-10/月 | +| Key Vault | 密钥管理 | <$1/月 | + +### Spot 实例 + +两种平台都支持 Spot/低优先级实例,最高节省 90%: + +| 类型 | 正常价格 | Spot 价格 | 节省 | +|------|---------|----------|------| +| NC6s_v3 (V100) | $3.06/hr | ~$0.92/hr | 70% | +| NC24ads_A100_v4 | $3.67/hr | ~$1.15/hr | 69% | + +### GPU 实例价格 + +| 实例 | GPU | 显存 | 价格/小时 | Spot 价格 | +|------|-----|------|---------|----------| +| NC6s_v3 | 1x V100 | 16GB | $3.06 | $0.92 | +| NC24s_v3 | 4x V100 | 64GB | $12.24 | $3.67 | +| NC24ads_A100_v4 | 1x A100 | 80GB | $3.67 | $1.15 | +| NC48ads_A100_v4 | 2x A100 | 160GB | $7.35 | $2.30 | + +--- + +## 推理方案 + +### 方案对比 + +| 方案 | GPU 支持 | 扩缩容 | 价格 | 适用场景 | +|------|---------|--------|------|---------| +| Container Apps (CPU) | 否 | 自动 0-N | ~$30/月 | YOLO 推理 (够用) | +| Container Apps (GPU) | 是 | Serverless | 按秒计费 | 高吞吐推理 | +| Azure App Service | 否 | 手动/自动 | ~$50/月 | 简单部署 | +| Azure ML Endpoint | 是 | 自动 | ~$100+/月 | MLOps 集成 | +| AKS (Kubernetes) | 是 | 自动 | 复杂计费 | 大规模生产 | + +### 推荐: Container Apps (CPU) + +对于 YOLO 推理,**CPU 足够**,不需要 GPU: +- YOLOv11n 在 CPU 上推理时间 ~200-500ms +- 比 GPU 便宜很多,适合中低流量 + +```yaml +# Container Apps 配置 +name: invoice-inference +image: myacr.azurecr.io/invoice-inference:v1 +resources: + cpu: 2.0 + memory: 4Gi +scale: + minReplicas: 1 # 最少 1 个实例保持响应 + maxReplicas: 10 # 最多扩展到 10 个 + rules: + - name: http-scaling + http: + metadata: + concurrentRequests: "50" # 每实例 50 并发时扩容 +``` + +### 推理服务代码示例 + +```python +# Dockerfile +FROM python:3.11-slim + +WORKDIR /app + +# 安装依赖 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制代码和模型 +COPY src/ ./src/ +COPY models/best.pt ./models/ + +# 启动服务 +CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +```python +# src/web/app.py +from fastapi import FastAPI, UploadFile, File +from ultralytics import YOLO +import tempfile + +app = FastAPI() +model = YOLO("models/best.pt") + +@app.post("/api/v1/infer") +async def infer(file: UploadFile = File(...)): + # 保存上传文件 + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + # 执行推理 + results = model.predict(tmp_path, conf=0.5) + + # 返回结果 + return { + "fields": extract_fields(results), + "confidence": get_confidence(results) + } + +@app.get("/health") +async def health(): + return {"status": "healthy"} +``` + +### 部署命令 + +```bash +# 1. 创建 Container Registry +az acr create --name invoiceacr --resource-group myRG --sku Basic + +# 2. 构建并推送镜像 +az acr build --registry invoiceacr --image invoice-inference:v1 . + +# 3. 创建 Container Apps 环境 +az containerapp env create \ + --name invoice-env \ + --resource-group myRG \ + --location eastus + +# 4. 部署应用 +az containerapp create \ + --name invoice-inference \ + --resource-group myRG \ + --environment invoice-env \ + --image invoiceacr.azurecr.io/invoice-inference:v1 \ + --registry-server invoiceacr.azurecr.io \ + --cpu 2 --memory 4Gi \ + --min-replicas 1 --max-replicas 10 \ + --ingress external --target-port 8000 + +# 5. 获取 URL +az containerapp show --name invoice-inference --resource-group myRG --query properties.configuration.ingress.fqdn +``` + +### 高吞吐场景: Serverless GPU + +如果需要 GPU 加速推理(高并发、低延迟): + +```bash +# 请求 GPU 配额 +az containerapp env workload-profile add \ + --name invoice-env \ + --resource-group myRG \ + --workload-profile-name gpu \ + --workload-profile-type Consumption-GPU-T4 + +# 部署 GPU 版本 +az containerapp create \ + --name invoice-inference-gpu \ + --resource-group myRG \ + --environment invoice-env \ + --image invoiceacr.azurecr.io/invoice-inference-gpu:v1 \ + --workload-profile-name gpu \ + --cpu 4 --memory 8Gi \ + --min-replicas 0 --max-replicas 5 \ + --ingress external --target-port 8000 +``` + +### 推理性能对比 + +| 配置 | 单次推理时间 | 并发能力 | 月费估算 | +|------|------------|---------|---------| +| CPU 2核 4GB | ~300-500ms | ~50 QPS | ~$30 | +| CPU 4核 8GB | ~200-300ms | ~100 QPS | ~$60 | +| GPU T4 | ~50-100ms | ~200 QPS | 按秒计费 | +| GPU A100 | ~20-50ms | ~500 QPS | 按秒计费 | + +--- + +## 价格对比 + +### 月度成本对比(假设每天训练 2 小时) + +| 方案 | 计算方式 | 月费 | +|------|---------|------| +| VM 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 | +| VM 按需启停 | 2h × 30天 × $3.06 | ~$184 | +| VM Spot 按需 | 2h × 30天 × $0.92 | ~$55 | +| Serverless GPU | 2h × 30天 × ~$3.50 | ~$210 | +| Azure ML (min=0) | 2h × 30天 × $3.06 | ~$184 | + +### 本项目完整成本估算 + +| 组件 | 推荐方案 | 月费 | +|------|---------|------| +| 图片存储 | Blob Storage (Hot) | ~$0.10 | +| 数据库 | PostgreSQL Flexible (Burstable B1ms) | ~$25 | +| 推理服务 | Container Apps CPU (2核4GB) | ~$30 | +| 训练服务 | Azure ML Spot (按需) | ~$1-5/次 | +| Container Registry | Basic | ~$5 | +| **总计** | | **~$65/月** + 训练费 | + +--- + +## 推荐架构 + +### 整体架构图 + +``` + ┌─────────────────────────────────────┐ + │ Azure Blob Storage │ + │ ├── training-images/ │ + │ ├── datasets/ │ + │ └── models/ │ + └─────────────────┬───────────────────┘ + │ + ┌─────────────────────────────────┼─────────────────────────────────┐ + │ │ │ + ▼ ▼ ▼ +┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐ +│ 推理服务 (24/7) │ │ 训练服务 (按需) │ │ Web UI (可选) │ +│ Container Apps │ │ Azure ML Compute │ │ Static Web Apps │ +│ CPU 2核 4GB │ │ min=0, Spot │ │ ~$0 (免费层) │ +│ ~$30/月 │ │ ~$1-5/次训练 │ │ │ +│ │ │ │ │ │ +│ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │ +│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │ +│ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │ +│ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │ +└───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘ + │ │ │ + └───────────────────────────────┼───────────────────────────────┘ + │ + ▼ + ┌───────────────────────┐ + │ PostgreSQL │ + │ Flexible Server │ + │ Burstable B1ms │ + │ ~$25/月 │ + └───────────────────────┘ +``` + +### 推理服务配置 + +```yaml +# Container Apps - CPU (24/7 运行) +name: invoice-inference +resources: + cpu: 2 + memory: 4Gi +scale: + minReplicas: 1 + maxReplicas: 10 +env: + - name: MODEL_PATH + value: /app/models/best.pt + - name: DB_HOST + secretRef: db-host + - name: DB_PASSWORD + secretRef: db-password +``` + +### 训练服务配置 + +**方案 A: Azure ML Compute(推荐)** + +```python +from azure.ai.ml.entities import AmlCompute + +gpu_cluster = AmlCompute( + name="gpu-cluster", + size="Standard_NC6s_v3", + min_instances=0, # 空闲时关机 + max_instances=1, + tier="LowPriority", # Spot 实例 + idle_time_before_scale_down=120 +) +``` + +**方案 B: Container Apps Serverless GPU** + +```yaml +name: invoice-training +resources: + gpu: 1 + gpuType: A100 +scale: + minReplicas: 0 + maxReplicas: 1 +``` + +--- + +## 实施步骤 + +### 阶段 1: 存储设置 + +```bash +# 创建 Storage Account +az storage account create \ + --name invoicestorage \ + --resource-group myRG \ + --sku Standard_LRS + +# 创建容器 +az storage container create --name training-images --account-name invoicestorage +az storage container create --name datasets --account-name invoicestorage +az storage container create --name models --account-name invoicestorage + +# 上传训练数据 +az storage blob upload-batch \ + --destination training-images \ + --source ./data/dataset/temp \ + --account-name invoicestorage +``` + +### 阶段 2: 数据库设置 + +```bash +# 创建 PostgreSQL +az postgres flexible-server create \ + --name invoice-db \ + --resource-group myRG \ + --sku-name Standard_B1ms \ + --storage-size 32 \ + --admin-user docmaster \ + --admin-password YOUR_PASSWORD + +# 配置防火墙 +az postgres flexible-server firewall-rule create \ + --name allow-azure \ + --resource-group myRG \ + --server-name invoice-db \ + --start-ip-address 0.0.0.0 \ + --end-ip-address 0.0.0.0 +``` + +### 阶段 3: 推理服务部署 + +```bash +# 创建 Container Registry +az acr create --name invoiceacr --resource-group myRG --sku Basic + +# 构建镜像 +az acr build --registry invoiceacr --image invoice-inference:v1 . + +# 创建环境 +az containerapp env create \ + --name invoice-env \ + --resource-group myRG \ + --location eastus + +# 部署推理服务 +az containerapp create \ + --name invoice-inference \ + --resource-group myRG \ + --environment invoice-env \ + --image invoiceacr.azurecr.io/invoice-inference:v1 \ + --registry-server invoiceacr.azurecr.io \ + --cpu 2 --memory 4Gi \ + --min-replicas 1 --max-replicas 10 \ + --ingress external --target-port 8000 \ + --env-vars \ + DB_HOST=invoice-db.postgres.database.azure.com \ + DB_NAME=docmaster \ + DB_USER=docmaster \ + --secrets db-password=YOUR_PASSWORD +``` + +### 阶段 4: 训练服务设置 + +```bash +# 创建 Azure ML Workspace +az ml workspace create --name invoice-ml --resource-group myRG + +# 创建 Compute Cluster +az ml compute create --name gpu-cluster \ + --type AmlCompute \ + --size Standard_NC6s_v3 \ + --min-instances 0 \ + --max-instances 1 \ + --tier low_priority +``` + +### 阶段 5: 集成训练触发 API + +```python +# src/web/routes/training.py +from fastapi import APIRouter +from azure.ai.ml import MLClient, command +from azure.identity import DefaultAzureCredential + +router = APIRouter() + +ml_client = MLClient( + credential=DefaultAzureCredential(), + subscription_id="your-subscription-id", + resource_group_name="myRG", + workspace_name="invoice-ml" +) + +@router.post("/api/v1/train") +async def trigger_training(request: TrainingRequest): + """触发 Azure ML 训练任务""" + training_job = command( + code="./training", + command=f"python train.py --epochs {request.epochs}", + environment="AzureML-pytorch-2.0-cuda11.8@latest", + compute="gpu-cluster", + ) + job = ml_client.jobs.create_or_update(training_job) + return { + "job_id": job.name, + "status": job.status, + "studio_url": job.studio_url + } + +@router.get("/api/v1/train/{job_id}/status") +async def get_training_status(job_id: str): + """查询训练状态""" + job = ml_client.jobs.get(job_id) + return {"status": job.status} +``` + +--- + +## 总结 + +### 推荐配置 + +| 组件 | 推荐方案 | 月费估算 | +|------|---------|---------| +| 图片存储 | Blob Storage (Hot) | ~$0.10 | +| 数据库 | PostgreSQL Flexible | ~$25 | +| 推理服务 | Container Apps CPU | ~$30 | +| 训练服务 | Azure ML (min=0, Spot) | 按需 ~$1-5/次 | +| Container Registry | Basic | ~$5 | +| **总计** | | **~$65/月** + 训练费 | + +### 关键决策 + +| 场景 | 选择 | +|------|------| +| 偶尔训练,简单需求 | Azure VM Spot + 手动启停 | +| 需要 MLOps,团队协作 | Azure ML Compute | +| 追求最低空闲成本 | Container Apps Serverless GPU | +| 生产环境推理 | Container Apps CPU | +| 高并发推理 | Container Apps Serverless GPU | + +### 注意事项 + +1. **冷启动**: Serverless GPU 启动需要 3-8 分钟 +2. **Spot 中断**: 可能被抢占,需要检查点机制 +3. **网络延迟**: Blob Storage 挂载比本地 SSD 慢,建议开启缓存 +4. **区域选择**: 选择有 GPU 配额的区域 (East US, West Europe 等) +5. **推理优化**: CPU 推理对于 YOLO 已经足够,无需 GPU diff --git a/frontend/src/api/endpoints/documents.ts b/frontend/src/api/endpoints/documents.ts index 75367ed..9db85ac 100644 --- a/frontend/src/api/endpoints/documents.ts +++ b/frontend/src/api/endpoints/documents.ts @@ -4,11 +4,13 @@ import type { DocumentDetailResponse, DocumentItem, UploadDocumentResponse, + DocumentCategoriesResponse, } from '../types' export const documentsApi = { list: async (params?: { status?: string + category?: string limit?: number offset?: number }): Promise => { @@ -16,18 +18,29 @@ export const documentsApi = { return data }, + getCategories: async (): Promise => { + const { data } = await apiClient.get('/api/v1/admin/documents/categories') + return data + }, + getDetail: async (documentId: string): Promise => { const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`) return data }, - upload: async (file: File, groupKey?: string): Promise => { + upload: async ( + file: File, + options?: { groupKey?: string; category?: string } + ): Promise => { const formData = new FormData() formData.append('file', file) const params: Record = {} - if (groupKey) { - params.group_key = groupKey + if (options?.groupKey) { + params.group_key = options.groupKey + } + if (options?.category) { + params.category = options.category } const { data } = await apiClient.post('/api/v1/admin/documents', formData, { @@ -95,4 +108,15 @@ export const documentsApi = { ) return data }, + + updateCategory: async ( + documentId: string, + category: string + ): Promise<{ status: string; document_id: string; category: string; message: string }> => { + const { data } = await apiClient.patch( + `/api/v1/admin/documents/${documentId}/category`, + { category } + ) + return data + }, } diff --git a/frontend/src/api/types.ts b/frontend/src/api/types.ts index 73908ca..7ceda95 100644 --- a/frontend/src/api/types.ts +++ b/frontend/src/api/types.ts @@ -9,6 +9,7 @@ export interface DocumentItem { auto_label_error: string | null upload_source: string group_key: string | null + category: string created_at: string updated_at: string annotation_count?: number @@ -61,6 +62,7 @@ export interface DocumentDetailResponse { upload_source: string batch_id: string | null group_key: string | null + category: string csv_field_values: Record | null can_annotate: boolean annotation_lock_until: string | null @@ -101,8 +103,21 @@ export interface TrainingTask { updated_at: string } +export interface ModelVersionItem { + version_id: string + version: string + name: string + status: string + is_active: boolean + metrics_mAP: number | null + document_count: number + trained_at: string | null + activated_at: string | null + created_at: string +} + export interface TrainingModelsResponse { - models: TrainingTask[] + models: ModelVersionItem[] total: number limit: number offset: number @@ -118,11 +133,17 @@ export interface UploadDocumentResponse { file_size: number page_count: number status: string + category: string group_key: string | null auto_label_started: boolean message: string } +export interface DocumentCategoriesResponse { + categories: string[] + total: number +} + export interface CreateAnnotationRequest { page_number: number class_id: number @@ -228,6 +249,8 @@ export interface DatasetDetailResponse { name: string description: string | null status: string + training_status: string | null + active_training_task_id: string | null train_ratio: number val_ratio: number seed: number diff --git a/frontend/src/components/Dashboard.tsx b/frontend/src/components/Dashboard.tsx index 5572734..601a949 100644 --- a/frontend/src/components/Dashboard.tsx +++ b/frontend/src/components/Dashboard.tsx @@ -3,7 +3,7 @@ import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react' import { Badge } from './Badge' import { Button } from './Button' import { UploadModal } from './UploadModal' -import { useDocuments } from '../hooks/useDocuments' +import { useDocuments, useCategories } from '../hooks/useDocuments' import type { DocumentItem } from '../api/types' interface DashboardProps { @@ -34,11 +34,15 @@ export const Dashboard: React.FC = ({ onNavigate }) => { const [isUploadOpen, setIsUploadOpen] = useState(false) const [selectedDocs, setSelectedDocs] = useState>(new Set()) const [statusFilter, setStatusFilter] = useState('') + const [categoryFilter, setCategoryFilter] = useState('') const [limit] = useState(20) const [offset] = useState(0) + const { categories } = useCategories() + const { documents, total, isLoading, error, refetch } = useDocuments({ status: statusFilter || undefined, + category: categoryFilter || undefined, limit, offset, }) @@ -102,6 +106,24 @@ export const Dashboard: React.FC = ({ onNavigate }) => {
+
+ + +
@@ -293,8 +294,12 @@ const DatasetList: React.FC<{ )}
diff --git a/frontend/src/components/UploadModal.tsx b/frontend/src/components/UploadModal.tsx index f76df93..11658fb 100644 --- a/frontend/src/components/UploadModal.tsx +++ b/frontend/src/components/UploadModal.tsx @@ -1,7 +1,7 @@ import React, { useState, useRef } from 'react' -import { X, UploadCloud, File, CheckCircle, AlertCircle } from 'lucide-react' +import { X, UploadCloud, File, CheckCircle, AlertCircle, ChevronDown } from 'lucide-react' import { Button } from './Button' -import { useDocuments } from '../hooks/useDocuments' +import { useDocuments, useCategories } from '../hooks/useDocuments' interface UploadModalProps { isOpen: boolean @@ -12,11 +12,13 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) => const [isDragging, setIsDragging] = useState(false) const [selectedFiles, setSelectedFiles] = useState([]) const [groupKey, setGroupKey] = useState('') + const [category, setCategory] = useState('invoice') const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle') const [errorMessage, setErrorMessage] = useState('') const fileInputRef = useRef(null) const { uploadDocument, isUploading } = useDocuments({}) + const { categories } = useCategories() if (!isOpen) return null @@ -63,7 +65,7 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) => for (const file of selectedFiles) { await new Promise((resolve, reject) => { uploadDocument( - { file, groupKey: groupKey || undefined }, + { file, groupKey: groupKey || undefined, category: category || 'invoice' }, { onSuccess: () => resolve(), onError: (error: Error) => reject(error), @@ -77,6 +79,7 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) => onClose() setSelectedFiles([]) setGroupKey('') + setCategory('invoice') setUploadStatus('idle') }, 1500) } catch (error) { @@ -91,6 +94,7 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) => } setSelectedFiles([]) setGroupKey('') + setCategory('invoice') setUploadStatus('idle') setErrorMessage('') onClose() @@ -179,6 +183,42 @@ export const UploadModal: React.FC = ({ isOpen, onClose }) =>
)} + {/* Category Select */} + {selectedFiles.length > 0 && ( +
+ +
+ + +
+

+ Select document type for training different models +

+
+ )} + {/* Group Key Input */} {selectedFiles.length > 0 && (
diff --git a/frontend/src/hooks/index.ts b/frontend/src/hooks/index.ts index d1b5a7b..fb72d26 100644 --- a/frontend/src/hooks/index.ts +++ b/frontend/src/hooks/index.ts @@ -1,4 +1,4 @@ -export { useDocuments } from './useDocuments' +export { useDocuments, useCategories } from './useDocuments' export { useDocumentDetail } from './useDocumentDetail' export { useAnnotations } from './useAnnotations' export { useTraining, useTrainingDocuments } from './useTraining' diff --git a/frontend/src/hooks/useDocuments.ts b/frontend/src/hooks/useDocuments.ts index b75a126..1fe81ce 100644 --- a/frontend/src/hooks/useDocuments.ts +++ b/frontend/src/hooks/useDocuments.ts @@ -1,9 +1,10 @@ import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query' import { documentsApi } from '../api/endpoints' -import type { DocumentListResponse, UploadDocumentResponse } from '../api/types' +import type { DocumentListResponse, DocumentCategoriesResponse } from '../api/types' interface UseDocumentsParams { status?: string + category?: string limit?: number offset?: number } @@ -18,10 +19,11 @@ export const useDocuments = (params: UseDocumentsParams = {}) => { }) const uploadMutation = useMutation({ - mutationFn: ({ file, groupKey }: { file: File; groupKey?: string }) => - documentsApi.upload(file, groupKey), + mutationFn: ({ file, groupKey, category }: { file: File; groupKey?: string; category?: string }) => + documentsApi.upload(file, { groupKey, category }), onSuccess: () => { queryClient.invalidateQueries({ queryKey: ['documents'] }) + queryClient.invalidateQueries({ queryKey: ['categories'] }) }, }) @@ -63,6 +65,15 @@ export const useDocuments = (params: UseDocumentsParams = {}) => { }, }) + const updateCategoryMutation = useMutation({ + mutationFn: ({ documentId, category }: { documentId: string; category: string }) => + documentsApi.updateCategory(documentId, category), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['documents'] }) + queryClient.invalidateQueries({ queryKey: ['categories'] }) + }, + }) + return { documents: data?.documents || [], total: data?.total || 0, @@ -86,5 +97,24 @@ export const useDocuments = (params: UseDocumentsParams = {}) => { updateGroupKey: updateGroupKeyMutation.mutate, updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync, isUpdatingGroupKey: updateGroupKeyMutation.isPending, + updateCategory: updateCategoryMutation.mutate, + updateCategoryAsync: updateCategoryMutation.mutateAsync, + isUpdatingCategory: updateCategoryMutation.isPending, + } +} + +export const useCategories = () => { + const { data, isLoading, error, refetch } = useQuery({ + queryKey: ['categories'], + queryFn: () => documentsApi.getCategories(), + staleTime: 60000, + }) + + return { + categories: data?.categories || [], + total: data?.total || 0, + isLoading, + error, + refetch, } } diff --git a/migrations/009_add_document_category.sql b/migrations/009_add_document_category.sql new file mode 100644 index 0000000..62cbd50 --- /dev/null +++ b/migrations/009_add_document_category.sql @@ -0,0 +1,13 @@ +-- Add category column to admin_documents table +-- Allows categorizing documents for training different models (e.g., invoice, letter, receipt) + +ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice'; + +-- Update existing NULL values to default +UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL; + +-- Make it NOT NULL after setting defaults +ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL; + +-- Create index for category filtering +CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category); diff --git a/migrations/010_add_dataset_training_status.sql b/migrations/010_add_dataset_training_status.sql new file mode 100644 index 0000000..7752569 --- /dev/null +++ b/migrations/010_add_dataset_training_status.sql @@ -0,0 +1,28 @@ +-- Migration: Add training_status and active_training_task_id to training_datasets +-- Description: Track training status separately from dataset build status + +-- Add training_status column +ALTER TABLE training_datasets +ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL; + +-- Add active_training_task_id column +ALTER TABLE training_datasets +ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL; + +-- Create index for training_status +CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status +ON training_datasets(training_status); + +-- Create index for active_training_task_id +CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id +ON training_datasets(active_training_task_id); + +-- Update existing datasets that have been used in completed training tasks to 'trained' status +UPDATE training_datasets d +SET status = 'trained' +WHERE d.status = 'ready' +AND EXISTS ( + SELECT 1 FROM training_tasks t + WHERE t.dataset_id = d.dataset_id + AND t.status = 'completed' +); diff --git a/packages/inference/inference/cli/serve.py b/packages/inference/inference/cli/serve.py index cb71ce4..ddc572c 100644 --- a/packages/inference/inference/cli/serve.py +++ b/packages/inference/inference/cli/serve.py @@ -120,7 +120,7 @@ def main() -> None: logger.info("=" * 60) # Create config - from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig + from inference.web.config import AppConfig, ModelConfig, ServerConfig, FileConfig config = AppConfig( model=ModelConfig( @@ -136,7 +136,7 @@ def main() -> None: reload=args.reload, workers=args.workers, ), - storage=StorageConfig(), + file=FileConfig(), ) # Create and run app diff --git a/packages/inference/inference/data/admin_db.py b/packages/inference/inference/data/admin_db.py index 9f9765d..62cb5f8 100644 --- a/packages/inference/inference/data/admin_db.py +++ b/packages/inference/inference/data/admin_db.py @@ -112,6 +112,7 @@ class AdminDB: upload_source: str = "ui", csv_field_values: dict[str, Any] | None = None, group_key: str | None = None, + category: str = "invoice", admin_token: str | None = None, # Deprecated, kept for compatibility ) -> str: """Create a new document record.""" @@ -125,6 +126,7 @@ class AdminDB: upload_source=upload_source, csv_field_values=csv_field_values, group_key=group_key, + category=category, ) session.add(document) session.flush() @@ -154,6 +156,7 @@ class AdminDB: has_annotations: bool | None = None, auto_label_status: str | None = None, batch_id: str | None = None, + category: str | None = None, limit: int = 20, offset: int = 0, ) -> tuple[list[AdminDocument], int]: @@ -171,6 +174,8 @@ class AdminDB: where_clauses.append(AdminDocument.auto_label_status == auto_label_status) if batch_id: where_clauses.append(AdminDocument.batch_id == UUID(batch_id)) + if category: + where_clauses.append(AdminDocument.category == category) # Count query count_stmt = select(func.count()).select_from(AdminDocument) @@ -283,6 +288,32 @@ class AdminDB: return True return False + def get_document_categories(self) -> list[str]: + """Get list of unique document categories.""" + with get_session_context() as session: + statement = ( + select(AdminDocument.category) + .distinct() + .order_by(AdminDocument.category) + ) + categories = session.exec(statement).all() + return [c for c in categories if c is not None] + + def update_document_category( + self, document_id: str, category: str + ) -> AdminDocument | None: + """Update document category.""" + with get_session_context() as session: + document = session.get(AdminDocument, UUID(document_id)) + if document: + document.category = category + document.updated_at = datetime.utcnow() + session.add(document) + session.commit() + session.refresh(document) + return document + return None + # ========================================================================== # Annotation Operations # ========================================================================== @@ -1292,6 +1323,36 @@ class AdminDB: session.add(dataset) session.commit() + def update_dataset_training_status( + self, + dataset_id: str | UUID, + training_status: str | None, + active_training_task_id: str | UUID | None = None, + update_main_status: bool = False, + ) -> None: + """Update dataset training status and optionally the main status. + + Args: + dataset_id: Dataset UUID + training_status: Training status (pending, running, completed, failed, cancelled) + active_training_task_id: Currently active training task ID + update_main_status: If True and training_status is 'completed', set main status to 'trained' + """ + with get_session_context() as session: + dataset = session.get(TrainingDataset, UUID(str(dataset_id))) + if not dataset: + return + dataset.training_status = training_status + dataset.active_training_task_id = ( + UUID(str(active_training_task_id)) if active_training_task_id else None + ) + dataset.updated_at = datetime.utcnow() + # Update main status to 'trained' when training completes + if update_main_status and training_status == "completed": + dataset.status = "trained" + session.add(dataset) + session.commit() + def add_dataset_documents( self, dataset_id: str | UUID, diff --git a/packages/inference/inference/data/admin_models.py b/packages/inference/inference/data/admin_models.py index ca7e5f2..2639680 100644 --- a/packages/inference/inference/data/admin_models.py +++ b/packages/inference/inference/data/admin_models.py @@ -11,23 +11,8 @@ from uuid import UUID, uuid4 from sqlmodel import Field, SQLModel, Column, JSON - -# ============================================================================= -# CSV to Field Class Mapping -# ============================================================================= - -CSV_TO_CLASS_MAPPING: dict[str, int] = { - "InvoiceNumber": 0, # invoice_number - "InvoiceDate": 1, # invoice_date - "InvoiceDueDate": 2, # invoice_due_date - "OCR": 3, # ocr_number - "Bankgiro": 4, # bankgiro - "Plusgiro": 5, # plusgiro - "Amount": 6, # amount - "supplier_organisation_number": 7, # supplier_organisation_number - # 8: payment_line (derived from OCR/Bankgiro/Amount) - "customer_number": 9, # customer_number -} +# Import field mappings from single source of truth +from shared.fields import CSV_TO_CLASS_MAPPING, FIELD_CLASSES, FIELD_CLASS_IDS # ============================================================================= @@ -72,6 +57,8 @@ class AdminDocument(SQLModel, table=True): # Link to batch upload (if uploaded via ZIP) group_key: str | None = Field(default=None, max_length=255, index=True) # User-defined grouping key for document organization + category: str = Field(default="invoice", max_length=100, index=True) + # Document category for training different models (e.g., invoice, letter, receipt) csv_field_values: dict[str, Any] | None = Field(default=None, sa_column=Column(JSON)) # Original CSV values for reference auto_label_queued_at: datetime | None = Field(default=None) @@ -237,7 +224,10 @@ class TrainingDataset(SQLModel, table=True): name: str = Field(max_length=255) description: str | None = Field(default=None) status: str = Field(default="building", max_length=20, index=True) - # Status: building, ready, training, archived, failed + # Status: building, ready, trained, archived, failed + training_status: str | None = Field(default=None, max_length=20, index=True) + # Training status: pending, scheduled, running, completed, failed, cancelled + active_training_task_id: UUID | None = Field(default=None, index=True) train_ratio: float = Field(default=0.8) val_ratio: float = Field(default=0.1) seed: int = Field(default=42) @@ -354,21 +344,8 @@ class AnnotationHistory(SQLModel, table=True): created_at: datetime = Field(default_factory=datetime.utcnow, index=True) -# Field class mapping (same as src/cli/train.py) -FIELD_CLASSES = { - 0: "invoice_number", - 1: "invoice_date", - 2: "invoice_due_date", - 3: "ocr_number", - 4: "bankgiro", - 5: "plusgiro", - 6: "amount", - 7: "supplier_organisation_number", - 8: "payment_line", - 9: "customer_number", -} - -FIELD_CLASS_IDS = {v: k for k, v in FIELD_CLASSES.items()} +# FIELD_CLASSES and FIELD_CLASS_IDS are now imported from shared.fields +# This ensures consistency with the trained YOLO model # Read-only models for API responses @@ -383,6 +360,7 @@ class AdminDocumentRead(SQLModel): status: str auto_label_status: str | None auto_label_error: str | None + category: str = "invoice" created_at: datetime updated_at: datetime diff --git a/packages/inference/inference/data/database.py b/packages/inference/inference/data/database.py index 15b4c14..656636e 100644 --- a/packages/inference/inference/data/database.py +++ b/packages/inference/inference/data/database.py @@ -141,6 +141,40 @@ def run_migrations() -> None: CREATE INDEX IF NOT EXISTS ix_model_versions_dataset_id ON model_versions(dataset_id); """, ), + # Migration 009: Add category to admin_documents + ( + "admin_documents_category", + """ + ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice'; + UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL; + ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL; + CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category); + """, + ), + # Migration 010: Add training_status and active_training_task_id to training_datasets + ( + "training_datasets_training_status", + """ + ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL; + ALTER TABLE training_datasets ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL; + CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status); + CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id); + """, + ), + # Migration 010b: Update existing datasets with completed training to 'trained' status + ( + "training_datasets_update_trained_status", + """ + UPDATE training_datasets d + SET status = 'trained' + WHERE d.status = 'ready' + AND EXISTS ( + SELECT 1 FROM training_tasks t + WHERE t.dataset_id = d.dataset_id + AND t.status = 'completed' + ); + """, + ), ] with engine.connect() as conn: diff --git a/packages/inference/inference/pipeline/field_extractor.py b/packages/inference/inference/pipeline/field_extractor.py index 4846f1c..2db644f 100644 --- a/packages/inference/inference/pipeline/field_extractor.py +++ b/packages/inference/inference/pipeline/field_extractor.py @@ -21,7 +21,8 @@ import re import numpy as np from PIL import Image -from .yolo_detector import Detection, CLASS_TO_FIELD +from shared.fields import CLASS_TO_FIELD +from .yolo_detector import Detection # Import shared utilities for text cleaning and validation from shared.utils.text_cleaner import TextCleaner diff --git a/packages/inference/inference/pipeline/pipeline.py b/packages/inference/inference/pipeline/pipeline.py index c9ade47..9c968e0 100644 --- a/packages/inference/inference/pipeline/pipeline.py +++ b/packages/inference/inference/pipeline/pipeline.py @@ -10,7 +10,8 @@ from typing import Any import time import re -from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD +from shared.fields import CLASS_TO_FIELD +from .yolo_detector import YOLODetector, Detection from .field_extractor import FieldExtractor, ExtractedField from .payment_line_parser import PaymentLineParser diff --git a/packages/inference/inference/pipeline/yolo_detector.py b/packages/inference/inference/pipeline/yolo_detector.py index 395dc7c..3b1f7d3 100644 --- a/packages/inference/inference/pipeline/yolo_detector.py +++ b/packages/inference/inference/pipeline/yolo_detector.py @@ -9,6 +9,9 @@ from pathlib import Path from typing import Any import numpy as np +# Import field mappings from single source of truth +from shared.fields import CLASS_NAMES, CLASS_TO_FIELD + @dataclass class Detection: @@ -72,33 +75,8 @@ class Detection: return (x0, y0, x1, y1) -# Class names (must match training configuration) -CLASS_NAMES = [ - 'invoice_number', - 'invoice_date', - 'invoice_due_date', - 'ocr_number', - 'bankgiro', - 'plusgiro', - 'amount', - 'supplier_org_number', # Matches training class name - 'customer_number', - 'payment_line', # Machine code payment line at bottom of invoice -] - -# Mapping from class name to field name -CLASS_TO_FIELD = { - 'invoice_number': 'InvoiceNumber', - 'invoice_date': 'InvoiceDate', - 'invoice_due_date': 'InvoiceDueDate', - 'ocr_number': 'OCR', - 'bankgiro': 'Bankgiro', - 'plusgiro': 'Plusgiro', - 'amount': 'Amount', - 'supplier_org_number': 'supplier_org_number', - 'customer_number': 'customer_number', - 'payment_line': 'payment_line', -} +# CLASS_NAMES and CLASS_TO_FIELD are now imported from shared.fields +# This ensures consistency with the trained YOLO model class YOLODetector: diff --git a/packages/inference/inference/web/api/v1/admin/annotations.py b/packages/inference/inference/web/api/v1/admin/annotations.py index 609db93..751cdc0 100644 --- a/packages/inference/inference/web/api/v1/admin/annotations.py +++ b/packages/inference/inference/web/api/v1/admin/annotations.py @@ -4,18 +4,19 @@ Admin Annotation API Routes FastAPI endpoints for annotation management. """ +import io import logging -from pathlib import Path from typing import Annotated from uuid import UUID from fastapi import APIRouter, HTTPException, Query -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, StreamingResponse from inference.data.admin_db import AdminDB -from inference.data.admin_models import FIELD_CLASSES, FIELD_CLASS_IDS +from shared.fields import FIELD_CLASSES, FIELD_CLASS_IDS from inference.web.core.auth import AdminTokenDep, AdminDBDep from inference.web.services.autolabel import get_auto_label_service +from inference.web.services.storage_helpers import get_storage_helper from inference.web.schemas.admin import ( AnnotationCreate, AnnotationItem, @@ -35,9 +36,6 @@ from inference.web.schemas.common import ErrorResponse logger = logging.getLogger(__name__) -# Image storage directory -ADMIN_IMAGES_DIR = Path("data/admin_images") - def _validate_uuid(value: str, name: str = "ID") -> None: """Validate UUID format.""" @@ -60,7 +58,9 @@ def create_annotation_router() -> APIRouter: @router.get( "/{document_id}/images/{page_number}", + response_model=None, responses={ + 200: {"content": {"image/png": {}}, "description": "Page image"}, 401: {"model": ErrorResponse, "description": "Invalid token"}, 404: {"model": ErrorResponse, "description": "Not found"}, }, @@ -72,7 +72,7 @@ def create_annotation_router() -> APIRouter: page_number: int, admin_token: AdminTokenDep, db: AdminDBDep, - ) -> FileResponse: + ) -> FileResponse | StreamingResponse: """Get page image.""" _validate_uuid(document_id, "document_id") @@ -91,18 +91,33 @@ def create_annotation_router() -> APIRouter: detail=f"Page {page_number} not found. Document has {document.page_count} pages.", ) - # Find image file - image_path = ADMIN_IMAGES_DIR / document_id / f"page_{page_number}.png" - if not image_path.exists(): + # Get storage helper + storage = get_storage_helper() + + # Check if image exists + if not storage.admin_image_exists(document_id, page_number): raise HTTPException( status_code=404, detail=f"Image for page {page_number} not found", ) - return FileResponse( - path=str(image_path), + # Try to get local path for efficient file serving + local_path = storage.get_admin_image_local_path(document_id, page_number) + if local_path is not None: + return FileResponse( + path=str(local_path), + media_type="image/png", + filename=f"{document.filename}_page_{page_number}.png", + ) + + # Fall back to streaming for cloud storage + image_content = storage.get_admin_image(document_id, page_number) + return StreamingResponse( + io.BytesIO(image_content), media_type="image/png", - filename=f"{document.filename}_page_{page_number}.png", + headers={ + "Content-Disposition": f'inline; filename="{document.filename}_page_{page_number}.png"' + }, ) # ========================================================================= @@ -210,16 +225,14 @@ def create_annotation_router() -> APIRouter: ) # Get image dimensions for normalization - image_path = ADMIN_IMAGES_DIR / document_id / f"page_{request.page_number}.png" - if not image_path.exists(): + storage = get_storage_helper() + dimensions = storage.get_admin_image_dimensions(document_id, request.page_number) + if dimensions is None: raise HTTPException( status_code=400, detail=f"Image for page {request.page_number} not available", ) - - from PIL import Image - with Image.open(image_path) as img: - image_width, image_height = img.size + image_width, image_height = dimensions # Calculate normalized coordinates x_center = (request.bbox.x + request.bbox.width / 2) / image_width @@ -315,10 +328,14 @@ def create_annotation_router() -> APIRouter: if request.bbox is not None: # Get image dimensions - image_path = ADMIN_IMAGES_DIR / document_id / f"page_{annotation.page_number}.png" - from PIL import Image - with Image.open(image_path) as img: - image_width, image_height = img.size + storage = get_storage_helper() + dimensions = storage.get_admin_image_dimensions(document_id, annotation.page_number) + if dimensions is None: + raise HTTPException( + status_code=400, + detail=f"Image for page {annotation.page_number} not available", + ) + image_width, image_height = dimensions # Calculate normalized coordinates update_kwargs["x_center"] = (request.bbox.x + request.bbox.width / 2) / image_width diff --git a/packages/inference/inference/web/api/v1/admin/documents.py b/packages/inference/inference/web/api/v1/admin/documents.py index f78db66..3f147e0 100644 --- a/packages/inference/inference/web/api/v1/admin/documents.py +++ b/packages/inference/inference/web/api/v1/admin/documents.py @@ -13,16 +13,19 @@ from fastapi import APIRouter, File, HTTPException, Query, UploadFile from inference.web.config import DEFAULT_DPI, StorageConfig from inference.web.core.auth import AdminTokenDep, AdminDBDep +from inference.web.services.storage_helpers import get_storage_helper from inference.web.schemas.admin import ( AnnotationItem, AnnotationSource, AutoLabelStatus, BoundingBox, + DocumentCategoriesResponse, DocumentDetailResponse, DocumentItem, DocumentListResponse, DocumentStatus, DocumentStatsResponse, + DocumentUpdateRequest, DocumentUploadResponse, ModelMetrics, TrainingHistoryItem, @@ -44,14 +47,12 @@ def _validate_uuid(value: str, name: str = "ID") -> None: def _convert_pdf_to_images( - document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int + document_id: str, content: bytes, page_count: int, dpi: int ) -> None: - """Convert PDF pages to images for annotation.""" + """Convert PDF pages to images for annotation using StorageHelper.""" import fitz - doc_images_dir = images_dir / document_id - doc_images_dir.mkdir(parents=True, exist_ok=True) - + storage = get_storage_helper() pdf_doc = fitz.open(stream=content, filetype="pdf") for page_num in range(page_count): @@ -60,8 +61,9 @@ def _convert_pdf_to_images( mat = fitz.Matrix(dpi / 72, dpi / 72) pix = page.get_pixmap(matrix=mat) - image_path = doc_images_dir / f"page_{page_num + 1}.png" - pix.save(str(image_path)) + # Save to storage using StorageHelper + image_bytes = pix.tobytes("png") + storage.save_admin_image(document_id, page_num + 1, image_bytes) pdf_doc.close() @@ -95,6 +97,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: str | None, Query(description="Optional group key for document organization", max_length=255), ] = None, + category: Annotated[ + str, + Query(description="Document category (e.g., invoice, letter, receipt)", max_length=100), + ] = "invoice", ) -> DocumentUploadResponse: """Upload a document for labeling.""" # Validate group_key length @@ -143,31 +149,33 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: file_path="", # Will update after saving page_count=page_count, group_key=group_key, + category=category, ) - # Save file to admin uploads - file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}" + # Save file to storage using StorageHelper + storage = get_storage_helper() + filename = f"{document_id}{file_ext}" try: - file_path.write_bytes(content) + storage_path = storage.save_raw_pdf(content, filename) except Exception as e: logger.error(f"Failed to save file: {e}") raise HTTPException(status_code=500, detail="Failed to save file") - # Update file path in database + # Update file path in database (using storage path for reference) from inference.data.database import get_session_context from inference.data.admin_models import AdminDocument with get_session_context() as session: doc = session.get(AdminDocument, UUID(document_id)) if doc: - doc.file_path = str(file_path) + # Store the storage path (relative path within storage) + doc.file_path = storage_path session.add(doc) # Convert PDF to images for annotation if file_ext == ".pdf": try: _convert_pdf_to_images( - document_id, content, page_count, - storage_config.admin_images_dir, storage_config.dpi + document_id, content, page_count, storage_config.dpi ) except Exception as e: logger.error(f"Failed to convert PDF to images: {e}") @@ -189,6 +197,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: file_size=len(content), page_count=page_count, status=DocumentStatus.AUTO_LABELING if auto_label_started else DocumentStatus.PENDING, + category=category, group_key=group_key, auto_label_started=auto_label_started, message="Document uploaded successfully", @@ -226,6 +235,10 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: str | None, Query(description="Filter by batch ID"), ] = None, + category: Annotated[ + str | None, + Query(description="Filter by document category"), + ] = None, limit: Annotated[ int, Query(ge=1, le=100, description="Page size"), @@ -264,6 +277,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: has_annotations=has_annotations, auto_label_status=auto_label_status, batch_id=batch_id, + category=category, limit=limit, offset=offset, ) @@ -291,6 +305,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: upload_source=doc.upload_source if hasattr(doc, 'upload_source') else "ui", batch_id=str(doc.batch_id) if hasattr(doc, 'batch_id') and doc.batch_id else None, group_key=doc.group_key if hasattr(doc, 'group_key') else None, + category=doc.category if hasattr(doc, 'category') else "invoice", can_annotate=can_annotate, created_at=doc.created_at, updated_at=doc.updated_at, @@ -436,6 +451,7 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: upload_source=document.upload_source if hasattr(document, 'upload_source') else "ui", batch_id=str(document.batch_id) if hasattr(document, 'batch_id') and document.batch_id else None, group_key=document.group_key if hasattr(document, 'group_key') else None, + category=document.category if hasattr(document, 'category') else "invoice", csv_field_values=csv_field_values, can_annotate=can_annotate, annotation_lock_until=annotation_lock_until, @@ -471,16 +487,22 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: detail="Document not found or does not belong to this token", ) - # Delete file - file_path = Path(document.file_path) - if file_path.exists(): - file_path.unlink() + # Delete file using StorageHelper + storage = get_storage_helper() - # Delete images - images_dir = ADMIN_IMAGES_DIR / document_id - if images_dir.exists(): - import shutil - shutil.rmtree(images_dir) + # Delete the raw PDF + filename = Path(document.file_path).name + if filename: + try: + storage._storage.delete(document.file_path) + except Exception as e: + logger.warning(f"Failed to delete PDF file: {e}") + + # Delete admin images + try: + storage.delete_admin_images(document_id) + except Exception as e: + logger.warning(f"Failed to delete admin images: {e}") # Delete from database db.delete_document(document_id) @@ -609,4 +631,61 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter: "message": "Document group key updated", } + @router.get( + "/categories", + response_model=DocumentCategoriesResponse, + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + }, + summary="Get available categories", + description="Get list of all available document categories.", + ) + async def get_categories( + admin_token: AdminTokenDep, + db: AdminDBDep, + ) -> DocumentCategoriesResponse: + """Get all available document categories.""" + categories = db.get_document_categories() + return DocumentCategoriesResponse( + categories=categories, + total=len(categories), + ) + + @router.patch( + "/{document_id}/category", + responses={ + 401: {"model": ErrorResponse, "description": "Invalid token"}, + 404: {"model": ErrorResponse, "description": "Document not found"}, + }, + summary="Update document category", + description="Update the category for a document.", + ) + async def update_document_category( + document_id: str, + admin_token: AdminTokenDep, + db: AdminDBDep, + request: DocumentUpdateRequest, + ) -> dict: + """Update document category.""" + _validate_uuid(document_id, "document_id") + + # Verify document exists + document = db.get_document_by_token(document_id, admin_token) + if document is None: + raise HTTPException( + status_code=404, + detail="Document not found or does not belong to this token", + ) + + # Update category if provided + if request.category is not None: + db.update_document_category(document_id, request.category) + + return { + "status": "updated", + "document_id": document_id, + "category": request.category, + "message": "Document category updated", + } + return router diff --git a/packages/inference/inference/web/api/v1/admin/training/datasets.py b/packages/inference/inference/web/api/v1/admin/training/datasets.py index bf93239..0c70287 100644 --- a/packages/inference/inference/web/api/v1/admin/training/datasets.py +++ b/packages/inference/inference/web/api/v1/admin/training/datasets.py @@ -17,6 +17,7 @@ from inference.web.schemas.admin import ( TrainingStatus, TrainingTaskResponse, ) +from inference.web.services.storage_helpers import get_storage_helper from ._utils import _validate_uuid @@ -38,7 +39,6 @@ def register_dataset_routes(router: APIRouter) -> None: db: AdminDBDep, ) -> DatasetResponse: """Create a training dataset from document IDs.""" - from pathlib import Path from inference.web.services.dataset_builder import DatasetBuilder # Validate minimum document count for proper train/val/test split @@ -56,7 +56,18 @@ def register_dataset_routes(router: APIRouter) -> None: seed=request.seed, ) - builder = DatasetBuilder(db=db, base_dir=Path("data/datasets")) + # Get storage paths from StorageHelper + storage = get_storage_helper() + datasets_dir = storage.get_datasets_base_path() + admin_images_dir = storage.get_admin_images_base_path() + + if datasets_dir is None or admin_images_dir is None: + raise HTTPException( + status_code=500, + detail="Storage not configured for local access", + ) + + builder = DatasetBuilder(db=db, base_dir=datasets_dir) try: builder.build_dataset( dataset_id=str(dataset.dataset_id), @@ -64,7 +75,7 @@ def register_dataset_routes(router: APIRouter) -> None: train_ratio=request.train_ratio, val_ratio=request.val_ratio, seed=request.seed, - admin_images_dir=Path("data/admin_images"), + admin_images_dir=admin_images_dir, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -142,6 +153,12 @@ def register_dataset_routes(router: APIRouter) -> None: name=dataset.name, description=dataset.description, status=dataset.status, + training_status=dataset.training_status, + active_training_task_id=( + str(dataset.active_training_task_id) + if dataset.active_training_task_id + else None + ), train_ratio=dataset.train_ratio, val_ratio=dataset.val_ratio, seed=dataset.seed, diff --git a/packages/inference/inference/web/api/v1/admin/training/export.py b/packages/inference/inference/web/api/v1/admin/training/export.py index 6ce2cc3..7c881fb 100644 --- a/packages/inference/inference/web/api/v1/admin/training/export.py +++ b/packages/inference/inference/web/api/v1/admin/training/export.py @@ -34,8 +34,10 @@ def register_export_routes(router: APIRouter) -> None: db: AdminDBDep, ) -> ExportResponse: """Export annotations for training.""" - from pathlib import Path - import shutil + from inference.web.services.storage_helpers import get_storage_helper + + # Get storage helper for reading images and exports directory + storage = get_storage_helper() if request.format not in ("yolo", "coco", "voc"): raise HTTPException( @@ -51,7 +53,14 @@ def register_export_routes(router: APIRouter) -> None: detail="No labeled documents available for export", ) - export_dir = Path("data/exports") / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}" + # Get exports directory from StorageHelper + exports_base = storage.get_exports_base_path() + if exports_base is None: + raise HTTPException( + status_code=500, + detail="Storage not configured for local access", + ) + export_dir = exports_base / f"export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}" export_dir.mkdir(parents=True, exist_ok=True) (export_dir / "images" / "train").mkdir(parents=True, exist_ok=True) @@ -80,13 +89,16 @@ def register_export_routes(router: APIRouter) -> None: if not page_annotations and not request.include_images: continue - src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png" - if not src_image.exists(): + # Get image from storage + doc_id = str(doc.document_id) + if not storage.admin_image_exists(doc_id, page_num): continue + # Download image and save to export directory image_name = f"{doc.document_id}_page{page_num}.png" dst_image = export_dir / "images" / split / image_name - shutil.copy(src_image, dst_image) + image_content = storage.get_admin_image(doc_id, page_num) + dst_image.write_bytes(image_content) total_images += 1 label_name = f"{doc.document_id}_page{page_num}.txt" @@ -98,7 +110,7 @@ def register_export_routes(router: APIRouter) -> None: f.write(line) total_annotations += 1 - from inference.data.admin_models import FIELD_CLASSES + from shared.fields import FIELD_CLASSES yaml_content = f"""# Auto-generated YOLO dataset config path: {export_dir.absolute()} diff --git a/packages/inference/inference/web/api/v1/public/inference.py b/packages/inference/inference/web/api/v1/public/inference.py index 4861b66..2a93349 100644 --- a/packages/inference/inference/web/api/v1/public/inference.py +++ b/packages/inference/inference/web/api/v1/public/inference.py @@ -22,6 +22,7 @@ from inference.web.schemas.inference import ( InferenceResult, ) from inference.web.schemas.common import ErrorResponse +from inference.web.services.storage_helpers import get_storage_helper if TYPE_CHECKING: from inference.web.services import InferenceService @@ -90,8 +91,17 @@ def create_inference_router( # Generate document ID doc_id = str(uuid.uuid4())[:8] - # Save uploaded file - upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}" + # Get storage helper and uploads directory + storage = get_storage_helper() + uploads_dir = storage.get_uploads_base_path(subfolder="inference") + if uploads_dir is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Storage not configured for local access", + ) + + # Save uploaded file to temporary location for processing + upload_path = uploads_dir / f"{doc_id}{file_ext}" try: with open(upload_path, "wb") as f: shutil.copyfileobj(file.file, f) @@ -149,12 +159,13 @@ def create_inference_router( # Cleanup uploaded file upload_path.unlink(missing_ok=True) - @router.get("/results/{filename}") + @router.get("/results/{filename}", response_model=None) async def get_result_image(filename: str) -> FileResponse: """Get visualization result image.""" - file_path = storage_config.result_dir / filename + storage = get_storage_helper() + file_path = storage.get_result_local_path(filename) - if not file_path.exists(): + if file_path is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Result file not found: {filename}", @@ -169,15 +180,15 @@ def create_inference_router( @router.delete("/results/{filename}") async def delete_result(filename: str) -> dict: """Delete a result file.""" - file_path = storage_config.result_dir / filename + storage = get_storage_helper() - if not file_path.exists(): + if not storage.result_exists(filename): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Result file not found: {filename}", ) - file_path.unlink() + storage.delete_result(filename) return {"status": "deleted", "filename": filename} return router diff --git a/packages/inference/inference/web/api/v1/public/labeling.py b/packages/inference/inference/web/api/v1/public/labeling.py index f029036..8d43de2 100644 --- a/packages/inference/inference/web/api/v1/public/labeling.py +++ b/packages/inference/inference/web/api/v1/public/labeling.py @@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s from inference.data.admin_db import AdminDB from inference.web.schemas.labeling import PreLabelResponse from inference.web.schemas.common import ErrorResponse +from inference.web.services.storage_helpers import get_storage_helper if TYPE_CHECKING: from inference.web.services import InferenceService @@ -23,19 +24,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Storage directory for pre-label uploads (legacy, now uses storage_config) -PRE_LABEL_UPLOAD_DIR = Path("data/pre_label_uploads") - def _convert_pdf_to_images( - document_id: str, content: bytes, page_count: int, images_dir: Path, dpi: int + document_id: str, content: bytes, page_count: int, dpi: int ) -> None: - """Convert PDF pages to images for annotation.""" + """Convert PDF pages to images for annotation using StorageHelper.""" import fitz - doc_images_dir = images_dir / document_id - doc_images_dir.mkdir(parents=True, exist_ok=True) - + storage = get_storage_helper() pdf_doc = fitz.open(stream=content, filetype="pdf") for page_num in range(page_count): @@ -43,8 +39,9 @@ def _convert_pdf_to_images( mat = fitz.Matrix(dpi / 72, dpi / 72) pix = page.get_pixmap(matrix=mat) - image_path = doc_images_dir / f"page_{page_num + 1}.png" - pix.save(str(image_path)) + # Save to storage using StorageHelper + image_bytes = pix.tobytes("png") + storage.save_admin_image(document_id, page_num + 1, image_bytes) pdf_doc.close() @@ -70,9 +67,6 @@ def create_labeling_router( """ router = APIRouter(prefix="/api/v1", tags=["labeling"]) - # Ensure upload directory exists - PRE_LABEL_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) - @router.post( "/pre-label", response_model=PreLabelResponse, @@ -165,10 +159,11 @@ def create_labeling_router( csv_field_values=expected_values, ) - # Save file to admin uploads - file_path = storage_config.admin_upload_dir / f"{document_id}{file_ext}" + # Save file to storage using StorageHelper + storage = get_storage_helper() + filename = f"{document_id}{file_ext}" try: - file_path.write_bytes(content) + storage_path = storage.save_raw_pdf(content, filename) except Exception as e: logger.error(f"Failed to save file: {e}") raise HTTPException( @@ -176,15 +171,14 @@ def create_labeling_router( detail="Failed to save file", ) - # Update file path in database - db.update_document_file_path(document_id, str(file_path)) + # Update file path in database (using storage path) + db.update_document_file_path(document_id, storage_path) # Convert PDF to images for annotation UI if file_ext == ".pdf": try: _convert_pdf_to_images( - document_id, content, page_count, - storage_config.admin_images_dir, storage_config.dpi + document_id, content, page_count, storage_config.dpi ) except Exception as e: logger.error(f"Failed to convert PDF to images: {e}") diff --git a/packages/inference/inference/web/app.py b/packages/inference/inference/web/app.py index f14e259..94c714d 100644 --- a/packages/inference/inference/web/app.py +++ b/packages/inference/inference/web/app.py @@ -18,6 +18,7 @@ from fastapi.responses import HTMLResponse from .config import AppConfig, default_config from inference.web.services import InferenceService +from inference.web.services.storage_helpers import get_storage_helper # Public API imports from inference.web.api.v1.public import ( @@ -238,13 +239,17 @@ def create_app(config: AppConfig | None = None) -> FastAPI: allow_headers=["*"], ) - # Mount static files for results - config.storage.result_dir.mkdir(parents=True, exist_ok=True) - app.mount( - "/static/results", - StaticFiles(directory=str(config.storage.result_dir)), - name="results", - ) + # Mount static files for results using StorageHelper + storage = get_storage_helper() + results_dir = storage.get_results_base_path() + if results_dir: + app.mount( + "/static/results", + StaticFiles(directory=str(results_dir)), + name="results", + ) + else: + logger.warning("Could not mount static results directory: local storage not available") # Include public API routes inference_router = create_inference_router(inference_service, config.storage) diff --git a/packages/inference/inference/web/config.py b/packages/inference/inference/web/config.py index 4eab8e3..de65fbc 100644 --- a/packages/inference/inference/web/config.py +++ b/packages/inference/inference/web/config.py @@ -4,16 +4,49 @@ Web Application Configuration Centralized configuration for the web application. """ +import os from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any -from shared.config import DEFAULT_DPI, PATHS +from shared.config import DEFAULT_DPI + +if TYPE_CHECKING: + from shared.storage.base import StorageBackend + + +def get_storage_backend( + config_path: Path | str | None = None, +) -> "StorageBackend": + """Get storage backend for file operations. + + Args: + config_path: Optional path to storage configuration file. + If not provided, uses STORAGE_CONFIG_PATH env var or falls back to env vars. + + Returns: + Configured StorageBackend instance. + """ + from shared.storage import get_storage_backend as _get_storage_backend + + # Check for config file path + if config_path is None: + config_path_str = os.environ.get("STORAGE_CONFIG_PATH") + if config_path_str: + config_path = Path(config_path_str) + + return _get_storage_backend(config_path=config_path) @dataclass(frozen=True) class ModelConfig: - """YOLO model configuration.""" + """YOLO model configuration. + + Note: Model files are stored locally (not in STORAGE_BASE_PATH) because: + - Models need to be accessible by inference service on any platform + - Models may be version-controlled or deployed separately + - Models are part of the application, not user data + """ model_path: Path = Path("runs/train/invoice_fields/weights/best.pt") confidence_threshold: float = 0.5 @@ -33,24 +66,39 @@ class ServerConfig: @dataclass(frozen=True) -class StorageConfig: - """File storage configuration. +class FileConfig: + """File handling configuration. - Note: admin_upload_dir uses PATHS['pdf_dir'] so uploaded PDFs are stored - directly in raw_pdfs directory. This ensures consistency with CLI autolabel - and avoids storing duplicate files. + This config holds file handling settings. For file operations, + use the storage backend with PREFIXES from shared.storage.prefixes. + + Example: + from shared.storage import PREFIXES, get_storage_backend + + storage = get_storage_backend() + path = PREFIXES.document_path(document_id) + storage.upload_bytes(content, path) + + Note: The path fields (upload_dir, result_dir, etc.) are deprecated. + They are kept for backward compatibility with existing code and tests. + New code should use the storage backend with PREFIXES instead. """ - upload_dir: Path = Path("uploads") - result_dir: Path = Path("results") - admin_upload_dir: Path = field(default_factory=lambda: Path(PATHS["pdf_dir"])) - admin_images_dir: Path = Path("data/admin_images") max_file_size_mb: int = 50 allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg") dpi: int = DEFAULT_DPI + presigned_url_expiry_seconds: int = 3600 + + # Deprecated path fields - kept for backward compatibility + # New code should use storage backend with PREFIXES instead + # All paths are now under data/ to match WSL storage layout + upload_dir: Path = field(default_factory=lambda: Path("data/uploads")) + result_dir: Path = field(default_factory=lambda: Path("data/results")) + admin_upload_dir: Path = field(default_factory=lambda: Path("data/raw_pdfs")) + admin_images_dir: Path = field(default_factory=lambda: Path("data/admin_images")) def __post_init__(self) -> None: - """Create directories if they don't exist.""" + """Create directories if they don't exist (for backward compatibility).""" object.__setattr__(self, "upload_dir", Path(self.upload_dir)) object.__setattr__(self, "result_dir", Path(self.result_dir)) object.__setattr__(self, "admin_upload_dir", Path(self.admin_upload_dir)) @@ -61,9 +109,17 @@ class StorageConfig: self.admin_images_dir.mkdir(parents=True, exist_ok=True) +# Backward compatibility alias +StorageConfig = FileConfig + + @dataclass(frozen=True) class AsyncConfig: - """Async processing configuration.""" + """Async processing configuration. + + Note: For file paths, use the storage backend with PREFIXES. + Example: PREFIXES.upload_path(filename, "async") + """ # Queue settings queue_max_size: int = 100 @@ -77,14 +133,17 @@ class AsyncConfig: # Storage result_retention_days: int = 7 - temp_upload_dir: Path = Path("uploads/async") max_file_size_mb: int = 50 + # Deprecated: kept for backward compatibility + # Path under data/ to match WSL storage layout + temp_upload_dir: Path = field(default_factory=lambda: Path("data/uploads/async")) + # Cleanup cleanup_interval_hours: int = 1 def __post_init__(self) -> None: - """Create directories if they don't exist.""" + """Create directories if they don't exist (for backward compatibility).""" object.__setattr__(self, "temp_upload_dir", Path(self.temp_upload_dir)) self.temp_upload_dir.mkdir(parents=True, exist_ok=True) @@ -95,19 +154,41 @@ class AppConfig: model: ModelConfig = field(default_factory=ModelConfig) server: ServerConfig = field(default_factory=ServerConfig) - storage: StorageConfig = field(default_factory=StorageConfig) + file: FileConfig = field(default_factory=FileConfig) async_processing: AsyncConfig = field(default_factory=AsyncConfig) + storage_backend: "StorageBackend | None" = None + + @property + def storage(self) -> FileConfig: + """Backward compatibility alias for file config.""" + return self.file @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig": """Create config from dictionary.""" + file_config = config_dict.get("file", config_dict.get("storage", {})) return cls( model=ModelConfig(**config_dict.get("model", {})), server=ServerConfig(**config_dict.get("server", {})), - storage=StorageConfig(**config_dict.get("storage", {})), + file=FileConfig(**file_config), async_processing=AsyncConfig(**config_dict.get("async_processing", {})), ) +def create_app_config( + storage_config_path: Path | str | None = None, +) -> AppConfig: + """Create application configuration with storage backend. + + Args: + storage_config_path: Optional path to storage configuration file. + + Returns: + Configured AppConfig instance with storage backend initialized. + """ + storage_backend = get_storage_backend(config_path=storage_config_path) + return AppConfig(storage_backend=storage_backend) + + # Default configuration instance default_config = AppConfig() diff --git a/packages/inference/inference/web/core/autolabel_scheduler.py b/packages/inference/inference/web/core/autolabel_scheduler.py index ded452b..e1b137d 100644 --- a/packages/inference/inference/web/core/autolabel_scheduler.py +++ b/packages/inference/inference/web/core/autolabel_scheduler.py @@ -13,6 +13,7 @@ from inference.web.services.db_autolabel import ( get_pending_autolabel_documents, process_document_autolabel, ) +from inference.web.services.storage_helpers import get_storage_helper logger = logging.getLogger(__name__) @@ -36,7 +37,13 @@ class AutoLabelScheduler: """ self._check_interval = check_interval_seconds self._batch_size = batch_size + + # Get output directory from StorageHelper + if output_dir is None: + storage = get_storage_helper() + output_dir = storage.get_autolabel_output_path() self._output_dir = output_dir or Path("data/autolabel_output") + self._running = False self._thread: threading.Thread | None = None self._stop_event = threading.Event() diff --git a/packages/inference/inference/web/core/scheduler.py b/packages/inference/inference/web/core/scheduler.py index 7ece72b..a22c0af 100644 --- a/packages/inference/inference/web/core/scheduler.py +++ b/packages/inference/inference/web/core/scheduler.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Any from inference.data.admin_db import AdminDB +from inference.web.services.storage_helpers import get_storage_helper logger = logging.getLogger(__name__) @@ -107,6 +108,14 @@ class TrainingScheduler: self._db.update_training_task_status(task_id, "running") self._db.add_training_log(task_id, "INFO", "Training task started") + # Update dataset training status to running + if dataset_id: + self._db.update_dataset_training_status( + dataset_id, + training_status="running", + active_training_task_id=task_id, + ) + try: # Get training configuration model_name = config.get("model_name", "yolo11n.pt") @@ -192,6 +201,15 @@ class TrainingScheduler: ) self._db.add_training_log(task_id, "INFO", "Training completed successfully") + # Update dataset training status to completed and main status to trained + if dataset_id: + self._db.update_dataset_training_status( + dataset_id, + training_status="completed", + active_training_task_id=None, + update_main_status=True, # Set main status to 'trained' + ) + # Auto-create model version for the completed training self._create_model_version_from_training( task_id=task_id, @@ -203,6 +221,13 @@ class TrainingScheduler: except Exception as e: logger.error(f"Training task {task_id} failed: {e}") self._db.add_training_log(task_id, "ERROR", f"Training failed: {e}") + # Update dataset training status to failed + if dataset_id: + self._db.update_dataset_training_status( + dataset_id, + training_status="failed", + active_training_task_id=None, + ) raise def _create_model_version_from_training( @@ -268,9 +293,10 @@ class TrainingScheduler: f"Created model version {version} (ID: {model_version.version_id}) " f"from training task {task_id}" ) + mAP_display = f"{metrics_mAP:.3f}" if metrics_mAP else "N/A" self._db.add_training_log( task_id, "INFO", - f"Model version {version} created (mAP: {metrics_mAP:.3f if metrics_mAP else 'N/A'})", + f"Model version {version} created (mAP: {mAP_display})", ) except Exception as e: @@ -283,8 +309,11 @@ class TrainingScheduler: def _export_training_data(self, task_id: str) -> dict[str, Any] | None: """Export training data for a task.""" from pathlib import Path - import shutil - from inference.data.admin_models import FIELD_CLASSES + from shared.fields import FIELD_CLASSES + from inference.web.services.storage_helpers import get_storage_helper + + # Get storage helper for reading images + storage = get_storage_helper() # Get all labeled documents documents = self._db.get_labeled_documents_for_export() @@ -293,8 +322,12 @@ class TrainingScheduler: self._db.add_training_log(task_id, "ERROR", "No labeled documents available") return None - # Create export directory - export_dir = Path("data/training") / task_id + # Create export directory using StorageHelper + training_base = storage.get_training_data_path() + if training_base is None: + self._db.add_training_log(task_id, "ERROR", "Storage not configured for local access") + return None + export_dir = training_base / task_id export_dir.mkdir(parents=True, exist_ok=True) # YOLO format directories @@ -323,14 +356,16 @@ class TrainingScheduler: for page_num in range(1, doc.page_count + 1): page_annotations = [a for a in annotations if a.page_number == page_num] - # Copy image - src_image = Path("data/admin_images") / str(doc.document_id) / f"page_{page_num}.png" - if not src_image.exists(): + # Get image from storage + doc_id = str(doc.document_id) + if not storage.admin_image_exists(doc_id, page_num): continue + # Download image and save to export directory image_name = f"{doc.document_id}_page{page_num}.png" dst_image = export_dir / "images" / split / image_name - shutil.copy(src_image, dst_image) + image_content = storage.get_admin_image(doc_id, page_num) + dst_image.write_bytes(image_content) total_images += 1 # Write YOLO label @@ -380,6 +415,8 @@ names: {list(FIELD_CLASSES.values())} self._db.add_training_log(task_id, level, message) # Create shared training config + # Note: Model outputs go to local runs/train directory (not STORAGE_BASE_PATH) + # because models need to be accessible by inference service on any platform # Note: workers=0 to avoid multiprocessing issues when running in scheduler thread config = SharedTrainingConfig( model_path=model_name, diff --git a/packages/inference/inference/web/schemas/admin/datasets.py b/packages/inference/inference/web/schemas/admin/datasets.py index e1c9420..b7490d7 100644 --- a/packages/inference/inference/web/schemas/admin/datasets.py +++ b/packages/inference/inference/web/schemas/admin/datasets.py @@ -13,6 +13,7 @@ class DatasetCreateRequest(BaseModel): name: str = Field(..., min_length=1, max_length=255, description="Dataset name") description: str | None = Field(None, description="Optional description") document_ids: list[str] = Field(..., min_length=1, description="Document UUIDs to include") + category: str | None = Field(None, description="Filter documents by category (optional)") train_ratio: float = Field(0.8, ge=0.1, le=0.95, description="Training split ratio") val_ratio: float = Field(0.1, ge=0.05, le=0.5, description="Validation split ratio") seed: int = Field(42, description="Random seed for split") @@ -43,6 +44,8 @@ class DatasetDetailResponse(BaseModel): name: str description: str | None status: str + training_status: str | None = None + active_training_task_id: str | None = None train_ratio: float val_ratio: float seed: int diff --git a/packages/inference/inference/web/schemas/admin/documents.py b/packages/inference/inference/web/schemas/admin/documents.py index cf3ea82..743a3c3 100644 --- a/packages/inference/inference/web/schemas/admin/documents.py +++ b/packages/inference/inference/web/schemas/admin/documents.py @@ -22,6 +22,7 @@ class DocumentUploadResponse(BaseModel): file_size: int = Field(..., ge=0, description="File size in bytes") page_count: int = Field(..., ge=1, description="Number of pages") status: DocumentStatus = Field(..., description="Document status") + category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)") group_key: str | None = Field(None, description="User-defined group key") auto_label_started: bool = Field( default=False, description="Whether auto-labeling was started" @@ -44,6 +45,7 @@ class DocumentItem(BaseModel): upload_source: str = Field(default="ui", description="Upload source (ui or api)") batch_id: str | None = Field(None, description="Batch ID if uploaded via batch") group_key: str | None = Field(None, description="User-defined group key") + category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)") can_annotate: bool = Field(default=True, description="Whether document can be annotated") created_at: datetime = Field(..., description="Creation timestamp") updated_at: datetime = Field(..., description="Last update timestamp") @@ -76,6 +78,7 @@ class DocumentDetailResponse(BaseModel): upload_source: str = Field(default="ui", description="Upload source (ui or api)") batch_id: str | None = Field(None, description="Batch ID if uploaded via batch") group_key: str | None = Field(None, description="User-defined group key") + category: str = Field(default="invoice", description="Document category (e.g., invoice, letter, receipt)") csv_field_values: dict[str, str] | None = Field( None, description="CSV field values if uploaded via batch" ) @@ -104,3 +107,17 @@ class DocumentStatsResponse(BaseModel): auto_labeling: int = Field(default=0, ge=0, description="Auto-labeling documents") labeled: int = Field(default=0, ge=0, description="Labeled documents") exported: int = Field(default=0, ge=0, description="Exported documents") + + +class DocumentUpdateRequest(BaseModel): + """Request for updating document metadata.""" + + category: str | None = Field(None, description="Document category (e.g., invoice, letter, receipt)") + group_key: str | None = Field(None, description="User-defined group key") + + +class DocumentCategoriesResponse(BaseModel): + """Response for available document categories.""" + + categories: list[str] = Field(..., description="List of available categories") + total: int = Field(..., ge=0, description="Total number of categories") diff --git a/packages/inference/inference/web/services/async_processing.py b/packages/inference/inference/web/services/async_processing.py index 11a0173..32428d9 100644 --- a/packages/inference/inference/web/services/async_processing.py +++ b/packages/inference/inference/web/services/async_processing.py @@ -5,6 +5,7 @@ Manages async request lifecycle and background processing. """ import logging +import re import shutil import time import uuid @@ -17,6 +18,7 @@ from typing import TYPE_CHECKING from inference.data.async_request_db import AsyncRequestDB from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue from inference.web.core.rate_limiter import RateLimiter +from inference.web.services.storage_helpers import get_storage_helper if TYPE_CHECKING: from inference.web.config import AsyncConfig, StorageConfig @@ -189,9 +191,7 @@ class AsyncProcessingService: filename: str, content: bytes, ) -> Path: - """Save uploaded file to temp storage.""" - import re - + """Save uploaded file to temp storage using StorageHelper.""" # Extract extension from filename ext = Path(filename).suffix.lower() @@ -203,9 +203,11 @@ class AsyncProcessingService: if ext not in self.ALLOWED_EXTENSIONS: ext = ".pdf" - # Create async upload directory - upload_dir = self._async_config.temp_upload_dir - upload_dir.mkdir(parents=True, exist_ok=True) + # Get upload directory from StorageHelper + storage = get_storage_helper() + upload_dir = storage.get_uploads_base_path(subfolder="async") + if upload_dir is None: + raise ValueError("Storage not configured for local access") # Build file path - request_id is a UUID so it's safe file_path = upload_dir / f"{request_id}{ext}" @@ -355,8 +357,9 @@ class AsyncProcessingService: def _cleanup_orphan_files(self) -> int: """Clean up upload files that don't have matching requests.""" - upload_dir = self._async_config.temp_upload_dir - if not upload_dir.exists(): + storage = get_storage_helper() + upload_dir = storage.get_uploads_base_path(subfolder="async") + if upload_dir is None or not upload_dir.exists(): return 0 count = 0 diff --git a/packages/inference/inference/web/services/autolabel.py b/packages/inference/inference/web/services/autolabel.py index a2f3728..242243b 100644 --- a/packages/inference/inference/web/services/autolabel.py +++ b/packages/inference/inference/web/services/autolabel.py @@ -13,7 +13,7 @@ from PIL import Image from shared.config import DEFAULT_DPI from inference.data.admin_db import AdminDB -from inference.data.admin_models import FIELD_CLASS_IDS, FIELD_CLASSES +from shared.fields import FIELD_CLASS_IDS, FIELD_CLASSES from shared.matcher.field_matcher import FieldMatcher from shared.ocr.paddle_ocr import OCREngine, OCRToken diff --git a/packages/inference/inference/web/services/batch_upload.py b/packages/inference/inference/web/services/batch_upload.py index 5ac903b..3b2b178 100644 --- a/packages/inference/inference/web/services/batch_upload.py +++ b/packages/inference/inference/web/services/batch_upload.py @@ -16,7 +16,7 @@ from uuid import UUID from pydantic import BaseModel, Field, field_validator from inference.data.admin_db import AdminDB -from inference.data.admin_models import CSV_TO_CLASS_MAPPING +from shared.fields import CSV_TO_CLASS_MAPPING logger = logging.getLogger(__name__) diff --git a/packages/inference/inference/web/services/dataset_builder.py b/packages/inference/inference/web/services/dataset_builder.py index fc383b4..c19f463 100644 --- a/packages/inference/inference/web/services/dataset_builder.py +++ b/packages/inference/inference/web/services/dataset_builder.py @@ -12,7 +12,7 @@ from pathlib import Path import yaml -from inference.data.admin_models import FIELD_CLASSES +from shared.fields import FIELD_CLASSES logger = logging.getLogger(__name__) diff --git a/packages/inference/inference/web/services/db_autolabel.py b/packages/inference/inference/web/services/db_autolabel.py index 44da968..5495e81 100644 --- a/packages/inference/inference/web/services/db_autolabel.py +++ b/packages/inference/inference/web/services/db_autolabel.py @@ -13,9 +13,10 @@ from typing import Any from shared.config import DEFAULT_DPI from inference.data.admin_db import AdminDB -from inference.data.admin_models import AdminDocument, CSV_TO_CLASS_MAPPING +from shared.fields import CSV_TO_CLASS_MAPPING +from inference.data.admin_models import AdminDocument from shared.data.db import DocumentDB -from inference.web.config import StorageConfig +from inference.web.services.storage_helpers import get_storage_helper logger = logging.getLogger(__name__) @@ -122,8 +123,12 @@ def process_document_autolabel( document_id = str(document.document_id) file_path = Path(document.file_path) + # Get output directory from StorageHelper + storage = get_storage_helper() if output_dir is None: - output_dir = Path("data/autolabel_output") + output_dir = storage.get_autolabel_output_path() + if output_dir is None: + output_dir = Path("data/autolabel_output") output_dir.mkdir(parents=True, exist_ok=True) # Mark as processing @@ -152,10 +157,12 @@ def process_document_autolabel( is_scanned = len(tokens) < 10 # Threshold for "no text" # Build task data - # Use admin_upload_dir (which is PATHS['pdf_dir']) for pdf_path + # Use raw_pdfs base path for pdf_path # This ensures consistency with CLI autolabel for reprocess_failed.py - storage_config = StorageConfig() - pdf_path_for_report = storage_config.admin_upload_dir / f"{document_id}.pdf" + raw_pdfs_dir = storage.get_raw_pdfs_base_path() + if raw_pdfs_dir is None: + raise ValueError("Storage not configured for local access") + pdf_path_for_report = raw_pdfs_dir / f"{document_id}.pdf" task_data = { "row_dict": row_dict, @@ -246,8 +253,8 @@ def _save_annotations_to_db( Returns: Number of annotations saved """ - from PIL import Image - from inference.data.admin_models import FIELD_CLASS_IDS + from shared.fields import FIELD_CLASS_IDS + from inference.web.services.storage_helpers import get_storage_helper # Mapping from CSV field names to internal field names CSV_TO_INTERNAL_FIELD: dict[str, str] = { @@ -266,6 +273,9 @@ def _save_annotations_to_db( # Scale factor: PDF points (72 DPI) -> pixels (at configured DPI) scale = dpi / 72.0 + # Get storage helper for image dimensions + storage = get_storage_helper() + # Cache for image dimensions per page image_dimensions: dict[int, tuple[int, int]] = {} @@ -274,18 +284,11 @@ def _save_annotations_to_db( if page_no in image_dimensions: return image_dimensions[page_no] - # Try to load from admin_images - admin_images_dir = Path("data/admin_images") / document_id - image_path = admin_images_dir / f"page_{page_no}.png" - - if image_path.exists(): - try: - with Image.open(image_path) as img: - dims = img.size # (width, height) - image_dimensions[page_no] = dims - return dims - except Exception as e: - logger.warning(f"Failed to read image dimensions from {image_path}: {e}") + # Get dimensions from storage helper + dims = storage.get_admin_image_dimensions(document_id, page_no) + if dims: + image_dimensions[page_no] = dims + return dims return None @@ -449,10 +452,17 @@ def save_manual_annotations_to_document_db( from datetime import datetime document_id = str(document.document_id) - storage_config = StorageConfig() - # Build pdf_path using admin_upload_dir (same as auto-label) - pdf_path = storage_config.admin_upload_dir / f"{document_id}.pdf" + # Build pdf_path using raw_pdfs base path (same as auto-label) + storage = get_storage_helper() + raw_pdfs_dir = storage.get_raw_pdfs_base_path() + if raw_pdfs_dir is None: + return { + "success": False, + "document_id": document_id, + "error": "Storage not configured for local access", + } + pdf_path = raw_pdfs_dir / f"{document_id}.pdf" # Build report dict compatible with DocumentDB.save_document() field_results = [] diff --git a/packages/inference/inference/web/services/document_service.py b/packages/inference/inference/web/services/document_service.py new file mode 100644 index 0000000..1e4288b --- /dev/null +++ b/packages/inference/inference/web/services/document_service.py @@ -0,0 +1,217 @@ +""" +Document Service for storage-backed file operations. + +Provides a unified interface for document upload, download, and serving +using the storage abstraction layer. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +if TYPE_CHECKING: + from shared.storage.base import StorageBackend + + +@dataclass +class DocumentResult: + """Result of document operation.""" + + id: str + file_path: str + filename: str | None = None + + +class DocumentService: + """Service for document file operations using storage backend. + + Provides upload, download, and URL generation for documents and images. + """ + + # Storage path prefixes + DOCUMENTS_PREFIX = "documents" + IMAGES_PREFIX = "images" + + def __init__( + self, + storage_backend: "StorageBackend", + admin_db: Any | None = None, + ) -> None: + """Initialize document service. + + Args: + storage_backend: Storage backend for file operations. + admin_db: Optional AdminDB instance for database operations. + """ + self._storage = storage_backend + self._admin_db = admin_db + + def upload_document( + self, + content: bytes, + filename: str, + dataset_id: str | None = None, + document_id: str | None = None, + ) -> DocumentResult: + """Upload a document to storage. + + Args: + content: Document content as bytes. + filename: Original filename. + dataset_id: Optional dataset ID for organization. + document_id: Optional document ID (generated if not provided). + + Returns: + DocumentResult with ID and storage path. + """ + if document_id is None: + document_id = str(uuid4()) + + # Extract extension from filename + ext = "" + if "." in filename: + ext = "." + filename.rsplit(".", 1)[-1].lower() + + # Build logical path + remote_path = f"{self.DOCUMENTS_PREFIX}/{document_id}{ext}" + + # Upload via storage backend + self._storage.upload_bytes(content, remote_path, overwrite=True) + + return DocumentResult( + id=document_id, + file_path=remote_path, + filename=filename, + ) + + def download_document(self, remote_path: str) -> bytes: + """Download a document from storage. + + Args: + remote_path: Logical path to the document. + + Returns: + Document content as bytes. + """ + return self._storage.download_bytes(remote_path) + + def get_document_url( + self, + remote_path: str, + expires_in_seconds: int = 3600, + ) -> str: + """Get a URL for accessing a document. + + Args: + remote_path: Logical path to the document. + expires_in_seconds: URL validity duration. + + Returns: + Pre-signed URL for document access. + """ + return self._storage.get_presigned_url(remote_path, expires_in_seconds) + + def document_exists(self, remote_path: str) -> bool: + """Check if a document exists in storage. + + Args: + remote_path: Logical path to the document. + + Returns: + True if document exists. + """ + return self._storage.exists(remote_path) + + def delete_document_files(self, remote_path: str) -> bool: + """Delete a document from storage. + + Args: + remote_path: Logical path to the document. + + Returns: + True if document was deleted. + """ + return self._storage.delete(remote_path) + + def save_page_image( + self, + document_id: str, + page_num: int, + content: bytes, + ) -> str: + """Save a page image to storage. + + Args: + document_id: Document ID. + page_num: Page number (1-indexed). + content: Image content as bytes. + + Returns: + Logical path where image was stored. + """ + remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png" + self._storage.upload_bytes(content, remote_path, overwrite=True) + return remote_path + + def get_page_image_url( + self, + document_id: str, + page_num: int, + expires_in_seconds: int = 3600, + ) -> str: + """Get a URL for accessing a page image. + + Args: + document_id: Document ID. + page_num: Page number (1-indexed). + expires_in_seconds: URL validity duration. + + Returns: + Pre-signed URL for image access. + """ + remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png" + return self._storage.get_presigned_url(remote_path, expires_in_seconds) + + def get_page_image(self, document_id: str, page_num: int) -> bytes: + """Download a page image from storage. + + Args: + document_id: Document ID. + page_num: Page number (1-indexed). + + Returns: + Image content as bytes. + """ + remote_path = f"{self.IMAGES_PREFIX}/{document_id}/page_{page_num}.png" + return self._storage.download_bytes(remote_path) + + def delete_document_images(self, document_id: str) -> int: + """Delete all images for a document. + + Args: + document_id: Document ID. + + Returns: + Number of images deleted. + """ + prefix = f"{self.IMAGES_PREFIX}/{document_id}/" + image_paths = self._storage.list_files(prefix) + + deleted_count = 0 + for path in image_paths: + if self._storage.delete(path): + deleted_count += 1 + + return deleted_count + + def list_document_images(self, document_id: str) -> list[str]: + """List all images for a document. + + Args: + document_id: Document ID. + + Returns: + List of image paths. + """ + prefix = f"{self.IMAGES_PREFIX}/{document_id}/" + return self._storage.list_files(prefix) diff --git a/packages/inference/inference/web/services/inference.py b/packages/inference/inference/web/services/inference.py index 84d4028..002bf56 100644 --- a/packages/inference/inference/web/services/inference.py +++ b/packages/inference/inference/web/services/inference.py @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Callable import numpy as np from PIL import Image +from inference.web.services.storage_helpers import get_storage_helper + if TYPE_CHECKING: from .config import ModelConfig, StorageConfig @@ -303,12 +305,19 @@ class InferenceService: """Save visualization image with detections.""" from ultralytics import YOLO + # Get storage helper for results directory + storage = get_storage_helper() + results_dir = storage.get_results_base_path() + if results_dir is None: + logger.warning("Cannot save visualization: local storage not available") + return None + # Load model and run prediction with visualization model = YOLO(str(self.model_config.model_path)) results = model.predict(str(image_path), verbose=False) # Save annotated image - output_path = self.storage_config.result_dir / f"{doc_id}_result.png" + output_path = results_dir / f"{doc_id}_result.png" for r in results: r.save(filename=str(output_path)) @@ -320,19 +329,26 @@ class InferenceService: from ultralytics import YOLO import io + # Get storage helper for results directory + storage = get_storage_helper() + results_dir = storage.get_results_base_path() + if results_dir is None: + logger.warning("Cannot save visualization: local storage not available") + return None + # Render first page for page_no, image_bytes in render_pdf_to_images( pdf_path, dpi=self.model_config.dpi ): image = Image.open(io.BytesIO(image_bytes)) - temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png" + temp_path = results_dir / f"{doc_id}_temp.png" image.save(temp_path) # Run YOLO and save visualization model = YOLO(str(self.model_config.model_path)) results = model.predict(str(temp_path), verbose=False) - output_path = self.storage_config.result_dir / f"{doc_id}_result.png" + output_path = results_dir / f"{doc_id}_result.png" for r in results: r.save(filename=str(output_path)) diff --git a/packages/inference/inference/web/services/storage_helpers.py b/packages/inference/inference/web/services/storage_helpers.py new file mode 100644 index 0000000..0c62a5d --- /dev/null +++ b/packages/inference/inference/web/services/storage_helpers.py @@ -0,0 +1,830 @@ +""" +Storage helpers for web services. + +Provides convenience functions for common storage operations, +wrapping the storage backend with proper path handling using prefixes. +""" + +from pathlib import Path +from typing import TYPE_CHECKING +from uuid import uuid4 + +from shared.storage import PREFIXES, get_storage_backend +from shared.storage.local import LocalStorageBackend + +if TYPE_CHECKING: + from shared.storage.base import StorageBackend + + +def get_default_storage() -> "StorageBackend": + """Get the default storage backend. + + Returns: + Configured StorageBackend instance. + """ + return get_storage_backend() + + +class StorageHelper: + """Helper class for storage operations with prefixes. + + Provides high-level operations for document storage, including + upload, download, and URL generation with proper path prefixes. + """ + + def __init__(self, storage: "StorageBackend | None" = None) -> None: + """Initialize storage helper. + + Args: + storage: Storage backend to use. If None, creates default. + """ + self._storage = storage or get_default_storage() + + @property + def storage(self) -> "StorageBackend": + """Get the underlying storage backend.""" + return self._storage + + # Document operations + + def upload_document( + self, + content: bytes, + filename: str, + document_id: str | None = None, + ) -> tuple[str, str]: + """Upload a document to storage. + + Args: + content: Document content as bytes. + filename: Original filename (used for extension). + document_id: Optional document ID. Generated if not provided. + + Returns: + Tuple of (document_id, storage_path). + """ + if document_id is None: + document_id = str(uuid4()) + + ext = Path(filename).suffix.lower() or ".pdf" + path = PREFIXES.document_path(document_id, ext) + self._storage.upload_bytes(content, path, overwrite=True) + + return document_id, path + + def download_document(self, document_id: str, extension: str = ".pdf") -> bytes: + """Download a document from storage. + + Args: + document_id: Document identifier. + extension: File extension. + + Returns: + Document content as bytes. + """ + path = PREFIXES.document_path(document_id, extension) + return self._storage.download_bytes(path) + + def get_document_url( + self, + document_id: str, + extension: str = ".pdf", + expires_in_seconds: int = 3600, + ) -> str: + """Get presigned URL for a document. + + Args: + document_id: Document identifier. + extension: File extension. + expires_in_seconds: URL expiration time. + + Returns: + Presigned URL string. + """ + path = PREFIXES.document_path(document_id, extension) + return self._storage.get_presigned_url(path, expires_in_seconds) + + def document_exists(self, document_id: str, extension: str = ".pdf") -> bool: + """Check if a document exists. + + Args: + document_id: Document identifier. + extension: File extension. + + Returns: + True if document exists. + """ + path = PREFIXES.document_path(document_id, extension) + return self._storage.exists(path) + + def delete_document(self, document_id: str, extension: str = ".pdf") -> bool: + """Delete a document. + + Args: + document_id: Document identifier. + extension: File extension. + + Returns: + True if document was deleted. + """ + path = PREFIXES.document_path(document_id, extension) + return self._storage.delete(path) + + # Image operations + + def save_page_image( + self, + document_id: str, + page_num: int, + content: bytes, + ) -> str: + """Save a page image to storage. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + content: Image content as bytes. + + Returns: + Storage path where image was saved. + """ + path = PREFIXES.image_path(document_id, page_num) + self._storage.upload_bytes(content, path, overwrite=True) + return path + + def get_page_image(self, document_id: str, page_num: int) -> bytes: + """Download a page image. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + + Returns: + Image content as bytes. + """ + path = PREFIXES.image_path(document_id, page_num) + return self._storage.download_bytes(path) + + def get_page_image_url( + self, + document_id: str, + page_num: int, + expires_in_seconds: int = 3600, + ) -> str: + """Get presigned URL for a page image. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + expires_in_seconds: URL expiration time. + + Returns: + Presigned URL string. + """ + path = PREFIXES.image_path(document_id, page_num) + return self._storage.get_presigned_url(path, expires_in_seconds) + + def delete_document_images(self, document_id: str) -> int: + """Delete all images for a document. + + Args: + document_id: Document identifier. + + Returns: + Number of images deleted. + """ + prefix = f"{PREFIXES.IMAGES}/{document_id}/" + images = self._storage.list_files(prefix) + deleted = 0 + for img_path in images: + if self._storage.delete(img_path): + deleted += 1 + return deleted + + def list_document_images(self, document_id: str) -> list[str]: + """List all images for a document. + + Args: + document_id: Document identifier. + + Returns: + List of image paths. + """ + prefix = f"{PREFIXES.IMAGES}/{document_id}/" + return self._storage.list_files(prefix) + + # Upload staging operations + + def save_upload( + self, + content: bytes, + filename: str, + subfolder: str | None = None, + ) -> str: + """Save a file to upload staging area. + + Args: + content: File content as bytes. + filename: Filename to save as. + subfolder: Optional subfolder (e.g., "async"). + + Returns: + Storage path where file was saved. + """ + path = PREFIXES.upload_path(filename, subfolder) + self._storage.upload_bytes(content, path, overwrite=True) + return path + + def get_upload(self, filename: str, subfolder: str | None = None) -> bytes: + """Get a file from upload staging area. + + Args: + filename: Filename to retrieve. + subfolder: Optional subfolder. + + Returns: + File content as bytes. + """ + path = PREFIXES.upload_path(filename, subfolder) + return self._storage.download_bytes(path) + + def delete_upload(self, filename: str, subfolder: str | None = None) -> bool: + """Delete a file from upload staging area. + + Args: + filename: Filename to delete. + subfolder: Optional subfolder. + + Returns: + True if file was deleted. + """ + path = PREFIXES.upload_path(filename, subfolder) + return self._storage.delete(path) + + # Result operations + + def save_result(self, content: bytes, filename: str) -> str: + """Save a result file. + + Args: + content: File content as bytes. + filename: Filename to save as. + + Returns: + Storage path where file was saved. + """ + path = PREFIXES.result_path(filename) + self._storage.upload_bytes(content, path, overwrite=True) + return path + + def get_result(self, filename: str) -> bytes: + """Get a result file. + + Args: + filename: Filename to retrieve. + + Returns: + File content as bytes. + """ + path = PREFIXES.result_path(filename) + return self._storage.download_bytes(path) + + def get_result_url(self, filename: str, expires_in_seconds: int = 3600) -> str: + """Get presigned URL for a result file. + + Args: + filename: Filename. + expires_in_seconds: URL expiration time. + + Returns: + Presigned URL string. + """ + path = PREFIXES.result_path(filename) + return self._storage.get_presigned_url(path, expires_in_seconds) + + def result_exists(self, filename: str) -> bool: + """Check if a result file exists. + + Args: + filename: Filename to check. + + Returns: + True if file exists. + """ + path = PREFIXES.result_path(filename) + return self._storage.exists(path) + + def delete_result(self, filename: str) -> bool: + """Delete a result file. + + Args: + filename: Filename to delete. + + Returns: + True if file was deleted. + """ + path = PREFIXES.result_path(filename) + return self._storage.delete(path) + + # Export operations + + def save_export(self, content: bytes, export_id: str, filename: str) -> str: + """Save an export file. + + Args: + content: File content as bytes. + export_id: Export identifier. + filename: Filename to save as. + + Returns: + Storage path where file was saved. + """ + path = PREFIXES.export_path(export_id, filename) + self._storage.upload_bytes(content, path, overwrite=True) + return path + + def get_export_url( + self, + export_id: str, + filename: str, + expires_in_seconds: int = 3600, + ) -> str: + """Get presigned URL for an export file. + + Args: + export_id: Export identifier. + filename: Filename. + expires_in_seconds: URL expiration time. + + Returns: + Presigned URL string. + """ + path = PREFIXES.export_path(export_id, filename) + return self._storage.get_presigned_url(path, expires_in_seconds) + + # Admin image operations + + def get_admin_image_path(self, document_id: str, page_num: int) -> str: + """Get the storage path for an admin image. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + + Returns: + Storage path like "admin_images/doc123/page_1.png" + """ + return f"{PREFIXES.ADMIN_IMAGES}/{document_id}/page_{page_num}.png" + + def save_admin_image( + self, + document_id: str, + page_num: int, + content: bytes, + ) -> str: + """Save an admin page image to storage. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + content: Image content as bytes. + + Returns: + Storage path where image was saved. + """ + path = self.get_admin_image_path(document_id, page_num) + self._storage.upload_bytes(content, path, overwrite=True) + return path + + def get_admin_image(self, document_id: str, page_num: int) -> bytes: + """Download an admin page image. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + + Returns: + Image content as bytes. + """ + path = self.get_admin_image_path(document_id, page_num) + return self._storage.download_bytes(path) + + def get_admin_image_url( + self, + document_id: str, + page_num: int, + expires_in_seconds: int = 3600, + ) -> str: + """Get presigned URL for an admin page image. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + expires_in_seconds: URL expiration time. + + Returns: + Presigned URL string. + """ + path = self.get_admin_image_path(document_id, page_num) + return self._storage.get_presigned_url(path, expires_in_seconds) + + def admin_image_exists(self, document_id: str, page_num: int) -> bool: + """Check if an admin page image exists. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + + Returns: + True if image exists. + """ + path = self.get_admin_image_path(document_id, page_num) + return self._storage.exists(path) + + def list_admin_images(self, document_id: str) -> list[str]: + """List all admin images for a document. + + Args: + document_id: Document identifier. + + Returns: + List of image paths. + """ + prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/" + return self._storage.list_files(prefix) + + def delete_admin_images(self, document_id: str) -> int: + """Delete all admin images for a document. + + Args: + document_id: Document identifier. + + Returns: + Number of images deleted. + """ + prefix = f"{PREFIXES.ADMIN_IMAGES}/{document_id}/" + images = self._storage.list_files(prefix) + deleted = 0 + for img_path in images: + if self._storage.delete(img_path): + deleted += 1 + return deleted + + def get_admin_image_local_path( + self, document_id: str, page_num: int + ) -> Path | None: + """Get the local filesystem path for an admin image. + + This method is useful for serving files via FileResponse. + Only works with LocalStorageBackend; returns None for cloud storage. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + + Returns: + Path object if using local storage and file exists, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + # Cloud storage - cannot get local path + return None + + remote_path = self.get_admin_image_path(document_id, page_num) + try: + full_path = self._storage._get_full_path(remote_path) + if full_path.exists(): + return full_path + return None + except Exception: + return None + + def get_admin_image_dimensions( + self, document_id: str, page_num: int + ) -> tuple[int, int] | None: + """Get the dimensions (width, height) of an admin image. + + This method is useful for normalizing bounding box coordinates. + + Args: + document_id: Document identifier. + page_num: Page number (1-indexed). + + Returns: + Tuple of (width, height) if image exists, None otherwise. + """ + from PIL import Image + + # Try local path first for efficiency + local_path = self.get_admin_image_local_path(document_id, page_num) + if local_path is not None: + with Image.open(local_path) as img: + return img.size + + # Fall back to downloading for cloud storage + if not self.admin_image_exists(document_id, page_num): + return None + + try: + import io + image_bytes = self.get_admin_image(document_id, page_num) + with Image.open(io.BytesIO(image_bytes)) as img: + return img.size + except Exception: + return None + + # Raw PDF operations (legacy compatibility) + + def save_raw_pdf(self, content: bytes, filename: str) -> str: + """Save a raw PDF for auto-labeling pipeline. + + Args: + content: PDF content as bytes. + filename: Filename to save as. + + Returns: + Storage path where file was saved. + """ + path = f"{PREFIXES.RAW_PDFS}/{filename}" + self._storage.upload_bytes(content, path, overwrite=True) + return path + + def get_raw_pdf(self, filename: str) -> bytes: + """Get a raw PDF from storage. + + Args: + filename: Filename to retrieve. + + Returns: + PDF content as bytes. + """ + path = f"{PREFIXES.RAW_PDFS}/{filename}" + return self._storage.download_bytes(path) + + def raw_pdf_exists(self, filename: str) -> bool: + """Check if a raw PDF exists. + + Args: + filename: Filename to check. + + Returns: + True if file exists. + """ + path = f"{PREFIXES.RAW_PDFS}/{filename}" + return self._storage.exists(path) + + def get_raw_pdf_local_path(self, filename: str) -> Path | None: + """Get the local filesystem path for a raw PDF. + + Only works with LocalStorageBackend; returns None for cloud storage. + + Args: + filename: Filename to retrieve. + + Returns: + Path object if using local storage and file exists, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + path = f"{PREFIXES.RAW_PDFS}/{filename}" + try: + full_path = self._storage._get_full_path(path) + if full_path.exists(): + return full_path + return None + except Exception: + return None + + def get_raw_pdf_path(self, filename: str) -> str: + """Get the storage path for a raw PDF (not the local filesystem path). + + Args: + filename: Filename. + + Returns: + Storage path like "raw_pdfs/filename.pdf" + """ + return f"{PREFIXES.RAW_PDFS}/{filename}" + + # Result local path operations + + def get_result_local_path(self, filename: str) -> Path | None: + """Get the local filesystem path for a result file. + + Only works with LocalStorageBackend; returns None for cloud storage. + + Args: + filename: Filename to retrieve. + + Returns: + Path object if using local storage and file exists, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + path = PREFIXES.result_path(filename) + try: + full_path = self._storage._get_full_path(path) + if full_path.exists(): + return full_path + return None + except Exception: + return None + + def get_results_base_path(self) -> Path | None: + """Get the base directory path for results (local storage only). + + Used for mounting static file directories. + + Returns: + Path to results directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + base_path = self._storage._get_full_path(PREFIXES.RESULTS) + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + # Upload local path operations + + def get_upload_local_path( + self, filename: str, subfolder: str | None = None + ) -> Path | None: + """Get the local filesystem path for an upload file. + + Only works with LocalStorageBackend; returns None for cloud storage. + + Args: + filename: Filename to retrieve. + subfolder: Optional subfolder. + + Returns: + Path object if using local storage and file exists, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + path = PREFIXES.upload_path(filename, subfolder) + try: + full_path = self._storage._get_full_path(path) + if full_path.exists(): + return full_path + return None + except Exception: + return None + + def get_uploads_base_path(self, subfolder: str | None = None) -> Path | None: + """Get the base directory path for uploads (local storage only). + + Args: + subfolder: Optional subfolder (e.g., "async"). + + Returns: + Path to uploads directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + if subfolder: + base_path = self._storage._get_full_path(f"{PREFIXES.UPLOADS}/{subfolder}") + else: + base_path = self._storage._get_full_path(PREFIXES.UPLOADS) + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + def upload_exists(self, filename: str, subfolder: str | None = None) -> bool: + """Check if an upload file exists. + + Args: + filename: Filename to check. + subfolder: Optional subfolder. + + Returns: + True if file exists. + """ + path = PREFIXES.upload_path(filename, subfolder) + return self._storage.exists(path) + + # Dataset operations + + def get_datasets_base_path(self) -> Path | None: + """Get the base directory path for datasets (local storage only). + + Returns: + Path to datasets directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + base_path = self._storage._get_full_path(PREFIXES.DATASETS) + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + def get_admin_images_base_path(self) -> Path | None: + """Get the base directory path for admin images (local storage only). + + Returns: + Path to admin_images directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + base_path = self._storage._get_full_path(PREFIXES.ADMIN_IMAGES) + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + def get_raw_pdfs_base_path(self) -> Path | None: + """Get the base directory path for raw PDFs (local storage only). + + Returns: + Path to raw_pdfs directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + base_path = self._storage._get_full_path(PREFIXES.RAW_PDFS) + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + def get_autolabel_output_path(self) -> Path | None: + """Get the directory path for autolabel output (local storage only). + + Returns: + Path to autolabel_output directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + # Use a subfolder under results for autolabel output + base_path = self._storage._get_full_path("autolabel_output") + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + def get_training_data_path(self) -> Path | None: + """Get the directory path for training data exports (local storage only). + + Returns: + Path to training directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + base_path = self._storage._get_full_path("training") + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + def get_exports_base_path(self) -> Path | None: + """Get the base directory path for exports (local storage only). + + Returns: + Path to exports directory if using local storage, None otherwise. + """ + if not isinstance(self._storage, LocalStorageBackend): + return None + + try: + base_path = self._storage._get_full_path(PREFIXES.EXPORTS) + base_path.mkdir(parents=True, exist_ok=True) + return base_path + except Exception: + return None + + +# Default instance for convenience +_default_helper: StorageHelper | None = None + + +def get_storage_helper() -> StorageHelper: + """Get the default storage helper instance. + + Creates the helper on first call with default storage backend. + + Returns: + Default StorageHelper instance. + """ + global _default_helper + if _default_helper is None: + _default_helper = StorageHelper() + return _default_helper diff --git a/packages/shared/README.md b/packages/shared/README.md new file mode 100644 index 0000000..39cc783 --- /dev/null +++ b/packages/shared/README.md @@ -0,0 +1,205 @@ +# Shared Package + +Shared utilities and abstractions for the Invoice Master system. + +## Storage Abstraction Layer + +A unified storage abstraction supporting multiple backends: +- **Local filesystem** - Development and testing +- **Azure Blob Storage** - Azure cloud deployments +- **AWS S3** - AWS cloud deployments + +### Installation + +```bash +# Basic installation (local storage only) +pip install -e packages/shared + +# With Azure support +pip install -e "packages/shared[azure]" + +# With S3 support +pip install -e "packages/shared[s3]" + +# All cloud providers +pip install -e "packages/shared[all]" +``` + +### Quick Start + +```python +from shared.storage import get_storage_backend + +# Option 1: From configuration file +storage = get_storage_backend("storage.yaml") + +# Option 2: From environment variables +from shared.storage import create_storage_backend_from_env +storage = create_storage_backend_from_env() + +# Upload a file +storage.upload(Path("local/file.pdf"), "documents/file.pdf") + +# Download a file +storage.download("documents/file.pdf", Path("local/downloaded.pdf")) + +# Get pre-signed URL for frontend access +url = storage.get_presigned_url("documents/file.pdf", expires_in_seconds=3600) +``` + +### Configuration File Format + +Create a `storage.yaml` file with environment variable substitution support: + +```yaml +# Backend selection: local, azure_blob, or s3 +backend: ${STORAGE_BACKEND:-local} + +# Default pre-signed URL expiry (seconds) +presigned_url_expiry: 3600 + +# Local storage configuration +local: + base_path: ${STORAGE_BASE_PATH:-./data/storage} + +# Azure Blob Storage configuration +azure: + connection_string: ${AZURE_STORAGE_CONNECTION_STRING} + container_name: ${AZURE_STORAGE_CONTAINER:-documents} + create_container: false + +# AWS S3 configuration +s3: + bucket_name: ${AWS_S3_BUCKET} + region_name: ${AWS_REGION:-us-east-1} + access_key_id: ${AWS_ACCESS_KEY_ID} + secret_access_key: ${AWS_SECRET_ACCESS_KEY} + endpoint_url: ${AWS_ENDPOINT_URL} # Optional, for S3-compatible services + create_bucket: false +``` + +### Environment Variables + +| Variable | Backend | Description | +|----------|---------|-------------| +| `STORAGE_BACKEND` | All | Backend type: `local`, `azure_blob`, `s3` | +| `STORAGE_BASE_PATH` | Local | Base directory path | +| `AZURE_STORAGE_CONNECTION_STRING` | Azure | Connection string | +| `AZURE_STORAGE_CONTAINER` | Azure | Container name | +| `AWS_S3_BUCKET` | S3 | Bucket name | +| `AWS_REGION` | S3 | AWS region (default: us-east-1) | +| `AWS_ACCESS_KEY_ID` | S3 | Access key (optional, uses credential chain) | +| `AWS_SECRET_ACCESS_KEY` | S3 | Secret key (optional) | +| `AWS_ENDPOINT_URL` | S3 | Custom endpoint for S3-compatible services | + +### API Reference + +#### StorageBackend Interface + +```python +class StorageBackend(ABC): + def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str: + """Upload a file to storage.""" + + def download(self, remote_path: str, local_path: Path) -> Path: + """Download a file from storage.""" + + def exists(self, remote_path: str) -> bool: + """Check if a file exists.""" + + def list_files(self, prefix: str) -> list[str]: + """List files with given prefix.""" + + def delete(self, remote_path: str) -> bool: + """Delete a file.""" + + def get_url(self, remote_path: str) -> str: + """Get URL for a file.""" + + def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str: + """Generate a pre-signed URL for temporary access (1-604800 seconds).""" + + def upload_bytes(self, data: bytes, remote_path: str, overwrite: bool = False) -> str: + """Upload bytes directly.""" + + def download_bytes(self, remote_path: str) -> bytes: + """Download file as bytes.""" +``` + +#### Factory Functions + +```python +# Create from configuration file +storage = create_storage_backend_from_file("storage.yaml") + +# Create from environment variables +storage = create_storage_backend_from_env() + +# Create from StorageConfig object +config = StorageConfig(backend_type="local", base_path=Path("./data")) +storage = create_storage_backend(config) + +# Convenience function with fallback chain: config file -> env vars -> local default +storage = get_storage_backend("storage.yaml") # or None for env-only +``` + +### Pre-signed URLs + +Pre-signed URLs provide temporary access to files without exposing credentials: + +```python +# Generate URL valid for 1 hour (default) +url = storage.get_presigned_url("documents/invoice.pdf") + +# Generate URL valid for 24 hours +url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=86400) + +# Maximum expiry: 7 days (604800 seconds) +url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=604800) +``` + +**Note:** Local storage returns `file://` URLs that don't actually expire. + +### Error Handling + +```python +from shared.storage import ( + StorageError, + FileNotFoundStorageError, + PresignedUrlNotSupportedError, +) + +try: + storage.download("nonexistent.pdf", Path("local.pdf")) +except FileNotFoundStorageError as e: + print(f"File not found: {e}") +except StorageError as e: + print(f"Storage error: {e}") +``` + +### Testing with MinIO (S3-compatible) + +```bash +# Start MinIO locally +docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001" + +# Configure environment +export STORAGE_BACKEND=s3 +export AWS_S3_BUCKET=test-bucket +export AWS_ENDPOINT_URL=http://localhost:9000 +export AWS_ACCESS_KEY_ID=minioadmin +export AWS_SECRET_ACCESS_KEY=minioadmin +``` + +### Module Structure + +``` +shared/storage/ +├── __init__.py # Public exports +├── base.py # Abstract interface and exceptions +├── local.py # Local filesystem backend +├── azure.py # Azure Blob Storage backend +├── s3.py # AWS S3 backend +├── config_loader.py # YAML configuration loader +└── factory.py # Backend factory functions +``` diff --git a/packages/shared/setup.py b/packages/shared/setup.py index 018e6c6..2250877 100644 --- a/packages/shared/setup.py +++ b/packages/shared/setup.py @@ -16,4 +16,18 @@ setup( "pyyaml>=6.0", "thefuzz>=0.20.0", ], + extras_require={ + "azure": [ + "azure-storage-blob>=12.19.0", + "azure-identity>=1.15.0", + ], + "s3": [ + "boto3>=1.34.0", + ], + "all": [ + "azure-storage-blob>=12.19.0", + "azure-identity>=1.15.0", + "boto3>=1.34.0", + ], + }, ) diff --git a/packages/shared/shared/config.py b/packages/shared/shared/config.py index 425f0e2..98a44b1 100644 --- a/packages/shared/shared/config.py +++ b/packages/shared/shared/config.py @@ -58,23 +58,16 @@ def get_db_connection_string(): return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}" -# Paths Configuration - auto-detect WSL vs Windows -if _is_wsl(): - # WSL: use native Linux filesystem for better I/O performance - PATHS = { - 'csv_dir': os.path.expanduser('~/invoice-data/structured_data'), - 'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'), - 'output_dir': os.path.expanduser('~/invoice-data/dataset'), - 'reports_dir': 'reports', # Keep reports in project directory - } -else: - # Windows or native Linux: use relative paths - PATHS = { - 'csv_dir': 'data/structured_data', - 'pdf_dir': 'data/raw_pdfs', - 'output_dir': 'data/dataset', - 'reports_dir': 'reports', - } +# Paths Configuration - uses STORAGE_BASE_PATH for consistency +# All paths are relative to STORAGE_BASE_PATH (defaults to ~/invoice-data/data) +_storage_base = os.path.expanduser(os.getenv('STORAGE_BASE_PATH', '~/invoice-data/data')) + +PATHS = { + 'csv_dir': f'{_storage_base}/structured_data', + 'pdf_dir': f'{_storage_base}/raw_pdfs', + 'output_dir': f'{_storage_base}/datasets', + 'reports_dir': 'reports', # Keep reports in project directory +} # Auto-labeling Configuration AUTOLABEL = { diff --git a/packages/shared/shared/fields/__init__.py b/packages/shared/shared/fields/__init__.py new file mode 100644 index 0000000..07c66d9 --- /dev/null +++ b/packages/shared/shared/fields/__init__.py @@ -0,0 +1,46 @@ +""" +Shared Field Definitions - Single Source of Truth. + +This module provides centralized field class definitions used throughout +the invoice extraction system. All field mappings are derived from +FIELD_DEFINITIONS to ensure consistency. + +Usage: + from shared.fields import FIELD_CLASSES, CLASS_NAMES, FIELD_CLASS_IDS + +Available exports: + - FieldDefinition: Dataclass for field definition + - FIELD_DEFINITIONS: Tuple of all field definitions (immutable) + - NUM_CLASSES: Total number of field classes (10) + - CLASS_NAMES: List of class names in order [0..9] + - FIELD_CLASSES: dict[int, str] - class_id to class_name + - FIELD_CLASS_IDS: dict[str, int] - class_name to class_id + - CLASS_TO_FIELD: dict[str, str] - class_name to field_name + - CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived) + - TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields) + - ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling +""" + +from .field_config import FieldDefinition, FIELD_DEFINITIONS, NUM_CLASSES +from .mappings import ( + CLASS_NAMES, + FIELD_CLASSES, + FIELD_CLASS_IDS, + CLASS_TO_FIELD, + CSV_TO_CLASS_MAPPING, + TRAINING_FIELD_CLASSES, + ACCOUNT_FIELD_MAPPING, +) + +__all__ = [ + "FieldDefinition", + "FIELD_DEFINITIONS", + "NUM_CLASSES", + "CLASS_NAMES", + "FIELD_CLASSES", + "FIELD_CLASS_IDS", + "CLASS_TO_FIELD", + "CSV_TO_CLASS_MAPPING", + "TRAINING_FIELD_CLASSES", + "ACCOUNT_FIELD_MAPPING", +] diff --git a/packages/shared/shared/fields/field_config.py b/packages/shared/shared/fields/field_config.py new file mode 100644 index 0000000..de550e3 --- /dev/null +++ b/packages/shared/shared/fields/field_config.py @@ -0,0 +1,58 @@ +""" +Field Configuration - Single Source of Truth + +This module defines all invoice field classes used throughout the system. +The class IDs are verified against the trained YOLO model (best.pt). + +IMPORTANT: Do not modify class_id values without retraining the model! +""" + +from dataclasses import dataclass +from typing import Final + + +@dataclass(frozen=True) +class FieldDefinition: + """Immutable field definition for invoice extraction. + + Attributes: + class_id: YOLO class ID (0-9), must match trained model + class_name: YOLO class name (lowercase_underscore) + field_name: Business field name used in API responses + csv_name: CSV column name for data import/export + is_derived: True if field is derived from other fields (not in CSV) + """ + + class_id: int + class_name: str + field_name: str + csv_name: str + is_derived: bool = False + + +# Verified from model weights (runs/train/invoice_fields/weights/best.pt) +# model.names = {0: 'invoice_number', 1: 'invoice_date', ..., 8: 'customer_number', 9: 'payment_line'} +# +# DO NOT CHANGE THE ORDER - it must match the trained model! +FIELD_DEFINITIONS: Final[tuple[FieldDefinition, ...]] = ( + FieldDefinition(0, "invoice_number", "InvoiceNumber", "InvoiceNumber"), + FieldDefinition(1, "invoice_date", "InvoiceDate", "InvoiceDate"), + FieldDefinition(2, "invoice_due_date", "InvoiceDueDate", "InvoiceDueDate"), + FieldDefinition(3, "ocr_number", "OCR", "OCR"), + FieldDefinition(4, "bankgiro", "Bankgiro", "Bankgiro"), + FieldDefinition(5, "plusgiro", "Plusgiro", "Plusgiro"), + FieldDefinition(6, "amount", "Amount", "Amount"), + FieldDefinition( + 7, + "supplier_org_number", + "supplier_organisation_number", + "supplier_organisation_number", + ), + FieldDefinition(8, "customer_number", "customer_number", "customer_number"), + FieldDefinition( + 9, "payment_line", "payment_line", "payment_line", is_derived=True + ), +) + +# Total number of field classes +NUM_CLASSES: Final[int] = len(FIELD_DEFINITIONS) diff --git a/packages/shared/shared/fields/mappings.py b/packages/shared/shared/fields/mappings.py new file mode 100644 index 0000000..18013dd --- /dev/null +++ b/packages/shared/shared/fields/mappings.py @@ -0,0 +1,57 @@ +""" +Field Mappings - Auto-generated from FIELD_DEFINITIONS. + +All mappings in this file are derived from field_config.FIELD_DEFINITIONS. +This ensures consistency across the entire codebase. + +DO NOT hardcode field mappings elsewhere - always import from this module. +""" + +from typing import Final + +from .field_config import FIELD_DEFINITIONS + + +# List of class names in order (for YOLO classes.txt generation) +# Index matches class_id: CLASS_NAMES[0] = "invoice_number" +CLASS_NAMES: Final[list[str]] = [fd.class_name for fd in FIELD_DEFINITIONS] + +# class_id -> class_name mapping +# Example: {0: "invoice_number", 1: "invoice_date", ...} +FIELD_CLASSES: Final[dict[int, str]] = { + fd.class_id: fd.class_name for fd in FIELD_DEFINITIONS +} + +# class_name -> class_id mapping (reverse of FIELD_CLASSES) +# Example: {"invoice_number": 0, "invoice_date": 1, ...} +FIELD_CLASS_IDS: Final[dict[str, int]] = { + fd.class_name: fd.class_id for fd in FIELD_DEFINITIONS +} + +# class_name -> field_name mapping (for API responses) +# Example: {"invoice_number": "InvoiceNumber", "ocr_number": "OCR", ...} +CLASS_TO_FIELD: Final[dict[str, str]] = { + fd.class_name: fd.field_name for fd in FIELD_DEFINITIONS +} + +# field_name -> class_id mapping (for CSV import) +# Excludes derived fields like payment_line +# Example: {"InvoiceNumber": 0, "InvoiceDate": 1, ...} +CSV_TO_CLASS_MAPPING: Final[dict[str, int]] = { + fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS if not fd.is_derived +} + +# field_name -> class_id mapping (for training, includes all fields) +# Example: {"InvoiceNumber": 0, ..., "payment_line": 9} +TRAINING_FIELD_CLASSES: Final[dict[str, int]] = { + fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS +} + +# Account field mapping for supplier_accounts special handling +# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro +ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = { + "supplier_accounts": { + "BG": "Bankgiro", + "PG": "Plusgiro", + } +} diff --git a/packages/shared/shared/storage/__init__.py b/packages/shared/shared/storage/__init__.py new file mode 100644 index 0000000..b44f7cd --- /dev/null +++ b/packages/shared/shared/storage/__init__.py @@ -0,0 +1,59 @@ +""" +Storage abstraction layer for training data. + +Provides a unified interface for local filesystem, Azure Blob Storage, and AWS S3. +""" + +from shared.storage.base import ( + FileNotFoundStorageError, + PresignedUrlNotSupportedError, + StorageBackend, + StorageConfig, + StorageError, +) +from shared.storage.factory import ( + create_storage_backend, + create_storage_backend_from_env, + create_storage_backend_from_file, + get_default_storage_config, + get_storage_backend, +) +from shared.storage.local import LocalStorageBackend +from shared.storage.prefixes import PREFIXES, StoragePrefixes + +__all__ = [ + # Base classes and exceptions + "StorageBackend", + "StorageConfig", + "StorageError", + "FileNotFoundStorageError", + "PresignedUrlNotSupportedError", + # Backends + "LocalStorageBackend", + # Factory functions + "create_storage_backend", + "create_storage_backend_from_env", + "create_storage_backend_from_file", + "get_default_storage_config", + "get_storage_backend", + # Path prefixes + "PREFIXES", + "StoragePrefixes", +] + + +# Lazy imports to avoid dependencies when not using specific backends +def __getattr__(name: str): + if name == "AzureBlobStorageBackend": + from shared.storage.azure import AzureBlobStorageBackend + + return AzureBlobStorageBackend + if name == "S3StorageBackend": + from shared.storage.s3 import S3StorageBackend + + return S3StorageBackend + if name == "load_storage_config": + from shared.storage.config_loader import load_storage_config + + return load_storage_config + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/packages/shared/shared/storage/azure.py b/packages/shared/shared/storage/azure.py new file mode 100644 index 0000000..0ae374a --- /dev/null +++ b/packages/shared/shared/storage/azure.py @@ -0,0 +1,335 @@ +""" +Azure Blob Storage backend. + +Provides storage operations using Azure Blob Storage. +""" + +from pathlib import Path + +from azure.storage.blob import ( + BlobSasPermissions, + BlobServiceClient, + ContainerClient, + generate_blob_sas, +) + +from shared.storage.base import ( + FileNotFoundStorageError, + StorageBackend, + StorageError, +) + + +class AzureBlobStorageBackend(StorageBackend): + """Storage backend using Azure Blob Storage. + + Files are stored as blobs in an Azure Blob Storage container. + """ + + def __init__( + self, + connection_string: str, + container_name: str, + create_container: bool = False, + ) -> None: + """Initialize Azure Blob Storage backend. + + Args: + connection_string: Azure Storage connection string. + container_name: Name of the blob container. + create_container: If True, create the container if it doesn't exist. + """ + self._connection_string = connection_string + self._container_name = container_name + + self._blob_service = BlobServiceClient.from_connection_string(connection_string) + self._container = self._blob_service.get_container_client(container_name) + + # Extract account key from connection string for SAS token generation + self._account_key = self._extract_account_key(connection_string) + + if create_container and not self._container.exists(): + self._container.create_container() + + @staticmethod + def _extract_account_key(connection_string: str) -> str | None: + """Extract account key from connection string. + + Args: + connection_string: Azure Storage connection string. + + Returns: + Account key if found, None otherwise. + """ + for part in connection_string.split(";"): + if part.startswith("AccountKey="): + return part[len("AccountKey=") :] + return None + + @property + def container_name(self) -> str: + """Get the container name for this storage backend.""" + return self._container_name + + @property + def container_client(self) -> ContainerClient: + """Get the Azure container client.""" + return self._container + + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + """Upload a file to Azure Blob Storage. + + Args: + local_path: Path to the local file to upload. + remote_path: Destination blob path. + overwrite: If True, overwrite existing blob. + + Returns: + The remote path where the file was stored. + + Raises: + FileNotFoundStorageError: If local_path doesn't exist. + StorageError: If blob exists and overwrite is False. + """ + if not local_path.exists(): + raise FileNotFoundStorageError(str(local_path)) + + blob_client = self._container.get_blob_client(remote_path) + + if blob_client.exists() and not overwrite: + raise StorageError(f"File already exists: {remote_path}") + + with open(local_path, "rb") as f: + blob_client.upload_blob(f, overwrite=overwrite) + + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + """Download a blob from Azure Blob Storage. + + Args: + remote_path: Blob path in storage. + local_path: Local destination path. + + Returns: + The local path where the file was downloaded. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + blob_client = self._container.get_blob_client(remote_path) + + if not blob_client.exists(): + raise FileNotFoundStorageError(remote_path) + + local_path.parent.mkdir(parents=True, exist_ok=True) + + stream = blob_client.download_blob() + local_path.write_bytes(stream.readall()) + + return local_path + + def exists(self, remote_path: str) -> bool: + """Check if a blob exists in storage. + + Args: + remote_path: Blob path to check. + + Returns: + True if the blob exists, False otherwise. + """ + blob_client = self._container.get_blob_client(remote_path) + return blob_client.exists() + + def list_files(self, prefix: str) -> list[str]: + """List blobs in storage with given prefix. + + Args: + prefix: Blob path prefix to filter. + + Returns: + List of blob paths matching the prefix. + """ + if prefix: + blobs = self._container.list_blobs(name_starts_with=prefix) + else: + blobs = self._container.list_blobs() + + return [blob.name for blob in blobs] + + def delete(self, remote_path: str) -> bool: + """Delete a blob from storage. + + Args: + remote_path: Blob path to delete. + + Returns: + True if blob was deleted, False if it didn't exist. + """ + blob_client = self._container.get_blob_client(remote_path) + + if not blob_client.exists(): + return False + + blob_client.delete_blob() + return True + + def get_url(self, remote_path: str) -> str: + """Get the URL for a blob. + + Args: + remote_path: Blob path in storage. + + Returns: + URL to access the blob. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + blob_client = self._container.get_blob_client(remote_path) + + if not blob_client.exists(): + raise FileNotFoundStorageError(remote_path) + + return blob_client.url + + def upload_bytes( + self, data: bytes, remote_path: str, overwrite: bool = False + ) -> str: + """Upload bytes directly to Azure Blob Storage. + + Args: + data: Bytes to upload. + remote_path: Destination blob path. + overwrite: If True, overwrite existing blob. + + Returns: + The remote path where the data was stored. + """ + blob_client = self._container.get_blob_client(remote_path) + + if blob_client.exists() and not overwrite: + raise StorageError(f"File already exists: {remote_path}") + + blob_client.upload_blob(data, overwrite=overwrite) + + return remote_path + + def download_bytes(self, remote_path: str) -> bytes: + """Download a blob as bytes. + + Args: + remote_path: Blob path in storage. + + Returns: + The blob contents as bytes. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + blob_client = self._container.get_blob_client(remote_path) + + if not blob_client.exists(): + raise FileNotFoundStorageError(remote_path) + + stream = blob_client.download_blob() + return stream.readall() + + def upload_directory( + self, local_dir: Path, remote_prefix: str, overwrite: bool = False + ) -> list[str]: + """Upload all files in a directory to Azure Blob Storage. + + Args: + local_dir: Local directory to upload. + remote_prefix: Prefix for remote blob paths. + overwrite: If True, overwrite existing blobs. + + Returns: + List of remote paths where files were stored. + """ + uploaded: list[str] = [] + + for file_path in local_dir.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(local_dir) + remote_path = f"{remote_prefix}{relative_path}".replace("\\", "/") + self.upload(file_path, remote_path, overwrite=overwrite) + uploaded.append(remote_path) + + return uploaded + + def download_directory( + self, remote_prefix: str, local_dir: Path + ) -> list[Path]: + """Download all blobs with a prefix to a local directory. + + Args: + remote_prefix: Blob path prefix to download. + local_dir: Local directory to download to. + + Returns: + List of local paths where files were downloaded. + """ + downloaded: list[Path] = [] + + blobs = self.list_files(remote_prefix) + + for blob_path in blobs: + # Remove prefix to get relative path + if remote_prefix: + relative_path = blob_path[len(remote_prefix):] + if relative_path.startswith("/"): + relative_path = relative_path[1:] + else: + relative_path = blob_path + + local_path = local_dir / relative_path + self.download(blob_path, local_path) + downloaded.append(local_path) + + return downloaded + + def get_presigned_url( + self, + remote_path: str, + expires_in_seconds: int = 3600, + ) -> str: + """Generate a SAS URL for temporary blob access. + + Args: + remote_path: Blob path in storage. + expires_in_seconds: SAS token validity duration (1 to 604800 seconds / 7 days). + + Returns: + Blob URL with SAS token. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + ValueError: If expires_in_seconds is out of valid range. + """ + if expires_in_seconds < 1 or expires_in_seconds > 604800: + raise ValueError( + "expires_in_seconds must be between 1 and 604800 (7 days)" + ) + + from datetime import datetime, timedelta, timezone + + blob_client = self._container.get_blob_client(remote_path) + + if not blob_client.exists(): + raise FileNotFoundStorageError(remote_path) + + # Generate SAS token + sas_token = generate_blob_sas( + account_name=self._blob_service.account_name, + container_name=self._container_name, + blob_name=remote_path, + account_key=self._account_key, + permission=BlobSasPermissions(read=True), + expiry=datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds), + ) + + return f"{blob_client.url}?{sas_token}" diff --git a/packages/shared/shared/storage/base.py b/packages/shared/shared/storage/base.py new file mode 100644 index 0000000..f51f951 --- /dev/null +++ b/packages/shared/shared/storage/base.py @@ -0,0 +1,229 @@ +""" +Base classes and interfaces for storage backends. + +Defines the abstract StorageBackend interface and common exceptions. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path + + +class StorageError(Exception): + """Base exception for storage operations.""" + + pass + + +class FileNotFoundStorageError(StorageError): + """Raised when a file is not found in storage.""" + + def __init__(self, path: str) -> None: + self.path = path + super().__init__(f"File not found in storage: {path}") + + +class PresignedUrlNotSupportedError(StorageError): + """Raised when pre-signed URLs are not supported by a backend.""" + + def __init__(self, backend_type: str) -> None: + self.backend_type = backend_type + super().__init__(f"Pre-signed URLs not supported for backend: {backend_type}") + + +@dataclass(frozen=True) +class StorageConfig: + """Configuration for storage backend. + + Attributes: + backend_type: Type of storage backend ("local", "azure_blob", or "s3"). + connection_string: Azure Blob Storage connection string (for azure_blob). + container_name: Azure Blob Storage container name (for azure_blob). + base_path: Base path for local storage (for local). + bucket_name: S3 bucket name (for s3). + region_name: AWS region name (for s3). + access_key_id: AWS access key ID (for s3). + secret_access_key: AWS secret access key (for s3). + endpoint_url: Custom endpoint URL for S3-compatible services (for s3). + presigned_url_expiry: Default expiry for pre-signed URLs in seconds. + """ + + backend_type: str + connection_string: str | None = None + container_name: str | None = None + base_path: Path | None = None + bucket_name: str | None = None + region_name: str | None = None + access_key_id: str | None = None + secret_access_key: str | None = None + endpoint_url: str | None = None + presigned_url_expiry: int = 3600 + + +class StorageBackend(ABC): + """Abstract base class for storage backends. + + Provides a unified interface for storing and retrieving files + from different storage systems (local filesystem, Azure Blob, etc.). + """ + + @abstractmethod + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + """Upload a file to storage. + + Args: + local_path: Path to the local file to upload. + remote_path: Destination path in storage. + overwrite: If True, overwrite existing file. + + Returns: + The remote path where the file was stored. + + Raises: + FileNotFoundStorageError: If local_path doesn't exist. + StorageError: If file exists and overwrite is False. + """ + pass + + @abstractmethod + def download(self, remote_path: str, local_path: Path) -> Path: + """Download a file from storage. + + Args: + remote_path: Path to the file in storage. + local_path: Local destination path. + + Returns: + The local path where the file was downloaded. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + pass + + @abstractmethod + def exists(self, remote_path: str) -> bool: + """Check if a file exists in storage. + + Args: + remote_path: Path to check in storage. + + Returns: + True if the file exists, False otherwise. + """ + pass + + @abstractmethod + def list_files(self, prefix: str) -> list[str]: + """List files in storage with given prefix. + + Args: + prefix: Path prefix to filter files. + + Returns: + List of file paths matching the prefix. + """ + pass + + @abstractmethod + def delete(self, remote_path: str) -> bool: + """Delete a file from storage. + + Args: + remote_path: Path to the file to delete. + + Returns: + True if file was deleted, False if it didn't exist. + """ + pass + + @abstractmethod + def get_url(self, remote_path: str) -> str: + """Get a URL or path to access a file. + + Args: + remote_path: Path to the file in storage. + + Returns: + URL or path to access the file. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + pass + + @abstractmethod + def get_presigned_url( + self, + remote_path: str, + expires_in_seconds: int = 3600, + ) -> str: + """Generate a pre-signed URL for temporary access. + + Args: + remote_path: Path to the file in storage. + expires_in_seconds: URL validity duration (default 1 hour). + + Returns: + Pre-signed URL string. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + PresignedUrlNotSupportedError: If backend doesn't support pre-signed URLs. + """ + pass + + def upload_bytes( + self, data: bytes, remote_path: str, overwrite: bool = False + ) -> str: + """Upload bytes directly to storage. + + Default implementation writes to temp file then uploads. + Subclasses may override for more efficient implementation. + + Args: + data: Bytes to upload. + remote_path: Destination path in storage. + overwrite: If True, overwrite existing file. + + Returns: + The remote path where the data was stored. + """ + import tempfile + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(data) + temp_path = Path(f.name) + + try: + return self.upload(temp_path, remote_path, overwrite=overwrite) + finally: + temp_path.unlink(missing_ok=True) + + def download_bytes(self, remote_path: str) -> bytes: + """Download a file as bytes. + + Default implementation downloads to temp file then reads. + Subclasses may override for more efficient implementation. + + Args: + remote_path: Path to the file in storage. + + Returns: + The file contents as bytes. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + import tempfile + + with tempfile.NamedTemporaryFile(delete=False) as f: + temp_path = Path(f.name) + + try: + self.download(remote_path, temp_path) + return temp_path.read_bytes() + finally: + temp_path.unlink(missing_ok=True) diff --git a/packages/shared/shared/storage/config_loader.py b/packages/shared/shared/storage/config_loader.py new file mode 100644 index 0000000..ec68d47 --- /dev/null +++ b/packages/shared/shared/storage/config_loader.py @@ -0,0 +1,242 @@ +""" +Configuration file loader for storage backends. + +Supports YAML configuration files with environment variable substitution. +""" + +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import yaml + + +@dataclass(frozen=True) +class LocalConfig: + """Local storage backend configuration.""" + + base_path: Path + + +@dataclass(frozen=True) +class AzureConfig: + """Azure Blob Storage configuration.""" + + connection_string: str + container_name: str + create_container: bool = False + + +@dataclass(frozen=True) +class S3Config: + """AWS S3 configuration.""" + + bucket_name: str + region_name: str | None = None + access_key_id: str | None = None + secret_access_key: str | None = None + endpoint_url: str | None = None + create_bucket: bool = False + + +@dataclass(frozen=True) +class StorageFileConfig: + """Extended storage configuration from file. + + Attributes: + backend_type: Type of storage backend. + local: Local backend configuration. + azure: Azure Blob configuration. + s3: S3 configuration. + presigned_url_expiry: Default expiry for pre-signed URLs in seconds. + """ + + backend_type: str + local: LocalConfig | None = None + azure: AzureConfig | None = None + s3: S3Config | None = None + presigned_url_expiry: int = 3600 + + +def substitute_env_vars(value: str) -> str: + """Substitute environment variables in a string. + + Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax. + + Args: + value: String potentially containing env var references. + + Returns: + String with env vars substituted. + """ + pattern = r"\$\{([A-Z_][A-Z0-9_]*)(?::-([^}]*))?\}" + + def replace(match: re.Match[str]) -> str: + var_name = match.group(1) + default = match.group(2) + return os.environ.get(var_name, default or "") + + return re.sub(pattern, replace, value) + + +def _substitute_in_dict(data: dict[str, Any]) -> dict[str, Any]: + """Recursively substitute env vars in a dictionary. + + Args: + data: Dictionary to process. + + Returns: + New dictionary with substitutions applied. + """ + result: dict[str, Any] = {} + for key, value in data.items(): + if isinstance(value, str): + result[key] = substitute_env_vars(value) + elif isinstance(value, dict): + result[key] = _substitute_in_dict(value) + elif isinstance(value, list): + result[key] = [ + substitute_env_vars(item) if isinstance(item, str) else item + for item in value + ] + else: + result[key] = value + return result + + +def _parse_local_config(data: dict[str, Any]) -> LocalConfig: + """Parse local configuration section. + + Args: + data: Dictionary containing local config. + + Returns: + LocalConfig instance. + + Raises: + ValueError: If required fields are missing. + """ + base_path = data.get("base_path") + if not base_path: + raise ValueError("local.base_path is required") + return LocalConfig(base_path=Path(base_path)) + + +def _parse_azure_config(data: dict[str, Any]) -> AzureConfig: + """Parse Azure configuration section. + + Args: + data: Dictionary containing Azure config. + + Returns: + AzureConfig instance. + + Raises: + ValueError: If required fields are missing. + """ + connection_string = data.get("connection_string") + container_name = data.get("container_name") + + if not connection_string: + raise ValueError("azure.connection_string is required") + if not container_name: + raise ValueError("azure.container_name is required") + + return AzureConfig( + connection_string=connection_string, + container_name=container_name, + create_container=data.get("create_container", False), + ) + + +def _parse_s3_config(data: dict[str, Any]) -> S3Config: + """Parse S3 configuration section. + + Args: + data: Dictionary containing S3 config. + + Returns: + S3Config instance. + + Raises: + ValueError: If required fields are missing. + """ + bucket_name = data.get("bucket_name") + + if not bucket_name: + raise ValueError("s3.bucket_name is required") + + return S3Config( + bucket_name=bucket_name, + region_name=data.get("region_name"), + access_key_id=data.get("access_key_id"), + secret_access_key=data.get("secret_access_key"), + endpoint_url=data.get("endpoint_url"), + create_bucket=data.get("create_bucket", False), + ) + + +def load_storage_config(config_path: Path | str) -> StorageFileConfig: + """Load storage configuration from YAML file. + + Supports environment variable substitution using ${VAR_NAME} or + ${VAR_NAME:-default} syntax. + + Args: + config_path: Path to configuration file. + + Returns: + Parsed StorageFileConfig. + + Raises: + FileNotFoundError: If config file doesn't exist. + ValueError: If config is invalid. + """ + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + try: + raw_content = config_path.read_text(encoding="utf-8") + data = yaml.safe_load(raw_content) + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in config file: {e}") from e + + if not isinstance(data, dict): + raise ValueError("Config file must contain a YAML dictionary") + + # Substitute environment variables + data = _substitute_in_dict(data) + + # Extract backend type + backend_type = data.get("backend") + if not backend_type: + raise ValueError("'backend' field is required in config file") + + # Parse presigned URL expiry + presigned_url_expiry = data.get("presigned_url_expiry", 3600) + + # Parse backend-specific configurations + local_config = None + azure_config = None + s3_config = None + + if "local" in data: + local_config = _parse_local_config(data["local"]) + + if "azure" in data: + azure_config = _parse_azure_config(data["azure"]) + + if "s3" in data: + s3_config = _parse_s3_config(data["s3"]) + + return StorageFileConfig( + backend_type=backend_type, + local=local_config, + azure=azure_config, + s3=s3_config, + presigned_url_expiry=presigned_url_expiry, + ) diff --git a/packages/shared/shared/storage/factory.py b/packages/shared/shared/storage/factory.py new file mode 100644 index 0000000..bdcc5e3 --- /dev/null +++ b/packages/shared/shared/storage/factory.py @@ -0,0 +1,296 @@ +""" +Factory functions for creating storage backends. + +Provides convenient functions for creating storage backends from +configuration or environment variables. +""" + +import os +from pathlib import Path + +from shared.storage.base import StorageBackend, StorageConfig + + +def create_storage_backend(config: StorageConfig) -> StorageBackend: + """Create a storage backend from configuration. + + Args: + config: Storage configuration. + + Returns: + A configured storage backend. + + Raises: + ValueError: If configuration is invalid. + """ + if config.backend_type == "local": + if config.base_path is None: + raise ValueError("base_path is required for local storage backend") + + from shared.storage.local import LocalStorageBackend + + return LocalStorageBackend(base_path=config.base_path) + + elif config.backend_type == "azure_blob": + if config.connection_string is None: + raise ValueError( + "connection_string is required for Azure blob storage backend" + ) + if config.container_name is None: + raise ValueError( + "container_name is required for Azure blob storage backend" + ) + + # Import here to allow lazy loading of Azure SDK + from azure.storage.blob import BlobServiceClient # noqa: F401 + + from shared.storage.azure import AzureBlobStorageBackend + + return AzureBlobStorageBackend( + connection_string=config.connection_string, + container_name=config.container_name, + ) + + elif config.backend_type == "s3": + if config.bucket_name is None: + raise ValueError("bucket_name is required for S3 storage backend") + + # Import here to allow lazy loading of boto3 + import boto3 # noqa: F401 + + from shared.storage.s3 import S3StorageBackend + + return S3StorageBackend( + bucket_name=config.bucket_name, + region_name=config.region_name, + access_key_id=config.access_key_id, + secret_access_key=config.secret_access_key, + endpoint_url=config.endpoint_url, + ) + + else: + raise ValueError(f"Unknown storage backend type: {config.backend_type}") + + +def get_default_storage_config() -> StorageConfig: + """Get storage configuration from environment variables. + + Environment variables: + STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local". + STORAGE_BASE_PATH: Base path for local storage. + AZURE_STORAGE_CONNECTION_STRING: Azure connection string. + AZURE_STORAGE_CONTAINER: Azure container name. + AWS_S3_BUCKET: S3 bucket name. + AWS_REGION: AWS region name. + AWS_ACCESS_KEY_ID: AWS access key ID. + AWS_SECRET_ACCESS_KEY: AWS secret access key. + AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services. + + Returns: + StorageConfig from environment. + """ + backend_type = os.environ.get("STORAGE_BACKEND", "local") + + if backend_type == "local": + base_path_str = os.environ.get("STORAGE_BASE_PATH") + # Expand ~ to home directory + base_path = Path(os.path.expanduser(base_path_str)) if base_path_str else None + + return StorageConfig( + backend_type="local", + base_path=base_path, + ) + + elif backend_type == "azure_blob": + return StorageConfig( + backend_type="azure_blob", + connection_string=os.environ.get("AZURE_STORAGE_CONNECTION_STRING"), + container_name=os.environ.get("AZURE_STORAGE_CONTAINER"), + ) + + elif backend_type == "s3": + return StorageConfig( + backend_type="s3", + bucket_name=os.environ.get("AWS_S3_BUCKET"), + region_name=os.environ.get("AWS_REGION"), + access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), + secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), + endpoint_url=os.environ.get("AWS_ENDPOINT_URL"), + ) + + else: + return StorageConfig(backend_type=backend_type) + + +def create_storage_backend_from_env() -> StorageBackend: + """Create a storage backend from environment variables. + + Environment variables: + STORAGE_BACKEND: Backend type ("local", "azure_blob", or "s3"), defaults to "local". + STORAGE_BASE_PATH: Base path for local storage. + AZURE_STORAGE_CONNECTION_STRING: Azure connection string. + AZURE_STORAGE_CONTAINER: Azure container name. + AWS_S3_BUCKET: S3 bucket name. + AWS_REGION: AWS region name. + AWS_ACCESS_KEY_ID: AWS access key ID. + AWS_SECRET_ACCESS_KEY: AWS secret access key. + AWS_ENDPOINT_URL: Custom endpoint URL for S3-compatible services. + + Returns: + A configured storage backend. + + Raises: + ValueError: If required environment variables are missing or empty. + """ + backend_type = os.environ.get("STORAGE_BACKEND", "local").strip() + + if backend_type == "local": + base_path = os.environ.get("STORAGE_BASE_PATH", "").strip() + if not base_path: + raise ValueError( + "STORAGE_BASE_PATH environment variable is required and cannot be empty" + ) + + # Expand ~ to home directory + base_path_expanded = os.path.expanduser(base_path) + + from shared.storage.local import LocalStorageBackend + + return LocalStorageBackend(base_path=Path(base_path_expanded)) + + elif backend_type == "azure_blob": + connection_string = os.environ.get( + "AZURE_STORAGE_CONNECTION_STRING", "" + ).strip() + if not connection_string: + raise ValueError( + "AZURE_STORAGE_CONNECTION_STRING environment variable is required " + "and cannot be empty" + ) + + container_name = os.environ.get("AZURE_STORAGE_CONTAINER", "").strip() + if not container_name: + raise ValueError( + "AZURE_STORAGE_CONTAINER environment variable is required " + "and cannot be empty" + ) + + # Import here to allow lazy loading of Azure SDK + from azure.storage.blob import BlobServiceClient # noqa: F401 + + from shared.storage.azure import AzureBlobStorageBackend + + return AzureBlobStorageBackend( + connection_string=connection_string, + container_name=container_name, + ) + + elif backend_type == "s3": + bucket_name = os.environ.get("AWS_S3_BUCKET", "").strip() + if not bucket_name: + raise ValueError( + "AWS_S3_BUCKET environment variable is required and cannot be empty" + ) + + # Import here to allow lazy loading of boto3 + import boto3 # noqa: F401 + + from shared.storage.s3 import S3StorageBackend + + return S3StorageBackend( + bucket_name=bucket_name, + region_name=os.environ.get("AWS_REGION", "").strip() or None, + access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "").strip() or None, + secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "").strip() + or None, + endpoint_url=os.environ.get("AWS_ENDPOINT_URL", "").strip() or None, + ) + + else: + raise ValueError(f"Unknown storage backend type: {backend_type}") + + +def create_storage_backend_from_file(config_path: Path | str) -> StorageBackend: + """Create a storage backend from a configuration file. + + Args: + config_path: Path to YAML configuration file. + + Returns: + A configured storage backend. + + Raises: + FileNotFoundError: If config file doesn't exist. + ValueError: If configuration is invalid. + """ + from shared.storage.config_loader import load_storage_config + + file_config = load_storage_config(config_path) + + if file_config.backend_type == "local": + if file_config.local is None: + raise ValueError("local configuration section is required") + + from shared.storage.local import LocalStorageBackend + + return LocalStorageBackend(base_path=file_config.local.base_path) + + elif file_config.backend_type == "azure_blob": + if file_config.azure is None: + raise ValueError("azure configuration section is required") + + # Import here to allow lazy loading of Azure SDK + from azure.storage.blob import BlobServiceClient # noqa: F401 + + from shared.storage.azure import AzureBlobStorageBackend + + return AzureBlobStorageBackend( + connection_string=file_config.azure.connection_string, + container_name=file_config.azure.container_name, + create_container=file_config.azure.create_container, + ) + + elif file_config.backend_type == "s3": + if file_config.s3 is None: + raise ValueError("s3 configuration section is required") + + # Import here to allow lazy loading of boto3 + import boto3 # noqa: F401 + + from shared.storage.s3 import S3StorageBackend + + return S3StorageBackend( + bucket_name=file_config.s3.bucket_name, + region_name=file_config.s3.region_name, + access_key_id=file_config.s3.access_key_id, + secret_access_key=file_config.s3.secret_access_key, + endpoint_url=file_config.s3.endpoint_url, + create_bucket=file_config.s3.create_bucket, + ) + + else: + raise ValueError(f"Unknown storage backend type: {file_config.backend_type}") + + +def get_storage_backend(config_path: Path | str | None = None) -> StorageBackend: + """Get storage backend with fallback chain. + + Priority: + 1. Config file (if provided) + 2. Environment variables + + Args: + config_path: Optional path to config file. + + Returns: + A configured storage backend. + + Raises: + ValueError: If configuration is invalid. + FileNotFoundError: If specified config file doesn't exist. + """ + if config_path: + return create_storage_backend_from_file(config_path) + + # Fall back to environment variables + return create_storage_backend_from_env() diff --git a/packages/shared/shared/storage/local.py b/packages/shared/shared/storage/local.py new file mode 100644 index 0000000..55c5a20 --- /dev/null +++ b/packages/shared/shared/storage/local.py @@ -0,0 +1,262 @@ +""" +Local filesystem storage backend. + +Provides storage operations using the local filesystem. +""" + +import shutil +from pathlib import Path + +from shared.storage.base import ( + FileNotFoundStorageError, + StorageBackend, + StorageError, +) + + +class LocalStorageBackend(StorageBackend): + """Storage backend using local filesystem. + + Files are stored relative to a base path on the local filesystem. + """ + + def __init__(self, base_path: str | Path) -> None: + """Initialize local storage backend. + + Args: + base_path: Base directory for all storage operations. + Will be created if it doesn't exist. + """ + self._base_path = Path(base_path) + self._base_path.mkdir(parents=True, exist_ok=True) + + @property + def base_path(self) -> Path: + """Get the base path for this storage backend.""" + return self._base_path + + def _get_full_path(self, remote_path: str) -> Path: + """Convert a remote path to a full local path with security validation. + + Args: + remote_path: The remote path to resolve. + + Returns: + The full local path. + + Raises: + StorageError: If the path attempts to escape the base directory. + """ + # Reject absolute paths + if remote_path.startswith("/") or (len(remote_path) > 1 and remote_path[1] == ":"): + raise StorageError(f"Absolute paths not allowed: {remote_path}") + + # Resolve to prevent path traversal attacks + full_path = (self._base_path / remote_path).resolve() + base_resolved = self._base_path.resolve() + + # Verify the resolved path is within base_path + try: + full_path.relative_to(base_resolved) + except ValueError: + raise StorageError(f"Path traversal not allowed: {remote_path}") + + return full_path + + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + """Upload a file to local storage. + + Args: + local_path: Path to the local file to upload. + remote_path: Destination path in storage. + overwrite: If True, overwrite existing file. + + Returns: + The remote path where the file was stored. + + Raises: + FileNotFoundStorageError: If local_path doesn't exist. + StorageError: If file exists and overwrite is False. + """ + if not local_path.exists(): + raise FileNotFoundStorageError(str(local_path)) + + dest_path = self._get_full_path(remote_path) + + if dest_path.exists() and not overwrite: + raise StorageError(f"File already exists: {remote_path}") + + dest_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(local_path, dest_path) + + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + """Download a file from local storage. + + Args: + remote_path: Path to the file in storage. + local_path: Local destination path. + + Returns: + The local path where the file was downloaded. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + source_path = self._get_full_path(remote_path) + + if not source_path.exists(): + raise FileNotFoundStorageError(remote_path) + + local_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source_path, local_path) + + return local_path + + def exists(self, remote_path: str) -> bool: + """Check if a file exists in storage. + + Args: + remote_path: Path to check in storage. + + Returns: + True if the file exists, False otherwise. + """ + return self._get_full_path(remote_path).exists() + + def list_files(self, prefix: str) -> list[str]: + """List files in storage with given prefix. + + Args: + prefix: Path prefix to filter files. + + Returns: + Sorted list of file paths matching the prefix. + """ + if prefix: + search_path = self._get_full_path(prefix) + if not search_path.exists(): + return [] + base_for_relative = self._base_path + else: + search_path = self._base_path + base_for_relative = self._base_path + + files: list[str] = [] + if search_path.is_file(): + files.append(str(search_path.relative_to(self._base_path))) + elif search_path.is_dir(): + for file_path in search_path.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(self._base_path) + files.append(str(relative_path).replace("\\", "/")) + + return sorted(files) + + def delete(self, remote_path: str) -> bool: + """Delete a file from storage. + + Args: + remote_path: Path to the file to delete. + + Returns: + True if file was deleted, False if it didn't exist. + """ + file_path = self._get_full_path(remote_path) + + if not file_path.exists(): + return False + + file_path.unlink() + return True + + def get_url(self, remote_path: str) -> str: + """Get a file:// URL to access a file. + + Args: + remote_path: Path to the file in storage. + + Returns: + file:// URL to access the file. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + file_path = self._get_full_path(remote_path) + + if not file_path.exists(): + raise FileNotFoundStorageError(remote_path) + + return file_path.as_uri() + + def upload_bytes( + self, data: bytes, remote_path: str, overwrite: bool = False + ) -> str: + """Upload bytes directly to storage. + + Args: + data: Bytes to upload. + remote_path: Destination path in storage. + overwrite: If True, overwrite existing file. + + Returns: + The remote path where the data was stored. + """ + dest_path = self._get_full_path(remote_path) + + if dest_path.exists() and not overwrite: + raise StorageError(f"File already exists: {remote_path}") + + dest_path.parent.mkdir(parents=True, exist_ok=True) + dest_path.write_bytes(data) + + return remote_path + + def download_bytes(self, remote_path: str) -> bytes: + """Download a file as bytes. + + Args: + remote_path: Path to the file in storage. + + Returns: + The file contents as bytes. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + file_path = self._get_full_path(remote_path) + + if not file_path.exists(): + raise FileNotFoundStorageError(remote_path) + + return file_path.read_bytes() + + def get_presigned_url( + self, + remote_path: str, + expires_in_seconds: int = 3600, + ) -> str: + """Get a file:// URL for local file access. + + For local storage, this returns a file:// URI. + Note: Local file:// URLs don't actually expire. + + Args: + remote_path: Path to the file in storage. + expires_in_seconds: Ignored for local storage (URLs don't expire). + + Returns: + file:// URL to access the file. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + file_path = self._get_full_path(remote_path) + + if not file_path.exists(): + raise FileNotFoundStorageError(remote_path) + + return file_path.as_uri() diff --git a/packages/shared/shared/storage/prefixes.py b/packages/shared/shared/storage/prefixes.py new file mode 100644 index 0000000..d7dc1da --- /dev/null +++ b/packages/shared/shared/storage/prefixes.py @@ -0,0 +1,158 @@ +""" +Storage path prefixes for unified file organization. + +Provides standardized path prefixes for organizing files within +the storage backend, ensuring consistent structure across +local, Azure Blob, and S3 storage. +""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class StoragePrefixes: + """Standardized storage path prefixes. + + All paths are relative to the storage backend root. + These prefixes ensure consistent file organization across + all storage backends (local, Azure, S3). + + Usage: + from shared.storage.prefixes import PREFIXES + + path = f"{PREFIXES.DOCUMENTS}/{document_id}.pdf" + storage.upload_bytes(content, path) + """ + + # Document storage + DOCUMENTS: str = "documents" + """Original document files (PDFs, etc.)""" + + IMAGES: str = "images" + """Page images extracted from documents""" + + # Processing directories + UPLOADS: str = "uploads" + """Temporary upload staging area""" + + RESULTS: str = "results" + """Inference results and visualizations""" + + EXPORTS: str = "exports" + """Exported datasets and annotations""" + + # Training data + DATASETS: str = "datasets" + """Training dataset files""" + + MODELS: str = "models" + """Trained model weights and checkpoints""" + + # Data pipeline directories (legacy compatibility) + RAW_PDFS: str = "raw_pdfs" + """Raw PDF files for auto-labeling pipeline""" + + STRUCTURED_DATA: str = "structured_data" + """CSV/structured data for matching""" + + ADMIN_IMAGES: str = "admin_images" + """Admin UI page images""" + + @staticmethod + def document_path(document_id: str, extension: str = ".pdf") -> str: + """Get path for a document file. + + Args: + document_id: Unique document identifier. + extension: File extension (include leading dot). + + Returns: + Storage path like "documents/abc123.pdf" + """ + ext = extension if extension.startswith(".") else f".{extension}" + return f"{PREFIXES.DOCUMENTS}/{document_id}{ext}" + + @staticmethod + def image_path(document_id: str, page_num: int, extension: str = ".png") -> str: + """Get path for a page image file. + + Args: + document_id: Unique document identifier. + page_num: Page number (1-indexed). + extension: File extension (include leading dot). + + Returns: + Storage path like "images/abc123/page_1.png" + """ + ext = extension if extension.startswith(".") else f".{extension}" + return f"{PREFIXES.IMAGES}/{document_id}/page_{page_num}{ext}" + + @staticmethod + def upload_path(filename: str, subfolder: str | None = None) -> str: + """Get path for a temporary upload file. + + Args: + filename: Original filename. + subfolder: Optional subfolder (e.g., "async"). + + Returns: + Storage path like "uploads/filename.pdf" or "uploads/async/filename.pdf" + """ + if subfolder: + return f"{PREFIXES.UPLOADS}/{subfolder}/{filename}" + return f"{PREFIXES.UPLOADS}/{filename}" + + @staticmethod + def result_path(filename: str) -> str: + """Get path for a result file. + + Args: + filename: Result filename. + + Returns: + Storage path like "results/filename.json" + """ + return f"{PREFIXES.RESULTS}/{filename}" + + @staticmethod + def export_path(export_id: str, filename: str) -> str: + """Get path for an export file. + + Args: + export_id: Unique export identifier. + filename: Export filename. + + Returns: + Storage path like "exports/abc123/filename.zip" + """ + return f"{PREFIXES.EXPORTS}/{export_id}/{filename}" + + @staticmethod + def dataset_path(dataset_id: str, filename: str) -> str: + """Get path for a dataset file. + + Args: + dataset_id: Unique dataset identifier. + filename: Dataset filename. + + Returns: + Storage path like "datasets/abc123/filename.yaml" + """ + return f"{PREFIXES.DATASETS}/{dataset_id}/{filename}" + + @staticmethod + def model_path(version: str, filename: str) -> str: + """Get path for a model file. + + Args: + version: Model version string. + filename: Model filename. + + Returns: + Storage path like "models/v1.0.0/best.pt" + """ + return f"{PREFIXES.MODELS}/{version}/{filename}" + + +# Default instance for convenient access +PREFIXES = StoragePrefixes() diff --git a/packages/shared/shared/storage/s3.py b/packages/shared/shared/storage/s3.py new file mode 100644 index 0000000..0af7ed7 --- /dev/null +++ b/packages/shared/shared/storage/s3.py @@ -0,0 +1,309 @@ +""" +AWS S3 Storage backend. + +Provides storage operations using AWS S3. +""" + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from mypy_boto3_s3 import S3Client + +from shared.storage.base import ( + FileNotFoundStorageError, + StorageBackend, + StorageError, +) + + +class S3StorageBackend(StorageBackend): + """Storage backend using AWS S3. + + Files are stored as objects in an S3 bucket. + """ + + def __init__( + self, + bucket_name: str, + region_name: str | None = None, + access_key_id: str | None = None, + secret_access_key: str | None = None, + endpoint_url: str | None = None, + create_bucket: bool = False, + ) -> None: + """Initialize S3 storage backend. + + Args: + bucket_name: Name of the S3 bucket. + region_name: AWS region name (optional, uses default if not set). + access_key_id: AWS access key ID (optional, uses credentials chain). + secret_access_key: AWS secret access key (optional). + endpoint_url: Custom endpoint URL (for S3-compatible services). + create_bucket: If True, create the bucket if it doesn't exist. + """ + import boto3 + + self._bucket_name = bucket_name + self._region_name = region_name + + # Build client kwargs + client_kwargs: dict[str, Any] = {} + if region_name: + client_kwargs["region_name"] = region_name + if endpoint_url: + client_kwargs["endpoint_url"] = endpoint_url + if access_key_id and secret_access_key: + client_kwargs["aws_access_key_id"] = access_key_id + client_kwargs["aws_secret_access_key"] = secret_access_key + + self._s3: "S3Client" = boto3.client("s3", **client_kwargs) + + if create_bucket: + self._ensure_bucket_exists() + + def _ensure_bucket_exists(self) -> None: + """Create the bucket if it doesn't exist.""" + from botocore.exceptions import ClientError + + try: + self._s3.head_bucket(Bucket=self._bucket_name) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code in ("404", "NoSuchBucket"): + # Bucket doesn't exist, create it + create_kwargs: dict[str, Any] = {"Bucket": self._bucket_name} + if self._region_name and self._region_name != "us-east-1": + create_kwargs["CreateBucketConfiguration"] = { + "LocationConstraint": self._region_name + } + self._s3.create_bucket(**create_kwargs) + else: + # Re-raise permission errors, network issues, etc. + raise + + def _object_exists(self, key: str) -> bool: + """Check if an object exists in S3. + + Args: + key: Object key to check. + + Returns: + True if object exists, False otherwise. + """ + from botocore.exceptions import ClientError + + try: + self._s3.head_object(Bucket=self._bucket_name, Key=key) + return True + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code in ("404", "NoSuchKey"): + return False + raise + + @property + def bucket_name(self) -> str: + """Get the bucket name for this storage backend.""" + return self._bucket_name + + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + """Upload a file to S3. + + Args: + local_path: Path to the local file to upload. + remote_path: Destination object key. + overwrite: If True, overwrite existing object. + + Returns: + The remote path where the file was stored. + + Raises: + FileNotFoundStorageError: If local_path doesn't exist. + StorageError: If object exists and overwrite is False. + """ + if not local_path.exists(): + raise FileNotFoundStorageError(str(local_path)) + + if not overwrite and self._object_exists(remote_path): + raise StorageError(f"File already exists: {remote_path}") + + self._s3.upload_file(str(local_path), self._bucket_name, remote_path) + + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + """Download an object from S3. + + Args: + remote_path: Object key in S3. + local_path: Local destination path. + + Returns: + The local path where the file was downloaded. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + if not self._object_exists(remote_path): + raise FileNotFoundStorageError(remote_path) + + local_path.parent.mkdir(parents=True, exist_ok=True) + + self._s3.download_file(self._bucket_name, remote_path, str(local_path)) + + return local_path + + def exists(self, remote_path: str) -> bool: + """Check if an object exists in S3. + + Args: + remote_path: Object key to check. + + Returns: + True if the object exists, False otherwise. + """ + return self._object_exists(remote_path) + + def list_files(self, prefix: str) -> list[str]: + """List objects in S3 with given prefix. + + Handles pagination to return all matching objects (S3 returns max 1000 per request). + + Args: + prefix: Object key prefix to filter. + + Returns: + List of object keys matching the prefix. + """ + kwargs: dict[str, Any] = {"Bucket": self._bucket_name} + if prefix: + kwargs["Prefix"] = prefix + + all_keys: list[str] = [] + while True: + response = self._s3.list_objects_v2(**kwargs) + contents = response.get("Contents", []) + all_keys.extend(obj["Key"] for obj in contents) + + if not response.get("IsTruncated"): + break + kwargs["ContinuationToken"] = response["NextContinuationToken"] + + return all_keys + + def delete(self, remote_path: str) -> bool: + """Delete an object from S3. + + Args: + remote_path: Object key to delete. + + Returns: + True if object was deleted, False if it didn't exist. + """ + if not self._object_exists(remote_path): + return False + + self._s3.delete_object(Bucket=self._bucket_name, Key=remote_path) + return True + + def get_url(self, remote_path: str) -> str: + """Get a URL for an object. + + Args: + remote_path: Object key in S3. + + Returns: + URL to access the object. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + if not self._object_exists(remote_path): + raise FileNotFoundStorageError(remote_path) + + return self._s3.generate_presigned_url( + "get_object", + Params={"Bucket": self._bucket_name, "Key": remote_path}, + ExpiresIn=3600, + ) + + def get_presigned_url( + self, + remote_path: str, + expires_in_seconds: int = 3600, + ) -> str: + """Generate a pre-signed URL for temporary object access. + + Args: + remote_path: Object key in S3. + expires_in_seconds: URL validity duration (1 to 604800 seconds / 7 days). + + Returns: + Pre-signed URL string. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + ValueError: If expires_in_seconds is out of valid range. + """ + if expires_in_seconds < 1 or expires_in_seconds > 604800: + raise ValueError( + "expires_in_seconds must be between 1 and 604800 (7 days)" + ) + + if not self._object_exists(remote_path): + raise FileNotFoundStorageError(remote_path) + + return self._s3.generate_presigned_url( + "get_object", + Params={"Bucket": self._bucket_name, "Key": remote_path}, + ExpiresIn=expires_in_seconds, + ) + + def upload_bytes( + self, data: bytes, remote_path: str, overwrite: bool = False + ) -> str: + """Upload bytes directly to S3. + + Args: + data: Bytes to upload. + remote_path: Destination object key. + overwrite: If True, overwrite existing object. + + Returns: + The remote path where the data was stored. + + Raises: + StorageError: If object exists and overwrite is False. + """ + if not overwrite and self._object_exists(remote_path): + raise StorageError(f"File already exists: {remote_path}") + + self._s3.put_object(Bucket=self._bucket_name, Key=remote_path, Body=data) + + return remote_path + + def download_bytes(self, remote_path: str) -> bytes: + """Download an object as bytes. + + Args: + remote_path: Object key in S3. + + Returns: + The object contents as bytes. + + Raises: + FileNotFoundStorageError: If remote_path doesn't exist. + """ + from botocore.exceptions import ClientError + + try: + response = self._s3.get_object(Bucket=self._bucket_name, Key=remote_path) + return response["Body"].read() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code in ("404", "NoSuchKey"): + raise FileNotFoundStorageError(remote_path) from e + raise diff --git a/packages/training/training/cli/analyze_labels.py b/packages/training/training/cli/analyze_labels.py index c8e9b0f..fbce5a6 100644 --- a/packages/training/training/cli/analyze_labels.py +++ b/packages/training/training/cli/analyze_labels.py @@ -20,7 +20,7 @@ from shared.config import get_db_connection_string from shared.normalize import normalize_field from shared.matcher import FieldMatcher from shared.pdf import is_text_pdf, extract_text_tokens -from training.yolo.annotation_generator import FIELD_CLASSES +from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES from shared.data.db import DocumentDB diff --git a/packages/training/training/cli/autolabel.py b/packages/training/training/cli/autolabel.py index 09791a4..f5c6a66 100644 --- a/packages/training/training/cli/autolabel.py +++ b/packages/training/training/cli/autolabel.py @@ -113,7 +113,7 @@ def process_single_document(args_tuple): # Import inside worker to avoid pickling issues from training.data.autolabel_report import AutoLabelReport from shared.pdf import PDFDocument - from training.yolo.annotation_generator import FIELD_CLASSES + from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES from training.processing.document_processor import process_page, record_unmatched_fields start_time = time.time() @@ -342,7 +342,8 @@ def main(): from shared.ocr import OCREngine from shared.matcher import FieldMatcher from shared.normalize import normalize_field - from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES + from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES + from training.yolo.annotation_generator import AnnotationGenerator # Handle comma-separated CSV paths csv_input = args.csv diff --git a/packages/training/training/processing/autolabel_tasks.py b/packages/training/training/processing/autolabel_tasks.py index df012ab..f938f21 100644 --- a/packages/training/training/processing/autolabel_tasks.py +++ b/packages/training/training/processing/autolabel_tasks.py @@ -90,7 +90,7 @@ def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: import shutil from training.data.autolabel_report import AutoLabelReport from shared.pdf import PDFDocument - from training.yolo.annotation_generator import FIELD_CLASSES + from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES from training.processing.document_processor import process_page, record_unmatched_fields row_dict = task_data["row_dict"] @@ -208,7 +208,7 @@ def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]: import shutil from training.data.autolabel_report import AutoLabelReport from shared.pdf import PDFDocument - from training.yolo.annotation_generator import FIELD_CLASSES + from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES from training.processing.document_processor import process_page, record_unmatched_fields row_dict = task_data["row_dict"] diff --git a/packages/training/training/processing/document_processor.py b/packages/training/training/processing/document_processor.py index fb099a9..a6cca2d 100644 --- a/packages/training/training/processing/document_processor.py +++ b/packages/training/training/processing/document_processor.py @@ -15,7 +15,8 @@ from training.data.autolabel_report import FieldMatchResult from shared.matcher import FieldMatcher from shared.normalize import normalize_field from shared.ocr.machine_code_parser import MachineCodeParser -from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES +from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES +from training.yolo.annotation_generator import AnnotationGenerator def match_supplier_accounts( diff --git a/packages/training/training/yolo/annotation_generator.py b/packages/training/training/yolo/annotation_generator.py index 9a95a86..b8327b0 100644 --- a/packages/training/training/yolo/annotation_generator.py +++ b/packages/training/training/yolo/annotation_generator.py @@ -9,43 +9,12 @@ from pathlib import Path from typing import Any import csv - -# Field class mapping for YOLO -# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro -FIELD_CLASSES = { - 'InvoiceNumber': 0, - 'InvoiceDate': 1, - 'InvoiceDueDate': 2, - 'OCR': 3, - 'Bankgiro': 4, - 'Plusgiro': 5, - 'Amount': 6, - 'supplier_organisation_number': 7, - 'customer_number': 8, - 'payment_line': 9, # Machine code payment line at bottom of invoice -} - -# Fields that need matching but map to other YOLO classes -# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type -ACCOUNT_FIELD_MAPPING = { - 'supplier_accounts': { - 'BG': 'Bankgiro', # BG:xxx -> Bankgiro class - 'PG': 'Plusgiro', # PG:xxx -> Plusgiro class - } -} - -CLASS_NAMES = [ - 'invoice_number', - 'invoice_date', - 'invoice_due_date', - 'ocr_number', - 'bankgiro', - 'plusgiro', - 'amount', - 'supplier_org_number', - 'customer_number', - 'payment_line', # Machine code payment line at bottom of invoice -] +# Import field mappings from single source of truth +from shared.fields import ( + TRAINING_FIELD_CLASSES as FIELD_CLASSES, + CLASS_NAMES, + ACCOUNT_FIELD_MAPPING, +) @dataclass diff --git a/packages/training/training/yolo/dataset_builder.py b/packages/training/training/yolo/dataset_builder.py index 97bbf8d..3114cb1 100644 --- a/packages/training/training/yolo/dataset_builder.py +++ b/packages/training/training/yolo/dataset_builder.py @@ -101,7 +101,8 @@ class DatasetBuilder: Returns: DatasetStats with build results """ - from .annotation_generator import AnnotationGenerator, CLASS_NAMES + from shared.fields import CLASS_NAMES + from .annotation_generator import AnnotationGenerator random.seed(seed) diff --git a/packages/training/training/yolo/db_dataset.py b/packages/training/training/yolo/db_dataset.py index dc0f5be..74c3e7e 100644 --- a/packages/training/training/yolo/db_dataset.py +++ b/packages/training/training/yolo/db_dataset.py @@ -18,7 +18,8 @@ import numpy as np from PIL import Image from shared.config import DEFAULT_DPI -from .annotation_generator import FIELD_CLASSES, YOLOAnnotation +from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES +from .annotation_generator import YOLOAnnotation logger = logging.getLogger(__name__) diff --git a/run_migration.py b/run_migration.py new file mode 100644 index 0000000..35cc7f8 --- /dev/null +++ b/run_migration.py @@ -0,0 +1,73 @@ +"""Run database migration for training_status fields.""" +import psycopg2 +import os + +# Read password from .env file +password = "" +try: + with open(".env") as f: + for line in f: + if line.startswith("DB_PASSWORD="): + password = line.strip().split("=", 1)[1].strip('"').strip("'") + break +except Exception as e: + print(f"Error reading .env: {e}") + +print(f"Password found: {bool(password)}") + +conn = psycopg2.connect( + host="192.168.68.31", + port=5432, + database="docmaster", + user="docmaster", + password=password +) +conn.autocommit = True +cur = conn.cursor() + +# Add training_status column +try: + cur.execute("ALTER TABLE training_datasets ADD COLUMN training_status VARCHAR(20) DEFAULT NULL") + print("Added training_status column") +except Exception as e: + print(f"training_status: {e}") + +# Add active_training_task_id column +try: + cur.execute("ALTER TABLE training_datasets ADD COLUMN active_training_task_id UUID DEFAULT NULL") + print("Added active_training_task_id column") +except Exception as e: + print(f"active_training_task_id: {e}") + +# Create indexes +try: + cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status ON training_datasets(training_status)") + print("Created training_status index") +except Exception as e: + print(f"index training_status: {e}") + +try: + cur.execute("CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id ON training_datasets(active_training_task_id)") + print("Created active_training_task_id index") +except Exception as e: + print(f"index active_training_task_id: {e}") + +# Update existing datasets that have been used in completed training tasks to trained status +try: + cur.execute(""" + UPDATE training_datasets d + SET status = 'trained' + WHERE d.status = 'ready' + AND EXISTS ( + SELECT 1 FROM training_tasks t + WHERE t.dataset_id = d.dataset_id + AND t.status = 'completed' + ) + """) + print(f"Updated {cur.rowcount} datasets to trained status") +except Exception as e: + print(f"update status: {e}") + +cur.close() +conn.close() +print("Migration complete!") diff --git a/tests/data/test_admin_models_v2.py b/tests/data/test_admin_models_v2.py index 7593283..5396845 100644 --- a/tests/data/test_admin_models_v2.py +++ b/tests/data/test_admin_models_v2.py @@ -17,9 +17,8 @@ from inference.data.admin_models import ( AdminDocument, AdminAnnotation, TrainingTask, - FIELD_CLASSES, - CSV_TO_CLASS_MAPPING, ) +from shared.fields import FIELD_CLASSES, CSV_TO_CLASS_MAPPING class TestBatchUpload: @@ -507,7 +506,10 @@ class TestCSVToClassMapping: assert len(CSV_TO_CLASS_MAPPING) > 0 def test_csv_mapping_values(self): - """Test specific CSV column mappings.""" + """Test specific CSV column mappings. + + Note: customer_number is class 8 (verified from trained model best.pt). + """ assert CSV_TO_CLASS_MAPPING["InvoiceNumber"] == 0 assert CSV_TO_CLASS_MAPPING["InvoiceDate"] == 1 assert CSV_TO_CLASS_MAPPING["InvoiceDueDate"] == 2 @@ -516,7 +518,7 @@ class TestCSVToClassMapping: assert CSV_TO_CLASS_MAPPING["Plusgiro"] == 5 assert CSV_TO_CLASS_MAPPING["Amount"] == 6 assert CSV_TO_CLASS_MAPPING["supplier_organisation_number"] == 7 - assert CSV_TO_CLASS_MAPPING["customer_number"] == 9 + assert CSV_TO_CLASS_MAPPING["customer_number"] == 8 # Fixed: was 9, model uses 8 def test_csv_mapping_matches_field_classes(self): """Test that CSV mapping is consistent with FIELD_CLASSES.""" diff --git a/tests/shared/fields/__init__.py b/tests/shared/fields/__init__.py new file mode 100644 index 0000000..3987cb8 --- /dev/null +++ b/tests/shared/fields/__init__.py @@ -0,0 +1 @@ +"""Tests for shared.fields module.""" diff --git a/tests/shared/fields/test_field_config.py b/tests/shared/fields/test_field_config.py new file mode 100644 index 0000000..d4fc76d --- /dev/null +++ b/tests/shared/fields/test_field_config.py @@ -0,0 +1,200 @@ +""" +Tests for field configuration - Single Source of Truth. + +These tests ensure consistency across all field definitions and prevent +accidental changes that could break model inference. + +CRITICAL: These tests verify that field definitions match the trained YOLO model. +If these tests fail, it likely means someone modified field IDs incorrectly. +""" + +import pytest + +from shared.fields import ( + FIELD_DEFINITIONS, + CLASS_NAMES, + FIELD_CLASSES, + FIELD_CLASS_IDS, + CLASS_TO_FIELD, + CSV_TO_CLASS_MAPPING, + TRAINING_FIELD_CLASSES, + NUM_CLASSES, + FieldDefinition, +) + + +class TestFieldDefinitionsIntegrity: + """Tests to ensure field definitions are complete and consistent.""" + + def test_exactly_10_field_definitions(self): + """Verify we have exactly 10 field classes (matching trained model).""" + assert len(FIELD_DEFINITIONS) == 10 + assert NUM_CLASSES == 10 + + def test_class_ids_are_sequential(self): + """Verify class IDs are 0-9 without gaps.""" + class_ids = {fd.class_id for fd in FIELD_DEFINITIONS} + assert class_ids == set(range(10)) + + def test_class_ids_are_unique(self): + """Verify no duplicate class IDs.""" + class_ids = [fd.class_id for fd in FIELD_DEFINITIONS] + assert len(class_ids) == len(set(class_ids)) + + def test_class_names_are_unique(self): + """Verify no duplicate class names.""" + class_names = [fd.class_name for fd in FIELD_DEFINITIONS] + assert len(class_names) == len(set(class_names)) + + def test_field_definition_is_immutable(self): + """Verify FieldDefinition is frozen (immutable).""" + fd = FIELD_DEFINITIONS[0] + with pytest.raises(AttributeError): + fd.class_id = 99 # type: ignore + + +class TestModelCompatibility: + """Tests to verify field definitions match the trained YOLO model. + + These exact values are read from runs/train/invoice_fields/weights/best.pt + and MUST NOT be changed without retraining the model. + """ + + # Expected model.names from best.pt - DO NOT CHANGE + EXPECTED_MODEL_NAMES = { + 0: "invoice_number", + 1: "invoice_date", + 2: "invoice_due_date", + 3: "ocr_number", + 4: "bankgiro", + 5: "plusgiro", + 6: "amount", + 7: "supplier_org_number", + 8: "customer_number", + 9: "payment_line", + } + + def test_field_classes_match_model(self): + """CRITICAL: Verify FIELD_CLASSES matches trained model exactly.""" + assert FIELD_CLASSES == self.EXPECTED_MODEL_NAMES + + def test_class_names_order_matches_model(self): + """CRITICAL: Verify CLASS_NAMES order matches model class IDs.""" + expected_order = [ + self.EXPECTED_MODEL_NAMES[i] for i in range(10) + ] + assert CLASS_NAMES == expected_order + + def test_customer_number_is_class_8(self): + """CRITICAL: customer_number must be class 8 (not 9).""" + assert FIELD_CLASS_IDS["customer_number"] == 8 + assert FIELD_CLASSES[8] == "customer_number" + + def test_payment_line_is_class_9(self): + """CRITICAL: payment_line must be class 9 (not 8).""" + assert FIELD_CLASS_IDS["payment_line"] == 9 + assert FIELD_CLASSES[9] == "payment_line" + + +class TestMappingConsistency: + """Tests to verify all mappings are consistent with each other.""" + + def test_field_classes_and_field_class_ids_are_inverses(self): + """Verify FIELD_CLASSES and FIELD_CLASS_IDS are proper inverses.""" + for class_id, class_name in FIELD_CLASSES.items(): + assert FIELD_CLASS_IDS[class_name] == class_id + + for class_name, class_id in FIELD_CLASS_IDS.items(): + assert FIELD_CLASSES[class_id] == class_name + + def test_class_names_matches_field_classes_values(self): + """Verify CLASS_NAMES list matches FIELD_CLASSES values in order.""" + for i, class_name in enumerate(CLASS_NAMES): + assert FIELD_CLASSES[i] == class_name + + def test_class_to_field_has_all_classes(self): + """Verify CLASS_TO_FIELD has mapping for all class names.""" + for class_name in CLASS_NAMES: + assert class_name in CLASS_TO_FIELD + + def test_csv_mapping_excludes_derived_fields(self): + """Verify CSV_TO_CLASS_MAPPING excludes derived fields like payment_line.""" + # payment_line is derived, should not be in CSV mapping + assert "payment_line" not in CSV_TO_CLASS_MAPPING + + # All non-derived fields should be in CSV mapping + for fd in FIELD_DEFINITIONS: + if not fd.is_derived: + assert fd.field_name in CSV_TO_CLASS_MAPPING + + def test_training_field_classes_includes_all(self): + """Verify TRAINING_FIELD_CLASSES includes all fields including derived.""" + for fd in FIELD_DEFINITIONS: + assert fd.field_name in TRAINING_FIELD_CLASSES + assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id + + +class TestSpecificFieldDefinitions: + """Tests for specific field definitions to catch common mistakes.""" + + @pytest.mark.parametrize( + "class_id,expected_class_name", + [ + (0, "invoice_number"), + (1, "invoice_date"), + (2, "invoice_due_date"), + (3, "ocr_number"), + (4, "bankgiro"), + (5, "plusgiro"), + (6, "amount"), + (7, "supplier_org_number"), + (8, "customer_number"), + (9, "payment_line"), + ], + ) + def test_class_id_to_name_mapping(self, class_id: int, expected_class_name: str): + """Verify each class ID maps to the correct class name.""" + assert FIELD_CLASSES[class_id] == expected_class_name + + def test_payment_line_is_derived(self): + """Verify payment_line is marked as derived.""" + payment_line_def = next( + fd for fd in FIELD_DEFINITIONS if fd.class_name == "payment_line" + ) + assert payment_line_def.is_derived is True + + def test_other_fields_are_not_derived(self): + """Verify all fields except payment_line are not derived.""" + for fd in FIELD_DEFINITIONS: + if fd.class_name != "payment_line": + assert fd.is_derived is False, f"{fd.class_name} should not be derived" + + +class TestBackwardCompatibility: + """Tests to ensure backward compatibility with existing code.""" + + def test_csv_to_class_mapping_field_names(self): + """Verify CSV_TO_CLASS_MAPPING uses correct field names.""" + # These are the field names used in CSV files + expected_fields = { + "InvoiceNumber": 0, + "InvoiceDate": 1, + "InvoiceDueDate": 2, + "OCR": 3, + "Bankgiro": 4, + "Plusgiro": 5, + "Amount": 6, + "supplier_organisation_number": 7, + "customer_number": 8, + # payment_line (9) is derived, not in CSV + } + assert CSV_TO_CLASS_MAPPING == expected_fields + + def test_class_to_field_returns_field_names(self): + """Verify CLASS_TO_FIELD maps class names to field names correctly.""" + # Sample checks for key fields + assert CLASS_TO_FIELD["invoice_number"] == "InvoiceNumber" + assert CLASS_TO_FIELD["invoice_date"] == "InvoiceDate" + assert CLASS_TO_FIELD["ocr_number"] == "OCR" + assert CLASS_TO_FIELD["customer_number"] == "customer_number" + assert CLASS_TO_FIELD["payment_line"] == "payment_line" diff --git a/tests/shared/storage/__init__.py b/tests/shared/storage/__init__.py new file mode 100644 index 0000000..6261261 --- /dev/null +++ b/tests/shared/storage/__init__.py @@ -0,0 +1 @@ +# Tests for storage module diff --git a/tests/shared/storage/test_azure.py b/tests/shared/storage/test_azure.py new file mode 100644 index 0000000..dccf7e2 --- /dev/null +++ b/tests/shared/storage/test_azure.py @@ -0,0 +1,718 @@ +""" +Tests for AzureBlobStorageBackend. + +TDD Phase 1: RED - Write tests first, then implement to pass. +Uses mocking to avoid requiring actual Azure credentials. +""" + +import tempfile +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + + +@pytest.fixture +def mock_blob_service_client() -> MagicMock: + """Create a mock BlobServiceClient.""" + return MagicMock() + + +@pytest.fixture +def mock_container_client(mock_blob_service_client: MagicMock) -> MagicMock: + """Create a mock ContainerClient.""" + container_client = MagicMock() + mock_blob_service_client.get_container_client.return_value = container_client + return container_client + + +@pytest.fixture +def mock_blob_client(mock_container_client: MagicMock) -> MagicMock: + """Create a mock BlobClient.""" + blob_client = MagicMock() + mock_container_client.get_blob_client.return_value = blob_client + return blob_client + + +class TestAzureBlobStorageBackendCreation: + """Tests for AzureBlobStorageBackend instantiation.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_create_with_connection_string( + self, mock_service_class: MagicMock + ) -> None: + """Test creating backend with connection string.""" + from shared.storage.azure import AzureBlobStorageBackend + + connection_string = "DefaultEndpointsProtocol=https;AccountName=test;..." + backend = AzureBlobStorageBackend( + connection_string=connection_string, + container_name="training-images", + ) + + mock_service_class.from_connection_string.assert_called_once_with( + connection_string + ) + assert backend.container_name == "training-images" + + @patch("shared.storage.azure.BlobServiceClient") + def test_create_creates_container_if_not_exists( + self, mock_service_class: MagicMock + ) -> None: + """Test that container is created if it doesn't exist.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_container.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="new-container", + create_container=True, + ) + + mock_container.create_container.assert_called_once() + + @patch("shared.storage.azure.BlobServiceClient") + def test_create_does_not_create_container_by_default( + self, mock_service_class: MagicMock + ) -> None: + """Test that container is not created by default.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_container.exists.return_value = True + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="existing-container", + ) + + mock_container.create_container.assert_not_called() + + @patch("shared.storage.azure.BlobServiceClient") + def test_is_storage_backend_subclass( + self, mock_service_class: MagicMock + ) -> None: + """Test that AzureBlobStorageBackend is a StorageBackend.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.base import StorageBackend + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + assert isinstance(backend, StorageBackend) + + +class TestAzureBlobStorageBackendUpload: + """Tests for AzureBlobStorageBackend.upload method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_upload_file(self, mock_service_class: MagicMock) -> None: + """Test uploading a file.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f: + f.write(b"Hello, World!") + temp_path = Path(f.name) + + try: + result = backend.upload(temp_path, "uploads/sample.txt") + + assert result == "uploads/sample.txt" + mock_container.get_blob_client.assert_called_with("uploads/sample.txt") + mock_blob.upload_blob.assert_called_once() + finally: + temp_path.unlink() + + @patch("shared.storage.azure.BlobServiceClient") + def test_upload_fails_if_blob_exists_without_overwrite( + self, mock_service_class: MagicMock + ) -> None: + """Test that upload fails if blob exists and overwrite is False.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.base import StorageError + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f: + f.write(b"content") + temp_path = Path(f.name) + + try: + with pytest.raises(StorageError, match="already exists"): + backend.upload(temp_path, "existing.txt", overwrite=False) + finally: + temp_path.unlink() + + @patch("shared.storage.azure.BlobServiceClient") + def test_upload_succeeds_with_overwrite( + self, mock_service_class: MagicMock + ) -> None: + """Test that upload succeeds with overwrite=True.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f: + f.write(b"content") + temp_path = Path(f.name) + + try: + result = backend.upload(temp_path, "existing.txt", overwrite=True) + + assert result == "existing.txt" + mock_blob.upload_blob.assert_called_once() + # Check overwrite=True was passed + call_kwargs = mock_blob.upload_blob.call_args[1] + assert call_kwargs.get("overwrite") is True + finally: + temp_path.unlink() + + @patch("shared.storage.azure.BlobServiceClient") + def test_upload_nonexistent_file_fails( + self, mock_service_class: MagicMock + ) -> None: + """Test that uploading nonexistent file fails.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.base import FileNotFoundStorageError + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with pytest.raises(FileNotFoundStorageError): + backend.upload(Path("/nonexistent/file.txt"), "sample.txt") + + +class TestAzureBlobStorageBackendDownload: + """Tests for AzureBlobStorageBackend.download method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_download_file(self, mock_service_class: MagicMock) -> None: + """Test downloading a file.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + + # Mock download_blob to return stream + mock_stream = MagicMock() + mock_stream.readall.return_value = b"Hello, World!" + mock_blob.download_blob.return_value = mock_stream + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with tempfile.TemporaryDirectory() as temp_dir: + local_path = Path(temp_dir) / "downloaded.txt" + result = backend.download("remote/sample.txt", local_path) + + assert result == local_path + assert local_path.exists() + assert local_path.read_bytes() == b"Hello, World!" + + @patch("shared.storage.azure.BlobServiceClient") + def test_download_creates_parent_directories( + self, mock_service_class: MagicMock + ) -> None: + """Test that download creates parent directories.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + + mock_stream = MagicMock() + mock_stream.readall.return_value = b"content" + mock_blob.download_blob.return_value = mock_stream + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with tempfile.TemporaryDirectory() as temp_dir: + local_path = Path(temp_dir) / "deep" / "nested" / "downloaded.txt" + result = backend.download("sample.txt", local_path) + + assert local_path.exists() + + @patch("shared.storage.azure.BlobServiceClient") + def test_download_nonexistent_blob_fails( + self, mock_service_class: MagicMock + ) -> None: + """Test that downloading nonexistent blob fails.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.base import FileNotFoundStorageError + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"): + backend.download("nonexistent.txt", Path("/tmp/file.txt")) + + +class TestAzureBlobStorageBackendExists: + """Tests for AzureBlobStorageBackend.exists method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_exists_returns_true_for_existing_blob( + self, mock_service_class: MagicMock + ) -> None: + """Test exists returns True for existing blob.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + assert backend.exists("existing.txt") is True + + @patch("shared.storage.azure.BlobServiceClient") + def test_exists_returns_false_for_nonexistent_blob( + self, mock_service_class: MagicMock + ) -> None: + """Test exists returns False for nonexistent blob.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + assert backend.exists("nonexistent.txt") is False + + +class TestAzureBlobStorageBackendListFiles: + """Tests for AzureBlobStorageBackend.list_files method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_list_files_empty_container( + self, mock_service_class: MagicMock + ) -> None: + """Test listing files in empty container.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_container.list_blobs.return_value = [] + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + assert backend.list_files("") == [] + + @patch("shared.storage.azure.BlobServiceClient") + def test_list_files_returns_all_blobs( + self, mock_service_class: MagicMock + ) -> None: + """Test listing all blobs.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + + # Create mock blob items + mock_blob1 = MagicMock() + mock_blob1.name = "file1.txt" + mock_blob2 = MagicMock() + mock_blob2.name = "file2.txt" + mock_blob3 = MagicMock() + mock_blob3.name = "subdir/file3.txt" + mock_container.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3] + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + files = backend.list_files("") + + assert len(files) == 3 + assert "file1.txt" in files + assert "file2.txt" in files + assert "subdir/file3.txt" in files + + @patch("shared.storage.azure.BlobServiceClient") + def test_list_files_with_prefix( + self, mock_service_class: MagicMock + ) -> None: + """Test listing files with prefix filter.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + + mock_blob1 = MagicMock() + mock_blob1.name = "images/a.png" + mock_blob2 = MagicMock() + mock_blob2.name = "images/b.png" + mock_container.list_blobs.return_value = [mock_blob1, mock_blob2] + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + files = backend.list_files("images/") + + mock_container.list_blobs.assert_called_with(name_starts_with="images/") + assert len(files) == 2 + + +class TestAzureBlobStorageBackendDelete: + """Tests for AzureBlobStorageBackend.delete method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_delete_existing_blob( + self, mock_service_class: MagicMock + ) -> None: + """Test deleting an existing blob.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + result = backend.delete("sample.txt") + + assert result is True + mock_blob.delete_blob.assert_called_once() + + @patch("shared.storage.azure.BlobServiceClient") + def test_delete_nonexistent_blob_returns_false( + self, mock_service_class: MagicMock + ) -> None: + """Test deleting nonexistent blob returns False.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + result = backend.delete("nonexistent.txt") + + assert result is False + mock_blob.delete_blob.assert_not_called() + + +class TestAzureBlobStorageBackendGetUrl: + """Tests for AzureBlobStorageBackend.get_url method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_get_url_returns_blob_url( + self, mock_service_class: MagicMock + ) -> None: + """Test get_url returns blob URL.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + mock_blob.url = "https://account.blob.core.windows.net/container/sample.txt" + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + url = backend.get_url("sample.txt") + + assert url == "https://account.blob.core.windows.net/container/sample.txt" + + @patch("shared.storage.azure.BlobServiceClient") + def test_get_url_nonexistent_blob_fails( + self, mock_service_class: MagicMock + ) -> None: + """Test get_url for nonexistent blob fails.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.base import FileNotFoundStorageError + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with pytest.raises(FileNotFoundStorageError): + backend.get_url("nonexistent.txt") + + +class TestAzureBlobStorageBackendUploadBytes: + """Tests for AzureBlobStorageBackend.upload_bytes method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_upload_bytes(self, mock_service_class: MagicMock) -> None: + """Test uploading bytes directly.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + data = b"Binary content here" + result = backend.upload_bytes(data, "binary.dat") + + assert result == "binary.dat" + mock_blob.upload_blob.assert_called_once() + + +class TestAzureBlobStorageBackendDownloadBytes: + """Tests for AzureBlobStorageBackend.download_bytes method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_download_bytes(self, mock_service_class: MagicMock) -> None: + """Test downloading blob as bytes.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = True + + mock_stream = MagicMock() + mock_stream.readall.return_value = b"Hello, World!" + mock_blob.download_blob.return_value = mock_stream + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + data = backend.download_bytes("sample.txt") + + assert data == b"Hello, World!" + + @patch("shared.storage.azure.BlobServiceClient") + def test_download_bytes_nonexistent( + self, mock_service_class: MagicMock + ) -> None: + """Test downloading nonexistent blob as bytes.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.base import FileNotFoundStorageError + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with pytest.raises(FileNotFoundStorageError): + backend.download_bytes("nonexistent.txt") + + +class TestAzureBlobStorageBackendBatchOperations: + """Tests for batch operations in AzureBlobStorageBackend.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_upload_directory(self, mock_service_class: MagicMock) -> None: + """Test uploading an entire directory.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + mock_blob = MagicMock() + mock_container.get_blob_client.return_value = mock_blob + mock_blob.exists.return_value = False + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + (temp_path / "file1.txt").write_text("content1") + (temp_path / "subdir").mkdir() + (temp_path / "subdir" / "file2.txt").write_text("content2") + + results = backend.upload_directory(temp_path, "uploads/") + + assert len(results) == 2 + assert "uploads/file1.txt" in results + assert "uploads/subdir/file2.txt" in results + + @patch("shared.storage.azure.BlobServiceClient") + def test_download_directory(self, mock_service_class: MagicMock) -> None: + """Test downloading blobs matching a prefix.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_service = MagicMock() + mock_service_class.from_connection_string.return_value = mock_service + mock_container = MagicMock() + mock_service.get_container_client.return_value = mock_container + + # Mock blob listing + mock_blob1 = MagicMock() + mock_blob1.name = "images/a.png" + mock_blob2 = MagicMock() + mock_blob2.name = "images/b.png" + mock_container.list_blobs.return_value = [mock_blob1, mock_blob2] + + # Mock blob clients + mock_blob_client = MagicMock() + mock_container.get_blob_client.return_value = mock_blob_client + mock_blob_client.exists.return_value = True + mock_stream = MagicMock() + mock_stream.readall.return_value = b"image content" + mock_blob_client.download_blob.return_value = mock_stream + + backend = AzureBlobStorageBackend( + connection_string="connection_string", + container_name="container", + ) + + with tempfile.TemporaryDirectory() as temp_dir: + local_path = Path(temp_dir) + results = backend.download_directory("images/", local_path) + + assert len(results) == 2 + # Files should be created relative to prefix + assert (local_path / "a.png").exists() or (local_path / "images" / "a.png").exists() diff --git a/tests/shared/storage/test_base.py b/tests/shared/storage/test_base.py new file mode 100644 index 0000000..dafaf80 --- /dev/null +++ b/tests/shared/storage/test_base.py @@ -0,0 +1,301 @@ +""" +Tests for storage base module. + +TDD Phase 1: RED - Write tests first, then implement to pass. +""" + +from abc import ABC +from pathlib import Path +from typing import BinaryIO +from unittest.mock import MagicMock, patch + +import pytest + + +class TestStorageBackendInterface: + """Tests for StorageBackend abstract base class.""" + + def test_cannot_instantiate_directly(self) -> None: + """Test that StorageBackend cannot be instantiated.""" + from shared.storage.base import StorageBackend + + with pytest.raises(TypeError): + StorageBackend() # type: ignore + + def test_is_abstract_base_class(self) -> None: + """Test that StorageBackend is an ABC.""" + from shared.storage.base import StorageBackend + + assert issubclass(StorageBackend, ABC) + + def test_subclass_must_implement_upload(self) -> None: + """Test that subclass must implement upload method.""" + from shared.storage.base import StorageBackend + + class IncompleteBackend(StorageBackend): + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def exists(self, remote_path: str) -> bool: + return False + + def list_files(self, prefix: str) -> list[str]: + return [] + + def delete(self, remote_path: str) -> bool: + return True + + def get_url(self, remote_path: str) -> str: + return "" + + with pytest.raises(TypeError): + IncompleteBackend() # type: ignore + + def test_subclass_must_implement_download(self) -> None: + """Test that subclass must implement download method.""" + from shared.storage.base import StorageBackend + + class IncompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def exists(self, remote_path: str) -> bool: + return False + + def list_files(self, prefix: str) -> list[str]: + return [] + + def delete(self, remote_path: str) -> bool: + return True + + def get_url(self, remote_path: str) -> str: + return "" + + with pytest.raises(TypeError): + IncompleteBackend() # type: ignore + + def test_subclass_must_implement_exists(self) -> None: + """Test that subclass must implement exists method.""" + from shared.storage.base import StorageBackend + + class IncompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def list_files(self, prefix: str) -> list[str]: + return [] + + def delete(self, remote_path: str) -> bool: + return True + + def get_url(self, remote_path: str) -> str: + return "" + + with pytest.raises(TypeError): + IncompleteBackend() # type: ignore + + def test_subclass_must_implement_list_files(self) -> None: + """Test that subclass must implement list_files method.""" + from shared.storage.base import StorageBackend + + class IncompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def exists(self, remote_path: str) -> bool: + return False + + def delete(self, remote_path: str) -> bool: + return True + + def get_url(self, remote_path: str) -> str: + return "" + + with pytest.raises(TypeError): + IncompleteBackend() # type: ignore + + def test_subclass_must_implement_delete(self) -> None: + """Test that subclass must implement delete method.""" + from shared.storage.base import StorageBackend + + class IncompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def exists(self, remote_path: str) -> bool: + return False + + def list_files(self, prefix: str) -> list[str]: + return [] + + def get_url(self, remote_path: str) -> str: + return "" + + with pytest.raises(TypeError): + IncompleteBackend() # type: ignore + + def test_subclass_must_implement_get_url(self) -> None: + """Test that subclass must implement get_url method.""" + from shared.storage.base import StorageBackend + + class IncompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def exists(self, remote_path: str) -> bool: + return False + + def list_files(self, prefix: str) -> list[str]: + return [] + + def delete(self, remote_path: str) -> bool: + return True + + with pytest.raises(TypeError): + IncompleteBackend() # type: ignore + + def test_valid_subclass_can_be_instantiated(self) -> None: + """Test that a complete subclass can be instantiated.""" + from shared.storage.base import StorageBackend + + class CompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def exists(self, remote_path: str) -> bool: + return False + + def list_files(self, prefix: str) -> list[str]: + return [] + + def delete(self, remote_path: str) -> bool: + return True + + def get_url(self, remote_path: str) -> str: + return "" + + def get_presigned_url( + self, remote_path: str, expires_in_seconds: int = 3600 + ) -> str: + return "" + + backend = CompleteBackend() + assert isinstance(backend, StorageBackend) + + +class TestStorageError: + """Tests for StorageError exception.""" + + def test_storage_error_is_exception(self) -> None: + """Test that StorageError is an Exception.""" + from shared.storage.base import StorageError + + assert issubclass(StorageError, Exception) + + def test_storage_error_with_message(self) -> None: + """Test StorageError with message.""" + from shared.storage.base import StorageError + + error = StorageError("Upload failed") + assert str(error) == "Upload failed" + + def test_storage_error_can_be_raised(self) -> None: + """Test that StorageError can be raised and caught.""" + from shared.storage.base import StorageError + + with pytest.raises(StorageError, match="test error"): + raise StorageError("test error") + + +class TestFileNotFoundError: + """Tests for FileNotFoundStorageError exception.""" + + def test_file_not_found_is_storage_error(self) -> None: + """Test that FileNotFoundStorageError is a StorageError.""" + from shared.storage.base import FileNotFoundStorageError, StorageError + + assert issubclass(FileNotFoundStorageError, StorageError) + + def test_file_not_found_with_path(self) -> None: + """Test FileNotFoundStorageError with path.""" + from shared.storage.base import FileNotFoundStorageError + + error = FileNotFoundStorageError("images/test.png") + assert "images/test.png" in str(error) + + +class TestStorageConfig: + """Tests for StorageConfig dataclass.""" + + def test_storage_config_creation(self) -> None: + """Test creating StorageConfig.""" + from shared.storage.base import StorageConfig + + config = StorageConfig( + backend_type="azure_blob", + connection_string="DefaultEndpointsProtocol=https;...", + container_name="training-images", + ) + + assert config.backend_type == "azure_blob" + assert config.connection_string == "DefaultEndpointsProtocol=https;..." + assert config.container_name == "training-images" + + def test_storage_config_defaults(self) -> None: + """Test StorageConfig with defaults.""" + from shared.storage.base import StorageConfig + + config = StorageConfig(backend_type="local") + + assert config.backend_type == "local" + assert config.connection_string is None + assert config.container_name is None + assert config.base_path is None + + def test_storage_config_with_base_path(self) -> None: + """Test StorageConfig with base_path for local backend.""" + from shared.storage.base import StorageConfig + + config = StorageConfig( + backend_type="local", + base_path=Path("/data/images"), + ) + + assert config.backend_type == "local" + assert config.base_path == Path("/data/images") + + def test_storage_config_immutable(self) -> None: + """Test that StorageConfig is immutable (frozen).""" + from shared.storage.base import StorageConfig + + config = StorageConfig(backend_type="local") + + with pytest.raises(AttributeError): + config.backend_type = "azure_blob" # type: ignore diff --git a/tests/shared/storage/test_config_loader.py b/tests/shared/storage/test_config_loader.py new file mode 100644 index 0000000..8f4c909 --- /dev/null +++ b/tests/shared/storage/test_config_loader.py @@ -0,0 +1,348 @@ +""" +Tests for storage configuration file loader. + +TDD Phase 1: RED - Write tests first, then implement to pass. +""" + +import os +import shutil +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def temp_dir() -> Path: + """Create a temporary directory for tests.""" + temp_dir = Path(tempfile.mkdtemp()) + yield temp_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +class TestEnvVarSubstitution: + """Tests for environment variable substitution in config values.""" + + def test_substitute_simple_env_var(self) -> None: + """Test substituting a simple environment variable.""" + from shared.storage.config_loader import substitute_env_vars + + with patch.dict(os.environ, {"MY_VAR": "my_value"}): + result = substitute_env_vars("${MY_VAR}") + assert result == "my_value" + + def test_substitute_env_var_with_default(self) -> None: + """Test substituting env var with default when var is not set.""" + from shared.storage.config_loader import substitute_env_vars + + # Ensure var is not set + os.environ.pop("UNSET_VAR", None) + + result = substitute_env_vars("${UNSET_VAR:-default_value}") + assert result == "default_value" + + def test_substitute_env_var_ignores_default_when_set(self) -> None: + """Test that default is ignored when env var is set.""" + from shared.storage.config_loader import substitute_env_vars + + with patch.dict(os.environ, {"SET_VAR": "actual_value"}): + result = substitute_env_vars("${SET_VAR:-default_value}") + assert result == "actual_value" + + def test_substitute_multiple_env_vars(self) -> None: + """Test substituting multiple env vars in one string.""" + from shared.storage.config_loader import substitute_env_vars + + with patch.dict(os.environ, {"HOST": "localhost", "PORT": "5432"}): + result = substitute_env_vars("postgres://${HOST}:${PORT}/db") + assert result == "postgres://localhost:5432/db" + + def test_substitute_preserves_non_env_text(self) -> None: + """Test that non-env-var text is preserved.""" + from shared.storage.config_loader import substitute_env_vars + + with patch.dict(os.environ, {"VAR": "value"}): + result = substitute_env_vars("prefix_${VAR}_suffix") + assert result == "prefix_value_suffix" + + def test_substitute_empty_string_when_not_set_and_no_default(self) -> None: + """Test that empty string is returned when var not set and no default.""" + from shared.storage.config_loader import substitute_env_vars + + os.environ.pop("MISSING_VAR", None) + + result = substitute_env_vars("${MISSING_VAR}") + assert result == "" + + +class TestLoadStorageConfigYaml: + """Tests for loading storage configuration from YAML files.""" + + def test_load_local_backend_config(self, temp_dir: Path) -> None: + """Test loading configuration for local backend.""" + from shared.storage.config_loader import load_storage_config + + config_path = temp_dir / "storage.yaml" + config_path.write_text(""" +backend: local +presigned_url_expiry: 3600 + +local: + base_path: ./data/storage +""") + + config = load_storage_config(config_path) + + assert config.backend_type == "local" + assert config.presigned_url_expiry == 3600 + assert config.local is not None + assert config.local.base_path == Path("./data/storage") + + def test_load_azure_backend_config(self, temp_dir: Path) -> None: + """Test loading configuration for Azure backend.""" + from shared.storage.config_loader import load_storage_config + + config_path = temp_dir / "storage.yaml" + config_path.write_text(""" +backend: azure_blob +presigned_url_expiry: 7200 + +azure: + connection_string: DefaultEndpointsProtocol=https;AccountName=test + container_name: documents + create_container: true +""") + + config = load_storage_config(config_path) + + assert config.backend_type == "azure_blob" + assert config.presigned_url_expiry == 7200 + assert config.azure is not None + assert config.azure.connection_string == "DefaultEndpointsProtocol=https;AccountName=test" + assert config.azure.container_name == "documents" + assert config.azure.create_container is True + + def test_load_s3_backend_config(self, temp_dir: Path) -> None: + """Test loading configuration for S3 backend.""" + from shared.storage.config_loader import load_storage_config + + config_path = temp_dir / "storage.yaml" + config_path.write_text(""" +backend: s3 +presigned_url_expiry: 1800 + +s3: + bucket_name: my-bucket + region_name: us-west-2 + endpoint_url: http://localhost:9000 + create_bucket: false +""") + + config = load_storage_config(config_path) + + assert config.backend_type == "s3" + assert config.presigned_url_expiry == 1800 + assert config.s3 is not None + assert config.s3.bucket_name == "my-bucket" + assert config.s3.region_name == "us-west-2" + assert config.s3.endpoint_url == "http://localhost:9000" + assert config.s3.create_bucket is False + + def test_load_config_with_env_var_substitution(self, temp_dir: Path) -> None: + """Test that environment variables are substituted in config.""" + from shared.storage.config_loader import load_storage_config + + config_path = temp_dir / "storage.yaml" + config_path.write_text(""" +backend: ${STORAGE_BACKEND:-local} + +local: + base_path: ${STORAGE_PATH:-./default/path} +""") + + with patch.dict(os.environ, {"STORAGE_BACKEND": "local", "STORAGE_PATH": "/custom/path"}): + config = load_storage_config(config_path) + + assert config.backend_type == "local" + assert config.local is not None + assert config.local.base_path == Path("/custom/path") + + def test_load_config_file_not_found_raises(self, temp_dir: Path) -> None: + """Test that FileNotFoundError is raised for missing config file.""" + from shared.storage.config_loader import load_storage_config + + with pytest.raises(FileNotFoundError): + load_storage_config(temp_dir / "nonexistent.yaml") + + def test_load_config_invalid_yaml_raises(self, temp_dir: Path) -> None: + """Test that ValueError is raised for invalid YAML.""" + from shared.storage.config_loader import load_storage_config + + config_path = temp_dir / "storage.yaml" + config_path.write_text("invalid: yaml: content: [") + + with pytest.raises(ValueError, match="Invalid"): + load_storage_config(config_path) + + def test_load_config_missing_backend_raises(self, temp_dir: Path) -> None: + """Test that ValueError is raised when backend is missing.""" + from shared.storage.config_loader import load_storage_config + + config_path = temp_dir / "storage.yaml" + config_path.write_text(""" +local: + base_path: ./data +""") + + with pytest.raises(ValueError, match="backend"): + load_storage_config(config_path) + + def test_load_config_default_presigned_url_expiry(self, temp_dir: Path) -> None: + """Test default presigned_url_expiry when not specified.""" + from shared.storage.config_loader import load_storage_config + + config_path = temp_dir / "storage.yaml" + config_path.write_text(""" +backend: local + +local: + base_path: ./data +""") + + config = load_storage_config(config_path) + + assert config.presigned_url_expiry == 3600 # Default value + + +class TestStorageFileConfig: + """Tests for StorageFileConfig dataclass.""" + + def test_storage_file_config_is_immutable(self) -> None: + """Test that StorageFileConfig is frozen (immutable).""" + from shared.storage.config_loader import StorageFileConfig + + config = StorageFileConfig(backend_type="local") + + with pytest.raises(AttributeError): + config.backend_type = "azure_blob" # type: ignore + + def test_storage_file_config_defaults(self) -> None: + """Test StorageFileConfig default values.""" + from shared.storage.config_loader import StorageFileConfig + + config = StorageFileConfig(backend_type="local") + + assert config.backend_type == "local" + assert config.local is None + assert config.azure is None + assert config.s3 is None + assert config.presigned_url_expiry == 3600 + + +class TestLocalConfig: + """Tests for LocalConfig dataclass.""" + + def test_local_config_creation(self) -> None: + """Test creating LocalConfig.""" + from shared.storage.config_loader import LocalConfig + + config = LocalConfig(base_path=Path("/data/storage")) + + assert config.base_path == Path("/data/storage") + + def test_local_config_is_immutable(self) -> None: + """Test that LocalConfig is frozen.""" + from shared.storage.config_loader import LocalConfig + + config = LocalConfig(base_path=Path("/data")) + + with pytest.raises(AttributeError): + config.base_path = Path("/other") # type: ignore + + +class TestAzureConfig: + """Tests for AzureConfig dataclass.""" + + def test_azure_config_creation(self) -> None: + """Test creating AzureConfig.""" + from shared.storage.config_loader import AzureConfig + + config = AzureConfig( + connection_string="test_connection", + container_name="test_container", + create_container=True, + ) + + assert config.connection_string == "test_connection" + assert config.container_name == "test_container" + assert config.create_container is True + + def test_azure_config_defaults(self) -> None: + """Test AzureConfig default values.""" + from shared.storage.config_loader import AzureConfig + + config = AzureConfig( + connection_string="conn", + container_name="container", + ) + + assert config.create_container is False + + def test_azure_config_is_immutable(self) -> None: + """Test that AzureConfig is frozen.""" + from shared.storage.config_loader import AzureConfig + + config = AzureConfig( + connection_string="conn", + container_name="container", + ) + + with pytest.raises(AttributeError): + config.container_name = "other" # type: ignore + + +class TestS3Config: + """Tests for S3Config dataclass.""" + + def test_s3_config_creation(self) -> None: + """Test creating S3Config.""" + from shared.storage.config_loader import S3Config + + config = S3Config( + bucket_name="my-bucket", + region_name="us-east-1", + access_key_id="AKIAIOSFODNN7EXAMPLE", + secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + endpoint_url="http://localhost:9000", + create_bucket=True, + ) + + assert config.bucket_name == "my-bucket" + assert config.region_name == "us-east-1" + assert config.access_key_id == "AKIAIOSFODNN7EXAMPLE" + assert config.secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + assert config.endpoint_url == "http://localhost:9000" + assert config.create_bucket is True + + def test_s3_config_minimal(self) -> None: + """Test S3Config with only required fields.""" + from shared.storage.config_loader import S3Config + + config = S3Config(bucket_name="bucket") + + assert config.bucket_name == "bucket" + assert config.region_name is None + assert config.access_key_id is None + assert config.secret_access_key is None + assert config.endpoint_url is None + assert config.create_bucket is False + + def test_s3_config_is_immutable(self) -> None: + """Test that S3Config is frozen.""" + from shared.storage.config_loader import S3Config + + config = S3Config(bucket_name="bucket") + + with pytest.raises(AttributeError): + config.bucket_name = "other" # type: ignore diff --git a/tests/shared/storage/test_factory.py b/tests/shared/storage/test_factory.py new file mode 100644 index 0000000..ebd6464 --- /dev/null +++ b/tests/shared/storage/test_factory.py @@ -0,0 +1,423 @@ +""" +Tests for storage factory. + +TDD Phase 1: RED - Write tests first, then implement to pass. +""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +class TestStorageFactory: + """Tests for create_storage_backend factory function.""" + + def test_create_local_backend(self) -> None: + """Test creating local storage backend.""" + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + from shared.storage.local import LocalStorageBackend + + with tempfile.TemporaryDirectory() as temp_dir: + config = StorageConfig( + backend_type="local", + base_path=Path(temp_dir), + ) + + backend = create_storage_backend(config) + + assert isinstance(backend, LocalStorageBackend) + assert backend.base_path == Path(temp_dir) + + @patch("shared.storage.azure.BlobServiceClient") + def test_create_azure_backend(self, mock_service_class: MagicMock) -> None: + """Test creating Azure blob storage backend.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + + config = StorageConfig( + backend_type="azure_blob", + connection_string="DefaultEndpointsProtocol=https;...", + container_name="training-images", + ) + + backend = create_storage_backend(config) + + assert isinstance(backend, AzureBlobStorageBackend) + + def test_create_unknown_backend_raises(self) -> None: + """Test that unknown backend type raises ValueError.""" + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + + config = StorageConfig(backend_type="unknown_backend") + + with pytest.raises(ValueError, match="Unknown storage backend"): + create_storage_backend(config) + + def test_create_local_requires_base_path(self) -> None: + """Test that local backend requires base_path.""" + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + + config = StorageConfig(backend_type="local") + + with pytest.raises(ValueError, match="base_path"): + create_storage_backend(config) + + def test_create_azure_requires_connection_string(self) -> None: + """Test that Azure backend requires connection_string.""" + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + + config = StorageConfig( + backend_type="azure_blob", + container_name="container", + ) + + with pytest.raises(ValueError, match="connection_string"): + create_storage_backend(config) + + def test_create_azure_requires_container_name(self) -> None: + """Test that Azure backend requires container_name.""" + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + + config = StorageConfig( + backend_type="azure_blob", + connection_string="connection_string", + ) + + with pytest.raises(ValueError, match="container_name"): + create_storage_backend(config) + + +class TestStorageFactoryFromEnv: + """Tests for create_storage_backend_from_env factory function.""" + + def test_create_from_env_local(self) -> None: + """Test creating local backend from environment variables.""" + from shared.storage.factory import create_storage_backend_from_env + from shared.storage.local import LocalStorageBackend + + with tempfile.TemporaryDirectory() as temp_dir: + env = { + "STORAGE_BACKEND": "local", + "STORAGE_BASE_PATH": temp_dir, + } + + with patch.dict(os.environ, env, clear=False): + backend = create_storage_backend_from_env() + + assert isinstance(backend, LocalStorageBackend) + + @patch("shared.storage.azure.BlobServiceClient") + def test_create_from_env_azure(self, mock_service_class: MagicMock) -> None: + """Test creating Azure backend from environment variables.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.factory import create_storage_backend_from_env + + env = { + "STORAGE_BACKEND": "azure_blob", + "AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...", + "AZURE_STORAGE_CONTAINER": "training-images", + } + + with patch.dict(os.environ, env, clear=False): + backend = create_storage_backend_from_env() + + assert isinstance(backend, AzureBlobStorageBackend) + + def test_create_from_env_defaults_to_local(self) -> None: + """Test that factory defaults to local backend.""" + from shared.storage.factory import create_storage_backend_from_env + from shared.storage.local import LocalStorageBackend + + with tempfile.TemporaryDirectory() as temp_dir: + env = { + "STORAGE_BASE_PATH": temp_dir, + } + + # Remove STORAGE_BACKEND if present + with patch.dict(os.environ, env, clear=False): + if "STORAGE_BACKEND" in os.environ: + del os.environ["STORAGE_BACKEND"] + backend = create_storage_backend_from_env() + + assert isinstance(backend, LocalStorageBackend) + + def test_create_from_env_missing_azure_vars(self) -> None: + """Test error when Azure env vars are missing.""" + from shared.storage.factory import create_storage_backend_from_env + + env = { + "STORAGE_BACKEND": "azure_blob", + # Missing AZURE_STORAGE_CONNECTION_STRING + } + + with patch.dict(os.environ, env, clear=False): + # Remove the connection string if present + if "AZURE_STORAGE_CONNECTION_STRING" in os.environ: + del os.environ["AZURE_STORAGE_CONNECTION_STRING"] + + with pytest.raises(ValueError, match="AZURE_STORAGE_CONNECTION_STRING"): + create_storage_backend_from_env() + + +class TestGetDefaultStorageConfig: + """Tests for get_default_storage_config function.""" + + def test_get_default_config_local(self) -> None: + """Test getting default local config.""" + from shared.storage.factory import get_default_storage_config + + with tempfile.TemporaryDirectory() as temp_dir: + env = { + "STORAGE_BACKEND": "local", + "STORAGE_BASE_PATH": temp_dir, + } + + with patch.dict(os.environ, env, clear=False): + config = get_default_storage_config() + + assert config.backend_type == "local" + assert config.base_path == Path(temp_dir) + + def test_get_default_config_azure(self) -> None: + """Test getting default Azure config.""" + from shared.storage.factory import get_default_storage_config + + env = { + "STORAGE_BACKEND": "azure_blob", + "AZURE_STORAGE_CONNECTION_STRING": "DefaultEndpointsProtocol=https;...", + "AZURE_STORAGE_CONTAINER": "training-images", + } + + with patch.dict(os.environ, env, clear=False): + config = get_default_storage_config() + + assert config.backend_type == "azure_blob" + assert config.connection_string == "DefaultEndpointsProtocol=https;..." + assert config.container_name == "training-images" + + +class TestStorageFactoryS3: + """Tests for S3 backend support in factory.""" + + @patch("boto3.client") + def test_create_s3_backend(self, mock_boto3_client: MagicMock) -> None: + """Test creating S3 storage backend.""" + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + from shared.storage.s3 import S3StorageBackend + + config = StorageConfig( + backend_type="s3", + bucket_name="test-bucket", + region_name="us-west-2", + ) + + backend = create_storage_backend(config) + + assert isinstance(backend, S3StorageBackend) + + def test_create_s3_requires_bucket_name(self) -> None: + """Test that S3 backend requires bucket_name.""" + from shared.storage.base import StorageConfig + from shared.storage.factory import create_storage_backend + + config = StorageConfig( + backend_type="s3", + region_name="us-west-2", + ) + + with pytest.raises(ValueError, match="bucket_name"): + create_storage_backend(config) + + @patch("boto3.client") + def test_create_from_env_s3(self, mock_boto3_client: MagicMock) -> None: + """Test creating S3 backend from environment variables.""" + from shared.storage.factory import create_storage_backend_from_env + from shared.storage.s3 import S3StorageBackend + + env = { + "STORAGE_BACKEND": "s3", + "AWS_S3_BUCKET": "test-bucket", + "AWS_REGION": "us-east-1", + } + + with patch.dict(os.environ, env, clear=False): + backend = create_storage_backend_from_env() + + assert isinstance(backend, S3StorageBackend) + + def test_create_from_env_s3_missing_bucket(self) -> None: + """Test error when S3 bucket env var is missing.""" + from shared.storage.factory import create_storage_backend_from_env + + env = { + "STORAGE_BACKEND": "s3", + # Missing AWS_S3_BUCKET + } + + with patch.dict(os.environ, env, clear=False): + if "AWS_S3_BUCKET" in os.environ: + del os.environ["AWS_S3_BUCKET"] + + with pytest.raises(ValueError, match="AWS_S3_BUCKET"): + create_storage_backend_from_env() + + def test_get_default_config_s3(self) -> None: + """Test getting default S3 config.""" + from shared.storage.factory import get_default_storage_config + + env = { + "STORAGE_BACKEND": "s3", + "AWS_S3_BUCKET": "test-bucket", + "AWS_REGION": "us-west-2", + "AWS_ENDPOINT_URL": "http://localhost:9000", + } + + with patch.dict(os.environ, env, clear=False): + config = get_default_storage_config() + + assert config.backend_type == "s3" + assert config.bucket_name == "test-bucket" + assert config.region_name == "us-west-2" + assert config.endpoint_url == "http://localhost:9000" + + +class TestStorageFactoryFromFile: + """Tests for create_storage_backend_from_file factory function.""" + + def test_create_from_yaml_file_local(self, tmp_path: Path) -> None: + """Test creating local backend from YAML config file.""" + from shared.storage.factory import create_storage_backend_from_file + from shared.storage.local import LocalStorageBackend + + config_file = tmp_path / "storage.yaml" + storage_path = tmp_path / "storage" + config_file.write_text(f""" +backend: local + +local: + base_path: {storage_path} +""") + + backend = create_storage_backend_from_file(config_file) + + assert isinstance(backend, LocalStorageBackend) + + @patch("shared.storage.azure.BlobServiceClient") + def test_create_from_yaml_file_azure( + self, mock_service_class: MagicMock, tmp_path: Path + ) -> None: + """Test creating Azure backend from YAML config file.""" + from shared.storage.azure import AzureBlobStorageBackend + from shared.storage.factory import create_storage_backend_from_file + + config_file = tmp_path / "storage.yaml" + config_file.write_text(""" +backend: azure_blob + +azure: + connection_string: DefaultEndpointsProtocol=https;AccountName=test + container_name: documents +""") + + backend = create_storage_backend_from_file(config_file) + + assert isinstance(backend, AzureBlobStorageBackend) + + @patch("boto3.client") + def test_create_from_yaml_file_s3( + self, mock_boto3_client: MagicMock, tmp_path: Path + ) -> None: + """Test creating S3 backend from YAML config file.""" + from shared.storage.factory import create_storage_backend_from_file + from shared.storage.s3 import S3StorageBackend + + config_file = tmp_path / "storage.yaml" + config_file.write_text(""" +backend: s3 + +s3: + bucket_name: my-bucket + region_name: us-east-1 +""") + + backend = create_storage_backend_from_file(config_file) + + assert isinstance(backend, S3StorageBackend) + + def test_create_from_file_with_env_substitution(self, tmp_path: Path) -> None: + """Test that env vars are substituted in config file.""" + from shared.storage.factory import create_storage_backend_from_file + from shared.storage.local import LocalStorageBackend + + config_file = tmp_path / "storage.yaml" + storage_path = tmp_path / "storage" + config_file.write_text(""" +backend: ${STORAGE_BACKEND:-local} + +local: + base_path: ${CUSTOM_STORAGE_PATH} +""") + + with patch.dict( + os.environ, + {"STORAGE_BACKEND": "local", "CUSTOM_STORAGE_PATH": str(storage_path)}, + ): + backend = create_storage_backend_from_file(config_file) + + assert isinstance(backend, LocalStorageBackend) + + def test_create_from_file_not_found_raises(self, tmp_path: Path) -> None: + """Test that FileNotFoundError is raised for missing file.""" + from shared.storage.factory import create_storage_backend_from_file + + with pytest.raises(FileNotFoundError): + create_storage_backend_from_file(tmp_path / "nonexistent.yaml") + + +class TestGetStorageBackend: + """Tests for get_storage_backend convenience function.""" + + def test_get_storage_backend_from_file(self, tmp_path: Path) -> None: + """Test getting backend from explicit config file.""" + from shared.storage.factory import get_storage_backend + from shared.storage.local import LocalStorageBackend + + config_file = tmp_path / "storage.yaml" + storage_path = tmp_path / "storage" + config_file.write_text(f""" +backend: local + +local: + base_path: {storage_path} +""") + + backend = get_storage_backend(config_path=config_file) + + assert isinstance(backend, LocalStorageBackend) + + def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None: + """Test that get_storage_backend falls back to env vars.""" + from shared.storage.factory import get_storage_backend + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + env = { + "STORAGE_BACKEND": "local", + "STORAGE_BASE_PATH": str(storage_path), + } + + with patch.dict(os.environ, env, clear=False): + # No config file provided, should use env vars + backend = get_storage_backend(config_path=None) + + assert isinstance(backend, LocalStorageBackend) diff --git a/tests/shared/storage/test_local.py b/tests/shared/storage/test_local.py new file mode 100644 index 0000000..9b97348 --- /dev/null +++ b/tests/shared/storage/test_local.py @@ -0,0 +1,712 @@ +""" +Tests for LocalStorageBackend. + +TDD Phase 1: RED - Write tests first, then implement to pass. +""" + +import shutil +import tempfile +from pathlib import Path + +import pytest + + +@pytest.fixture +def temp_storage_dir() -> Path: + """Create a temporary directory for storage tests.""" + temp_dir = Path(tempfile.mkdtemp()) + yield temp_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def sample_file(temp_storage_dir: Path) -> Path: + """Create a sample file for testing.""" + file_path = temp_storage_dir / "sample.txt" + file_path.write_text("Hello, World!") + return file_path + + +@pytest.fixture +def sample_image(temp_storage_dir: Path) -> Path: + """Create a sample PNG file for testing.""" + file_path = temp_storage_dir / "sample.png" + # Minimal valid PNG (1x1 transparent pixel) + png_data = bytes( + [ + 0x89, + 0x50, + 0x4E, + 0x47, + 0x0D, + 0x0A, + 0x1A, + 0x0A, # PNG signature + 0x00, + 0x00, + 0x00, + 0x0D, # IHDR length + 0x49, + 0x48, + 0x44, + 0x52, # IHDR + 0x00, + 0x00, + 0x00, + 0x01, # width: 1 + 0x00, + 0x00, + 0x00, + 0x01, # height: 1 + 0x08, + 0x06, + 0x00, + 0x00, + 0x00, # 8-bit RGBA + 0x1F, + 0x15, + 0xC4, + 0x89, # CRC + 0x00, + 0x00, + 0x00, + 0x0A, # IDAT length + 0x49, + 0x44, + 0x41, + 0x54, # IDAT + 0x78, + 0x9C, + 0x63, + 0x00, + 0x01, + 0x00, + 0x00, + 0x05, + 0x00, + 0x01, # compressed data + 0x0D, + 0x0A, + 0x2D, + 0xB4, # CRC + 0x00, + 0x00, + 0x00, + 0x00, # IEND length + 0x49, + 0x45, + 0x4E, + 0x44, # IEND + 0xAE, + 0x42, + 0x60, + 0x82, # CRC + ] + ) + file_path.write_bytes(png_data) + return file_path + + +class TestLocalStorageBackendCreation: + """Tests for LocalStorageBackend instantiation.""" + + def test_create_with_base_path(self, temp_storage_dir: Path) -> None: + """Test creating backend with base path.""" + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + assert backend.base_path == temp_storage_dir + + def test_create_with_string_path(self, temp_storage_dir: Path) -> None: + """Test creating backend with string path.""" + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=str(temp_storage_dir)) + + assert backend.base_path == temp_storage_dir + + def test_create_creates_directory_if_not_exists( + self, temp_storage_dir: Path + ) -> None: + """Test that base directory is created if it doesn't exist.""" + from shared.storage.local import LocalStorageBackend + + new_dir = temp_storage_dir / "new_storage" + assert not new_dir.exists() + + backend = LocalStorageBackend(base_path=new_dir) + + assert new_dir.exists() + assert backend.base_path == new_dir + + def test_is_storage_backend_subclass(self, temp_storage_dir: Path) -> None: + """Test that LocalStorageBackend is a StorageBackend.""" + from shared.storage.base import StorageBackend + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + assert isinstance(backend, StorageBackend) + + +class TestLocalStorageBackendUpload: + """Tests for LocalStorageBackend.upload method.""" + + def test_upload_file( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test uploading a file.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + result = backend.upload(sample_file, "uploads/sample.txt") + + assert result == "uploads/sample.txt" + assert (storage_dir / "uploads" / "sample.txt").exists() + assert (storage_dir / "uploads" / "sample.txt").read_text() == "Hello, World!" + + def test_upload_creates_subdirectories( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that upload creates necessary subdirectories.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + result = backend.upload(sample_file, "deep/nested/path/sample.txt") + + assert (storage_dir / "deep" / "nested" / "path" / "sample.txt").exists() + + def test_upload_fails_if_file_exists_without_overwrite( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that upload fails if file exists and overwrite is False.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + # First upload succeeds + backend.upload(sample_file, "sample.txt") + + # Second upload should fail + with pytest.raises(StorageError, match="already exists"): + backend.upload(sample_file, "sample.txt", overwrite=False) + + def test_upload_succeeds_with_overwrite( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that upload succeeds with overwrite=True.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + # First upload + backend.upload(sample_file, "sample.txt") + + # Modify original file + sample_file.write_text("Modified content") + + # Second upload with overwrite + result = backend.upload(sample_file, "sample.txt", overwrite=True) + + assert result == "sample.txt" + assert (storage_dir / "sample.txt").read_text() == "Modified content" + + def test_upload_nonexistent_file_fails(self, temp_storage_dir: Path) -> None: + """Test that uploading nonexistent file fails.""" + from shared.storage.base import FileNotFoundStorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(FileNotFoundStorageError): + backend.upload(Path("/nonexistent/file.txt"), "sample.txt") + + def test_upload_binary_file( + self, temp_storage_dir: Path, sample_image: Path + ) -> None: + """Test uploading a binary file.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + result = backend.upload(sample_image, "images/sample.png") + + assert result == "images/sample.png" + uploaded_content = (storage_dir / "images" / "sample.png").read_bytes() + assert uploaded_content == sample_image.read_bytes() + + +class TestLocalStorageBackendDownload: + """Tests for LocalStorageBackend.download method.""" + + def test_download_file( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test downloading a file.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + download_dir = temp_storage_dir / "downloads" + download_dir.mkdir() + backend = LocalStorageBackend(base_path=storage_dir) + + # First upload + backend.upload(sample_file, "sample.txt") + + # Then download + local_path = download_dir / "downloaded.txt" + result = backend.download("sample.txt", local_path) + + assert result == local_path + assert local_path.exists() + assert local_path.read_text() == "Hello, World!" + + def test_download_creates_parent_directories( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that download creates parent directories.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "sample.txt") + + local_path = temp_storage_dir / "deep" / "nested" / "downloaded.txt" + result = backend.download("sample.txt", local_path) + + assert local_path.exists() + assert local_path.read_text() == "Hello, World!" + + def test_download_nonexistent_file_fails(self, temp_storage_dir: Path) -> None: + """Test that downloading nonexistent file fails.""" + from shared.storage.base import FileNotFoundStorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(FileNotFoundStorageError, match="nonexistent.txt"): + backend.download("nonexistent.txt", Path("/tmp/file.txt")) + + def test_download_nested_file( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test downloading a file from nested path.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "a/b/c/sample.txt") + + local_path = temp_storage_dir / "downloaded.txt" + result = backend.download("a/b/c/sample.txt", local_path) + + assert local_path.read_text() == "Hello, World!" + + +class TestLocalStorageBackendExists: + """Tests for LocalStorageBackend.exists method.""" + + def test_exists_returns_true_for_existing_file( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test exists returns True for existing file.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "sample.txt") + + assert backend.exists("sample.txt") is True + + def test_exists_returns_false_for_nonexistent_file( + self, temp_storage_dir: Path + ) -> None: + """Test exists returns False for nonexistent file.""" + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + assert backend.exists("nonexistent.txt") is False + + def test_exists_with_nested_path( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test exists with nested path.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "a/b/sample.txt") + + assert backend.exists("a/b/sample.txt") is True + assert backend.exists("a/b/other.txt") is False + + +class TestLocalStorageBackendListFiles: + """Tests for LocalStorageBackend.list_files method.""" + + def test_list_files_empty_storage(self, temp_storage_dir: Path) -> None: + """Test listing files in empty storage.""" + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + assert backend.list_files("") == [] + + def test_list_files_returns_all_files( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test listing all files.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + # Upload multiple files + backend.upload(sample_file, "file1.txt") + backend.upload(sample_file, "file2.txt") + backend.upload(sample_file, "subdir/file3.txt") + + files = backend.list_files("") + + assert len(files) == 3 + assert "file1.txt" in files + assert "file2.txt" in files + assert "subdir/file3.txt" in files + + def test_list_files_with_prefix( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test listing files with prefix filter.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + backend.upload(sample_file, "images/a.png") + backend.upload(sample_file, "images/b.png") + backend.upload(sample_file, "labels/a.txt") + + files = backend.list_files("images/") + + assert len(files) == 2 + assert "images/a.png" in files + assert "images/b.png" in files + assert "labels/a.txt" not in files + + def test_list_files_returns_sorted( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that list_files returns sorted list.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + backend.upload(sample_file, "c.txt") + backend.upload(sample_file, "a.txt") + backend.upload(sample_file, "b.txt") + + files = backend.list_files("") + + assert files == ["a.txt", "b.txt", "c.txt"] + + +class TestLocalStorageBackendDelete: + """Tests for LocalStorageBackend.delete method.""" + + def test_delete_existing_file( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test deleting an existing file.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "sample.txt") + + result = backend.delete("sample.txt") + + assert result is True + assert not (storage_dir / "sample.txt").exists() + + def test_delete_nonexistent_file_returns_false( + self, temp_storage_dir: Path + ) -> None: + """Test deleting nonexistent file returns False.""" + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + result = backend.delete("nonexistent.txt") + + assert result is False + + def test_delete_nested_file( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test deleting a nested file.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "a/b/sample.txt") + + result = backend.delete("a/b/sample.txt") + + assert result is True + assert not (storage_dir / "a" / "b" / "sample.txt").exists() + + +class TestLocalStorageBackendGetUrl: + """Tests for LocalStorageBackend.get_url method.""" + + def test_get_url_returns_file_path( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test get_url returns file:// URL.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "sample.txt") + + url = backend.get_url("sample.txt") + + # Should return file:// URL or absolute path + assert "sample.txt" in url + # URL should be usable to locate the file + expected_path = storage_dir / "sample.txt" + assert str(expected_path) in url or expected_path.as_uri() == url + + def test_get_url_nonexistent_file(self, temp_storage_dir: Path) -> None: + """Test get_url for nonexistent file.""" + from shared.storage.base import FileNotFoundStorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(FileNotFoundStorageError): + backend.get_url("nonexistent.txt") + + +class TestLocalStorageBackendUploadBytes: + """Tests for LocalStorageBackend.upload_bytes method.""" + + def test_upload_bytes(self, temp_storage_dir: Path) -> None: + """Test uploading bytes directly.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + data = b"Binary content here" + result = backend.upload_bytes(data, "binary.dat") + + assert result == "binary.dat" + assert (storage_dir / "binary.dat").read_bytes() == data + + def test_upload_bytes_creates_subdirectories( + self, temp_storage_dir: Path + ) -> None: + """Test that upload_bytes creates subdirectories.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + data = b"content" + backend.upload_bytes(data, "a/b/c/file.dat") + + assert (storage_dir / "a" / "b" / "c" / "file.dat").exists() + + +class TestLocalStorageBackendDownloadBytes: + """Tests for LocalStorageBackend.download_bytes method.""" + + def test_download_bytes( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test downloading file as bytes.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "sample.txt") + + data = backend.download_bytes("sample.txt") + + assert data == b"Hello, World!" + + def test_download_bytes_nonexistent(self, temp_storage_dir: Path) -> None: + """Test downloading nonexistent file as bytes.""" + from shared.storage.base import FileNotFoundStorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(FileNotFoundStorageError): + backend.download_bytes("nonexistent.txt") + + +class TestLocalStorageBackendSecurity: + """Security tests for LocalStorageBackend - path traversal prevention.""" + + def test_path_traversal_with_dotdot_blocked( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that path traversal using ../ is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.upload(sample_file, "../escape.txt") + + def test_path_traversal_with_nested_dotdot_blocked( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that nested path traversal is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.upload(sample_file, "subdir/../../escape.txt") + + def test_path_traversal_with_many_dotdot_blocked( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that deeply nested path traversal is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.upload(sample_file, "a/b/c/../../../../escape.txt") + + def test_absolute_path_unix_blocked( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that absolute Unix paths are blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Absolute paths not allowed"): + backend.upload(sample_file, "/etc/passwd") + + def test_absolute_path_windows_blocked( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that absolute Windows paths are blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Absolute paths not allowed"): + backend.upload(sample_file, "C:\\Windows\\System32\\config") + + def test_download_path_traversal_blocked( + self, temp_storage_dir: Path + ) -> None: + """Test that path traversal in download is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.download("../escape.txt", Path("/tmp/file.txt")) + + def test_exists_path_traversal_blocked( + self, temp_storage_dir: Path + ) -> None: + """Test that path traversal in exists is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.exists("../escape.txt") + + def test_delete_path_traversal_blocked( + self, temp_storage_dir: Path + ) -> None: + """Test that path traversal in delete is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.delete("../escape.txt") + + def test_get_url_path_traversal_blocked( + self, temp_storage_dir: Path + ) -> None: + """Test that path traversal in get_url is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.get_url("../escape.txt") + + def test_upload_bytes_path_traversal_blocked( + self, temp_storage_dir: Path + ) -> None: + """Test that path traversal in upload_bytes is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.upload_bytes(b"content", "../escape.txt") + + def test_download_bytes_path_traversal_blocked( + self, temp_storage_dir: Path + ) -> None: + """Test that path traversal in download_bytes is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.download_bytes("../escape.txt") + + def test_valid_nested_path_still_works( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test that valid nested paths still work after security fix.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + + # Valid nested paths should still work + result = backend.upload(sample_file, "a/b/c/d/file.txt") + + assert result == "a/b/c/d/file.txt" + assert (storage_dir / "a" / "b" / "c" / "d" / "file.txt").exists() diff --git a/tests/shared/storage/test_prefixes.py b/tests/shared/storage/test_prefixes.py new file mode 100644 index 0000000..7e3b990 --- /dev/null +++ b/tests/shared/storage/test_prefixes.py @@ -0,0 +1,158 @@ +"""Tests for storage prefixes module.""" + +import pytest + +from shared.storage.prefixes import PREFIXES, StoragePrefixes + + +class TestStoragePrefixes: + """Tests for StoragePrefixes class.""" + + def test_prefixes_are_strings(self) -> None: + """All prefix constants should be strings.""" + assert isinstance(PREFIXES.DOCUMENTS, str) + assert isinstance(PREFIXES.IMAGES, str) + assert isinstance(PREFIXES.UPLOADS, str) + assert isinstance(PREFIXES.RESULTS, str) + assert isinstance(PREFIXES.EXPORTS, str) + assert isinstance(PREFIXES.DATASETS, str) + assert isinstance(PREFIXES.MODELS, str) + assert isinstance(PREFIXES.RAW_PDFS, str) + assert isinstance(PREFIXES.STRUCTURED_DATA, str) + assert isinstance(PREFIXES.ADMIN_IMAGES, str) + + def test_prefixes_are_non_empty(self) -> None: + """All prefix constants should be non-empty.""" + assert PREFIXES.DOCUMENTS + assert PREFIXES.IMAGES + assert PREFIXES.UPLOADS + assert PREFIXES.RESULTS + assert PREFIXES.EXPORTS + assert PREFIXES.DATASETS + assert PREFIXES.MODELS + assert PREFIXES.RAW_PDFS + assert PREFIXES.STRUCTURED_DATA + assert PREFIXES.ADMIN_IMAGES + + def test_prefixes_have_no_leading_slash(self) -> None: + """Prefixes should not start with a slash for portability.""" + assert not PREFIXES.DOCUMENTS.startswith("/") + assert not PREFIXES.IMAGES.startswith("/") + assert not PREFIXES.UPLOADS.startswith("/") + assert not PREFIXES.RESULTS.startswith("/") + + def test_prefixes_have_no_trailing_slash(self) -> None: + """Prefixes should not end with a slash.""" + assert not PREFIXES.DOCUMENTS.endswith("/") + assert not PREFIXES.IMAGES.endswith("/") + assert not PREFIXES.UPLOADS.endswith("/") + assert not PREFIXES.RESULTS.endswith("/") + + def test_frozen_dataclass(self) -> None: + """StoragePrefixes should be immutable.""" + with pytest.raises(Exception): # FrozenInstanceError + PREFIXES.DOCUMENTS = "new_value" # type: ignore + + +class TestDocumentPath: + """Tests for document_path helper.""" + + def test_document_path_with_extension(self) -> None: + """Should generate correct document path with extension.""" + path = PREFIXES.document_path("abc123", ".pdf") + assert path == "documents/abc123.pdf" + + def test_document_path_without_leading_dot(self) -> None: + """Should handle extension without leading dot.""" + path = PREFIXES.document_path("abc123", "pdf") + assert path == "documents/abc123.pdf" + + def test_document_path_default_extension(self) -> None: + """Should use .pdf as default extension.""" + path = PREFIXES.document_path("abc123") + assert path == "documents/abc123.pdf" + + +class TestImagePath: + """Tests for image_path helper.""" + + def test_image_path_basic(self) -> None: + """Should generate correct image path.""" + path = PREFIXES.image_path("doc123", 1) + assert path == "images/doc123/page_1.png" + + def test_image_path_page_number(self) -> None: + """Should include page number in path.""" + path = PREFIXES.image_path("doc123", 5) + assert path == "images/doc123/page_5.png" + + def test_image_path_custom_extension(self) -> None: + """Should support custom extension.""" + path = PREFIXES.image_path("doc123", 1, ".jpg") + assert path == "images/doc123/page_1.jpg" + + +class TestUploadPath: + """Tests for upload_path helper.""" + + def test_upload_path_basic(self) -> None: + """Should generate correct upload path.""" + path = PREFIXES.upload_path("invoice.pdf") + assert path == "uploads/invoice.pdf" + + def test_upload_path_with_subfolder(self) -> None: + """Should include subfolder when provided.""" + path = PREFIXES.upload_path("invoice.pdf", "async") + assert path == "uploads/async/invoice.pdf" + + +class TestResultPath: + """Tests for result_path helper.""" + + def test_result_path_basic(self) -> None: + """Should generate correct result path.""" + path = PREFIXES.result_path("output.json") + assert path == "results/output.json" + + +class TestExportPath: + """Tests for export_path helper.""" + + def test_export_path_basic(self) -> None: + """Should generate correct export path.""" + path = PREFIXES.export_path("exp123", "dataset.zip") + assert path == "exports/exp123/dataset.zip" + + +class TestDatasetPath: + """Tests for dataset_path helper.""" + + def test_dataset_path_basic(self) -> None: + """Should generate correct dataset path.""" + path = PREFIXES.dataset_path("ds123", "data.yaml") + assert path == "datasets/ds123/data.yaml" + + +class TestModelPath: + """Tests for model_path helper.""" + + def test_model_path_basic(self) -> None: + """Should generate correct model path.""" + path = PREFIXES.model_path("v1.0.0", "best.pt") + assert path == "models/v1.0.0/best.pt" + + +class TestExportsFromInit: + """Tests for exports from storage __init__.py.""" + + def test_prefixes_exported(self) -> None: + """PREFIXES should be exported from storage module.""" + from shared.storage import PREFIXES as exported_prefixes + + assert exported_prefixes is PREFIXES + + def test_storage_prefixes_exported(self) -> None: + """StoragePrefixes should be exported from storage module.""" + from shared.storage import StoragePrefixes as exported_class + + assert exported_class is StoragePrefixes diff --git a/tests/shared/storage/test_presigned_urls.py b/tests/shared/storage/test_presigned_urls.py new file mode 100644 index 0000000..6cbbfa8 --- /dev/null +++ b/tests/shared/storage/test_presigned_urls.py @@ -0,0 +1,264 @@ +""" +Tests for pre-signed URL functionality across all storage backends. + +TDD Phase 1: RED - Write tests first, then implement to pass. +""" + +import shutil +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def temp_storage_dir() -> Path: + """Create a temporary directory for storage tests.""" + temp_dir = Path(tempfile.mkdtemp()) + yield temp_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def sample_file(temp_storage_dir: Path) -> Path: + """Create a sample file for testing.""" + file_path = temp_storage_dir / "sample.txt" + file_path.write_text("Hello, World!") + return file_path + + +class TestStorageBackendInterfacePresignedUrl: + """Tests for get_presigned_url in StorageBackend interface.""" + + def test_subclass_must_implement_get_presigned_url(self) -> None: + """Test that subclass must implement get_presigned_url method.""" + from shared.storage.base import StorageBackend + + class IncompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def exists(self, remote_path: str) -> bool: + return False + + def list_files(self, prefix: str) -> list[str]: + return [] + + def delete(self, remote_path: str) -> bool: + return True + + def get_url(self, remote_path: str) -> str: + return "" + + with pytest.raises(TypeError): + IncompleteBackend() # type: ignore + + def test_valid_subclass_with_get_presigned_url_can_be_instantiated(self) -> None: + """Test that a complete subclass with get_presigned_url can be instantiated.""" + from shared.storage.base import StorageBackend + + class CompleteBackend(StorageBackend): + def upload( + self, local_path: Path, remote_path: str, overwrite: bool = False + ) -> str: + return remote_path + + def download(self, remote_path: str, local_path: Path) -> Path: + return local_path + + def exists(self, remote_path: str) -> bool: + return False + + def list_files(self, prefix: str) -> list[str]: + return [] + + def delete(self, remote_path: str) -> bool: + return True + + def get_url(self, remote_path: str) -> str: + return "" + + def get_presigned_url( + self, remote_path: str, expires_in_seconds: int = 3600 + ) -> str: + return f"https://example.com/{remote_path}?token=abc" + + backend = CompleteBackend() + assert isinstance(backend, StorageBackend) + + +class TestLocalStorageBackendPresignedUrl: + """Tests for LocalStorageBackend.get_presigned_url method.""" + + def test_get_presigned_url_returns_file_uri( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test get_presigned_url returns file:// URI for existing file.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "sample.txt") + + url = backend.get_presigned_url("sample.txt") + + assert url.startswith("file://") + assert "sample.txt" in url + + def test_get_presigned_url_with_custom_expiry( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test get_presigned_url accepts expires_in_seconds parameter.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "sample.txt") + + # For local storage, expiry is ignored but should not raise error + url = backend.get_presigned_url("sample.txt", expires_in_seconds=7200) + + assert url.startswith("file://") + + def test_get_presigned_url_nonexistent_file_raises( + self, temp_storage_dir: Path + ) -> None: + """Test get_presigned_url raises FileNotFoundStorageError for missing file.""" + from shared.storage.base import FileNotFoundStorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(FileNotFoundStorageError): + backend.get_presigned_url("nonexistent.txt") + + def test_get_presigned_url_path_traversal_blocked( + self, temp_storage_dir: Path + ) -> None: + """Test that path traversal in get_presigned_url is blocked.""" + from shared.storage.base import StorageError + from shared.storage.local import LocalStorageBackend + + backend = LocalStorageBackend(base_path=temp_storage_dir) + + with pytest.raises(StorageError, match="Path traversal not allowed"): + backend.get_presigned_url("../escape.txt") + + def test_get_presigned_url_nested_path( + self, temp_storage_dir: Path, sample_file: Path + ) -> None: + """Test get_presigned_url works with nested paths.""" + from shared.storage.local import LocalStorageBackend + + storage_dir = temp_storage_dir / "storage" + backend = LocalStorageBackend(base_path=storage_dir) + backend.upload(sample_file, "a/b/c/sample.txt") + + url = backend.get_presigned_url("a/b/c/sample.txt") + + assert url.startswith("file://") + assert "sample.txt" in url + + +class TestAzureBlobStorageBackendPresignedUrl: + """Tests for AzureBlobStorageBackend.get_presigned_url method.""" + + @patch("shared.storage.azure.BlobServiceClient") + def test_get_presigned_url_generates_sas_url( + self, mock_blob_service_class: MagicMock + ) -> None: + """Test get_presigned_url generates URL with SAS token.""" + from shared.storage.azure import AzureBlobStorageBackend + + # Setup mocks + mock_blob_service = MagicMock() + mock_blob_service.account_name = "testaccount" + mock_blob_service_class.from_connection_string.return_value = mock_blob_service + + mock_container = MagicMock() + mock_container.exists.return_value = True + mock_blob_service.get_container_client.return_value = mock_container + + mock_blob_client = MagicMock() + mock_blob_client.exists.return_value = True + mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt" + mock_container.get_blob_client.return_value = mock_blob_client + + backend = AzureBlobStorageBackend( + connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net", + container_name="container", + ) + + with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas: + mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123" + + url = backend.get_presigned_url("test.txt", expires_in_seconds=3600) + + assert "https://testaccount.blob.core.windows.net" in url + assert "sv=2021-06-08" in url or "test.txt" in url + + @patch("shared.storage.azure.BlobServiceClient") + def test_get_presigned_url_nonexistent_blob_raises( + self, mock_blob_service_class: MagicMock + ) -> None: + """Test get_presigned_url raises for nonexistent blob.""" + from shared.storage.base import FileNotFoundStorageError + from shared.storage.azure import AzureBlobStorageBackend + + mock_blob_service = MagicMock() + mock_blob_service_class.from_connection_string.return_value = mock_blob_service + + mock_container = MagicMock() + mock_container.exists.return_value = True + mock_blob_service.get_container_client.return_value = mock_container + + mock_blob_client = MagicMock() + mock_blob_client.exists.return_value = False + mock_container.get_blob_client.return_value = mock_blob_client + + backend = AzureBlobStorageBackend( + connection_string="DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key==;EndpointSuffix=core.windows.net", + container_name="container", + ) + + with pytest.raises(FileNotFoundStorageError): + backend.get_presigned_url("nonexistent.txt") + + @patch("shared.storage.azure.BlobServiceClient") + def test_get_presigned_url_uses_custom_expiry( + self, mock_blob_service_class: MagicMock + ) -> None: + """Test get_presigned_url uses custom expiry time.""" + from shared.storage.azure import AzureBlobStorageBackend + + mock_blob_service = MagicMock() + mock_blob_service.account_name = "testaccount" + mock_blob_service_class.from_connection_string.return_value = mock_blob_service + + mock_container = MagicMock() + mock_container.exists.return_value = True + mock_blob_service.get_container_client.return_value = mock_container + + mock_blob_client = MagicMock() + mock_blob_client.exists.return_value = True + mock_blob_client.url = "https://testaccount.blob.core.windows.net/container/test.txt" + mock_container.get_blob_client.return_value = mock_blob_client + + backend = AzureBlobStorageBackend( + connection_string="DefaultEndpointsProtocol=https;AccountName=testaccount;AccountKey=testkey==;EndpointSuffix=core.windows.net", + container_name="container", + ) + + with patch("shared.storage.azure.generate_blob_sas") as mock_generate_sas: + mock_generate_sas.return_value = "sv=2021-06-08&sr=b&sig=abc123" + + backend.get_presigned_url("test.txt", expires_in_seconds=7200) + + # Verify generate_blob_sas was called (expiry is part of the call) + mock_generate_sas.assert_called_once() diff --git a/tests/shared/storage/test_s3.py b/tests/shared/storage/test_s3.py new file mode 100644 index 0000000..4b20457 --- /dev/null +++ b/tests/shared/storage/test_s3.py @@ -0,0 +1,520 @@ +""" +Tests for S3StorageBackend. + +TDD Phase 1: RED - Write tests first, then implement to pass. +""" + +import shutil +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch, call + +import pytest + + +@pytest.fixture +def temp_dir() -> Path: + """Create a temporary directory for tests.""" + temp_dir = Path(tempfile.mkdtemp()) + yield temp_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def sample_file(temp_dir: Path) -> Path: + """Create a sample file for testing.""" + file_path = temp_dir / "sample.txt" + file_path.write_text("Hello, World!") + return file_path + + +@pytest.fixture +def mock_boto3_client(): + """Create a mock boto3 S3 client.""" + with patch("boto3.client") as mock_client_func: + mock_client = MagicMock() + mock_client_func.return_value = mock_client + yield mock_client + + +class TestS3StorageBackendCreation: + """Tests for S3StorageBackend instantiation.""" + + def test_create_with_bucket_name(self, mock_boto3_client: MagicMock) -> None: + """Test creating backend with bucket name.""" + from shared.storage.s3 import S3StorageBackend + + backend = S3StorageBackend(bucket_name="test-bucket") + + assert backend.bucket_name == "test-bucket" + + def test_create_with_region(self, mock_boto3_client: MagicMock) -> None: + """Test creating backend with region.""" + from shared.storage.s3 import S3StorageBackend + + with patch("boto3.client") as mock_client: + S3StorageBackend( + bucket_name="test-bucket", + region_name="us-west-2", + ) + + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert call_kwargs.get("region_name") == "us-west-2" + + def test_create_with_credentials(self, mock_boto3_client: MagicMock) -> None: + """Test creating backend with explicit credentials.""" + from shared.storage.s3 import S3StorageBackend + + with patch("boto3.client") as mock_client: + S3StorageBackend( + bucket_name="test-bucket", + access_key_id="AKIATEST", + secret_access_key="secret123", + ) + + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert call_kwargs.get("aws_access_key_id") == "AKIATEST" + assert call_kwargs.get("aws_secret_access_key") == "secret123" + + def test_create_with_endpoint_url(self, mock_boto3_client: MagicMock) -> None: + """Test creating backend with custom endpoint (for S3-compatible services).""" + from shared.storage.s3 import S3StorageBackend + + with patch("boto3.client") as mock_client: + S3StorageBackend( + bucket_name="test-bucket", + endpoint_url="http://localhost:9000", + ) + + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert call_kwargs.get("endpoint_url") == "http://localhost:9000" + + def test_create_bucket_when_requested(self, mock_boto3_client: MagicMock) -> None: + """Test that bucket is created when create_bucket=True.""" + from botocore.exceptions import ClientError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_bucket.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadBucket" + ) + + S3StorageBackend( + bucket_name="test-bucket", + create_bucket=True, + ) + + mock_boto3_client.create_bucket.assert_called_once() + + def test_is_storage_backend_subclass(self, mock_boto3_client: MagicMock) -> None: + """Test that S3StorageBackend is a StorageBackend.""" + from shared.storage.base import StorageBackend + from shared.storage.s3 import S3StorageBackend + + backend = S3StorageBackend(bucket_name="test-bucket") + + assert isinstance(backend, StorageBackend) + + +class TestS3StorageBackendUpload: + """Tests for S3StorageBackend.upload method.""" + + def test_upload_file( + self, mock_boto3_client: MagicMock, temp_dir: Path, sample_file: Path + ) -> None: + """Test uploading a file.""" + from botocore.exceptions import ClientError + from shared.storage.s3 import S3StorageBackend + + # Object does not exist + mock_boto3_client.head_object.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + + result = backend.upload(sample_file, "uploads/sample.txt") + + assert result == "uploads/sample.txt" + mock_boto3_client.upload_file.assert_called_once() + + def test_upload_fails_if_exists_without_overwrite( + self, mock_boto3_client: MagicMock, sample_file: Path + ) -> None: + """Test that upload fails if object exists and overwrite is False.""" + from shared.storage.base import StorageError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} # Object exists + + backend = S3StorageBackend(bucket_name="test-bucket") + + with pytest.raises(StorageError, match="already exists"): + backend.upload(sample_file, "sample.txt", overwrite=False) + + def test_upload_succeeds_with_overwrite( + self, mock_boto3_client: MagicMock, sample_file: Path + ) -> None: + """Test that upload succeeds with overwrite=True.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} # Object exists + + backend = S3StorageBackend(bucket_name="test-bucket") + result = backend.upload(sample_file, "sample.txt", overwrite=True) + + assert result == "sample.txt" + mock_boto3_client.upload_file.assert_called_once() + + def test_upload_nonexistent_file_fails( + self, mock_boto3_client: MagicMock, temp_dir: Path + ) -> None: + """Test that uploading nonexistent file fails.""" + from shared.storage.base import FileNotFoundStorageError + from shared.storage.s3 import S3StorageBackend + + backend = S3StorageBackend(bucket_name="test-bucket") + + with pytest.raises(FileNotFoundStorageError): + backend.upload(temp_dir / "nonexistent.txt", "sample.txt") + + +class TestS3StorageBackendDownload: + """Tests for S3StorageBackend.download method.""" + + def test_download_file( + self, mock_boto3_client: MagicMock, temp_dir: Path + ) -> None: + """Test downloading a file.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} # Object exists + + backend = S3StorageBackend(bucket_name="test-bucket") + local_path = temp_dir / "downloaded.txt" + + result = backend.download("sample.txt", local_path) + + assert result == local_path + mock_boto3_client.download_file.assert_called_once() + + def test_download_creates_parent_directories( + self, mock_boto3_client: MagicMock, temp_dir: Path + ) -> None: + """Test that download creates parent directories.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} + + backend = S3StorageBackend(bucket_name="test-bucket") + local_path = temp_dir / "deep" / "nested" / "downloaded.txt" + + backend.download("sample.txt", local_path) + + assert local_path.parent.exists() + + def test_download_nonexistent_object_fails( + self, mock_boto3_client: MagicMock, temp_dir: Path + ) -> None: + """Test that downloading nonexistent object fails.""" + from botocore.exceptions import ClientError + from shared.storage.base import FileNotFoundStorageError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + + with pytest.raises(FileNotFoundStorageError): + backend.download("nonexistent.txt", temp_dir / "file.txt") + + +class TestS3StorageBackendExists: + """Tests for S3StorageBackend.exists method.""" + + def test_exists_returns_true_for_existing_object( + self, mock_boto3_client: MagicMock + ) -> None: + """Test exists returns True for existing object.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} + + backend = S3StorageBackend(bucket_name="test-bucket") + + assert backend.exists("sample.txt") is True + + def test_exists_returns_false_for_nonexistent_object( + self, mock_boto3_client: MagicMock + ) -> None: + """Test exists returns False for nonexistent object.""" + from botocore.exceptions import ClientError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + + assert backend.exists("nonexistent.txt") is False + + +class TestS3StorageBackendListFiles: + """Tests for S3StorageBackend.list_files method.""" + + def test_list_files_returns_objects( + self, mock_boto3_client: MagicMock + ) -> None: + """Test listing objects.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.list_objects_v2.return_value = { + "Contents": [ + {"Key": "file1.txt"}, + {"Key": "file2.txt"}, + {"Key": "subdir/file3.txt"}, + ] + } + + backend = S3StorageBackend(bucket_name="test-bucket") + files = backend.list_files("") + + assert len(files) == 3 + assert "file1.txt" in files + assert "file2.txt" in files + assert "subdir/file3.txt" in files + + def test_list_files_with_prefix( + self, mock_boto3_client: MagicMock + ) -> None: + """Test listing objects with prefix filter.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.list_objects_v2.return_value = { + "Contents": [ + {"Key": "images/a.png"}, + {"Key": "images/b.png"}, + ] + } + + backend = S3StorageBackend(bucket_name="test-bucket") + files = backend.list_files("images/") + + mock_boto3_client.list_objects_v2.assert_called_with( + Bucket="test-bucket", Prefix="images/" + ) + + def test_list_files_empty_bucket( + self, mock_boto3_client: MagicMock + ) -> None: + """Test listing files in empty bucket.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.list_objects_v2.return_value = {} # No Contents key + + backend = S3StorageBackend(bucket_name="test-bucket") + files = backend.list_files("") + + assert files == [] + + +class TestS3StorageBackendDelete: + """Tests for S3StorageBackend.delete method.""" + + def test_delete_existing_object( + self, mock_boto3_client: MagicMock + ) -> None: + """Test deleting an existing object.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} + + backend = S3StorageBackend(bucket_name="test-bucket") + result = backend.delete("sample.txt") + + assert result is True + mock_boto3_client.delete_object.assert_called_once() + + def test_delete_nonexistent_object_returns_false( + self, mock_boto3_client: MagicMock + ) -> None: + """Test deleting nonexistent object returns False.""" + from botocore.exceptions import ClientError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + result = backend.delete("nonexistent.txt") + + assert result is False + + +class TestS3StorageBackendGetUrl: + """Tests for S3StorageBackend.get_url method.""" + + def test_get_url_returns_s3_url( + self, mock_boto3_client: MagicMock + ) -> None: + """Test get_url returns S3 URL.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} + mock_boto3_client.generate_presigned_url.return_value = ( + "https://test-bucket.s3.amazonaws.com/sample.txt" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + url = backend.get_url("sample.txt") + + assert "sample.txt" in url + + def test_get_url_nonexistent_object_raises( + self, mock_boto3_client: MagicMock + ) -> None: + """Test get_url raises for nonexistent object.""" + from botocore.exceptions import ClientError + from shared.storage.base import FileNotFoundStorageError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + + with pytest.raises(FileNotFoundStorageError): + backend.get_url("nonexistent.txt") + + +class TestS3StorageBackendUploadBytes: + """Tests for S3StorageBackend.upload_bytes method.""" + + def test_upload_bytes( + self, mock_boto3_client: MagicMock + ) -> None: + """Test uploading bytes directly.""" + from shared.storage.s3 import S3StorageBackend + + from botocore.exceptions import ClientError + mock_boto3_client.head_object.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + data = b"Binary content here" + + result = backend.upload_bytes(data, "binary.dat") + + assert result == "binary.dat" + mock_boto3_client.put_object.assert_called_once() + + def test_upload_bytes_fails_if_exists_without_overwrite( + self, mock_boto3_client: MagicMock + ) -> None: + """Test upload_bytes fails if object exists and overwrite is False.""" + from shared.storage.base import StorageError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} # Object exists + + backend = S3StorageBackend(bucket_name="test-bucket") + + with pytest.raises(StorageError, match="already exists"): + backend.upload_bytes(b"content", "sample.txt", overwrite=False) + + +class TestS3StorageBackendDownloadBytes: + """Tests for S3StorageBackend.download_bytes method.""" + + def test_download_bytes( + self, mock_boto3_client: MagicMock + ) -> None: + """Test downloading object as bytes.""" + from shared.storage.s3 import S3StorageBackend + + mock_response = MagicMock() + mock_response.read.return_value = b"Hello, World!" + mock_boto3_client.get_object.return_value = {"Body": mock_response} + + backend = S3StorageBackend(bucket_name="test-bucket") + data = backend.download_bytes("sample.txt") + + assert data == b"Hello, World!" + + def test_download_bytes_nonexistent_raises( + self, mock_boto3_client: MagicMock + ) -> None: + """Test downloading nonexistent object as bytes.""" + from botocore.exceptions import ClientError + from shared.storage.base import FileNotFoundStorageError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.get_object.side_effect = ClientError( + {"Error": {"Code": "NoSuchKey"}}, "GetObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + + with pytest.raises(FileNotFoundStorageError): + backend.download_bytes("nonexistent.txt") + + +class TestS3StorageBackendPresignedUrl: + """Tests for S3StorageBackend.get_presigned_url method.""" + + def test_get_presigned_url_generates_url( + self, mock_boto3_client: MagicMock + ) -> None: + """Test get_presigned_url generates presigned URL.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} + mock_boto3_client.generate_presigned_url.return_value = ( + "https://test-bucket.s3.amazonaws.com/sample.txt?X-Amz-Algorithm=..." + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + url = backend.get_presigned_url("sample.txt") + + assert "X-Amz-Algorithm" in url or "sample.txt" in url + mock_boto3_client.generate_presigned_url.assert_called_once() + + def test_get_presigned_url_with_custom_expiry( + self, mock_boto3_client: MagicMock + ) -> None: + """Test get_presigned_url uses custom expiry.""" + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.return_value = {} + mock_boto3_client.generate_presigned_url.return_value = "https://..." + + backend = S3StorageBackend(bucket_name="test-bucket") + backend.get_presigned_url("sample.txt", expires_in_seconds=7200) + + call_args = mock_boto3_client.generate_presigned_url.call_args + assert call_args[1].get("ExpiresIn") == 7200 + + def test_get_presigned_url_nonexistent_raises( + self, mock_boto3_client: MagicMock + ) -> None: + """Test get_presigned_url raises for nonexistent object.""" + from botocore.exceptions import ClientError + from shared.storage.base import FileNotFoundStorageError + from shared.storage.s3 import S3StorageBackend + + mock_boto3_client.head_object.side_effect = ClientError( + {"Error": {"Code": "404"}}, "HeadObject" + ) + + backend = S3StorageBackend(bucket_name="test-bucket") + + with pytest.raises(FileNotFoundStorageError): + backend.get_presigned_url("nonexistent.txt") diff --git a/tests/web/test_admin_annotations.py b/tests/web/test_admin_annotations.py index 9265810..0140b03 100644 --- a/tests/web/test_admin_annotations.py +++ b/tests/web/test_admin_annotations.py @@ -9,7 +9,8 @@ from uuid import UUID from fastapi import HTTPException -from inference.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES +from inference.data.admin_models import AdminAnnotation, AdminDocument +from shared.fields import FIELD_CLASSES from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router from inference.web.schemas.admin import ( AnnotationCreate, diff --git a/tests/web/test_admin_routes_enhanced.py b/tests/web/test_admin_routes_enhanced.py index 55b0563..7c23ce4 100644 --- a/tests/web/test_admin_routes_enhanced.py +++ b/tests/web/test_admin_routes_enhanced.py @@ -31,6 +31,7 @@ class MockAdminDocument: self.batch_id = kwargs.get('batch_id', None) self.csv_field_values = kwargs.get('csv_field_values', None) self.annotation_lock_until = kwargs.get('annotation_lock_until', None) + self.category = kwargs.get('category', 'invoice') self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) @@ -67,12 +68,13 @@ class MockAdminDB: def get_documents_by_token( self, - admin_token, + admin_token=None, status=None, upload_source=None, has_annotations=None, auto_label_status=None, batch_id=None, + category=None, limit=20, offset=0 ): @@ -95,6 +97,8 @@ class MockAdminDB: docs = [d for d in docs if d.auto_label_status == auto_label_status] if batch_id: docs = [d for d in docs if str(d.batch_id) == str(batch_id)] + if category: + docs = [d for d in docs if d.category == category] total = len(docs) return docs[offset:offset+limit], total diff --git a/tests/web/test_async_service.py b/tests/web/test_async_service.py index 556dc1e..0999852 100644 --- a/tests/web/test_async_service.py +++ b/tests/web/test_async_service.py @@ -215,8 +215,10 @@ class TestAsyncProcessingService: def test_cleanup_orphan_files(self, async_service, mock_db): """Test cleanup of orphan files.""" - # Create an orphan file + # Create the async upload directory temp_dir = async_service._async_config.temp_upload_dir + temp_dir.mkdir(parents=True, exist_ok=True) + orphan_file = temp_dir / "orphan-request.pdf" orphan_file.write_bytes(b"orphan content") @@ -228,7 +230,13 @@ class TestAsyncProcessingService: # Mock database to say file doesn't exist mock_db.get_request.return_value = None - count = async_service._cleanup_orphan_files() + # Mock the storage helper to return the same directory as the fixture + with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage: + mock_helper = MagicMock() + mock_helper.get_uploads_base_path.return_value = temp_dir + mock_storage.return_value = mock_helper + + count = async_service._cleanup_orphan_files() assert count == 1 assert not orphan_file.exists() diff --git a/tests/web/test_augmentation_routes.py b/tests/web/test_augmentation_routes.py index 698d876..f6bd2bc 100644 --- a/tests/web/test_augmentation_routes.py +++ b/tests/web/test_augmentation_routes.py @@ -5,7 +5,75 @@ TDD Phase 5: RED - Write tests first, then implement to pass. """ import pytest +from unittest.mock import MagicMock, patch +from fastapi import FastAPI from fastapi.testclient import TestClient +import numpy as np + +from inference.web.api.v1.admin.augmentation import create_augmentation_router +from inference.web.core.auth import validate_admin_token, get_admin_db + + +TEST_ADMIN_TOKEN = "test-admin-token-12345" +TEST_DOCUMENT_UUID = "550e8400-e29b-41d4-a716-446655440001" +TEST_DATASET_UUID = "660e8400-e29b-41d4-a716-446655440001" + + +@pytest.fixture +def admin_token() -> str: + """Provide admin token for testing.""" + return TEST_ADMIN_TOKEN + + +@pytest.fixture +def mock_admin_db() -> MagicMock: + """Create a mock AdminDB for testing.""" + mock = MagicMock() + # Default return values + mock.get_document_by_token.return_value = None + mock.get_dataset.return_value = None + mock.get_augmented_datasets.return_value = ([], 0) + return mock + + +@pytest.fixture +def admin_client(mock_admin_db: MagicMock) -> TestClient: + """Create test client with admin authentication.""" + app = FastAPI() + + # Override dependencies + def get_token_override(): + return TEST_ADMIN_TOKEN + + def get_db_override(): + return mock_admin_db + + app.dependency_overrides[validate_admin_token] = get_token_override + app.dependency_overrides[get_admin_db] = get_db_override + + # Include router - the router already has /augmentation prefix + # so we add /api/v1/admin to get /api/v1/admin/augmentation + router = create_augmentation_router() + app.include_router(router, prefix="/api/v1/admin") + + return TestClient(app) + + +@pytest.fixture +def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient: + """Create test client WITHOUT admin authentication override.""" + app = FastAPI() + + # Only override the database, NOT the token validation + def get_db_override(): + return mock_admin_db + + app.dependency_overrides[get_admin_db] = get_db_override + + router = create_augmentation_router() + app.include_router(router, prefix="/api/v1/admin") + + return TestClient(app) class TestAugmentationTypesEndpoint: @@ -34,10 +102,10 @@ class TestAugmentationTypesEndpoint: assert "stage" in aug_type def test_list_augmentation_types_unauthorized( - self, admin_client: TestClient + self, unauthenticated_client: TestClient ) -> None: """Test that unauthorized request is rejected.""" - response = admin_client.get("/api/v1/admin/augmentation/types") + response = unauthenticated_client.get("/api/v1/admin/augmentation/types") assert response.status_code == 401 @@ -74,16 +142,30 @@ class TestAugmentationPreviewEndpoint: admin_client: TestClient, admin_token: str, sample_document_id: str, + mock_admin_db: MagicMock, ) -> None: """Test previewing augmentation on a document.""" - response = admin_client.post( - f"/api/v1/admin/augmentation/preview/{sample_document_id}", - headers={"X-Admin-Token": admin_token}, - json={ - "augmentation_type": "gaussian_noise", - "params": {"std": 15}, - }, - ) + # Mock document exists + mock_document = MagicMock() + mock_document.images_dir = "/fake/path" + mock_admin_db.get_document.return_value = mock_document + + # Create a fake image (100x100 RGB) + fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + + with patch( + "inference.web.services.augmentation_service.AugmentationService._load_document_page" + ) as mock_load: + mock_load.return_value = fake_image + + response = admin_client.post( + f"/api/v1/admin/augmentation/preview/{sample_document_id}", + headers={"X-Admin-Token": admin_token}, + json={ + "augmentation_type": "gaussian_noise", + "params": {"std": 15}, + }, + ) assert response.status_code == 200 data = response.json() @@ -136,18 +218,32 @@ class TestAugmentationPreviewConfigEndpoint: admin_client: TestClient, admin_token: str, sample_document_id: str, + mock_admin_db: MagicMock, ) -> None: """Test previewing full config on a document.""" - response = admin_client.post( - f"/api/v1/admin/augmentation/preview-config/{sample_document_id}", - headers={"X-Admin-Token": admin_token}, - json={ - "gaussian_noise": {"enabled": True, "probability": 1.0}, - "lighting_variation": {"enabled": True, "probability": 1.0}, - "preserve_bboxes": True, - "seed": 42, - }, - ) + # Mock document exists + mock_document = MagicMock() + mock_document.images_dir = "/fake/path" + mock_admin_db.get_document.return_value = mock_document + + # Create a fake image (100x100 RGB) + fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + + with patch( + "inference.web.services.augmentation_service.AugmentationService._load_document_page" + ) as mock_load: + mock_load.return_value = fake_image + + response = admin_client.post( + f"/api/v1/admin/augmentation/preview-config/{sample_document_id}", + headers={"X-Admin-Token": admin_token}, + json={ + "gaussian_noise": {"enabled": True, "probability": 1.0}, + "lighting_variation": {"enabled": True, "probability": 1.0}, + "preserve_bboxes": True, + "seed": 42, + }, + ) assert response.status_code == 200 data = response.json() @@ -164,8 +260,14 @@ class TestAugmentationBatchEndpoint: admin_client: TestClient, admin_token: str, sample_dataset_id: str, + mock_admin_db: MagicMock, ) -> None: """Test creating augmented dataset.""" + # Mock dataset exists + mock_dataset = MagicMock() + mock_dataset.total_images = 100 + mock_admin_db.get_dataset.return_value = mock_dataset + response = admin_client.post( "/api/v1/admin/augmentation/batch", headers={"X-Admin-Token": admin_token}, @@ -250,12 +352,10 @@ class TestAugmentedDatasetsListEndpoint: @pytest.fixture def sample_document_id() -> str: """Provide a sample document ID for testing.""" - # This would need to be created in test setup - return "test-document-id" + return TEST_DOCUMENT_UUID @pytest.fixture def sample_dataset_id() -> str: """Provide a sample dataset ID for testing.""" - # This would need to be created in test setup - return "test-dataset-id" + return TEST_DATASET_UUID diff --git a/tests/web/test_dataset_routes.py b/tests/web/test_dataset_routes.py index 2f1e5a5..4063161 100644 --- a/tests/web/test_dataset_routes.py +++ b/tests/web/test_dataset_routes.py @@ -35,6 +35,8 @@ def _make_dataset(**overrides) -> MagicMock: name="test-dataset", description="Test dataset", status="ready", + training_status=None, + active_training_task_id=None, train_ratio=0.8, val_ratio=0.1, seed=42, @@ -183,6 +185,8 @@ class TestListDatasetsRoute: mock_db = MagicMock() mock_db.get_datasets.return_value = ([_make_dataset()], 1) + # Mock the active training tasks lookup to return empty dict + mock_db.get_active_training_tasks_for_datasets.return_value = {} result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0)) diff --git a/tests/web/test_dataset_training_status.py b/tests/web/test_dataset_training_status.py new file mode 100644 index 0000000..e2e330b --- /dev/null +++ b/tests/web/test_dataset_training_status.py @@ -0,0 +1,363 @@ +""" +Tests for dataset training status feature. + +Tests cover: +1. Database model fields (training_status, active_training_task_id) +2. AdminDB update_dataset_training_status method +3. API response includes training status fields +4. Scheduler updates dataset status during training lifecycle +""" + +import pytest +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +# ============================================================================= +# Test Database Model +# ============================================================================= + + +class TestTrainingDatasetModel: + """Tests for TrainingDataset model fields.""" + + def test_training_dataset_has_training_status_field(self): + """TrainingDataset model should have training_status field.""" + from inference.data.admin_models import TrainingDataset + + dataset = TrainingDataset( + name="test-dataset", + training_status="running", + ) + assert dataset.training_status == "running" + + def test_training_dataset_has_active_training_task_id_field(self): + """TrainingDataset model should have active_training_task_id field.""" + from inference.data.admin_models import TrainingDataset + + task_id = uuid4() + dataset = TrainingDataset( + name="test-dataset", + active_training_task_id=task_id, + ) + assert dataset.active_training_task_id == task_id + + def test_training_dataset_defaults(self): + """TrainingDataset should have correct defaults for new fields.""" + from inference.data.admin_models import TrainingDataset + + dataset = TrainingDataset(name="test-dataset") + assert dataset.training_status is None + assert dataset.active_training_task_id is None + + +# ============================================================================= +# Test AdminDB Methods +# ============================================================================= + + +class TestAdminDBDatasetTrainingStatus: + """Tests for AdminDB.update_dataset_training_status method.""" + + @pytest.fixture + def mock_session(self): + """Create mock database session.""" + session = MagicMock() + return session + + def test_update_dataset_training_status_sets_status(self, mock_session): + """update_dataset_training_status should set training_status.""" + from inference.data.admin_models import TrainingDataset + + dataset_id = uuid4() + dataset = TrainingDataset( + dataset_id=dataset_id, + name="test-dataset", + status="ready", + ) + mock_session.get.return_value = dataset + + with patch("inference.data.admin_db.get_session_context") as mock_ctx: + mock_ctx.return_value.__enter__.return_value = mock_session + + from inference.data.admin_db import AdminDB + + db = AdminDB() + db.update_dataset_training_status( + dataset_id=str(dataset_id), + training_status="running", + ) + + assert dataset.training_status == "running" + mock_session.add.assert_called_once_with(dataset) + mock_session.commit.assert_called_once() + + def test_update_dataset_training_status_sets_task_id(self, mock_session): + """update_dataset_training_status should set active_training_task_id.""" + from inference.data.admin_models import TrainingDataset + + dataset_id = uuid4() + task_id = uuid4() + dataset = TrainingDataset( + dataset_id=dataset_id, + name="test-dataset", + status="ready", + ) + mock_session.get.return_value = dataset + + with patch("inference.data.admin_db.get_session_context") as mock_ctx: + mock_ctx.return_value.__enter__.return_value = mock_session + + from inference.data.admin_db import AdminDB + + db = AdminDB() + db.update_dataset_training_status( + dataset_id=str(dataset_id), + training_status="running", + active_training_task_id=str(task_id), + ) + + assert dataset.active_training_task_id == task_id + + def test_update_dataset_training_status_updates_main_status_on_complete( + self, mock_session + ): + """update_dataset_training_status should update main status to 'trained' when completed.""" + from inference.data.admin_models import TrainingDataset + + dataset_id = uuid4() + dataset = TrainingDataset( + dataset_id=dataset_id, + name="test-dataset", + status="ready", + ) + mock_session.get.return_value = dataset + + with patch("inference.data.admin_db.get_session_context") as mock_ctx: + mock_ctx.return_value.__enter__.return_value = mock_session + + from inference.data.admin_db import AdminDB + + db = AdminDB() + db.update_dataset_training_status( + dataset_id=str(dataset_id), + training_status="completed", + update_main_status=True, + ) + + assert dataset.status == "trained" + assert dataset.training_status == "completed" + + def test_update_dataset_training_status_clears_task_id_on_complete( + self, mock_session + ): + """update_dataset_training_status should clear task_id when training completes.""" + from inference.data.admin_models import TrainingDataset + + dataset_id = uuid4() + task_id = uuid4() + dataset = TrainingDataset( + dataset_id=dataset_id, + name="test-dataset", + status="ready", + training_status="running", + active_training_task_id=task_id, + ) + mock_session.get.return_value = dataset + + with patch("inference.data.admin_db.get_session_context") as mock_ctx: + mock_ctx.return_value.__enter__.return_value = mock_session + + from inference.data.admin_db import AdminDB + + db = AdminDB() + db.update_dataset_training_status( + dataset_id=str(dataset_id), + training_status="completed", + active_training_task_id=None, + ) + + assert dataset.active_training_task_id is None + + def test_update_dataset_training_status_handles_missing_dataset(self, mock_session): + """update_dataset_training_status should handle missing dataset gracefully.""" + mock_session.get.return_value = None + + with patch("inference.data.admin_db.get_session_context") as mock_ctx: + mock_ctx.return_value.__enter__.return_value = mock_session + + from inference.data.admin_db import AdminDB + + db = AdminDB() + # Should not raise + db.update_dataset_training_status( + dataset_id=str(uuid4()), + training_status="running", + ) + + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + +# ============================================================================= +# Test API Response +# ============================================================================= + + +class TestDatasetDetailResponseTrainingStatus: + """Tests for DatasetDetailResponse including training status fields.""" + + def test_dataset_detail_response_includes_training_status(self): + """DatasetDetailResponse schema should include training_status field.""" + from inference.web.schemas.admin.datasets import DatasetDetailResponse + + response = DatasetDetailResponse( + dataset_id=str(uuid4()), + name="test-dataset", + description=None, + status="ready", + training_status="running", + active_training_task_id=str(uuid4()), + train_ratio=0.8, + val_ratio=0.1, + seed=42, + total_documents=10, + total_images=15, + total_annotations=100, + dataset_path="/path/to/dataset", + error_message=None, + documents=[], + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + + assert response.training_status == "running" + assert response.active_training_task_id is not None + + def test_dataset_detail_response_allows_null_training_status(self): + """DatasetDetailResponse should allow null training_status.""" + from inference.web.schemas.admin.datasets import DatasetDetailResponse + + response = DatasetDetailResponse( + dataset_id=str(uuid4()), + name="test-dataset", + description=None, + status="ready", + training_status=None, + active_training_task_id=None, + train_ratio=0.8, + val_ratio=0.1, + seed=42, + total_documents=10, + total_images=15, + total_annotations=100, + dataset_path=None, + error_message=None, + documents=[], + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + + assert response.training_status is None + assert response.active_training_task_id is None + + +# ============================================================================= +# Test Scheduler Training Status Updates +# ============================================================================= + + +class TestSchedulerDatasetStatusUpdates: + """Tests for scheduler updating dataset status during training.""" + + @pytest.fixture + def mock_db(self): + """Create mock AdminDB.""" + mock = MagicMock() + mock.get_dataset.return_value = MagicMock( + dataset_id=uuid4(), + name="test-dataset", + dataset_path="/path/to/dataset", + total_images=100, + ) + mock.get_pending_training_tasks.return_value = [] + return mock + + def test_scheduler_sets_running_status_on_task_start(self, mock_db): + """Scheduler should set dataset training_status to 'running' when task starts.""" + from inference.web.core.scheduler import TrainingScheduler + + with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train: + mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}} + + scheduler = TrainingScheduler() + scheduler._db = mock_db + + task_id = str(uuid4()) + dataset_id = str(uuid4()) + + # Execute task (will fail but we check the status update call) + try: + scheduler._execute_task( + task_id=task_id, + config={"model_name": "yolo11n.pt"}, + dataset_id=dataset_id, + ) + except Exception: + pass # Expected to fail in test environment + + # Check that training status was updated to running + mock_db.update_dataset_training_status.assert_called() + first_call = mock_db.update_dataset_training_status.call_args_list[0] + assert first_call.kwargs["training_status"] == "running" + assert first_call.kwargs["active_training_task_id"] == task_id + + +# ============================================================================= +# Test Dataset Status Values +# ============================================================================= + + +class TestDatasetStatusValues: + """Tests for valid dataset status values.""" + + def test_dataset_status_building(self): + """Dataset can have status 'building'.""" + from inference.data.admin_models import TrainingDataset + + dataset = TrainingDataset(name="test", status="building") + assert dataset.status == "building" + + def test_dataset_status_ready(self): + """Dataset can have status 'ready'.""" + from inference.data.admin_models import TrainingDataset + + dataset = TrainingDataset(name="test", status="ready") + assert dataset.status == "ready" + + def test_dataset_status_trained(self): + """Dataset can have status 'trained'.""" + from inference.data.admin_models import TrainingDataset + + dataset = TrainingDataset(name="test", status="trained") + assert dataset.status == "trained" + + def test_dataset_status_failed(self): + """Dataset can have status 'failed'.""" + from inference.data.admin_models import TrainingDataset + + dataset = TrainingDataset(name="test", status="failed") + assert dataset.status == "failed" + + def test_training_status_values(self): + """Training status can have various values.""" + from inference.data.admin_models import TrainingDataset + + valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"] + for status in valid_statuses: + dataset = TrainingDataset(name="test", training_status=status) + assert dataset.training_status == status diff --git a/tests/web/test_document_category.py b/tests/web/test_document_category.py new file mode 100644 index 0000000..4dc4f1d --- /dev/null +++ b/tests/web/test_document_category.py @@ -0,0 +1,207 @@ +""" +Tests for Document Category Feature. + +TDD tests for adding category field to admin_documents table. +Documents can be categorized (e.g., invoice, letter, receipt) for training different models. +""" + +import pytest +from datetime import datetime +from unittest.mock import MagicMock +from uuid import UUID, uuid4 + +from inference.data.admin_models import AdminDocument + + +# Test constants +TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000" +TEST_TOKEN = "test-admin-token-12345" + + +class TestAdminDocumentCategoryField: + """Tests for AdminDocument category field.""" + + def test_document_has_category_field(self): + """Test AdminDocument model has category field.""" + doc = AdminDocument( + document_id=UUID(TEST_DOC_UUID), + filename="test.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/path/to/file.pdf", + ) + assert hasattr(doc, "category") + + def test_document_category_defaults_to_invoice(self): + """Test category defaults to 'invoice' when not specified.""" + doc = AdminDocument( + document_id=UUID(TEST_DOC_UUID), + filename="test.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/path/to/file.pdf", + ) + assert doc.category == "invoice" + + def test_document_accepts_custom_category(self): + """Test document accepts custom category values.""" + categories = ["invoice", "letter", "receipt", "contract", "custom_type"] + + for cat in categories: + doc = AdminDocument( + document_id=uuid4(), + filename="test.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/path/to/file.pdf", + category=cat, + ) + assert doc.category == cat + + def test_document_category_is_string_type(self): + """Test category field is a string type.""" + doc = AdminDocument( + document_id=UUID(TEST_DOC_UUID), + filename="test.pdf", + file_size=1024, + content_type="application/pdf", + file_path="/path/to/file.pdf", + category="letter", + ) + assert isinstance(doc.category, str) + + +class TestDocumentCategoryInReadModel: + """Tests for category in response models.""" + + def test_admin_document_read_has_category(self): + """Test AdminDocumentRead includes category field.""" + from inference.data.admin_models import AdminDocumentRead + + # Check the model has category field in its schema + assert "category" in AdminDocumentRead.model_fields + + +class TestDocumentCategoryAPI: + """Tests for document category in API endpoints.""" + + @pytest.fixture + def mock_admin_db(self): + """Create mock AdminDB.""" + db = MagicMock() + db.is_valid_admin_token.return_value = True + return db + + def test_upload_document_with_category(self, mock_admin_db): + """Test uploading document with category parameter.""" + from inference.web.schemas.admin import DocumentUploadResponse + + # Verify response schema supports category + response = DocumentUploadResponse( + document_id=TEST_DOC_UUID, + filename="test.pdf", + file_size=1024, + page_count=1, + status="pending", + message="Upload successful", + category="letter", + ) + assert response.category == "letter" + + def test_list_documents_returns_category(self, mock_admin_db): + """Test list documents endpoint returns category.""" + from inference.web.schemas.admin import DocumentItem + + item = DocumentItem( + document_id=TEST_DOC_UUID, + filename="test.pdf", + file_size=1024, + page_count=1, + status="pending", + annotation_count=0, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + category="invoice", + ) + assert item.category == "invoice" + + def test_document_detail_includes_category(self, mock_admin_db): + """Test document detail response includes category.""" + from inference.web.schemas.admin import DocumentDetailResponse + + # Check schema has category + assert "category" in DocumentDetailResponse.model_fields + + +class TestDocumentCategoryFiltering: + """Tests for filtering documents by category.""" + + @pytest.fixture + def mock_admin_db(self): + """Create mock AdminDB with category filtering support.""" + db = MagicMock() + db.is_valid_admin_token.return_value = True + + # Mock documents with different categories + invoice_doc = MagicMock() + invoice_doc.document_id = uuid4() + invoice_doc.category = "invoice" + + letter_doc = MagicMock() + letter_doc.document_id = uuid4() + letter_doc.category = "letter" + + db.get_documents_by_category.return_value = [invoice_doc] + return db + + def test_filter_documents_by_category(self, mock_admin_db): + """Test filtering documents by category.""" + # This tests the DB method signature + result = mock_admin_db.get_documents_by_category("invoice") + assert len(result) == 1 + assert result[0].category == "invoice" + + +class TestDocumentCategoryUpdate: + """Tests for updating document category.""" + + def test_update_document_category_schema(self): + """Test update document request supports category.""" + from inference.web.schemas.admin import DocumentUpdateRequest + + request = DocumentUpdateRequest(category="letter") + assert request.category == "letter" + + def test_update_document_category_optional(self): + """Test category is optional in update request.""" + from inference.web.schemas.admin import DocumentUpdateRequest + + # Should not raise - category is optional + request = DocumentUpdateRequest() + assert request.category is None + + +class TestDatasetWithCategory: + """Tests for dataset creation with category filtering.""" + + def test_dataset_create_with_category_filter(self): + """Test creating dataset can filter by document category.""" + from inference.web.schemas.admin import DatasetCreateRequest + + request = DatasetCreateRequest( + name="Invoice Training Set", + document_ids=[TEST_DOC_UUID], + category="invoice", # Optional filter + ) + assert request.category == "invoice" + + def test_dataset_create_category_is_optional(self): + """Test category filter is optional when creating dataset.""" + from inference.web.schemas.admin import DatasetCreateRequest + + request = DatasetCreateRequest( + name="Mixed Training Set", + document_ids=[TEST_DOC_UUID], + ) + # category should be optional + assert not hasattr(request, "category") or request.category is None diff --git a/tests/web/test_document_category_api.py b/tests/web/test_document_category_api.py new file mode 100644 index 0000000..8822361 --- /dev/null +++ b/tests/web/test_document_category_api.py @@ -0,0 +1,165 @@ +""" +Tests for Document Category API Endpoints. + +TDD tests for category filtering and management in document endpoints. +""" + +import pytest +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import UUID, uuid4 + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +# Test constants +TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000" +TEST_TOKEN = "test-admin-token-12345" + + +class TestGetCategoriesEndpoint: + """Tests for GET /admin/documents/categories endpoint.""" + + def test_categories_endpoint_returns_list(self): + """Test categories endpoint returns list of available categories.""" + from inference.web.schemas.admin import DocumentCategoriesResponse + + # Test schema exists and works + response = DocumentCategoriesResponse( + categories=["invoice", "letter", "receipt"], + total=3, + ) + assert response.categories == ["invoice", "letter", "receipt"] + assert response.total == 3 + + def test_categories_response_schema(self): + """Test DocumentCategoriesResponse schema structure.""" + from inference.web.schemas.admin import DocumentCategoriesResponse + + assert "categories" in DocumentCategoriesResponse.model_fields + assert "total" in DocumentCategoriesResponse.model_fields + + +class TestDocumentListFilterByCategory: + """Tests for filtering documents by category.""" + + @pytest.fixture + def mock_admin_db(self): + """Create mock AdminDB.""" + db = MagicMock() + db.is_valid_admin_token.return_value = True + + # Mock documents with different categories + invoice_doc = MagicMock() + invoice_doc.document_id = uuid4() + invoice_doc.category = "invoice" + invoice_doc.filename = "invoice1.pdf" + + letter_doc = MagicMock() + letter_doc.document_id = uuid4() + letter_doc.category = "letter" + letter_doc.filename = "letter1.pdf" + + db.get_documents.return_value = ([invoice_doc], 1) + db.get_document_categories.return_value = ["invoice", "letter", "receipt"] + return db + + def test_list_documents_accepts_category_filter(self, mock_admin_db): + """Test list documents endpoint accepts category query parameter.""" + # The endpoint should accept ?category=invoice parameter + # This test verifies the schema/query parameter exists + from inference.web.schemas.admin import DocumentListResponse + + # Schema should work with category filter applied + assert DocumentListResponse is not None + + def test_get_document_categories_from_db(self, mock_admin_db): + """Test fetching unique categories from database.""" + categories = mock_admin_db.get_document_categories() + assert "invoice" in categories + assert "letter" in categories + assert len(categories) == 3 + + +class TestDocumentUploadWithCategory: + """Tests for uploading documents with category.""" + + def test_upload_request_accepts_category(self): + """Test upload request can include category field.""" + # When uploading via form data, category should be accepted + # This is typically a form field, not a schema + pass + + def test_upload_response_includes_category(self): + """Test upload response includes the category that was set.""" + from inference.web.schemas.admin import DocumentUploadResponse + + response = DocumentUploadResponse( + document_id=TEST_DOC_UUID, + filename="test.pdf", + file_size=1024, + page_count=1, + status="pending", + category="letter", # Custom category + message="Upload successful", + ) + assert response.category == "letter" + + def test_upload_defaults_to_invoice_category(self): + """Test upload defaults to 'invoice' if no category specified.""" + from inference.web.schemas.admin import DocumentUploadResponse + + response = DocumentUploadResponse( + document_id=TEST_DOC_UUID, + filename="test.pdf", + file_size=1024, + page_count=1, + status="pending", + message="Upload successful", + # No category specified - should default to "invoice" + ) + assert response.category == "invoice" + + +class TestAdminDBCategoryMethods: + """Tests for AdminDB category-related methods.""" + + def test_get_document_categories_method_exists(self): + """Test AdminDB has get_document_categories method.""" + from inference.data.admin_db import AdminDB + + db = AdminDB() + assert hasattr(db, "get_document_categories") + + def test_get_documents_accepts_category_filter(self): + """Test get_documents_by_token method accepts category parameter.""" + from inference.data.admin_db import AdminDB + import inspect + + db = AdminDB() + # Check the method exists and accepts category parameter + method = getattr(db, "get_documents_by_token", None) + assert callable(method) + + # Check category is in the method signature + sig = inspect.signature(method) + assert "category" in sig.parameters + + +class TestUpdateDocumentCategory: + """Tests for updating document category.""" + + def test_update_document_category_method_exists(self): + """Test AdminDB has method to update document category.""" + from inference.data.admin_db import AdminDB + + db = AdminDB() + assert hasattr(db, "update_document_category") + + def test_update_request_schema(self): + """Test DocumentUpdateRequest can update category.""" + from inference.web.schemas.admin import DocumentUpdateRequest + + request = DocumentUpdateRequest(category="receipt") + assert request.category == "receipt" diff --git a/tests/web/test_inference_api.py b/tests/web/test_inference_api.py index bc09f53..466e35d 100644 --- a/tests/web/test_inference_api.py +++ b/tests/web/test_inference_api.py @@ -32,10 +32,10 @@ def test_app(tmp_path): use_gpu=False, dpi=150, ), - storage=StorageConfig( + file=StorageConfig( upload_dir=upload_dir, result_dir=result_dir, - allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"}, + allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"), max_file_size_mb=50, ), ) @@ -252,20 +252,25 @@ class TestResultsEndpoint: response = client.get("/api/v1/results/nonexistent.png") assert response.status_code == 404 - def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path): + def test_get_result_image_returns_file_if_exists(self, client, tmp_path): """Test that existing result file is returned.""" - # Get storage config from app - storage_config = test_app.extra.get("storage_config") - if not storage_config: - pytest.skip("Storage config not available in test app") - - # Create a test result file - result_file = storage_config.result_dir / "test_result.png" + # Create a test result file in temp directory + result_dir = tmp_path / "results" + result_dir.mkdir(exist_ok=True) + result_file = result_dir / "test_result.png" img = Image.new('RGB', (100, 100), color='red') img.save(result_file) - # Request the file - response = client.get("/api/v1/results/test_result.png") + # Mock the storage helper to return our test file path + with patch( + "inference.web.api.v1.public.inference.get_storage_helper" + ) as mock_storage: + mock_helper = Mock() + mock_helper.get_result_local_path.return_value = result_file + mock_storage.return_value = mock_helper + + # Request the file + response = client.get("/api/v1/results/test_result.png") assert response.status_code == 200 assert response.headers["content-type"] == "image/png" diff --git a/tests/web/test_model_versions.py b/tests/web/test_model_versions.py index e28249a..353f281 100644 --- a/tests/web/test_model_versions.py +++ b/tests/web/test_model_versions.py @@ -266,7 +266,11 @@ class TestActivateModelVersionRoute: mock_db = MagicMock() mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True) - result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) + # Create mock request with app state + mock_request = MagicMock() + mock_request.app.state.inference_service = None + + result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db)) mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID) assert result.status == "active" @@ -278,10 +282,14 @@ class TestActivateModelVersionRoute: mock_db = MagicMock() mock_db.activate_model_version.return_value = None + # Create mock request with app state + mock_request = MagicMock() + mock_request.app.state.inference_service = None + from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: - asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db)) + asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db)) assert exc_info.value.status_code == 404 diff --git a/tests/web/test_storage_helpers.py b/tests/web/test_storage_helpers.py new file mode 100644 index 0000000..bab09fa --- /dev/null +++ b/tests/web/test_storage_helpers.py @@ -0,0 +1,828 @@ +"""Tests for storage helpers module.""" + +import pytest +from unittest.mock import MagicMock, patch + +from inference.web.services.storage_helpers import StorageHelper, get_storage_helper +from shared.storage import PREFIXES + + +@pytest.fixture +def mock_storage() -> MagicMock: + """Create a mock storage backend.""" + storage = MagicMock() + storage.upload_bytes = MagicMock() + storage.download_bytes = MagicMock(return_value=b"test content") + storage.get_presigned_url = MagicMock(return_value="https://example.com/file") + storage.exists = MagicMock(return_value=True) + storage.delete = MagicMock(return_value=True) + storage.list_files = MagicMock(return_value=[]) + return storage + + +@pytest.fixture +def helper(mock_storage: MagicMock) -> StorageHelper: + """Create a storage helper with mock backend.""" + return StorageHelper(storage=mock_storage) + + +class TestStorageHelperInit: + """Tests for StorageHelper initialization.""" + + def test_init_with_storage(self, mock_storage: MagicMock) -> None: + """Should use provided storage backend.""" + helper = StorageHelper(storage=mock_storage) + assert helper.storage is mock_storage + + def test_storage_property(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Storage property should return the backend.""" + assert helper.storage is mock_storage + + +class TestDocumentOperations: + """Tests for document storage operations.""" + + def test_upload_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should upload document with correct path.""" + doc_id, path = helper.upload_document(b"pdf content", "invoice.pdf", "doc123") + + assert doc_id == "doc123" + assert path == "documents/doc123.pdf" + mock_storage.upload_bytes.assert_called_once_with( + b"pdf content", "documents/doc123.pdf", overwrite=True + ) + + def test_upload_document_generates_id(self, helper: StorageHelper) -> None: + """Should generate document ID if not provided.""" + doc_id, path = helper.upload_document(b"content", "file.pdf") + + assert doc_id is not None + assert len(doc_id) > 0 + assert path.startswith("documents/") + + def test_download_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should download document from correct path.""" + content = helper.download_document("doc123") + + assert content == b"test content" + mock_storage.download_bytes.assert_called_once_with("documents/doc123.pdf") + + def test_get_document_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get presigned URL for document.""" + url = helper.get_document_url("doc123", expires_in_seconds=7200) + + assert url == "https://example.com/file" + mock_storage.get_presigned_url.assert_called_once_with( + "documents/doc123.pdf", 7200 + ) + + def test_document_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should check document existence.""" + exists = helper.document_exists("doc123") + + assert exists is True + mock_storage.exists.assert_called_once_with("documents/doc123.pdf") + + def test_delete_document(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should delete document.""" + result = helper.delete_document("doc123") + + assert result is True + mock_storage.delete.assert_called_once_with("documents/doc123.pdf") + + +class TestImageOperations: + """Tests for image storage operations.""" + + def test_save_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should save page image with correct path.""" + path = helper.save_page_image("doc123", 1, b"image data") + + assert path == "images/doc123/page_1.png" + mock_storage.upload_bytes.assert_called_once_with( + b"image data", "images/doc123/page_1.png", overwrite=True + ) + + def test_get_page_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get page image from correct path.""" + content = helper.get_page_image("doc123", 2) + + assert content == b"test content" + mock_storage.download_bytes.assert_called_once_with("images/doc123/page_2.png") + + def test_get_page_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get presigned URL for page image.""" + url = helper.get_page_image_url("doc123", 3) + + assert url == "https://example.com/file" + mock_storage.get_presigned_url.assert_called_once_with( + "images/doc123/page_3.png", 3600 + ) + + def test_delete_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should delete all images for a document.""" + mock_storage.list_files.return_value = [ + "images/doc123/page_1.png", + "images/doc123/page_2.png", + ] + + deleted = helper.delete_document_images("doc123") + + assert deleted == 2 + mock_storage.list_files.assert_called_once_with("images/doc123/") + + def test_list_document_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should list all images for a document.""" + mock_storage.list_files.return_value = ["images/doc123/page_1.png"] + + images = helper.list_document_images("doc123") + + assert images == ["images/doc123/page_1.png"] + + +class TestUploadOperations: + """Tests for upload staging operations.""" + + def test_save_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should save upload to correct path.""" + path = helper.save_upload(b"content", "file.pdf") + + assert path == "uploads/file.pdf" + mock_storage.upload_bytes.assert_called_once() + + def test_save_upload_with_subfolder(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should save upload with subfolder.""" + path = helper.save_upload(b"content", "file.pdf", "async") + + assert path == "uploads/async/file.pdf" + + def test_get_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get upload from correct path.""" + content = helper.get_upload("file.pdf", "async") + + mock_storage.download_bytes.assert_called_once_with("uploads/async/file.pdf") + + def test_delete_upload(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should delete upload.""" + result = helper.delete_upload("file.pdf") + + assert result is True + mock_storage.delete.assert_called_once_with("uploads/file.pdf") + + +class TestResultOperations: + """Tests for result file operations.""" + + def test_save_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should save result to correct path.""" + path = helper.save_result(b"result data", "output.json") + + assert path == "results/output.json" + mock_storage.upload_bytes.assert_called_once() + + def test_get_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get result from correct path.""" + content = helper.get_result("output.json") + + mock_storage.download_bytes.assert_called_once_with("results/output.json") + + def test_get_result_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get presigned URL for result.""" + url = helper.get_result_url("output.json") + + mock_storage.get_presigned_url.assert_called_once_with("results/output.json", 3600) + + def test_result_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should check result existence.""" + exists = helper.result_exists("output.json") + + assert exists is True + mock_storage.exists.assert_called_once_with("results/output.json") + + +class TestExportOperations: + """Tests for export file operations.""" + + def test_save_export(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should save export to correct path.""" + path = helper.save_export(b"export data", "exp123", "dataset.zip") + + assert path == "exports/exp123/dataset.zip" + mock_storage.upload_bytes.assert_called_once() + + def test_get_export_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get presigned URL for export.""" + url = helper.get_export_url("exp123", "dataset.zip") + + mock_storage.get_presigned_url.assert_called_once_with( + "exports/exp123/dataset.zip", 3600 + ) + + +class TestRawPdfOperations: + """Tests for raw PDF operations (legacy compatibility).""" + + def test_save_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should save raw PDF to correct path.""" + path = helper.save_raw_pdf(b"pdf data", "invoice.pdf") + + assert path == "raw_pdfs/invoice.pdf" + mock_storage.upload_bytes.assert_called_once() + + def test_get_raw_pdf(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get raw PDF from correct path.""" + content = helper.get_raw_pdf("invoice.pdf") + + mock_storage.download_bytes.assert_called_once_with("raw_pdfs/invoice.pdf") + + def test_raw_pdf_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should check raw PDF existence.""" + exists = helper.raw_pdf_exists("invoice.pdf") + + assert exists is True + mock_storage.exists.assert_called_once_with("raw_pdfs/invoice.pdf") + + +class TestAdminImageOperations: + """Tests for admin image storage operations.""" + + def test_save_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should save admin image with correct path.""" + path = helper.save_admin_image("doc123", 1, b"image data") + + assert path == "admin_images/doc123/page_1.png" + mock_storage.upload_bytes.assert_called_once_with( + b"image data", "admin_images/doc123/page_1.png", overwrite=True + ) + + def test_get_admin_image(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get admin image from correct path.""" + content = helper.get_admin_image("doc123", 2) + + assert content == b"test content" + mock_storage.download_bytes.assert_called_once_with("admin_images/doc123/page_2.png") + + def test_get_admin_image_url(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should get presigned URL for admin image.""" + url = helper.get_admin_image_url("doc123", 3) + + assert url == "https://example.com/file" + mock_storage.get_presigned_url.assert_called_once_with( + "admin_images/doc123/page_3.png", 3600 + ) + + def test_admin_image_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should check admin image existence.""" + exists = helper.admin_image_exists("doc123", 1) + + assert exists is True + mock_storage.exists.assert_called_once_with("admin_images/doc123/page_1.png") + + def test_get_admin_image_path(self, helper: StorageHelper) -> None: + """Should return correct admin image path.""" + path = helper.get_admin_image_path("doc123", 2) + + assert path == "admin_images/doc123/page_2.png" + + def test_list_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should list all admin images for a document.""" + mock_storage.list_files.return_value = [ + "admin_images/doc123/page_1.png", + "admin_images/doc123/page_2.png", + ] + + images = helper.list_admin_images("doc123") + + assert images == ["admin_images/doc123/page_1.png", "admin_images/doc123/page_2.png"] + mock_storage.list_files.assert_called_once_with("admin_images/doc123/") + + def test_delete_admin_images(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should delete all admin images for a document.""" + mock_storage.list_files.return_value = [ + "admin_images/doc123/page_1.png", + "admin_images/doc123/page_2.png", + ] + + deleted = helper.delete_admin_images("doc123") + + assert deleted == 2 + mock_storage.list_files.assert_called_once_with("admin_images/doc123/") + + +class TestGetLocalPath: + """Tests for get_local_path method.""" + + def test_get_admin_image_local_path_with_local_storage(self) -> None: + """Should return local path when using local storage backend.""" + from pathlib import Path + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + # Create a test image + test_path = Path(temp_dir) / "admin_images" / "doc123" + test_path.mkdir(parents=True, exist_ok=True) + (test_path / "page_1.png").write_bytes(b"test image") + + local_path = helper.get_admin_image_local_path("doc123", 1) + + assert local_path is not None + assert local_path.exists() + assert local_path.name == "page_1.png" + + def test_get_admin_image_local_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when storage doesn't support local paths.""" + # Mock storage without get_local_path method (simulating cloud storage) + mock_storage.get_local_path = MagicMock(return_value=None) + helper = StorageHelper(storage=mock_storage) + + local_path = helper.get_admin_image_local_path("doc123", 1) + + assert local_path is None + + def test_get_admin_image_local_path_nonexistent_file(self) -> None: + """Should return None when file doesn't exist.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + local_path = helper.get_admin_image_local_path("nonexistent", 1) + + assert local_path is None + + +class TestGetAdminImageDimensions: + """Tests for get_admin_image_dimensions method.""" + + def test_get_dimensions_with_local_storage(self) -> None: + """Should return image dimensions when using local storage.""" + from pathlib import Path + from shared.storage.local import LocalStorageBackend + from PIL import Image + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + # Create a test image with known dimensions + test_path = Path(temp_dir) / "admin_images" / "doc123" + test_path.mkdir(parents=True, exist_ok=True) + img = Image.new("RGB", (800, 600), color="white") + img.save(test_path / "page_1.png") + + dimensions = helper.get_admin_image_dimensions("doc123", 1) + + assert dimensions == (800, 600) + + def test_get_dimensions_nonexistent_file(self) -> None: + """Should return None when file doesn't exist.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + dimensions = helper.get_admin_image_dimensions("nonexistent", 1) + + assert dimensions is None + + +class TestGetStorageHelper: + """Tests for get_storage_helper function.""" + + def test_returns_helper_instance(self) -> None: + """Should return a StorageHelper instance.""" + with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get: + mock_get.return_value = MagicMock() + # Reset the global helper + import inference.web.services.storage_helpers as module + module._default_helper = None + + helper = get_storage_helper() + + assert isinstance(helper, StorageHelper) + + def test_returns_same_instance(self) -> None: + """Should return the same instance on subsequent calls.""" + with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get: + mock_get.return_value = MagicMock() + import inference.web.services.storage_helpers as module + module._default_helper = None + + helper1 = get_storage_helper() + helper2 = get_storage_helper() + + assert helper1 is helper2 + + +class TestDeleteResult: + """Tests for delete_result method.""" + + def test_delete_result(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should delete result file.""" + result = helper.delete_result("output.json") + + assert result is True + mock_storage.delete.assert_called_once_with("results/output.json") + + +class TestResultLocalPath: + """Tests for get_result_local_path method.""" + + def test_get_result_local_path_with_local_storage(self) -> None: + """Should return local path when using local storage backend.""" + from pathlib import Path + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + # Create a test result file + test_path = Path(temp_dir) / "results" + test_path.mkdir(parents=True, exist_ok=True) + (test_path / "output.json").write_bytes(b"test result") + + local_path = helper.get_result_local_path("output.json") + + assert local_path is not None + assert local_path.exists() + assert local_path.name == "output.json" + + def test_get_result_local_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when storage doesn't support local paths.""" + helper = StorageHelper(storage=mock_storage) + local_path = helper.get_result_local_path("output.json") + assert local_path is None + + def test_get_result_local_path_nonexistent_file(self) -> None: + """Should return None when file doesn't exist.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + local_path = helper.get_result_local_path("nonexistent.json") + + assert local_path is None + + +class TestResultsBasePath: + """Tests for get_results_base_path method.""" + + def test_get_results_base_path_with_local_storage(self) -> None: + """Should return base path when using local storage.""" + from pathlib import Path + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + base_path = helper.get_results_base_path() + + assert base_path is not None + assert base_path.exists() + assert base_path.name == "results" + + def test_get_results_base_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + base_path = helper.get_results_base_path() + assert base_path is None + + +class TestUploadLocalPath: + """Tests for get_upload_local_path method.""" + + def test_get_upload_local_path_with_local_storage(self) -> None: + """Should return local path when using local storage backend.""" + from pathlib import Path + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + # Create a test upload file + test_path = Path(temp_dir) / "uploads" + test_path.mkdir(parents=True, exist_ok=True) + (test_path / "file.pdf").write_bytes(b"test upload") + + local_path = helper.get_upload_local_path("file.pdf") + + assert local_path is not None + assert local_path.exists() + assert local_path.name == "file.pdf" + + def test_get_upload_local_path_with_subfolder(self) -> None: + """Should return local path with subfolder.""" + from pathlib import Path + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + # Create a test upload file with subfolder + test_path = Path(temp_dir) / "uploads" / "async" + test_path.mkdir(parents=True, exist_ok=True) + (test_path / "file.pdf").write_bytes(b"test upload") + + local_path = helper.get_upload_local_path("file.pdf", "async") + + assert local_path is not None + assert local_path.exists() + + def test_get_upload_local_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + local_path = helper.get_upload_local_path("file.pdf") + assert local_path is None + + +class TestUploadsBasePath: + """Tests for get_uploads_base_path method.""" + + def test_get_uploads_base_path_with_local_storage(self) -> None: + """Should return base path when using local storage.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + base_path = helper.get_uploads_base_path() + + assert base_path is not None + assert base_path.exists() + assert base_path.name == "uploads" + + def test_get_uploads_base_path_with_subfolder(self) -> None: + """Should return base path with subfolder.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + base_path = helper.get_uploads_base_path("async") + + assert base_path is not None + assert base_path.exists() + assert base_path.name == "async" + + def test_get_uploads_base_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + base_path = helper.get_uploads_base_path() + assert base_path is None + + +class TestUploadExists: + """Tests for upload_exists method.""" + + def test_upload_exists(self, helper: StorageHelper, mock_storage: MagicMock) -> None: + """Should check upload existence.""" + exists = helper.upload_exists("file.pdf") + + assert exists is True + mock_storage.exists.assert_called_once_with("uploads/file.pdf") + + def test_upload_exists_with_subfolder( + self, helper: StorageHelper, mock_storage: MagicMock + ) -> None: + """Should check upload existence with subfolder.""" + helper.upload_exists("file.pdf", "async") + + mock_storage.exists.assert_called_once_with("uploads/async/file.pdf") + + +class TestDatasetsBasePath: + """Tests for get_datasets_base_path method.""" + + def test_get_datasets_base_path_with_local_storage(self) -> None: + """Should return base path when using local storage.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + base_path = helper.get_datasets_base_path() + + assert base_path is not None + assert base_path.exists() + assert base_path.name == "datasets" + + def test_get_datasets_base_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + base_path = helper.get_datasets_base_path() + assert base_path is None + + +class TestAdminImagesBasePath: + """Tests for get_admin_images_base_path method.""" + + def test_get_admin_images_base_path_with_local_storage(self) -> None: + """Should return base path when using local storage.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + base_path = helper.get_admin_images_base_path() + + assert base_path is not None + assert base_path.exists() + assert base_path.name == "admin_images" + + def test_get_admin_images_base_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + base_path = helper.get_admin_images_base_path() + assert base_path is None + + +class TestRawPdfsBasePath: + """Tests for get_raw_pdfs_base_path method.""" + + def test_get_raw_pdfs_base_path_with_local_storage(self) -> None: + """Should return base path when using local storage.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + base_path = helper.get_raw_pdfs_base_path() + + assert base_path is not None + assert base_path.exists() + assert base_path.name == "raw_pdfs" + + def test_get_raw_pdfs_base_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + base_path = helper.get_raw_pdfs_base_path() + assert base_path is None + + +class TestRawPdfLocalPath: + """Tests for get_raw_pdf_local_path method.""" + + def test_get_raw_pdf_local_path_with_local_storage(self) -> None: + """Should return local path when using local storage backend.""" + from pathlib import Path + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + # Create a test raw PDF + test_path = Path(temp_dir) / "raw_pdfs" + test_path.mkdir(parents=True, exist_ok=True) + (test_path / "invoice.pdf").write_bytes(b"test pdf") + + local_path = helper.get_raw_pdf_local_path("invoice.pdf") + + assert local_path is not None + assert local_path.exists() + assert local_path.name == "invoice.pdf" + + def test_get_raw_pdf_local_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + local_path = helper.get_raw_pdf_local_path("invoice.pdf") + assert local_path is None + + +class TestRawPdfPath: + """Tests for get_raw_pdf_path method.""" + + def test_get_raw_pdf_path(self, helper: StorageHelper) -> None: + """Should return correct storage path.""" + path = helper.get_raw_pdf_path("invoice.pdf") + assert path == "raw_pdfs/invoice.pdf" + + +class TestAutolabelOutputPath: + """Tests for get_autolabel_output_path method.""" + + def test_get_autolabel_output_path_with_local_storage(self) -> None: + """Should return output path when using local storage.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + output_path = helper.get_autolabel_output_path() + + assert output_path is not None + assert output_path.exists() + assert output_path.name == "autolabel_output" + + def test_get_autolabel_output_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + output_path = helper.get_autolabel_output_path() + assert output_path is None + + +class TestTrainingDataPath: + """Tests for get_training_data_path method.""" + + def test_get_training_data_path_with_local_storage(self) -> None: + """Should return training path when using local storage.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + training_path = helper.get_training_data_path() + + assert training_path is not None + assert training_path.exists() + assert training_path.name == "training" + + def test_get_training_data_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + training_path = helper.get_training_data_path() + assert training_path is None + + +class TestExportsBasePath: + """Tests for get_exports_base_path method.""" + + def test_get_exports_base_path_with_local_storage(self) -> None: + """Should return base path when using local storage.""" + from shared.storage.local import LocalStorageBackend + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + storage = LocalStorageBackend(temp_dir) + helper = StorageHelper(storage=storage) + + base_path = helper.get_exports_base_path() + + assert base_path is not None + assert base_path.exists() + assert base_path.name == "exports" + + def test_get_exports_base_path_returns_none_for_cloud( + self, mock_storage: MagicMock + ) -> None: + """Should return None when not using local storage.""" + helper = StorageHelper(storage=mock_storage) + base_path = helper.get_exports_base_path() + assert base_path is None diff --git a/tests/web/test_storage_integration.py b/tests/web/test_storage_integration.py new file mode 100644 index 0000000..2fb081d --- /dev/null +++ b/tests/web/test_storage_integration.py @@ -0,0 +1,306 @@ +""" +Tests for storage backend integration in web application. + +TDD Phase 1: RED - Write tests first, then implement to pass. +""" + +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +class TestStorageBackendInitialization: + """Tests for storage backend initialization in web config.""" + + def test_get_storage_backend_returns_backend(self, tmp_path: Path) -> None: + """Test that get_storage_backend returns a StorageBackend instance.""" + from shared.storage.base import StorageBackend + + from inference.web.config import get_storage_backend + + env = { + "STORAGE_BACKEND": "local", + "STORAGE_BASE_PATH": str(tmp_path / "storage"), + } + + with patch.dict(os.environ, env, clear=False): + backend = get_storage_backend() + + assert isinstance(backend, StorageBackend) + + def test_get_storage_backend_uses_config_file_if_exists( + self, tmp_path: Path + ) -> None: + """Test that storage config file is used when present.""" + from shared.storage.local import LocalStorageBackend + + from inference.web.config import get_storage_backend + + config_file = tmp_path / "storage.yaml" + storage_path = tmp_path / "storage" + config_file.write_text(f""" +backend: local + +local: + base_path: {storage_path} +""") + + backend = get_storage_backend(config_path=config_file) + + assert isinstance(backend, LocalStorageBackend) + + def test_get_storage_backend_falls_back_to_env(self, tmp_path: Path) -> None: + """Test fallback to environment variables when no config file.""" + from shared.storage.local import LocalStorageBackend + + from inference.web.config import get_storage_backend + + env = { + "STORAGE_BACKEND": "local", + "STORAGE_BASE_PATH": str(tmp_path / "storage"), + } + + with patch.dict(os.environ, env, clear=False): + backend = get_storage_backend(config_path=None) + + assert isinstance(backend, LocalStorageBackend) + + def test_app_config_has_storage_backend(self, tmp_path: Path) -> None: + """Test that AppConfig can be created with storage backend.""" + from shared.storage.base import StorageBackend + + from inference.web.config import AppConfig, create_app_config + + env = { + "STORAGE_BACKEND": "local", + "STORAGE_BASE_PATH": str(tmp_path / "storage"), + } + + with patch.dict(os.environ, env, clear=False): + config = create_app_config() + + assert hasattr(config, "storage_backend") + assert isinstance(config.storage_backend, StorageBackend) + + +class TestStorageBackendInDocumentUpload: + """Tests for storage backend usage in document upload.""" + + def test_upload_document_uses_storage_backend( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that document upload uses storage backend.""" + from unittest.mock import AsyncMock + + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + # Create a mock upload file + pdf_content = b"%PDF-1.4 test content" + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + # Upload should use storage backend + result = service.upload_document( + content=pdf_content, + filename="test.pdf", + dataset_id="dataset-1", + ) + + assert result is not None + # Verify file was stored via storage backend + assert backend.exists(f"documents/{result.id}.pdf") + + def test_upload_document_stores_logical_path( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that document stores logical path, not absolute path.""" + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + pdf_content = b"%PDF-1.4 test content" + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + result = service.upload_document( + content=pdf_content, + filename="test.pdf", + dataset_id="dataset-1", + ) + + # Path should be logical (relative), not absolute + assert not result.file_path.startswith("/") + assert not result.file_path.startswith("C:") + assert result.file_path.startswith("documents/") + + +class TestStorageBackendInDocumentDownload: + """Tests for storage backend usage in document download/serving.""" + + def test_get_document_url_returns_presigned_url( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that document URL uses presigned URL from storage backend.""" + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + # Create a test file + doc_path = "documents/test-doc.pdf" + backend.upload_bytes(b"%PDF-1.4 test", doc_path) + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + url = service.get_document_url(doc_path) + + # Should return a URL (file:// for local, https:// for cloud) + assert url is not None + assert "test-doc.pdf" in url + + def test_download_document_uses_storage_backend( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that document download uses storage backend.""" + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + # Create a test file + doc_path = "documents/test-doc.pdf" + original_content = b"%PDF-1.4 test content" + backend.upload_bytes(original_content, doc_path) + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + content = service.download_document(doc_path) + + assert content == original_content + + +class TestStorageBackendInImageServing: + """Tests for storage backend usage in image serving.""" + + def test_get_page_image_url_returns_presigned_url( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that page image URL uses presigned URL.""" + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + # Create a test image + image_path = "images/doc-123/page_1.png" + backend.upload_bytes(b"fake png content", image_path) + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + url = service.get_page_image_url("doc-123", 1) + + assert url is not None + assert "page_1.png" in url + + def test_save_page_image_uses_storage_backend( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that page image saving uses storage backend.""" + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + image_content = b"fake png content" + service.save_page_image("doc-123", 1, image_content) + + # Verify image was stored + assert backend.exists("images/doc-123/page_1.png") + + +class TestStorageBackendInDocumentDeletion: + """Tests for storage backend usage in document deletion.""" + + def test_delete_document_removes_from_storage( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that document deletion removes file from storage.""" + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + # Create test files + doc_path = "documents/test-doc.pdf" + backend.upload_bytes(b"%PDF-1.4 test", doc_path) + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + service.delete_document_files(doc_path) + + assert not backend.exists(doc_path) + + def test_delete_document_removes_images( + self, tmp_path: Path, mock_admin_db: MagicMock + ) -> None: + """Test that document deletion removes associated images.""" + from shared.storage.local import LocalStorageBackend + + storage_path = tmp_path / "storage" + storage_path.mkdir(parents=True, exist_ok=True) + backend = LocalStorageBackend(str(storage_path)) + + # Create test files + doc_id = "test-doc-123" + backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png") + backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png") + + from inference.web.services.document_service import DocumentService + + service = DocumentService(admin_db=mock_admin_db, storage_backend=backend) + + service.delete_document_images(doc_id) + + assert not backend.exists(f"images/{doc_id}/page_1.png") + assert not backend.exists(f"images/{doc_id}/page_2.png") + + +@pytest.fixture +def mock_admin_db() -> MagicMock: + """Create a mock AdminDB for testing.""" + mock = MagicMock() + mock.get_document.return_value = None + mock.create_document.return_value = MagicMock( + id="test-doc-id", + file_path="documents/test-doc-id.pdf", + ) + return mock diff --git a/tests/web/test_training_phase4.py b/tests/web/test_training_phase4.py index 1162985..c27e3d8 100644 --- a/tests/web/test_training_phase4.py +++ b/tests/web/test_training_phase4.py @@ -103,6 +103,31 @@ class MockAnnotation: self.updated_at = kwargs.get('updated_at', datetime.utcnow()) +class MockModelVersion: + """Mock ModelVersion for testing.""" + + def __init__(self, **kwargs): + self.version_id = kwargs.get('version_id', uuid4()) + self.version = kwargs.get('version', '1.0.0') + self.name = kwargs.get('name', 'Test Model') + self.description = kwargs.get('description', None) + self.model_path = kwargs.get('model_path', 'runs/train/test/weights/best.pt') + self.status = kwargs.get('status', 'inactive') + self.is_active = kwargs.get('is_active', False) + self.task_id = kwargs.get('task_id', None) + self.dataset_id = kwargs.get('dataset_id', None) + self.metrics_mAP = kwargs.get('metrics_mAP', 0.935) + self.metrics_precision = kwargs.get('metrics_precision', 0.92) + self.metrics_recall = kwargs.get('metrics_recall', 0.88) + self.document_count = kwargs.get('document_count', 100) + self.training_config = kwargs.get('training_config', {}) + self.file_size = kwargs.get('file_size', 52428800) + self.trained_at = kwargs.get('trained_at', datetime.utcnow()) + self.activated_at = kwargs.get('activated_at', None) + self.created_at = kwargs.get('created_at', datetime.utcnow()) + self.updated_at = kwargs.get('updated_at', datetime.utcnow()) + + class MockAdminDB: """Mock AdminDB for testing Phase 4.""" @@ -111,6 +136,7 @@ class MockAdminDB: self.annotations = {} self.training_tasks = {} self.training_links = {} + self.model_versions = {} def get_documents_for_training( self, @@ -174,6 +200,14 @@ class MockAdminDB: """Get training task by ID.""" return self.training_tasks.get(str(task_id)) + def get_model_versions(self, status=None, limit=20, offset=0): + """Get model versions with optional filtering.""" + models = list(self.model_versions.values()) + if status: + models = [m for m in models if m.status == status] + total = len(models) + return models[offset:offset+limit], total + @pytest.fixture def app(): @@ -241,6 +275,30 @@ def app(): ) mock_db.training_links[str(doc1.document_id)] = [link1] + # Add model versions + model1 = MockModelVersion( + version="1.0.0", + name="Model v1.0.0", + status="inactive", + is_active=False, + metrics_mAP=0.935, + metrics_precision=0.92, + metrics_recall=0.88, + document_count=500, + ) + model2 = MockModelVersion( + version="1.1.0", + name="Model v1.1.0", + status="active", + is_active=True, + metrics_mAP=0.951, + metrics_precision=0.94, + metrics_recall=0.92, + document_count=600, + ) + mock_db.model_versions[str(model1.version_id)] = model1 + mock_db.model_versions[str(model2.version_id)] = model2 + # Override dependencies app.dependency_overrides[validate_admin_token] = lambda: "test-token" app.dependency_overrides[get_admin_db] = lambda: mock_db @@ -324,10 +382,10 @@ class TestTrainingDocuments: class TestTrainingModels: - """Tests for GET /admin/training/models endpoint.""" + """Tests for GET /admin/training/models endpoint (ModelVersionListResponse).""" def test_get_training_models_success(self, client): - """Test getting trained models list.""" + """Test getting model versions list.""" response = client.get("/admin/training/models") assert response.status_code == 200 @@ -338,43 +396,44 @@ class TestTrainingModels: assert len(data["models"]) == 2 def test_get_training_models_includes_metrics(self, client): - """Test that models include metrics.""" + """Test that model versions include metrics.""" response = client.get("/admin/training/models") assert response.status_code == 200 data = response.json() - # Check first model has metrics + # Check first model has metrics fields model = data["models"][0] - assert "metrics" in model - assert "mAP" in model["metrics"] - assert model["metrics"]["mAP"] is not None - assert "precision" in model["metrics"] - assert "recall" in model["metrics"] + assert "metrics_mAP" in model + assert model["metrics_mAP"] is not None - def test_get_training_models_includes_download_url(self, client): - """Test that completed models have download URLs.""" + def test_get_training_models_includes_version_fields(self, client): + """Test that model versions include version fields.""" response = client.get("/admin/training/models") assert response.status_code == 200 data = response.json() - # Check completed models have download URLs - for model in data["models"]: - if model["status"] == "completed": - assert "download_url" in model - assert model["download_url"] is not None + # Check model has expected fields + model = data["models"][0] + assert "version_id" in model + assert "version" in model + assert "name" in model + assert "status" in model + assert "is_active" in model + assert "document_count" in model def test_get_training_models_filter_by_status(self, client): - """Test filtering models by status.""" - response = client.get("/admin/training/models?status=completed") + """Test filtering model versions by status.""" + response = client.get("/admin/training/models?status=active") assert response.status_code == 200 data = response.json() - # All returned models should be completed + assert data["total"] == 1 + # All returned models should be active for model in data["models"]: - assert model["status"] == "completed" + assert model["status"] == "active" def test_get_training_models_pagination(self, client): - """Test pagination for models.""" + """Test pagination for model versions.""" response = client.get("/admin/training/models?limit=1&offset=0") assert response.status_code == 200