Compare commits
10 Commits
8fd61ea928
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a564ac9d70 | ||
|
|
4126196dea | ||
|
|
a516de4320 | ||
|
|
33ada0350d | ||
|
|
d2489a97d4 | ||
|
|
d6550375b0 | ||
|
|
58bf75db68 | ||
|
|
e83a0cae36 | ||
|
|
d5101e3604 | ||
|
|
e599424a92 |
@@ -1,263 +1,143 @@
|
||||
[角色]
|
||||
你是废才,一位资深产品经理兼全栈开发教练。
|
||||
# Invoice Master POC v2
|
||||
|
||||
你见过太多人带着"改变世界"的妄想来找你,最后连需求都说不清楚。
|
||||
你也见过真正能成事的人——他们不一定聪明,但足够诚实,敢于面对自己想法的漏洞。
|
||||
Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
|
||||
|
||||
你负责引导用户完成产品开发的完整旅程:从脑子里的模糊想法,到可运行的产品。
|
||||
## Tech Stack
|
||||
|
||||
[任务]
|
||||
引导用户完成产品开发的完整流程:
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Object Detection | YOLOv11 (Ultralytics) |
|
||||
| OCR Engine | PaddleOCR v5 (PP-OCRv5) |
|
||||
| PDF Processing | PyMuPDF (fitz) |
|
||||
| Database | PostgreSQL + psycopg2 |
|
||||
| Web Framework | FastAPI + Uvicorn |
|
||||
| Deep Learning | PyTorch + CUDA 12.x |
|
||||
|
||||
1. **需求收集** → 调用 product-spec-builder,生成 Product-Spec.md
|
||||
2. **原型设计** → 调用 ui-prompt-generator,生成 UI-Prompts.md(可选)
|
||||
3. **项目开发** → 调用 dev-builder,实现项目代码
|
||||
4. **本地运行** → 启动项目,输出使用指南
|
||||
## WSL Environment (REQUIRED)
|
||||
|
||||
[文件结构]
|
||||
project/
|
||||
├── Product-Spec.md # 产品需求文档
|
||||
├── Product-Spec-CHANGELOG.md # 需求变更记录
|
||||
├── UI-Prompts.md # 原型图提示词(可选)
|
||||
├── [项目源代码]/ # 代码文件
|
||||
└── .claude/
|
||||
├── CLAUDE.md # 主控(本文件)
|
||||
└── skills/
|
||||
├── product-spec-builder/ # 需求收集
|
||||
├── ui-prompt-generator/ # 原型图提示词
|
||||
└── dev-builder/ # 项目开发
|
||||
**Prefix ALL commands with:**
|
||||
|
||||
[总体规则]
|
||||
- 严格按照 需求收集 → 原型设计(可选)→ 项目开发 → 本地运行 的流程引导
|
||||
- **任何功能变更、UI 修改、需求调整,都必须先更新 Product Spec,再实现代码**
|
||||
- 无论用户如何打断或提出新问题,完成当前回答后始终引导用户进入下一步
|
||||
- 始终使用**中文**进行交流
|
||||
|
||||
[运行环境要求]
|
||||
**强制要求**:所有程序运行、命令执行必须在 WSL 环境中进行
|
||||
|
||||
- **WSL**:所有 bash 命令必须通过 `wsl` 前缀执行
|
||||
- **Conda 环境**:必须使用 `invoice-py311` 环境
|
||||
|
||||
命令执行格式:
|
||||
```bash
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && <你的命令>"
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && <command>"
|
||||
```
|
||||
|
||||
示例:
|
||||
**NEVER run Python commands directly in Windows PowerShell/CMD.**
|
||||
|
||||
## Project-Specific Rules
|
||||
|
||||
- Python 3.11+ with type hints
|
||||
- No print() in production - use logging
|
||||
- Run tests: `pytest --cov=src`
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── cli/ # autolabel, train, infer, serve
|
||||
├── pdf/ # extractor, renderer, detector
|
||||
├── ocr/ # PaddleOCR wrapper, machine_code_parser
|
||||
├── inference/ # pipeline, yolo_detector, field_extractor
|
||||
├── normalize/ # Per-field normalizers
|
||||
├── matcher/ # Exact, substring, fuzzy strategies
|
||||
├── processing/ # CPU/GPU pool architecture
|
||||
├── web/ # FastAPI app, routes, services, schemas
|
||||
├── utils/ # validators, text_cleaner, fuzzy_matcher
|
||||
└── data/ # Database operations
|
||||
tests/ # Mirror of src structure
|
||||
runs/train/ # Training outputs
|
||||
```
|
||||
|
||||
## Supported Fields
|
||||
|
||||
| ID | Field | Description |
|
||||
|----|-------|-------------|
|
||||
| 0 | invoice_number | Invoice number |
|
||||
| 1 | invoice_date | Invoice date |
|
||||
| 2 | invoice_due_date | Due date |
|
||||
| 3 | ocr_number | OCR reference (Swedish payment) |
|
||||
| 4 | bankgiro | Bankgiro account |
|
||||
| 5 | plusgiro | Plusgiro account |
|
||||
| 6 | amount | Amount |
|
||||
| 7 | supplier_organisation_number | Supplier org number |
|
||||
| 8 | payment_line | Payment line (machine-readable) |
|
||||
| 9 | customer_number | Customer number |
|
||||
|
||||
## Key Patterns
|
||||
|
||||
### Inference Result
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
document_id: str
|
||||
document_type: str # "invoice" or "letter"
|
||||
fields: dict[str, str]
|
||||
confidence: dict[str, float]
|
||||
cross_validation: CrossValidationResult | None
|
||||
processing_time_ms: float
|
||||
```
|
||||
|
||||
### API Schemas
|
||||
|
||||
See `src/web/schemas.py` for request/response models.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
```bash
|
||||
# 运行 Python 脚本
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python main.py"
|
||||
# Required
|
||||
DB_PASSWORD=
|
||||
|
||||
# 安装依赖
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && pip install -r requirements.txt"
|
||||
|
||||
# 运行测试
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && pytest"
|
||||
# Optional (with defaults)
|
||||
DB_HOST=192.168.68.31
|
||||
DB_PORT=5432
|
||||
DB_NAME=docmaster
|
||||
DB_USER=docmaster
|
||||
MODEL_PATH=runs/train/invoice_fields/weights/best.pt
|
||||
CONFIDENCE_THRESHOLD=0.5
|
||||
SERVER_HOST=0.0.0.0
|
||||
SERVER_PORT=8000
|
||||
```
|
||||
|
||||
**注意**:
|
||||
- 不要直接在 Windows PowerShell/CMD 中运行 Python 命令
|
||||
- 每次执行命令都需要激活 conda 环境(因为是非交互式 shell)
|
||||
- 路径需要转换为 WSL 格式(如 `/mnt/c/Users/...`)
|
||||
## CLI Commands
|
||||
|
||||
[Skill 调用规则]
|
||||
[product-spec-builder]
|
||||
**自动调用**:
|
||||
- 用户表达想要开发产品、应用、工具时
|
||||
- 用户描述产品想法、功能需求时
|
||||
- 用户要修改 UI、改界面、调整布局时(迭代模式)
|
||||
- 用户要增加功能、新增功能时(迭代模式)
|
||||
- 用户要改需求、调整功能、修改逻辑时(迭代模式)
|
||||
```bash
|
||||
# Auto-labeling
|
||||
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
|
||||
|
||||
**手动调用**:/prd
|
||||
# Training
|
||||
python -m src.cli.train --model yolo11n.pt --epochs 100 --batch 16 --name invoice_fields
|
||||
|
||||
[ui-prompt-generator]
|
||||
**手动调用**:/ui
|
||||
# Inference
|
||||
python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt --input invoice.pdf --gpu
|
||||
|
||||
前置条件:Product-Spec.md 必须存在
|
||||
|
||||
[dev-builder]
|
||||
**手动调用**:/dev
|
||||
|
||||
前置条件:Product-Spec.md 必须存在
|
||||
|
||||
[项目状态检测与路由]
|
||||
初始化时自动检测项目进度,路由到对应阶段:
|
||||
|
||||
检测逻辑:
|
||||
- 无 Product-Spec.md → 全新项目 → 引导用户描述想法或输入 /prd
|
||||
- 有 Product-Spec.md,无代码 → Spec 已完成 → 输出交付指南
|
||||
- 有 Product-Spec.md,有代码 → 项目已创建 → 可执行 /check 或 /run
|
||||
|
||||
显示格式:
|
||||
"📊 **项目进度检测**
|
||||
|
||||
- Product Spec:[已完成/未完成]
|
||||
- 原型图提示词:[已生成/未生成]
|
||||
- 项目代码:[已创建/未创建]
|
||||
|
||||
**当前阶段**:[阶段名称]
|
||||
**下一步**:[具体指令或操作]"
|
||||
|
||||
[工作流程]
|
||||
[需求收集阶段]
|
||||
触发:用户表达产品想法(自动)或输入 /prd(手动)
|
||||
|
||||
执行:调用 product-spec-builder skill
|
||||
|
||||
完成后:输出交付指南,引导下一步
|
||||
|
||||
[交付阶段]
|
||||
触发:Product Spec 生成完成后自动执行
|
||||
|
||||
输出:
|
||||
"✅ **Product Spec 已生成!**
|
||||
|
||||
文件:Product-Spec.md
|
||||
|
||||
---
|
||||
|
||||
## 📘 接下来
|
||||
|
||||
- 输入 /ui 生成原型图提示词(可选)
|
||||
- 输入 /dev 开始开发项目
|
||||
- 直接对话可以改 UI、加功能"
|
||||
|
||||
[原型图阶段]
|
||||
触发:用户输入 /ui
|
||||
|
||||
执行:调用 ui-prompt-generator skill
|
||||
|
||||
完成后:
|
||||
"✅ **原型图提示词已生成!**
|
||||
|
||||
文件:UI-Prompts.md
|
||||
|
||||
把提示词发给 AI 绘图工具生成原型图,然后输入 /dev 开始开发。"
|
||||
|
||||
[项目开发阶段]
|
||||
触发:用户输入 /dev
|
||||
|
||||
第一步:询问原型图
|
||||
询问用户:"有原型图或设计稿吗?有的话发给我参考。"
|
||||
用户发送图片 → 记录,开发时参考
|
||||
用户说没有 → 继续
|
||||
|
||||
第二步:执行开发
|
||||
调用 dev-builder skill
|
||||
|
||||
完成后:引导用户执行 /run
|
||||
|
||||
[代码检查阶段]
|
||||
触发:用户输入 /check
|
||||
|
||||
执行:
|
||||
第一步:读取 Product Spec 文档
|
||||
加载 Product-Spec.md 文件
|
||||
解析功能需求、UI 布局
|
||||
|
||||
第二步:扫描项目代码
|
||||
遍历项目目录下的代码文件
|
||||
识别已实现的功能、组件
|
||||
|
||||
第三步:功能完整度检查
|
||||
- 功能需求:Product Spec 功能需求 vs 代码实现
|
||||
- UI 布局:Product Spec 布局描述 vs 界面代码
|
||||
|
||||
第四步:输出检查报告
|
||||
|
||||
输出:
|
||||
"📋 **项目完整度检查报告**
|
||||
|
||||
**对照文档**:Product-Spec.md
|
||||
|
||||
---
|
||||
|
||||
✅ **已完成(X项)**
|
||||
- [功能名称]:[实现位置]
|
||||
|
||||
⚠️ **部分完成(X项)**
|
||||
- [功能名称]:[缺失内容]
|
||||
|
||||
❌ **缺失(X项)**
|
||||
- [功能名称]:未实现
|
||||
|
||||
---
|
||||
|
||||
💡 **改进建议**
|
||||
1. [具体建议]
|
||||
2. [具体建议]
|
||||
|
||||
---
|
||||
|
||||
需要我帮你补充这些功能吗?或输入 /run 先跑起来看看。"
|
||||
|
||||
[本地运行阶段]
|
||||
触发:用户输入 /run
|
||||
|
||||
执行:自动检测项目类型,安装依赖,启动项目
|
||||
|
||||
输出:
|
||||
"🚀 **项目已启动!**
|
||||
|
||||
**访问地址**:http://localhost:[端口号]
|
||||
|
||||
---
|
||||
|
||||
## 📖 使用指南
|
||||
|
||||
[根据 Product Spec 生成简要使用说明]
|
||||
|
||||
---
|
||||
|
||||
💡 **提示**:
|
||||
- /stop 停止服务
|
||||
- /check 检查完整度
|
||||
- /prd 修改需求"
|
||||
|
||||
[内容修订]
|
||||
当用户提出修改意见时:
|
||||
|
||||
**流程**:先更新文档 → 再实现代码
|
||||
|
||||
1. 调用 product-spec-builder(迭代模式)
|
||||
- 通过追问明确变更内容
|
||||
- 更新 Product-Spec.md
|
||||
- 更新 Product-Spec-CHANGELOG.md
|
||||
2. 调用 dev-builder 实现代码变更
|
||||
3. 建议用户执行 /check 验证
|
||||
|
||||
[指令集]
|
||||
/prd - 需求收集,生成 Product Spec
|
||||
/ui - 生成原型图提示词
|
||||
/dev - 开发项目代码
|
||||
/check - 对照 Spec 检查代码完整度
|
||||
/run - 本地运行项目
|
||||
/stop - 停止运行中的服务
|
||||
/status - 显示项目进度
|
||||
/help - 显示所有指令
|
||||
|
||||
[初始化]
|
||||
以下ASCII艺术应该显示"FEICAI"字样。如果您看到乱码或显示异常,请帮忙纠正,使用ASCII艺术生成显示"FEICAI"
|
||||
```
|
||||
"███████╗███████╗██╗ ██████╗ █████╗ ██╗
|
||||
██╔════╝██╔════╝██║██╔════╝██╔══██╗██║
|
||||
█████╗ █████╗ ██║██║ ███████║██║
|
||||
██╔══╝ ██╔══╝ ██║██║ ██╔══██║██║
|
||||
██║ ███████╗██║╚██████╗██║ ██║██║
|
||||
╚═╝ ╚══════╝╚═╝ ╚═════╝╚═╝ ╚═╝╚═╝"
|
||||
# Web Server
|
||||
python run_server.py --port 8000
|
||||
```
|
||||
|
||||
"👋 我是废才,产品经理兼开发教练。
|
||||
## API Endpoints
|
||||
|
||||
我不聊理想,只聊产品。你负责想,我负责问到你想清楚。
|
||||
从需求文档到本地运行,全程我带着走。
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| GET | `/` | Web UI |
|
||||
| GET | `/api/v1/health` | Health check |
|
||||
| POST | `/api/v1/infer` | Process invoice |
|
||||
| GET | `/api/v1/results/{filename}` | Get visualization |
|
||||
|
||||
过程中我会问很多问题,有些可能让你不舒服。不过放心,我只是想让你的产品能落地,仅此而已。
|
||||
## Current Status
|
||||
|
||||
💡 输入 /help 查看所有指令
|
||||
- **Tests**: 688 passing
|
||||
- **Coverage**: 37%
|
||||
- **Model**: 93.5% mAP@0.5
|
||||
- **Documents Labeled**: 9,738
|
||||
|
||||
现在,说说你想做什么?"
|
||||
## Quick Start
|
||||
|
||||
执行 [项目状态检测与路由]
|
||||
```bash
|
||||
# Start server
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py"
|
||||
|
||||
# Run tests
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest"
|
||||
|
||||
# Access UI: http://localhost:8000
|
||||
```
|
||||
22
.claude/commands/build-fix.md
Normal file
22
.claude/commands/build-fix.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Build and Fix
|
||||
|
||||
Incrementally fix Python errors and test failures.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run check: `mypy src/ --ignore-missing-imports` or `pytest -x --tb=short`
|
||||
2. Parse errors, group by file, sort by severity (ImportError > TypeError > other)
|
||||
3. For each error:
|
||||
- Show context (5 lines)
|
||||
- Explain and propose fix
|
||||
- Apply fix
|
||||
- Re-run test for that file
|
||||
- Verify resolved
|
||||
4. Stop if: fix introduces new errors, same error after 3 attempts, or user pauses
|
||||
5. Show summary: fixed / remaining / new errors
|
||||
|
||||
## Rules
|
||||
|
||||
- Fix ONE error at a time
|
||||
- Re-run tests after each fix
|
||||
- Never batch multiple unrelated fixes
|
||||
74
.claude/commands/checkpoint.md
Normal file
74
.claude/commands/checkpoint.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Checkpoint Command
|
||||
|
||||
Create or verify a checkpoint in your workflow.
|
||||
|
||||
## Usage
|
||||
|
||||
`/checkpoint [create|verify|list] [name]`
|
||||
|
||||
## Create Checkpoint
|
||||
|
||||
When creating a checkpoint:
|
||||
|
||||
1. Run `/verify quick` to ensure current state is clean
|
||||
2. Create a git stash or commit with checkpoint name
|
||||
3. Log checkpoint to `.claude/checkpoints.log`:
|
||||
|
||||
```bash
|
||||
echo "$(date +%Y-%m-%d-%H:%M) | $CHECKPOINT_NAME | $(git rev-parse --short HEAD)" >> .claude/checkpoints.log
|
||||
```
|
||||
|
||||
4. Report checkpoint created
|
||||
|
||||
## Verify Checkpoint
|
||||
|
||||
When verifying against a checkpoint:
|
||||
|
||||
1. Read checkpoint from log
|
||||
2. Compare current state to checkpoint:
|
||||
- Files added since checkpoint
|
||||
- Files modified since checkpoint
|
||||
- Test pass rate now vs then
|
||||
- Coverage now vs then
|
||||
|
||||
3. Report:
|
||||
```
|
||||
CHECKPOINT COMPARISON: $NAME
|
||||
============================
|
||||
Files changed: X
|
||||
Tests: +Y passed / -Z failed
|
||||
Coverage: +X% / -Y%
|
||||
Build: [PASS/FAIL]
|
||||
```
|
||||
|
||||
## List Checkpoints
|
||||
|
||||
Show all checkpoints with:
|
||||
- Name
|
||||
- Timestamp
|
||||
- Git SHA
|
||||
- Status (current, behind, ahead)
|
||||
|
||||
## Workflow
|
||||
|
||||
Typical checkpoint flow:
|
||||
|
||||
```
|
||||
[Start] --> /checkpoint create "feature-start"
|
||||
|
|
||||
[Implement] --> /checkpoint create "core-done"
|
||||
|
|
||||
[Test] --> /checkpoint verify "core-done"
|
||||
|
|
||||
[Refactor] --> /checkpoint create "refactor-done"
|
||||
|
|
||||
[PR] --> /checkpoint verify "feature-start"
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
$ARGUMENTS:
|
||||
- `create <name>` - Create named checkpoint
|
||||
- `verify <name>` - Verify against named checkpoint
|
||||
- `list` - Show all checkpoints
|
||||
- `clear` - Remove old checkpoints (keeps last 5)
|
||||
46
.claude/commands/code-review.md
Normal file
46
.claude/commands/code-review.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Code Review
|
||||
|
||||
Security and quality review of uncommitted changes.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Get changed files: `git diff --name-only HEAD` and `git diff --staged --name-only`
|
||||
2. Review each file for issues (see checklist below)
|
||||
3. Run automated checks: `mypy src/`, `ruff check src/`, `pytest -x`
|
||||
4. Generate report with severity, location, description, suggested fix
|
||||
5. Block commit if CRITICAL or HIGH issues found
|
||||
|
||||
## Checklist
|
||||
|
||||
### CRITICAL (Block)
|
||||
|
||||
- Hardcoded credentials, API keys, tokens, passwords
|
||||
- SQL injection (must use parameterized queries)
|
||||
- Path traversal risks
|
||||
- Missing input validation on API endpoints
|
||||
- Missing authentication/authorization
|
||||
|
||||
### HIGH (Block)
|
||||
|
||||
- Functions > 50 lines, files > 800 lines
|
||||
- Nesting depth > 4 levels
|
||||
- Missing error handling or bare `except:`
|
||||
- `print()` in production code (use logging)
|
||||
- Mutable default arguments
|
||||
|
||||
### MEDIUM (Warn)
|
||||
|
||||
- Missing type hints on public functions
|
||||
- Missing tests for new code
|
||||
- Duplicate code, magic numbers
|
||||
- Unused imports/variables
|
||||
- TODO/FIXME comments
|
||||
|
||||
## Report Format
|
||||
|
||||
```
|
||||
[SEVERITY] file:line - Issue description
|
||||
Suggested fix: ...
|
||||
```
|
||||
|
||||
## Never Approve Code With Security Vulnerabilities!
|
||||
40
.claude/commands/e2e.md
Normal file
40
.claude/commands/e2e.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# E2E Testing
|
||||
|
||||
End-to-end testing for the Invoice Field Extraction API.
|
||||
|
||||
## When to Use
|
||||
|
||||
- Testing complete inference pipeline (PDF -> Fields)
|
||||
- Verifying API endpoints work end-to-end
|
||||
- Validating YOLO + OCR + field extraction integration
|
||||
- Pre-deployment verification
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Ensure server is running: `python run_server.py`
|
||||
2. Run health check: `curl http://localhost:8000/api/v1/health`
|
||||
3. Run E2E tests: `pytest tests/e2e/ -v`
|
||||
4. Verify results and capture any failures
|
||||
|
||||
## Critical Scenarios (Must Pass)
|
||||
|
||||
1. Health check returns `{"status": "healthy", "model_loaded": true}`
|
||||
2. PDF upload returns valid response with fields
|
||||
3. Fields extracted with confidence scores
|
||||
4. Visualization image generated
|
||||
5. Cross-validation included for invoices with payment_line
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Server running on http://localhost:8000
|
||||
- [ ] Health check passes
|
||||
- [ ] PDF inference returns valid JSON
|
||||
- [ ] At least one field extracted
|
||||
- [ ] Visualization URL returns image
|
||||
- [ ] Response time < 10 seconds
|
||||
- [ ] No server errors in logs
|
||||
|
||||
## Test Location
|
||||
|
||||
E2E tests: `tests/e2e/`
|
||||
Sample fixtures: `tests/fixtures/`
|
||||
174
.claude/commands/eval.md
Normal file
174
.claude/commands/eval.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# Eval Command
|
||||
|
||||
Evaluate model performance and field extraction accuracy.
|
||||
|
||||
## Usage
|
||||
|
||||
`/eval [model|accuracy|compare|report]`
|
||||
|
||||
## Model Evaluation
|
||||
|
||||
`/eval model`
|
||||
|
||||
Evaluate YOLO model performance on test dataset:
|
||||
|
||||
```bash
|
||||
# Run model evaluation
|
||||
python -m src.cli.train --model runs/train/invoice_fields/weights/best.pt --eval-only
|
||||
|
||||
# Or use ultralytics directly
|
||||
yolo val model=runs/train/invoice_fields/weights/best.pt data=data.yaml
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
Model Evaluation: invoice_fields/best.pt
|
||||
========================================
|
||||
mAP@0.5: 93.5%
|
||||
mAP@0.5-0.95: 83.0%
|
||||
|
||||
Per-class AP:
|
||||
- invoice_number: 95.2%
|
||||
- invoice_date: 94.8%
|
||||
- invoice_due_date: 93.1%
|
||||
- ocr_number: 91.5%
|
||||
- bankgiro: 92.3%
|
||||
- plusgiro: 90.8%
|
||||
- amount: 88.7%
|
||||
- supplier_org_num: 85.2%
|
||||
- payment_line: 82.4%
|
||||
- customer_number: 81.1%
|
||||
```
|
||||
|
||||
## Accuracy Evaluation
|
||||
|
||||
`/eval accuracy`
|
||||
|
||||
Evaluate field extraction accuracy against ground truth:
|
||||
|
||||
```bash
|
||||
# Run accuracy evaluation on labeled data
|
||||
python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt \
|
||||
--input ~/invoice-data/test/*.pdf \
|
||||
--ground-truth ~/invoice-data/test/labels.csv \
|
||||
--output eval_results.json
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
Field Extraction Accuracy
|
||||
=========================
|
||||
Documents tested: 500
|
||||
|
||||
Per-field accuracy:
|
||||
- InvoiceNumber: 98.9% (494/500)
|
||||
- InvoiceDate: 95.5% (478/500)
|
||||
- InvoiceDueDate: 95.9% (480/500)
|
||||
- OCR: 99.1% (496/500)
|
||||
- Bankgiro: 99.0% (495/500)
|
||||
- Plusgiro: 99.4% (497/500)
|
||||
- Amount: 91.3% (457/500)
|
||||
- supplier_org: 78.2% (391/500)
|
||||
|
||||
Overall: 94.8%
|
||||
```
|
||||
|
||||
## Compare Models
|
||||
|
||||
`/eval compare`
|
||||
|
||||
Compare two model versions:
|
||||
|
||||
```bash
|
||||
# Compare old vs new model
|
||||
python -m src.cli.eval compare \
|
||||
--model-a runs/train/invoice_v1/weights/best.pt \
|
||||
--model-b runs/train/invoice_v2/weights/best.pt \
|
||||
--test-data ~/invoice-data/test/
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
Model Comparison
|
||||
================
|
||||
Model A Model B Delta
|
||||
mAP@0.5: 91.2% 93.5% +2.3%
|
||||
Accuracy: 92.1% 94.8% +2.7%
|
||||
Speed (ms): 1850 1520 -330
|
||||
|
||||
Per-field improvements:
|
||||
- amount: +4.2%
|
||||
- payment_line: +3.8%
|
||||
- customer_num: +2.1%
|
||||
|
||||
Recommendation: Deploy Model B
|
||||
```
|
||||
|
||||
## Generate Report
|
||||
|
||||
`/eval report`
|
||||
|
||||
Generate comprehensive evaluation report:
|
||||
|
||||
```bash
|
||||
python -m src.cli.eval report --output eval_report.md
|
||||
```
|
||||
|
||||
Output:
|
||||
```markdown
|
||||
# Evaluation Report
|
||||
Generated: 2026-01-25
|
||||
|
||||
## Model Performance
|
||||
- Model: runs/train/invoice_fields/weights/best.pt
|
||||
- mAP@0.5: 93.5%
|
||||
- Training samples: 9,738
|
||||
|
||||
## Field Extraction Accuracy
|
||||
| Field | Accuracy | Errors |
|
||||
|-------|----------|--------|
|
||||
| InvoiceNumber | 98.9% | 6 |
|
||||
| Amount | 91.3% | 43 |
|
||||
...
|
||||
|
||||
## Error Analysis
|
||||
### Common Errors
|
||||
1. Amount: OCR misreads comma as period
|
||||
2. supplier_org: Missing from some invoices
|
||||
3. payment_line: Partially obscured by stamps
|
||||
|
||||
## Recommendations
|
||||
1. Add more training data for low-accuracy fields
|
||||
2. Implement OCR error correction for amounts
|
||||
3. Consider confidence threshold tuning
|
||||
```
|
||||
|
||||
## Quick Commands
|
||||
|
||||
```bash
|
||||
# Evaluate model metrics
|
||||
yolo val model=runs/train/invoice_fields/weights/best.pt
|
||||
|
||||
# Test inference on sample
|
||||
python -m src.cli.infer --input sample.pdf --output result.json --gpu
|
||||
|
||||
# Check test coverage
|
||||
pytest --cov=src --cov-report=html
|
||||
```
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
| Metric | Target | Current |
|
||||
|--------|--------|---------|
|
||||
| mAP@0.5 | >90% | 93.5% |
|
||||
| Overall Accuracy | >90% | 94.8% |
|
||||
| Test Coverage | >60% | 37% |
|
||||
| Tests Passing | 100% | 100% |
|
||||
|
||||
## When to Evaluate
|
||||
|
||||
- After training a new model
|
||||
- Before deploying to production
|
||||
- After adding new training data
|
||||
- When accuracy complaints arise
|
||||
- Weekly performance monitoring
|
||||
70
.claude/commands/learn.md
Normal file
70
.claude/commands/learn.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# /learn - Extract Reusable Patterns
|
||||
|
||||
Analyze the current session and extract any patterns worth saving as skills.
|
||||
|
||||
## Trigger
|
||||
|
||||
Run `/learn` at any point during a session when you've solved a non-trivial problem.
|
||||
|
||||
## What to Extract
|
||||
|
||||
Look for:
|
||||
|
||||
1. **Error Resolution Patterns**
|
||||
- What error occurred?
|
||||
- What was the root cause?
|
||||
- What fixed it?
|
||||
- Is this reusable for similar errors?
|
||||
|
||||
2. **Debugging Techniques**
|
||||
- Non-obvious debugging steps
|
||||
- Tool combinations that worked
|
||||
- Diagnostic patterns
|
||||
|
||||
3. **Workarounds**
|
||||
- Library quirks
|
||||
- API limitations
|
||||
- Version-specific fixes
|
||||
|
||||
4. **Project-Specific Patterns**
|
||||
- Codebase conventions discovered
|
||||
- Architecture decisions made
|
||||
- Integration patterns
|
||||
|
||||
## Output Format
|
||||
|
||||
Create a skill file at `~/.claude/skills/learned/[pattern-name].md`:
|
||||
|
||||
```markdown
|
||||
# [Descriptive Pattern Name]
|
||||
|
||||
**Extracted:** [Date]
|
||||
**Context:** [Brief description of when this applies]
|
||||
|
||||
## Problem
|
||||
[What problem this solves - be specific]
|
||||
|
||||
## Solution
|
||||
[The pattern/technique/workaround]
|
||||
|
||||
## Example
|
||||
[Code example if applicable]
|
||||
|
||||
## When to Use
|
||||
[Trigger conditions - what should activate this skill]
|
||||
```
|
||||
|
||||
## Process
|
||||
|
||||
1. Review the session for extractable patterns
|
||||
2. Identify the most valuable/reusable insight
|
||||
3. Draft the skill file
|
||||
4. Ask user to confirm before saving
|
||||
5. Save to `~/.claude/skills/learned/`
|
||||
|
||||
## Notes
|
||||
|
||||
- Don't extract trivial fixes (typos, simple syntax errors)
|
||||
- Don't extract one-time issues (specific API outages, etc.)
|
||||
- Focus on patterns that will save time in future sessions
|
||||
- Keep skills focused - one pattern per skill
|
||||
172
.claude/commands/orchestrate.md
Normal file
172
.claude/commands/orchestrate.md
Normal file
@@ -0,0 +1,172 @@
|
||||
# Orchestrate Command
|
||||
|
||||
Sequential agent workflow for complex tasks.
|
||||
|
||||
## Usage
|
||||
|
||||
`/orchestrate [workflow-type] [task-description]`
|
||||
|
||||
## Workflow Types
|
||||
|
||||
### feature
|
||||
Full feature implementation workflow:
|
||||
```
|
||||
planner -> tdd-guide -> code-reviewer -> security-reviewer
|
||||
```
|
||||
|
||||
### bugfix
|
||||
Bug investigation and fix workflow:
|
||||
```
|
||||
explorer -> tdd-guide -> code-reviewer
|
||||
```
|
||||
|
||||
### refactor
|
||||
Safe refactoring workflow:
|
||||
```
|
||||
architect -> code-reviewer -> tdd-guide
|
||||
```
|
||||
|
||||
### security
|
||||
Security-focused review:
|
||||
```
|
||||
security-reviewer -> code-reviewer -> architect
|
||||
```
|
||||
|
||||
## Execution Pattern
|
||||
|
||||
For each agent in the workflow:
|
||||
|
||||
1. **Invoke agent** with context from previous agent
|
||||
2. **Collect output** as structured handoff document
|
||||
3. **Pass to next agent** in chain
|
||||
4. **Aggregate results** into final report
|
||||
|
||||
## Handoff Document Format
|
||||
|
||||
Between agents, create handoff document:
|
||||
|
||||
```markdown
|
||||
## HANDOFF: [previous-agent] -> [next-agent]
|
||||
|
||||
### Context
|
||||
[Summary of what was done]
|
||||
|
||||
### Findings
|
||||
[Key discoveries or decisions]
|
||||
|
||||
### Files Modified
|
||||
[List of files touched]
|
||||
|
||||
### Open Questions
|
||||
[Unresolved items for next agent]
|
||||
|
||||
### Recommendations
|
||||
[Suggested next steps]
|
||||
```
|
||||
|
||||
## Example: Feature Workflow
|
||||
|
||||
```
|
||||
/orchestrate feature "Add user authentication"
|
||||
```
|
||||
|
||||
Executes:
|
||||
|
||||
1. **Planner Agent**
|
||||
- Analyzes requirements
|
||||
- Creates implementation plan
|
||||
- Identifies dependencies
|
||||
- Output: `HANDOFF: planner -> tdd-guide`
|
||||
|
||||
2. **TDD Guide Agent**
|
||||
- Reads planner handoff
|
||||
- Writes tests first
|
||||
- Implements to pass tests
|
||||
- Output: `HANDOFF: tdd-guide -> code-reviewer`
|
||||
|
||||
3. **Code Reviewer Agent**
|
||||
- Reviews implementation
|
||||
- Checks for issues
|
||||
- Suggests improvements
|
||||
- Output: `HANDOFF: code-reviewer -> security-reviewer`
|
||||
|
||||
4. **Security Reviewer Agent**
|
||||
- Security audit
|
||||
- Vulnerability check
|
||||
- Final approval
|
||||
- Output: Final Report
|
||||
|
||||
## Final Report Format
|
||||
|
||||
```
|
||||
ORCHESTRATION REPORT
|
||||
====================
|
||||
Workflow: feature
|
||||
Task: Add user authentication
|
||||
Agents: planner -> tdd-guide -> code-reviewer -> security-reviewer
|
||||
|
||||
SUMMARY
|
||||
-------
|
||||
[One paragraph summary]
|
||||
|
||||
AGENT OUTPUTS
|
||||
-------------
|
||||
Planner: [summary]
|
||||
TDD Guide: [summary]
|
||||
Code Reviewer: [summary]
|
||||
Security Reviewer: [summary]
|
||||
|
||||
FILES CHANGED
|
||||
-------------
|
||||
[List all files modified]
|
||||
|
||||
TEST RESULTS
|
||||
------------
|
||||
[Test pass/fail summary]
|
||||
|
||||
SECURITY STATUS
|
||||
---------------
|
||||
[Security findings]
|
||||
|
||||
RECOMMENDATION
|
||||
--------------
|
||||
[SHIP / NEEDS WORK / BLOCKED]
|
||||
```
|
||||
|
||||
## Parallel Execution
|
||||
|
||||
For independent checks, run agents in parallel:
|
||||
|
||||
```markdown
|
||||
### Parallel Phase
|
||||
Run simultaneously:
|
||||
- code-reviewer (quality)
|
||||
- security-reviewer (security)
|
||||
- architect (design)
|
||||
|
||||
### Merge Results
|
||||
Combine outputs into single report
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
$ARGUMENTS:
|
||||
- `feature <description>` - Full feature workflow
|
||||
- `bugfix <description>` - Bug fix workflow
|
||||
- `refactor <description>` - Refactoring workflow
|
||||
- `security <description>` - Security review workflow
|
||||
- `custom <agents> <description>` - Custom agent sequence
|
||||
|
||||
## Custom Workflow Example
|
||||
|
||||
```
|
||||
/orchestrate custom "architect,tdd-guide,code-reviewer" "Redesign caching layer"
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
1. **Start with planner** for complex features
|
||||
2. **Always include code-reviewer** before merge
|
||||
3. **Use security-reviewer** for auth/payment/PII
|
||||
4. **Keep handoffs concise** - focus on what next agent needs
|
||||
5. **Run verification** between agents if needed
|
||||
113
.claude/commands/plan.md
Normal file
113
.claude/commands/plan.md
Normal file
@@ -0,0 +1,113 @@
|
||||
---
|
||||
description: Restate requirements, assess risks, and create step-by-step implementation plan. WAIT for user CONFIRM before touching any code.
|
||||
---
|
||||
|
||||
# Plan Command
|
||||
|
||||
This command invokes the **planner** agent to create a comprehensive implementation plan before writing any code.
|
||||
|
||||
## What This Command Does
|
||||
|
||||
1. **Restate Requirements** - Clarify what needs to be built
|
||||
2. **Identify Risks** - Surface potential issues and blockers
|
||||
3. **Create Step Plan** - Break down implementation into phases
|
||||
4. **Wait for Confirmation** - MUST receive user approval before proceeding
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `/plan` when:
|
||||
- Starting a new feature
|
||||
- Making significant architectural changes
|
||||
- Working on complex refactoring
|
||||
- Multiple files/components will be affected
|
||||
- Requirements are unclear or ambiguous
|
||||
|
||||
## How It Works
|
||||
|
||||
The planner agent will:
|
||||
|
||||
1. **Analyze the request** and restate requirements in clear terms
|
||||
2. **Break down into phases** with specific, actionable steps
|
||||
3. **Identify dependencies** between components
|
||||
4. **Assess risks** and potential blockers
|
||||
5. **Estimate complexity** (High/Medium/Low)
|
||||
6. **Present the plan** and WAIT for your explicit confirmation
|
||||
|
||||
## Example Usage
|
||||
|
||||
```
|
||||
User: /plan I need to add real-time notifications when markets resolve
|
||||
|
||||
Agent (planner):
|
||||
# Implementation Plan: Real-Time Market Resolution Notifications
|
||||
|
||||
## Requirements Restatement
|
||||
- Send notifications to users when markets they're watching resolve
|
||||
- Support multiple notification channels (in-app, email, webhook)
|
||||
- Ensure notifications are delivered reliably
|
||||
- Include market outcome and user's position result
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Database Schema
|
||||
- Add notifications table with columns: id, user_id, market_id, type, status, created_at
|
||||
- Add user_notification_preferences table for channel preferences
|
||||
- Create indexes on user_id and market_id for performance
|
||||
|
||||
### Phase 2: Notification Service
|
||||
- Create notification service in lib/notifications.ts
|
||||
- Implement notification queue using BullMQ/Redis
|
||||
- Add retry logic for failed deliveries
|
||||
- Create notification templates
|
||||
|
||||
### Phase 3: Integration Points
|
||||
- Hook into market resolution logic (when status changes to "resolved")
|
||||
- Query all users with positions in market
|
||||
- Enqueue notifications for each user
|
||||
|
||||
### Phase 4: Frontend Components
|
||||
- Create NotificationBell component in header
|
||||
- Add NotificationList modal
|
||||
- Implement real-time updates via Supabase subscriptions
|
||||
- Add notification preferences page
|
||||
|
||||
## Dependencies
|
||||
- Redis (for queue)
|
||||
- Email service (SendGrid/Resend)
|
||||
- Supabase real-time subscriptions
|
||||
|
||||
## Risks
|
||||
- HIGH: Email deliverability (SPF/DKIM required)
|
||||
- MEDIUM: Performance with 1000+ users per market
|
||||
- MEDIUM: Notification spam if markets resolve frequently
|
||||
- LOW: Real-time subscription overhead
|
||||
|
||||
## Estimated Complexity: MEDIUM
|
||||
- Backend: 4-6 hours
|
||||
- Frontend: 3-4 hours
|
||||
- Testing: 2-3 hours
|
||||
- Total: 9-13 hours
|
||||
|
||||
**WAITING FOR CONFIRMATION**: Proceed with this plan? (yes/no/modify)
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
**CRITICAL**: The planner agent will **NOT** write any code until you explicitly confirm the plan with "yes" or "proceed" or similar affirmative response.
|
||||
|
||||
If you want changes, respond with:
|
||||
- "modify: [your changes]"
|
||||
- "different approach: [alternative]"
|
||||
- "skip phase 2 and do phase 3 first"
|
||||
|
||||
## Integration with Other Commands
|
||||
|
||||
After planning:
|
||||
- Use `/tdd` to implement with test-driven development
|
||||
- Use `/build-and-fix` if build errors occur
|
||||
- Use `/code-review` to review completed implementation
|
||||
|
||||
## Related Agents
|
||||
|
||||
This command invokes the `planner` agent located at:
|
||||
`~/.claude/agents/planner.md`
|
||||
28
.claude/commands/refactor-clean.md
Normal file
28
.claude/commands/refactor-clean.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Refactor Clean
|
||||
|
||||
Safely identify and remove dead code with test verification:
|
||||
|
||||
1. Run dead code analysis tools:
|
||||
- knip: Find unused exports and files
|
||||
- depcheck: Find unused dependencies
|
||||
- ts-prune: Find unused TypeScript exports
|
||||
|
||||
2. Generate comprehensive report in .reports/dead-code-analysis.md
|
||||
|
||||
3. Categorize findings by severity:
|
||||
- SAFE: Test files, unused utilities
|
||||
- CAUTION: API routes, components
|
||||
- DANGER: Config files, main entry points
|
||||
|
||||
4. Propose safe deletions only
|
||||
|
||||
5. Before each deletion:
|
||||
- Run full test suite
|
||||
- Verify tests pass
|
||||
- Apply change
|
||||
- Re-run tests
|
||||
- Rollback if tests fail
|
||||
|
||||
6. Show summary of cleaned items
|
||||
|
||||
Never delete code without running tests first!
|
||||
80
.claude/commands/setup-pm.md
Normal file
80
.claude/commands/setup-pm.md
Normal file
@@ -0,0 +1,80 @@
|
||||
---
|
||||
description: Configure your preferred package manager (npm/pnpm/yarn/bun)
|
||||
disable-model-invocation: true
|
||||
---
|
||||
|
||||
# Package Manager Setup
|
||||
|
||||
Configure your preferred package manager for this project or globally.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Detect current package manager
|
||||
node scripts/setup-package-manager.js --detect
|
||||
|
||||
# Set global preference
|
||||
node scripts/setup-package-manager.js --global pnpm
|
||||
|
||||
# Set project preference
|
||||
node scripts/setup-package-manager.js --project bun
|
||||
|
||||
# List available package managers
|
||||
node scripts/setup-package-manager.js --list
|
||||
```
|
||||
|
||||
## Detection Priority
|
||||
|
||||
When determining which package manager to use, the following order is checked:
|
||||
|
||||
1. **Environment variable**: `CLAUDE_PACKAGE_MANAGER`
|
||||
2. **Project config**: `.claude/package-manager.json`
|
||||
3. **package.json**: `packageManager` field
|
||||
4. **Lock file**: Presence of package-lock.json, yarn.lock, pnpm-lock.yaml, or bun.lockb
|
||||
5. **Global config**: `~/.claude/package-manager.json`
|
||||
6. **Fallback**: First available package manager (pnpm > bun > yarn > npm)
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### Global Configuration
|
||||
```json
|
||||
// ~/.claude/package-manager.json
|
||||
{
|
||||
"packageManager": "pnpm"
|
||||
}
|
||||
```
|
||||
|
||||
### Project Configuration
|
||||
```json
|
||||
// .claude/package-manager.json
|
||||
{
|
||||
"packageManager": "bun"
|
||||
}
|
||||
```
|
||||
|
||||
### package.json
|
||||
```json
|
||||
{
|
||||
"packageManager": "pnpm@8.6.0"
|
||||
}
|
||||
```
|
||||
|
||||
## Environment Variable
|
||||
|
||||
Set `CLAUDE_PACKAGE_MANAGER` to override all other detection methods:
|
||||
|
||||
```bash
|
||||
# Windows (PowerShell)
|
||||
$env:CLAUDE_PACKAGE_MANAGER = "pnpm"
|
||||
|
||||
# macOS/Linux
|
||||
export CLAUDE_PACKAGE_MANAGER=pnpm
|
||||
```
|
||||
|
||||
## Run the Detection
|
||||
|
||||
To see current package manager detection results, run:
|
||||
|
||||
```bash
|
||||
node scripts/setup-package-manager.js --detect
|
||||
```
|
||||
326
.claude/commands/tdd.md
Normal file
326
.claude/commands/tdd.md
Normal file
@@ -0,0 +1,326 @@
|
||||
---
|
||||
description: Enforce test-driven development workflow. Scaffold interfaces, generate tests FIRST, then implement minimal code to pass. Ensure 80%+ coverage.
|
||||
---
|
||||
|
||||
# TDD Command
|
||||
|
||||
This command invokes the **tdd-guide** agent to enforce test-driven development methodology.
|
||||
|
||||
## What This Command Does
|
||||
|
||||
1. **Scaffold Interfaces** - Define types/interfaces first
|
||||
2. **Generate Tests First** - Write failing tests (RED)
|
||||
3. **Implement Minimal Code** - Write just enough to pass (GREEN)
|
||||
4. **Refactor** - Improve code while keeping tests green (REFACTOR)
|
||||
5. **Verify Coverage** - Ensure 80%+ test coverage
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `/tdd` when:
|
||||
- Implementing new features
|
||||
- Adding new functions/components
|
||||
- Fixing bugs (write test that reproduces bug first)
|
||||
- Refactoring existing code
|
||||
- Building critical business logic
|
||||
|
||||
## How It Works
|
||||
|
||||
The tdd-guide agent will:
|
||||
|
||||
1. **Define interfaces** for inputs/outputs
|
||||
2. **Write tests that will FAIL** (because code doesn't exist yet)
|
||||
3. **Run tests** and verify they fail for the right reason
|
||||
4. **Write minimal implementation** to make tests pass
|
||||
5. **Run tests** and verify they pass
|
||||
6. **Refactor** code while keeping tests green
|
||||
7. **Check coverage** and add more tests if below 80%
|
||||
|
||||
## TDD Cycle
|
||||
|
||||
```
|
||||
RED → GREEN → REFACTOR → REPEAT
|
||||
|
||||
RED: Write a failing test
|
||||
GREEN: Write minimal code to pass
|
||||
REFACTOR: Improve code, keep tests passing
|
||||
REPEAT: Next feature/scenario
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
```
|
||||
User: /tdd I need a function to calculate market liquidity score
|
||||
|
||||
Agent (tdd-guide):
|
||||
# TDD Session: Market Liquidity Score Calculator
|
||||
|
||||
## Step 1: Define Interface (SCAFFOLD)
|
||||
|
||||
```typescript
|
||||
// lib/liquidity.ts
|
||||
export interface MarketData {
|
||||
totalVolume: number
|
||||
bidAskSpread: number
|
||||
activeTraders: number
|
||||
lastTradeTime: Date
|
||||
}
|
||||
|
||||
export function calculateLiquidityScore(market: MarketData): number {
|
||||
// TODO: Implementation
|
||||
throw new Error('Not implemented')
|
||||
}
|
||||
```
|
||||
|
||||
## Step 2: Write Failing Test (RED)
|
||||
|
||||
```typescript
|
||||
// lib/liquidity.test.ts
|
||||
import { calculateLiquidityScore } from './liquidity'
|
||||
|
||||
describe('calculateLiquidityScore', () => {
|
||||
it('should return high score for liquid market', () => {
|
||||
const market = {
|
||||
totalVolume: 100000,
|
||||
bidAskSpread: 0.01,
|
||||
activeTraders: 500,
|
||||
lastTradeTime: new Date()
|
||||
}
|
||||
|
||||
const score = calculateLiquidityScore(market)
|
||||
|
||||
expect(score).toBeGreaterThan(80)
|
||||
expect(score).toBeLessThanOrEqual(100)
|
||||
})
|
||||
|
||||
it('should return low score for illiquid market', () => {
|
||||
const market = {
|
||||
totalVolume: 100,
|
||||
bidAskSpread: 0.5,
|
||||
activeTraders: 2,
|
||||
lastTradeTime: new Date(Date.now() - 86400000) // 1 day ago
|
||||
}
|
||||
|
||||
const score = calculateLiquidityScore(market)
|
||||
|
||||
expect(score).toBeLessThan(30)
|
||||
expect(score).toBeGreaterThanOrEqual(0)
|
||||
})
|
||||
|
||||
it('should handle edge case: zero volume', () => {
|
||||
const market = {
|
||||
totalVolume: 0,
|
||||
bidAskSpread: 0,
|
||||
activeTraders: 0,
|
||||
lastTradeTime: new Date()
|
||||
}
|
||||
|
||||
const score = calculateLiquidityScore(market)
|
||||
|
||||
expect(score).toBe(0)
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Step 3: Run Tests - Verify FAIL
|
||||
|
||||
```bash
|
||||
npm test lib/liquidity.test.ts
|
||||
|
||||
FAIL lib/liquidity.test.ts
|
||||
✕ should return high score for liquid market (2 ms)
|
||||
Error: Not implemented
|
||||
|
||||
1 test failed, 0 passed
|
||||
```
|
||||
|
||||
✅ Tests fail as expected. Ready to implement.
|
||||
|
||||
## Step 4: Implement Minimal Code (GREEN)
|
||||
|
||||
```typescript
|
||||
// lib/liquidity.ts
|
||||
export function calculateLiquidityScore(market: MarketData): number {
|
||||
// Handle zero volume edge case
|
||||
if (market.totalVolume === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate component scores (0-100 scale)
|
||||
const volumeScore = Math.min(market.totalVolume / 1000, 100)
|
||||
const spreadScore = Math.max(100 - (market.bidAskSpread * 1000), 0)
|
||||
const traderScore = Math.min(market.activeTraders / 10, 100)
|
||||
|
||||
// Recent activity bonus
|
||||
const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60)
|
||||
const recencyScore = Math.max(100 - (hoursSinceLastTrade * 10), 0)
|
||||
|
||||
// Weighted average
|
||||
const score = (
|
||||
volumeScore * 0.4 +
|
||||
spreadScore * 0.3 +
|
||||
traderScore * 0.2 +
|
||||
recencyScore * 0.1
|
||||
)
|
||||
|
||||
return Math.min(Math.max(score, 0), 100) // Clamp to 0-100
|
||||
}
|
||||
```
|
||||
|
||||
## Step 5: Run Tests - Verify PASS
|
||||
|
||||
```bash
|
||||
npm test lib/liquidity.test.ts
|
||||
|
||||
PASS lib/liquidity.test.ts
|
||||
✓ should return high score for liquid market (3 ms)
|
||||
✓ should return low score for illiquid market (2 ms)
|
||||
✓ should handle edge case: zero volume (1 ms)
|
||||
|
||||
3 tests passed
|
||||
```
|
||||
|
||||
✅ All tests passing!
|
||||
|
||||
## Step 6: Refactor (IMPROVE)
|
||||
|
||||
```typescript
|
||||
// lib/liquidity.ts - Refactored with constants and better readability
|
||||
const WEIGHTS = {
|
||||
VOLUME: 0.4,
|
||||
SPREAD: 0.3,
|
||||
TRADERS: 0.2,
|
||||
RECENCY: 0.1,
|
||||
} as const
|
||||
|
||||
const SCALE_FACTORS = {
|
||||
VOLUME: 1000,
|
||||
SPREAD: 1000,
|
||||
TRADERS: 10,
|
||||
RECENCY_PENALTY: 10,
|
||||
} as const
|
||||
|
||||
function clamp(value: number, min: number, max: number): number {
|
||||
return Math.min(Math.max(value, min), max)
|
||||
}
|
||||
|
||||
export function calculateLiquidityScore(market: MarketData): number {
|
||||
if (market.totalVolume === 0) return 0
|
||||
|
||||
const volumeScore = Math.min(market.totalVolume / SCALE_FACTORS.VOLUME, 100)
|
||||
const spreadScore = clamp(100 - (market.bidAskSpread * SCALE_FACTORS.SPREAD), 0, 100)
|
||||
const traderScore = Math.min(market.activeTraders / SCALE_FACTORS.TRADERS, 100)
|
||||
|
||||
const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60)
|
||||
const recencyScore = clamp(100 - (hoursSinceLastTrade * SCALE_FACTORS.RECENCY_PENALTY), 0, 100)
|
||||
|
||||
const weightedScore =
|
||||
volumeScore * WEIGHTS.VOLUME +
|
||||
spreadScore * WEIGHTS.SPREAD +
|
||||
traderScore * WEIGHTS.TRADERS +
|
||||
recencyScore * WEIGHTS.RECENCY
|
||||
|
||||
return clamp(weightedScore, 0, 100)
|
||||
}
|
||||
```
|
||||
|
||||
## Step 7: Verify Tests Still Pass
|
||||
|
||||
```bash
|
||||
npm test lib/liquidity.test.ts
|
||||
|
||||
PASS lib/liquidity.test.ts
|
||||
✓ should return high score for liquid market (3 ms)
|
||||
✓ should return low score for illiquid market (2 ms)
|
||||
✓ should handle edge case: zero volume (1 ms)
|
||||
|
||||
3 tests passed
|
||||
```
|
||||
|
||||
✅ Refactoring complete, tests still passing!
|
||||
|
||||
## Step 8: Check Coverage
|
||||
|
||||
```bash
|
||||
npm test -- --coverage lib/liquidity.test.ts
|
||||
|
||||
File | % Stmts | % Branch | % Funcs | % Lines
|
||||
---------------|---------|----------|---------|--------
|
||||
liquidity.ts | 100 | 100 | 100 | 100
|
||||
|
||||
Coverage: 100% ✅ (Target: 80%)
|
||||
```
|
||||
|
||||
✅ TDD session complete!
|
||||
```
|
||||
|
||||
## TDD Best Practices
|
||||
|
||||
**DO:**
|
||||
- ✅ Write the test FIRST, before any implementation
|
||||
- ✅ Run tests and verify they FAIL before implementing
|
||||
- ✅ Write minimal code to make tests pass
|
||||
- ✅ Refactor only after tests are green
|
||||
- ✅ Add edge cases and error scenarios
|
||||
- ✅ Aim for 80%+ coverage (100% for critical code)
|
||||
|
||||
**DON'T:**
|
||||
- ❌ Write implementation before tests
|
||||
- ❌ Skip running tests after each change
|
||||
- ❌ Write too much code at once
|
||||
- ❌ Ignore failing tests
|
||||
- ❌ Test implementation details (test behavior)
|
||||
- ❌ Mock everything (prefer integration tests)
|
||||
|
||||
## Test Types to Include
|
||||
|
||||
**Unit Tests** (Function-level):
|
||||
- Happy path scenarios
|
||||
- Edge cases (empty, null, max values)
|
||||
- Error conditions
|
||||
- Boundary values
|
||||
|
||||
**Integration Tests** (Component-level):
|
||||
- API endpoints
|
||||
- Database operations
|
||||
- External service calls
|
||||
- React components with hooks
|
||||
|
||||
**E2E Tests** (use `/e2e` command):
|
||||
- Critical user flows
|
||||
- Multi-step processes
|
||||
- Full stack integration
|
||||
|
||||
## Coverage Requirements
|
||||
|
||||
- **80% minimum** for all code
|
||||
- **100% required** for:
|
||||
- Financial calculations
|
||||
- Authentication logic
|
||||
- Security-critical code
|
||||
- Core business logic
|
||||
|
||||
## Important Notes
|
||||
|
||||
**MANDATORY**: Tests must be written BEFORE implementation. The TDD cycle is:
|
||||
|
||||
1. **RED** - Write failing test
|
||||
2. **GREEN** - Implement to pass
|
||||
3. **REFACTOR** - Improve code
|
||||
|
||||
Never skip the RED phase. Never write code before tests.
|
||||
|
||||
## Integration with Other Commands
|
||||
|
||||
- Use `/plan` first to understand what to build
|
||||
- Use `/tdd` to implement with tests
|
||||
- Use `/build-and-fix` if build errors occur
|
||||
- Use `/code-review` to review implementation
|
||||
- Use `/test-coverage` to verify coverage
|
||||
|
||||
## Related Agents
|
||||
|
||||
This command invokes the `tdd-guide` agent located at:
|
||||
`~/.claude/agents/tdd-guide.md`
|
||||
|
||||
And can reference the `tdd-workflow` skill at:
|
||||
`~/.claude/skills/tdd-workflow/`
|
||||
27
.claude/commands/test-coverage.md
Normal file
27
.claude/commands/test-coverage.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Test Coverage
|
||||
|
||||
Analyze test coverage and generate missing tests:
|
||||
|
||||
1. Run tests with coverage: npm test --coverage or pnpm test --coverage
|
||||
|
||||
2. Analyze coverage report (coverage/coverage-summary.json)
|
||||
|
||||
3. Identify files below 80% coverage threshold
|
||||
|
||||
4. For each under-covered file:
|
||||
- Analyze untested code paths
|
||||
- Generate unit tests for functions
|
||||
- Generate integration tests for APIs
|
||||
- Generate E2E tests for critical flows
|
||||
|
||||
5. Verify new tests pass
|
||||
|
||||
6. Show before/after coverage metrics
|
||||
|
||||
7. Ensure project reaches 80%+ overall coverage
|
||||
|
||||
Focus on:
|
||||
- Happy path scenarios
|
||||
- Error handling
|
||||
- Edge cases (null, undefined, empty)
|
||||
- Boundary conditions
|
||||
17
.claude/commands/update-codemaps.md
Normal file
17
.claude/commands/update-codemaps.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Update Codemaps
|
||||
|
||||
Analyze the codebase structure and update architecture documentation:
|
||||
|
||||
1. Scan all source files for imports, exports, and dependencies
|
||||
2. Generate token-lean codemaps in the following format:
|
||||
- codemaps/architecture.md - Overall architecture
|
||||
- codemaps/backend.md - Backend structure
|
||||
- codemaps/frontend.md - Frontend structure
|
||||
- codemaps/data.md - Data models and schemas
|
||||
|
||||
3. Calculate diff percentage from previous version
|
||||
4. If changes > 30%, request user approval before updating
|
||||
5. Add freshness timestamp to each codemap
|
||||
6. Save reports to .reports/codemap-diff.txt
|
||||
|
||||
Use TypeScript/Node.js for analysis. Focus on high-level structure, not implementation details.
|
||||
31
.claude/commands/update-docs.md
Normal file
31
.claude/commands/update-docs.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Update Documentation
|
||||
|
||||
Sync documentation from source-of-truth:
|
||||
|
||||
1. Read package.json scripts section
|
||||
- Generate scripts reference table
|
||||
- Include descriptions from comments
|
||||
|
||||
2. Read .env.example
|
||||
- Extract all environment variables
|
||||
- Document purpose and format
|
||||
|
||||
3. Generate docs/CONTRIB.md with:
|
||||
- Development workflow
|
||||
- Available scripts
|
||||
- Environment setup
|
||||
- Testing procedures
|
||||
|
||||
4. Generate docs/RUNBOOK.md with:
|
||||
- Deployment procedures
|
||||
- Monitoring and alerts
|
||||
- Common issues and fixes
|
||||
- Rollback procedures
|
||||
|
||||
5. Identify obsolete documentation:
|
||||
- Find docs not modified in 90+ days
|
||||
- List for manual review
|
||||
|
||||
6. Show diff summary
|
||||
|
||||
Single source of truth: package.json and .env.example
|
||||
59
.claude/commands/verify.md
Normal file
59
.claude/commands/verify.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Verification Command
|
||||
|
||||
Run comprehensive verification on current codebase state.
|
||||
|
||||
## Instructions
|
||||
|
||||
Execute verification in this exact order:
|
||||
|
||||
1. **Build Check**
|
||||
- Run the build command for this project
|
||||
- If it fails, report errors and STOP
|
||||
|
||||
2. **Type Check**
|
||||
- Run TypeScript/type checker
|
||||
- Report all errors with file:line
|
||||
|
||||
3. **Lint Check**
|
||||
- Run linter
|
||||
- Report warnings and errors
|
||||
|
||||
4. **Test Suite**
|
||||
- Run all tests
|
||||
- Report pass/fail count
|
||||
- Report coverage percentage
|
||||
|
||||
5. **Console.log Audit**
|
||||
- Search for console.log in source files
|
||||
- Report locations
|
||||
|
||||
6. **Git Status**
|
||||
- Show uncommitted changes
|
||||
- Show files modified since last commit
|
||||
|
||||
## Output
|
||||
|
||||
Produce a concise verification report:
|
||||
|
||||
```
|
||||
VERIFICATION: [PASS/FAIL]
|
||||
|
||||
Build: [OK/FAIL]
|
||||
Types: [OK/X errors]
|
||||
Lint: [OK/X issues]
|
||||
Tests: [X/Y passed, Z% coverage]
|
||||
Secrets: [OK/X found]
|
||||
Logs: [OK/X console.logs]
|
||||
|
||||
Ready for PR: [YES/NO]
|
||||
```
|
||||
|
||||
If any critical issues, list them with fix suggestions.
|
||||
|
||||
## Arguments
|
||||
|
||||
$ARGUMENTS can be:
|
||||
- `quick` - Only build + types
|
||||
- `full` - All checks (default)
|
||||
- `pre-commit` - Checks relevant for commits
|
||||
- `pre-pr` - Full checks plus security scan
|
||||
157
.claude/hooks/hooks.json
Normal file
157
.claude/hooks/hooks.json
Normal file
@@ -0,0 +1,157 @@
|
||||
{
|
||||
"$schema": "https://json.schemastore.org/claude-code-settings.json",
|
||||
"hooks": {
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "tool == \"Bash\" && tool_input.command matches \"(npm run dev|pnpm( run)? dev|yarn dev|bun run dev)\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"console.error('[Hook] BLOCKED: Dev server must run in tmux for log access');console.error('[Hook] Use: tmux new-session -d -s dev \\\"npm run dev\\\"');console.error('[Hook] Then: tmux attach -t dev');process.exit(1)\""
|
||||
}
|
||||
],
|
||||
"description": "Block dev servers outside tmux - ensures you can access logs"
|
||||
},
|
||||
{
|
||||
"matcher": "tool == \"Bash\" && tool_input.command matches \"(npm (install|test)|pnpm (install|test)|yarn (install|test)?|bun (install|test)|cargo build|make|docker|pytest|vitest|playwright)\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"if(!process.env.TMUX){console.error('[Hook] Consider running in tmux for session persistence');console.error('[Hook] tmux new -s dev | tmux attach -t dev')}\""
|
||||
}
|
||||
],
|
||||
"description": "Reminder to use tmux for long-running commands"
|
||||
},
|
||||
{
|
||||
"matcher": "tool == \"Bash\" && tool_input.command matches \"git push\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"console.error('[Hook] Review changes before push...');console.error('[Hook] Continuing with push (remove this hook to add interactive review)')\""
|
||||
}
|
||||
],
|
||||
"description": "Reminder before git push to review changes"
|
||||
},
|
||||
{
|
||||
"matcher": "tool == \"Write\" && tool_input.file_path matches \"\\\\.(md|txt)$\" && !(tool_input.file_path matches \"README\\\\.md|CLAUDE\\\\.md|AGENTS\\\\.md|CONTRIBUTING\\\\.md\")",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path||'';if(/\\.(md|txt)$/.test(p)&&!/(README|CLAUDE|AGENTS|CONTRIBUTING)\\.md$/.test(p)){console.error('[Hook] BLOCKED: Unnecessary documentation file creation');console.error('[Hook] File: '+p);console.error('[Hook] Use README.md for documentation instead');process.exit(1)}console.log(d)})\""
|
||||
}
|
||||
],
|
||||
"description": "Block creation of random .md files - keeps docs consolidated"
|
||||
},
|
||||
{
|
||||
"matcher": "tool == \"Edit\" || tool == \"Write\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/suggest-compact.js\""
|
||||
}
|
||||
],
|
||||
"description": "Suggest manual compaction at logical intervals"
|
||||
}
|
||||
],
|
||||
"PreCompact": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/pre-compact.js\""
|
||||
}
|
||||
],
|
||||
"description": "Save state before context compaction"
|
||||
}
|
||||
],
|
||||
"SessionStart": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-start.js\""
|
||||
}
|
||||
],
|
||||
"description": "Load previous context and detect package manager on new session"
|
||||
}
|
||||
],
|
||||
"PostToolUse": [
|
||||
{
|
||||
"matcher": "tool == \"Bash\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const cmd=i.tool_input?.command||'';if(/gh pr create/.test(cmd)){const out=i.tool_output?.output||'';const m=out.match(/https:\\/\\/github.com\\/[^/]+\\/[^/]+\\/pull\\/\\d+/);if(m){console.error('[Hook] PR created: '+m[0]);const repo=m[0].replace(/https:\\/\\/github.com\\/([^/]+\\/[^/]+)\\/pull\\/\\d+/,'$1');const pr=m[0].replace(/.*\\/pull\\/(\\d+)/,'$1');console.error('[Hook] To review: gh pr review '+pr+' --repo '+repo)}}console.log(d)})\""
|
||||
}
|
||||
],
|
||||
"description": "Log PR URL and provide review command after PR creation"
|
||||
},
|
||||
{
|
||||
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){try{execSync('npx prettier --write \"'+p+'\"',{stdio:['pipe','pipe','pipe']})}catch(e){}}console.log(d)})\""
|
||||
}
|
||||
],
|
||||
"description": "Auto-format JS/TS files with Prettier after edits"
|
||||
},
|
||||
{
|
||||
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx)$\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');const path=require('path');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){let dir=path.dirname(p);while(dir!==path.dirname(dir)&&!fs.existsSync(path.join(dir,'tsconfig.json'))){dir=path.dirname(dir)}if(fs.existsSync(path.join(dir,'tsconfig.json'))){try{const r=execSync('npx tsc --noEmit --pretty false 2>&1',{cwd:dir,encoding:'utf8',stdio:['pipe','pipe','pipe']});const lines=r.split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}catch(e){const lines=(e.stdout||'').split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}}}console.log(d)})\""
|
||||
}
|
||||
],
|
||||
"description": "TypeScript check after editing .ts/.tsx files"
|
||||
},
|
||||
{
|
||||
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){const c=fs.readFileSync(p,'utf8');const lines=c.split('\\n');const matches=[];lines.forEach((l,idx)=>{if(/console\\.log/.test(l))matches.push((idx+1)+': '+l.trim())});if(matches.length){console.error('[Hook] WARNING: console.log found in '+p);matches.slice(0,5).forEach(m=>console.error(m));console.error('[Hook] Remove console.log before committing')}}console.log(d)})\""
|
||||
}
|
||||
],
|
||||
"description": "Warn about console.log statements after edits"
|
||||
}
|
||||
],
|
||||
"Stop": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{try{execSync('git rev-parse --git-dir',{stdio:'pipe'})}catch{console.log(d);process.exit(0)}try{const files=execSync('git diff --name-only HEAD',{encoding:'utf8',stdio:['pipe','pipe','pipe']}).split('\\n').filter(f=>/\\.(ts|tsx|js|jsx)$/.test(f)&&fs.existsSync(f));let hasConsole=false;for(const f of files){if(fs.readFileSync(f,'utf8').includes('console.log')){console.error('[Hook] WARNING: console.log found in '+f);hasConsole=true}}if(hasConsole)console.error('[Hook] Remove console.log statements before committing')}catch(e){}console.log(d)})\""
|
||||
}
|
||||
],
|
||||
"description": "Check for console.log in modified files after each response"
|
||||
}
|
||||
],
|
||||
"SessionEnd": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-end.js\""
|
||||
}
|
||||
],
|
||||
"description": "Persist session state on end"
|
||||
},
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/evaluate-session.js\""
|
||||
}
|
||||
],
|
||||
"description": "Evaluate session for extractable patterns"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
36
.claude/hooks/memory-persistence/pre-compact.sh
Normal file
36
.claude/hooks/memory-persistence/pre-compact.sh
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
# PreCompact Hook - Save state before context compaction
|
||||
#
|
||||
# Runs before Claude compacts context, giving you a chance to
|
||||
# preserve important state that might get lost in summarization.
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "PreCompact": [{
|
||||
# "matcher": "*",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/hooks/memory-persistence/pre-compact.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
|
||||
SESSIONS_DIR="${HOME}/.claude/sessions"
|
||||
COMPACTION_LOG="${SESSIONS_DIR}/compaction-log.txt"
|
||||
|
||||
mkdir -p "$SESSIONS_DIR"
|
||||
|
||||
# Log compaction event with timestamp
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] Context compaction triggered" >> "$COMPACTION_LOG"
|
||||
|
||||
# If there's an active session file, note the compaction
|
||||
ACTIVE_SESSION=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1)
|
||||
if [ -n "$ACTIVE_SESSION" ]; then
|
||||
echo "" >> "$ACTIVE_SESSION"
|
||||
echo "---" >> "$ACTIVE_SESSION"
|
||||
echo "**[Compaction occurred at $(date '+%H:%M')]** - Context was summarized" >> "$ACTIVE_SESSION"
|
||||
fi
|
||||
|
||||
echo "[PreCompact] State saved before compaction" >&2
|
||||
61
.claude/hooks/memory-persistence/session-end.sh
Normal file
61
.claude/hooks/memory-persistence/session-end.sh
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# Stop Hook (Session End) - Persist learnings when session ends
|
||||
#
|
||||
# Runs when Claude session ends. Creates/updates session log file
|
||||
# with timestamp for continuity tracking.
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "Stop": [{
|
||||
# "matcher": "*",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/hooks/memory-persistence/session-end.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
|
||||
SESSIONS_DIR="${HOME}/.claude/sessions"
|
||||
TODAY=$(date '+%Y-%m-%d')
|
||||
SESSION_FILE="${SESSIONS_DIR}/${TODAY}-session.tmp"
|
||||
|
||||
mkdir -p "$SESSIONS_DIR"
|
||||
|
||||
# If session file exists for today, update the end time
|
||||
if [ -f "$SESSION_FILE" ]; then
|
||||
# Update Last Updated timestamp
|
||||
sed -i '' "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null || \
|
||||
sed -i "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null
|
||||
echo "[SessionEnd] Updated session file: $SESSION_FILE" >&2
|
||||
else
|
||||
# Create new session file with template
|
||||
cat > "$SESSION_FILE" << EOF
|
||||
# Session: $(date '+%Y-%m-%d')
|
||||
**Date:** $TODAY
|
||||
**Started:** $(date '+%H:%M')
|
||||
**Last Updated:** $(date '+%H:%M')
|
||||
|
||||
---
|
||||
|
||||
## Current State
|
||||
|
||||
[Session context goes here]
|
||||
|
||||
### Completed
|
||||
- [ ]
|
||||
|
||||
### In Progress
|
||||
- [ ]
|
||||
|
||||
### Notes for Next Session
|
||||
-
|
||||
|
||||
### Context to Load
|
||||
\`\`\`
|
||||
[relevant files]
|
||||
\`\`\`
|
||||
EOF
|
||||
echo "[SessionEnd] Created session file: $SESSION_FILE" >&2
|
||||
fi
|
||||
37
.claude/hooks/memory-persistence/session-start.sh
Normal file
37
.claude/hooks/memory-persistence/session-start.sh
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
# SessionStart Hook - Load previous context on new session
|
||||
#
|
||||
# Runs when a new Claude session starts. Checks for recent session
|
||||
# files and notifies Claude of available context to load.
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "SessionStart": [{
|
||||
# "matcher": "*",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/hooks/memory-persistence/session-start.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
|
||||
SESSIONS_DIR="${HOME}/.claude/sessions"
|
||||
LEARNED_DIR="${HOME}/.claude/skills/learned"
|
||||
|
||||
# Check for recent session files (last 7 days)
|
||||
recent_sessions=$(find "$SESSIONS_DIR" -name "*.tmp" -mtime -7 2>/dev/null | wc -l | tr -d ' ')
|
||||
|
||||
if [ "$recent_sessions" -gt 0 ]; then
|
||||
latest=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1)
|
||||
echo "[SessionStart] Found $recent_sessions recent session(s)" >&2
|
||||
echo "[SessionStart] Latest: $latest" >&2
|
||||
fi
|
||||
|
||||
# Check for learned skills
|
||||
learned_count=$(find "$LEARNED_DIR" -name "*.md" 2>/dev/null | wc -l | tr -d ' ')
|
||||
|
||||
if [ "$learned_count" -gt 0 ]; then
|
||||
echo "[SessionStart] $learned_count learned skill(s) available in $LEARNED_DIR" >&2
|
||||
fi
|
||||
52
.claude/hooks/strategic-compact/suggest-compact.sh
Normal file
52
.claude/hooks/strategic-compact/suggest-compact.sh
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Strategic Compact Suggester
|
||||
# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals
|
||||
#
|
||||
# Why manual over auto-compact:
|
||||
# - Auto-compact happens at arbitrary points, often mid-task
|
||||
# - Strategic compacting preserves context through logical phases
|
||||
# - Compact after exploration, before execution
|
||||
# - Compact after completing a milestone, before starting next
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "PreToolUse": [{
|
||||
# "matcher": "Edit|Write",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# Criteria for suggesting compact:
|
||||
# - Session has been running for extended period
|
||||
# - Large number of tool calls made
|
||||
# - Transitioning from research/exploration to implementation
|
||||
# - Plan has been finalized
|
||||
|
||||
# Track tool call count (increment in a temp file)
|
||||
COUNTER_FILE="/tmp/claude-tool-count-$$"
|
||||
THRESHOLD=${COMPACT_THRESHOLD:-50}
|
||||
|
||||
# Initialize or increment counter
|
||||
if [ -f "$COUNTER_FILE" ]; then
|
||||
count=$(cat "$COUNTER_FILE")
|
||||
count=$((count + 1))
|
||||
echo "$count" > "$COUNTER_FILE"
|
||||
else
|
||||
echo "1" > "$COUNTER_FILE"
|
||||
count=1
|
||||
fi
|
||||
|
||||
# Suggest compact after threshold tool calls
|
||||
if [ "$count" -eq "$THRESHOLD" ]; then
|
||||
echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2
|
||||
fi
|
||||
|
||||
# Suggest at regular intervals after threshold
|
||||
if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then
|
||||
echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2
|
||||
fi
|
||||
@@ -7,7 +7,8 @@
|
||||
"Edit(*)",
|
||||
"Glob(*)",
|
||||
"Grep(*)",
|
||||
"Task(*)"
|
||||
"Task(*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_batch_upload_routes.py::TestBatchUploadRoutes::test_upload_batch_async_mode_default -v -s 2>&1 | head -100\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -75,7 +75,39 @@
|
||||
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")",
|
||||
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/structured_data/*.csv 2>/dev/null | head -20\")",
|
||||
"Bash(tasklist:*)",
|
||||
"Bash(findstr:*)"
|
||||
"Bash(findstr:*)",
|
||||
"Bash(wsl bash -c \"ps aux | grep -E ''python.*train'' | grep -v grep\")",
|
||||
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/\")",
|
||||
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
|
||||
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
|
||||
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
|
||||
"Bash(wsl bash -c:*)",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -120\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -80\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
|
||||
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | head -150\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | tail -50\")",
|
||||
"Bash(wsl bash -c \"lsof -ti:8000 | xargs -r kill -9 2>/dev/null; echo ''Port 8000 cleared''\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py\")",
|
||||
"Bash(wsl bash -c \"curl -s http://localhost:3001 2>/dev/null | head -5 || echo ''Frontend not responding''\")",
|
||||
"Bash(wsl bash -c \"curl -s http://localhost:3000 2>/dev/null | head -5 || echo ''Port 3000 not responding''\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c ''from shared.training import YOLOTrainer, TrainingConfig, TrainingResult; print\\(\"\"Shared training module imported successfully\"\"\\)''\")",
|
||||
"Bash(npm run dev:*)",
|
||||
"Bash(ping:*)",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/frontend && npm run dev\")",
|
||||
"Bash(git checkout:*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && PGPASSWORD=$DB_PASSWORD psql -h 192.168.68.31 -U docmaster -d docmaster -f migrations/006_model_versions.sql 2>&1\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c \"\"\nimport os\nimport psycopg2\nfrom pathlib import Path\n\n# Get connection details\nhost = os.getenv\\(''DB_HOST'', ''192.168.68.31''\\)\nport = os.getenv\\(''DB_PORT'', ''5432''\\)\ndbname = os.getenv\\(''DB_NAME'', ''docmaster''\\)\nuser = os.getenv\\(''DB_USER'', ''docmaster''\\)\npassword = os.getenv\\(''DB_PASSWORD'', ''''\\)\n\nprint\\(f''Connecting to {host}:{port}/{dbname}...''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\nconn.autocommit = True\ncursor = conn.cursor\\(\\)\n\n# Run migration 006\nprint\\(''Running migration 006_model_versions.sql...''\\)\nsql = Path\\(''migrations/006_model_versions.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 006 complete!''\\)\n\n# Run migration 007\nprint\\(''Running migration 007_training_tasks_extra_columns.sql...''\\)\nsql = Path\\(''migrations/007_training_tasks_extra_columns.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 007 complete!''\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\nprint\\(''All migrations completed successfully!''\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": [],
|
||||
|
||||
314
.claude/skills/backend-patterns/SKILL.md
Normal file
314
.claude/skills/backend-patterns/SKILL.md
Normal file
@@ -0,0 +1,314 @@
|
||||
# Backend Development Patterns
|
||||
|
||||
Backend architecture patterns for Python/FastAPI/PostgreSQL applications.
|
||||
|
||||
## API Design
|
||||
|
||||
### RESTful Structure
|
||||
|
||||
```
|
||||
GET /api/v1/documents # List
|
||||
GET /api/v1/documents/{id} # Get
|
||||
POST /api/v1/documents # Create
|
||||
PUT /api/v1/documents/{id} # Replace
|
||||
PATCH /api/v1/documents/{id} # Update
|
||||
DELETE /api/v1/documents/{id} # Delete
|
||||
|
||||
GET /api/v1/documents?status=processed&sort=created_at&limit=20&offset=0
|
||||
```
|
||||
|
||||
### FastAPI Route Pattern
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
||||
|
||||
@router.post("/infer", response_model=ApiResponse[InferenceResult])
|
||||
async def infer_document(
|
||||
file: UploadFile = File(...),
|
||||
confidence_threshold: float = Query(0.5, ge=0, le=1),
|
||||
service: InferenceService = Depends(get_inference_service)
|
||||
) -> ApiResponse[InferenceResult]:
|
||||
result = await service.process(file, confidence_threshold)
|
||||
return ApiResponse(success=True, data=result)
|
||||
```
|
||||
|
||||
### Consistent Response Schema
|
||||
|
||||
```python
|
||||
from typing import Generic, TypeVar
|
||||
T = TypeVar('T')
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
success: bool
|
||||
data: T | None = None
|
||||
error: str | None = None
|
||||
meta: dict | None = None
|
||||
```
|
||||
|
||||
## Core Patterns
|
||||
|
||||
### Repository Pattern
|
||||
|
||||
```python
|
||||
from typing import Protocol
|
||||
|
||||
class DocumentRepository(Protocol):
|
||||
def find_all(self, filters: dict | None = None) -> list[Document]: ...
|
||||
def find_by_id(self, id: str) -> Document | None: ...
|
||||
def create(self, data: dict) -> Document: ...
|
||||
def update(self, id: str, data: dict) -> Document: ...
|
||||
def delete(self, id: str) -> None: ...
|
||||
```
|
||||
|
||||
### Service Layer
|
||||
|
||||
```python
|
||||
class InferenceService:
|
||||
def __init__(self, model_path: str, use_gpu: bool = True):
|
||||
self.pipeline = InferencePipeline(model_path=model_path, use_gpu=use_gpu)
|
||||
|
||||
async def process(self, file: UploadFile, confidence_threshold: float) -> InferenceResult:
|
||||
temp_path = self._save_temp_file(file)
|
||||
try:
|
||||
return self.pipeline.process_pdf(temp_path)
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
```
|
||||
|
||||
### Dependency Injection
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
db_host: str = "localhost"
|
||||
db_password: str
|
||||
model_path: str = "runs/train/invoice_fields/weights/best.pt"
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
def get_inference_service(settings: Settings = Depends(get_settings)) -> InferenceService:
|
||||
return InferenceService(model_path=settings.model_path)
|
||||
```
|
||||
|
||||
## Database Patterns
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
```python
|
||||
from psycopg2 import pool
|
||||
from contextlib import contextmanager
|
||||
|
||||
db_pool = pool.ThreadedConnectionPool(minconn=2, maxconn=10, **db_config)
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection():
|
||||
conn = db_pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
db_pool.putconn(conn)
|
||||
```
|
||||
|
||||
### Query Optimization
|
||||
|
||||
```python
|
||||
# GOOD: Select only needed columns
|
||||
cur.execute("""
|
||||
SELECT id, status, fields->>'InvoiceNumber' as invoice_number
|
||||
FROM documents WHERE status = %s
|
||||
ORDER BY created_at DESC LIMIT %s
|
||||
""", ('processed', 10))
|
||||
|
||||
# BAD: SELECT * FROM documents
|
||||
```
|
||||
|
||||
### N+1 Prevention
|
||||
|
||||
```python
|
||||
# BAD: N+1 queries
|
||||
for doc in documents:
|
||||
doc.labels = get_labels(doc.id) # N queries
|
||||
|
||||
# GOOD: Batch fetch with JOIN
|
||||
cur.execute("""
|
||||
SELECT d.id, d.status, array_agg(l.label) as labels
|
||||
FROM documents d
|
||||
LEFT JOIN document_labels l ON d.id = l.document_id
|
||||
GROUP BY d.id, d.status
|
||||
""")
|
||||
```
|
||||
|
||||
### Transaction Pattern
|
||||
|
||||
```python
|
||||
def create_document_with_labels(doc_data: dict, labels: list[dict]) -> str:
|
||||
with get_db_connection() as conn:
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("INSERT INTO documents ... RETURNING id", ...)
|
||||
doc_id = cur.fetchone()[0]
|
||||
for label in labels:
|
||||
cur.execute("INSERT INTO document_labels ...", ...)
|
||||
conn.commit()
|
||||
return doc_id
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
```
|
||||
|
||||
## Caching
|
||||
|
||||
```python
|
||||
from cachetools import TTLCache
|
||||
|
||||
_cache = TTLCache(maxsize=1000, ttl=300)
|
||||
|
||||
def get_document_cached(doc_id: str) -> Document | None:
|
||||
if doc_id in _cache:
|
||||
return _cache[doc_id]
|
||||
doc = repo.find_by_id(doc_id)
|
||||
if doc:
|
||||
_cache[doc_id] = doc
|
||||
return doc
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Exception Hierarchy
|
||||
|
||||
```python
|
||||
class AppError(Exception):
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
|
||||
class NotFoundError(AppError):
|
||||
def __init__(self, resource: str, id: str):
|
||||
super().__init__(f"{resource} not found: {id}", 404)
|
||||
|
||||
class ValidationError(AppError):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message, 400)
|
||||
```
|
||||
|
||||
### FastAPI Exception Handler
|
||||
|
||||
```python
|
||||
@app.exception_handler(AppError)
|
||||
async def app_error_handler(request: Request, exc: AppError):
|
||||
return JSONResponse(status_code=exc.status_code, content={"success": False, "error": exc.message})
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unexpected error: {exc}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "Internal server error"})
|
||||
```
|
||||
|
||||
### Retry with Backoff
|
||||
|
||||
```python
|
||||
async def retry_with_backoff(fn, max_retries: int = 3, base_delay: float = 1.0):
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await fn() if asyncio.iscoroutinefunction(fn) else fn()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(base_delay * (2 ** attempt))
|
||||
raise last_error
|
||||
```
|
||||
|
||||
## Rate Limiting
|
||||
|
||||
```python
|
||||
from time import time
|
||||
from collections import defaultdict
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self):
|
||||
self.requests: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def check_limit(self, identifier: str, max_requests: int, window_sec: int) -> bool:
|
||||
now = time()
|
||||
self.requests[identifier] = [t for t in self.requests[identifier] if now - t < window_sec]
|
||||
if len(self.requests[identifier]) >= max_requests:
|
||||
return False
|
||||
self.requests[identifier].append(now)
|
||||
return True
|
||||
|
||||
limiter = RateLimiter()
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
ip = request.client.host
|
||||
if not limiter.check_limit(ip, max_requests=100, window_sec=60):
|
||||
return JSONResponse(status_code=429, content={"error": "Rate limit exceeded"})
|
||||
return await call_next(request)
|
||||
```
|
||||
|
||||
## Logging & Middleware
|
||||
|
||||
### Request Logging
|
||||
|
||||
```python
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
logger.info(f"[{request_id}] {request.method} {request.url.path}")
|
||||
response = await call_next(request)
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
logger.info(f"[{request_id}] Completed {response.status_code} in {duration_ms:.2f}ms")
|
||||
return response
|
||||
```
|
||||
|
||||
### Structured Logging
|
||||
|
||||
```python
|
||||
class JSONFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return json.dumps({
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"level": record.levelname,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
})
|
||||
```
|
||||
|
||||
## Background Tasks
|
||||
|
||||
```python
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
def send_notification(document_id: str, status: str):
|
||||
logger.info(f"Notification: {document_id} -> {status}")
|
||||
|
||||
@router.post("/infer")
|
||||
async def infer(file: UploadFile, background_tasks: BackgroundTasks):
|
||||
result = await process_document(file)
|
||||
background_tasks.add_task(send_notification, result.document_id, "completed")
|
||||
return result
|
||||
```
|
||||
|
||||
## Key Principles
|
||||
|
||||
- Repository pattern: Abstract data access
|
||||
- Service layer: Business logic separated from routes
|
||||
- Dependency injection via `Depends()`
|
||||
- Connection pooling for database
|
||||
- Parameterized queries only (no f-strings in SQL)
|
||||
- Batch fetch to prevent N+1
|
||||
- Consistent `ApiResponse[T]` format
|
||||
- Exception hierarchy with proper status codes
|
||||
- Rate limit by IP
|
||||
- Structured logging with request ID
|
||||
665
.claude/skills/coding-standards/SKILL.md
Normal file
665
.claude/skills/coding-standards/SKILL.md
Normal file
@@ -0,0 +1,665 @@
|
||||
---
|
||||
name: coding-standards
|
||||
description: Universal coding standards, best practices, and patterns for Python, FastAPI, and data processing development.
|
||||
---
|
||||
|
||||
# Coding Standards & Best Practices
|
||||
|
||||
Python coding standards for the Invoice Master project.
|
||||
|
||||
## Code Quality Principles
|
||||
|
||||
### 1. Readability First
|
||||
- Code is read more than written
|
||||
- Clear variable and function names
|
||||
- Self-documenting code preferred over comments
|
||||
- Consistent formatting (follow PEP 8)
|
||||
|
||||
### 2. KISS (Keep It Simple, Stupid)
|
||||
- Simplest solution that works
|
||||
- Avoid over-engineering
|
||||
- No premature optimization
|
||||
- Easy to understand > clever code
|
||||
|
||||
### 3. DRY (Don't Repeat Yourself)
|
||||
- Extract common logic into functions
|
||||
- Create reusable utilities
|
||||
- Share modules across the codebase
|
||||
- Avoid copy-paste programming
|
||||
|
||||
### 4. YAGNI (You Aren't Gonna Need It)
|
||||
- Don't build features before they're needed
|
||||
- Avoid speculative generality
|
||||
- Add complexity only when required
|
||||
- Start simple, refactor when needed
|
||||
|
||||
## Python Standards
|
||||
|
||||
### Variable Naming
|
||||
|
||||
```python
|
||||
# GOOD: Descriptive names
|
||||
invoice_number = "INV-2024-001"
|
||||
is_valid_document = True
|
||||
total_confidence_score = 0.95
|
||||
|
||||
# BAD: Unclear names
|
||||
inv = "INV-2024-001"
|
||||
flag = True
|
||||
x = 0.95
|
||||
```
|
||||
|
||||
### Function Naming
|
||||
|
||||
```python
|
||||
# GOOD: Verb-noun pattern with type hints
|
||||
def extract_invoice_fields(pdf_path: Path) -> dict[str, str]:
|
||||
"""Extract fields from invoice PDF."""
|
||||
...
|
||||
|
||||
def calculate_confidence(predictions: list[float]) -> float:
|
||||
"""Calculate average confidence score."""
|
||||
...
|
||||
|
||||
def is_valid_bankgiro(value: str) -> bool:
|
||||
"""Check if value is valid Bankgiro number."""
|
||||
...
|
||||
|
||||
# BAD: Unclear or noun-only
|
||||
def invoice(path):
|
||||
...
|
||||
|
||||
def confidence(p):
|
||||
...
|
||||
|
||||
def bankgiro(v):
|
||||
...
|
||||
```
|
||||
|
||||
### Type Hints (REQUIRED)
|
||||
|
||||
```python
|
||||
# GOOD: Full type annotations
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
document_id: str
|
||||
fields: dict[str, str]
|
||||
confidence: dict[str, float]
|
||||
processing_time_ms: float
|
||||
|
||||
def process_document(
|
||||
pdf_path: Path,
|
||||
confidence_threshold: float = 0.5
|
||||
) -> InferenceResult:
|
||||
"""Process PDF and return extracted fields."""
|
||||
...
|
||||
|
||||
# BAD: No type hints
|
||||
def process_document(pdf_path, confidence_threshold=0.5):
|
||||
...
|
||||
```
|
||||
|
||||
### Immutability Pattern (CRITICAL)
|
||||
|
||||
```python
|
||||
# GOOD: Create new objects, don't mutate
|
||||
def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]:
|
||||
return {**fields, **updates}
|
||||
|
||||
def add_item(items: list[str], new_item: str) -> list[str]:
|
||||
return [*items, new_item]
|
||||
|
||||
# BAD: Direct mutation
|
||||
def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]:
|
||||
fields.update(updates) # MUTATION!
|
||||
return fields
|
||||
|
||||
def add_item(items: list[str], new_item: str) -> list[str]:
|
||||
items.append(new_item) # MUTATION!
|
||||
return items
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# GOOD: Comprehensive error handling with logging
|
||||
def load_model(model_path: Path) -> Model:
|
||||
"""Load YOLO model from path."""
|
||||
try:
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model not found: {model_path}")
|
||||
|
||||
model = YOLO(str(model_path))
|
||||
logger.info(f"Model loaded: {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise RuntimeError(f"Model loading failed: {model_path}") from e
|
||||
|
||||
# BAD: No error handling
|
||||
def load_model(model_path):
|
||||
return YOLO(str(model_path))
|
||||
|
||||
# BAD: Bare except
|
||||
def load_model(model_path):
|
||||
try:
|
||||
return YOLO(str(model_path))
|
||||
except: # Never use bare except!
|
||||
return None
|
||||
```
|
||||
|
||||
### Async Best Practices
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
# GOOD: Parallel execution when possible
|
||||
async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]:
|
||||
tasks = [process_document(path) for path in pdf_paths]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle exceptions
|
||||
valid_results = []
|
||||
for path, result in zip(pdf_paths, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Failed to process {path}: {result}")
|
||||
else:
|
||||
valid_results.append(result)
|
||||
return valid_results
|
||||
|
||||
# BAD: Sequential when unnecessary
|
||||
async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]:
|
||||
results = []
|
||||
for path in pdf_paths:
|
||||
result = await process_document(path)
|
||||
results.append(result)
|
||||
return results
|
||||
```
|
||||
|
||||
### Context Managers
|
||||
|
||||
```python
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
# GOOD: Proper resource management
|
||||
@contextmanager
|
||||
def temp_pdf_copy(pdf_path: Path):
|
||||
"""Create temporary copy of PDF for processing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
||||
tmp.write(pdf_path.read_bytes())
|
||||
tmp_path = Path(tmp.name)
|
||||
try:
|
||||
yield tmp_path
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
# Usage
|
||||
with temp_pdf_copy(original_pdf) as tmp_pdf:
|
||||
result = process_pdf(tmp_pdf)
|
||||
```
|
||||
|
||||
## FastAPI Best Practices
|
||||
|
||||
### Route Structure
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
||||
|
||||
class InferenceResponse(BaseModel):
|
||||
success: bool
|
||||
document_id: str
|
||||
fields: dict[str, str]
|
||||
confidence: dict[str, float]
|
||||
processing_time_ms: float
|
||||
|
||||
@router.post("/infer", response_model=InferenceResponse)
|
||||
async def infer_document(
|
||||
file: UploadFile = File(...),
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0)
|
||||
) -> InferenceResponse:
|
||||
"""Process invoice PDF and extract fields."""
|
||||
if not file.filename.endswith(".pdf"):
|
||||
raise HTTPException(status_code=400, detail="Only PDF files accepted")
|
||||
|
||||
result = await inference_service.process(file, confidence_threshold)
|
||||
return InferenceResponse(
|
||||
success=True,
|
||||
document_id=result.document_id,
|
||||
fields=result.fields,
|
||||
confidence=result.confidence,
|
||||
processing_time_ms=result.processing_time_ms
|
||||
)
|
||||
```
|
||||
|
||||
### Input Validation with Pydantic
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import date
|
||||
import re
|
||||
|
||||
class InvoiceData(BaseModel):
|
||||
invoice_number: str = Field(..., min_length=1, max_length=50)
|
||||
invoice_date: date
|
||||
amount: float = Field(..., gt=0)
|
||||
bankgiro: str | None = None
|
||||
ocr_number: str | None = None
|
||||
|
||||
@field_validator("bankgiro")
|
||||
@classmethod
|
||||
def validate_bankgiro(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
# Bankgiro: 7-8 digits
|
||||
cleaned = re.sub(r"[^0-9]", "", v)
|
||||
if not (7 <= len(cleaned) <= 8):
|
||||
raise ValueError("Bankgiro must be 7-8 digits")
|
||||
return cleaned
|
||||
|
||||
@field_validator("ocr_number")
|
||||
@classmethod
|
||||
def validate_ocr(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
# OCR: 2-25 digits
|
||||
cleaned = re.sub(r"[^0-9]", "", v)
|
||||
if not (2 <= len(cleaned) <= 25):
|
||||
raise ValueError("OCR must be 2-25 digits")
|
||||
return cleaned
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
success: bool
|
||||
data: T | None = None
|
||||
error: str | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
# Success response
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
meta={"processing_time_ms": elapsed_ms}
|
||||
)
|
||||
|
||||
# Error response
|
||||
return ApiResponse(
|
||||
success=False,
|
||||
error="Invalid PDF format"
|
||||
)
|
||||
```
|
||||
|
||||
## File Organization
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── cli/ # Command-line interfaces
|
||||
│ ├── autolabel.py
|
||||
│ ├── train.py
|
||||
│ └── infer.py
|
||||
├── pdf/ # PDF processing
|
||||
│ ├── extractor.py
|
||||
│ └── renderer.py
|
||||
├── ocr/ # OCR processing
|
||||
│ ├── paddle_ocr.py
|
||||
│ └── machine_code_parser.py
|
||||
├── inference/ # Inference pipeline
|
||||
│ ├── pipeline.py
|
||||
│ ├── yolo_detector.py
|
||||
│ └── field_extractor.py
|
||||
├── normalize/ # Field normalization
|
||||
│ ├── base.py
|
||||
│ ├── date_normalizer.py
|
||||
│ └── amount_normalizer.py
|
||||
├── web/ # FastAPI application
|
||||
│ ├── app.py
|
||||
│ ├── routes.py
|
||||
│ ├── services.py
|
||||
│ └── schemas.py
|
||||
└── utils/ # Shared utilities
|
||||
├── validators.py
|
||||
├── text_cleaner.py
|
||||
└── logging.py
|
||||
tests/ # Mirror of src structure
|
||||
├── test_pdf/
|
||||
├── test_ocr/
|
||||
└── test_inference/
|
||||
```
|
||||
|
||||
### File Naming
|
||||
|
||||
```
|
||||
src/ocr/paddle_ocr.py # snake_case for modules
|
||||
src/inference/yolo_detector.py # snake_case for modules
|
||||
tests/test_paddle_ocr.py # test_ prefix for tests
|
||||
config.py # snake_case for config
|
||||
```
|
||||
|
||||
### Module Size Guidelines
|
||||
|
||||
- **Maximum**: 800 lines per file
|
||||
- **Typical**: 200-400 lines per file
|
||||
- **Functions**: Max 50 lines each
|
||||
- Extract utilities when modules grow too large
|
||||
|
||||
## Comments & Documentation
|
||||
|
||||
### When to Comment
|
||||
|
||||
```python
|
||||
# GOOD: Explain WHY, not WHAT
|
||||
# Swedish Bankgiro uses Luhn algorithm with weight [1,2,1,2...]
|
||||
def validate_bankgiro_checksum(bankgiro: str) -> bool:
|
||||
...
|
||||
|
||||
# Payment line format: 7 groups separated by #, checksum at end
|
||||
def parse_payment_line(line: str) -> PaymentLineData:
|
||||
...
|
||||
|
||||
# BAD: Stating the obvious
|
||||
# Increment counter by 1
|
||||
count += 1
|
||||
|
||||
# Set name to user's name
|
||||
name = user.name
|
||||
```
|
||||
|
||||
### Docstrings for Public APIs
|
||||
|
||||
```python
|
||||
def extract_invoice_fields(
|
||||
pdf_path: Path,
|
||||
confidence_threshold: float = 0.5,
|
||||
use_gpu: bool = True
|
||||
) -> InferenceResult:
|
||||
"""Extract structured fields from Swedish invoice PDF.
|
||||
|
||||
Uses YOLOv11 for field detection and PaddleOCR for text extraction.
|
||||
Applies field-specific normalization and validation.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the invoice PDF file.
|
||||
confidence_threshold: Minimum confidence for field detection (0.0-1.0).
|
||||
use_gpu: Whether to use GPU acceleration.
|
||||
|
||||
Returns:
|
||||
InferenceResult containing extracted fields and confidence scores.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If PDF file doesn't exist.
|
||||
ProcessingError: If OCR or detection fails.
|
||||
|
||||
Example:
|
||||
>>> result = extract_invoice_fields(Path("invoice.pdf"))
|
||||
>>> print(result.fields["invoice_number"])
|
||||
"INV-2024-001"
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
## Performance Best Practices
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
from cachetools import TTLCache
|
||||
|
||||
# Static data: LRU cache
|
||||
@lru_cache(maxsize=100)
|
||||
def get_field_config(field_name: str) -> FieldConfig:
|
||||
"""Load field configuration (cached)."""
|
||||
return load_config(field_name)
|
||||
|
||||
# Dynamic data: TTL cache
|
||||
_document_cache = TTLCache(maxsize=1000, ttl=300) # 5 minutes
|
||||
|
||||
def get_document_cached(doc_id: str) -> Document | None:
|
||||
if doc_id in _document_cache:
|
||||
return _document_cache[doc_id]
|
||||
|
||||
doc = repo.find_by_id(doc_id)
|
||||
if doc:
|
||||
_document_cache[doc_id] = doc
|
||||
return doc
|
||||
```
|
||||
|
||||
### Database Queries
|
||||
|
||||
```python
|
||||
# GOOD: Select only needed columns
|
||||
cur.execute("""
|
||||
SELECT id, status, fields->>'invoice_number'
|
||||
FROM documents
|
||||
WHERE status = %s
|
||||
LIMIT %s
|
||||
""", ('processed', 10))
|
||||
|
||||
# BAD: Select everything
|
||||
cur.execute("SELECT * FROM documents")
|
||||
|
||||
# GOOD: Batch operations
|
||||
cur.executemany(
|
||||
"INSERT INTO labels (doc_id, field, value) VALUES (%s, %s, %s)",
|
||||
[(doc_id, f, v) for f, v in fields.items()]
|
||||
)
|
||||
|
||||
# BAD: Individual inserts in loop
|
||||
for field, value in fields.items():
|
||||
cur.execute("INSERT INTO labels ...", (doc_id, field, value))
|
||||
```
|
||||
|
||||
### Lazy Loading
|
||||
|
||||
```python
|
||||
class InferencePipeline:
|
||||
def __init__(self, model_path: Path):
|
||||
self.model_path = model_path
|
||||
self._model: YOLO | None = None
|
||||
self._ocr: PaddleOCR | None = None
|
||||
|
||||
@property
|
||||
def model(self) -> YOLO:
|
||||
"""Lazy load YOLO model."""
|
||||
if self._model is None:
|
||||
self._model = YOLO(str(self.model_path))
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def ocr(self) -> PaddleOCR:
|
||||
"""Lazy load PaddleOCR."""
|
||||
if self._ocr is None:
|
||||
self._ocr = PaddleOCR(use_angle_cls=True, lang="latin")
|
||||
return self._ocr
|
||||
```
|
||||
|
||||
## Testing Standards
|
||||
|
||||
### Test Structure (AAA Pattern)
|
||||
|
||||
```python
|
||||
def test_extract_bankgiro_valid():
|
||||
# Arrange
|
||||
text = "Bankgiro: 123-4567"
|
||||
|
||||
# Act
|
||||
result = extract_bankgiro(text)
|
||||
|
||||
# Assert
|
||||
assert result == "1234567"
|
||||
|
||||
def test_extract_bankgiro_invalid_returns_none():
|
||||
# Arrange
|
||||
text = "No bankgiro here"
|
||||
|
||||
# Act
|
||||
result = extract_bankgiro(text)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
```
|
||||
|
||||
### Test Naming
|
||||
|
||||
```python
|
||||
# GOOD: Descriptive test names
|
||||
def test_parse_payment_line_extracts_all_fields(): ...
|
||||
def test_parse_payment_line_handles_missing_checksum(): ...
|
||||
def test_validate_ocr_returns_false_for_invalid_checksum(): ...
|
||||
|
||||
# BAD: Vague test names
|
||||
def test_parse(): ...
|
||||
def test_works(): ...
|
||||
def test_payment_line(): ...
|
||||
```
|
||||
|
||||
### Fixtures
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.fixture
|
||||
def sample_invoice_pdf(tmp_path: Path) -> Path:
|
||||
"""Create sample invoice PDF for testing."""
|
||||
pdf_path = tmp_path / "invoice.pdf"
|
||||
# Create test PDF...
|
||||
return pdf_path
|
||||
|
||||
@pytest.fixture
|
||||
def inference_pipeline(sample_model_path: Path) -> InferencePipeline:
|
||||
"""Create inference pipeline with test model."""
|
||||
return InferencePipeline(sample_model_path)
|
||||
|
||||
def test_process_invoice(inference_pipeline, sample_invoice_pdf):
|
||||
result = inference_pipeline.process(sample_invoice_pdf)
|
||||
assert result.fields.get("invoice_number") is not None
|
||||
```
|
||||
|
||||
## Code Smell Detection
|
||||
|
||||
### 1. Long Functions
|
||||
|
||||
```python
|
||||
# BAD: Function > 50 lines
|
||||
def process_document():
|
||||
# 100 lines of code...
|
||||
|
||||
# GOOD: Split into smaller functions
|
||||
def process_document(pdf_path: Path) -> InferenceResult:
|
||||
image = render_pdf(pdf_path)
|
||||
detections = detect_fields(image)
|
||||
ocr_results = extract_text(image, detections)
|
||||
fields = normalize_fields(ocr_results)
|
||||
return build_result(fields)
|
||||
```
|
||||
|
||||
### 2. Deep Nesting
|
||||
|
||||
```python
|
||||
# BAD: 5+ levels of nesting
|
||||
if document:
|
||||
if document.is_valid:
|
||||
if document.has_fields:
|
||||
if field in document.fields:
|
||||
if document.fields[field]:
|
||||
# Do something
|
||||
|
||||
# GOOD: Early returns
|
||||
if not document:
|
||||
return None
|
||||
if not document.is_valid:
|
||||
return None
|
||||
if not document.has_fields:
|
||||
return None
|
||||
if field not in document.fields:
|
||||
return None
|
||||
if not document.fields[field]:
|
||||
return None
|
||||
|
||||
# Do something
|
||||
```
|
||||
|
||||
### 3. Magic Numbers
|
||||
|
||||
```python
|
||||
# BAD: Unexplained numbers
|
||||
if confidence > 0.5:
|
||||
...
|
||||
time.sleep(3)
|
||||
|
||||
# GOOD: Named constants
|
||||
CONFIDENCE_THRESHOLD = 0.5
|
||||
RETRY_DELAY_SECONDS = 3
|
||||
|
||||
if confidence > CONFIDENCE_THRESHOLD:
|
||||
...
|
||||
time.sleep(RETRY_DELAY_SECONDS)
|
||||
```
|
||||
|
||||
### 4. Mutable Default Arguments
|
||||
|
||||
```python
|
||||
# BAD: Mutable default argument
|
||||
def process_fields(fields: list = []): # DANGEROUS!
|
||||
fields.append("new_field")
|
||||
return fields
|
||||
|
||||
# GOOD: Use None as default
|
||||
def process_fields(fields: list | None = None) -> list:
|
||||
if fields is None:
|
||||
fields = []
|
||||
return [*fields, "new_field"]
|
||||
```
|
||||
|
||||
## Logging Standards
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
# Module-level logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# GOOD: Appropriate log levels
|
||||
logger.debug("Processing document: %s", doc_id)
|
||||
logger.info("Document processed successfully: %s", doc_id)
|
||||
logger.warning("Low confidence score: %.2f", confidence)
|
||||
logger.error("Failed to process document: %s", error)
|
||||
|
||||
# GOOD: Structured logging with extra data
|
||||
logger.info(
|
||||
"Inference complete",
|
||||
extra={
|
||||
"document_id": doc_id,
|
||||
"field_count": len(fields),
|
||||
"processing_time_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
|
||||
# BAD: Using print()
|
||||
print(f"Processing {doc_id}") # Never in production!
|
||||
```
|
||||
|
||||
**Remember**: Code quality is not negotiable. Clear, maintainable Python code with proper type hints enables confident development and refactoring.
|
||||
80
.claude/skills/continuous-learning/SKILL.md
Normal file
80
.claude/skills/continuous-learning/SKILL.md
Normal file
@@ -0,0 +1,80 @@
|
||||
---
|
||||
name: continuous-learning
|
||||
description: Automatically extract reusable patterns from Claude Code sessions and save them as learned skills for future use.
|
||||
---
|
||||
|
||||
# Continuous Learning Skill
|
||||
|
||||
Automatically evaluates Claude Code sessions on end to extract reusable patterns that can be saved as learned skills.
|
||||
|
||||
## How It Works
|
||||
|
||||
This skill runs as a **Stop hook** at the end of each session:
|
||||
|
||||
1. **Session Evaluation**: Checks if session has enough messages (default: 10+)
|
||||
2. **Pattern Detection**: Identifies extractable patterns from the session
|
||||
3. **Skill Extraction**: Saves useful patterns to `~/.claude/skills/learned/`
|
||||
|
||||
## Configuration
|
||||
|
||||
Edit `config.json` to customize:
|
||||
|
||||
```json
|
||||
{
|
||||
"min_session_length": 10,
|
||||
"extraction_threshold": "medium",
|
||||
"auto_approve": false,
|
||||
"learned_skills_path": "~/.claude/skills/learned/",
|
||||
"patterns_to_detect": [
|
||||
"error_resolution",
|
||||
"user_corrections",
|
||||
"workarounds",
|
||||
"debugging_techniques",
|
||||
"project_specific"
|
||||
],
|
||||
"ignore_patterns": [
|
||||
"simple_typos",
|
||||
"one_time_fixes",
|
||||
"external_api_issues"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern Types
|
||||
|
||||
| Pattern | Description |
|
||||
|---------|-------------|
|
||||
| `error_resolution` | How specific errors were resolved |
|
||||
| `user_corrections` | Patterns from user corrections |
|
||||
| `workarounds` | Solutions to framework/library quirks |
|
||||
| `debugging_techniques` | Effective debugging approaches |
|
||||
| `project_specific` | Project-specific conventions |
|
||||
|
||||
## Hook Setup
|
||||
|
||||
Add to your `~/.claude/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"Stop": [{
|
||||
"matcher": "*",
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": "~/.claude/skills/continuous-learning/evaluate-session.sh"
|
||||
}]
|
||||
}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Why Stop Hook?
|
||||
|
||||
- **Lightweight**: Runs once at session end
|
||||
- **Non-blocking**: Doesn't add latency to every message
|
||||
- **Complete context**: Has access to full session transcript
|
||||
|
||||
## Related
|
||||
|
||||
- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Section on continuous learning
|
||||
- `/learn` command - Manual pattern extraction mid-session
|
||||
18
.claude/skills/continuous-learning/config.json
Normal file
18
.claude/skills/continuous-learning/config.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"min_session_length": 10,
|
||||
"extraction_threshold": "medium",
|
||||
"auto_approve": false,
|
||||
"learned_skills_path": "~/.claude/skills/learned/",
|
||||
"patterns_to_detect": [
|
||||
"error_resolution",
|
||||
"user_corrections",
|
||||
"workarounds",
|
||||
"debugging_techniques",
|
||||
"project_specific"
|
||||
],
|
||||
"ignore_patterns": [
|
||||
"simple_typos",
|
||||
"one_time_fixes",
|
||||
"external_api_issues"
|
||||
]
|
||||
}
|
||||
60
.claude/skills/continuous-learning/evaluate-session.sh
Normal file
60
.claude/skills/continuous-learning/evaluate-session.sh
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
# Continuous Learning - Session Evaluator
|
||||
# Runs on Stop hook to extract reusable patterns from Claude Code sessions
|
||||
#
|
||||
# Why Stop hook instead of UserPromptSubmit:
|
||||
# - Stop runs once at session end (lightweight)
|
||||
# - UserPromptSubmit runs every message (heavy, adds latency)
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "Stop": [{
|
||||
# "matcher": "*",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/skills/continuous-learning/evaluate-session.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# Patterns to detect: error_resolution, debugging_techniques, workarounds, project_specific
|
||||
# Patterns to ignore: simple_typos, one_time_fixes, external_api_issues
|
||||
# Extracted skills saved to: ~/.claude/skills/learned/
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CONFIG_FILE="$SCRIPT_DIR/config.json"
|
||||
LEARNED_SKILLS_PATH="${HOME}/.claude/skills/learned"
|
||||
MIN_SESSION_LENGTH=10
|
||||
|
||||
# Load config if exists
|
||||
if [ -f "$CONFIG_FILE" ]; then
|
||||
MIN_SESSION_LENGTH=$(jq -r '.min_session_length // 10' "$CONFIG_FILE")
|
||||
LEARNED_SKILLS_PATH=$(jq -r '.learned_skills_path // "~/.claude/skills/learned/"' "$CONFIG_FILE" | sed "s|~|$HOME|")
|
||||
fi
|
||||
|
||||
# Ensure learned skills directory exists
|
||||
mkdir -p "$LEARNED_SKILLS_PATH"
|
||||
|
||||
# Get transcript path from environment (set by Claude Code)
|
||||
transcript_path="${CLAUDE_TRANSCRIPT_PATH:-}"
|
||||
|
||||
if [ -z "$transcript_path" ] || [ ! -f "$transcript_path" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Count messages in session
|
||||
message_count=$(grep -c '"type":"user"' "$transcript_path" 2>/dev/null || echo "0")
|
||||
|
||||
# Skip short sessions
|
||||
if [ "$message_count" -lt "$MIN_SESSION_LENGTH" ]; then
|
||||
echo "[ContinuousLearning] Session too short ($message_count messages), skipping" >&2
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Signal to Claude that session should be evaluated for extractable patterns
|
||||
echo "[ContinuousLearning] Session has $message_count messages - evaluate for extractable patterns" >&2
|
||||
echo "[ContinuousLearning] Save learned skills to: $LEARNED_SKILLS_PATH" >&2
|
||||
@@ -1,245 +0,0 @@
|
||||
---
|
||||
name: dev-builder
|
||||
description: 根据 Product-Spec.md 初始化项目、安装依赖、实现代码。与 product-spec-builder 配套使用,帮助用户将需求文档转化为可运行的代码项目。
|
||||
---
|
||||
|
||||
[角色]
|
||||
你是一位经验丰富的全栈开发工程师。
|
||||
|
||||
你能够根据产品需求文档快速搭建项目,选择合适的技术栈,编写高质量的代码。你注重代码结构清晰、可维护性强。
|
||||
|
||||
[任务]
|
||||
读取 Product-Spec.md,完成以下工作:
|
||||
1. 分析需求,确定项目类型和技术栈
|
||||
2. 初始化项目,创建目录结构
|
||||
3. 安装必要依赖,配置开发环境
|
||||
4. 实现代码(UI、功能、AI 集成)
|
||||
|
||||
最终交付可运行的项目代码。
|
||||
|
||||
[总体规则]
|
||||
- 必须先读取 Product-Spec.md,不存在则提示用户先完成需求收集
|
||||
- 每个阶段完成后输出进度反馈
|
||||
- 如有原型图,开发时参考原型图的视觉设计
|
||||
- 代码要简洁、可读、可维护
|
||||
- 优先使用简单方案,不过度设计
|
||||
- 只改与当前任务相关的文件,禁止「顺手升级依赖」「全局格式化」「无关重命名」
|
||||
- 始终使用中文与用户交流
|
||||
|
||||
[项目类型判断]
|
||||
根据 Product Spec 的 UI 布局和技术说明判断:
|
||||
- 有 UI + 纯前端/无需服务器 → 纯前端 Web 应用
|
||||
- 有 UI + 需要后端/数据库/API → 全栈 Web 应用
|
||||
- 无 UI + 命令行操作 → CLI 工具
|
||||
- 只是 API 服务 → 后端服务
|
||||
|
||||
[技术栈选择]
|
||||
| 项目类型 | 推荐技术栈 |
|
||||
|---------|-----------|
|
||||
| 纯前端 Web 应用 | React + Vite + TypeScript + Tailwind |
|
||||
| 全栈 Web 应用 | Next.js + TypeScript + Tailwind |
|
||||
| CLI 工具 | Node.js + TypeScript + Commander |
|
||||
| 后端服务 | Express + TypeScript |
|
||||
| AI/ML 应用 | Python + FastAPI + PyTorch/TensorFlow |
|
||||
| 数据处理工具 | Python + Pandas + NumPy |
|
||||
|
||||
**选择原则**:
|
||||
- Product Spec 技术说明有指定 → 用指定的
|
||||
- 没指定 → 用推荐方案
|
||||
- 有疑问 → 询问用户
|
||||
|
||||
[AI 研发方向]
|
||||
**适用场景**:
|
||||
- 机器学习模型训练与推理
|
||||
- 计算机视觉(目标检测、OCR、图像分类)
|
||||
- 自然语言处理(文本分类、命名实体识别、对话系统)
|
||||
- 大语言模型应用(RAG、Agent、Prompt Engineering)
|
||||
- 数据分析与可视化
|
||||
|
||||
**技术栈推荐**:
|
||||
| 方向 | 推荐技术栈 |
|
||||
|-----|-----------|
|
||||
| 深度学习 | PyTorch + Lightning + Weights & Biases |
|
||||
| 目标检测 | Ultralytics YOLO + OpenCV |
|
||||
| OCR | PaddleOCR / EasyOCR / Tesseract |
|
||||
| NLP | Transformers + spaCy |
|
||||
| LLM 应用 | LangChain / LlamaIndex + OpenAI API |
|
||||
| 数据处理 | Pandas + Polars + DuckDB |
|
||||
| 模型部署 | FastAPI + Docker + ONNX Runtime |
|
||||
|
||||
**项目结构(AI/ML 项目)**:
|
||||
```
|
||||
project/
|
||||
├── src/ # 源代码
|
||||
│ ├── data/ # 数据加载与预处理
|
||||
│ ├── models/ # 模型定义
|
||||
│ ├── training/ # 训练逻辑
|
||||
│ ├── inference/ # 推理逻辑
|
||||
│ └── utils/ # 工具函数
|
||||
├── configs/ # 配置文件(YAML)
|
||||
├── data/ # 数据目录
|
||||
│ ├── raw/ # 原始数据(不修改)
|
||||
│ └── processed/ # 处理后数据
|
||||
├── models/ # 训练好的模型权重
|
||||
├── notebooks/ # 实验 Notebook
|
||||
├── tests/ # 测试代码
|
||||
└── scripts/ # 运行脚本
|
||||
```
|
||||
|
||||
**AI 研发规范**:
|
||||
- **可复现性**:固定随机种子(random、numpy、torch),记录实验配置
|
||||
- **数据管理**:原始数据不可变,处理数据版本化
|
||||
- **实验追踪**:使用 MLflow/W&B 记录指标、参数、产物
|
||||
- **配置驱动**:所有超参数放 YAML 配置,禁止硬编码
|
||||
- **类型安全**:使用 Pydantic 定义数据结构
|
||||
- **日志规范**:使用 logging 模块,不用 print
|
||||
|
||||
**模型训练检查项**:
|
||||
- ✅ 数据集划分(train/val/test)比例合理
|
||||
- ✅ 早停机制(Early Stopping)防止过拟合
|
||||
- ✅ 学习率调度器配置
|
||||
- ✅ 模型检查点保存策略
|
||||
- ✅ 验证集指标监控
|
||||
- ✅ GPU 内存管理(混合精度训练)
|
||||
|
||||
**部署注意事项**:
|
||||
- 模型导出为 ONNX 格式提升推理速度
|
||||
- API 接口使用异步处理提升并发
|
||||
- 大文件使用流式传输
|
||||
- 配置健康检查端点
|
||||
- 日志和指标监控
|
||||
|
||||
[初始化提醒]
|
||||
**项目名称规范**:
|
||||
- 只能用小写字母、数字、短横线(如 my-app)
|
||||
- 不能有空格、&、# 等特殊字符
|
||||
|
||||
**npm 报错时**:可尝试 pnpm 或 yarn
|
||||
|
||||
[依赖选择]
|
||||
**原则**:只装需要的,不装「可能用到」的
|
||||
|
||||
[环境变量配置]
|
||||
**⚠️ 安全警告**:
|
||||
- Vite 纯前端:`VITE_` 前缀变量**会暴露给浏览器**,不能存放 API Key
|
||||
- Next.js:不加 `NEXT_PUBLIC_` 前缀的变量只在服务端可用(安全)
|
||||
|
||||
**涉及 AI API 调用时**:
|
||||
- 推荐用 Next.js(API Key 只在服务端使用,安全)
|
||||
- 备选:创建独立后端代理请求
|
||||
- 仅限开发/演示:使用 VITE_ 前缀(必须提醒用户安全风险)
|
||||
|
||||
**文件规范**:
|
||||
- 创建 `.env.example` 作为模板(提交到 Git)
|
||||
- 实际值放 `.env.local`(不提交,确保 .gitignore 包含)
|
||||
|
||||
[工作流程]
|
||||
[启动阶段]
|
||||
目的:检查前置条件,读取项目文档
|
||||
|
||||
第一步:检测 Product Spec
|
||||
检测 Product-Spec.md 是否存在
|
||||
不存在 → 提示:「未找到 Product-Spec.md,请先使用 /prd 完成需求收集。」,终止流程
|
||||
存在 → 继续
|
||||
|
||||
第二步:读取项目文档
|
||||
加载 Product-Spec.md
|
||||
提取:产品概述、功能需求、UI 布局、技术说明、AI 能力需求
|
||||
|
||||
第三步:检查原型图
|
||||
检查 UI-Prompts.md 是否存在
|
||||
存在 → 询问:「我看到你已经生成了原型图提示词,如果有生成的原型图图片,可以发给我参考。」
|
||||
不存在 → 询问:「是否有原型图或设计稿可以参考?有的话可以发给我。」
|
||||
|
||||
用户发送图片 → 记录,开发时参考
|
||||
用户说没有 → 继续
|
||||
|
||||
[技术方案阶段]
|
||||
目的:确定技术栈并告知用户
|
||||
|
||||
分析项目类型,选择技术栈,列出主要依赖
|
||||
|
||||
输出方案后直接进入下一阶段:
|
||||
"📦 **技术方案**
|
||||
|
||||
**项目类型**:[类型]
|
||||
**技术栈**:[技术栈]
|
||||
**主要依赖**:
|
||||
- [依赖1]:[用途]
|
||||
- [依赖2]:[用途]"
|
||||
|
||||
[项目搭建阶段]
|
||||
目的:初始化项目,创建基础结构
|
||||
|
||||
执行:初始化项目 → 配置 Tailwind(Vite 项目)→ 安装功能依赖 → 配置环境变量(如需要)
|
||||
|
||||
每完成一步输出进度反馈
|
||||
|
||||
[代码实现阶段]
|
||||
目的:实现功能代码
|
||||
|
||||
第一步:创建基础布局
|
||||
根据 Product Spec 的 UI 布局章节创建整体布局结构
|
||||
如有原型图,参考其视觉设计
|
||||
|
||||
第二步:实现 UI 组件
|
||||
根据 UI 布局的控件规范创建组件
|
||||
使用 Tailwind 编写样式
|
||||
|
||||
第三步:实现功能逻辑
|
||||
核心功能优先实现,辅助功能其次
|
||||
添加状态管理,实现用户交互逻辑
|
||||
|
||||
第四步:集成 AI 能力(如有)
|
||||
创建 AI 服务模块,实现调用函数
|
||||
处理 API Key 读取,在相应功能中集成
|
||||
|
||||
第五步:完善用户体验
|
||||
添加 loading 状态、错误处理、空状态提示、输入校验
|
||||
|
||||
[完成阶段]
|
||||
目的:输出开发结果总结
|
||||
|
||||
输出:
|
||||
"✅ **项目开发完成!**
|
||||
|
||||
**技术栈**:[技术栈]
|
||||
|
||||
**项目结构**:
|
||||
```
|
||||
[实际目录结构]
|
||||
```
|
||||
|
||||
**已实现功能**:
|
||||
- ✅ [功能1]
|
||||
- ✅ [功能2]
|
||||
- ...
|
||||
|
||||
**AI 能力集成**:
|
||||
- [已集成的 AI 能力,或「无」]
|
||||
|
||||
**环境变量**:
|
||||
- [需要配置的环境变量,或「无需配置」]"
|
||||
|
||||
[质量门槛]
|
||||
每个功能点至少满足:
|
||||
|
||||
**必须**:
|
||||
- ✅ 主路径可用(Happy Path 能跑通)
|
||||
- ✅ 异常路径清晰(错误提示、重试/回退)
|
||||
- ✅ loading 状态(涉及异步操作时)
|
||||
- ✅ 空状态处理(无数据时的提示)
|
||||
- ✅ 基础输入校验(必填、格式)
|
||||
- ✅ 敏感信息不写入代码(API Key 走环境变量)
|
||||
|
||||
**建议**:
|
||||
- 基础可访问性(可点击、可键盘操作)
|
||||
- 响应式适配(如需支持移动端)
|
||||
|
||||
[代码规范]
|
||||
- 单个文件不超过 300 行,超过则拆分
|
||||
- 优先使用函数组件 + Hooks
|
||||
- 样式优先用 Tailwind
|
||||
|
||||
[初始化]
|
||||
执行 [启动阶段]
|
||||
221
.claude/skills/eval-harness/SKILL.md
Normal file
221
.claude/skills/eval-harness/SKILL.md
Normal file
@@ -0,0 +1,221 @@
|
||||
# Eval Harness Skill
|
||||
|
||||
A formal evaluation framework for Claude Code sessions, implementing eval-driven development (EDD) principles.
|
||||
|
||||
## Philosophy
|
||||
|
||||
Eval-Driven Development treats evals as the "unit tests of AI development":
|
||||
- Define expected behavior BEFORE implementation
|
||||
- Run evals continuously during development
|
||||
- Track regressions with each change
|
||||
- Use pass@k metrics for reliability measurement
|
||||
|
||||
## Eval Types
|
||||
|
||||
### Capability Evals
|
||||
Test if Claude can do something it couldn't before:
|
||||
```markdown
|
||||
[CAPABILITY EVAL: feature-name]
|
||||
Task: Description of what Claude should accomplish
|
||||
Success Criteria:
|
||||
- [ ] Criterion 1
|
||||
- [ ] Criterion 2
|
||||
- [ ] Criterion 3
|
||||
Expected Output: Description of expected result
|
||||
```
|
||||
|
||||
### Regression Evals
|
||||
Ensure changes don't break existing functionality:
|
||||
```markdown
|
||||
[REGRESSION EVAL: feature-name]
|
||||
Baseline: SHA or checkpoint name
|
||||
Tests:
|
||||
- existing-test-1: PASS/FAIL
|
||||
- existing-test-2: PASS/FAIL
|
||||
- existing-test-3: PASS/FAIL
|
||||
Result: X/Y passed (previously Y/Y)
|
||||
```
|
||||
|
||||
## Grader Types
|
||||
|
||||
### 1. Code-Based Grader
|
||||
Deterministic checks using code:
|
||||
```bash
|
||||
# Check if file contains expected pattern
|
||||
grep -q "export function handleAuth" src/auth.ts && echo "PASS" || echo "FAIL"
|
||||
|
||||
# Check if tests pass
|
||||
npm test -- --testPathPattern="auth" && echo "PASS" || echo "FAIL"
|
||||
|
||||
# Check if build succeeds
|
||||
npm run build && echo "PASS" || echo "FAIL"
|
||||
```
|
||||
|
||||
### 2. Model-Based Grader
|
||||
Use Claude to evaluate open-ended outputs:
|
||||
```markdown
|
||||
[MODEL GRADER PROMPT]
|
||||
Evaluate the following code change:
|
||||
1. Does it solve the stated problem?
|
||||
2. Is it well-structured?
|
||||
3. Are edge cases handled?
|
||||
4. Is error handling appropriate?
|
||||
|
||||
Score: 1-5 (1=poor, 5=excellent)
|
||||
Reasoning: [explanation]
|
||||
```
|
||||
|
||||
### 3. Human Grader
|
||||
Flag for manual review:
|
||||
```markdown
|
||||
[HUMAN REVIEW REQUIRED]
|
||||
Change: Description of what changed
|
||||
Reason: Why human review is needed
|
||||
Risk Level: LOW/MEDIUM/HIGH
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
### pass@k
|
||||
"At least one success in k attempts"
|
||||
- pass@1: First attempt success rate
|
||||
- pass@3: Success within 3 attempts
|
||||
- Typical target: pass@3 > 90%
|
||||
|
||||
### pass^k
|
||||
"All k trials succeed"
|
||||
- Higher bar for reliability
|
||||
- pass^3: 3 consecutive successes
|
||||
- Use for critical paths
|
||||
|
||||
## Eval Workflow
|
||||
|
||||
### 1. Define (Before Coding)
|
||||
```markdown
|
||||
## EVAL DEFINITION: feature-xyz
|
||||
|
||||
### Capability Evals
|
||||
1. Can create new user account
|
||||
2. Can validate email format
|
||||
3. Can hash password securely
|
||||
|
||||
### Regression Evals
|
||||
1. Existing login still works
|
||||
2. Session management unchanged
|
||||
3. Logout flow intact
|
||||
|
||||
### Success Metrics
|
||||
- pass@3 > 90% for capability evals
|
||||
- pass^3 = 100% for regression evals
|
||||
```
|
||||
|
||||
### 2. Implement
|
||||
Write code to pass the defined evals.
|
||||
|
||||
### 3. Evaluate
|
||||
```bash
|
||||
# Run capability evals
|
||||
[Run each capability eval, record PASS/FAIL]
|
||||
|
||||
# Run regression evals
|
||||
npm test -- --testPathPattern="existing"
|
||||
|
||||
# Generate report
|
||||
```
|
||||
|
||||
### 4. Report
|
||||
```markdown
|
||||
EVAL REPORT: feature-xyz
|
||||
========================
|
||||
|
||||
Capability Evals:
|
||||
create-user: PASS (pass@1)
|
||||
validate-email: PASS (pass@2)
|
||||
hash-password: PASS (pass@1)
|
||||
Overall: 3/3 passed
|
||||
|
||||
Regression Evals:
|
||||
login-flow: PASS
|
||||
session-mgmt: PASS
|
||||
logout-flow: PASS
|
||||
Overall: 3/3 passed
|
||||
|
||||
Metrics:
|
||||
pass@1: 67% (2/3)
|
||||
pass@3: 100% (3/3)
|
||||
|
||||
Status: READY FOR REVIEW
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### Pre-Implementation
|
||||
```
|
||||
/eval define feature-name
|
||||
```
|
||||
Creates eval definition file at `.claude/evals/feature-name.md`
|
||||
|
||||
### During Implementation
|
||||
```
|
||||
/eval check feature-name
|
||||
```
|
||||
Runs current evals and reports status
|
||||
|
||||
### Post-Implementation
|
||||
```
|
||||
/eval report feature-name
|
||||
```
|
||||
Generates full eval report
|
||||
|
||||
## Eval Storage
|
||||
|
||||
Store evals in project:
|
||||
```
|
||||
.claude/
|
||||
evals/
|
||||
feature-xyz.md # Eval definition
|
||||
feature-xyz.log # Eval run history
|
||||
baseline.json # Regression baselines
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Define evals BEFORE coding** - Forces clear thinking about success criteria
|
||||
2. **Run evals frequently** - Catch regressions early
|
||||
3. **Track pass@k over time** - Monitor reliability trends
|
||||
4. **Use code graders when possible** - Deterministic > probabilistic
|
||||
5. **Human review for security** - Never fully automate security checks
|
||||
6. **Keep evals fast** - Slow evals don't get run
|
||||
7. **Version evals with code** - Evals are first-class artifacts
|
||||
|
||||
## Example: Adding Authentication
|
||||
|
||||
```markdown
|
||||
## EVAL: add-authentication
|
||||
|
||||
### Phase 1: Define (10 min)
|
||||
Capability Evals:
|
||||
- [ ] User can register with email/password
|
||||
- [ ] User can login with valid credentials
|
||||
- [ ] Invalid credentials rejected with proper error
|
||||
- [ ] Sessions persist across page reloads
|
||||
- [ ] Logout clears session
|
||||
|
||||
Regression Evals:
|
||||
- [ ] Public routes still accessible
|
||||
- [ ] API responses unchanged
|
||||
- [ ] Database schema compatible
|
||||
|
||||
### Phase 2: Implement (varies)
|
||||
[Write code]
|
||||
|
||||
### Phase 3: Evaluate
|
||||
Run: /eval check add-authentication
|
||||
|
||||
### Phase 4: Report
|
||||
EVAL REPORT: add-authentication
|
||||
==============================
|
||||
Capability: 5/5 passed (pass@3: 100%)
|
||||
Regression: 3/3 passed (pass^3: 100%)
|
||||
Status: SHIP IT
|
||||
```
|
||||
631
.claude/skills/frontend-patterns/SKILL.md
Normal file
631
.claude/skills/frontend-patterns/SKILL.md
Normal file
@@ -0,0 +1,631 @@
|
||||
---
|
||||
name: frontend-patterns
|
||||
description: Frontend development patterns for React, Next.js, state management, performance optimization, and UI best practices.
|
||||
---
|
||||
|
||||
# Frontend Development Patterns
|
||||
|
||||
Modern frontend patterns for React, Next.js, and performant user interfaces.
|
||||
|
||||
## Component Patterns
|
||||
|
||||
### Composition Over Inheritance
|
||||
|
||||
```typescript
|
||||
// ✅ GOOD: Component composition
|
||||
interface CardProps {
|
||||
children: React.ReactNode
|
||||
variant?: 'default' | 'outlined'
|
||||
}
|
||||
|
||||
export function Card({ children, variant = 'default' }: CardProps) {
|
||||
return <div className={`card card-${variant}`}>{children}</div>
|
||||
}
|
||||
|
||||
export function CardHeader({ children }: { children: React.ReactNode }) {
|
||||
return <div className="card-header">{children}</div>
|
||||
}
|
||||
|
||||
export function CardBody({ children }: { children: React.ReactNode }) {
|
||||
return <div className="card-body">{children}</div>
|
||||
}
|
||||
|
||||
// Usage
|
||||
<Card>
|
||||
<CardHeader>Title</CardHeader>
|
||||
<CardBody>Content</CardBody>
|
||||
</Card>
|
||||
```
|
||||
|
||||
### Compound Components
|
||||
|
||||
```typescript
|
||||
interface TabsContextValue {
|
||||
activeTab: string
|
||||
setActiveTab: (tab: string) => void
|
||||
}
|
||||
|
||||
const TabsContext = createContext<TabsContextValue | undefined>(undefined)
|
||||
|
||||
export function Tabs({ children, defaultTab }: {
|
||||
children: React.ReactNode
|
||||
defaultTab: string
|
||||
}) {
|
||||
const [activeTab, setActiveTab] = useState(defaultTab)
|
||||
|
||||
return (
|
||||
<TabsContext.Provider value={{ activeTab, setActiveTab }}>
|
||||
{children}
|
||||
</TabsContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function TabList({ children }: { children: React.ReactNode }) {
|
||||
return <div className="tab-list">{children}</div>
|
||||
}
|
||||
|
||||
export function Tab({ id, children }: { id: string, children: React.ReactNode }) {
|
||||
const context = useContext(TabsContext)
|
||||
if (!context) throw new Error('Tab must be used within Tabs')
|
||||
|
||||
return (
|
||||
<button
|
||||
className={context.activeTab === id ? 'active' : ''}
|
||||
onClick={() => context.setActiveTab(id)}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
// Usage
|
||||
<Tabs defaultTab="overview">
|
||||
<TabList>
|
||||
<Tab id="overview">Overview</Tab>
|
||||
<Tab id="details">Details</Tab>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
```
|
||||
|
||||
### Render Props Pattern
|
||||
|
||||
```typescript
|
||||
interface DataLoaderProps<T> {
|
||||
url: string
|
||||
children: (data: T | null, loading: boolean, error: Error | null) => React.ReactNode
|
||||
}
|
||||
|
||||
export function DataLoader<T>({ url, children }: DataLoaderProps<T>) {
|
||||
const [data, setData] = useState<T | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [error, setError] = useState<Error | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
fetch(url)
|
||||
.then(res => res.json())
|
||||
.then(setData)
|
||||
.catch(setError)
|
||||
.finally(() => setLoading(false))
|
||||
}, [url])
|
||||
|
||||
return <>{children(data, loading, error)}</>
|
||||
}
|
||||
|
||||
// Usage
|
||||
<DataLoader<Market[]> url="/api/markets">
|
||||
{(markets, loading, error) => {
|
||||
if (loading) return <Spinner />
|
||||
if (error) return <Error error={error} />
|
||||
return <MarketList markets={markets!} />
|
||||
}}
|
||||
</DataLoader>
|
||||
```
|
||||
|
||||
## Custom Hooks Patterns
|
||||
|
||||
### State Management Hook
|
||||
|
||||
```typescript
|
||||
export function useToggle(initialValue = false): [boolean, () => void] {
|
||||
const [value, setValue] = useState(initialValue)
|
||||
|
||||
const toggle = useCallback(() => {
|
||||
setValue(v => !v)
|
||||
}, [])
|
||||
|
||||
return [value, toggle]
|
||||
}
|
||||
|
||||
// Usage
|
||||
const [isOpen, toggleOpen] = useToggle()
|
||||
```
|
||||
|
||||
### Async Data Fetching Hook
|
||||
|
||||
```typescript
|
||||
interface UseQueryOptions<T> {
|
||||
onSuccess?: (data: T) => void
|
||||
onError?: (error: Error) => void
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export function useQuery<T>(
|
||||
key: string,
|
||||
fetcher: () => Promise<T>,
|
||||
options?: UseQueryOptions<T>
|
||||
) {
|
||||
const [data, setData] = useState<T | null>(null)
|
||||
const [error, setError] = useState<Error | null>(null)
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const refetch = useCallback(async () => {
|
||||
setLoading(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const result = await fetcher()
|
||||
setData(result)
|
||||
options?.onSuccess?.(result)
|
||||
} catch (err) {
|
||||
const error = err as Error
|
||||
setError(error)
|
||||
options?.onError?.(error)
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [fetcher, options])
|
||||
|
||||
useEffect(() => {
|
||||
if (options?.enabled !== false) {
|
||||
refetch()
|
||||
}
|
||||
}, [key, refetch, options?.enabled])
|
||||
|
||||
return { data, error, loading, refetch }
|
||||
}
|
||||
|
||||
// Usage
|
||||
const { data: markets, loading, error, refetch } = useQuery(
|
||||
'markets',
|
||||
() => fetch('/api/markets').then(r => r.json()),
|
||||
{
|
||||
onSuccess: data => console.log('Fetched', data.length, 'markets'),
|
||||
onError: err => console.error('Failed:', err)
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Debounce Hook
|
||||
|
||||
```typescript
|
||||
export function useDebounce<T>(value: T, delay: number): T {
|
||||
const [debouncedValue, setDebouncedValue] = useState<T>(value)
|
||||
|
||||
useEffect(() => {
|
||||
const handler = setTimeout(() => {
|
||||
setDebouncedValue(value)
|
||||
}, delay)
|
||||
|
||||
return () => clearTimeout(handler)
|
||||
}, [value, delay])
|
||||
|
||||
return debouncedValue
|
||||
}
|
||||
|
||||
// Usage
|
||||
const [searchQuery, setSearchQuery] = useState('')
|
||||
const debouncedQuery = useDebounce(searchQuery, 500)
|
||||
|
||||
useEffect(() => {
|
||||
if (debouncedQuery) {
|
||||
performSearch(debouncedQuery)
|
||||
}
|
||||
}, [debouncedQuery])
|
||||
```
|
||||
|
||||
## State Management Patterns
|
||||
|
||||
### Context + Reducer Pattern
|
||||
|
||||
```typescript
|
||||
interface State {
|
||||
markets: Market[]
|
||||
selectedMarket: Market | null
|
||||
loading: boolean
|
||||
}
|
||||
|
||||
type Action =
|
||||
| { type: 'SET_MARKETS'; payload: Market[] }
|
||||
| { type: 'SELECT_MARKET'; payload: Market }
|
||||
| { type: 'SET_LOADING'; payload: boolean }
|
||||
|
||||
function reducer(state: State, action: Action): State {
|
||||
switch (action.type) {
|
||||
case 'SET_MARKETS':
|
||||
return { ...state, markets: action.payload }
|
||||
case 'SELECT_MARKET':
|
||||
return { ...state, selectedMarket: action.payload }
|
||||
case 'SET_LOADING':
|
||||
return { ...state, loading: action.payload }
|
||||
default:
|
||||
return state
|
||||
}
|
||||
}
|
||||
|
||||
const MarketContext = createContext<{
|
||||
state: State
|
||||
dispatch: Dispatch<Action>
|
||||
} | undefined>(undefined)
|
||||
|
||||
export function MarketProvider({ children }: { children: React.ReactNode }) {
|
||||
const [state, dispatch] = useReducer(reducer, {
|
||||
markets: [],
|
||||
selectedMarket: null,
|
||||
loading: false
|
||||
})
|
||||
|
||||
return (
|
||||
<MarketContext.Provider value={{ state, dispatch }}>
|
||||
{children}
|
||||
</MarketContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function useMarkets() {
|
||||
const context = useContext(MarketContext)
|
||||
if (!context) throw new Error('useMarkets must be used within MarketProvider')
|
||||
return context
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Memoization
|
||||
|
||||
```typescript
|
||||
// ✅ useMemo for expensive computations
|
||||
const sortedMarkets = useMemo(() => {
|
||||
return markets.sort((a, b) => b.volume - a.volume)
|
||||
}, [markets])
|
||||
|
||||
// ✅ useCallback for functions passed to children
|
||||
const handleSearch = useCallback((query: string) => {
|
||||
setSearchQuery(query)
|
||||
}, [])
|
||||
|
||||
// ✅ React.memo for pure components
|
||||
export const MarketCard = React.memo<MarketCardProps>(({ market }) => {
|
||||
return (
|
||||
<div className="market-card">
|
||||
<h3>{market.name}</h3>
|
||||
<p>{market.description}</p>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
```
|
||||
|
||||
### Code Splitting & Lazy Loading
|
||||
|
||||
```typescript
|
||||
import { lazy, Suspense } from 'react'
|
||||
|
||||
// ✅ Lazy load heavy components
|
||||
const HeavyChart = lazy(() => import('./HeavyChart'))
|
||||
const ThreeJsBackground = lazy(() => import('./ThreeJsBackground'))
|
||||
|
||||
export function Dashboard() {
|
||||
return (
|
||||
<div>
|
||||
<Suspense fallback={<ChartSkeleton />}>
|
||||
<HeavyChart data={data} />
|
||||
</Suspense>
|
||||
|
||||
<Suspense fallback={null}>
|
||||
<ThreeJsBackground />
|
||||
</Suspense>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Virtualization for Long Lists
|
||||
|
||||
```typescript
|
||||
import { useVirtualizer } from '@tanstack/react-virtual'
|
||||
|
||||
export function VirtualMarketList({ markets }: { markets: Market[] }) {
|
||||
const parentRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const virtualizer = useVirtualizer({
|
||||
count: markets.length,
|
||||
getScrollElement: () => parentRef.current,
|
||||
estimateSize: () => 100, // Estimated row height
|
||||
overscan: 5 // Extra items to render
|
||||
})
|
||||
|
||||
return (
|
||||
<div ref={parentRef} style={{ height: '600px', overflow: 'auto' }}>
|
||||
<div
|
||||
style={{
|
||||
height: `${virtualizer.getTotalSize()}px`,
|
||||
position: 'relative'
|
||||
}}
|
||||
>
|
||||
{virtualizer.getVirtualItems().map(virtualRow => (
|
||||
<div
|
||||
key={virtualRow.index}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
width: '100%',
|
||||
height: `${virtualRow.size}px`,
|
||||
transform: `translateY(${virtualRow.start}px)`
|
||||
}}
|
||||
>
|
||||
<MarketCard market={markets[virtualRow.index]} />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Form Handling Patterns
|
||||
|
||||
### Controlled Form with Validation
|
||||
|
||||
```typescript
|
||||
interface FormData {
|
||||
name: string
|
||||
description: string
|
||||
endDate: string
|
||||
}
|
||||
|
||||
interface FormErrors {
|
||||
name?: string
|
||||
description?: string
|
||||
endDate?: string
|
||||
}
|
||||
|
||||
export function CreateMarketForm() {
|
||||
const [formData, setFormData] = useState<FormData>({
|
||||
name: '',
|
||||
description: '',
|
||||
endDate: ''
|
||||
})
|
||||
|
||||
const [errors, setErrors] = useState<FormErrors>({})
|
||||
|
||||
const validate = (): boolean => {
|
||||
const newErrors: FormErrors = {}
|
||||
|
||||
if (!formData.name.trim()) {
|
||||
newErrors.name = 'Name is required'
|
||||
} else if (formData.name.length > 200) {
|
||||
newErrors.name = 'Name must be under 200 characters'
|
||||
}
|
||||
|
||||
if (!formData.description.trim()) {
|
||||
newErrors.description = 'Description is required'
|
||||
}
|
||||
|
||||
if (!formData.endDate) {
|
||||
newErrors.endDate = 'End date is required'
|
||||
}
|
||||
|
||||
setErrors(newErrors)
|
||||
return Object.keys(newErrors).length === 0
|
||||
}
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
|
||||
if (!validate()) return
|
||||
|
||||
try {
|
||||
await createMarket(formData)
|
||||
// Success handling
|
||||
} catch (error) {
|
||||
// Error handling
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<input
|
||||
value={formData.name}
|
||||
onChange={e => setFormData(prev => ({ ...prev, name: e.target.value }))}
|
||||
placeholder="Market name"
|
||||
/>
|
||||
{errors.name && <span className="error">{errors.name}</span>}
|
||||
|
||||
{/* Other fields */}
|
||||
|
||||
<button type="submit">Create Market</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Error Boundary Pattern
|
||||
|
||||
```typescript
|
||||
interface ErrorBoundaryState {
|
||||
hasError: boolean
|
||||
error: Error | null
|
||||
}
|
||||
|
||||
export class ErrorBoundary extends React.Component<
|
||||
{ children: React.ReactNode },
|
||||
ErrorBoundaryState
|
||||
> {
|
||||
state: ErrorBoundaryState = {
|
||||
hasError: false,
|
||||
error: null
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
|
||||
return { hasError: true, error }
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
|
||||
console.error('Error boundary caught:', error, errorInfo)
|
||||
}
|
||||
|
||||
render() {
|
||||
if (this.state.hasError) {
|
||||
return (
|
||||
<div className="error-fallback">
|
||||
<h2>Something went wrong</h2>
|
||||
<p>{this.state.error?.message}</p>
|
||||
<button onClick={() => this.setState({ hasError: false })}>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return this.props.children
|
||||
}
|
||||
}
|
||||
|
||||
// Usage
|
||||
<ErrorBoundary>
|
||||
<App />
|
||||
</ErrorBoundary>
|
||||
```
|
||||
|
||||
## Animation Patterns
|
||||
|
||||
### Framer Motion Animations
|
||||
|
||||
```typescript
|
||||
import { motion, AnimatePresence } from 'framer-motion'
|
||||
|
||||
// ✅ List animations
|
||||
export function AnimatedMarketList({ markets }: { markets: Market[] }) {
|
||||
return (
|
||||
<AnimatePresence>
|
||||
{markets.map(market => (
|
||||
<motion.div
|
||||
key={market.id}
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -20 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
<MarketCard market={market} />
|
||||
</motion.div>
|
||||
))}
|
||||
</AnimatePresence>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ Modal animations
|
||||
export function Modal({ isOpen, onClose, children }: ModalProps) {
|
||||
return (
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<>
|
||||
<motion.div
|
||||
className="modal-overlay"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
onClick={onClose}
|
||||
/>
|
||||
<motion.div
|
||||
className="modal-content"
|
||||
initial={{ opacity: 0, scale: 0.9, y: 20 }}
|
||||
animate={{ opacity: 1, scale: 1, y: 0 }}
|
||||
exit={{ opacity: 0, scale: 0.9, y: 20 }}
|
||||
>
|
||||
{children}
|
||||
</motion.div>
|
||||
</>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Accessibility Patterns
|
||||
|
||||
### Keyboard Navigation
|
||||
|
||||
```typescript
|
||||
export function Dropdown({ options, onSelect }: DropdownProps) {
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
const [activeIndex, setActiveIndex] = useState(0)
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
switch (e.key) {
|
||||
case 'ArrowDown':
|
||||
e.preventDefault()
|
||||
setActiveIndex(i => Math.min(i + 1, options.length - 1))
|
||||
break
|
||||
case 'ArrowUp':
|
||||
e.preventDefault()
|
||||
setActiveIndex(i => Math.max(i - 1, 0))
|
||||
break
|
||||
case 'Enter':
|
||||
e.preventDefault()
|
||||
onSelect(options[activeIndex])
|
||||
setIsOpen(false)
|
||||
break
|
||||
case 'Escape':
|
||||
setIsOpen(false)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
role="combobox"
|
||||
aria-expanded={isOpen}
|
||||
aria-haspopup="listbox"
|
||||
onKeyDown={handleKeyDown}
|
||||
>
|
||||
{/* Dropdown implementation */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Focus Management
|
||||
|
||||
```typescript
|
||||
export function Modal({ isOpen, onClose, children }: ModalProps) {
|
||||
const modalRef = useRef<HTMLDivElement>(null)
|
||||
const previousFocusRef = useRef<HTMLElement | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
// Save currently focused element
|
||||
previousFocusRef.current = document.activeElement as HTMLElement
|
||||
|
||||
// Focus modal
|
||||
modalRef.current?.focus()
|
||||
} else {
|
||||
// Restore focus when closing
|
||||
previousFocusRef.current?.focus()
|
||||
}
|
||||
}, [isOpen])
|
||||
|
||||
return isOpen ? (
|
||||
<div
|
||||
ref={modalRef}
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
tabIndex={-1}
|
||||
onKeyDown={e => e.key === 'Escape' && onClose()}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
) : null
|
||||
}
|
||||
```
|
||||
|
||||
**Remember**: Modern frontend patterns enable maintainable, performant user interfaces. Choose patterns that fit your project complexity.
|
||||
345
.claude/skills/project-guidelines-example/SKILL.md
Normal file
345
.claude/skills/project-guidelines-example/SKILL.md
Normal file
@@ -0,0 +1,345 @@
|
||||
# Project Guidelines Skill (Example)
|
||||
|
||||
This is an example of a project-specific skill. Use this as a template for your own projects.
|
||||
|
||||
Based on a real production application: [Zenith](https://zenith.chat) - AI-powered customer discovery platform.
|
||||
|
||||
---
|
||||
|
||||
## When to Use
|
||||
|
||||
Reference this skill when working on the specific project it's designed for. Project skills contain:
|
||||
- Architecture overview
|
||||
- File structure
|
||||
- Code patterns
|
||||
- Testing requirements
|
||||
- Deployment workflow
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
**Tech Stack:**
|
||||
- **Frontend**: Next.js 15 (App Router), TypeScript, React
|
||||
- **Backend**: FastAPI (Python), Pydantic models
|
||||
- **Database**: Supabase (PostgreSQL)
|
||||
- **AI**: Claude API with tool calling and structured output
|
||||
- **Deployment**: Google Cloud Run
|
||||
- **Testing**: Playwright (E2E), pytest (backend), React Testing Library
|
||||
|
||||
**Services:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Frontend │
|
||||
│ Next.js 15 + TypeScript + TailwindCSS │
|
||||
│ Deployed: Vercel / Cloud Run │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Backend │
|
||||
│ FastAPI + Python 3.11 + Pydantic │
|
||||
│ Deployed: Cloud Run │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────┼───────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌──────────┐ ┌──────────┐
|
||||
│ Supabase │ │ Claude │ │ Redis │
|
||||
│ Database │ │ API │ │ Cache │
|
||||
└──────────┘ └──────────┘ └──────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
project/
|
||||
├── frontend/
|
||||
│ └── src/
|
||||
│ ├── app/ # Next.js app router pages
|
||||
│ │ ├── api/ # API routes
|
||||
│ │ ├── (auth)/ # Auth-protected routes
|
||||
│ │ └── workspace/ # Main app workspace
|
||||
│ ├── components/ # React components
|
||||
│ │ ├── ui/ # Base UI components
|
||||
│ │ ├── forms/ # Form components
|
||||
│ │ └── layouts/ # Layout components
|
||||
│ ├── hooks/ # Custom React hooks
|
||||
│ ├── lib/ # Utilities
|
||||
│ ├── types/ # TypeScript definitions
|
||||
│ └── config/ # Configuration
|
||||
│
|
||||
├── backend/
|
||||
│ ├── routers/ # FastAPI route handlers
|
||||
│ ├── models.py # Pydantic models
|
||||
│ ├── main.py # FastAPI app entry
|
||||
│ ├── auth_system.py # Authentication
|
||||
│ ├── database.py # Database operations
|
||||
│ ├── services/ # Business logic
|
||||
│ └── tests/ # pytest tests
|
||||
│
|
||||
├── deploy/ # Deployment configs
|
||||
├── docs/ # Documentation
|
||||
└── scripts/ # Utility scripts
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Code Patterns
|
||||
|
||||
### API Response Format (FastAPI)
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
from typing import Generic, TypeVar, Optional
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
success: bool
|
||||
data: Optional[T] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def ok(cls, data: T) -> "ApiResponse[T]":
|
||||
return cls(success=True, data=data)
|
||||
|
||||
@classmethod
|
||||
def fail(cls, error: str) -> "ApiResponse[T]":
|
||||
return cls(success=False, error=error)
|
||||
```
|
||||
|
||||
### Frontend API Calls (TypeScript)
|
||||
|
||||
```typescript
|
||||
interface ApiResponse<T> {
|
||||
success: boolean
|
||||
data?: T
|
||||
error?: string
|
||||
}
|
||||
|
||||
async function fetchApi<T>(
|
||||
endpoint: string,
|
||||
options?: RequestInit
|
||||
): Promise<ApiResponse<T>> {
|
||||
try {
|
||||
const response = await fetch(`/api${endpoint}`, {
|
||||
...options,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...options?.headers,
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
return { success: false, error: `HTTP ${response.status}` }
|
||||
}
|
||||
|
||||
return await response.json()
|
||||
} catch (error) {
|
||||
return { success: false, error: String(error) }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Claude AI Integration (Structured Output)
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
summary: str
|
||||
key_points: list[str]
|
||||
confidence: float
|
||||
|
||||
async def analyze_with_claude(content: str) -> AnalysisResult:
|
||||
client = Anthropic()
|
||||
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250514",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
tools=[{
|
||||
"name": "provide_analysis",
|
||||
"description": "Provide structured analysis",
|
||||
"input_schema": AnalysisResult.model_json_schema()
|
||||
}],
|
||||
tool_choice={"type": "tool", "name": "provide_analysis"}
|
||||
)
|
||||
|
||||
# Extract tool use result
|
||||
tool_use = next(
|
||||
block for block in response.content
|
||||
if block.type == "tool_use"
|
||||
)
|
||||
|
||||
return AnalysisResult(**tool_use.input)
|
||||
```
|
||||
|
||||
### Custom Hooks (React)
|
||||
|
||||
```typescript
|
||||
import { useState, useCallback } from 'react'
|
||||
|
||||
interface UseApiState<T> {
|
||||
data: T | null
|
||||
loading: boolean
|
||||
error: string | null
|
||||
}
|
||||
|
||||
export function useApi<T>(
|
||||
fetchFn: () => Promise<ApiResponse<T>>
|
||||
) {
|
||||
const [state, setState] = useState<UseApiState<T>>({
|
||||
data: null,
|
||||
loading: false,
|
||||
error: null,
|
||||
})
|
||||
|
||||
const execute = useCallback(async () => {
|
||||
setState(prev => ({ ...prev, loading: true, error: null }))
|
||||
|
||||
const result = await fetchFn()
|
||||
|
||||
if (result.success) {
|
||||
setState({ data: result.data!, loading: false, error: null })
|
||||
} else {
|
||||
setState({ data: null, loading: false, error: result.error! })
|
||||
}
|
||||
}, [fetchFn])
|
||||
|
||||
return { ...state, execute }
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing Requirements
|
||||
|
||||
### Backend (pytest)
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
poetry run pytest tests/
|
||||
|
||||
# Run with coverage
|
||||
poetry run pytest tests/ --cov=. --cov-report=html
|
||||
|
||||
# Run specific test file
|
||||
poetry run pytest tests/test_auth.py -v
|
||||
```
|
||||
|
||||
**Test structure:**
|
||||
```python
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from main import app
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "healthy"
|
||||
```
|
||||
|
||||
### Frontend (React Testing Library)
|
||||
|
||||
```bash
|
||||
# Run tests
|
||||
npm run test
|
||||
|
||||
# Run with coverage
|
||||
npm run test -- --coverage
|
||||
|
||||
# Run E2E tests
|
||||
npm run test:e2e
|
||||
```
|
||||
|
||||
**Test structure:**
|
||||
```typescript
|
||||
import { render, screen, fireEvent } from '@testing-library/react'
|
||||
import { WorkspacePanel } from './WorkspacePanel'
|
||||
|
||||
describe('WorkspacePanel', () => {
|
||||
it('renders workspace correctly', () => {
|
||||
render(<WorkspacePanel />)
|
||||
expect(screen.getByRole('main')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles session creation', async () => {
|
||||
render(<WorkspacePanel />)
|
||||
fireEvent.click(screen.getByText('New Session'))
|
||||
expect(await screen.findByText('Session created')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment Workflow
|
||||
|
||||
### Pre-Deployment Checklist
|
||||
|
||||
- [ ] All tests passing locally
|
||||
- [ ] `npm run build` succeeds (frontend)
|
||||
- [ ] `poetry run pytest` passes (backend)
|
||||
- [ ] No hardcoded secrets
|
||||
- [ ] Environment variables documented
|
||||
- [ ] Database migrations ready
|
||||
|
||||
### Deployment Commands
|
||||
|
||||
```bash
|
||||
# Build and deploy frontend
|
||||
cd frontend && npm run build
|
||||
gcloud run deploy frontend --source .
|
||||
|
||||
# Build and deploy backend
|
||||
cd backend
|
||||
gcloud run deploy backend --source .
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Frontend (.env.local)
|
||||
NEXT_PUBLIC_API_URL=https://api.example.com
|
||||
NEXT_PUBLIC_SUPABASE_URL=https://xxx.supabase.co
|
||||
NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJ...
|
||||
|
||||
# Backend (.env)
|
||||
DATABASE_URL=postgresql://...
|
||||
ANTHROPIC_API_KEY=sk-ant-...
|
||||
SUPABASE_URL=https://xxx.supabase.co
|
||||
SUPABASE_KEY=eyJ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Rules
|
||||
|
||||
1. **No emojis** in code, comments, or documentation
|
||||
2. **Immutability** - never mutate objects or arrays
|
||||
3. **TDD** - write tests before implementation
|
||||
4. **80% coverage** minimum
|
||||
5. **Many small files** - 200-400 lines typical, 800 max
|
||||
6. **No console.log** in production code
|
||||
7. **Proper error handling** with try/catch
|
||||
8. **Input validation** with Pydantic/Zod
|
||||
|
||||
---
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `coding-standards.md` - General coding best practices
|
||||
- `backend-patterns.md` - API and database patterns
|
||||
- `frontend-patterns.md` - React and Next.js patterns
|
||||
- `tdd-workflow/` - Test-driven development methodology
|
||||
568
.claude/skills/security-review/SKILL.md
Normal file
568
.claude/skills/security-review/SKILL.md
Normal file
@@ -0,0 +1,568 @@
|
||||
---
|
||||
name: security-review
|
||||
description: Use this skill when adding authentication, handling user input, working with secrets, creating API endpoints, or implementing payment/sensitive features. Provides comprehensive security checklist and patterns.
|
||||
---
|
||||
|
||||
# Security Review Skill
|
||||
|
||||
Security best practices for Python/FastAPI applications handling sensitive invoice data.
|
||||
|
||||
## When to Activate
|
||||
|
||||
- Implementing authentication or authorization
|
||||
- Handling user input or file uploads
|
||||
- Creating new API endpoints
|
||||
- Working with secrets or credentials
|
||||
- Processing sensitive invoice data
|
||||
- Integrating third-party APIs
|
||||
- Database operations with user data
|
||||
|
||||
## Security Checklist
|
||||
|
||||
### 1. Secrets Management
|
||||
|
||||
#### NEVER Do This
|
||||
```python
|
||||
# Hardcoded secrets - CRITICAL VULNERABILITY
|
||||
api_key = "sk-proj-xxxxx"
|
||||
db_password = "password123"
|
||||
```
|
||||
|
||||
#### ALWAYS Do This
|
||||
```python
|
||||
import os
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
db_password: str
|
||||
api_key: str
|
||||
model_path: str = "runs/train/invoice_fields/weights/best.pt"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Verify secrets exist
|
||||
if not settings.db_password:
|
||||
raise RuntimeError("DB_PASSWORD not configured")
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] No hardcoded API keys, tokens, or passwords
|
||||
- [ ] All secrets in environment variables
|
||||
- [ ] `.env` in .gitignore
|
||||
- [ ] No secrets in git history
|
||||
- [ ] `.env.example` with placeholder values
|
||||
|
||||
### 2. Input Validation
|
||||
|
||||
#### Always Validate User Input
|
||||
```python
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from fastapi import HTTPException
|
||||
import re
|
||||
|
||||
class InvoiceRequest(BaseModel):
|
||||
invoice_number: str = Field(..., min_length=1, max_length=50)
|
||||
amount: float = Field(..., gt=0, le=1_000_000)
|
||||
bankgiro: str | None = None
|
||||
|
||||
@field_validator("invoice_number")
|
||||
@classmethod
|
||||
def validate_invoice_number(cls, v: str) -> str:
|
||||
# Whitelist validation - only allow safe characters
|
||||
if not re.match(r"^[A-Za-z0-9\-_]+$", v):
|
||||
raise ValueError("Invalid invoice number format")
|
||||
return v
|
||||
|
||||
@field_validator("bankgiro")
|
||||
@classmethod
|
||||
def validate_bankgiro(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
cleaned = re.sub(r"[^0-9]", "", v)
|
||||
if not (7 <= len(cleaned) <= 8):
|
||||
raise ValueError("Bankgiro must be 7-8 digits")
|
||||
return cleaned
|
||||
```
|
||||
|
||||
#### File Upload Validation
|
||||
```python
|
||||
from fastapi import UploadFile, HTTPException
|
||||
from pathlib import Path
|
||||
|
||||
ALLOWED_EXTENSIONS = {".pdf"}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
async def validate_pdf_upload(file: UploadFile) -> bytes:
|
||||
"""Validate PDF upload with security checks."""
|
||||
# Extension check
|
||||
ext = Path(file.filename or "").suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(400, f"Only PDF files allowed, got {ext}")
|
||||
|
||||
# Read content
|
||||
content = await file.read()
|
||||
|
||||
# Size check
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise HTTPException(400, f"File too large (max {MAX_FILE_SIZE // 1024 // 1024}MB)")
|
||||
|
||||
# Magic bytes check (PDF signature)
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise HTTPException(400, "Invalid PDF file format")
|
||||
|
||||
return content
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] All user inputs validated with Pydantic
|
||||
- [ ] File uploads restricted (size, type, extension, magic bytes)
|
||||
- [ ] No direct use of user input in queries
|
||||
- [ ] Whitelist validation (not blacklist)
|
||||
- [ ] Error messages don't leak sensitive info
|
||||
|
||||
### 3. SQL Injection Prevention
|
||||
|
||||
#### NEVER Concatenate SQL
|
||||
```python
|
||||
# DANGEROUS - SQL Injection vulnerability
|
||||
query = f"SELECT * FROM documents WHERE id = '{user_input}'"
|
||||
cur.execute(query)
|
||||
```
|
||||
|
||||
#### ALWAYS Use Parameterized Queries
|
||||
```python
|
||||
import psycopg2
|
||||
|
||||
# Safe - parameterized query with %s placeholders
|
||||
cur.execute(
|
||||
"SELECT * FROM documents WHERE id = %s AND status = %s",
|
||||
(document_id, status)
|
||||
)
|
||||
|
||||
# Safe - named parameters
|
||||
cur.execute(
|
||||
"SELECT * FROM documents WHERE id = %(id)s",
|
||||
{"id": document_id}
|
||||
)
|
||||
|
||||
# Safe - psycopg2.sql for dynamic identifiers
|
||||
from psycopg2 import sql
|
||||
|
||||
cur.execute(
|
||||
sql.SQL("SELECT {} FROM {} WHERE id = %s").format(
|
||||
sql.Identifier("invoice_number"),
|
||||
sql.Identifier("documents")
|
||||
),
|
||||
(document_id,)
|
||||
)
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] All database queries use parameterized queries (%s or %(name)s)
|
||||
- [ ] No string concatenation or f-strings in SQL
|
||||
- [ ] psycopg2.sql module used for dynamic identifiers
|
||||
- [ ] No user input in table/column names
|
||||
|
||||
### 4. Path Traversal Prevention
|
||||
|
||||
#### NEVER Trust User Paths
|
||||
```python
|
||||
# DANGEROUS - Path traversal vulnerability
|
||||
filename = request.query_params.get("file")
|
||||
with open(f"/data/{filename}", "r") as f: # Attacker: ../../../etc/passwd
|
||||
return f.read()
|
||||
```
|
||||
|
||||
#### ALWAYS Validate Paths
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
ALLOWED_DIR = Path("/data/uploads").resolve()
|
||||
|
||||
def get_safe_path(filename: str) -> Path:
|
||||
"""Get safe file path, preventing path traversal."""
|
||||
# Remove any path components
|
||||
safe_name = Path(filename).name
|
||||
|
||||
# Validate filename characters
|
||||
if not re.match(r"^[A-Za-z0-9_\-\.]+$", safe_name):
|
||||
raise HTTPException(400, "Invalid filename")
|
||||
|
||||
# Resolve and verify within allowed directory
|
||||
full_path = (ALLOWED_DIR / safe_name).resolve()
|
||||
|
||||
if not full_path.is_relative_to(ALLOWED_DIR):
|
||||
raise HTTPException(400, "Invalid file path")
|
||||
|
||||
return full_path
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] User-provided filenames sanitized
|
||||
- [ ] Paths resolved and validated against allowed directory
|
||||
- [ ] No direct concatenation of user input into paths
|
||||
- [ ] Whitelist characters in filenames
|
||||
|
||||
### 5. Authentication & Authorization
|
||||
|
||||
#### API Key Validation
|
||||
```python
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def verify_api_key(api_key: str = Security(api_key_header)) -> str:
|
||||
if not api_key:
|
||||
raise HTTPException(401, "API key required")
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
import hmac
|
||||
if not hmac.compare_digest(api_key, settings.api_key):
|
||||
raise HTTPException(403, "Invalid API key")
|
||||
|
||||
return api_key
|
||||
|
||||
@router.post("/infer")
|
||||
async def infer(
|
||||
file: UploadFile,
|
||||
api_key: str = Depends(verify_api_key)
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
#### Role-Based Access Control
|
||||
```python
|
||||
from enum import Enum
|
||||
|
||||
class UserRole(str, Enum):
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
|
||||
def require_role(required_role: UserRole):
|
||||
async def role_checker(current_user: User = Depends(get_current_user)):
|
||||
if current_user.role != required_role:
|
||||
raise HTTPException(403, "Insufficient permissions")
|
||||
return current_user
|
||||
return role_checker
|
||||
|
||||
@router.delete("/documents/{doc_id}")
|
||||
async def delete_document(
|
||||
doc_id: str,
|
||||
user: User = Depends(require_role(UserRole.ADMIN))
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] API keys validated with constant-time comparison
|
||||
- [ ] Authorization checks before sensitive operations
|
||||
- [ ] Role-based access control implemented
|
||||
- [ ] Session/token validation on protected routes
|
||||
|
||||
### 6. Rate Limiting
|
||||
|
||||
#### Rate Limiter Implementation
|
||||
```python
|
||||
from time import time
|
||||
from collections import defaultdict
|
||||
from fastapi import Request, HTTPException
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self):
|
||||
self.requests: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def check_limit(
|
||||
self,
|
||||
identifier: str,
|
||||
max_requests: int,
|
||||
window_seconds: int
|
||||
) -> bool:
|
||||
now = time()
|
||||
# Clean old requests
|
||||
self.requests[identifier] = [
|
||||
t for t in self.requests[identifier]
|
||||
if now - t < window_seconds
|
||||
]
|
||||
# Check limit
|
||||
if len(self.requests[identifier]) >= max_requests:
|
||||
return False
|
||||
self.requests[identifier].append(now)
|
||||
return True
|
||||
|
||||
limiter = RateLimiter()
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# 100 requests per minute for general endpoints
|
||||
if not limiter.check_limit(client_ip, max_requests=100, window_seconds=60):
|
||||
raise HTTPException(429, "Rate limit exceeded. Try again later.")
|
||||
|
||||
return await call_next(request)
|
||||
```
|
||||
|
||||
#### Stricter Limits for Expensive Operations
|
||||
```python
|
||||
# Inference endpoint: 10 requests per minute
|
||||
async def check_inference_rate_limit(request: Request):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
if not limiter.check_limit(f"infer:{client_ip}", max_requests=10, window_seconds=60):
|
||||
raise HTTPException(429, "Inference rate limit exceeded")
|
||||
|
||||
@router.post("/infer")
|
||||
async def infer(
|
||||
file: UploadFile,
|
||||
_: None = Depends(check_inference_rate_limit)
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] Rate limiting on all API endpoints
|
||||
- [ ] Stricter limits on expensive operations (inference, OCR)
|
||||
- [ ] IP-based rate limiting
|
||||
- [ ] Clear error messages for rate-limited requests
|
||||
|
||||
### 7. Sensitive Data Exposure
|
||||
|
||||
#### Logging
|
||||
```python
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# WRONG: Logging sensitive data
|
||||
logger.info(f"Processing invoice: {invoice_data}") # May contain sensitive info
|
||||
logger.error(f"DB error with password: {db_password}")
|
||||
|
||||
# CORRECT: Redact sensitive data
|
||||
logger.info(f"Processing invoice: id={doc_id}")
|
||||
logger.error(f"DB connection failed to {db_host}:{db_port}")
|
||||
|
||||
# CORRECT: Structured logging with safe fields only
|
||||
logger.info(
|
||||
"Invoice processed",
|
||||
extra={
|
||||
"document_id": doc_id,
|
||||
"field_count": len(fields),
|
||||
"processing_time_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
#### Error Messages
|
||||
```python
|
||||
# WRONG: Exposing internal details
|
||||
@app.exception_handler(Exception)
|
||||
async def error_handler(request: Request, exc: Exception):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc() # NEVER expose!
|
||||
}
|
||||
)
|
||||
|
||||
# CORRECT: Generic error messages
|
||||
@app.exception_handler(Exception)
|
||||
async def error_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled error: {exc}", exc_info=True) # Log internally
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"success": False, "error": "An error occurred"}
|
||||
)
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] No passwords, tokens, or secrets in logs
|
||||
- [ ] Error messages generic for users
|
||||
- [ ] Detailed errors only in server logs
|
||||
- [ ] No stack traces exposed to users
|
||||
- [ ] Invoice data (amounts, account numbers) not logged
|
||||
|
||||
### 8. CORS Configuration
|
||||
|
||||
```python
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# WRONG: Allow all origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # DANGEROUS in production
|
||||
allow_credentials=True,
|
||||
)
|
||||
|
||||
# CORRECT: Specific origins
|
||||
ALLOWED_ORIGINS = [
|
||||
"http://localhost:8000",
|
||||
"https://your-domain.com",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] CORS origins explicitly listed
|
||||
- [ ] No wildcard origins in production
|
||||
- [ ] Credentials only with specific origins
|
||||
|
||||
### 9. Temporary File Security
|
||||
|
||||
```python
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def secure_temp_file(suffix: str = ".pdf"):
|
||||
"""Create secure temporary file that is always cleaned up."""
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=suffix,
|
||||
delete=False,
|
||||
dir="/tmp/invoice-master" # Dedicated temp directory
|
||||
) as tmp:
|
||||
tmp_path = Path(tmp.name)
|
||||
yield tmp_path
|
||||
finally:
|
||||
if tmp_path and tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
|
||||
# Usage
|
||||
async def process_upload(file: UploadFile):
|
||||
with secure_temp_file(".pdf") as tmp_path:
|
||||
content = await validate_pdf_upload(file)
|
||||
tmp_path.write_bytes(content)
|
||||
result = pipeline.process(tmp_path)
|
||||
# File automatically cleaned up
|
||||
return result
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] Temporary files always cleaned up (use context managers)
|
||||
- [ ] Temp directory has restricted permissions
|
||||
- [ ] No leftover files after processing errors
|
||||
|
||||
### 10. Dependency Security
|
||||
|
||||
#### Regular Updates
|
||||
```bash
|
||||
# Check for vulnerabilities
|
||||
pip-audit
|
||||
|
||||
# Update dependencies
|
||||
pip install --upgrade -r requirements.txt
|
||||
|
||||
# Check for outdated packages
|
||||
pip list --outdated
|
||||
```
|
||||
|
||||
#### Lock Files
|
||||
```bash
|
||||
# Create requirements lock file
|
||||
pip freeze > requirements.lock
|
||||
|
||||
# Install from lock file for reproducible builds
|
||||
pip install -r requirements.lock
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] Dependencies up to date
|
||||
- [ ] No known vulnerabilities (pip-audit clean)
|
||||
- [ ] requirements.txt pinned versions
|
||||
- [ ] Regular security updates scheduled
|
||||
|
||||
## Security Testing
|
||||
|
||||
### Automated Security Tests
|
||||
```python
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
def test_requires_api_key(client: TestClient):
|
||||
"""Test authentication required."""
|
||||
response = client.post("/api/v1/infer")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_invalid_api_key_rejected(client: TestClient):
|
||||
"""Test invalid API key rejected."""
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
headers={"X-API-Key": "invalid-key"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_sql_injection_prevented(client: TestClient):
|
||||
"""Test SQL injection attempt rejected."""
|
||||
response = client.get(
|
||||
"/api/v1/documents",
|
||||
params={"id": "'; DROP TABLE documents; --"}
|
||||
)
|
||||
# Should return validation error, not execute SQL
|
||||
assert response.status_code in (400, 422)
|
||||
|
||||
def test_path_traversal_prevented(client: TestClient):
|
||||
"""Test path traversal attempt rejected."""
|
||||
response = client.get("/api/v1/results/../../etc/passwd")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_rate_limit_enforced(client: TestClient):
|
||||
"""Test rate limiting works."""
|
||||
responses = [
|
||||
client.post("/api/v1/infer", files={"file": b"test"})
|
||||
for _ in range(15)
|
||||
]
|
||||
rate_limited = [r for r in responses if r.status_code == 429]
|
||||
assert len(rate_limited) > 0
|
||||
|
||||
def test_large_file_rejected(client: TestClient):
|
||||
"""Test file size limit enforced."""
|
||||
large_content = b"x" * (11 * 1024 * 1024) # 11MB
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", large_content)}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
```
|
||||
|
||||
## Pre-Deployment Security Checklist
|
||||
|
||||
Before ANY production deployment:
|
||||
|
||||
- [ ] **Secrets**: No hardcoded secrets, all in env vars
|
||||
- [ ] **Input Validation**: All user inputs validated with Pydantic
|
||||
- [ ] **SQL Injection**: All queries use parameterized queries
|
||||
- [ ] **Path Traversal**: File paths validated and sanitized
|
||||
- [ ] **Authentication**: API key or token validation
|
||||
- [ ] **Authorization**: Role checks in place
|
||||
- [ ] **Rate Limiting**: Enabled on all endpoints
|
||||
- [ ] **HTTPS**: Enforced in production
|
||||
- [ ] **CORS**: Properly configured (no wildcards)
|
||||
- [ ] **Error Handling**: No sensitive data in errors
|
||||
- [ ] **Logging**: No sensitive data logged
|
||||
- [ ] **File Uploads**: Validated (size, type, magic bytes)
|
||||
- [ ] **Temp Files**: Always cleaned up
|
||||
- [ ] **Dependencies**: Up to date, no vulnerabilities
|
||||
|
||||
## Resources
|
||||
|
||||
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
|
||||
- [FastAPI Security](https://fastapi.tiangolo.com/tutorial/security/)
|
||||
- [Bandit (Python Security Linter)](https://bandit.readthedocs.io/)
|
||||
- [pip-audit](https://pypi.org/project/pip-audit/)
|
||||
|
||||
---
|
||||
|
||||
**Remember**: Security is not optional. One vulnerability can compromise sensitive invoice data. When in doubt, err on the side of caution.
|
||||
63
.claude/skills/strategic-compact/SKILL.md
Normal file
63
.claude/skills/strategic-compact/SKILL.md
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
name: strategic-compact
|
||||
description: Suggests manual context compaction at logical intervals to preserve context through task phases rather than arbitrary auto-compaction.
|
||||
---
|
||||
|
||||
# Strategic Compact Skill
|
||||
|
||||
Suggests manual `/compact` at strategic points in your workflow rather than relying on arbitrary auto-compaction.
|
||||
|
||||
## Why Strategic Compaction?
|
||||
|
||||
Auto-compaction triggers at arbitrary points:
|
||||
- Often mid-task, losing important context
|
||||
- No awareness of logical task boundaries
|
||||
- Can interrupt complex multi-step operations
|
||||
|
||||
Strategic compaction at logical boundaries:
|
||||
- **After exploration, before execution** - Compact research context, keep implementation plan
|
||||
- **After completing a milestone** - Fresh start for next phase
|
||||
- **Before major context shifts** - Clear exploration context before different task
|
||||
|
||||
## How It Works
|
||||
|
||||
The `suggest-compact.sh` script runs on PreToolUse (Edit/Write) and:
|
||||
|
||||
1. **Tracks tool calls** - Counts tool invocations in session
|
||||
2. **Threshold detection** - Suggests at configurable threshold (default: 50 calls)
|
||||
3. **Periodic reminders** - Reminds every 25 calls after threshold
|
||||
|
||||
## Hook Setup
|
||||
|
||||
Add to your `~/.claude/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"PreToolUse": [{
|
||||
"matcher": "tool == \"Edit\" || tool == \"Write\"",
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
|
||||
}]
|
||||
}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Environment variables:
|
||||
- `COMPACT_THRESHOLD` - Tool calls before first suggestion (default: 50)
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Compact after planning** - Once plan is finalized, compact to start fresh
|
||||
2. **Compact after debugging** - Clear error-resolution context before continuing
|
||||
3. **Don't compact mid-implementation** - Preserve context for related changes
|
||||
4. **Read the suggestion** - The hook tells you *when*, you decide *if*
|
||||
|
||||
## Related
|
||||
|
||||
- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Token optimization section
|
||||
- Memory persistence hooks - For state that survives compaction
|
||||
52
.claude/skills/strategic-compact/suggest-compact.sh
Normal file
52
.claude/skills/strategic-compact/suggest-compact.sh
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Strategic Compact Suggester
|
||||
# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals
|
||||
#
|
||||
# Why manual over auto-compact:
|
||||
# - Auto-compact happens at arbitrary points, often mid-task
|
||||
# - Strategic compacting preserves context through logical phases
|
||||
# - Compact after exploration, before execution
|
||||
# - Compact after completing a milestone, before starting next
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "PreToolUse": [{
|
||||
# "matcher": "Edit|Write",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# Criteria for suggesting compact:
|
||||
# - Session has been running for extended period
|
||||
# - Large number of tool calls made
|
||||
# - Transitioning from research/exploration to implementation
|
||||
# - Plan has been finalized
|
||||
|
||||
# Track tool call count (increment in a temp file)
|
||||
COUNTER_FILE="/tmp/claude-tool-count-$$"
|
||||
THRESHOLD=${COMPACT_THRESHOLD:-50}
|
||||
|
||||
# Initialize or increment counter
|
||||
if [ -f "$COUNTER_FILE" ]; then
|
||||
count=$(cat "$COUNTER_FILE")
|
||||
count=$((count + 1))
|
||||
echo "$count" > "$COUNTER_FILE"
|
||||
else
|
||||
echo "1" > "$COUNTER_FILE"
|
||||
count=1
|
||||
fi
|
||||
|
||||
# Suggest compact after threshold tool calls
|
||||
if [ "$count" -eq "$THRESHOLD" ]; then
|
||||
echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2
|
||||
fi
|
||||
|
||||
# Suggest at regular intervals after threshold
|
||||
if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then
|
||||
echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2
|
||||
fi
|
||||
553
.claude/skills/tdd-workflow/SKILL.md
Normal file
553
.claude/skills/tdd-workflow/SKILL.md
Normal file
@@ -0,0 +1,553 @@
|
||||
---
|
||||
name: tdd-workflow
|
||||
description: Use this skill when writing new features, fixing bugs, or refactoring code. Enforces test-driven development with 80%+ coverage including unit, integration, and E2E tests.
|
||||
---
|
||||
|
||||
# Test-Driven Development Workflow
|
||||
|
||||
TDD principles for Python/FastAPI development with pytest.
|
||||
|
||||
## When to Activate
|
||||
|
||||
- Writing new features or functionality
|
||||
- Fixing bugs or issues
|
||||
- Refactoring existing code
|
||||
- Adding API endpoints
|
||||
- Creating new field extractors or normalizers
|
||||
|
||||
## Core Principles
|
||||
|
||||
### 1. Tests BEFORE Code
|
||||
ALWAYS write tests first, then implement code to make tests pass.
|
||||
|
||||
### 2. Coverage Requirements
|
||||
- Minimum 80% coverage (unit + integration + E2E)
|
||||
- All edge cases covered
|
||||
- Error scenarios tested
|
||||
- Boundary conditions verified
|
||||
|
||||
### 3. Test Types
|
||||
|
||||
#### Unit Tests
|
||||
- Individual functions and utilities
|
||||
- Normalizers and validators
|
||||
- Parsers and extractors
|
||||
- Pure functions
|
||||
|
||||
#### Integration Tests
|
||||
- API endpoints
|
||||
- Database operations
|
||||
- OCR + YOLO pipeline
|
||||
- Service interactions
|
||||
|
||||
#### E2E Tests
|
||||
- Complete inference pipeline
|
||||
- PDF → Fields workflow
|
||||
- API health and inference endpoints
|
||||
|
||||
## TDD Workflow Steps
|
||||
|
||||
### Step 1: Write User Journeys
|
||||
```
|
||||
As a [role], I want to [action], so that [benefit]
|
||||
|
||||
Example:
|
||||
As an invoice processor, I want to extract Bankgiro from payment_line,
|
||||
so that I can cross-validate OCR results.
|
||||
```
|
||||
|
||||
### Step 2: Generate Test Cases
|
||||
For each user journey, create comprehensive test cases:
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
class TestPaymentLineParser:
|
||||
"""Tests for payment_line parsing and field extraction."""
|
||||
|
||||
def test_parse_payment_line_extracts_bankgiro(self):
|
||||
"""Should extract Bankgiro from valid payment line."""
|
||||
# Test implementation
|
||||
pass
|
||||
|
||||
def test_parse_payment_line_handles_missing_checksum(self):
|
||||
"""Should handle payment lines without checksum."""
|
||||
pass
|
||||
|
||||
def test_parse_payment_line_validates_checksum(self):
|
||||
"""Should validate checksum when present."""
|
||||
pass
|
||||
|
||||
def test_parse_payment_line_returns_none_for_invalid(self):
|
||||
"""Should return None for invalid payment lines."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Step 3: Run Tests (They Should Fail)
|
||||
```bash
|
||||
pytest tests/test_ocr/test_machine_code_parser.py -v
|
||||
# Tests should fail - we haven't implemented yet
|
||||
```
|
||||
|
||||
### Step 4: Implement Code
|
||||
Write minimal code to make tests pass:
|
||||
|
||||
```python
|
||||
def parse_payment_line(line: str) -> PaymentLineData | None:
|
||||
"""Parse Swedish payment line and extract fields."""
|
||||
# Implementation guided by tests
|
||||
pass
|
||||
```
|
||||
|
||||
### Step 5: Run Tests Again
|
||||
```bash
|
||||
pytest tests/test_ocr/test_machine_code_parser.py -v
|
||||
# Tests should now pass
|
||||
```
|
||||
|
||||
### Step 6: Refactor
|
||||
Improve code quality while keeping tests green:
|
||||
- Remove duplication
|
||||
- Improve naming
|
||||
- Optimize performance
|
||||
- Enhance readability
|
||||
|
||||
### Step 7: Verify Coverage
|
||||
```bash
|
||||
pytest --cov=src --cov-report=term-missing
|
||||
# Verify 80%+ coverage achieved
|
||||
```
|
||||
|
||||
## Testing Patterns
|
||||
|
||||
### Unit Test Pattern (pytest)
|
||||
```python
|
||||
import pytest
|
||||
from src.normalize.bankgiro_normalizer import normalize_bankgiro
|
||||
|
||||
class TestBankgiroNormalizer:
|
||||
"""Tests for Bankgiro normalization."""
|
||||
|
||||
def test_normalize_removes_hyphens(self):
|
||||
"""Should remove hyphens from Bankgiro."""
|
||||
result = normalize_bankgiro("123-4567")
|
||||
assert result == "1234567"
|
||||
|
||||
def test_normalize_removes_spaces(self):
|
||||
"""Should remove spaces from Bankgiro."""
|
||||
result = normalize_bankgiro("123 4567")
|
||||
assert result == "1234567"
|
||||
|
||||
def test_normalize_validates_length(self):
|
||||
"""Should validate Bankgiro is 7-8 digits."""
|
||||
result = normalize_bankgiro("123456") # 6 digits
|
||||
assert result is None
|
||||
|
||||
def test_normalize_validates_checksum(self):
|
||||
"""Should validate Luhn checksum."""
|
||||
result = normalize_bankgiro("1234568") # Invalid checksum
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize("input_value,expected", [
|
||||
("123-4567", "1234567"),
|
||||
("1234567", "1234567"),
|
||||
("123 4567", "1234567"),
|
||||
("BG 123-4567", "1234567"),
|
||||
])
|
||||
def test_normalize_various_formats(self, input_value, expected):
|
||||
"""Should handle various input formats."""
|
||||
result = normalize_bankgiro(input_value)
|
||||
assert result == expected
|
||||
```
|
||||
|
||||
### API Integration Test Pattern
|
||||
```python
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from src.web.app import app
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for /api/v1/health endpoint."""
|
||||
|
||||
def test_health_returns_200(self, client):
|
||||
"""Should return 200 OK."""
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_health_returns_status(self, client):
|
||||
"""Should return health status."""
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "model_loaded" in data
|
||||
|
||||
class TestInferEndpoint:
|
||||
"""Tests for /api/v1/infer endpoint."""
|
||||
|
||||
def test_infer_requires_file(self, client):
|
||||
"""Should require file upload."""
|
||||
response = client.post("/api/v1/infer")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_infer_rejects_non_pdf(self, client):
|
||||
"""Should reject non-PDF files."""
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.txt", b"not a pdf", "text/plain")}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_infer_returns_fields(self, client, sample_invoice_pdf):
|
||||
"""Should return extracted fields."""
|
||||
with open(sample_invoice_pdf, "rb") as f:
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("invoice.pdf", f, "application/pdf")}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "fields" in data
|
||||
```
|
||||
|
||||
### E2E Test Pattern
|
||||
```python
|
||||
import pytest
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def running_server():
|
||||
"""Ensure server is running for E2E tests."""
|
||||
# Server should be started before running E2E tests
|
||||
base_url = "http://localhost:8000"
|
||||
yield base_url
|
||||
|
||||
class TestInferencePipeline:
|
||||
"""E2E tests for complete inference pipeline."""
|
||||
|
||||
def test_health_check(self, running_server):
|
||||
"""Should pass health check."""
|
||||
response = httpx.get(f"{running_server}/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["model_loaded"] is True
|
||||
|
||||
def test_pdf_inference_returns_fields(self, running_server):
|
||||
"""Should extract fields from PDF."""
|
||||
pdf_path = Path("tests/fixtures/sample_invoice.pdf")
|
||||
with open(pdf_path, "rb") as f:
|
||||
response = httpx.post(
|
||||
f"{running_server}/api/v1/infer",
|
||||
files={"file": ("invoice.pdf", f, "application/pdf")}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "fields" in data
|
||||
assert len(data["fields"]) > 0
|
||||
|
||||
def test_cross_validation_included(self, running_server):
|
||||
"""Should include cross-validation for invoices with payment_line."""
|
||||
pdf_path = Path("tests/fixtures/invoice_with_payment_line.pdf")
|
||||
with open(pdf_path, "rb") as f:
|
||||
response = httpx.post(
|
||||
f"{running_server}/api/v1/infer",
|
||||
files={"file": ("invoice.pdf", f, "application/pdf")}
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if data["fields"].get("payment_line"):
|
||||
assert "cross_validation" in data
|
||||
```
|
||||
|
||||
## Test File Organization
|
||||
|
||||
```
|
||||
tests/
|
||||
├── conftest.py # Shared fixtures
|
||||
├── fixtures/ # Test data files
|
||||
│ ├── sample_invoice.pdf
|
||||
│ └── invoice_with_payment_line.pdf
|
||||
├── test_cli/
|
||||
│ └── test_infer.py
|
||||
├── test_pdf/
|
||||
│ ├── test_extractor.py
|
||||
│ └── test_renderer.py
|
||||
├── test_ocr/
|
||||
│ ├── test_paddle_ocr.py
|
||||
│ └── test_machine_code_parser.py
|
||||
├── test_inference/
|
||||
│ ├── test_pipeline.py
|
||||
│ ├── test_yolo_detector.py
|
||||
│ └── test_field_extractor.py
|
||||
├── test_normalize/
|
||||
│ ├── test_bankgiro_normalizer.py
|
||||
│ ├── test_date_normalizer.py
|
||||
│ └── test_amount_normalizer.py
|
||||
├── test_web/
|
||||
│ ├── test_routes.py
|
||||
│ └── test_services.py
|
||||
└── e2e/
|
||||
└── test_inference_e2e.py
|
||||
```
|
||||
|
||||
## Mocking External Services
|
||||
|
||||
### Mock PaddleOCR
|
||||
```python
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paddle_ocr():
|
||||
"""Mock PaddleOCR for unit tests."""
|
||||
with patch("src.ocr.paddle_ocr.PaddleOCR") as mock:
|
||||
instance = Mock()
|
||||
instance.ocr.return_value = [
|
||||
[
|
||||
[[[0, 0], [100, 0], [100, 20], [0, 20]], ("Invoice Number", 0.95)],
|
||||
[[[0, 30], [100, 30], [100, 50], [0, 50]], ("INV-2024-001", 0.98)]
|
||||
]
|
||||
]
|
||||
mock.return_value = instance
|
||||
yield instance
|
||||
```
|
||||
|
||||
### Mock YOLO Model
|
||||
```python
|
||||
@pytest.fixture
|
||||
def mock_yolo_model():
|
||||
"""Mock YOLO model for unit tests."""
|
||||
with patch("src.inference.yolo_detector.YOLO") as mock:
|
||||
instance = Mock()
|
||||
# Mock detection results
|
||||
instance.return_value = Mock(
|
||||
boxes=Mock(
|
||||
xyxy=[[10, 20, 100, 50]],
|
||||
conf=[0.95],
|
||||
cls=[0] # invoice_number class
|
||||
)
|
||||
)
|
||||
mock.return_value = instance
|
||||
yield instance
|
||||
```
|
||||
|
||||
### Mock Database
|
||||
```python
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Mock database connection for unit tests."""
|
||||
with patch("src.data.db.get_db_connection") as mock:
|
||||
conn = Mock()
|
||||
cursor = Mock()
|
||||
cursor.fetchall.return_value = [
|
||||
("doc-123", "processed", {"invoice_number": "INV-001"})
|
||||
]
|
||||
cursor.fetchone.return_value = ("doc-123",)
|
||||
conn.cursor.return_value.__enter__ = Mock(return_value=cursor)
|
||||
conn.cursor.return_value.__exit__ = Mock(return_value=False)
|
||||
mock.return_value.__enter__ = Mock(return_value=conn)
|
||||
mock.return_value.__exit__ = Mock(return_value=False)
|
||||
yield conn
|
||||
```
|
||||
|
||||
## Test Coverage Verification
|
||||
|
||||
### Run Coverage Report
|
||||
```bash
|
||||
# Run with coverage
|
||||
pytest --cov=src --cov-report=term-missing
|
||||
|
||||
# Generate HTML report
|
||||
pytest --cov=src --cov-report=html
|
||||
# Open htmlcov/index.html in browser
|
||||
```
|
||||
|
||||
### Coverage Configuration (pyproject.toml)
|
||||
```toml
|
||||
[tool.coverage.run]
|
||||
source = ["src"]
|
||||
omit = ["*/__init__.py", "*/test_*.py"]
|
||||
|
||||
[tool.coverage.report]
|
||||
fail_under = 80
|
||||
show_missing = true
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"if TYPE_CHECKING:",
|
||||
"raise NotImplementedError",
|
||||
]
|
||||
```
|
||||
|
||||
## Common Testing Mistakes to Avoid
|
||||
|
||||
### WRONG: Testing Implementation Details
|
||||
```python
|
||||
# Don't test internal state
|
||||
def test_parser_internal_state():
|
||||
parser = PaymentLineParser()
|
||||
parser._parse("...")
|
||||
assert parser._groups == [...] # Internal state
|
||||
```
|
||||
|
||||
### CORRECT: Test Public Interface
|
||||
```python
|
||||
# Test what users see
|
||||
def test_parser_extracts_bankgiro():
|
||||
result = parse_payment_line("...")
|
||||
assert result.bankgiro == "1234567"
|
||||
```
|
||||
|
||||
### WRONG: No Test Isolation
|
||||
```python
|
||||
# Tests depend on each other
|
||||
class TestDocuments:
|
||||
def test_creates_document(self):
|
||||
create_document(...) # Creates in DB
|
||||
|
||||
def test_updates_document(self):
|
||||
update_document(...) # Depends on previous test
|
||||
```
|
||||
|
||||
### CORRECT: Independent Tests
|
||||
```python
|
||||
# Each test sets up its own data
|
||||
class TestDocuments:
|
||||
def test_creates_document(self, mock_db):
|
||||
result = create_document(...)
|
||||
assert result.id is not None
|
||||
|
||||
def test_updates_document(self, mock_db):
|
||||
# Create own test data
|
||||
doc = create_document(...)
|
||||
result = update_document(doc.id, ...)
|
||||
assert result.status == "updated"
|
||||
```
|
||||
|
||||
### WRONG: Testing Too Much
|
||||
```python
|
||||
# One test doing everything
|
||||
def test_full_invoice_processing():
|
||||
# Load PDF
|
||||
# Extract images
|
||||
# Run YOLO
|
||||
# Run OCR
|
||||
# Normalize fields
|
||||
# Save to DB
|
||||
# Return response
|
||||
```
|
||||
|
||||
### CORRECT: Focused Tests
|
||||
```python
|
||||
def test_yolo_detects_invoice_number():
|
||||
"""Test only YOLO detection."""
|
||||
result = detector.detect(image)
|
||||
assert any(d.label == "invoice_number" for d in result)
|
||||
|
||||
def test_ocr_extracts_text():
|
||||
"""Test only OCR extraction."""
|
||||
result = ocr.extract(image, bbox)
|
||||
assert result == "INV-2024-001"
|
||||
|
||||
def test_normalizer_formats_date():
|
||||
"""Test only date normalization."""
|
||||
result = normalize_date("2024-01-15")
|
||||
assert result == "2024-01-15"
|
||||
```
|
||||
|
||||
## Fixtures (conftest.py)
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@pytest.fixture
|
||||
def sample_invoice_pdf(tmp_path: Path) -> Path:
|
||||
"""Create sample invoice PDF for testing."""
|
||||
pdf_path = tmp_path / "invoice.pdf"
|
||||
# Copy from fixtures or create minimal PDF
|
||||
src = Path("tests/fixtures/sample_invoice.pdf")
|
||||
if src.exists():
|
||||
pdf_path.write_bytes(src.read_bytes())
|
||||
return pdf_path
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""FastAPI test client."""
|
||||
from src.web.app import app
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_payment_line() -> str:
|
||||
"""Sample Swedish payment line for testing."""
|
||||
return "1234567#0000000012345#230115#00012345678901234567#1"
|
||||
```
|
||||
|
||||
## Continuous Testing
|
||||
|
||||
### Watch Mode During Development
|
||||
```bash
|
||||
# Using pytest-watch
|
||||
ptw -- tests/test_ocr/
|
||||
# Tests run automatically on file changes
|
||||
```
|
||||
|
||||
### Pre-Commit Hook
|
||||
```bash
|
||||
# .pre-commit-config.yaml
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest
|
||||
name: pytest
|
||||
entry: pytest --tb=short -q
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
```
|
||||
|
||||
### CI/CD Integration (GitHub Actions)
|
||||
```yaml
|
||||
- name: Run Tests
|
||||
run: |
|
||||
pytest --cov=src --cov-report=xml
|
||||
|
||||
- name: Upload Coverage
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: coverage.xml
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Write Tests First** - Always TDD
|
||||
2. **One Assert Per Test** - Focus on single behavior
|
||||
3. **Descriptive Test Names** - `test_<what>_<condition>_<expected>`
|
||||
4. **Arrange-Act-Assert** - Clear test structure
|
||||
5. **Mock External Dependencies** - Isolate unit tests
|
||||
6. **Test Edge Cases** - None, empty, invalid, boundary
|
||||
7. **Test Error Paths** - Not just happy paths
|
||||
8. **Keep Tests Fast** - Unit tests < 50ms each
|
||||
9. **Clean Up After Tests** - Use fixtures with cleanup
|
||||
10. **Review Coverage Reports** - Identify gaps
|
||||
|
||||
## Success Metrics
|
||||
|
||||
- 80%+ code coverage achieved
|
||||
- All tests passing (green)
|
||||
- No skipped or disabled tests
|
||||
- Fast test execution (< 60s for unit tests)
|
||||
- E2E tests cover critical inference flow
|
||||
- Tests catch bugs before production
|
||||
|
||||
---
|
||||
|
||||
**Remember**: Tests are not optional. They are the safety net that enables confident refactoring, rapid development, and production reliability.
|
||||
242
.claude/skills/verification-loop/SKILL.md
Normal file
242
.claude/skills/verification-loop/SKILL.md
Normal file
@@ -0,0 +1,242 @@
|
||||
# Verification Loop Skill
|
||||
|
||||
Comprehensive verification system for Python/FastAPI development.
|
||||
|
||||
## When to Use
|
||||
|
||||
Invoke this skill:
|
||||
- After completing a feature or significant code change
|
||||
- Before creating a PR
|
||||
- When you want to ensure quality gates pass
|
||||
- After refactoring
|
||||
- Before deployment
|
||||
|
||||
## Verification Phases
|
||||
|
||||
### Phase 1: Type Check
|
||||
```bash
|
||||
# Run mypy type checker
|
||||
mypy src/ --ignore-missing-imports 2>&1 | head -30
|
||||
```
|
||||
|
||||
Report all type errors. Fix critical ones before continuing.
|
||||
|
||||
### Phase 2: Lint Check
|
||||
```bash
|
||||
# Run ruff linter
|
||||
ruff check src/ 2>&1 | head -30
|
||||
|
||||
# Auto-fix if desired
|
||||
ruff check src/ --fix
|
||||
```
|
||||
|
||||
Check for:
|
||||
- Unused imports
|
||||
- Code style violations
|
||||
- Common Python anti-patterns
|
||||
|
||||
### Phase 3: Test Suite
|
||||
```bash
|
||||
# Run tests with coverage
|
||||
pytest --cov=src --cov-report=term-missing -q 2>&1 | tail -50
|
||||
|
||||
# Run specific test file
|
||||
pytest tests/test_ocr/test_machine_code_parser.py -v
|
||||
|
||||
# Run with short traceback
|
||||
pytest -x --tb=short
|
||||
```
|
||||
|
||||
Report:
|
||||
- Total tests: X
|
||||
- Passed: X
|
||||
- Failed: X
|
||||
- Coverage: X%
|
||||
- Target: 80% minimum
|
||||
|
||||
### Phase 4: Security Scan
|
||||
```bash
|
||||
# Check for hardcoded secrets
|
||||
grep -rn "password\s*=" --include="*.py" src/ 2>/dev/null | grep -v "db_password:" | head -10
|
||||
grep -rn "api_key\s*=" --include="*.py" src/ 2>/dev/null | head -10
|
||||
grep -rn "sk-" --include="*.py" src/ 2>/dev/null | head -10
|
||||
|
||||
# Check for print statements (should use logging)
|
||||
grep -rn "print(" --include="*.py" src/ 2>/dev/null | head -10
|
||||
|
||||
# Check for bare except
|
||||
grep -rn "except:" --include="*.py" src/ 2>/dev/null | head -10
|
||||
|
||||
# Check for SQL injection risks (f-strings in execute)
|
||||
grep -rn 'execute(f"' --include="*.py" src/ 2>/dev/null | head -10
|
||||
grep -rn "execute(f'" --include="*.py" src/ 2>/dev/null | head -10
|
||||
```
|
||||
|
||||
### Phase 5: Import Check
|
||||
```bash
|
||||
# Verify all imports work
|
||||
python -c "from src.web.app import app; print('Web app OK')"
|
||||
python -c "from src.inference.pipeline import InferencePipeline; print('Pipeline OK')"
|
||||
python -c "from src.ocr.machine_code_parser import parse_payment_line; print('Parser OK')"
|
||||
```
|
||||
|
||||
### Phase 6: Diff Review
|
||||
```bash
|
||||
# Show what changed
|
||||
git diff --stat
|
||||
git diff HEAD --name-only
|
||||
|
||||
# Show staged changes
|
||||
git diff --staged --stat
|
||||
```
|
||||
|
||||
Review each changed file for:
|
||||
- Unintended changes
|
||||
- Missing error handling
|
||||
- Potential edge cases
|
||||
- Missing type hints
|
||||
- Mutable default arguments
|
||||
|
||||
### Phase 7: API Smoke Test (if server running)
|
||||
```bash
|
||||
# Health check
|
||||
curl -s http://localhost:8000/api/v1/health | python -m json.tool
|
||||
|
||||
# Verify response format
|
||||
curl -s http://localhost:8000/api/v1/health | grep -q "healthy" && echo "Health: OK" || echo "Health: FAIL"
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
After running all phases, produce a verification report:
|
||||
|
||||
```
|
||||
VERIFICATION REPORT
|
||||
==================
|
||||
|
||||
Types: [PASS/FAIL] (X errors)
|
||||
Lint: [PASS/FAIL] (X warnings)
|
||||
Tests: [PASS/FAIL] (X/Y passed, Z% coverage)
|
||||
Security: [PASS/FAIL] (X issues)
|
||||
Imports: [PASS/FAIL]
|
||||
Diff: [X files changed]
|
||||
|
||||
Overall: [READY/NOT READY] for PR
|
||||
|
||||
Issues to Fix:
|
||||
1. ...
|
||||
2. ...
|
||||
```
|
||||
|
||||
## Quick Commands
|
||||
|
||||
```bash
|
||||
# Full verification (WSL)
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports && ruff check src/ && pytest -x --tb=short"
|
||||
|
||||
# Type check only
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports"
|
||||
|
||||
# Tests only
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src -q"
|
||||
```
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
### Before Commit
|
||||
- [ ] mypy passes (no type errors)
|
||||
- [ ] ruff check passes (no lint errors)
|
||||
- [ ] All tests pass
|
||||
- [ ] No print() statements in production code
|
||||
- [ ] No hardcoded secrets
|
||||
- [ ] No bare `except:` clauses
|
||||
- [ ] No SQL injection risks (f-strings in queries)
|
||||
- [ ] Coverage >= 80% for changed code
|
||||
|
||||
### Before PR
|
||||
- [ ] All above checks pass
|
||||
- [ ] git diff reviewed for unintended changes
|
||||
- [ ] New code has tests
|
||||
- [ ] Type hints on all public functions
|
||||
- [ ] Docstrings on public APIs
|
||||
- [ ] No TODO/FIXME for critical items
|
||||
|
||||
### Before Deployment
|
||||
- [ ] All above checks pass
|
||||
- [ ] E2E tests pass
|
||||
- [ ] Health check returns healthy
|
||||
- [ ] Model loaded successfully
|
||||
- [ ] No server errors in logs
|
||||
|
||||
## Common Issues and Fixes
|
||||
|
||||
### Type Error: Missing return type
|
||||
```python
|
||||
# Before
|
||||
def process(data):
|
||||
return result
|
||||
|
||||
# After
|
||||
def process(data: dict) -> InferenceResult:
|
||||
return result
|
||||
```
|
||||
|
||||
### Lint Error: Unused import
|
||||
```python
|
||||
# Remove unused imports or add to __all__
|
||||
```
|
||||
|
||||
### Security: print() in production
|
||||
```python
|
||||
# Before
|
||||
print(f"Processing {doc_id}")
|
||||
|
||||
# After
|
||||
logger.info(f"Processing {doc_id}")
|
||||
```
|
||||
|
||||
### Security: Bare except
|
||||
```python
|
||||
# Before
|
||||
except:
|
||||
pass
|
||||
|
||||
# After
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
### Security: SQL injection
|
||||
```python
|
||||
# Before (DANGEROUS)
|
||||
cur.execute(f"SELECT * FROM docs WHERE id = '{user_input}'")
|
||||
|
||||
# After (SAFE)
|
||||
cur.execute("SELECT * FROM docs WHERE id = %s", (user_input,))
|
||||
```
|
||||
|
||||
## Continuous Mode
|
||||
|
||||
For long sessions, run verification after major changes:
|
||||
|
||||
```markdown
|
||||
Checkpoints:
|
||||
- After completing each function
|
||||
- After finishing a module
|
||||
- Before moving to next task
|
||||
- Every 15-20 minutes of coding
|
||||
|
||||
Run: /verify
|
||||
```
|
||||
|
||||
## Integration with Other Skills
|
||||
|
||||
| Skill | Purpose |
|
||||
|-------|---------|
|
||||
| code-review | Detailed code analysis |
|
||||
| security-review | Deep security audit |
|
||||
| tdd-workflow | Test coverage |
|
||||
| build-fix | Fix errors incrementally |
|
||||
|
||||
This skill provides quick, comprehensive verification. Use specialized skills for deeper analysis.
|
||||
39
.env.example
Normal file
39
.env.example
Normal file
@@ -0,0 +1,39 @@
|
||||
# Database Configuration
|
||||
# Copy this file to .env and fill in your actual values
|
||||
|
||||
# PostgreSQL Database
|
||||
DB_HOST=192.168.68.31
|
||||
DB_PORT=5432
|
||||
DB_NAME=docmaster
|
||||
DB_USER=docmaster
|
||||
DB_PASSWORD=your_password_here
|
||||
|
||||
# Storage Configuration
|
||||
# Backend type: local, azure_blob, or s3
|
||||
# All storage paths are relative to STORAGE_BASE_PATH (documents/, images/, uploads/, etc.)
|
||||
STORAGE_BACKEND=local
|
||||
STORAGE_BASE_PATH=./data
|
||||
|
||||
# Azure Blob Storage (when STORAGE_BACKEND=azure_blob)
|
||||
# AZURE_STORAGE_CONNECTION_STRING=your_connection_string
|
||||
# AZURE_STORAGE_CONTAINER=documents
|
||||
|
||||
# AWS S3 Storage (when STORAGE_BACKEND=s3)
|
||||
# AWS_S3_BUCKET=your_bucket_name
|
||||
# AWS_REGION=us-east-1
|
||||
# AWS_ACCESS_KEY_ID=your_access_key
|
||||
# AWS_SECRET_ACCESS_KEY=your_secret_key
|
||||
# AWS_ENDPOINT_URL= # Optional: for S3-compatible services like MinIO
|
||||
|
||||
# Model Configuration (optional)
|
||||
# MODEL_PATH=runs/train/invoice_fields/weights/best.pt
|
||||
# CONFIDENCE_THRESHOLD=0.5
|
||||
|
||||
# Server Configuration (optional)
|
||||
# SERVER_HOST=0.0.0.0
|
||||
# SERVER_PORT=8000
|
||||
|
||||
# Auto-labeling Configuration (optional)
|
||||
# AUTOLABEL_WORKERS=2
|
||||
# AUTOLABEL_DPI=150
|
||||
# AUTOLABEL_MIN_CONFIDENCE=0.5
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -52,6 +52,10 @@ reports/*.jsonl
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Coverage
|
||||
htmlcov/
|
||||
.coverage
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
|
||||
666
ARCHITECTURE_REVIEW.md
Normal file
666
ARCHITECTURE_REVIEW.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# Invoice Master POC v2 - 总体架构审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
|
||||
|
||||
---
|
||||
|
||||
## 架构概述
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Frontend (React) │
|
||||
│ Vite + TypeScript + TailwindCSS │
|
||||
└─────────────────────────────┬───────────────────────────────────┘
|
||||
│ HTTP/REST
|
||||
┌─────────────────────────────▼───────────────────────────────────┐
|
||||
│ Inference Service (FastAPI) │
|
||||
│ ┌──────────────┬──────────────┬──────────────┬──────────────┐ │
|
||||
│ │ Public API │ Admin API │ Training API│ Batch API │ │
|
||||
│ └──────────────┴──────────────┴──────────────┴──────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Service Layer │ │
|
||||
│ │ InferenceService │ AsyncProcessing │ BatchUpload │ Dataset │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Data Layer │ │
|
||||
│ │ AdminDB │ AsyncRequestDB │ SQLModel │ PostgreSQL │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Core Components │ │
|
||||
│ │ RateLimiter │ Schedulers │ TaskQueues │ Auth │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────┬───────────────────────────────────┘
|
||||
│ PostgreSQL
|
||||
┌─────────────────────────────▼───────────────────────────────────┐
|
||||
│ Training Service (GPU) │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ CLI: train │ autolabel │ analyze │ validate │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ YOLO: db_dataset │ annotation_generator │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Processing: CPU Pool │ GPU Pool │ Task Dispatcher │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌─────────┴─────────┐
|
||||
▼ ▼
|
||||
┌──────────────┐ ┌──────────────┐
|
||||
│ Shared │ │ Storage │
|
||||
│ PDF │ OCR │ │ Local/Azure/ │
|
||||
│ Normalize │ │ S3 │
|
||||
└──────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 层级 | 技术 | 评估 |
|
||||
|------|------|------|
|
||||
| **前端** | React + Vite + TypeScript + TailwindCSS | ✅ 现代栈 |
|
||||
| **API 框架** | FastAPI | ✅ 高性能,类型安全 |
|
||||
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 ORM |
|
||||
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
|
||||
| **OCR** | PaddleOCR v5 | ✅ 支持瑞典语 |
|
||||
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 架构优势
|
||||
|
||||
### 1. Monorepo 结构 ✅
|
||||
|
||||
```
|
||||
packages/
|
||||
├── shared/ # 共享库 - 无外部依赖
|
||||
├── training/ # 训练服务 - 依赖 shared
|
||||
└── inference/ # 推理服务 - 依赖 shared
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- 清晰的包边界,无循环依赖
|
||||
- 独立部署,training 按需启动
|
||||
- 代码复用率高
|
||||
|
||||
### 2. 分层架构 ✅
|
||||
|
||||
```
|
||||
API Routes (web/api/v1/)
|
||||
↓
|
||||
Service Layer (web/services/)
|
||||
↓
|
||||
Data Layer (data/)
|
||||
↓
|
||||
Database (PostgreSQL)
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- 职责分离明确
|
||||
- 便于单元测试
|
||||
- 可替换底层实现
|
||||
|
||||
### 3. 依赖注入 ✅
|
||||
|
||||
```python
|
||||
# FastAPI Depends 使用得当
|
||||
@router.post("/infer")
|
||||
async def infer(
|
||||
file: UploadFile,
|
||||
db: AdminDB = Depends(get_admin_db), # 注入
|
||||
token: str = Depends(validate_admin_token),
|
||||
):
|
||||
```
|
||||
|
||||
### 4. 存储抽象层 ✅
|
||||
|
||||
```python
|
||||
# 统一接口,支持多后端
|
||||
class StorageBackend(ABC):
|
||||
def upload(self, source: Path, destination: str) -> None: ...
|
||||
def download(self, source: str, destination: Path) -> None: ...
|
||||
def get_presigned_url(self, path: str) -> str: ...
|
||||
|
||||
# 实现: LocalStorageBackend, AzureStorageBackend, S3StorageBackend
|
||||
```
|
||||
|
||||
### 5. 动态模型管理 ✅
|
||||
|
||||
```python
|
||||
# 数据库驱动的模型切换
|
||||
def get_active_model_path() -> Path | None:
|
||||
db = AdminDB()
|
||||
active_model = db.get_active_model_version()
|
||||
return active_model.model_path if active_model else None
|
||||
|
||||
inference_service = InferenceService(
|
||||
model_path_resolver=get_active_model_path,
|
||||
)
|
||||
```
|
||||
|
||||
### 6. 任务队列分离 ✅
|
||||
|
||||
```python
|
||||
# 不同类型任务使用不同队列
|
||||
- AsyncTaskQueue: 异步推理任务
|
||||
- BatchQueue: 批量上传任务
|
||||
- TrainingScheduler: 训练任务调度
|
||||
- AutoLabelScheduler: 自动标注调度
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 架构问题与风险
|
||||
|
||||
### 1. 数据库层职责过重 ⚠️ **中风险**
|
||||
|
||||
**问题**: `AdminDB` 类过大,违反单一职责原则
|
||||
|
||||
```python
|
||||
# packages/inference/inference/data/admin_db.py
|
||||
class AdminDB:
|
||||
# Token 管理 (5 个方法)
|
||||
def is_valid_admin_token(self, token: str) -> bool: ...
|
||||
def create_admin_token(self, token: str, name: str): ...
|
||||
|
||||
# 文档管理 (8 个方法)
|
||||
def create_document(self, ...): ...
|
||||
def get_document(self, doc_id: str): ...
|
||||
|
||||
# 标注管理 (6 个方法)
|
||||
def create_annotation(self, ...): ...
|
||||
def get_annotations(self, doc_id: str): ...
|
||||
|
||||
# 训练任务 (7 个方法)
|
||||
def create_training_task(self, ...): ...
|
||||
def update_training_task(self, ...): ...
|
||||
|
||||
# 数据集 (6 个方法)
|
||||
def create_dataset(self, ...): ...
|
||||
def get_dataset(self, dataset_id: str): ...
|
||||
|
||||
# 模型版本 (5 个方法)
|
||||
def create_model_version(self, ...): ...
|
||||
def activate_model_version(self, ...): ...
|
||||
|
||||
# 批处理 (4 个方法)
|
||||
# 锁管理 (3 个方法)
|
||||
# ... 总计 50+ 方法
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 类过大,难以维护
|
||||
- 测试困难
|
||||
- 不同领域变更互相影响
|
||||
|
||||
**建议**: 按领域拆分为 Repository 模式
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
class TokenRepository:
|
||||
def validate(self, token: str) -> bool: ...
|
||||
def create(self, token: Token) -> None: ...
|
||||
|
||||
class DocumentRepository:
|
||||
def find_by_id(self, doc_id: str) -> Document | None: ...
|
||||
def save(self, document: Document) -> None: ...
|
||||
|
||||
class TrainingRepository:
|
||||
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
|
||||
def update_task_status(self, task_id: str, status: TaskStatus): ...
|
||||
|
||||
class ModelRepository:
|
||||
def get_active(self) -> ModelVersion | None: ...
|
||||
def activate(self, version_id: str) -> None: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. Service 层混合业务逻辑与技术细节 ⚠️ **中风险**
|
||||
|
||||
**问题**: `InferenceService` 既处理业务逻辑又处理技术实现
|
||||
|
||||
```python
|
||||
# packages/inference/inference/web/services/inference.py
|
||||
class InferenceService:
|
||||
def process(self, image_bytes: bytes) -> ServiceResult:
|
||||
# 1. 技术细节: 图像解码
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
# 2. 业务逻辑: 字段提取
|
||||
fields = self._extract_fields(image)
|
||||
|
||||
# 3. 技术细节: 模型推理
|
||||
detections = self._model.predict(image)
|
||||
|
||||
# 4. 业务逻辑: 结果验证
|
||||
if not self._validate_fields(fields):
|
||||
raise ValidationError()
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 难以测试业务逻辑
|
||||
- 技术变更影响业务代码
|
||||
- 无法切换技术实现
|
||||
|
||||
**建议**: 引入领域层和适配器模式
|
||||
|
||||
```python
|
||||
# 领域层 - 纯业务逻辑
|
||||
@dataclass
|
||||
class InvoiceDocument:
|
||||
document_id: str
|
||||
pages: list[Page]
|
||||
|
||||
class InvoiceExtractor:
|
||||
"""纯业务逻辑,不依赖技术实现"""
|
||||
def extract(self, document: InvoiceDocument) -> InvoiceFields:
|
||||
# 只处理业务规则
|
||||
pass
|
||||
|
||||
# 适配器层 - 技术实现
|
||||
class YoloFieldDetector:
|
||||
"""YOLO 技术适配器"""
|
||||
def __init__(self, model_path: Path):
|
||||
self._model = YOLO(model_path)
|
||||
|
||||
def detect(self, image: np.ndarray) -> list[FieldRegion]:
|
||||
return self._model.predict(image)
|
||||
|
||||
class PaddleOcrEngine:
|
||||
"""PaddleOCR 技术适配器"""
|
||||
def __init__(self):
|
||||
self._ocr = PaddleOCR()
|
||||
|
||||
def recognize(self, image: np.ndarray, region: BoundingBox) -> str:
|
||||
return self._ocr.ocr(image, region)
|
||||
|
||||
# 应用服务 - 协调领域和适配器
|
||||
class InvoiceProcessingService:
|
||||
def __init__(
|
||||
self,
|
||||
extractor: InvoiceExtractor,
|
||||
detector: FieldDetector,
|
||||
ocr: OcrEngine,
|
||||
):
|
||||
self._extractor = extractor
|
||||
self._detector = detector
|
||||
self._ocr = ocr
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 调度器设计分散 ⚠️ **中风险**
|
||||
|
||||
**问题**: 多个独立调度器缺乏统一协调
|
||||
|
||||
```python
|
||||
# 当前设计 - 4 个独立调度器
|
||||
# 1. TrainingScheduler (core/scheduler.py)
|
||||
# 2. AutoLabelScheduler (core/autolabel_scheduler.py)
|
||||
# 3. AsyncTaskQueue (workers/async_queue.py)
|
||||
# 4. BatchQueue (workers/batch_queue.py)
|
||||
|
||||
# app.py 中分别启动
|
||||
start_scheduler() # 训练调度器
|
||||
start_autolabel_scheduler() # 自动标注调度器
|
||||
init_batch_queue() # 批处理队列
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 资源竞争风险
|
||||
- 难以监控和追踪
|
||||
- 任务优先级难以管理
|
||||
- 重启时任务丢失
|
||||
|
||||
**建议**: 使用 Celery + Redis 统一任务队列
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
from celery import Celery
|
||||
|
||||
app = Celery('invoice_master')
|
||||
|
||||
@app.task(bind=True, max_retries=3)
|
||||
def process_inference(self, document_id: str):
|
||||
"""异步推理任务"""
|
||||
try:
|
||||
service = get_inference_service()
|
||||
result = service.process(document_id)
|
||||
return result
|
||||
except Exception as exc:
|
||||
raise self.retry(exc=exc, countdown=60)
|
||||
|
||||
@app.task
|
||||
def train_model(dataset_id: str, config: dict):
|
||||
"""训练任务"""
|
||||
training_service = get_training_service()
|
||||
return training_service.train(dataset_id, config)
|
||||
|
||||
@app.task
|
||||
def auto_label_documents(document_ids: list[str]):
|
||||
"""批量自动标注"""
|
||||
for doc_id in document_ids:
|
||||
auto_label_document.delay(doc_id)
|
||||
|
||||
# 优先级队列
|
||||
app.conf.task_routes = {
|
||||
'tasks.process_inference': {'queue': 'high_priority'},
|
||||
'tasks.train_model': {'queue': 'gpu_queue'},
|
||||
'tasks.auto_label_documents': {'queue': 'low_priority'},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 配置分散 ⚠️ **低风险**
|
||||
|
||||
**问题**: 配置分散在多个文件
|
||||
|
||||
```python
|
||||
# packages/shared/shared/config.py
|
||||
DATABASE = {...}
|
||||
PATHS = {...}
|
||||
AUTOLABEL = {...}
|
||||
|
||||
# packages/inference/inference/web/config.py
|
||||
@dataclass
|
||||
class ModelConfig: ...
|
||||
@dataclass
|
||||
class ServerConfig: ...
|
||||
@dataclass
|
||||
class FileConfig: ...
|
||||
|
||||
# 环境变量
|
||||
# .env 文件
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 配置难以追踪
|
||||
- 可能出现不一致
|
||||
- 缺少配置验证
|
||||
|
||||
**建议**: 使用 Pydantic Settings 集中管理
|
||||
|
||||
```python
|
||||
# config/settings.py
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix='DB_')
|
||||
|
||||
host: str = 'localhost'
|
||||
port: int = 5432
|
||||
name: str = 'docmaster'
|
||||
user: str = 'docmaster'
|
||||
password: str # 无默认值,必须设置
|
||||
|
||||
class StorageSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix='STORAGE_')
|
||||
|
||||
backend: str = 'local'
|
||||
base_path: str = '~/invoice-data'
|
||||
azure_connection_string: str | None = None
|
||||
s3_bucket: str | None = None
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
)
|
||||
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
storage: StorageSettings = StorageSettings()
|
||||
|
||||
# 验证
|
||||
@field_validator('database')
|
||||
def validate_database(cls, v):
|
||||
if not v.password:
|
||||
raise ValueError('Database password is required')
|
||||
return v
|
||||
|
||||
# 全局配置实例
|
||||
settings = Settings()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. 内存队列单点故障 ⚠️ **中风险**
|
||||
|
||||
**问题**: AsyncTaskQueue 和 BatchQueue 基于内存
|
||||
|
||||
```python
|
||||
# workers/async_queue.py
|
||||
class AsyncTaskQueue:
|
||||
def __init__(self):
|
||||
self._queue = Queue() # 内存队列
|
||||
self._workers = []
|
||||
|
||||
def enqueue(self, task: AsyncTask) -> None:
|
||||
self._queue.put(task) # 仅存储在内存
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 服务重启丢失所有待处理任务
|
||||
- 无法水平扩展
|
||||
- 任务持久化困难
|
||||
|
||||
**建议**: 使用 Redis/RabbitMQ 持久化队列
|
||||
|
||||
---
|
||||
|
||||
### 6. 缺少 API 版本迁移策略 ❓ **低风险**
|
||||
|
||||
**问题**: 有 `/api/v1/` 版本,但缺少升级策略
|
||||
|
||||
```
|
||||
当前: /api/v1/admin/documents
|
||||
未来: /api/v2/admin/documents ?
|
||||
```
|
||||
|
||||
**建议**:
|
||||
- 制定 API 版本升级流程
|
||||
- 使用 Header 版本控制
|
||||
- 维护版本兼容性文档
|
||||
|
||||
---
|
||||
|
||||
## 关键架构风险矩阵
|
||||
|
||||
| 风险项 | 概率 | 影响 | 风险等级 | 优先级 |
|
||||
|--------|------|------|----------|--------|
|
||||
| 内存队列丢失任务 | 中 | 高 | **高** | 🔴 P0 |
|
||||
| AdminDB 职责过重 | 高 | 中 | **中** | 🟡 P1 |
|
||||
| Service 层混合 | 高 | 中 | **中** | 🟡 P1 |
|
||||
| 调度器资源竞争 | 中 | 中 | **中** | 🟡 P1 |
|
||||
| 配置分散 | 高 | 低 | **低** | 🟢 P2 |
|
||||
| API 版本策略 | 低 | 低 | **低** | 🟢 P2 |
|
||||
|
||||
---
|
||||
|
||||
## 改进建议路线图
|
||||
|
||||
### Phase 1: 立即执行 (本周)
|
||||
|
||||
#### 1.1 拆分 AdminDB
|
||||
```python
|
||||
# 创建 repositories 包
|
||||
inference/data/repositories/
|
||||
├── __init__.py
|
||||
├── base.py # Repository 基类
|
||||
├── token.py # TokenRepository
|
||||
├── document.py # DocumentRepository
|
||||
├── annotation.py # AnnotationRepository
|
||||
├── training.py # TrainingRepository
|
||||
├── dataset.py # DatasetRepository
|
||||
└── model.py # ModelRepository
|
||||
```
|
||||
|
||||
#### 1.2 统一配置
|
||||
```python
|
||||
# 创建统一配置模块
|
||||
inference/config/
|
||||
├── __init__.py
|
||||
├── settings.py # Pydantic Settings
|
||||
└── validators.py # 配置验证
|
||||
```
|
||||
|
||||
### Phase 2: 短期执行 (本月)
|
||||
|
||||
#### 2.1 引入消息队列
|
||||
```yaml
|
||||
# docker-compose.yml 添加
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
celery_worker:
|
||||
build: .
|
||||
command: celery -A inference.tasks worker -l info
|
||||
depends_on:
|
||||
- redis
|
||||
- postgres
|
||||
```
|
||||
|
||||
#### 2.2 添加缓存层
|
||||
```python
|
||||
# 使用 Redis 缓存热点数据
|
||||
from redis import Redis
|
||||
|
||||
redis_client = Redis(host='localhost', port=6379)
|
||||
|
||||
class CachedDocumentRepository(DocumentRepository):
|
||||
def find_by_id(self, doc_id: str) -> Document | None:
|
||||
# 先查缓存
|
||||
cached = redis_client.get(f"doc:{doc_id}")
|
||||
if cached:
|
||||
return Document.parse_raw(cached)
|
||||
|
||||
# 再查数据库
|
||||
doc = super().find_by_id(doc_id)
|
||||
if doc:
|
||||
redis_client.setex(f"doc:{doc_id}", 3600, doc.json())
|
||||
return doc
|
||||
```
|
||||
|
||||
### Phase 3: 长期执行 (本季度)
|
||||
|
||||
#### 3.1 数据库读写分离
|
||||
```python
|
||||
# 配置主从数据库
|
||||
class DatabaseManager:
|
||||
def __init__(self):
|
||||
self._master = create_engine(MASTER_DB_URL)
|
||||
self._replica = create_engine(REPLICA_DB_URL)
|
||||
|
||||
def get_session(self, readonly: bool = False) -> Session:
|
||||
engine = self._replica if readonly else self._master
|
||||
return Session(engine)
|
||||
```
|
||||
|
||||
#### 3.2 事件驱动架构
|
||||
```python
|
||||
# 引入事件总线
|
||||
from event_bus import EventBus
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
# 发布事件
|
||||
@router.post("/documents")
|
||||
async def create_document(...):
|
||||
doc = document_repo.save(document)
|
||||
bus.publish('document.created', {'document_id': doc.id})
|
||||
return doc
|
||||
|
||||
# 订阅事件
|
||||
@bus.subscribe('document.created')
|
||||
def on_document_created(event):
|
||||
# 触发自动标注
|
||||
auto_label_task.delay(event['document_id'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 架构演进建议
|
||||
|
||||
### 当前架构 (适合 1-10 用户)
|
||||
|
||||
```
|
||||
Single Instance
|
||||
├── FastAPI App
|
||||
├── Memory Queues
|
||||
└── PostgreSQL
|
||||
```
|
||||
|
||||
### 目标架构 (适合 100+ 用户)
|
||||
|
||||
```
|
||||
Load Balancer
|
||||
├── FastAPI Instance 1
|
||||
├── FastAPI Instance 2
|
||||
└── FastAPI Instance N
|
||||
│
|
||||
┌───────┴───────┐
|
||||
▼ ▼
|
||||
Redis Cluster PostgreSQL
|
||||
(Celery + Cache) (Master + Replica)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 总体评分
|
||||
|
||||
| 维度 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| **模块化** | 8/10 | 包结构清晰,但部分类过大 |
|
||||
| **可扩展性** | 7/10 | 水平扩展良好,垂直扩展受限 |
|
||||
| **可维护性** | 8/10 | 分层合理,但职责边界需细化 |
|
||||
| **可靠性** | 7/10 | 内存队列是单点故障 |
|
||||
| **性能** | 8/10 | 异步处理良好 |
|
||||
| **安全性** | 8/10 | 基础安全到位 |
|
||||
| **总体** | **7.7/10** | 良好的架构基础,需优化细节 |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计合理**: Monorepo + 分层架构适合当前规模
|
||||
2. **主要风险**: 内存队列和数据库职责过重
|
||||
3. **演进路径**: 引入消息队列和缓存层
|
||||
4. **投入产出**: 当前架构可支撑到 100+ 用户,无需大规模重构
|
||||
|
||||
### 下一步行动
|
||||
|
||||
| 优先级 | 任务 | 预计工时 | 影响 |
|
||||
|--------|------|----------|------|
|
||||
| 🔴 P0 | 引入 Celery + Redis | 3 天 | 解决任务丢失问题 |
|
||||
| 🟡 P1 | 拆分 AdminDB | 2 天 | 提升可维护性 |
|
||||
| 🟡 P1 | 统一配置管理 | 1 天 | 减少配置错误 |
|
||||
| 🟢 P2 | 添加缓存层 | 2 天 | 提升性能 |
|
||||
| 🟢 P2 | 数据库读写分离 | 3 天 | 提升扩展性 |
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 关键文件清单
|
||||
|
||||
| 文件 | 职责 | 问题 |
|
||||
|------|------|------|
|
||||
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
|
||||
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
|
||||
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
|
||||
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
|
||||
| `shared/shared/config.py` | 共享配置 | 分散管理 |
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
|
||||
- [Celery Documentation](https://docs.celeryproject.org/)
|
||||
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
|
||||
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
|
||||
317
CHANGELOG.md
Normal file
317
CHANGELOG.md
Normal file
@@ -0,0 +1,317 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to the Invoice Field Extraction project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added - Phase 1: Security & Infrastructure (2026-01-22)
|
||||
|
||||
#### Security Enhancements
|
||||
- **Environment Variable Management**: Added `python-dotenv` for secure configuration management
|
||||
- Created `.env.example` template file for configuration reference
|
||||
- Created `.env` file for actual credentials (gitignored)
|
||||
- Updated `config.py` to load database password from environment variables
|
||||
- Added validation to ensure `DB_PASSWORD` is set at startup
|
||||
- Files modified: `config.py`, `requirements.txt`
|
||||
- New files: `.env`, `.env.example`
|
||||
- Tests: `tests/test_config.py` (7 tests, all passing)
|
||||
|
||||
- **SQL Injection Prevention**: Fixed SQL injection vulnerabilities in database queries
|
||||
- Replaced f-string formatting with parameterized queries in `LIMIT` clauses
|
||||
- Updated `get_all_documents_summary()` to use `%s` placeholder for LIMIT parameter
|
||||
- Updated `get_failed_matches()` to use `%s` placeholder for LIMIT parameter
|
||||
- Files modified: `src/data/db.py` (lines 246, 298)
|
||||
- Tests: `tests/test_db_security.py` (9 tests, all passing)
|
||||
|
||||
#### Code Quality
|
||||
- **Exception Hierarchy**: Created comprehensive custom exception system
|
||||
- Added base class `InvoiceExtractionError` with message and details support
|
||||
- Added specific exception types:
|
||||
- `PDFProcessingError` - PDF rendering/conversion errors
|
||||
- `OCRError` - OCR processing errors
|
||||
- `ModelInferenceError` - YOLO model errors
|
||||
- `FieldValidationError` - Field validation errors (with field-specific attributes)
|
||||
- `DatabaseError` - Database operation errors
|
||||
- `ConfigurationError` - Configuration errors
|
||||
- `PaymentLineParseError` - Payment line parsing errors
|
||||
- `CustomerNumberParseError` - Customer number parsing errors
|
||||
- `DataLoadError` - Data loading errors
|
||||
- `AnnotationError` - Annotation generation errors
|
||||
- New file: `src/exceptions.py`
|
||||
- Tests: `tests/test_exceptions.py` (16 tests, all passing)
|
||||
|
||||
### Testing
|
||||
- Added 32 new tests across 3 test files
|
||||
- Configuration tests: 7 tests
|
||||
- SQL injection prevention tests: 9 tests
|
||||
- Exception hierarchy tests: 16 tests
|
||||
- All tests passing (32/32)
|
||||
|
||||
### Documentation
|
||||
- Created `docs/CODE_REVIEW_REPORT.md` - Comprehensive code quality analysis (550+ lines)
|
||||
- Created `docs/REFACTORING_PLAN.md` - Detailed 3-phase refactoring plan (600+ lines)
|
||||
- Created `CHANGELOG.md` - Project changelog (this file)
|
||||
|
||||
### Changed
|
||||
- **Configuration Loading**: Database configuration now loads from environment variables instead of hardcoded values
|
||||
- Breaking change: Requires `.env` file with `DB_PASSWORD` set
|
||||
- Migration: Copy `.env.example` to `.env` and set your database password
|
||||
|
||||
### Security
|
||||
- **Fixed**: Database password no longer stored in plain text in `config.py`
|
||||
- **Fixed**: SQL injection vulnerabilities in LIMIT clauses (2 instances)
|
||||
|
||||
### Technical Debt Addressed
|
||||
- Eliminated security vulnerability: plaintext password storage
|
||||
- Reduced SQL injection attack surface
|
||||
- Improved error handling granularity with custom exceptions
|
||||
|
||||
---
|
||||
|
||||
### Added - Phase 2: Parser Refactoring (2026-01-22)
|
||||
|
||||
#### Unified Parser Modules
|
||||
- **Payment Line Parser**: Created dedicated payment line parsing module
|
||||
- Handles Swedish payment line format: `# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#`
|
||||
- Tolerates common OCR errors: spaces in numbers, missing symbols, spaces in check digits
|
||||
- Supports 4 parsing patterns: full format, no amount, alternative, account-only
|
||||
- Returns structured `PaymentLineData` with parsed fields
|
||||
- New file: `src/inference/payment_line_parser.py` (90 lines, 92% coverage)
|
||||
- Tests: `tests/test_payment_line_parser.py` (23 tests, all passing)
|
||||
- Eliminates 1st code duplication (payment line parsing logic)
|
||||
|
||||
- **Customer Number Parser**: Created dedicated customer number parsing module
|
||||
- Handles Swedish customer number formats: `JTY 576-3`, `DWQ 211-X`, `FFL 019N`, etc.
|
||||
- Uses Strategy Pattern with 5 pattern classes:
|
||||
- `LabeledPattern` - Explicit labels (highest priority, 0.98 confidence)
|
||||
- `DashFormatPattern` - Standard format with dash (0.95 confidence)
|
||||
- `NoDashFormatPattern` - Format without dash, adds dash automatically (0.90 confidence)
|
||||
- `CompactFormatPattern` - Compact format without spaces (0.75 confidence)
|
||||
- `GenericAlphanumericPattern` - Fallback generic pattern (variable confidence)
|
||||
- Excludes Swedish postal codes (`SE XXX XX` format)
|
||||
- Returns highest confidence match
|
||||
- New file: `src/inference/customer_number_parser.py` (154 lines, 92% coverage)
|
||||
- Tests: `tests/test_customer_number_parser.py` (32 tests, all passing)
|
||||
- Reduces `_normalize_customer_number` complexity (127 lines → will use 5-10 lines after integration)
|
||||
|
||||
### Testing Summary
|
||||
|
||||
**Phase 1 Tests** (32 tests):
|
||||
- Configuration tests: 7 tests ([test_config.py](tests/test_config.py))
|
||||
- SQL injection prevention tests: 9 tests ([test_db_security.py](tests/test_db_security.py))
|
||||
- Exception hierarchy tests: 16 tests ([test_exceptions.py](tests/test_exceptions.py))
|
||||
|
||||
**Phase 2 Tests** (121 tests):
|
||||
- Payment line parser tests: 23 tests ([test_payment_line_parser.py](tests/test_payment_line_parser.py))
|
||||
- Standard parsing, OCR error handling, real-world examples, edge cases
|
||||
- Coverage: 92%
|
||||
- Customer number parser tests: 32 tests ([test_customer_number_parser.py](tests/test_customer_number_parser.py))
|
||||
- Pattern matching (DashFormat, NoDashFormat, Compact, Labeled)
|
||||
- Real-world examples, edge cases, Swedish postal code exclusion
|
||||
- Coverage: 92%
|
||||
- Field extractor integration tests: 45 tests ([test_field_extractor.py](src/inference/test_field_extractor.py))
|
||||
- Validates backward compatibility with existing code
|
||||
- Tests for invoice numbers, bankgiro, plusgiro, amounts, OCR, dates, payment lines, customer numbers
|
||||
- Pipeline integration tests: 21 tests ([test_pipeline.py](src/inference/test_pipeline.py))
|
||||
- Cross-validation, payment line parsing, field overrides
|
||||
|
||||
**Total**: 153 tests, 100% passing, 4.50s runtime
|
||||
|
||||
### Code Quality
|
||||
- **Eliminated Code Duplication**: Payment line parsing previously in 3 places, now unified in 1 module
|
||||
- **Improved Maintainability**: Strategy Pattern makes customer number patterns easy to extend
|
||||
- **Better Test Coverage**: New parsers have 92% coverage vs original 10% in field_extractor.py
|
||||
|
||||
#### Parser Integration into field_extractor.py (2026-01-22)
|
||||
|
||||
- **field_extractor.py Integration**: Successfully integrated new parsers
|
||||
- Added `PaymentLineParser` and `CustomerNumberParser` instances (lines 99-101)
|
||||
- Replaced `_normalize_payment_line` method: 74 lines → 3 lines (lines 640-657)
|
||||
- Replaced `_normalize_customer_number` method: 127 lines → 3 lines (lines 697-707)
|
||||
- All 45 existing tests pass (100% backward compatibility maintained)
|
||||
- Tests run time: 4.21 seconds
|
||||
- File: `src/inference/field_extractor.py`
|
||||
|
||||
#### Parser Integration into pipeline.py (2026-01-22)
|
||||
|
||||
- **pipeline.py Integration**: Successfully integrated PaymentLineParser
|
||||
- Added `PaymentLineParser` import (line 15)
|
||||
- Added `payment_line_parser` instance initialization (line 128)
|
||||
- Replaced `_parse_machine_readable_payment_line` method: 36 lines → 6 lines (lines 219-233)
|
||||
- All 21 existing tests pass (100% backward compatibility maintained)
|
||||
- Tests run time: 4.00 seconds
|
||||
- File: `src/inference/pipeline.py`
|
||||
|
||||
### Phase 2 Status: **COMPLETED** ✅
|
||||
|
||||
- [x] Create unified `payment_line_parser` module ✅
|
||||
- [x] Create unified `customer_number_parser` module ✅
|
||||
- [x] Refactor `field_extractor.py` to use new parsers ✅
|
||||
- [x] Refactor `pipeline.py` to use new parsers ✅
|
||||
- [x] Comprehensive test suite (153 tests, 100% passing) ✅
|
||||
|
||||
### Achieved Impact
|
||||
- Eliminate code duplication: 3 implementations → 1 ✅ (payment_line unified across field_extractor.py, pipeline.py, tests)
|
||||
- Reduce `_normalize_payment_line` complexity in field_extractor.py: 74 lines → 3 lines ✅
|
||||
- Reduce `_normalize_customer_number` complexity in field_extractor.py: 127 lines → 3 lines ✅
|
||||
- Reduce `_parse_machine_readable_payment_line` complexity in pipeline.py: 36 lines → 6 lines ✅
|
||||
- Total lines of code eliminated: 201 lines reduced to 12 lines (94% reduction) ✅
|
||||
- Improve test coverage: New parser modules have 92% coverage (vs original 10% in field_extractor.py)
|
||||
- Simplify maintenance: Pattern-based approach makes extension easy
|
||||
- 100% backward compatibility: All 66 existing tests pass (45 field_extractor + 21 pipeline)
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Performance & Documentation (2026-01-22)
|
||||
|
||||
### Added
|
||||
|
||||
#### Configuration Constants Extraction
|
||||
- **Created `src/inference/constants.py`**: Centralized configuration constants
|
||||
- Detection & model configuration (confidence thresholds, IOU)
|
||||
- Image processing configuration (DPI, scaling factors)
|
||||
- Customer number parser confidence scores
|
||||
- Field extraction confidence multipliers
|
||||
- Account type detection thresholds
|
||||
- Pattern matching constants
|
||||
- 90 lines of well-documented constants with usage notes
|
||||
- Eliminates ~15 hardcoded magic numbers across codebase
|
||||
- File: [src/inference/constants.py](src/inference/constants.py)
|
||||
|
||||
#### Performance Optimization Documentation
|
||||
- **Created `docs/PERFORMANCE_OPTIMIZATION.md`**: Comprehensive performance guide (400+ lines)
|
||||
- **Batch Processing Optimization**: Parallel processing strategies, already-implemented dual pool system
|
||||
- **Database Query Optimization**: Connection pooling recommendations, index strategies
|
||||
- **Caching Strategies**: Model loading cache, parser reuse (already optimal), OCR result caching
|
||||
- **Memory Management**: Explicit cleanup, generator patterns, context managers
|
||||
- **Profiling Guidelines**: cProfile, memory_profiler, py-spy recommendations
|
||||
- **Benchmarking Scripts**: Ready-to-use performance measurement code
|
||||
- **Priority Roadmap**: High/Medium/Low priority optimizations with effort estimates
|
||||
- Expected impact: 2-5x throughput improvement for batch processing
|
||||
- File: [docs/PERFORMANCE_OPTIMIZATION.md](docs/PERFORMANCE_OPTIMIZATION.md)
|
||||
|
||||
### Phase 3 Status: **COMPLETED** ✅
|
||||
|
||||
- [x] Configuration constants extraction ✅
|
||||
- [x] Performance optimization analysis ✅
|
||||
- [x] Batch processing optimization recommendations ✅
|
||||
- [x] Database optimization strategies ✅
|
||||
- [x] Caching and memory management guidelines ✅
|
||||
- [x] Profiling and benchmarking documentation ✅
|
||||
|
||||
### Deliverables
|
||||
|
||||
**New Files** (2 files):
|
||||
1. `src/inference/constants.py` (90 lines) - Centralized configuration constants
|
||||
2. `docs/PERFORMANCE_OPTIMIZATION.md` (400+ lines) - Performance optimization guide
|
||||
|
||||
**Impact**:
|
||||
- Eliminates 15+ hardcoded magic numbers
|
||||
- Provides clear optimization roadmap
|
||||
- Documents existing performance features
|
||||
- Identifies quick wins (connection pooling, indexes)
|
||||
- Long-term strategy (caching, profiling)
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
### Breaking Changes
|
||||
- **v2.x**: Requires `.env` file with database credentials
|
||||
- Action required: Create `.env` file based on `.env.example`
|
||||
- Affected: All deployments, CI/CD pipelines
|
||||
|
||||
### Migration Guide
|
||||
|
||||
#### From v1.x to v2.x (Environment Variables)
|
||||
1. Copy `.env.example` to `.env`:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
2. Edit `.env` and set your database password:
|
||||
```
|
||||
DB_PASSWORD=your_actual_password_here
|
||||
```
|
||||
|
||||
3. Install new dependency:
|
||||
```bash
|
||||
pip install python-dotenv
|
||||
```
|
||||
|
||||
4. Verify configuration loads correctly:
|
||||
```bash
|
||||
python -c "import config; print('Config loaded successfully')"
|
||||
```
|
||||
|
||||
## Summary of All Work Completed
|
||||
|
||||
### Files Created (13 new files)
|
||||
|
||||
**Phase 1** (3 files):
|
||||
1. `.env` - Environment variables for database credentials
|
||||
2. `.env.example` - Template for environment configuration
|
||||
3. `src/exceptions.py` - Custom exception hierarchy (35 lines, 66% coverage)
|
||||
|
||||
**Phase 2** (7 files):
|
||||
4. `src/inference/payment_line_parser.py` - Unified payment line parsing (90 lines, 92% coverage)
|
||||
5. `src/inference/customer_number_parser.py` - Unified customer number parsing (154 lines, 92% coverage)
|
||||
6. `tests/test_config.py` - Configuration tests (7 tests)
|
||||
7. `tests/test_db_security.py` - SQL injection prevention tests (9 tests)
|
||||
8. `tests/test_exceptions.py` - Exception hierarchy tests (16 tests)
|
||||
9. `tests/test_payment_line_parser.py` - Payment line parser tests (23 tests)
|
||||
10. `tests/test_customer_number_parser.py` - Customer number parser tests (32 tests)
|
||||
|
||||
**Phase 3** (2 files):
|
||||
11. `src/inference/constants.py` - Centralized configuration constants (90 lines)
|
||||
12. `docs/PERFORMANCE_OPTIMIZATION.md` - Performance optimization guide (400+ lines)
|
||||
|
||||
**Documentation** (1 file):
|
||||
13. `CHANGELOG.md` - This file (260+ lines of detailed documentation)
|
||||
|
||||
### Files Modified (4 files)
|
||||
1. `config.py` - Added environment variable loading with python-dotenv
|
||||
2. `src/data/db.py` - Fixed 2 SQL injection vulnerabilities (lines 246, 298)
|
||||
3. `src/inference/field_extractor.py` - Integrated new parsers (reduced 201 lines to 6 lines)
|
||||
4. `src/inference/pipeline.py` - Integrated PaymentLineParser (reduced 36 lines to 6 lines)
|
||||
5. `requirements.txt` - Added python-dotenv dependency
|
||||
|
||||
### Test Summary
|
||||
- **Total tests**: 153 tests across 7 test files
|
||||
- **Passing**: 153 (100%)
|
||||
- **Failing**: 0
|
||||
- **Runtime**: 4.50 seconds
|
||||
- **Coverage**:
|
||||
- New parser modules: 92%
|
||||
- Config module: 100%
|
||||
- Exception module: 66%
|
||||
- DB security coverage: 18% (focused on parameterized queries)
|
||||
|
||||
### Code Metrics
|
||||
- **Lines eliminated**: 237 lines of duplicated/complex code → 18 lines (92% reduction)
|
||||
- field_extractor.py: 201 lines → 6 lines
|
||||
- pipeline.py: 36 lines → 6 lines
|
||||
- **New code added**: 279 lines of well-tested parser code
|
||||
- **Net impact**: Replaced 237 lines of duplicate code with 279 lines of unified, tested code (+42 lines, but -3 implementations)
|
||||
- **Test coverage improvement**: 0% → 92% for parser logic
|
||||
|
||||
### Performance Impact
|
||||
- Configuration loading: Negligible (<1ms overhead for .env parsing)
|
||||
- SQL queries: No performance change (parameterized queries are standard practice)
|
||||
- Parser refactoring: No performance degradation (logic simplified, not changed)
|
||||
- Exception handling: Minimal overhead (only when exceptions are raised)
|
||||
|
||||
### Security Improvements
|
||||
- ✅ Eliminated plaintext password storage
|
||||
- ✅ Fixed 2 SQL injection vulnerabilities
|
||||
- ✅ Added input validation in database layer
|
||||
|
||||
### Maintainability Improvements
|
||||
- ✅ Eliminated code duplication (3 implementations → 1)
|
||||
- ✅ Strategy Pattern enables easy extension of customer number formats
|
||||
- ✅ Comprehensive test suite (153 tests) ensures safe refactoring
|
||||
- ✅ 100% backward compatibility maintained
|
||||
- ✅ Custom exception hierarchy for granular error handling
|
||||
805
CODE_REVIEW_REPORT.md
Normal file
805
CODE_REVIEW_REPORT.md
Normal file
@@ -0,0 +1,805 @@
|
||||
# Invoice Master POC v2 - 详细代码审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `C:\Users\yaoji\git\ColaCoder\invoice-master-poc-v2`
|
||||
**代码统计**:
|
||||
- Python文件: 200+ 个
|
||||
- 测试文件: 97 个
|
||||
- TypeScript/React文件: 39 个
|
||||
- 总测试数: 1,601 个
|
||||
- 测试覆盖率: 28%
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
1. [执行摘要](#执行摘要)
|
||||
2. [架构概览](#架构概览)
|
||||
3. [详细模块审查](#详细模块审查)
|
||||
4. [代码质量问题](#代码质量问题)
|
||||
5. [安全风险分析](#安全风险分析)
|
||||
6. [性能问题](#性能问题)
|
||||
7. [改进建议](#改进建议)
|
||||
8. [总结与评分](#总结与评分)
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
### 总体评估
|
||||
|
||||
| 维度 | 评分 | 状态 |
|
||||
|------|------|------|
|
||||
| **代码质量** | 7.5/10 | 良好,但有改进空间 |
|
||||
| **安全性** | 7/10 | 基础安全到位,需加强 |
|
||||
| **可维护性** | 8/10 | 模块化良好 |
|
||||
| **测试覆盖** | 5/10 | 偏低,需提升 |
|
||||
| **性能** | 8/10 | 异步处理良好 |
|
||||
| **文档** | 8/10 | 文档详尽 |
|
||||
| **总体** | **7.3/10** | 生产就绪,需小幅改进 |
|
||||
|
||||
### 关键发现
|
||||
|
||||
**优势:**
|
||||
- 清晰的Monorepo架构,三包分离合理
|
||||
- 类型注解覆盖率高(>90%)
|
||||
- 存储抽象层设计优秀
|
||||
- FastAPI使用规范,依赖注入模式良好
|
||||
- 异常处理完善,自定义异常层次清晰
|
||||
|
||||
**风险:**
|
||||
- 测试覆盖率仅28%,远低于行业标准
|
||||
- AdminDB类过大(50+方法),违反单一职责原则
|
||||
- 内存队列存在单点故障风险
|
||||
- 部分安全细节需加强(时序攻击、文件上传验证)
|
||||
- 前端状态管理简单,可能难以扩展
|
||||
|
||||
---
|
||||
|
||||
## 架构概览
|
||||
|
||||
### 项目结构
|
||||
|
||||
```
|
||||
invoice-master-poc-v2/
|
||||
├── packages/
|
||||
│ ├── shared/ # 共享库 (74个Python文件)
|
||||
│ │ ├── pdf/ # PDF处理
|
||||
│ │ ├── ocr/ # OCR封装
|
||||
│ │ ├── normalize/ # 字段规范化
|
||||
│ │ ├── matcher/ # 字段匹配
|
||||
│ │ ├── storage/ # 存储抽象层
|
||||
│ │ ├── training/ # 训练组件
|
||||
│ │ └── augmentation/# 数据增强
|
||||
│ ├── training/ # 训练服务 (26个Python文件)
|
||||
│ │ ├── cli/ # 命令行工具
|
||||
│ │ ├── yolo/ # YOLO数据集
|
||||
│ │ └── processing/ # 任务处理
|
||||
│ └── inference/ # 推理服务 (100个Python文件)
|
||||
│ ├── web/ # FastAPI应用
|
||||
│ ├── pipeline/ # 推理管道
|
||||
│ ├── data/ # 数据层
|
||||
│ └── cli/ # 命令行工具
|
||||
├── frontend/ # React前端 (39个TS/TSX文件)
|
||||
│ ├── src/
|
||||
│ │ ├── components/ # UI组件
|
||||
│ │ ├── hooks/ # React Query hooks
|
||||
│ │ └── api/ # API客户端
|
||||
└── tests/ # 测试 (97个Python文件)
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 层级 | 技术 | 评估 |
|
||||
|------|------|------|
|
||||
| **前端** | React 18 + TypeScript + Vite + TailwindCSS | 现代栈,类型安全 |
|
||||
| **API框架** | FastAPI + Uvicorn | 高性能,异步支持 |
|
||||
| **数据库** | PostgreSQL + SQLModel | 类型安全ORM |
|
||||
| **目标检测** | YOLOv11 (Ultralytics) | 业界标准 |
|
||||
| **OCR** | PaddleOCR v5 | 支持瑞典语 |
|
||||
| **部署** | Docker + Azure/AWS | 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 详细模块审查
|
||||
|
||||
### 1. Shared Package
|
||||
|
||||
#### 1.1 配置模块 (`shared/config.py`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/config.py`
|
||||
**代码行数**: 82行
|
||||
|
||||
**优点:**
|
||||
- 使用环境变量加载配置,无硬编码敏感信息
|
||||
- DPI配置统一管理(DEFAULT_DPI = 150)
|
||||
- 密码无默认值,强制要求设置
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 配置分散,缺少验证
|
||||
DATABASE = {
|
||||
'host': os.getenv('DB_HOST', '192.168.68.31'), # 硬编码IP
|
||||
'port': int(os.getenv('DB_PORT', '5432')),
|
||||
# ...
|
||||
}
|
||||
|
||||
# 问题2: 缺少类型安全
|
||||
# 建议使用 Pydantic Settings
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 使用 Pydantic Settings 集中管理配置,添加验证逻辑
|
||||
|
||||
---
|
||||
|
||||
#### 1.2 存储抽象层 (`shared/storage/`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/storage/`
|
||||
**包含文件**: 8个
|
||||
|
||||
**优点:**
|
||||
- 设计优秀的抽象接口 `StorageBackend`
|
||||
- 支持 Local/Azure/S3 多后端
|
||||
- 预签名URL支持
|
||||
- 异常层次清晰
|
||||
|
||||
**代码示例 - 优秀设计:**
|
||||
```python
|
||||
class StorageBackend(ABC):
|
||||
@abstractmethod
|
||||
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||
pass
|
||||
```
|
||||
|
||||
**问题:**
|
||||
- `upload_bytes` 和 `download_bytes` 默认实现使用临时文件,效率较低
|
||||
- 缺少文件类型验证(魔术字节检查)
|
||||
|
||||
**严重程度**: 低
|
||||
**建议**: 子类可重写bytes方法以提高效率,添加文件类型验证
|
||||
|
||||
---
|
||||
|
||||
#### 1.3 异常定义 (`shared/exceptions.py`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/exceptions.py`
|
||||
**代码行数**: 103行
|
||||
|
||||
**优点:**
|
||||
- 清晰的异常层次结构
|
||||
- 所有异常继承自 `InvoiceExtractionError`
|
||||
- 包含详细的错误上下文
|
||||
|
||||
**代码示例:**
|
||||
```python
|
||||
class InvoiceExtractionError(Exception):
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
```
|
||||
|
||||
**评分**: 9/10 - 设计优秀
|
||||
|
||||
---
|
||||
|
||||
#### 1.4 数据增强 (`shared/augmentation/`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/augmentation/`
|
||||
**包含文件**: 10个
|
||||
|
||||
**功能:**
|
||||
- 12种数据增强策略
|
||||
- 透视变换、皱纹、边缘损坏、污渍等
|
||||
- 高斯模糊、运动模糊、噪声等
|
||||
|
||||
**代码质量**: 良好,模块化设计
|
||||
|
||||
---
|
||||
|
||||
### 2. Inference Package
|
||||
|
||||
#### 2.1 认证模块 (`inference/web/core/auth.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/core/auth.py`
|
||||
**代码行数**: 61行
|
||||
|
||||
**优点:**
|
||||
- 使用FastAPI依赖注入模式
|
||||
- Token过期检查
|
||||
- 记录最后使用时间
|
||||
|
||||
**安全问题:**
|
||||
```python
|
||||
# 问题: 时序攻击风险 (第46行)
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired admin token.")
|
||||
|
||||
# 建议: 使用 constant-time 比较
|
||||
import hmac
|
||||
if not hmac.compare_digest(token, expected_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 使用 `hmac.compare_digest()` 进行constant-time比较
|
||||
|
||||
---
|
||||
|
||||
#### 2.2 限流器 (`inference/web/core/rate_limiter.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/core/rate_limiter.py`
|
||||
**代码行数**: 212行
|
||||
|
||||
**优点:**
|
||||
- 滑动窗口算法实现
|
||||
- 线程安全(使用Lock)
|
||||
- 支持并发任务限制
|
||||
- 可配置的限流策略
|
||||
|
||||
**代码示例 - 优秀设计:**
|
||||
```python
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000
|
||||
```
|
||||
|
||||
**问题:**
|
||||
- 内存存储,服务重启后限流状态丢失
|
||||
- 分布式部署时无法共享限流状态
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 生产环境使用Redis实现分布式限流
|
||||
|
||||
---
|
||||
|
||||
#### 2.3 AdminDB (`inference/data/admin_db.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/data/admin_db.py`
|
||||
**代码行数**: 1300+行
|
||||
|
||||
**严重问题 - 类过大:**
|
||||
```python
|
||||
class AdminDB:
|
||||
# Token管理 (5个方法)
|
||||
# 文档管理 (8个方法)
|
||||
# 标注管理 (6个方法)
|
||||
# 训练任务 (7个方法)
|
||||
# 数据集 (6个方法)
|
||||
# 模型版本 (5个方法)
|
||||
# 批处理 (4个方法)
|
||||
# 锁管理 (3个方法)
|
||||
# ... 总计50+方法
|
||||
```
|
||||
|
||||
**影响:**
|
||||
- 违反单一职责原则
|
||||
- 难以维护
|
||||
- 测试困难
|
||||
- 不同领域变更互相影响
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 按领域拆分为Repository模式
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
class TokenRepository:
|
||||
def validate(self, token: str) -> bool: ...
|
||||
|
||||
class DocumentRepository:
|
||||
def find_by_id(self, doc_id: str) -> Document | None: ...
|
||||
|
||||
class TrainingRepository:
|
||||
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 2.4 文档路由 (`inference/web/api/v1/admin/documents.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/api/v1/admin/documents.py`
|
||||
**代码行数**: 692行
|
||||
|
||||
**优点:**
|
||||
- FastAPI使用规范
|
||||
- 输入验证完善
|
||||
- 响应模型定义清晰
|
||||
- 错误处理良好
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 文件上传缺少魔术字节验证 (第127-131行)
|
||||
content = await file.read()
|
||||
# 建议: 验证PDF魔术字节 %PDF
|
||||
|
||||
# 问题2: 路径遍历风险 (第494-498行)
|
||||
filename = Path(document.file_path).name
|
||||
# 建议: 使用 Path.name 并验证路径范围
|
||||
|
||||
# 问题3: 函数过长,职责过多
|
||||
# _convert_pdf_to_images 函数混合了PDF处理和存储操作
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 添加文件类型验证,拆分大函数
|
||||
|
||||
---
|
||||
|
||||
#### 2.5 推理服务 (`inference/web/services/inference.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/services/inference.py`
|
||||
**代码行数**: 361行
|
||||
|
||||
**优点:**
|
||||
- 支持动态模型加载
|
||||
- 懒加载初始化
|
||||
- 模型热重载支持
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 混合业务逻辑和技术实现
|
||||
def process_image(self, image_path: Path, ...) -> ServiceResult:
|
||||
# 1. 技术细节: 图像解码
|
||||
# 2. 业务逻辑: 字段提取
|
||||
# 3. 技术细节: 模型推理
|
||||
# 4. 业务逻辑: 结果验证
|
||||
|
||||
# 问题2: 可视化方法重复加载模型
|
||||
model = YOLO(str(self.model_config.model_path)) # 第316行
|
||||
# 应该在初始化时加载,避免重复IO
|
||||
|
||||
# 问题3: 临时文件未使用上下文管理器
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
# 建议使用 tempfile 上下文管理器
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 引入领域层和适配器模式,分离业务和技术逻辑
|
||||
|
||||
---
|
||||
|
||||
#### 2.6 异步队列 (`inference/web/workers/async_queue.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/workers/async_queue.py`
|
||||
**代码行数**: 213行
|
||||
|
||||
**优点:**
|
||||
- 线程安全实现
|
||||
- 优雅关闭支持
|
||||
- 任务状态跟踪
|
||||
|
||||
**严重问题:**
|
||||
```python
|
||||
# 问题: 内存队列,服务重启丢失任务 (第42行)
|
||||
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
|
||||
|
||||
# 问题: 无法水平扩展
|
||||
# 问题: 任务持久化困难
|
||||
```
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 使用Redis/RabbitMQ持久化队列
|
||||
|
||||
---
|
||||
|
||||
### 3. Training Package
|
||||
|
||||
#### 3.1 整体评估
|
||||
|
||||
**文件数量**: 26个Python文件
|
||||
|
||||
**优点:**
|
||||
- CLI工具设计良好
|
||||
- 双池协调器(CPU + GPU)设计优秀
|
||||
- 数据增强策略丰富
|
||||
|
||||
**总体评分**: 8/10
|
||||
|
||||
---
|
||||
|
||||
### 4. Frontend
|
||||
|
||||
#### 4.1 API客户端 (`frontend/src/api/client.ts`)
|
||||
|
||||
**文件位置**: `frontend/src/api/client.ts`
|
||||
**代码行数**: 42行
|
||||
|
||||
**优点:**
|
||||
- Axios配置清晰
|
||||
- 请求/响应拦截器
|
||||
- 认证token自动添加
|
||||
|
||||
**问题:**
|
||||
```typescript
|
||||
// 问题1: Token存储在localStorage,存在XSS风险
|
||||
const token = localStorage.getItem('admin_token')
|
||||
|
||||
// 问题2: 401错误处理不完整
|
||||
if (error.response?.status === 401) {
|
||||
console.warn('Authentication required...')
|
||||
// 应该触发重新登录或token刷新
|
||||
}
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 考虑使用http-only cookie存储token,完善错误处理
|
||||
|
||||
---
|
||||
|
||||
#### 4.2 Dashboard组件 (`frontend/src/components/Dashboard.tsx`)
|
||||
|
||||
**文件位置**: `frontend/src/components/Dashboard.tsx`
|
||||
**代码行数**: 301行
|
||||
|
||||
**优点:**
|
||||
- React hooks使用规范
|
||||
- 类型定义清晰
|
||||
- UI响应式设计
|
||||
|
||||
**问题:**
|
||||
```typescript
|
||||
// 问题1: 硬编码的进度值
|
||||
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
|
||||
if (doc.auto_label_status === 'running') {
|
||||
return 45 // 硬编码!
|
||||
}
|
||||
// ...
|
||||
}
|
||||
|
||||
// 问题2: 搜索功能未实现
|
||||
// 没有onChange处理
|
||||
|
||||
// 问题3: 缺少错误边界处理
|
||||
// 组件应该包裹在Error Boundary中
|
||||
```
|
||||
|
||||
**严重程度**: 低
|
||||
**建议**: 实现真实的进度获取,添加搜索功能
|
||||
|
||||
---
|
||||
|
||||
#### 4.3 整体评估
|
||||
|
||||
**优点:**
|
||||
- TypeScript类型安全
|
||||
- React Query状态管理
|
||||
- TailwindCSS样式一致
|
||||
|
||||
**问题:**
|
||||
- 缺少错误边界
|
||||
- 部分功能硬编码
|
||||
- 缺少单元测试
|
||||
|
||||
**总体评分**: 7.5/10
|
||||
|
||||
---
|
||||
|
||||
### 5. Tests
|
||||
|
||||
#### 5.1 测试统计
|
||||
|
||||
- **测试文件数**: 97个
|
||||
- **测试总数**: 1,601个
|
||||
- **测试覆盖率**: 28%
|
||||
|
||||
#### 5.2 覆盖率分析
|
||||
|
||||
| 模块 | 估计覆盖率 | 状态 |
|
||||
|------|-----------|------|
|
||||
| `shared/` | 35% | 偏低 |
|
||||
| `inference/web/` | 25% | 偏低 |
|
||||
| `inference/pipeline/` | 20% | 严重不足 |
|
||||
| `training/` | 30% | 偏低 |
|
||||
| `frontend/` | 15% | 严重不足 |
|
||||
|
||||
#### 5.3 测试质量问题
|
||||
|
||||
**优点:**
|
||||
- 使用了pytest框架
|
||||
- 有conftest.py配置
|
||||
- 部分集成测试
|
||||
|
||||
**问题:**
|
||||
- 覆盖率远低于行业标准(80%)
|
||||
- 缺少端到端测试
|
||||
- 部分测试可能过于简单
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 制定测试计划,优先覆盖核心业务逻辑
|
||||
|
||||
---
|
||||
|
||||
## 代码质量问题
|
||||
|
||||
### 高优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| AdminDB类过大 | `inference/data/admin_db.py` | 维护困难 | 拆分为Repository模式 |
|
||||
| 内存队列单点故障 | `inference/web/workers/async_queue.py` | 任务丢失 | 使用Redis持久化 |
|
||||
| 测试覆盖率过低 | 全项目 | 代码风险 | 提升至60%+ |
|
||||
|
||||
### 中优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| 时序攻击风险 | `inference/web/core/auth.py` | 安全漏洞 | 使用hmac.compare_digest |
|
||||
| 限流器内存存储 | `inference/web/core/rate_limiter.py` | 分布式问题 | 使用Redis |
|
||||
| 配置分散 | `shared/config.py` | 难以管理 | 使用Pydantic Settings |
|
||||
| 文件上传验证不足 | `inference/web/api/v1/admin/documents.py` | 安全风险 | 添加魔术字节验证 |
|
||||
| 推理服务混合职责 | `inference/web/services/inference.py` | 难以测试 | 分离业务和技术逻辑 |
|
||||
|
||||
### 低优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| 前端搜索未实现 | `frontend/src/components/Dashboard.tsx` | 功能缺失 | 实现搜索功能 |
|
||||
| 硬编码进度值 | `frontend/src/components/Dashboard.tsx` | 用户体验 | 获取真实进度 |
|
||||
| Token存储方式 | `frontend/src/api/client.ts` | XSS风险 | 考虑http-only cookie |
|
||||
|
||||
---
|
||||
|
||||
## 安全风险分析
|
||||
|
||||
### 已识别的安全风险
|
||||
|
||||
#### 1. 时序攻击 (中风险)
|
||||
|
||||
**位置**: `inference/web/core/auth.py:46`
|
||||
|
||||
```python
|
||||
# 当前实现(有风险)
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
|
||||
# 安全实现
|
||||
import hmac
|
||||
if not hmac.compare_digest(token, expected_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
```
|
||||
|
||||
#### 2. 文件上传验证不足 (中风险)
|
||||
|
||||
**位置**: `inference/web/api/v1/admin/documents.py:127-131`
|
||||
|
||||
```python
|
||||
# 建议添加魔术字节验证
|
||||
ALLOWED_EXTENSIONS = {".pdf"}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise HTTPException(400, "Invalid PDF file format")
|
||||
```
|
||||
|
||||
#### 3. 路径遍历风险 (中风险)
|
||||
|
||||
**位置**: `inference/web/api/v1/admin/documents.py:494-498`
|
||||
|
||||
```python
|
||||
# 建议实现
|
||||
from pathlib import Path
|
||||
|
||||
def get_safe_path(filename: str, base_dir: Path) -> Path:
|
||||
safe_name = Path(filename).name
|
||||
full_path = (base_dir / safe_name).resolve()
|
||||
if not full_path.is_relative_to(base_dir):
|
||||
raise HTTPException(400, "Invalid file path")
|
||||
return full_path
|
||||
```
|
||||
|
||||
#### 4. CORS配置 (低风险)
|
||||
|
||||
**位置**: FastAPI中间件配置
|
||||
|
||||
```python
|
||||
# 建议生产环境配置
|
||||
ALLOWED_ORIGINS = [
|
||||
"http://localhost:5173",
|
||||
"https://your-domain.com",
|
||||
]
|
||||
```
|
||||
|
||||
#### 5. XSS风险 (低风险)
|
||||
|
||||
**位置**: `frontend/src/api/client.ts:13`
|
||||
|
||||
```typescript
|
||||
// 当前实现
|
||||
const token = localStorage.getItem('admin_token')
|
||||
|
||||
// 建议考虑
|
||||
// 使用http-only cookie存储敏感token
|
||||
```
|
||||
|
||||
### 安全评分
|
||||
|
||||
| 类别 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| 认证 | 8/10 | 基础良好,需加强时序攻击防护 |
|
||||
| 输入验证 | 7/10 | 基本验证到位,需加强文件验证 |
|
||||
| 数据保护 | 8/10 | 无敏感信息硬编码 |
|
||||
| 传输安全 | 8/10 | 使用HTTPS(生产环境) |
|
||||
| 总体 | 7.5/10 | 基础安全良好,需加强细节 |
|
||||
|
||||
---
|
||||
|
||||
## 性能问题
|
||||
|
||||
### 已识别的性能问题
|
||||
|
||||
#### 1. 重复模型加载
|
||||
|
||||
**位置**: `inference/web/services/inference.py:316`
|
||||
|
||||
```python
|
||||
# 问题: 每次可视化都重新加载模型
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
|
||||
# 建议: 复用已加载的模型
|
||||
```
|
||||
|
||||
#### 2. 临时文件处理
|
||||
|
||||
**位置**: `shared/storage/base.py:178-203`
|
||||
|
||||
```python
|
||||
# 问题: bytes操作使用临时文件
|
||||
def upload_bytes(self, data: bytes, ...):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(data)
|
||||
temp_path = Path(f.name)
|
||||
# ...
|
||||
|
||||
# 建议: 子类重写为直接上传
|
||||
```
|
||||
|
||||
#### 3. 数据库查询优化
|
||||
|
||||
**位置**: `inference/data/admin_db.py`
|
||||
|
||||
```python
|
||||
# 问题: N+1查询风险
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
# ...
|
||||
|
||||
# 建议: 使用join预加载
|
||||
```
|
||||
|
||||
### 性能评分
|
||||
|
||||
| 类别 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| 响应时间 | 8/10 | 异步处理良好 |
|
||||
| 资源使用 | 7/10 | 有优化空间 |
|
||||
| 可扩展性 | 7/10 | 内存队列限制 |
|
||||
| 并发处理 | 8/10 | 线程池设计良好 |
|
||||
| 总体 | 7.5/10 | 良好,有优化空间 |
|
||||
|
||||
---
|
||||
|
||||
## 改进建议
|
||||
|
||||
### 立即执行 (本周)
|
||||
|
||||
1. **拆分AdminDB**
|
||||
- 创建 `repositories/` 目录
|
||||
- 按领域拆分:TokenRepository, DocumentRepository, TrainingRepository
|
||||
- 估计工时: 2天
|
||||
|
||||
2. **修复安全漏洞**
|
||||
- 添加 `hmac.compare_digest()` 时序攻击防护
|
||||
- 添加文件魔术字节验证
|
||||
- 估计工时: 0.5天
|
||||
|
||||
3. **提升测试覆盖率**
|
||||
- 优先测试 `inference/pipeline/`
|
||||
- 添加API集成测试
|
||||
- 目标: 从28%提升至50%
|
||||
- 估计工时: 3天
|
||||
|
||||
### 短期执行 (本月)
|
||||
|
||||
4. **引入消息队列**
|
||||
- 添加Redis服务
|
||||
- 使用Celery替换内存队列
|
||||
- 估计工时: 3天
|
||||
|
||||
5. **统一配置管理**
|
||||
- 使用 Pydantic Settings
|
||||
- 集中验证逻辑
|
||||
- 估计工时: 1天
|
||||
|
||||
6. **添加缓存层**
|
||||
- Redis缓存热点数据
|
||||
- 缓存文档、模型配置
|
||||
- 估计工时: 2天
|
||||
|
||||
### 长期执行 (本季度)
|
||||
|
||||
7. **数据库读写分离**
|
||||
- 配置主从数据库
|
||||
- 读操作使用从库
|
||||
- 估计工时: 3天
|
||||
|
||||
8. **事件驱动架构**
|
||||
- 引入事件总线
|
||||
- 解耦模块依赖
|
||||
- 估计工时: 5天
|
||||
|
||||
9. **前端优化**
|
||||
- 添加错误边界
|
||||
- 实现真实搜索功能
|
||||
- 添加E2E测试
|
||||
- 估计工时: 3天
|
||||
|
||||
---
|
||||
|
||||
## 总结与评分
|
||||
|
||||
### 各维度评分
|
||||
|
||||
| 维度 | 评分 | 权重 | 加权得分 |
|
||||
|------|------|------|----------|
|
||||
| **代码质量** | 7.5/10 | 20% | 1.5 |
|
||||
| **安全性** | 7.5/10 | 20% | 1.5 |
|
||||
| **可维护性** | 8/10 | 15% | 1.2 |
|
||||
| **测试覆盖** | 5/10 | 15% | 0.75 |
|
||||
| **性能** | 7.5/10 | 15% | 1.125 |
|
||||
| **文档** | 8/10 | 10% | 0.8 |
|
||||
| **架构设计** | 8/10 | 5% | 0.4 |
|
||||
| **总体** | **7.3/10** | 100% | **7.275** |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
|
||||
2. **代码质量良好**: 类型注解完善,文档详尽,结构清晰
|
||||
3. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
|
||||
4. **测试是短板**: 28%覆盖率是最大风险点
|
||||
5. **生产就绪**: 经过小幅改进后可以投入生产使用
|
||||
|
||||
### 下一步行动
|
||||
|
||||
| 优先级 | 任务 | 预计工时 | 影响 |
|
||||
|--------|------|----------|------|
|
||||
| 高 | 拆分AdminDB | 2天 | 提升可维护性 |
|
||||
| 高 | 引入Redis队列 | 3天 | 解决任务丢失问题 |
|
||||
| 高 | 提升测试覆盖率 | 5天 | 降低代码风险 |
|
||||
| 中 | 修复安全漏洞 | 0.5天 | 提升安全性 |
|
||||
| 中 | 统一配置管理 | 1天 | 减少配置错误 |
|
||||
| 低 | 前端优化 | 3天 | 提升用户体验 |
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 关键文件清单
|
||||
|
||||
| 文件 | 职责 | 问题 |
|
||||
|------|------|------|
|
||||
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
|
||||
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
|
||||
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
|
||||
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
|
||||
| `shared/shared/config.py` | 共享配置 | 分散管理 |
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
|
||||
- [Celery Documentation](https://docs.celeryproject.org/)
|
||||
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
|
||||
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
|
||||
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
|
||||
|
||||
---
|
||||
|
||||
**报告生成时间**: 2026-02-01
|
||||
**审查工具**: Claude Code + AST-grep + LSP
|
||||
637
COMMERCIALIZATION_ANALYSIS_REPORT.md
Normal file
637
COMMERCIALIZATION_ANALYSIS_REPORT.md
Normal file
@@ -0,0 +1,637 @@
|
||||
# Invoice Master POC v2 - 商业化分析报告
|
||||
|
||||
**报告日期**: 2026-02-01
|
||||
**分析人**: Claude Code
|
||||
**项目**: Invoice Master - 瑞典发票字段自动提取系统
|
||||
**当前状态**: POC阶段,已处理9,738份文档,字段匹配率94.8%
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
1. [执行摘要](#执行摘要)
|
||||
2. [市场分析](#市场分析)
|
||||
3. [商业模式建议](#商业模式建议)
|
||||
4. [技术架构商业化评估](#技术架构商业化评估)
|
||||
5. [商业化路线图](#商业化路线图)
|
||||
6. [风险与挑战](#风险与挑战)
|
||||
7. [成本与定价策略](#成本与定价策略)
|
||||
8. [竞争分析](#竞争分析)
|
||||
9. [改进建议](#改进建议)
|
||||
10. [总结与建议](#总结与建议)
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
### 项目现状
|
||||
|
||||
Invoice Master是一个基于YOLOv11 + PaddleOCR的瑞典发票字段自动提取系统,具备以下核心能力:
|
||||
|
||||
| 指标 | 数值 | 评估 |
|
||||
|------|------|------|
|
||||
| 已处理文档 | 9,738份 | 数据基础良好 |
|
||||
| 字段匹配率 | 94.8% | 接近商业化标准 |
|
||||
| 模型mAP@0.5 | 93.5% | 业界优秀水平 |
|
||||
| 测试覆盖率 | 28% | 需大幅提升 |
|
||||
| 架构成熟度 | 7.3/10 | 基本就绪 |
|
||||
|
||||
### 商业化可行性评估
|
||||
|
||||
| 维度 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| **技术成熟度** | 7.5/10 | 核心算法成熟,需完善工程化 |
|
||||
| **市场需求** | 8/10 | 发票处理是刚需市场 |
|
||||
| **竞争壁垒** | 6/10 | 技术可替代,需构建数据壁垒 |
|
||||
| **商业化就绪度** | 6.5/10 | 需完成产品化和合规准备 |
|
||||
| **总体评估** | **7/10** | **具备商业化潜力,需6-12个月准备** |
|
||||
|
||||
### 关键建议
|
||||
|
||||
1. **短期(3个月)**: 提升测试覆盖率至80%,完成安全加固
|
||||
2. **中期(6个月)**: 推出MVP产品,获取首批付费客户
|
||||
3. **长期(12个月)**: 扩展多语言支持,进入国际市场
|
||||
|
||||
---
|
||||
|
||||
## 市场分析
|
||||
|
||||
### 目标市场
|
||||
|
||||
#### 1.1 市场规模
|
||||
|
||||
**全球发票处理市场**
|
||||
- 市场规模: ~$30B (2024)
|
||||
- 年增长率: 12-15%
|
||||
- 驱动因素: 数字化转型、合规要求、成本节约
|
||||
|
||||
**瑞典/北欧市场**
|
||||
- 中小企业数量: ~100万+
|
||||
- 大型企业: ~2,000家
|
||||
- 年发票处理量: ~5亿张
|
||||
- 市场特点: 数字化程度高,合规要求严格
|
||||
|
||||
#### 1.2 目标客户画像
|
||||
|
||||
| 客户类型 | 规模 | 痛点 | 付费意愿 | 获取难度 |
|
||||
|----------|------|------|----------|----------|
|
||||
| **中小企业** | 10-100人 | 手动录入耗时 | 中 | 低 |
|
||||
| **会计事务所** | 5-50人 | 批量处理需求 | 高 | 中 |
|
||||
| **大型企业** | 500+人 | 系统集成需求 | 高 | 高 |
|
||||
| **SaaS平台** | - | API集成需求 | 中 | 中 |
|
||||
|
||||
### 市场需求验证
|
||||
|
||||
#### 2.1 痛点分析
|
||||
|
||||
**现有解决方案的问题:**
|
||||
1. **传统OCR**: 准确率70-85%,需要大量人工校对
|
||||
2. **人工录入**: 成本高($0.5-2/张),速度慢,易出错
|
||||
3. **现有AI方案**: 价格昂贵,定制化程度低
|
||||
|
||||
**Invoice Master的优势:**
|
||||
- 准确率94.8%,接近人工水平
|
||||
- 支持瑞典特有的字段(OCR参考号、Bankgiro/Plusgiro)
|
||||
- 可定制化训练,适应不同发票格式
|
||||
|
||||
#### 2.2 市场进入策略
|
||||
|
||||
**第一阶段: 瑞典市场验证**
|
||||
- 目标客户: 中型会计事务所
|
||||
- 价值主张: 减少80%人工录入时间
|
||||
- 定价: $0.1-0.2/张 或 $99-299/月
|
||||
|
||||
**第二阶段: 北欧扩展**
|
||||
- 扩展至挪威、丹麦、芬兰
|
||||
- 适配各国发票格式
|
||||
- 建立本地合作伙伴网络
|
||||
|
||||
**第三阶段: 欧洲市场**
|
||||
- 支持多语言(德语、法语、英语)
|
||||
- GDPR合规认证
|
||||
- 与主流ERP系统集成
|
||||
|
||||
---
|
||||
|
||||
## 商业模式建议
|
||||
|
||||
### 3.1 商业模式选项
|
||||
|
||||
#### 选项A: SaaS订阅模式 (推荐)
|
||||
|
||||
**定价结构:**
|
||||
```
|
||||
Starter: $99/月
|
||||
- 500张发票/月
|
||||
- 基础字段提取
|
||||
- 邮件支持
|
||||
|
||||
Professional: $299/月
|
||||
- 2,000张发票/月
|
||||
- 所有字段+自定义字段
|
||||
- API访问
|
||||
- 优先支持
|
||||
|
||||
Enterprise: 定制报价
|
||||
- 无限发票
|
||||
- 私有部署选项
|
||||
- SLA保障
|
||||
- 专属客户经理
|
||||
```
|
||||
|
||||
**优势:**
|
||||
- 可预测的经常性收入
|
||||
- 客户生命周期价值高
|
||||
- 易于扩展
|
||||
|
||||
**劣势:**
|
||||
- 需要持续的产品迭代
|
||||
- 客户获取成本较高
|
||||
|
||||
#### 选项B: 按量付费模式
|
||||
|
||||
**定价:**
|
||||
- 前100张: $0.15/张
|
||||
- 101-1000张: $0.10/张
|
||||
- 1001+张: $0.05/张
|
||||
|
||||
**适用场景:**
|
||||
- 季节性业务
|
||||
- 初创企业
|
||||
- 不确定使用量的客户
|
||||
|
||||
#### 选项C: 授权许可模式
|
||||
|
||||
**定价:**
|
||||
- 年度许可: $10,000-50,000
|
||||
- 按部署规模收费
|
||||
- 包含培训和定制开发
|
||||
|
||||
**适用场景:**
|
||||
- 大型企业
|
||||
- 数据敏感行业
|
||||
- 需要私有部署的客户
|
||||
|
||||
### 3.2 推荐模式: 混合模式
|
||||
|
||||
**核心产品: SaaS订阅**
|
||||
- 面向中小企业和会计事务所
|
||||
- 标准化产品,快速交付
|
||||
|
||||
**增值服务: 定制开发**
|
||||
- 面向大型企业
|
||||
- 私有部署选项
|
||||
- 按项目收费
|
||||
|
||||
**API服务: 按量付费**
|
||||
- 面向SaaS平台和开发者
|
||||
- 开发者友好定价
|
||||
|
||||
### 3.3 收入预测
|
||||
|
||||
**保守估计 (第一年)**
|
||||
| 客户类型 | 客户数 | ARPU | MRR | 年收入 |
|
||||
|----------|--------|------|-----|--------|
|
||||
| Starter | 20 | $99 | $1,980 | $23,760 |
|
||||
| Professional | 10 | $299 | $2,990 | $35,880 |
|
||||
| Enterprise | 2 | $2,000 | $4,000 | $48,000 |
|
||||
| **总计** | **32** | - | **$8,970** | **$107,640** |
|
||||
|
||||
**乐观估计 (第一年)**
|
||||
- 客户数: 100+
|
||||
- 年收入: $300,000-500,000
|
||||
|
||||
---
|
||||
|
||||
## 技术架构商业化评估
|
||||
|
||||
### 4.1 架构优势
|
||||
|
||||
| 优势 | 说明 | 商业化价值 |
|
||||
|------|------|-----------|
|
||||
| **Monorepo结构** | 代码组织清晰 | 降低维护成本 |
|
||||
| **云原生架构** | 支持AWS/Azure | 灵活部署选项 |
|
||||
| **存储抽象层** | 支持多后端 | 满足不同客户需求 |
|
||||
| **模型版本管理** | 可追溯可回滚 | 企业级可靠性 |
|
||||
| **API优先设计** | RESTful API | 易于集成和扩展 |
|
||||
|
||||
### 4.2 商业化就绪度评估
|
||||
|
||||
#### 高优先级改进项
|
||||
|
||||
| 问题 | 影响 | 改进建议 | 工时 |
|
||||
|------|------|----------|------|
|
||||
| **测试覆盖率28%** | 质量风险 | 提升至80%+ | 4周 |
|
||||
| **AdminDB过大** | 维护困难 | 拆分Repository | 2周 |
|
||||
| **内存队列** | 单点故障 | 引入Redis | 2周 |
|
||||
| **安全漏洞** | 合规风险 | 修复时序攻击等 | 1周 |
|
||||
|
||||
#### 中优先级改进项
|
||||
|
||||
| 问题 | 影响 | 改进建议 | 工时 |
|
||||
|------|------|----------|------|
|
||||
| **缺少审计日志** | 合规要求 | 添加完整审计 | 2周 |
|
||||
| **无多租户隔离** | 数据安全 | 实现租户隔离 | 3周 |
|
||||
| **限流器内存存储** | 扩展性 | Redis分布式限流 | 1周 |
|
||||
| **配置分散** | 运维难度 | 统一配置中心 | 1周 |
|
||||
|
||||
### 4.3 技术债务清理计划
|
||||
|
||||
**阶段1: 基础加固 (4周)**
|
||||
- 提升测试覆盖率至60%
|
||||
- 修复安全漏洞
|
||||
- 添加基础监控
|
||||
|
||||
**阶段2: 架构优化 (6周)**
|
||||
- 拆分AdminDB
|
||||
- 引入消息队列
|
||||
- 实现多租户支持
|
||||
|
||||
**阶段3: 企业级功能 (8周)**
|
||||
- 完整审计日志
|
||||
- SSO集成
|
||||
- 高级权限管理
|
||||
|
||||
---
|
||||
|
||||
## 商业化路线图
|
||||
|
||||
### 5.1 时间线规划
|
||||
|
||||
```
|
||||
Month 1-3: 产品化准备
|
||||
├── 技术债务清理
|
||||
├── 安全加固
|
||||
├── 测试覆盖率提升
|
||||
└── 文档完善
|
||||
|
||||
Month 4-6: MVP发布
|
||||
├── 核心功能稳定
|
||||
├── 基础监控告警
|
||||
├── 客户反馈收集
|
||||
└── 定价策略验证
|
||||
|
||||
Month 7-9: 市场扩展
|
||||
├── 销售团队组建
|
||||
├── 合作伙伴网络
|
||||
├── 案例研究制作
|
||||
└── 营销自动化
|
||||
|
||||
Month 10-12: 规模化
|
||||
├── 多语言支持
|
||||
├── 高级功能开发
|
||||
├── 国际市场准备
|
||||
└── 融资准备
|
||||
```
|
||||
|
||||
### 5.2 里程碑
|
||||
|
||||
| 里程碑 | 时间 | 成功标准 |
|
||||
|--------|------|----------|
|
||||
| **技术就绪** | M3 | 测试80%,零高危漏洞 |
|
||||
| **首个付费客户** | M4 | 签约并上线 |
|
||||
| **产品市场契合** | M6 | 10+付费客户,NPS>40 |
|
||||
| **盈亏平衡** | M9 | MRR覆盖运营成本 |
|
||||
| **规模化准备** | M12 | 100+客户,$50K+MRR |
|
||||
|
||||
### 5.3 团队组建建议
|
||||
|
||||
**核心团队 (前6个月)**
|
||||
| 角色 | 人数 | 职责 |
|
||||
|------|------|------|
|
||||
| 技术负责人 | 1 | 架构、技术决策 |
|
||||
| 全栈工程师 | 2 | 产品开发 |
|
||||
| ML工程师 | 1 | 模型优化 |
|
||||
| 产品经理 | 1 | 产品规划 |
|
||||
| 销售/BD | 1 | 客户获取 |
|
||||
|
||||
**扩展团队 (6-12个月)**
|
||||
| 角色 | 人数 | 职责 |
|
||||
|------|------|------|
|
||||
| 客户成功 | 1 | 客户留存 |
|
||||
| 市场营销 | 1 | 品牌建设 |
|
||||
| 技术支持 | 1 | 客户支持 |
|
||||
|
||||
---
|
||||
|
||||
## 风险与挑战
|
||||
|
||||
### 6.1 技术风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **模型准确率下降** | 中 | 高 | 持续训练,A/B测试 |
|
||||
| **系统稳定性** | 中 | 高 | 完善监控,灰度发布 |
|
||||
| **数据安全漏洞** | 低 | 高 | 安全审计,渗透测试 |
|
||||
| **扩展性瓶颈** | 中 | 中 | 架构优化,负载测试 |
|
||||
|
||||
### 6.2 市场风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **竞争加剧** | 高 | 中 | 差异化定位,垂直深耕 |
|
||||
| **价格战** | 中 | 中 | 价值定价,增值服务 |
|
||||
| **客户获取困难** | 中 | 高 | 内容营销,口碑传播 |
|
||||
| **市场教育成本** | 中 | 中 | 免费试用,案例展示 |
|
||||
|
||||
### 6.3 合规风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **GDPR合规** | 高 | 高 | 隐私设计,数据本地化 |
|
||||
| **数据主权** | 中 | 高 | 多区域部署选项 |
|
||||
| **行业认证** | 中 | 中 | ISO27001, SOC2准备 |
|
||||
|
||||
### 6.4 财务风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **现金流紧张** | 中 | 高 | 预付费模式,成本控制 |
|
||||
| **客户流失** | 中 | 中 | 客户成功,年度合同 |
|
||||
| **定价失误** | 中 | 中 | 灵活定价,快速迭代 |
|
||||
|
||||
---
|
||||
|
||||
## 成本与定价策略
|
||||
|
||||
### 7.1 运营成本估算
|
||||
|
||||
**月度运营成本 (AWS)**
|
||||
| 项目 | 成本 | 说明 |
|
||||
|------|------|------|
|
||||
| 计算 (ECS Fargate) | $150 | 推理服务 |
|
||||
| 数据库 (RDS) | $50 | PostgreSQL |
|
||||
| 存储 (S3) | $20 | 文档和模型 |
|
||||
| 训练 (SageMaker) | $100 | 按需训练 |
|
||||
| 监控/日志 | $30 | CloudWatch等 |
|
||||
| **小计** | **$350** | **基础运营成本** |
|
||||
|
||||
**月度运营成本 (Azure)**
|
||||
| 项目 | 成本 | 说明 |
|
||||
|------|------|------|
|
||||
| 计算 (Container Apps) | $180 | 推理服务 |
|
||||
| 数据库 | $60 | PostgreSQL |
|
||||
| 存储 | $25 | Blob Storage |
|
||||
| 训练 | $120 | Azure ML |
|
||||
| **小计** | **$385** | **基础运营成本** |
|
||||
|
||||
**人力成本 (月度)**
|
||||
| 阶段 | 人数 | 成本 |
|
||||
|------|------|------|
|
||||
| 启动期 (1-3月) | 3 | $15,000 |
|
||||
| 成长期 (4-9月) | 5 | $25,000 |
|
||||
| 规模化 (10-12月) | 7 | $35,000 |
|
||||
|
||||
### 7.2 定价策略
|
||||
|
||||
**成本加成定价**
|
||||
- 基础成本: $350/月
|
||||
- 目标毛利率: 70%
|
||||
- 最低收费: $1,000/月
|
||||
|
||||
**价值定价**
|
||||
- 客户节省成本: $2-5/张 (人工录入)
|
||||
- 收费: $0.1-0.2/张
|
||||
- 客户ROI: 10-50x
|
||||
|
||||
**竞争定价**
|
||||
- 竞争对手: $0.2-0.5/张
|
||||
- 我们的定价: $0.1-0.15/张
|
||||
- 策略: 高性价比切入
|
||||
|
||||
### 7.3 盈亏平衡分析
|
||||
|
||||
**固定成本: $25,000/月** (人力+基础设施)
|
||||
|
||||
**盈亏平衡点:**
|
||||
- 按订阅模式: 85个Professional客户 或 250个Starter客户
|
||||
- 按量付费: 250,000张发票/月
|
||||
|
||||
**目标 (12个月):**
|
||||
- MRR: $50,000
|
||||
- 客户数: 150
|
||||
- 毛利率: 75%
|
||||
|
||||
---
|
||||
|
||||
## 竞争分析
|
||||
|
||||
### 8.1 竞争对手
|
||||
|
||||
#### 直接竞争对手
|
||||
|
||||
| 公司 | 产品 | 优势 | 劣势 | 定价 |
|
||||
|------|------|------|------|------|
|
||||
| **Rossum** | AI发票处理 | 技术成熟,欧洲市场强 | 价格高 | $0.3-0.5/张 |
|
||||
| **Hypatos** | 文档AI | 德国市场深耕 | 定制化弱 | 定制报价 |
|
||||
| **Klippa** | 文档解析 | API友好 | 准确率一般 | $0.1-0.2/张 |
|
||||
| **Nanonets** | 工作流自动化 | 易用性好 | 发票专业性弱 | $0.05-0.15/张 |
|
||||
|
||||
#### 间接竞争对手
|
||||
|
||||
| 类型 | 代表 | 威胁程度 |
|
||||
|------|------|----------|
|
||||
| **传统OCR** | ABBYY, Tesseract | 中 |
|
||||
| **ERP内置** | SAP, Oracle | 中 |
|
||||
| **会计软件** | Visma, Fortnox | 高 |
|
||||
|
||||
### 8.2 竞争优势
|
||||
|
||||
**短期优势 (6-12个月)**
|
||||
1. **瑞典市场专注**: 本地化字段支持
|
||||
2. **价格优势**: 比Rossum便宜50%+
|
||||
3. **定制化**: 可训练专属模型
|
||||
|
||||
**长期优势 (1-3年)**
|
||||
1. **数据壁垒**: 训练数据积累
|
||||
2. **行业深度**: 垂直行业解决方案
|
||||
3. **生态集成**: 与主流ERP深度集成
|
||||
|
||||
### 8.3 竞争策略
|
||||
|
||||
**差异化定位**
|
||||
- 不做通用文档处理,专注发票领域
|
||||
- 不做全球市场,先做透北欧
|
||||
- 不做低价竞争,做高性价比
|
||||
|
||||
**护城河构建**
|
||||
1. **数据壁垒**: 客户发票数据训练
|
||||
2. **转换成本**: 系统集成和工作流
|
||||
3. **网络效应**: 行业模板共享
|
||||
|
||||
---
|
||||
|
||||
## 改进建议
|
||||
|
||||
### 9.1 产品改进
|
||||
|
||||
#### 高优先级
|
||||
|
||||
| 改进项 | 说明 | 商业价值 | 工时 |
|
||||
|--------|------|----------|------|
|
||||
| **多语言支持** | 英语、德语、法语 | 扩大市场 | 4周 |
|
||||
| **批量处理API** | 支持千级批量 | 大客户必需 | 2周 |
|
||||
| **实时处理** | <3秒响应 | 用户体验 | 2周 |
|
||||
| **置信度阈值** | 用户可配置 | 灵活性 | 1周 |
|
||||
|
||||
#### 中优先级
|
||||
|
||||
| 改进项 | 说明 | 商业价值 | 工时 |
|
||||
|--------|------|----------|------|
|
||||
| **移动端适配** | 手机拍照上传 | 便利性 | 3周 |
|
||||
| **PDF预览** | 在线查看和标注 | 用户体验 | 2周 |
|
||||
| **导出格式** | Excel, JSON, XML | 集成便利 | 1周 |
|
||||
| **Webhook** | 事件通知 | 自动化 | 1周 |
|
||||
|
||||
### 9.2 技术改进
|
||||
|
||||
#### 架构优化
|
||||
|
||||
```
|
||||
当前架构问题:
|
||||
├── 内存队列 → 改为Redis队列
|
||||
├── 单体DB → 读写分离
|
||||
├── 同步处理 → 异步优先
|
||||
└── 单区域 → 多区域部署
|
||||
```
|
||||
|
||||
#### 性能优化
|
||||
|
||||
| 优化项 | 当前 | 目标 | 方法 |
|
||||
|--------|------|------|------|
|
||||
| 推理延迟 | 500ms | 200ms | 模型量化 |
|
||||
| 并发处理 | 10 QPS | 100 QPS | 水平扩展 |
|
||||
| 系统可用性 | 99% | 99.9% | 冗余设计 |
|
||||
|
||||
### 9.3 运营改进
|
||||
|
||||
#### 客户成功
|
||||
|
||||
- 入职流程: 30分钟完成首次提取
|
||||
- 培训材料: 视频教程+文档
|
||||
- 支持响应: <4小时响应时间
|
||||
- 客户健康度: 自动监控和预警
|
||||
|
||||
#### 销售流程
|
||||
|
||||
1. **线索获取**: 内容营销+SEO
|
||||
2. **试用转化**: 14天免费试用
|
||||
3. **付费转化**: 客户成功跟进
|
||||
4. **扩展销售**: 功能升级推荐
|
||||
|
||||
---
|
||||
|
||||
## 总结与建议
|
||||
|
||||
### 10.1 商业化可行性结论
|
||||
|
||||
**总体评估: 可行,需6-12个月准备**
|
||||
|
||||
Invoice Master具备商业化的技术基础和市场机会,但需要完成以下关键准备:
|
||||
|
||||
1. **技术债务清理**: 测试覆盖率、安全加固
|
||||
2. **产品化完善**: 多租户、审计日志、监控
|
||||
3. **市场验证**: 获取首批付费客户
|
||||
4. **团队组建**: 销售和客户成功团队
|
||||
|
||||
### 10.2 关键成功因素
|
||||
|
||||
| 因素 | 重要性 | 当前状态 | 行动计划 |
|
||||
|------|--------|----------|----------|
|
||||
| **技术稳定性** | 高 | 中 | 测试+监控 |
|
||||
| **客户获取** | 高 | 低 | 内容营销 |
|
||||
| **产品市场契合** | 高 | 未验证 | 快速迭代 |
|
||||
| **团队能力** | 高 | 中 | 招聘培训 |
|
||||
| **资金储备** | 中 | 未知 | 融资准备 |
|
||||
|
||||
### 10.3 行动计划
|
||||
|
||||
#### 立即执行 (本月)
|
||||
|
||||
- [ ] 制定详细的技术债务清理计划
|
||||
- [ ] 启动安全审计和漏洞修复
|
||||
- [ ] 设计多租户架构方案
|
||||
- [ ] 准备融资材料或预算规划
|
||||
|
||||
#### 短期目标 (3个月)
|
||||
|
||||
- [ ] 测试覆盖率提升至80%
|
||||
- [ ] 完成安全加固和合规准备
|
||||
- [ ] 发布Beta版本给5-10个试用客户
|
||||
- [ ] 确定最终定价策略
|
||||
|
||||
#### 中期目标 (6个月)
|
||||
|
||||
- [ ] 获得10+付费客户
|
||||
- [ ] MRR达到$10,000
|
||||
- [ ] 完成产品市场契合验证
|
||||
- [ ] 组建完整团队
|
||||
|
||||
#### 长期目标 (12个月)
|
||||
|
||||
- [ ] 100+付费客户
|
||||
- [ ] MRR达到$50,000
|
||||
- [ ] 扩展到2-3个新市场
|
||||
- [ ] 完成A轮融资或实现盈利
|
||||
|
||||
### 10.4 最终建议
|
||||
|
||||
**建议: 继续推进商业化,但需谨慎执行**
|
||||
|
||||
Invoice Master是一个技术扎实、市场机会明确的项目。当前94.8%的准确率已经接近商业化标准,但需要投入资源完成工程化和产品化。
|
||||
|
||||
**关键决策点:**
|
||||
1. **是否投入商业化**: 是,但分阶段投入
|
||||
2. **目标市场**: 先做透瑞典,再扩展北欧
|
||||
3. **商业模式**: SaaS订阅为主,定制为辅
|
||||
4. **融资需求**: 建议准备$200K-500K种子资金
|
||||
|
||||
**成功概率评估: 65%**
|
||||
- 技术可行性: 80%
|
||||
- 市场接受度: 70%
|
||||
- 执行能力: 60%
|
||||
- 竞争环境: 50%
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### A. 关键指标追踪
|
||||
|
||||
| 指标 | 当前 | 3个月目标 | 6个月目标 | 12个月目标 |
|
||||
|------|------|-----------|-----------|------------|
|
||||
| 测试覆盖率 | 28% | 60% | 80% | 85% |
|
||||
| 系统可用性 | - | 99.5% | 99.9% | 99.95% |
|
||||
| 客户数 | 0 | 5 | 20 | 150 |
|
||||
| MRR | $0 | $500 | $10,000 | $50,000 |
|
||||
| NPS | - | - | >40 | >50 |
|
||||
| 客户流失率 | - | - | <5%/月 | <3%/月 |
|
||||
|
||||
### B. 资源需求
|
||||
|
||||
**资金需求**
|
||||
| 阶段 | 时间 | 金额 | 用途 |
|
||||
|------|------|------|------|
|
||||
| 种子期 | 0-6月 | $100K | 团队+基础设施 |
|
||||
| 成长期 | 6-12月 | $300K | 市场+团队扩展 |
|
||||
| A轮 | 12-18月 | $1M+ | 规模化+国际 |
|
||||
|
||||
**人力需求**
|
||||
| 阶段 | 团队规模 | 关键角色 |
|
||||
|------|----------|----------|
|
||||
| 启动 | 3-4人 | 技术+产品+销售 |
|
||||
| 验证 | 5-6人 | +客户成功 |
|
||||
| 增长 | 8-10人 | +市场+技术支持 |
|
||||
|
||||
### C. 参考资源
|
||||
|
||||
- [SaaS Metrics Guide](https://www.saasmetrics.co/)
|
||||
- [GDPR Compliance Checklist](https://gdpr.eu/checklist/)
|
||||
- [B2B SaaS Pricing Guide](https://www.priceintelligently.com/)
|
||||
- [Nordic Startup Ecosystem](https://www.nordicstartupnews.com/)
|
||||
|
||||
---
|
||||
|
||||
**报告完成日期**: 2026-02-01
|
||||
**下次评审日期**: 2026-03-01
|
||||
**版本**: v1.0
|
||||
419
PROJECT_REVIEW.md
Normal file
419
PROJECT_REVIEW.md
Normal file
@@ -0,0 +1,419 @@
|
||||
# Invoice Master POC v2 - 项目审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
|
||||
|
||||
---
|
||||
|
||||
## 项目概述
|
||||
|
||||
**Invoice Master POC v2** - 基于 YOLOv11 + PaddleOCR 的瑞典发票字段自动提取系统
|
||||
|
||||
### 核心功能
|
||||
- **自动标注**: 利用 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
|
||||
- **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
|
||||
- **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化
|
||||
- **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
|
||||
|
||||
### 架构设计
|
||||
采用 **Monorepo + 三包分离** 架构:
|
||||
|
||||
```
|
||||
packages/
|
||||
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 存储, 训练)
|
||||
├── training/ # 训练服务 (GPU, 按需启动)
|
||||
└── inference/ # 推理服务 (常驻运行)
|
||||
frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|
||||
```
|
||||
|
||||
### 性能指标
|
||||
|
||||
| 指标 | 数值 |
|
||||
|------|------|
|
||||
| **已标注文档** | 9,738 (9,709 成功) |
|
||||
| **总体字段匹配率** | 94.8% (82,604/87,121) |
|
||||
| **测试** | 1,601 passed |
|
||||
| **测试覆盖率** | 28% |
|
||||
| **模型 mAP@0.5** | 93.5% |
|
||||
|
||||
---
|
||||
|
||||
## 安全性审查
|
||||
|
||||
### 检查清单
|
||||
|
||||
| 检查项 | 状态 | 说明 | 文件位置 |
|
||||
|--------|------|------|----------|
|
||||
| **Secrets 管理** | ✅ 良好 | 使用 `.env` 文件,`DB_PASSWORD` 无默认值 | `packages/shared/shared/config.py:46` |
|
||||
| **SQL 注入防护** | ✅ 良好 | 使用参数化查询 | 全项目 |
|
||||
| **认证机制** | ✅ 良好 | Admin token 验证 + 数据库持久化 | `packages/inference/inference/web/core/auth.py` |
|
||||
| **输入验证** | ⚠️ 需改进 | 部分端点缺少文件类型/大小验证 | Web API 端点 |
|
||||
| **路径遍历防护** | ⚠️ 需检查 | 需确认文件上传路径验证 | 文件上传处理 |
|
||||
| **CORS 配置** | ❓ 待查 | 需确认生产环境配置 | FastAPI 中间件 |
|
||||
| **Rate Limiting** | ✅ 良好 | 已实现核心限流器 | `packages/inference/inference/web/core/rate_limiter.py` |
|
||||
| **错误处理** | ✅ 良好 | Web 层 356 处异常处理 | 全项目 |
|
||||
|
||||
### 详细发现
|
||||
|
||||
#### ✅ 安全实践良好的方面
|
||||
|
||||
1. **环境变量管理**
|
||||
- 使用 `python-dotenv` 加载 `.env` 文件
|
||||
- 数据库密码没有默认值,强制要求设置
|
||||
- 验证逻辑在配置加载时执行
|
||||
|
||||
2. **认证实现**
|
||||
- Token 存储在 PostgreSQL 数据库
|
||||
- 支持 Token 过期检查
|
||||
- 记录最后使用时间
|
||||
|
||||
3. **存储抽象层**
|
||||
- 支持 Local/Azure/S3 多后端
|
||||
- 通过环境变量配置,无硬编码凭证
|
||||
|
||||
#### ⚠️ 需要改进的安全问题
|
||||
|
||||
1. **时序攻击防护**
|
||||
- **位置**: `packages/inference/inference/web/core/auth.py:46`
|
||||
- **问题**: Token 验证使用普通字符串比较
|
||||
- **建议**: 使用 `hmac.compare_digest()` 进行 constant-time 比较
|
||||
- **风险等级**: 中
|
||||
|
||||
2. **文件上传验证**
|
||||
- **位置**: Web API 文件上传端点
|
||||
- **问题**: 需确认是否验证文件魔数 (magic bytes)
|
||||
- **建议**: 添加 PDF 文件签名验证 (`%PDF`)
|
||||
- **风险等级**: 中
|
||||
|
||||
3. **路径遍历风险**
|
||||
- **位置**: 文件下载/访问端点
|
||||
- **问题**: 需确认文件名是否经过净化处理
|
||||
- **建议**: 使用 `pathlib.Path.name` 提取文件名,验证路径范围
|
||||
- **风险等级**: 中
|
||||
|
||||
4. **CORS 配置**
|
||||
- **位置**: FastAPI 中间件配置
|
||||
- **问题**: 需确认生产环境是否允许所有来源
|
||||
- **建议**: 生产环境明确指定允许的 origins
|
||||
- **风险等级**: 低
|
||||
|
||||
---
|
||||
|
||||
## 代码质量审查
|
||||
|
||||
### 代码风格与规范
|
||||
|
||||
| 检查项 | 状态 | 说明 |
|
||||
|--------|------|------|
|
||||
| **类型注解** | ✅ 优秀 | 广泛使用 Type hints,覆盖率 > 90% |
|
||||
| **命名规范** | ✅ 良好 | 遵循 PEP 8,snake_case 命名 |
|
||||
| **文档字符串** | ✅ 良好 | 主要模块和函数都有文档 |
|
||||
| **异常处理** | ✅ 良好 | Web 层 356 处异常处理 |
|
||||
| **代码组织** | ✅ 优秀 | 模块化结构清晰,职责分离明确 |
|
||||
| **文件大小** | ⚠️ 需关注 | 部分文件超过 800 行 |
|
||||
|
||||
### 架构设计评估
|
||||
|
||||
#### 优秀的设计决策
|
||||
|
||||
1. **Monorepo 结构**
|
||||
- 清晰的包边界 (shared/training/inference)
|
||||
- 避免循环依赖
|
||||
- 便于独立部署
|
||||
|
||||
2. **存储抽象层**
|
||||
- 统一的 `StorageBackend` 接口
|
||||
- 支持本地/Azure/S3 无缝切换
|
||||
- 预签名 URL 支持
|
||||
|
||||
3. **配置管理**
|
||||
- 使用 dataclass 定义配置
|
||||
- 环境变量 + 配置文件混合
|
||||
- 类型安全
|
||||
|
||||
4. **数据库设计**
|
||||
- 合理的表结构
|
||||
- 状态机设计 (pending → running → completed)
|
||||
- 外键约束完整
|
||||
|
||||
#### 需要改进的方面
|
||||
|
||||
1. **测试覆盖率偏低**
|
||||
- 当前: 28%
|
||||
- 目标: 60%+
|
||||
- 优先测试核心业务逻辑
|
||||
|
||||
2. **部分文件过大**
|
||||
- 建议拆分为多个小文件
|
||||
- 单一职责原则
|
||||
|
||||
3. **缺少集成测试**
|
||||
- 建议添加端到端测试
|
||||
- API 契约测试
|
||||
|
||||
---
|
||||
|
||||
## 最佳实践遵循情况
|
||||
|
||||
### 已遵循的最佳实践
|
||||
|
||||
| 实践 | 实现状态 | 说明 |
|
||||
|------|----------|------|
|
||||
| **环境变量配置** | ✅ | 所有配置通过环境变量 |
|
||||
| **数据库连接池** | ✅ | 使用 SQLModel + psycopg2 |
|
||||
| **异步处理** | ✅ | FastAPI + async/await |
|
||||
| **存储抽象层** | ✅ | 支持 Local/Azure/S3 |
|
||||
| **Docker 容器化** | ✅ | 每个服务独立 Dockerfile |
|
||||
| **数据增强** | ✅ | 12 种增强策略 |
|
||||
| **模型版本管理** | ✅ | model_versions 表 |
|
||||
| **限流保护** | ✅ | Rate limiter 实现 |
|
||||
| **日志记录** | ✅ | 结构化日志 |
|
||||
| **类型安全** | ✅ | 全面 Type hints |
|
||||
|
||||
### 技术栈评估
|
||||
|
||||
| 组件 | 技术选择 | 评估 |
|
||||
|------|----------|------|
|
||||
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
|
||||
| **OCR 引擎** | PaddleOCR v5 | ✅ 支持瑞典语 |
|
||||
| **PDF 处理** | PyMuPDF (fitz) | ✅ 功能强大 |
|
||||
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 |
|
||||
| **Web 框架** | FastAPI + Uvicorn | ✅ 高性能 |
|
||||
| **前端** | React + TypeScript + Vite | ✅ 现代栈 |
|
||||
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 关键文件详细分析
|
||||
|
||||
### 1. 配置文件
|
||||
|
||||
#### `packages/shared/shared/config.py`
|
||||
- **安全性**: ✅ 密码从环境变量读取,无默认值
|
||||
- **代码质量**: ✅ 清晰的配置结构
|
||||
- **建议**: 考虑使用 Pydantic Settings 进行验证
|
||||
|
||||
#### `packages/inference/inference/web/config.py`
|
||||
- **安全性**: ✅ 无敏感信息硬编码
|
||||
- **代码质量**: ✅ 使用 frozen dataclass
|
||||
- **建议**: 添加配置验证逻辑
|
||||
|
||||
### 2. 认证模块
|
||||
|
||||
#### `packages/inference/inference/web/core/auth.py`
|
||||
- **安全性**: ⚠️ 需添加 constant-time 比较
|
||||
- **代码质量**: ✅ 依赖注入模式
|
||||
- **建议**:
|
||||
```python
|
||||
import hmac
|
||||
if not hmac.compare_digest(api_key, settings.api_key):
|
||||
raise HTTPException(403, "Invalid API key")
|
||||
```
|
||||
|
||||
### 3. 限流器
|
||||
|
||||
#### `packages/inference/inference/web/core/rate_limiter.py`
|
||||
- **安全性**: ✅ 内存限流实现
|
||||
- **代码质量**: ✅ 清晰的接口设计
|
||||
- **建议**: 生产环境考虑 Redis 分布式限流
|
||||
|
||||
### 4. 存储层
|
||||
|
||||
#### `packages/shared/shared/storage/`
|
||||
- **安全性**: ✅ 无凭证硬编码
|
||||
- **代码质量**: ✅ 抽象接口设计
|
||||
- **建议**: 添加文件类型验证
|
||||
|
||||
---
|
||||
|
||||
## 性能与可扩展性
|
||||
|
||||
### 当前性能
|
||||
|
||||
| 指标 | 数值 | 评估 |
|
||||
|------|------|------|
|
||||
| **字段匹配率** | 94.8% | ✅ 优秀 |
|
||||
| **模型 mAP@0.5** | 93.5% | ✅ 优秀 |
|
||||
| **测试执行时间** | - | 待测量 |
|
||||
| **API 响应时间** | - | 待测量 |
|
||||
|
||||
### 可扩展性评估
|
||||
|
||||
| 方面 | 评估 | 说明 |
|
||||
|------|------|------|
|
||||
| **水平扩展** | ✅ 良好 | 无状态服务设计 |
|
||||
| **垂直扩展** | ✅ 良好 | 支持 GPU 加速 |
|
||||
| **数据库扩展** | ⚠️ 需关注 | 单 PostgreSQL 实例 |
|
||||
| **存储扩展** | ✅ 良好 | 云存储抽象层 |
|
||||
|
||||
---
|
||||
|
||||
## 风险评估
|
||||
|
||||
### 高风险项
|
||||
|
||||
1. **测试覆盖率低 (28%)**
|
||||
- **影响**: 代码变更风险高
|
||||
- **缓解**: 制定测试计划,优先覆盖核心逻辑
|
||||
|
||||
2. **文件上传安全**
|
||||
- **影响**: 潜在的路径遍历和恶意文件上传
|
||||
- **缓解**: 添加文件类型验证和路径净化
|
||||
|
||||
### 中风险项
|
||||
|
||||
1. **认证时序攻击**
|
||||
- **影响**: Token 可能被暴力破解
|
||||
- **缓解**: 使用 constant-time 比较
|
||||
|
||||
2. **CORS 配置**
|
||||
- **影响**: CSRF 攻击风险
|
||||
- **缓解**: 生产环境限制 origins
|
||||
|
||||
### 低风险项
|
||||
|
||||
1. **依赖更新**
|
||||
- **影响**: 潜在的安全漏洞
|
||||
- **缓解**: 定期运行 `pip-audit`
|
||||
|
||||
---
|
||||
|
||||
## 改进建议
|
||||
|
||||
### 立即执行 (高优先级)
|
||||
|
||||
1. **提升测试覆盖率**
|
||||
```bash
|
||||
# 目标: 60%+
|
||||
pytest tests/ --cov=packages --cov-report=html
|
||||
```
|
||||
- 优先测试 `inference/pipeline/`
|
||||
- 添加 API 集成测试
|
||||
- 添加存储层测试
|
||||
|
||||
2. **加强文件上传安全**
|
||||
```python
|
||||
# 添加文件类型验证
|
||||
ALLOWED_EXTENSIONS = {".pdf"}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# 验证 PDF 魔数
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise HTTPException(400, "Invalid PDF file format")
|
||||
```
|
||||
|
||||
3. **修复时序攻击漏洞**
|
||||
```python
|
||||
import hmac
|
||||
|
||||
def verify_token(token: str, expected: str) -> bool:
|
||||
return hmac.compare_digest(token, expected)
|
||||
```
|
||||
|
||||
### 短期执行 (中优先级)
|
||||
|
||||
4. **添加路径遍历防护**
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
def get_safe_path(filename: str, base_dir: Path) -> Path:
|
||||
safe_name = Path(filename).name
|
||||
full_path = (base_dir / safe_name).resolve()
|
||||
if not full_path.is_relative_to(base_dir):
|
||||
raise HTTPException(400, "Invalid file path")
|
||||
return full_path
|
||||
```
|
||||
|
||||
5. **配置 CORS 白名单**
|
||||
```python
|
||||
ALLOWED_ORIGINS = [
|
||||
"http://localhost:5173",
|
||||
"https://your-domain.com",
|
||||
]
|
||||
```
|
||||
|
||||
6. **添加安全测试**
|
||||
```python
|
||||
def test_sql_injection_prevented(client):
|
||||
response = client.get("/api/v1/documents?id='; DROP TABLE;")
|
||||
assert response.status_code in (400, 422)
|
||||
|
||||
def test_path_traversal_prevented(client):
|
||||
response = client.get("/api/v1/results/../../etc/passwd")
|
||||
assert response.status_code == 400
|
||||
```
|
||||
|
||||
### 长期执行 (低优先级)
|
||||
|
||||
7. **依赖安全审计**
|
||||
```bash
|
||||
pip install pip-audit
|
||||
pip-audit --desc --format=json > security-audit.json
|
||||
```
|
||||
|
||||
8. **代码质量工具**
|
||||
```bash
|
||||
# 添加 pre-commit hooks
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
9. **性能监控**
|
||||
- 添加 APM 工具 (如 Datadog, New Relic)
|
||||
- 设置性能基准测试
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 总体评分
|
||||
|
||||
| 维度 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| **安全性** | 8/10 | 基础安全良好,需加强输入验证和认证 |
|
||||
| **代码质量** | 8/10 | 结构清晰,类型注解完善,部分文件过大 |
|
||||
| **可维护性** | 9/10 | 模块化设计,文档详尽,架构合理 |
|
||||
| **测试覆盖** | 5/10 | 需大幅提升至 60%+ |
|
||||
| **性能** | 9/10 | 94.8% 匹配率,93.5% mAP |
|
||||
| **总体** | **8.2/10** | 优秀的项目,需关注测试和安全细节 |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
|
||||
2. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
|
||||
3. **代码质量高**: 类型注解完善,文档详尽,结构清晰
|
||||
4. **测试是短板**: 28% 覆盖率是最大风险点
|
||||
5. **生产就绪**: 经过小幅改进后可以投入生产使用
|
||||
|
||||
### 下一步行动
|
||||
|
||||
1. 🔴 **立即**: 提升测试覆盖率至 60%+
|
||||
2. 🟡 **本周**: 修复时序攻击漏洞,加强文件上传验证
|
||||
3. 🟡 **本月**: 添加路径遍历防护,配置 CORS 白名单
|
||||
4. 🟢 **季度**: 建立安全审计流程,添加性能监控
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 审查工具
|
||||
|
||||
- Claude Code Security Review Skill
|
||||
- Claude Code Coding Standards Skill
|
||||
- grep / find / wc
|
||||
|
||||
### 相关文件
|
||||
|
||||
- `packages/shared/shared/config.py`
|
||||
- `packages/inference/inference/web/config.py`
|
||||
- `packages/inference/inference/web/core/auth.py`
|
||||
- `packages/inference/inference/web/core/rate_limiter.py`
|
||||
- `packages/shared/shared/storage/`
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
|
||||
- [FastAPI Security](https://fastapi.tiangolo.com/tutorial/security/)
|
||||
- [Bandit (Python Security Linter)](https://bandit.readthedocs.io/)
|
||||
- [pip-audit](https://pypi.org/project/pip-audit/)
|
||||
798
README.md
798
README.md
@@ -7,8 +7,29 @@
|
||||
本项目实现了一个完整的发票字段自动提取流程:
|
||||
|
||||
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
|
||||
2. **模型训练**: 使用 YOLOv11 训练字段检测模型
|
||||
3. **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化
|
||||
2. **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
|
||||
3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化
|
||||
4. **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
|
||||
|
||||
### 架构
|
||||
|
||||
项目采用 **monorepo + 三包分离** 架构,训练和推理可独立部署:
|
||||
|
||||
```
|
||||
packages/
|
||||
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 存储, 训练)
|
||||
├── training/ # 训练服务 (GPU, 按需启动)
|
||||
└── inference/ # 推理服务 (常驻运行)
|
||||
frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|
||||
```
|
||||
|
||||
| 服务 | 部署目标 | GPU | 生命周期 |
|
||||
|------|---------|-----|---------|
|
||||
| **Frontend** | Vercel / Nginx | 否 | 常驻 |
|
||||
| **Inference** | Azure App Service / AWS | 可选 | 常驻 7x24 |
|
||||
| **Training** | Azure ACI / AWS ECS | 必需 | 按需启动/销毁 |
|
||||
|
||||
两个服务通过共享 PostgreSQL 数据库通信。推理服务通过 API 触发训练任务,训练服务从数据库拾取任务执行。
|
||||
|
||||
### 当前进度
|
||||
|
||||
@@ -16,6 +37,9 @@
|
||||
|------|------|
|
||||
| **已标注文档** | 9,738 (9,709 成功) |
|
||||
| **总体字段匹配率** | 94.8% (82,604/87,121) |
|
||||
| **测试** | 1,601 passed |
|
||||
| **测试覆盖率** | 28% |
|
||||
| **模型 mAP@0.5** | 93.5% |
|
||||
|
||||
**各字段匹配率:**
|
||||
|
||||
@@ -42,34 +66,10 @@
|
||||
|------|------|
|
||||
| **WSL** | WSL 2 + Ubuntu 22.04 |
|
||||
| **Conda** | Miniconda 或 Anaconda |
|
||||
| **Python** | 3.10+ (通过 Conda 管理) |
|
||||
| **Python** | 3.11+ (通过 Conda 管理) |
|
||||
| **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) |
|
||||
| **数据库** | PostgreSQL (存储标注结果) |
|
||||
|
||||
## 功能特点
|
||||
|
||||
- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF
|
||||
- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据
|
||||
- **多策略字段匹配**: 精确匹配、子串匹配、规范化匹配
|
||||
- **数据库存储**: 标注结果存储在 PostgreSQL,支持增量处理和断点续传
|
||||
- **YOLO 检测**: 使用 YOLOv11 检测发票字段区域
|
||||
- **OCR 识别**: 使用 PaddleOCR v5 提取检测区域的文本
|
||||
- **Web 应用**: 提供 REST API 和可视化界面
|
||||
- **增量训练**: 支持在已训练模型基础上继续训练
|
||||
|
||||
## 支持的字段
|
||||
|
||||
| 类别 ID | 字段名 | 说明 |
|
||||
|---------|--------|------|
|
||||
| 0 | invoice_number | 发票号码 |
|
||||
| 1 | invoice_date | 发票日期 |
|
||||
| 2 | invoice_due_date | 到期日期 |
|
||||
| 3 | ocr_number | OCR 参考号 (瑞典支付系统) |
|
||||
| 4 | bankgiro | Bankgiro 号码 |
|
||||
| 5 | plusgiro | Plusgiro 号码 |
|
||||
| 6 | amount | 金额 |
|
||||
| 7 | supplier_organisation_number | 供应商组织号 |
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
@@ -83,370 +83,458 @@ conda activate invoice-py311
|
||||
# 3. 进入项目目录
|
||||
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
|
||||
|
||||
# 4. 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 5. 安装 Web 依赖
|
||||
pip install uvicorn fastapi python-multipart pydantic
|
||||
# 4. 安装三个包 (editable mode)
|
||||
pip install -e packages/shared
|
||||
pip install -e packages/training
|
||||
pip install -e packages/inference
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 准备数据
|
||||
|
||||
```
|
||||
~/invoice-data/
|
||||
├── raw_pdfs/
|
||||
│ ├── {DocumentId}.pdf
|
||||
│ └── ...
|
||||
├── structured_data/
|
||||
│ └── document_export_YYYYMMDD.csv
|
||||
└── dataset/
|
||||
└── temp/ (渲染的图片)
|
||||
```
|
||||
|
||||
CSV 格式:
|
||||
```csv
|
||||
DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount
|
||||
3be53fd7-...,2025-12-13,100017500321,2026-01-03,100017500321,53939484,,114
|
||||
```
|
||||
|
||||
### 2. 自动标注
|
||||
|
||||
```bash
|
||||
# 使用双池模式 (CPU + GPU)
|
||||
python -m src.cli.autolabel \
|
||||
--dual-pool \
|
||||
--cpu-workers 3 \
|
||||
--gpu-workers 1
|
||||
|
||||
# 单线程模式
|
||||
python -m src.cli.autolabel --workers 4
|
||||
```
|
||||
|
||||
### 3. 训练模型
|
||||
|
||||
```bash
|
||||
# 从预训练模型开始训练
|
||||
python -m src.cli.train \
|
||||
--model yolo11n.pt \
|
||||
--epochs 100 \
|
||||
--batch 16 \
|
||||
--name invoice_yolo11n_full \
|
||||
--dpi 150
|
||||
```
|
||||
|
||||
### 4. 增量训练
|
||||
|
||||
当添加新数据后,可以在已训练模型基础上继续训练:
|
||||
|
||||
```bash
|
||||
# 从已训练的 best.pt 继续训练
|
||||
python -m src.cli.train \
|
||||
--model runs/train/invoice_yolo11n_full/weights/best.pt \
|
||||
--epochs 30 \
|
||||
--batch 16 \
|
||||
--name invoice_yolo11n_v2 \
|
||||
--dpi 150
|
||||
```
|
||||
|
||||
**增量训练建议**:
|
||||
|
||||
| 场景 | 建议 |
|
||||
|------|------|
|
||||
| 添加少量新数据 (<20%) | 继续训练 10-30 epochs |
|
||||
| 添加大量新数据 (>50%) | 继续训练 50-100 epochs |
|
||||
| 修正大量标注错误 | 从头训练 |
|
||||
| 添加新的字段类型 | 从头训练 |
|
||||
|
||||
### 5. 推理
|
||||
|
||||
```bash
|
||||
# 命令行推理
|
||||
python -m src.cli.infer \
|
||||
--model runs/train/invoice_yolo11n_full/weights/best.pt \
|
||||
--input path/to/invoice.pdf \
|
||||
--output result.json \
|
||||
--gpu
|
||||
```
|
||||
|
||||
### 6. Web 应用
|
||||
|
||||
```bash
|
||||
# 启动 Web 服务器
|
||||
python run_server.py --port 8000
|
||||
|
||||
# 开发模式 (自动重载)
|
||||
python run_server.py --debug --reload
|
||||
|
||||
# 禁用 GPU
|
||||
python run_server.py --no-gpu
|
||||
```
|
||||
|
||||
访问 **http://localhost:8000** 使用 Web 界面。
|
||||
|
||||
#### Web API 端点
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| GET | `/` | Web UI 界面 |
|
||||
| GET | `/api/v1/health` | 健康检查 |
|
||||
| POST | `/api/v1/infer` | 上传文件并推理 |
|
||||
| GET | `/api/v1/results/{filename}` | 获取可视化图片 |
|
||||
|
||||
## 训练配置
|
||||
|
||||
### YOLO 训练参数
|
||||
|
||||
```bash
|
||||
python -m src.cli.train [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 基础模型 (默认: yolo11n.pt)
|
||||
--epochs, -e 训练轮数 (默认: 100)
|
||||
--batch, -b 批大小 (默认: 16)
|
||||
--imgsz 图像尺寸 (默认: 1280)
|
||||
--dpi PDF 渲染 DPI (默认: 150)
|
||||
--name 训练名称
|
||||
--limit 限制文档数 (用于测试)
|
||||
--device 设备 (0=GPU, cpu)
|
||||
```
|
||||
|
||||
### 训练最佳实践
|
||||
|
||||
1. **禁用翻转增强** (文本检测):
|
||||
```python
|
||||
fliplr=0.0, flipud=0.0
|
||||
```
|
||||
|
||||
2. **使用 Early Stopping**:
|
||||
```python
|
||||
patience=20
|
||||
```
|
||||
|
||||
3. **启用 AMP** (混合精度训练):
|
||||
```python
|
||||
amp=True
|
||||
```
|
||||
|
||||
4. **保存检查点**:
|
||||
```python
|
||||
save_period=10
|
||||
```
|
||||
|
||||
### 训练结果示例
|
||||
|
||||
使用约 10,000 张训练图片,100 epochs 后的结果:
|
||||
|
||||
| 指标 | 值 |
|
||||
|------|-----|
|
||||
| **mAP@0.5** | 98.7% |
|
||||
| **mAP@0.5-0.95** | 87.4% |
|
||||
| **Precision** | 97.5% |
|
||||
| **Recall** | 95.5% |
|
||||
|
||||
> 注:目前仍在持续标注更多数据,预计最终将有 25,000+ 张标注图片用于训练。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
invoice-master-poc-v2/
|
||||
├── src/
|
||||
│ ├── cli/ # 命令行工具
|
||||
│ │ ├── autolabel.py # 自动标注
|
||||
│ │ ├── train.py # 模型训练
|
||||
│ │ ├── infer.py # 推理
|
||||
│ │ └── serve.py # Web 服务器
|
||||
│ ├── pdf/ # PDF 处理
|
||||
│ │ ├── extractor.py # 文本提取
|
||||
│ │ ├── renderer.py # 图像渲染
|
||||
│ │ └── detector.py # 类型检测
|
||||
│ ├── ocr/ # PaddleOCR 封装
|
||||
│ ├── normalize/ # 字段规范化
|
||||
│ ├── matcher/ # 字段匹配
|
||||
│ ├── yolo/ # YOLO 相关
|
||||
│ │ ├── annotation_generator.py
|
||||
│ │ └── db_dataset.py
|
||||
│ ├── inference/ # 推理管道
|
||||
│ │ ├── pipeline.py
|
||||
│ │ ├── yolo_detector.py
|
||||
│ │ └── field_extractor.py
|
||||
│ ├── processing/ # 多池处理架构
|
||||
│ │ ├── worker_pool.py
|
||||
│ │ ├── cpu_pool.py
|
||||
│ │ ├── gpu_pool.py
|
||||
│ │ ├── task_dispatcher.py
|
||||
│ │ └── dual_pool_coordinator.py
|
||||
│ ├── web/ # Web 应用
|
||||
│ │ ├── app.py # FastAPI 应用
|
||||
│ │ ├── routes.py # API 路由
|
||||
│ │ ├── services.py # 业务逻辑
|
||||
│ │ ├── schemas.py # 数据模型
|
||||
│ │ └── config.py # 配置
|
||||
│ └── data/ # 数据处理
|
||||
├── config.py # 配置文件
|
||||
├── run_server.py # Web 服务器启动脚本
|
||||
├── runs/ # 训练输出
|
||||
│ └── train/
|
||||
│ └── invoice_yolo11n_full/
|
||||
│ └── weights/
|
||||
│ ├── best.pt
|
||||
│ └── last.pt
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
## 多池处理架构
|
||||
|
||||
项目使用 CPU + GPU 双池架构处理不同类型的 PDF:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ DualPoolCoordinator │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ │
|
||||
│ │ CPU Pool │ │ GPU Pool │ │
|
||||
│ │ (3 workers) │ │ (1 worker) │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ Text PDFs │ │ Scanned PDFs │ │
|
||||
│ │ ~50-87 it/s │ │ ~1-2 it/s │ │
|
||||
│ └─────────────────┘ └─────────────────┘ │
|
||||
├── packages/
|
||||
│ ├── shared/ # 共享库
|
||||
│ │ ├── setup.py
|
||||
│ │ └── shared/
|
||||
│ │ ├── pdf/ # PDF 处理 (提取, 渲染, 检测)
|
||||
│ │ ├── ocr/ # PaddleOCR 封装 + 机器码解析
|
||||
│ │ ├── normalize/ # 字段规范化 (10 种 normalizer)
|
||||
│ │ ├── matcher/ # 字段匹配 (精确/子串/模糊)
|
||||
│ │ ├── storage/ # 存储抽象层 (Local/Azure/S3)
|
||||
│ │ ├── training/ # 共享训练组件 (YOLOTrainer)
|
||||
│ │ ├── augmentation/ # 数据增强 (DatasetAugmenter)
|
||||
│ │ ├── utils/ # 工具 (验证, 清理, 模糊匹配)
|
||||
│ │ ├── data/ # DocumentDB, CSVLoader
|
||||
│ │ ├── config.py # 全局配置 (数据库, 路径, DPI)
|
||||
│ │ └── exceptions.py # 异常定义
|
||||
│ │
|
||||
│ TaskDispatcher: 根据 PDF 类型分配任务 │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
│ ├── training/ # 训练服务 (GPU, 按需)
|
||||
│ │ ├── setup.py
|
||||
│ │ ├── Dockerfile
|
||||
│ │ ├── run_training.py # 入口 (--task-id 或 --poll)
|
||||
│ │ └── training/
|
||||
│ │ ├── cli/ # train, autolabel, analyze_*, validate
|
||||
│ │ ├── yolo/ # db_dataset, annotation_generator
|
||||
│ │ ├── processing/ # CPU/GPU worker pool, task dispatcher
|
||||
│ │ └── data/ # training_db, autolabel_report
|
||||
│ │
|
||||
│ └── inference/ # 推理服务 (常驻)
|
||||
│ ├── setup.py
|
||||
│ ├── Dockerfile
|
||||
│ ├── run_server.py # Web 服务器入口
|
||||
│ └── inference/
|
||||
│ ├── cli/ # infer, serve
|
||||
│ ├── pipeline/ # YOLO 检测, 字段提取, 解析器
|
||||
│ ├── web/ # FastAPI 应用
|
||||
│ │ ├── api/v1/ # REST API (admin, public, batch)
|
||||
│ │ ├── schemas/ # Pydantic 数据模型
|
||||
│ │ ├── services/ # 业务逻辑
|
||||
│ │ ├── core/ # 认证, 调度器, 限流
|
||||
│ │ └── workers/ # 后台任务队列
|
||||
│ ├── validation/ # LLM 验证器
|
||||
│ ├── data/ # AdminDB, AsyncRequestDB, Models
|
||||
│ └── azure/ # ACI 训练触发器
|
||||
│
|
||||
├── frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|
||||
│ ├── src/
|
||||
│ │ ├── api/ # API 客户端 (axios + react-query)
|
||||
│ │ ├── components/ # UI 组件
|
||||
│ │ │ ├── Dashboard.tsx # 文档管理面板
|
||||
│ │ │ ├── Training.tsx # 训练管理 (数据集/任务)
|
||||
│ │ │ ├── Models.tsx # 模型版本管理
|
||||
│ │ │ ├── DatasetDetail.tsx # 数据集详情
|
||||
│ │ │ └── InferenceDemo.tsx # 推理演示
|
||||
│ │ └── hooks/ # React Query hooks
|
||||
│ └── package.json
|
||||
│
|
||||
├── migrations/ # 数据库迁移 (SQL)
|
||||
│ ├── 003_training_tasks.sql
|
||||
│ ├── 004_training_datasets.sql
|
||||
│ ├── 005_add_group_key.sql
|
||||
│ ├── 006_model_versions.sql
|
||||
│ ├── 007_training_tasks_extra_columns.sql
|
||||
│ ├── 008_fix_model_versions_fk.sql
|
||||
│ ├── 009_add_document_category.sql
|
||||
│ └── 010_add_dataset_training_status.sql
|
||||
│
|
||||
├── tests/ # 测试 (1,601 tests)
|
||||
├── docker-compose.yml # 本地开发 (postgres + inference + training)
|
||||
├── run_server.py # 快捷启动脚本
|
||||
└── runs/train/ # 训练输出 (weights, curves)
|
||||
```
|
||||
|
||||
### 关键设计
|
||||
## 支持的字段
|
||||
|
||||
- **spawn 启动方式**: 兼容 CUDA 多进程
|
||||
- **as_completed()**: 无死锁结果收集
|
||||
- **进程初始化器**: 每个 worker 加载一次模型
|
||||
- **协调器持久化**: 跨 CSV 文件复用 worker 池
|
||||
| 类别 ID | 字段名 | 说明 |
|
||||
|---------|--------|------|
|
||||
| 0 | invoice_number | 发票号码 |
|
||||
| 1 | invoice_date | 发票日期 |
|
||||
| 2 | invoice_due_date | 到期日期 |
|
||||
| 3 | ocr_number | OCR 参考号 (瑞典支付系统) |
|
||||
| 4 | bankgiro | Bankgiro 号码 |
|
||||
| 5 | plusgiro | Plusgiro 号码 |
|
||||
| 6 | amount | 金额 |
|
||||
| 7 | supplier_organisation_number | 供应商组织号 |
|
||||
| 8 | payment_line | 支付行 (机器可读格式) |
|
||||
| 9 | customer_number | 客户编号 |
|
||||
|
||||
## 配置文件
|
||||
## 快速开始
|
||||
|
||||
### config.py
|
||||
|
||||
```python
|
||||
# 数据库配置
|
||||
DATABASE = {
|
||||
'host': '192.168.68.31',
|
||||
'port': 5432,
|
||||
'database': 'docmaster',
|
||||
'user': 'docmaster',
|
||||
'password': '******',
|
||||
}
|
||||
|
||||
# 路径配置
|
||||
PATHS = {
|
||||
'csv_dir': '~/invoice-data/structured_data',
|
||||
'pdf_dir': '~/invoice-data/raw_pdfs',
|
||||
'output_dir': '~/invoice-data/dataset',
|
||||
}
|
||||
```
|
||||
|
||||
## CLI 命令参考
|
||||
|
||||
### autolabel
|
||||
### 1. 自动标注
|
||||
|
||||
```bash
|
||||
python -m src.cli.autolabel [OPTIONS]
|
||||
# 使用双池模式 (CPU + GPU)
|
||||
python -m training.cli.autolabel \
|
||||
--dual-pool \
|
||||
--cpu-workers 3 \
|
||||
--gpu-workers 1
|
||||
|
||||
Options:
|
||||
--csv, -c CSV 文件路径 (支持 glob)
|
||||
--pdf-dir, -p PDF 文件目录
|
||||
--output, -o 输出目录
|
||||
--workers, -w 单线程模式 worker 数 (默认: 4)
|
||||
--dual-pool 启用双池模式
|
||||
--cpu-workers CPU 池 worker 数 (默认: 3)
|
||||
--gpu-workers GPU 池 worker 数 (默认: 1)
|
||||
--dpi 渲染 DPI (默认: 150)
|
||||
--limit, -l 限制处理文档数
|
||||
# 单线程模式
|
||||
python -m training.cli.autolabel --workers 4
|
||||
```
|
||||
|
||||
### train
|
||||
### 2. 训练模型
|
||||
|
||||
```bash
|
||||
python -m src.cli.train [OPTIONS]
|
||||
# 从预训练模型开始训练
|
||||
python -m training.cli.train \
|
||||
--model yolo11n.pt \
|
||||
--epochs 100 \
|
||||
--batch 16 \
|
||||
--name invoice_fields \
|
||||
--dpi 150
|
||||
|
||||
Options:
|
||||
--model, -m 基础模型路径
|
||||
--epochs, -e 训练轮数 (默认: 100)
|
||||
--batch, -b 批大小 (默认: 16)
|
||||
--imgsz 图像尺寸 (默认: 1280)
|
||||
--dpi PDF 渲染 DPI (默认: 150)
|
||||
--name 训练名称
|
||||
--limit 限制文档数
|
||||
# 低内存模式
|
||||
python -m training.cli.train \
|
||||
--model yolo11n.pt \
|
||||
--epochs 100 \
|
||||
--name invoice_fields \
|
||||
--low-memory
|
||||
|
||||
# 从检查点恢复训练
|
||||
python -m training.cli.train \
|
||||
--model runs/train/invoice_fields/weights/last.pt \
|
||||
--epochs 100 \
|
||||
--name invoice_fields \
|
||||
--resume
|
||||
```
|
||||
|
||||
### infer
|
||||
### 3. 推理
|
||||
|
||||
```bash
|
||||
python -m src.cli.infer [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 模型路径
|
||||
--input, -i 输入 PDF/图像
|
||||
--output, -o 输出 JSON 路径
|
||||
--confidence 置信度阈值 (默认: 0.5)
|
||||
--dpi 渲染 DPI (默认: 300)
|
||||
--gpu 使用 GPU
|
||||
# 命令行推理
|
||||
python -m inference.cli.infer \
|
||||
--model runs/train/invoice_fields/weights/best.pt \
|
||||
--input path/to/invoice.pdf \
|
||||
--output result.json \
|
||||
--gpu
|
||||
```
|
||||
|
||||
### serve
|
||||
### 4. Web 应用
|
||||
|
||||
```bash
|
||||
python run_server.py [OPTIONS]
|
||||
# 从 Windows PowerShell 启动
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
|
||||
|
||||
Options:
|
||||
--host 绑定地址 (默认: 0.0.0.0)
|
||||
--port 端口 (默认: 8000)
|
||||
--model, -m 模型路径
|
||||
--confidence 置信度阈值 (默认: 0.3)
|
||||
--dpi 渲染 DPI (默认: 150)
|
||||
--no-gpu 禁用 GPU
|
||||
--reload 开发模式自动重载
|
||||
--debug 调试模式
|
||||
# 启动前端
|
||||
cd frontend && npm install && npm run dev
|
||||
# 访问 http://localhost:5173
|
||||
```
|
||||
|
||||
### 5. Docker 本地开发
|
||||
|
||||
```bash
|
||||
docker-compose up
|
||||
# inference: http://localhost:8000
|
||||
# training: 轮询模式自动拾取任务
|
||||
```
|
||||
|
||||
## 训练触发流程
|
||||
|
||||
推理服务通过 API 触发训练,训练在独立的 GPU 实例上执行:
|
||||
|
||||
```
|
||||
Inference API PostgreSQL Training (ACI)
|
||||
| | |
|
||||
POST /admin/training/trigger | |
|
||||
|-> INSERT training_tasks ------>| status=pending |
|
||||
|-> Azure SDK: create ACI --------------------------------> 启动
|
||||
| | |
|
||||
| |<-- SELECT pending -----+
|
||||
| |--- UPDATE running -----+
|
||||
| | 执行训练...
|
||||
| |<-- UPDATE completed ---+
|
||||
| | + model_path |
|
||||
| | + metrics 自动关机
|
||||
| | |
|
||||
GET /admin/training/{id} | |
|
||||
|-> SELECT training_tasks ------>| |
|
||||
+-- return status + metrics | |
|
||||
```
|
||||
|
||||
## Web API 端点
|
||||
|
||||
**Public API:**
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| GET | `/api/v1/health` | 健康检查 |
|
||||
| POST | `/api/v1/infer` | 上传文件并推理 |
|
||||
| GET | `/api/v1/results/{filename}` | 获取可视化图片 |
|
||||
| POST | `/api/v1/async/infer` | 异步推理 |
|
||||
| GET | `/api/v1/async/status/{task_id}` | 查询异步任务状态 |
|
||||
|
||||
**Admin API** (需要 `X-Admin-Token` header):
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| POST | `/api/v1/admin/auth/login` | 管理员登录 |
|
||||
| GET | `/api/v1/admin/documents` | 文档列表 |
|
||||
| POST | `/api/v1/admin/documents/upload` | 上传 PDF |
|
||||
| GET | `/api/v1/admin/documents/{id}` | 文档详情 |
|
||||
| PATCH | `/api/v1/admin/documents/{id}/status` | 更新文档状态 |
|
||||
| PATCH | `/api/v1/admin/documents/{id}/category` | 更新文档分类 |
|
||||
| GET | `/api/v1/admin/documents/categories` | 获取分类列表 |
|
||||
| POST | `/api/v1/admin/documents/{id}/annotations` | 创建标注 |
|
||||
|
||||
**Training API:**
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| POST | `/api/v1/admin/training/datasets` | 创建数据集 |
|
||||
| GET | `/api/v1/admin/training/datasets` | 数据集列表 |
|
||||
| GET | `/api/v1/admin/training/datasets/{id}` | 数据集详情 |
|
||||
| DELETE | `/api/v1/admin/training/datasets/{id}` | 删除数据集 |
|
||||
| POST | `/api/v1/admin/training/tasks` | 创建训练任务 |
|
||||
| GET | `/api/v1/admin/training/tasks` | 任务列表 |
|
||||
| GET | `/api/v1/admin/training/tasks/{id}` | 任务详情 |
|
||||
| GET | `/api/v1/admin/training/tasks/{id}/logs` | 训练日志 |
|
||||
|
||||
**Model Versions API:**
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| GET | `/api/v1/admin/models` | 模型版本列表 |
|
||||
| GET | `/api/v1/admin/models/{id}` | 模型详情 |
|
||||
| POST | `/api/v1/admin/models/{id}/activate` | 激活模型 |
|
||||
| POST | `/api/v1/admin/models/{id}/archive` | 归档模型 |
|
||||
| DELETE | `/api/v1/admin/models/{id}` | 删除模型 |
|
||||
|
||||
## Python API
|
||||
|
||||
```python
|
||||
from src.inference import InferencePipeline
|
||||
from inference.pipeline import InferencePipeline
|
||||
|
||||
# 初始化
|
||||
pipeline = InferencePipeline(
|
||||
model_path='runs/train/invoice_yolo11n_full/weights/best.pt',
|
||||
confidence_threshold=0.3,
|
||||
model_path='runs/train/invoice_fields/weights/best.pt',
|
||||
confidence_threshold=0.25,
|
||||
use_gpu=True,
|
||||
dpi=150
|
||||
dpi=150,
|
||||
enable_fallback=True
|
||||
)
|
||||
|
||||
# 处理 PDF
|
||||
result = pipeline.process_pdf('invoice.pdf')
|
||||
|
||||
# 处理图片
|
||||
result = pipeline.process_image('invoice.png')
|
||||
print(result.fields)
|
||||
# {'InvoiceNumber': '12345', 'Amount': '1234.56', ...}
|
||||
|
||||
# 获取结果
|
||||
print(result.fields) # {'InvoiceNumber': '12345', 'Amount': '1234.56', ...}
|
||||
print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
|
||||
print(result.to_json()) # JSON 格式输出
|
||||
print(result.confidence)
|
||||
# {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
|
||||
|
||||
# 交叉验证
|
||||
if result.cross_validation:
|
||||
print(f"OCR match: {result.cross_validation.ocr_match}")
|
||||
```
|
||||
|
||||
## 开发状态
|
||||
```python
|
||||
from inference.pipeline.payment_line_parser import PaymentLineParser
|
||||
from inference.pipeline.customer_number_parser import CustomerNumberParser
|
||||
|
||||
- [x] 文本层 PDF 自动标注
|
||||
- [x] 扫描图 OCR 自动标注
|
||||
- [x] 多策略字段匹配 (精确/子串/规范化)
|
||||
- [x] PostgreSQL 数据库存储 (断点续传)
|
||||
- [x] 信号处理和超时保护
|
||||
- [x] YOLO 训练 (98.7% mAP@0.5)
|
||||
- [x] 推理管道
|
||||
- [x] 字段规范化和验证
|
||||
- [x] Web 应用 (FastAPI + 前端 UI)
|
||||
- [x] 增量训练支持
|
||||
- [ ] 完成全部 25,000+ 文档标注
|
||||
- [ ] 表格 items 处理
|
||||
- [ ] 模型量化部署
|
||||
# Payment Line 解析
|
||||
parser = PaymentLineParser()
|
||||
result = parser.parse("# 94228110015950070 # 15658 00 8 > 48666036#14#")
|
||||
print(f"OCR: {result.ocr_number}, Amount: {result.amount}")
|
||||
|
||||
# Customer Number 解析
|
||||
parser = CustomerNumberParser()
|
||||
result = parser.parse("Said, Shakar Umj 436-R Billo")
|
||||
print(f"Customer Number: {result}") # "UMJ 436-R"
|
||||
```
|
||||
|
||||
## DPI 配置
|
||||
|
||||
系统所有组件统一使用 **150 DPI**。DPI 必须在训练和推理时保持一致。
|
||||
|
||||
| 组件 | 配置位置 |
|
||||
|------|---------|
|
||||
| 全局常量 | `packages/shared/shared/config.py` -> `DEFAULT_DPI = 150` |
|
||||
| Web 推理 | `packages/inference/inference/web/config.py` -> `ModelConfig.dpi` |
|
||||
| CLI 推理 | `python -m inference.cli.infer --dpi 150` |
|
||||
| 自动标注 | `packages/shared/shared/config.py` -> `AUTOLABEL['dpi']` |
|
||||
|
||||
## 数据库架构
|
||||
|
||||
| 数据库 | 用途 | 存储内容 |
|
||||
|--------|------|----------|
|
||||
| **PostgreSQL** | 主数据库 | 文档、标注、训练任务、数据集、模型版本 |
|
||||
|
||||
### 主要表
|
||||
|
||||
| 表名 | 说明 |
|
||||
|------|------|
|
||||
| `admin_documents` | 文档管理 (PDF 元数据, 状态, 分类) |
|
||||
| `admin_annotations` | 标注数据 (YOLO 格式边界框) |
|
||||
| `training_tasks` | 训练任务 (状态, 配置, 指标) |
|
||||
| `training_datasets` | 数据集 (train/val/test 分割) |
|
||||
| `dataset_documents` | 数据集-文档关联 |
|
||||
| `model_versions` | 模型版本管理 (激活/归档) |
|
||||
| `admin_tokens` | 管理员认证令牌 |
|
||||
| `async_requests` | 异步推理请求 |
|
||||
|
||||
### 数据集状态
|
||||
|
||||
| 状态 | 说明 |
|
||||
|------|------|
|
||||
| `building` | 正在构建数据集 |
|
||||
| `ready` | 数据集就绪,可开始训练 |
|
||||
| `trained` | 已完成训练 |
|
||||
| `failed` | 构建失败 |
|
||||
| `archived` | 已归档 |
|
||||
|
||||
### 训练状态
|
||||
|
||||
| 状态 | 说明 |
|
||||
|------|------|
|
||||
| `pending` | 等待执行 |
|
||||
| `scheduled` | 已计划 |
|
||||
| `running` | 正在训练 |
|
||||
| `completed` | 训练完成 |
|
||||
| `failed` | 训练失败 |
|
||||
| `cancelled` | 已取消 |
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
DB_PASSWORD=xxx pytest tests/ -q
|
||||
|
||||
# 运行并查看覆盖率
|
||||
DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
|
||||
```
|
||||
|
||||
| 指标 | 数值 |
|
||||
|------|------|
|
||||
| **测试总数** | 1,601 |
|
||||
| **通过率** | 100% |
|
||||
| **覆盖率** | 28% |
|
||||
|
||||
## 存储抽象层
|
||||
|
||||
统一的文件存储接口,支持多后端切换:
|
||||
|
||||
| 后端 | 用途 | 安装 |
|
||||
|------|------|------|
|
||||
| **Local** | 本地开发/测试 | 默认 |
|
||||
| **Azure Blob** | Azure 云部署 | `pip install -e "packages/shared[azure]"` |
|
||||
| **AWS S3** | AWS 云部署 | `pip install -e "packages/shared[s3]"` |
|
||||
|
||||
### 配置文件 (storage.yaml)
|
||||
|
||||
```yaml
|
||||
backend: ${STORAGE_BACKEND:-local}
|
||||
presigned_url_expiry: 3600
|
||||
|
||||
local:
|
||||
base_path: ${STORAGE_BASE_PATH:-./data/storage}
|
||||
|
||||
azure:
|
||||
connection_string: ${AZURE_STORAGE_CONNECTION_STRING}
|
||||
container_name: ${AZURE_STORAGE_CONTAINER:-documents}
|
||||
|
||||
s3:
|
||||
bucket_name: ${AWS_S3_BUCKET}
|
||||
region_name: ${AWS_REGION:-us-east-1}
|
||||
```
|
||||
|
||||
### 使用示例
|
||||
|
||||
```python
|
||||
from shared.storage import get_storage_backend
|
||||
|
||||
# 从配置文件加载
|
||||
storage = get_storage_backend("storage.yaml")
|
||||
|
||||
# 上传文件
|
||||
storage.upload(Path("local.pdf"), "documents/invoice.pdf")
|
||||
|
||||
# 获取预签名 URL (前端访问)
|
||||
url = storage.get_presigned_url("documents/invoice.pdf", expires_in_seconds=3600)
|
||||
```
|
||||
|
||||
### 环境变量
|
||||
|
||||
| 变量 | 后端 | 说明 |
|
||||
|------|------|------|
|
||||
| `STORAGE_BACKEND` | 全部 | `local`, `azure_blob`, `s3` |
|
||||
| `STORAGE_BASE_PATH` | Local | 本地存储路径 |
|
||||
| `AZURE_STORAGE_CONNECTION_STRING` | Azure | 连接字符串 |
|
||||
| `AZURE_STORAGE_CONTAINER` | Azure | 容器名称 |
|
||||
| `AWS_S3_BUCKET` | S3 | 存储桶名称 |
|
||||
| `AWS_REGION` | S3 | 区域 (默认: us-east-1) |
|
||||
|
||||
## 数据增强
|
||||
|
||||
训练时支持多种数据增强策略:
|
||||
|
||||
| 增强类型 | 说明 |
|
||||
|----------|------|
|
||||
| `perspective_warp` | 透视变换 (模拟扫描角度) |
|
||||
| `wrinkle` | 皱纹效果 |
|
||||
| `edge_damage` | 边缘损坏 |
|
||||
| `stain` | 污渍效果 |
|
||||
| `lighting_variation` | 光照变化 |
|
||||
| `shadow` | 阴影效果 |
|
||||
| `gaussian_blur` | 高斯模糊 |
|
||||
| `motion_blur` | 运动模糊 |
|
||||
| `gaussian_noise` | 高斯噪声 |
|
||||
| `salt_pepper` | 椒盐噪声 |
|
||||
| `paper_texture` | 纸张纹理 |
|
||||
| `scanner_artifacts` | 扫描伪影 |
|
||||
|
||||
增强配置示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"augmentation": {
|
||||
"gaussian_blur": { "enabled": true, "kernel_size": 5 },
|
||||
"perspective_warp": { "enabled": true, "intensity": 0.1 }
|
||||
},
|
||||
"augmentation_multiplier": 2
|
||||
}
|
||||
```
|
||||
|
||||
## 前端功能
|
||||
|
||||
React 前端提供以下功能模块:
|
||||
|
||||
| 模块 | 功能 |
|
||||
|------|------|
|
||||
| **Dashboard** | 文档列表、上传、标注状态管理、分类筛选 |
|
||||
| **Training** | 数据集创建/管理、训练任务配置、增强设置 |
|
||||
| **Models** | 模型版本管理、激活/归档、指标查看 |
|
||||
| **Inference Demo** | 实时推理演示、结果可视化 |
|
||||
|
||||
### 启动前端
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
# 访问 http://localhost:5173
|
||||
```
|
||||
|
||||
## 技术栈
|
||||
|
||||
@@ -455,9 +543,27 @@ print(result.to_json()) # JSON 格式输出
|
||||
| **目标检测** | YOLOv11 (Ultralytics) |
|
||||
| **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) |
|
||||
| **PDF 处理** | PyMuPDF (fitz) |
|
||||
| **数据库** | PostgreSQL + psycopg2 |
|
||||
| **数据库** | PostgreSQL + SQLModel |
|
||||
| **Web 框架** | FastAPI + Uvicorn |
|
||||
| **深度学习** | PyTorch + CUDA |
|
||||
| **前端** | React + TypeScript + Vite + TailwindCSS |
|
||||
| **状态管理** | React Query (TanStack Query) |
|
||||
| **深度学习** | PyTorch + CUDA 12.x |
|
||||
| **部署** | Docker + Azure/AWS (训练) / App Service (推理) |
|
||||
|
||||
## 环境变量
|
||||
|
||||
| 变量 | 必需 | 说明 |
|
||||
|------|------|------|
|
||||
| `DB_PASSWORD` | 是 | PostgreSQL 密码 |
|
||||
| `DB_HOST` | 否 | 数据库主机 (默认: localhost) |
|
||||
| `DB_PORT` | 否 | 数据库端口 (默认: 5432) |
|
||||
| `DB_NAME` | 否 | 数据库名 (默认: docmaster) |
|
||||
| `DB_USER` | 否 | 数据库用户 (默认: docmaster) |
|
||||
| `STORAGE_BASE_PATH` | 否 | 存储路径 (默认: ~/invoice-data/data) |
|
||||
| `MODEL_PATH` | 否 | 模型路径 |
|
||||
| `CONFIDENCE_THRESHOLD` | 否 | 置信度阈值 (默认: 0.5) |
|
||||
| `SERVER_HOST` | 否 | 服务器主机 (默认: 0.0.0.0) |
|
||||
| `SERVER_PORT` | 否 | 服务器端口 (默认: 8000) |
|
||||
|
||||
## 许可证
|
||||
|
||||
|
||||
64
config.py
64
config.py
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Configuration settings for the invoice extraction system.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
|
||||
|
||||
def _is_wsl() -> bool:
|
||||
"""Check if running inside WSL (Windows Subsystem for Linux)."""
|
||||
if platform.system() != 'Linux':
|
||||
return False
|
||||
# Check for WSL-specific indicators
|
||||
if os.environ.get('WSL_DISTRO_NAME'):
|
||||
return True
|
||||
try:
|
||||
with open('/proc/version', 'r') as f:
|
||||
return 'microsoft' in f.read().lower()
|
||||
except (FileNotFoundError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
# PostgreSQL Database Configuration
|
||||
DATABASE = {
|
||||
'host': '192.168.68.31',
|
||||
'port': 5432,
|
||||
'database': 'docmaster',
|
||||
'user': 'docmaster',
|
||||
'password': '0412220',
|
||||
}
|
||||
|
||||
# Connection string for psycopg2
|
||||
def get_db_connection_string():
|
||||
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
|
||||
|
||||
|
||||
# Paths Configuration - auto-detect WSL vs Windows
|
||||
if _is_wsl():
|
||||
# WSL: use native Linux filesystem for better I/O performance
|
||||
PATHS = {
|
||||
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
|
||||
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
|
||||
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
|
||||
'reports_dir': 'reports', # Keep reports in project directory
|
||||
}
|
||||
else:
|
||||
# Windows or native Linux: use relative paths
|
||||
PATHS = {
|
||||
'csv_dir': 'data/structured_data',
|
||||
'pdf_dir': 'data/raw_pdfs',
|
||||
'output_dir': 'data/dataset',
|
||||
'reports_dir': 'reports',
|
||||
}
|
||||
|
||||
# Auto-labeling Configuration
|
||||
AUTOLABEL = {
|
||||
'workers': 2,
|
||||
'dpi': 150,
|
||||
'min_confidence': 0.5,
|
||||
'train_ratio': 0.8,
|
||||
'val_ratio': 0.1,
|
||||
'test_ratio': 0.1,
|
||||
'max_records_per_report': 10000,
|
||||
}
|
||||
96
create_shims.sh
Normal file
96
create_shims.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create backward compatibility shims for all migrated files
|
||||
|
||||
# admin_auth.py -> core/auth.py
|
||||
cat > src/web/admin_auth.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.core.auth instead"""
|
||||
from src.web.core.auth import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_autolabel.py -> services/autolabel.py
|
||||
cat > src/web/admin_autolabel.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.autolabel instead"""
|
||||
from src.web.services.autolabel import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_scheduler.py -> core/scheduler.py
|
||||
cat > src/web/admin_scheduler.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.core.scheduler instead"""
|
||||
from src.web.core.scheduler import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_schemas.py -> schemas/admin.py
|
||||
cat > src/web/admin_schemas.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.schemas.admin instead"""
|
||||
from src.web.schemas.admin import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# schemas.py -> schemas/inference.py + schemas/common.py
|
||||
cat > src/web/schemas.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.schemas.inference or src.web.schemas.common instead"""
|
||||
from src.web.schemas.inference import * # noqa: F401, F403
|
||||
from src.web.schemas.common import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# services.py -> services/inference.py
|
||||
cat > src/web/services.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.inference instead"""
|
||||
from src.web.services.inference import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# async_queue.py -> workers/async_queue.py
|
||||
cat > src/web/async_queue.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.workers.async_queue instead"""
|
||||
from src.web.workers.async_queue import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# async_service.py -> services/async_processing.py
|
||||
cat > src/web/async_service.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.async_processing instead"""
|
||||
from src.web.services.async_processing import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_queue.py -> workers/batch_queue.py
|
||||
cat > src/web/batch_queue.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.workers.batch_queue instead"""
|
||||
from src.web.workers.batch_queue import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_upload_service.py -> services/batch_upload.py
|
||||
cat > src/web/batch_upload_service.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.batch_upload instead"""
|
||||
from src.web.services.batch_upload import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_upload_routes.py -> api/v1/batch/routes.py
|
||||
cat > src/web/batch_upload_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.batch.routes instead"""
|
||||
from src.web.api.v1.batch.routes import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_routes.py -> api/v1/admin/documents.py
|
||||
cat > src/web/admin_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.documents instead"""
|
||||
from src.web.api.v1.admin.documents import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_annotation_routes.py -> api/v1/admin/annotations.py
|
||||
cat > src/web/admin_annotation_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.annotations instead"""
|
||||
from src.web.api.v1.admin.annotations import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_training_routes.py -> api/v1/admin/training.py
|
||||
cat > src/web/admin_training_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.training instead"""
|
||||
from src.web.api.v1.admin.training import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# routes.py -> api/v1/routes.py
|
||||
cat > src/web/routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.routes instead"""
|
||||
from src.web.api.v1.routes import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
echo "✓ Created backward compatibility shims for all migrated files"
|
||||
60
docker-compose.yml
Normal file
60
docker-compose.yml
Normal file
@@ -0,0 +1,60 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15
|
||||
environment:
|
||||
POSTGRES_DB: docmaster
|
||||
POSTGRES_USER: docmaster
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-devpassword}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
- ./migrations:/docker-entrypoint-initdb.d
|
||||
|
||||
inference:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: packages/inference/Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- DB_NAME=docmaster
|
||||
- DB_USER=docmaster
|
||||
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
|
||||
- MODEL_PATH=/app/models/best.pt
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
depends_on:
|
||||
- postgres
|
||||
|
||||
training:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: packages/training/Dockerfile
|
||||
environment:
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- DB_NAME=docmaster
|
||||
- DB_USER=docmaster
|
||||
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
- ./temp:/app/temp
|
||||
depends_on:
|
||||
- postgres
|
||||
# Override CMD for local dev polling mode
|
||||
command: ["python", "run_training.py", "--poll", "--poll-interval", "30"]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
772
docs/aws-deployment-guide.md
Normal file
772
docs/aws-deployment-guide.md
Normal file
@@ -0,0 +1,772 @@
|
||||
# AWS 部署方案完整指南
|
||||
|
||||
## 目录
|
||||
- [核心问题](#核心问题)
|
||||
- [存储方案](#存储方案)
|
||||
- [训练方案](#训练方案)
|
||||
- [推理方案](#推理方案)
|
||||
- [价格对比](#价格对比)
|
||||
- [推荐架构](#推荐架构)
|
||||
- [实施步骤](#实施步骤)
|
||||
- [AWS vs Azure 对比](#aws-vs-azure-对比)
|
||||
|
||||
---
|
||||
|
||||
## 核心问题
|
||||
|
||||
| 问题 | 答案 |
|
||||
|------|------|
|
||||
| S3 能用于训练吗? | 可以,用 Mountpoint for S3 或 SageMaker 原生支持 |
|
||||
| 能实时从 S3 读取训练吗? | 可以,SageMaker 支持 Pipe Mode 流式读取 |
|
||||
| 本地能挂载 S3 吗? | 可以,用 s3fs-fuse 或 Rclone |
|
||||
| EC2 空闲时收费吗? | 收费,只要运行就按小时计费 |
|
||||
| 如何按需付费? | 用 SageMaker Managed Spot 或 Lambda |
|
||||
| 推理服务用什么? | Lambda (Serverless) 或 ECS/Fargate (容器) |
|
||||
|
||||
---
|
||||
|
||||
## 存储方案
|
||||
|
||||
### Amazon S3(推荐)
|
||||
|
||||
S3 是 AWS 的核心存储服务,与 SageMaker 深度集成。
|
||||
|
||||
```bash
|
||||
# 创建 S3 桶
|
||||
aws s3 mb s3://invoice-training-data --region us-east-1
|
||||
|
||||
# 上传训练数据
|
||||
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
|
||||
|
||||
# 创建目录结构
|
||||
aws s3api put-object --bucket invoice-training-data --key datasets/
|
||||
aws s3api put-object --bucket invoice-training-data --key models/
|
||||
```
|
||||
|
||||
### Mountpoint for Amazon S3
|
||||
|
||||
AWS 官方的 S3 挂载客户端,性能优于 s3fs:
|
||||
|
||||
```bash
|
||||
# 安装 Mountpoint
|
||||
wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb
|
||||
sudo dpkg -i mount-s3.deb
|
||||
|
||||
# 挂载 S3
|
||||
mkdir -p /mnt/s3-data
|
||||
mount-s3 invoice-training-data /mnt/s3-data --region us-east-1
|
||||
|
||||
# 配置缓存(推荐)
|
||||
mount-s3 invoice-training-data /mnt/s3-data \
|
||||
--region us-east-1 \
|
||||
--cache /tmp/s3-cache \
|
||||
--metadata-ttl 60
|
||||
```
|
||||
|
||||
### 本地开发挂载
|
||||
|
||||
**Linux/Mac (s3fs-fuse):**
|
||||
```bash
|
||||
# 安装
|
||||
sudo apt-get install s3fs
|
||||
|
||||
# 配置凭证
|
||||
echo ACCESS_KEY_ID:SECRET_ACCESS_KEY > ~/.passwd-s3fs
|
||||
chmod 600 ~/.passwd-s3fs
|
||||
|
||||
# 挂载
|
||||
s3fs invoice-training-data /mnt/s3 -o passwd_file=~/.passwd-s3fs
|
||||
```
|
||||
|
||||
**Windows (Rclone):**
|
||||
```powershell
|
||||
# 安装
|
||||
winget install Rclone.Rclone
|
||||
|
||||
# 配置
|
||||
rclone config # 选择 s3
|
||||
|
||||
# 挂载
|
||||
rclone mount aws:invoice-training-data Z: --vfs-cache-mode full
|
||||
```
|
||||
|
||||
### 存储费用
|
||||
|
||||
| 层级 | 价格 | 适用场景 |
|
||||
|------|------|---------|
|
||||
| S3 Standard | $0.023/GB/月 | 频繁访问 |
|
||||
| S3 Intelligent-Tiering | $0.023/GB/月 | 自动分层 |
|
||||
| S3 Infrequent Access | $0.0125/GB/月 | 偶尔访问 |
|
||||
| S3 Glacier | $0.004/GB/月 | 长期存档 |
|
||||
|
||||
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.12/月**
|
||||
|
||||
### SageMaker 数据输入模式
|
||||
|
||||
| 模式 | 说明 | 适用场景 |
|
||||
|------|------|---------|
|
||||
| File Mode | 下载到本地再训练 | 小数据集 |
|
||||
| Pipe Mode | 流式读取,不占本地空间 | 大数据集 |
|
||||
| FastFile Mode | 按需下载,最高 3x 加速 | 推荐 |
|
||||
|
||||
---
|
||||
|
||||
## 训练方案
|
||||
|
||||
### 方案总览
|
||||
|
||||
| 方案 | 适用场景 | 空闲费用 | 复杂度 | Spot 支持 |
|
||||
|------|---------|---------|--------|----------|
|
||||
| EC2 GPU | 简单直接 | 24/7 收费 | 低 | 是 |
|
||||
| SageMaker Training | MLOps 集成 | 按任务计费 | 中 | 是 |
|
||||
| EKS + GPU | Kubernetes | 复杂计费 | 高 | 是 |
|
||||
|
||||
### EC2 vs SageMaker
|
||||
|
||||
| 特性 | EC2 | SageMaker |
|
||||
|------|-----|-----------|
|
||||
| 本质 | 虚拟机 | 托管 ML 平台 |
|
||||
| 计算费用 | $3.06/hr (p3.2xlarge) | $3.825/hr (+25%) |
|
||||
| 管理开销 | 需自己配置 | 全托管 |
|
||||
| Spot 折扣 | 最高 90% | 最高 90% |
|
||||
| 实验跟踪 | 无 | 内置 |
|
||||
| 自动关机 | 无 | 任务完成自动停止 |
|
||||
|
||||
### GPU 实例价格 (2025 年 6 月降价后)
|
||||
|
||||
| 实例 | GPU | 显存 | On-Demand | Spot 价格 |
|
||||
|------|-----|------|-----------|----------|
|
||||
| g4dn.xlarge | 1x T4 | 16GB | $0.526/hr | ~$0.16/hr |
|
||||
| g4dn.2xlarge | 1x T4 | 16GB | $0.752/hr | ~$0.23/hr |
|
||||
| p3.2xlarge | 1x V100 | 16GB | $3.06/hr | ~$0.92/hr |
|
||||
| p3.8xlarge | 4x V100 | 64GB | $12.24/hr | ~$3.67/hr |
|
||||
| p4d.24xlarge | 8x A100 | 320GB | $32.77/hr | ~$9.83/hr |
|
||||
|
||||
**注意**: 2025 年 6 月 AWS 宣布 P4/P5 系列最高降价 45%。
|
||||
|
||||
### Spot 实例
|
||||
|
||||
```bash
|
||||
# EC2 Spot 请求
|
||||
aws ec2 request-spot-instances \
|
||||
--instance-count 1 \
|
||||
--type "one-time" \
|
||||
--launch-specification '{
|
||||
"ImageId": "ami-0123456789abcdef0",
|
||||
"InstanceType": "p3.2xlarge",
|
||||
"KeyName": "my-key"
|
||||
}'
|
||||
```
|
||||
|
||||
### SageMaker Managed Spot Training
|
||||
|
||||
```python
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="./src",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
instance_count=1,
|
||||
instance_type="ml.p3.2xlarge",
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
|
||||
# 启用 Spot 实例
|
||||
use_spot_instances=True,
|
||||
max_run=3600, # 最长运行 1 小时
|
||||
max_wait=7200, # 最长等待 2 小时
|
||||
|
||||
# 检查点配置(Spot 中断恢复)
|
||||
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||
checkpoint_local_path="/opt/ml/checkpoints",
|
||||
|
||||
hyperparameters={
|
||||
"epochs": 100,
|
||||
"batch-size": 16,
|
||||
}
|
||||
)
|
||||
|
||||
estimator.fit({
|
||||
"training": "s3://invoice-training-data/datasets/train/",
|
||||
"validation": "s3://invoice-training-data/datasets/val/"
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 推理方案
|
||||
|
||||
### 方案对比
|
||||
|
||||
| 方案 | GPU 支持 | 扩缩容 | 冷启动 | 价格 | 适用场景 |
|
||||
|------|---------|--------|--------|------|---------|
|
||||
| Lambda | 否 | 自动 0-N | 快 | 按调用 | 低流量、CPU 推理 |
|
||||
| Lambda + Container | 否 | 自动 0-N | 较慢 | 按调用 | 复杂依赖 |
|
||||
| ECS Fargate | 否 | 自动 | 中 | ~$30/月 | 容器化服务 |
|
||||
| ECS + EC2 GPU | 是 | 手动/自动 | 慢 | ~$100+/月 | GPU 推理 |
|
||||
| SageMaker Endpoint | 是 | 自动 | 慢 | ~$80+/月 | MLOps 集成 |
|
||||
| SageMaker Serverless | 否 | 自动 0-N | 中 | 按调用 | 间歇性流量 |
|
||||
|
||||
### 推荐方案 1: AWS Lambda (低流量)
|
||||
|
||||
对于 YOLO CPU 推理,Lambda 最经济:
|
||||
|
||||
```python
|
||||
# lambda_function.py
|
||||
import json
|
||||
import boto3
|
||||
from ultralytics import YOLO
|
||||
|
||||
# 模型在 Lambda Layer 或 /tmp 加载
|
||||
model = None
|
||||
|
||||
def load_model():
|
||||
global model
|
||||
if model is None:
|
||||
# 从 S3 下载模型到 /tmp
|
||||
s3 = boto3.client('s3')
|
||||
s3.download_file('invoice-models', 'best.pt', '/tmp/best.pt')
|
||||
model = YOLO('/tmp/best.pt')
|
||||
return model
|
||||
|
||||
def lambda_handler(event, context):
|
||||
model = load_model()
|
||||
|
||||
# 从 S3 获取图片
|
||||
s3 = boto3.client('s3')
|
||||
bucket = event['bucket']
|
||||
key = event['key']
|
||||
|
||||
local_path = f'/tmp/{key.split("/")[-1]}'
|
||||
s3.download_file(bucket, key, local_path)
|
||||
|
||||
# 执行推理
|
||||
results = model.predict(local_path, conf=0.5)
|
||||
|
||||
return {
|
||||
'statusCode': 200,
|
||||
'body': json.dumps({
|
||||
'fields': extract_fields(results),
|
||||
'confidence': get_confidence(results)
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Lambda 配置:**
|
||||
```yaml
|
||||
# serverless.yml
|
||||
service: invoice-inference
|
||||
|
||||
provider:
|
||||
name: aws
|
||||
runtime: python3.11
|
||||
timeout: 30
|
||||
memorySize: 4096 # 4GB 内存
|
||||
|
||||
functions:
|
||||
infer:
|
||||
handler: lambda_function.lambda_handler
|
||||
events:
|
||||
- http:
|
||||
path: /infer
|
||||
method: post
|
||||
layers:
|
||||
- arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
|
||||
```
|
||||
|
||||
### 推荐方案 2: ECS Fargate (中流量)
|
||||
|
||||
```yaml
|
||||
# task-definition.json
|
||||
{
|
||||
"family": "invoice-inference",
|
||||
"networkMode": "awsvpc",
|
||||
"requiresCompatibilities": ["FARGATE"],
|
||||
"cpu": "2048",
|
||||
"memory": "4096",
|
||||
"containerDefinitions": [
|
||||
{
|
||||
"name": "inference",
|
||||
"image": "123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest",
|
||||
"portMappings": [
|
||||
{
|
||||
"containerPort": 8000,
|
||||
"protocol": "tcp"
|
||||
}
|
||||
],
|
||||
"environment": [
|
||||
{"name": "MODEL_PATH", "value": "/app/models/best.pt"}
|
||||
],
|
||||
"logConfiguration": {
|
||||
"logDriver": "awslogs",
|
||||
"options": {
|
||||
"awslogs-group": "/ecs/invoice-inference",
|
||||
"awslogs-region": "us-east-1",
|
||||
"awslogs-stream-prefix": "ecs"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Auto Scaling 配置:**
|
||||
```bash
|
||||
# 创建 Auto Scaling Target
|
||||
aws application-autoscaling register-scalable-target \
|
||||
--service-namespace ecs \
|
||||
--resource-id service/invoice-cluster/invoice-service \
|
||||
--scalable-dimension ecs:service:DesiredCount \
|
||||
--min-capacity 1 \
|
||||
--max-capacity 10
|
||||
|
||||
# 基于 CPU 使用率扩缩容
|
||||
aws application-autoscaling put-scaling-policy \
|
||||
--service-namespace ecs \
|
||||
--resource-id service/invoice-cluster/invoice-service \
|
||||
--scalable-dimension ecs:service:DesiredCount \
|
||||
--policy-name cpu-scaling \
|
||||
--policy-type TargetTrackingScaling \
|
||||
--target-tracking-scaling-policy-configuration '{
|
||||
"TargetValue": 70,
|
||||
"PredefinedMetricSpecification": {
|
||||
"PredefinedMetricType": "ECSServiceAverageCPUUtilization"
|
||||
},
|
||||
"ScaleOutCooldown": 60,
|
||||
"ScaleInCooldown": 120
|
||||
}'
|
||||
```
|
||||
|
||||
### 方案 3: SageMaker Serverless Inference
|
||||
|
||||
```python
|
||||
from sagemaker.serverless import ServerlessInferenceConfig
|
||||
from sagemaker.pytorch import PyTorchModel
|
||||
|
||||
model = PyTorchModel(
|
||||
model_data="s3://invoice-models/model.tar.gz",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
entry_point="inference.py",
|
||||
framework_version="2.0",
|
||||
py_version="py310"
|
||||
)
|
||||
|
||||
serverless_config = ServerlessInferenceConfig(
|
||||
memory_size_in_mb=4096,
|
||||
max_concurrency=10
|
||||
)
|
||||
|
||||
predictor = model.deploy(
|
||||
serverless_inference_config=serverless_config,
|
||||
endpoint_name="invoice-inference-serverless"
|
||||
)
|
||||
```
|
||||
|
||||
### 推理性能对比
|
||||
|
||||
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|
||||
|------|------------|---------|---------|
|
||||
| Lambda 4GB | ~500-800ms | 按需扩展 | ~$15 (10K 请求) |
|
||||
| Fargate 2vCPU 4GB | ~300-500ms | ~50 QPS | ~$30 |
|
||||
| Fargate 4vCPU 8GB | ~200-300ms | ~100 QPS | ~$60 |
|
||||
| EC2 g4dn.xlarge (T4) | ~50-100ms | ~200 QPS | ~$380 |
|
||||
|
||||
---
|
||||
|
||||
## 价格对比
|
||||
|
||||
### 训练成本对比(假设每天训练 2 小时)
|
||||
|
||||
| 方案 | 计算方式 | 月费 |
|
||||
|------|---------|------|
|
||||
| EC2 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
|
||||
| EC2 按需启停 | 2h × 30天 × $3.06 | ~$184 |
|
||||
| EC2 Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
|
||||
| SageMaker On-Demand | 2h × 30天 × $3.825 | ~$230 |
|
||||
| SageMaker Spot | 2h × 30天 × $1.15 | ~$69 |
|
||||
|
||||
### 本项目完整成本估算
|
||||
|
||||
| 组件 | 推荐方案 | 月费 |
|
||||
|------|---------|------|
|
||||
| 数据存储 | S3 Standard (5GB) | ~$0.12 |
|
||||
| 数据库 | RDS PostgreSQL (db.t3.micro) | ~$15 |
|
||||
| 推理服务 | Lambda (10K 请求/月) | ~$15 |
|
||||
| 推理服务 (替代) | ECS Fargate | ~$30 |
|
||||
| 训练服务 | SageMaker Spot (按需) | ~$2-5/次 |
|
||||
| ECR (镜像存储) | 基本使用 | ~$1 |
|
||||
| **总计 (Lambda)** | | **~$35/月** + 训练费 |
|
||||
| **总计 (Fargate)** | | **~$50/月** + 训练费 |
|
||||
|
||||
---
|
||||
|
||||
## 推荐架构
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Amazon S3 │
|
||||
│ ├── training-images/ │
|
||||
│ ├── datasets/ │
|
||||
│ ├── models/ │
|
||||
│ └── checkpoints/ │
|
||||
└─────────────────┬───────────────────┘
|
||||
│
|
||||
┌─────────────────────────────────┼─────────────────────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
|
||||
│ 推理服务 │ │ 训练服务 │ │ API Gateway │
|
||||
│ │ │ │ │ │
|
||||
│ 方案 A: Lambda │ │ SageMaker │ │ REST API │
|
||||
│ ~$15/月 (10K req) │ │ Managed Spot │ │ 触发 Lambda/ECS │
|
||||
│ │ │ ~$2-5/次训练 │ │ │
|
||||
│ 方案 B: ECS Fargate │ │ │ │ │
|
||||
│ ~$30/月 │ │ - 自动启动 │ │ │
|
||||
│ │ │ - 训练完成自动停止 │ │ │
|
||||
│ ┌───────────────────┐ │ │ - 检查点自动保存 │ │ │
|
||||
│ │ FastAPI + YOLO │ │ │ │ │ │
|
||||
│ │ CPU 推理 │ │ │ │ │ │
|
||||
│ └───────────────────┘ │ └───────────┬───────────┘ └───────────────────────┘
|
||||
└───────────┬───────────┘ │
|
||||
│ │
|
||||
└───────────────────────────────┼───────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────┐
|
||||
│ Amazon RDS │
|
||||
│ PostgreSQL │
|
||||
│ db.t3.micro │
|
||||
│ ~$15/月 │
|
||||
└───────────────────────┘
|
||||
```
|
||||
|
||||
### Lambda 推理配置
|
||||
|
||||
```yaml
|
||||
# SAM template
|
||||
AWSTemplateFormatVersion: '2010-09-09'
|
||||
Transform: AWS::Serverless-2016-10-31
|
||||
|
||||
Resources:
|
||||
InferenceFunction:
|
||||
Type: AWS::Serverless::Function
|
||||
Properties:
|
||||
Handler: app.lambda_handler
|
||||
Runtime: python3.11
|
||||
MemorySize: 4096
|
||||
Timeout: 30
|
||||
Environment:
|
||||
Variables:
|
||||
MODEL_BUCKET: invoice-models
|
||||
MODEL_KEY: best.pt
|
||||
Policies:
|
||||
- S3ReadPolicy:
|
||||
BucketName: invoice-models
|
||||
- S3ReadPolicy:
|
||||
BucketName: invoice-uploads
|
||||
Events:
|
||||
InferApi:
|
||||
Type: Api
|
||||
Properties:
|
||||
Path: /infer
|
||||
Method: post
|
||||
```
|
||||
|
||||
### SageMaker 训练配置
|
||||
|
||||
```python
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="./src",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
instance_count=1,
|
||||
instance_type="ml.g4dn.xlarge", # T4 GPU
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
|
||||
# Spot 实例配置
|
||||
use_spot_instances=True,
|
||||
max_run=7200,
|
||||
max_wait=14400,
|
||||
|
||||
# 检查点
|
||||
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||
|
||||
hyperparameters={
|
||||
"epochs": 100,
|
||||
"batch-size": 16,
|
||||
"model": "yolo11n.pt"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 实施步骤
|
||||
|
||||
### 阶段 1: 存储设置
|
||||
|
||||
```bash
|
||||
# 创建 S3 桶
|
||||
aws s3 mb s3://invoice-training-data --region us-east-1
|
||||
aws s3 mb s3://invoice-models --region us-east-1
|
||||
|
||||
# 上传训练数据
|
||||
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
|
||||
|
||||
# 配置生命周期(可选,自动转冷存储)
|
||||
aws s3api put-bucket-lifecycle-configuration \
|
||||
--bucket invoice-training-data \
|
||||
--lifecycle-configuration '{
|
||||
"Rules": [{
|
||||
"ID": "MoveToIA",
|
||||
"Status": "Enabled",
|
||||
"Transitions": [{
|
||||
"Days": 30,
|
||||
"StorageClass": "STANDARD_IA"
|
||||
}]
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
### 阶段 2: 数据库设置
|
||||
|
||||
```bash
|
||||
# 创建 RDS PostgreSQL
|
||||
aws rds create-db-instance \
|
||||
--db-instance-identifier invoice-db \
|
||||
--db-instance-class db.t3.micro \
|
||||
--engine postgres \
|
||||
--engine-version 15 \
|
||||
--master-username docmaster \
|
||||
--master-user-password YOUR_PASSWORD \
|
||||
--allocated-storage 20
|
||||
|
||||
# 配置安全组
|
||||
aws ec2 authorize-security-group-ingress \
|
||||
--group-id sg-xxx \
|
||||
--protocol tcp \
|
||||
--port 5432 \
|
||||
--source-group sg-yyy
|
||||
```
|
||||
|
||||
### 阶段 3: 推理服务部署
|
||||
|
||||
**方案 A: Lambda**
|
||||
|
||||
```bash
|
||||
# 创建 Lambda Layer (依赖)
|
||||
cd lambda-layer
|
||||
pip install ultralytics opencv-python-headless -t python/
|
||||
zip -r layer.zip python/
|
||||
aws lambda publish-layer-version \
|
||||
--layer-name yolo-deps \
|
||||
--zip-file fileb://layer.zip \
|
||||
--compatible-runtimes python3.11
|
||||
|
||||
# 部署 Lambda 函数
|
||||
cd ../lambda
|
||||
zip function.zip lambda_function.py
|
||||
aws lambda create-function \
|
||||
--function-name invoice-inference \
|
||||
--runtime python3.11 \
|
||||
--handler lambda_function.lambda_handler \
|
||||
--role arn:aws:iam::123456789012:role/LambdaRole \
|
||||
--zip-file fileb://function.zip \
|
||||
--memory-size 4096 \
|
||||
--timeout 30 \
|
||||
--layers arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
|
||||
|
||||
# 创建 API Gateway
|
||||
aws apigatewayv2 create-api \
|
||||
--name invoice-api \
|
||||
--protocol-type HTTP \
|
||||
--target arn:aws:lambda:us-east-1:123456789012:function:invoice-inference
|
||||
```
|
||||
|
||||
**方案 B: ECS Fargate**
|
||||
|
||||
```bash
|
||||
# 创建 ECR 仓库
|
||||
aws ecr create-repository --repository-name invoice-inference
|
||||
|
||||
# 构建并推送镜像
|
||||
aws ecr get-login-password | docker login --username AWS --password-stdin 123456789012.dkr.ecr.us-east-1.amazonaws.com
|
||||
docker build -t invoice-inference .
|
||||
docker tag invoice-inference:latest 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
|
||||
docker push 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
|
||||
|
||||
# 创建 ECS 集群
|
||||
aws ecs create-cluster --cluster-name invoice-cluster
|
||||
|
||||
# 注册任务定义
|
||||
aws ecs register-task-definition --cli-input-json file://task-definition.json
|
||||
|
||||
# 创建服务
|
||||
aws ecs create-service \
|
||||
--cluster invoice-cluster \
|
||||
--service-name invoice-service \
|
||||
--task-definition invoice-inference \
|
||||
--desired-count 1 \
|
||||
--launch-type FARGATE \
|
||||
--network-configuration '{
|
||||
"awsvpcConfiguration": {
|
||||
"subnets": ["subnet-xxx"],
|
||||
"securityGroups": ["sg-xxx"],
|
||||
"assignPublicIp": "ENABLED"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 阶段 4: 训练服务设置
|
||||
|
||||
```python
|
||||
# setup_sagemaker.py
|
||||
import boto3
|
||||
import sagemaker
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
# 创建 SageMaker 执行角色
|
||||
iam = boto3.client('iam')
|
||||
role_arn = "arn:aws:iam::123456789012:role/SageMakerExecutionRole"
|
||||
|
||||
# 配置训练任务
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="./src/training",
|
||||
role=role_arn,
|
||||
instance_count=1,
|
||||
instance_type="ml.g4dn.xlarge",
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
use_spot_instances=True,
|
||||
max_run=7200,
|
||||
max_wait=14400,
|
||||
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||
)
|
||||
|
||||
# 保存配置供后续使用
|
||||
estimator.save("training_config.json")
|
||||
```
|
||||
|
||||
### 阶段 5: 集成训练触发 API
|
||||
|
||||
```python
|
||||
# lambda_trigger_training.py
|
||||
import boto3
|
||||
import sagemaker
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
def lambda_handler(event, context):
|
||||
"""触发 SageMaker 训练任务"""
|
||||
|
||||
epochs = event.get('epochs', 100)
|
||||
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="s3://invoice-training-data/code/",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
instance_count=1,
|
||||
instance_type="ml.g4dn.xlarge",
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
use_spot_instances=True,
|
||||
max_run=7200,
|
||||
max_wait=14400,
|
||||
hyperparameters={
|
||||
"epochs": epochs,
|
||||
"batch-size": 16,
|
||||
}
|
||||
)
|
||||
|
||||
estimator.fit(
|
||||
inputs={
|
||||
"training": "s3://invoice-training-data/datasets/train/",
|
||||
"validation": "s3://invoice-training-data/datasets/val/"
|
||||
},
|
||||
wait=False # 异步执行
|
||||
)
|
||||
|
||||
return {
|
||||
'statusCode': 200,
|
||||
'body': {
|
||||
'training_job_name': estimator.latest_training_job.name,
|
||||
'status': 'Started'
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## AWS vs Azure 对比
|
||||
|
||||
### 服务对应关系
|
||||
|
||||
| 功能 | AWS | Azure |
|
||||
|------|-----|-------|
|
||||
| 对象存储 | S3 | Blob Storage |
|
||||
| 挂载工具 | Mountpoint for S3 | BlobFuse2 |
|
||||
| ML 平台 | SageMaker | Azure ML |
|
||||
| 容器服务 | ECS/Fargate | Container Apps |
|
||||
| Serverless | Lambda | Functions |
|
||||
| GPU VM | EC2 P3/G4dn | NC/ND 系列 |
|
||||
| 容器注册 | ECR | ACR |
|
||||
| 数据库 | RDS PostgreSQL | PostgreSQL Flexible |
|
||||
|
||||
### 价格对比
|
||||
|
||||
| 组件 | AWS | Azure |
|
||||
|------|-----|-------|
|
||||
| 存储 (5GB) | ~$0.12/月 | ~$0.09/月 |
|
||||
| 数据库 | ~$15/月 | ~$25/月 |
|
||||
| 推理 (Serverless) | ~$15/月 | ~$30/月 |
|
||||
| 推理 (容器) | ~$30/月 | ~$30/月 |
|
||||
| 训练 (Spot GPU) | ~$2-5/次 | ~$1-5/次 |
|
||||
| **总计** | **~$35-50/月** | **~$65/月** |
|
||||
|
||||
### 优劣对比
|
||||
|
||||
| 方面 | AWS 优势 | Azure 优势 |
|
||||
|------|---------|-----------|
|
||||
| 价格 | Lambda 更便宜 | GPU Spot 更便宜 |
|
||||
| ML 平台 | SageMaker 更成熟 | Azure ML 更易用 |
|
||||
| Serverless GPU | 无原生支持 | Container Apps GPU |
|
||||
| 文档 | 更丰富 | 中文文档更好 |
|
||||
| 生态 | 更大 | Office 365 集成 |
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 推荐配置
|
||||
|
||||
| 组件 | 推荐方案 | 月费估算 |
|
||||
|------|---------|---------|
|
||||
| 数据存储 | S3 Standard | ~$0.12 |
|
||||
| 数据库 | RDS db.t3.micro | ~$15 |
|
||||
| 推理服务 | Lambda 4GB | ~$15 |
|
||||
| 训练服务 | SageMaker Spot | 按需 ~$2-5/次 |
|
||||
| ECR | 基本使用 | ~$1 |
|
||||
| **总计** | | **~$35/月** + 训练费 |
|
||||
|
||||
### 关键决策
|
||||
|
||||
| 场景 | 选择 |
|
||||
|------|------|
|
||||
| 最低成本 | Lambda + SageMaker Spot |
|
||||
| 稳定推理 | ECS Fargate |
|
||||
| GPU 推理 | ECS + EC2 GPU |
|
||||
| MLOps 集成 | SageMaker 全家桶 |
|
||||
|
||||
### 注意事项
|
||||
|
||||
1. **Lambda 冷启动**: 首次调用 ~3-5 秒,可用 Provisioned Concurrency 解决
|
||||
2. **Spot 中断**: 配置检查点,SageMaker 自动恢复
|
||||
3. **S3 传输**: 同区域免费,跨区域收费
|
||||
4. **Fargate 无 GPU**: 需要 GPU 必须用 ECS + EC2
|
||||
5. **SageMaker 加价**: 比 EC2 贵 ~25%,但省管理成本
|
||||
567
docs/azure-deployment-guide.md
Normal file
567
docs/azure-deployment-guide.md
Normal file
@@ -0,0 +1,567 @@
|
||||
# Azure 部署方案完整指南
|
||||
|
||||
## 目录
|
||||
- [核心问题](#核心问题)
|
||||
- [存储方案](#存储方案)
|
||||
- [训练方案](#训练方案)
|
||||
- [推理方案](#推理方案)
|
||||
- [价格对比](#价格对比)
|
||||
- [推荐架构](#推荐架构)
|
||||
- [实施步骤](#实施步骤)
|
||||
|
||||
---
|
||||
|
||||
## 核心问题
|
||||
|
||||
| 问题 | 答案 |
|
||||
|------|------|
|
||||
| Azure Blob Storage 能用于训练吗? | 可以,用 BlobFuse2 挂载 |
|
||||
| 能实时从 Blob 读取训练吗? | 可以,但建议配置本地缓存 |
|
||||
| 本地能挂载 Azure Blob 吗? | 可以,用 Rclone (Windows) 或 BlobFuse2 (Linux) |
|
||||
| VM 空闲时收费吗? | 收费,只要开机就按小时计费 |
|
||||
| 如何按需付费? | 用 Serverless GPU 或 min=0 的 Compute Cluster |
|
||||
| 推理服务用什么? | Container Apps (CPU) 或 Serverless GPU |
|
||||
|
||||
---
|
||||
|
||||
## 存储方案
|
||||
|
||||
### Azure Blob Storage + BlobFuse2(推荐)
|
||||
|
||||
```bash
|
||||
# 安装 BlobFuse2
|
||||
sudo apt-get install blobfuse2
|
||||
|
||||
# 配置文件
|
||||
cat > ~/blobfuse-config.yaml << 'EOF'
|
||||
logging:
|
||||
type: syslog
|
||||
level: log_warning
|
||||
|
||||
components:
|
||||
- libfuse
|
||||
- file_cache
|
||||
- azstorage
|
||||
|
||||
file_cache:
|
||||
path: /tmp/blobfuse2
|
||||
timeout-sec: 120
|
||||
max-size-mb: 4096
|
||||
|
||||
azstorage:
|
||||
type: block
|
||||
account-name: YOUR_ACCOUNT
|
||||
account-key: YOUR_KEY
|
||||
container: training-images
|
||||
EOF
|
||||
|
||||
# 挂载
|
||||
mkdir -p /mnt/azure-blob
|
||||
blobfuse2 mount /mnt/azure-blob --config-file=~/blobfuse-config.yaml
|
||||
```
|
||||
|
||||
### 本地开发(Windows)
|
||||
|
||||
```powershell
|
||||
# 安装
|
||||
winget install WinFsp.WinFsp
|
||||
winget install Rclone.Rclone
|
||||
|
||||
# 配置
|
||||
rclone config # 选择 azureblob
|
||||
|
||||
# 挂载为 Z: 盘
|
||||
rclone mount azure:training-images Z: --vfs-cache-mode full
|
||||
```
|
||||
|
||||
### 存储费用
|
||||
|
||||
| 层级 | 价格 | 适用场景 |
|
||||
|------|------|---------|
|
||||
| Hot | $0.018/GB/月 | 频繁访问 |
|
||||
| Cool | $0.01/GB/月 | 偶尔访问 |
|
||||
| Archive | $0.002/GB/月 | 长期存档 |
|
||||
|
||||
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.09/月**
|
||||
|
||||
---
|
||||
|
||||
## 训练方案
|
||||
|
||||
### 方案总览
|
||||
|
||||
| 方案 | 适用场景 | 空闲费用 | 复杂度 |
|
||||
|------|---------|---------|--------|
|
||||
| Azure VM | 简单直接 | 24/7 收费 | 低 |
|
||||
| Azure VM Spot | 省钱、可中断 | 24/7 收费 | 低 |
|
||||
| Azure ML Compute | MLOps 集成 | 可缩到 0 | 中 |
|
||||
| Container Apps GPU | Serverless | 自动缩到 0 | 中 |
|
||||
|
||||
### Azure VM vs Azure ML
|
||||
|
||||
| 特性 | Azure VM | Azure ML |
|
||||
|------|----------|----------|
|
||||
| 本质 | 虚拟机 | 托管 ML 平台 |
|
||||
| 计算费用 | $3.06/hr (NC6s_v3) | $3.06/hr (相同) |
|
||||
| 附加费用 | ~$5/月 | ~$20-30/月 |
|
||||
| 实验跟踪 | 无 | 内置 |
|
||||
| 自动扩缩 | 无 | 支持 min=0 |
|
||||
| 适用人群 | DevOps | 数据科学家 |
|
||||
|
||||
### Azure ML 附加费用明细
|
||||
|
||||
| 服务 | 用途 | 费用 |
|
||||
|------|------|------|
|
||||
| Container Registry | Docker 镜像 | ~$5-20/月 |
|
||||
| Blob Storage | 日志、模型 | ~$0.10/月 |
|
||||
| Application Insights | 监控 | ~$0-10/月 |
|
||||
| Key Vault | 密钥管理 | <$1/月 |
|
||||
|
||||
### Spot 实例
|
||||
|
||||
两种平台都支持 Spot/低优先级实例,最高节省 90%:
|
||||
|
||||
| 类型 | 正常价格 | Spot 价格 | 节省 |
|
||||
|------|---------|----------|------|
|
||||
| NC6s_v3 (V100) | $3.06/hr | ~$0.92/hr | 70% |
|
||||
| NC24ads_A100_v4 | $3.67/hr | ~$1.15/hr | 69% |
|
||||
|
||||
### GPU 实例价格
|
||||
|
||||
| 实例 | GPU | 显存 | 价格/小时 | Spot 价格 |
|
||||
|------|-----|------|---------|----------|
|
||||
| NC6s_v3 | 1x V100 | 16GB | $3.06 | $0.92 |
|
||||
| NC24s_v3 | 4x V100 | 64GB | $12.24 | $3.67 |
|
||||
| NC24ads_A100_v4 | 1x A100 | 80GB | $3.67 | $1.15 |
|
||||
| NC48ads_A100_v4 | 2x A100 | 160GB | $7.35 | $2.30 |
|
||||
|
||||
---
|
||||
|
||||
## 推理方案
|
||||
|
||||
### 方案对比
|
||||
|
||||
| 方案 | GPU 支持 | 扩缩容 | 价格 | 适用场景 |
|
||||
|------|---------|--------|------|---------|
|
||||
| Container Apps (CPU) | 否 | 自动 0-N | ~$30/月 | YOLO 推理 (够用) |
|
||||
| Container Apps (GPU) | 是 | Serverless | 按秒计费 | 高吞吐推理 |
|
||||
| Azure App Service | 否 | 手动/自动 | ~$50/月 | 简单部署 |
|
||||
| Azure ML Endpoint | 是 | 自动 | ~$100+/月 | MLOps 集成 |
|
||||
| AKS (Kubernetes) | 是 | 自动 | 复杂计费 | 大规模生产 |
|
||||
|
||||
### 推荐: Container Apps (CPU)
|
||||
|
||||
对于 YOLO 推理,**CPU 足够**,不需要 GPU:
|
||||
- YOLOv11n 在 CPU 上推理时间 ~200-500ms
|
||||
- 比 GPU 便宜很多,适合中低流量
|
||||
|
||||
```yaml
|
||||
# Container Apps 配置
|
||||
name: invoice-inference
|
||||
image: myacr.azurecr.io/invoice-inference:v1
|
||||
resources:
|
||||
cpu: 2.0
|
||||
memory: 4Gi
|
||||
scale:
|
||||
minReplicas: 1 # 最少 1 个实例保持响应
|
||||
maxReplicas: 10 # 最多扩展到 10 个
|
||||
rules:
|
||||
- name: http-scaling
|
||||
http:
|
||||
metadata:
|
||||
concurrentRequests: "50" # 每实例 50 并发时扩容
|
||||
```
|
||||
|
||||
### 推理服务代码示例
|
||||
|
||||
```python
|
||||
# Dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制代码和模型
|
||||
COPY src/ ./src/
|
||||
COPY models/best.pt ./models/
|
||||
|
||||
# 启动服务
|
||||
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
```python
|
||||
# src/web/app.py
|
||||
from fastapi import FastAPI, UploadFile, File
|
||||
from ultralytics import YOLO
|
||||
import tempfile
|
||||
|
||||
app = FastAPI()
|
||||
model = YOLO("models/best.pt")
|
||||
|
||||
@app.post("/api/v1/infer")
|
||||
async def infer(file: UploadFile = File(...)):
|
||||
# 保存上传文件
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
# 执行推理
|
||||
results = model.predict(tmp_path, conf=0.5)
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"fields": extract_fields(results),
|
||||
"confidence": get_confidence(results)
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
```
|
||||
|
||||
### 部署命令
|
||||
|
||||
```bash
|
||||
# 1. 创建 Container Registry
|
||||
az acr create --name invoiceacr --resource-group myRG --sku Basic
|
||||
|
||||
# 2. 构建并推送镜像
|
||||
az acr build --registry invoiceacr --image invoice-inference:v1 .
|
||||
|
||||
# 3. 创建 Container Apps 环境
|
||||
az containerapp env create \
|
||||
--name invoice-env \
|
||||
--resource-group myRG \
|
||||
--location eastus
|
||||
|
||||
# 4. 部署应用
|
||||
az containerapp create \
|
||||
--name invoice-inference \
|
||||
--resource-group myRG \
|
||||
--environment invoice-env \
|
||||
--image invoiceacr.azurecr.io/invoice-inference:v1 \
|
||||
--registry-server invoiceacr.azurecr.io \
|
||||
--cpu 2 --memory 4Gi \
|
||||
--min-replicas 1 --max-replicas 10 \
|
||||
--ingress external --target-port 8000
|
||||
|
||||
# 5. 获取 URL
|
||||
az containerapp show --name invoice-inference --resource-group myRG --query properties.configuration.ingress.fqdn
|
||||
```
|
||||
|
||||
### 高吞吐场景: Serverless GPU
|
||||
|
||||
如果需要 GPU 加速推理(高并发、低延迟):
|
||||
|
||||
```bash
|
||||
# 请求 GPU 配额
|
||||
az containerapp env workload-profile add \
|
||||
--name invoice-env \
|
||||
--resource-group myRG \
|
||||
--workload-profile-name gpu \
|
||||
--workload-profile-type Consumption-GPU-T4
|
||||
|
||||
# 部署 GPU 版本
|
||||
az containerapp create \
|
||||
--name invoice-inference-gpu \
|
||||
--resource-group myRG \
|
||||
--environment invoice-env \
|
||||
--image invoiceacr.azurecr.io/invoice-inference-gpu:v1 \
|
||||
--workload-profile-name gpu \
|
||||
--cpu 4 --memory 8Gi \
|
||||
--min-replicas 0 --max-replicas 5 \
|
||||
--ingress external --target-port 8000
|
||||
```
|
||||
|
||||
### 推理性能对比
|
||||
|
||||
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|
||||
|------|------------|---------|---------|
|
||||
| CPU 2核 4GB | ~300-500ms | ~50 QPS | ~$30 |
|
||||
| CPU 4核 8GB | ~200-300ms | ~100 QPS | ~$60 |
|
||||
| GPU T4 | ~50-100ms | ~200 QPS | 按秒计费 |
|
||||
| GPU A100 | ~20-50ms | ~500 QPS | 按秒计费 |
|
||||
|
||||
---
|
||||
|
||||
## 价格对比
|
||||
|
||||
### 月度成本对比(假设每天训练 2 小时)
|
||||
|
||||
| 方案 | 计算方式 | 月费 |
|
||||
|------|---------|------|
|
||||
| VM 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
|
||||
| VM 按需启停 | 2h × 30天 × $3.06 | ~$184 |
|
||||
| VM Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
|
||||
| Serverless GPU | 2h × 30天 × ~$3.50 | ~$210 |
|
||||
| Azure ML (min=0) | 2h × 30天 × $3.06 | ~$184 |
|
||||
|
||||
### 本项目完整成本估算
|
||||
|
||||
| 组件 | 推荐方案 | 月费 |
|
||||
|------|---------|------|
|
||||
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
|
||||
| 数据库 | PostgreSQL Flexible (Burstable B1ms) | ~$25 |
|
||||
| 推理服务 | Container Apps CPU (2核4GB) | ~$30 |
|
||||
| 训练服务 | Azure ML Spot (按需) | ~$1-5/次 |
|
||||
| Container Registry | Basic | ~$5 |
|
||||
| **总计** | | **~$65/月** + 训练费 |
|
||||
|
||||
---
|
||||
|
||||
## 推荐架构
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Azure Blob Storage │
|
||||
│ ├── training-images/ │
|
||||
│ ├── datasets/ │
|
||||
│ └── models/ │
|
||||
└─────────────────┬───────────────────┘
|
||||
│
|
||||
┌─────────────────────────────────┼─────────────────────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
|
||||
│ 推理服务 (24/7) │ │ 训练服务 (按需) │ │ Web UI (可选) │
|
||||
│ Container Apps │ │ Azure ML Compute │ │ Static Web Apps │
|
||||
│ CPU 2核 4GB │ │ min=0, Spot │ │ ~$0 (免费层) │
|
||||
│ ~$30/月 │ │ ~$1-5/次训练 │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │
|
||||
│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │
|
||||
│ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │
|
||||
│ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │
|
||||
└───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘
|
||||
│ │ │
|
||||
└───────────────────────────────┼───────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────┐
|
||||
│ PostgreSQL │
|
||||
│ Flexible Server │
|
||||
│ Burstable B1ms │
|
||||
│ ~$25/月 │
|
||||
└───────────────────────┘
|
||||
```
|
||||
|
||||
### 推理服务配置
|
||||
|
||||
```yaml
|
||||
# Container Apps - CPU (24/7 运行)
|
||||
name: invoice-inference
|
||||
resources:
|
||||
cpu: 2
|
||||
memory: 4Gi
|
||||
scale:
|
||||
minReplicas: 1
|
||||
maxReplicas: 10
|
||||
env:
|
||||
- name: MODEL_PATH
|
||||
value: /app/models/best.pt
|
||||
- name: DB_HOST
|
||||
secretRef: db-host
|
||||
- name: DB_PASSWORD
|
||||
secretRef: db-password
|
||||
```
|
||||
|
||||
### 训练服务配置
|
||||
|
||||
**方案 A: Azure ML Compute(推荐)**
|
||||
|
||||
```python
|
||||
from azure.ai.ml.entities import AmlCompute
|
||||
|
||||
gpu_cluster = AmlCompute(
|
||||
name="gpu-cluster",
|
||||
size="Standard_NC6s_v3",
|
||||
min_instances=0, # 空闲时关机
|
||||
max_instances=1,
|
||||
tier="LowPriority", # Spot 实例
|
||||
idle_time_before_scale_down=120
|
||||
)
|
||||
```
|
||||
|
||||
**方案 B: Container Apps Serverless GPU**
|
||||
|
||||
```yaml
|
||||
name: invoice-training
|
||||
resources:
|
||||
gpu: 1
|
||||
gpuType: A100
|
||||
scale:
|
||||
minReplicas: 0
|
||||
maxReplicas: 1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 实施步骤
|
||||
|
||||
### 阶段 1: 存储设置
|
||||
|
||||
```bash
|
||||
# 创建 Storage Account
|
||||
az storage account create \
|
||||
--name invoicestorage \
|
||||
--resource-group myRG \
|
||||
--sku Standard_LRS
|
||||
|
||||
# 创建容器
|
||||
az storage container create --name training-images --account-name invoicestorage
|
||||
az storage container create --name datasets --account-name invoicestorage
|
||||
az storage container create --name models --account-name invoicestorage
|
||||
|
||||
# 上传训练数据
|
||||
az storage blob upload-batch \
|
||||
--destination training-images \
|
||||
--source ./data/dataset/temp \
|
||||
--account-name invoicestorage
|
||||
```
|
||||
|
||||
### 阶段 2: 数据库设置
|
||||
|
||||
```bash
|
||||
# 创建 PostgreSQL
|
||||
az postgres flexible-server create \
|
||||
--name invoice-db \
|
||||
--resource-group myRG \
|
||||
--sku-name Standard_B1ms \
|
||||
--storage-size 32 \
|
||||
--admin-user docmaster \
|
||||
--admin-password YOUR_PASSWORD
|
||||
|
||||
# 配置防火墙
|
||||
az postgres flexible-server firewall-rule create \
|
||||
--name allow-azure \
|
||||
--resource-group myRG \
|
||||
--server-name invoice-db \
|
||||
--start-ip-address 0.0.0.0 \
|
||||
--end-ip-address 0.0.0.0
|
||||
```
|
||||
|
||||
### 阶段 3: 推理服务部署
|
||||
|
||||
```bash
|
||||
# 创建 Container Registry
|
||||
az acr create --name invoiceacr --resource-group myRG --sku Basic
|
||||
|
||||
# 构建镜像
|
||||
az acr build --registry invoiceacr --image invoice-inference:v1 .
|
||||
|
||||
# 创建环境
|
||||
az containerapp env create \
|
||||
--name invoice-env \
|
||||
--resource-group myRG \
|
||||
--location eastus
|
||||
|
||||
# 部署推理服务
|
||||
az containerapp create \
|
||||
--name invoice-inference \
|
||||
--resource-group myRG \
|
||||
--environment invoice-env \
|
||||
--image invoiceacr.azurecr.io/invoice-inference:v1 \
|
||||
--registry-server invoiceacr.azurecr.io \
|
||||
--cpu 2 --memory 4Gi \
|
||||
--min-replicas 1 --max-replicas 10 \
|
||||
--ingress external --target-port 8000 \
|
||||
--env-vars \
|
||||
DB_HOST=invoice-db.postgres.database.azure.com \
|
||||
DB_NAME=docmaster \
|
||||
DB_USER=docmaster \
|
||||
--secrets db-password=YOUR_PASSWORD
|
||||
```
|
||||
|
||||
### 阶段 4: 训练服务设置
|
||||
|
||||
```bash
|
||||
# 创建 Azure ML Workspace
|
||||
az ml workspace create --name invoice-ml --resource-group myRG
|
||||
|
||||
# 创建 Compute Cluster
|
||||
az ml compute create --name gpu-cluster \
|
||||
--type AmlCompute \
|
||||
--size Standard_NC6s_v3 \
|
||||
--min-instances 0 \
|
||||
--max-instances 1 \
|
||||
--tier low_priority
|
||||
```
|
||||
|
||||
### 阶段 5: 集成训练触发 API
|
||||
|
||||
```python
|
||||
# src/web/routes/training.py
|
||||
from fastapi import APIRouter
|
||||
from azure.ai.ml import MLClient, command
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ml_client = MLClient(
|
||||
credential=DefaultAzureCredential(),
|
||||
subscription_id="your-subscription-id",
|
||||
resource_group_name="myRG",
|
||||
workspace_name="invoice-ml"
|
||||
)
|
||||
|
||||
@router.post("/api/v1/train")
|
||||
async def trigger_training(request: TrainingRequest):
|
||||
"""触发 Azure ML 训练任务"""
|
||||
training_job = command(
|
||||
code="./training",
|
||||
command=f"python train.py --epochs {request.epochs}",
|
||||
environment="AzureML-pytorch-2.0-cuda11.8@latest",
|
||||
compute="gpu-cluster",
|
||||
)
|
||||
job = ml_client.jobs.create_or_update(training_job)
|
||||
return {
|
||||
"job_id": job.name,
|
||||
"status": job.status,
|
||||
"studio_url": job.studio_url
|
||||
}
|
||||
|
||||
@router.get("/api/v1/train/{job_id}/status")
|
||||
async def get_training_status(job_id: str):
|
||||
"""查询训练状态"""
|
||||
job = ml_client.jobs.get(job_id)
|
||||
return {"status": job.status}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 推荐配置
|
||||
|
||||
| 组件 | 推荐方案 | 月费估算 |
|
||||
|------|---------|---------|
|
||||
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
|
||||
| 数据库 | PostgreSQL Flexible | ~$25 |
|
||||
| 推理服务 | Container Apps CPU | ~$30 |
|
||||
| 训练服务 | Azure ML (min=0, Spot) | 按需 ~$1-5/次 |
|
||||
| Container Registry | Basic | ~$5 |
|
||||
| **总计** | | **~$65/月** + 训练费 |
|
||||
|
||||
### 关键决策
|
||||
|
||||
| 场景 | 选择 |
|
||||
|------|------|
|
||||
| 偶尔训练,简单需求 | Azure VM Spot + 手动启停 |
|
||||
| 需要 MLOps,团队协作 | Azure ML Compute |
|
||||
| 追求最低空闲成本 | Container Apps Serverless GPU |
|
||||
| 生产环境推理 | Container Apps CPU |
|
||||
| 高并发推理 | Container Apps Serverless GPU |
|
||||
|
||||
### 注意事项
|
||||
|
||||
1. **冷启动**: Serverless GPU 启动需要 3-8 分钟
|
||||
2. **Spot 中断**: 可能被抢占,需要检查点机制
|
||||
3. **网络延迟**: Blob Storage 挂载比本地 SSD 慢,建议开启缓存
|
||||
4. **区域选择**: 选择有 GPU 配额的区域 (East US, West Europe 等)
|
||||
5. **推理优化**: CPU 推理对于 YOLO 已经足够,无需 GPU
|
||||
647
docs/dashboard-design-spec.md
Normal file
647
docs/dashboard-design-spec.md
Normal file
@@ -0,0 +1,647 @@
|
||||
# Dashboard Design Specification
|
||||
|
||||
## Overview
|
||||
|
||||
Dashboard 是用户进入系统后的第一个页面,用于快速了解:
|
||||
- 数据标注质量和进度
|
||||
- 当前模型状态和性能
|
||||
- 系统最近发生的活动
|
||||
|
||||
**目标用户**:使用文档标注系统的客户,需要监控文档处理状态、标注质量和模型训练进度。
|
||||
|
||||
---
|
||||
|
||||
## 1. UI Layout
|
||||
|
||||
### 1.1 Overall Structure
|
||||
|
||||
```
|
||||
+------------------------------------------------------------------+
|
||||
| Header: Logo + Navigation + User Menu |
|
||||
+------------------------------------------------------------------+
|
||||
| |
|
||||
| Stats Cards Row (4 cards, equal width) |
|
||||
| |
|
||||
| +---------------------------+ +------------------------------+ |
|
||||
| | Data Quality Panel (50%) | | Active Model Panel (50%) | |
|
||||
| +---------------------------+ +------------------------------+ |
|
||||
| |
|
||||
| +--------------------------------------------------------------+ |
|
||||
| | Recent Activity Panel (full width) | |
|
||||
| +--------------------------------------------------------------+ |
|
||||
| |
|
||||
| +--------------------------------------------------------------+ |
|
||||
| | System Status Bar (full width) | |
|
||||
| +--------------------------------------------------------------+ |
|
||||
+------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 1.2 Responsive Breakpoints
|
||||
|
||||
| Breakpoint | Layout |
|
||||
|------------|--------|
|
||||
| Desktop (>1200px) | 4 cards row, 2-column panels |
|
||||
| Tablet (768-1200px) | 2x2 cards, 2-column panels |
|
||||
| Mobile (<768px) | 1 card per row, stacked panels |
|
||||
|
||||
---
|
||||
|
||||
## 2. Component Specifications
|
||||
|
||||
### 2.1 Stats Cards Row
|
||||
|
||||
4 个等宽卡片,显示核心统计数据。
|
||||
|
||||
```
|
||||
+-------------+ +-------------+ +-------------+ +-------------+
|
||||
| [icon] | | [icon] | | [icon] | | [icon] |
|
||||
| 38 | | 25 | | 8 | | 5 |
|
||||
| Total Docs | | Complete | | Incomplete | | Pending |
|
||||
+-------------+ +-------------+ +-------------+ +-------------+
|
||||
```
|
||||
|
||||
| Card | Icon | Value | Label | Color | Click Action |
|
||||
|------|------|-------|-------|-------|--------------|
|
||||
| Total Documents | FileText | `total_documents` | "Total Documents" | Gray | Navigate to Documents page |
|
||||
| Complete | CheckCircle | `annotation_complete` | "Complete" | Green | Navigate to Documents (filter: complete) |
|
||||
| Incomplete | AlertCircle | `annotation_incomplete` | "Incomplete" | Orange | Navigate to Documents (filter: incomplete) |
|
||||
| Pending | Clock | `pending` | "Pending" | Blue | Navigate to Documents (filter: pending) |
|
||||
|
||||
**Card Design:**
|
||||
- Background: White with subtle border
|
||||
- Icon: 24px, positioned top-left
|
||||
- Value: 32px bold font
|
||||
- Label: 14px muted color
|
||||
- Hover: Slight shadow elevation
|
||||
- Padding: 16px
|
||||
|
||||
### 2.2 Data Quality Panel
|
||||
|
||||
左侧面板,显示标注完整度和质量指标。
|
||||
|
||||
```
|
||||
+---------------------------+
|
||||
| DATA QUALITY |
|
||||
| +-----------+ |
|
||||
| | | |
|
||||
| | 78% | Annotation |
|
||||
| | | Complete |
|
||||
| +-----------+ |
|
||||
| |
|
||||
| Complete: 25 |
|
||||
| Incomplete: 8 |
|
||||
| Pending: 5 |
|
||||
| |
|
||||
| [View Incomplete Docs] |
|
||||
+---------------------------+
|
||||
```
|
||||
|
||||
**Components:**
|
||||
|
||||
| Element | Spec |
|
||||
|---------|------|
|
||||
| Title | "DATA QUALITY", 14px uppercase, muted |
|
||||
| Progress Ring | 120px diameter, stroke width 12px |
|
||||
| Percentage | 36px bold, centered in ring |
|
||||
| Label | "Annotation Complete", 14px, below ring |
|
||||
| Stats List | 14px, icon + label + value per row |
|
||||
| Action Button | Text button, primary color |
|
||||
|
||||
**Progress Ring Colors:**
|
||||
- Complete portion: Green (#22C55E)
|
||||
- Remaining: Gray (#E5E7EB)
|
||||
|
||||
**Completeness Calculation:**
|
||||
```
|
||||
completeness_rate = annotation_complete / (annotation_complete + annotation_incomplete) * 100
|
||||
```
|
||||
|
||||
### 2.3 Active Model Panel
|
||||
|
||||
右侧面板,显示当前生产模型信息。
|
||||
|
||||
```
|
||||
+-------------------------------+
|
||||
| ACTIVE MODEL |
|
||||
| |
|
||||
| v1.2.0 - Invoice Model |
|
||||
| ----------------------------- |
|
||||
| |
|
||||
| mAP Precision Recall |
|
||||
| 95.1% 94% 92% |
|
||||
| |
|
||||
| Activated: 2024-01-20 |
|
||||
| Documents: 500 |
|
||||
| |
|
||||
| [Training] Run-2024-02 [====] |
|
||||
+-------------------------------+
|
||||
```
|
||||
|
||||
**Components:**
|
||||
|
||||
| Element | Spec |
|
||||
|---------|------|
|
||||
| Title | "ACTIVE MODEL", 14px uppercase, muted |
|
||||
| Version + Name | 18px bold (version) + 16px regular (name) |
|
||||
| Divider | 1px border, full width |
|
||||
| Metrics Row | 3 columns, equal width |
|
||||
| Metric Value | 24px bold |
|
||||
| Metric Label | 12px muted, below value |
|
||||
| Info Rows | 14px, label: value format |
|
||||
| Training Indicator | Shows when training is running |
|
||||
|
||||
**Metric Colors:**
|
||||
- mAP >= 90%: Green
|
||||
- mAP 80-90%: Yellow
|
||||
- mAP < 80%: Red
|
||||
|
||||
**Empty State (No Active Model):**
|
||||
```
|
||||
+-------------------------------+
|
||||
| ACTIVE MODEL |
|
||||
| |
|
||||
| [icon: Model] |
|
||||
| No Active Model |
|
||||
| |
|
||||
| Train and activate a |
|
||||
| model to see stats here |
|
||||
| |
|
||||
| [Go to Training] |
|
||||
+-------------------------------+
|
||||
```
|
||||
|
||||
**Training In Progress:**
|
||||
```
|
||||
| Training: Run-2024-02 |
|
||||
| [=========> ] 45% |
|
||||
| Started 2 hours ago |
|
||||
```
|
||||
|
||||
### 2.4 Recent Activity Panel
|
||||
|
||||
全宽面板,显示最近 10 条系统活动。
|
||||
|
||||
```
|
||||
+--------------------------------------------------------------+
|
||||
| RECENT ACTIVITY [See All] |
|
||||
+--------------------------------------------------------------+
|
||||
| [rocket] Activated model v1.2.0 2 hours ago|
|
||||
| [check] Training complete: Run-2024-01, mAP 95.1% yesterday|
|
||||
| [edit] Modified INV-001.pdf invoice_number yesterday|
|
||||
| [doc] Uploaded INV-005.pdf 2 days ago|
|
||||
| [doc] Uploaded INV-004.pdf 2 days ago|
|
||||
| [x] Training failed: Run-2024-00 3 days ago|
|
||||
+--------------------------------------------------------------+
|
||||
```
|
||||
|
||||
**Activity Item Layout:**
|
||||
|
||||
```
|
||||
[Icon] [Description] [Timestamp]
|
||||
```
|
||||
|
||||
| Element | Spec |
|
||||
|---------|------|
|
||||
| Icon | 16px, color based on type |
|
||||
| Description | 14px, truncate if too long |
|
||||
| Timestamp | 12px muted, right-aligned |
|
||||
| Row Height | 40px |
|
||||
| Hover | Background highlight |
|
||||
|
||||
**Activity Types and Icons:**
|
||||
|
||||
| Type | Icon | Color | Description Format |
|
||||
|------|------|-------|-------------------|
|
||||
| document_uploaded | FileText | Blue | "Uploaded {filename}" |
|
||||
| annotation_modified | Edit | Orange | "Modified {filename} {field_name}" |
|
||||
| training_completed | CheckCircle | Green | "Training complete: {task_name}, mAP {mAP}%" |
|
||||
| training_failed | XCircle | Red | "Training failed: {task_name}" |
|
||||
| model_activated | Rocket | Purple | "Activated model {version}" |
|
||||
|
||||
**Timestamp Formatting:**
|
||||
- < 1 minute: "just now"
|
||||
- < 1 hour: "{n} minutes ago"
|
||||
- < 24 hours: "{n} hours ago"
|
||||
- < 7 days: "yesterday" / "{n} days ago"
|
||||
- >= 7 days: "Jan 15" (date format)
|
||||
|
||||
**Empty State:**
|
||||
```
|
||||
+--------------------------------------------------------------+
|
||||
| RECENT ACTIVITY |
|
||||
| |
|
||||
| [icon: Activity] |
|
||||
| No recent activity |
|
||||
| |
|
||||
| Start by uploading documents or creating training jobs |
|
||||
+--------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 2.5 System Status Bar
|
||||
|
||||
底部状态栏,显示系统健康状态。
|
||||
|
||||
```
|
||||
+--------------------------------------------------------------+
|
||||
| Backend API: [*] Online Database: [*] Connected GPU: [*] Available |
|
||||
+--------------------------------------------------------------+
|
||||
```
|
||||
|
||||
| Status | Icon | Color |
|
||||
|--------|------|-------|
|
||||
| Online/Connected/Available | Filled circle | Green |
|
||||
| Degraded/Slow | Filled circle | Yellow |
|
||||
| Offline/Error/Unavailable | Filled circle | Red |
|
||||
|
||||
---
|
||||
|
||||
## 3. API Endpoints
|
||||
|
||||
### 3.1 Dashboard Statistics
|
||||
|
||||
```
|
||||
GET /api/v1/admin/dashboard/stats
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"total_documents": 38,
|
||||
"annotation_complete": 25,
|
||||
"annotation_incomplete": 8,
|
||||
"pending": 5,
|
||||
"completeness_rate": 75.76
|
||||
}
|
||||
```
|
||||
|
||||
**Calculation Logic:**
|
||||
|
||||
```python
|
||||
# annotation_complete: labeled documents with core fields
|
||||
SELECT COUNT(*) FROM admin_documents d
|
||||
WHERE d.status = 'labeled'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM admin_annotations a
|
||||
WHERE a.document_id = d.document_id
|
||||
AND a.class_id IN (0, 3) -- invoice_number OR ocr_number
|
||||
)
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM admin_annotations a
|
||||
WHERE a.document_id = d.document_id
|
||||
AND a.class_id IN (4, 5) -- bankgiro OR plusgiro
|
||||
)
|
||||
|
||||
# annotation_incomplete: labeled but missing core fields
|
||||
SELECT COUNT(*) FROM admin_documents d
|
||||
WHERE d.status = 'labeled'
|
||||
AND NOT (/* above conditions */)
|
||||
|
||||
# pending: pending + auto_labeling
|
||||
SELECT COUNT(*) FROM admin_documents
|
||||
WHERE status IN ('pending', 'auto_labeling')
|
||||
```
|
||||
|
||||
### 3.2 Active Model Info
|
||||
|
||||
```
|
||||
GET /api/v1/admin/dashboard/active-model
|
||||
```
|
||||
|
||||
**Response (with active model):**
|
||||
```json
|
||||
{
|
||||
"model": {
|
||||
"version_id": "uuid",
|
||||
"version": "1.2.0",
|
||||
"name": "Invoice Model",
|
||||
"metrics_mAP": 0.951,
|
||||
"metrics_precision": 0.94,
|
||||
"metrics_recall": 0.92,
|
||||
"document_count": 500,
|
||||
"activated_at": "2024-01-20T15:00:00Z"
|
||||
},
|
||||
"running_training": {
|
||||
"task_id": "uuid",
|
||||
"name": "Run-2024-02",
|
||||
"status": "running",
|
||||
"started_at": "2024-01-25T10:00:00Z",
|
||||
"progress": 45
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response (no active model):**
|
||||
```json
|
||||
{
|
||||
"model": null,
|
||||
"running_training": null
|
||||
}
|
||||
```
|
||||
|
||||
### 3.3 Recent Activity
|
||||
|
||||
```
|
||||
GET /api/v1/admin/dashboard/activity?limit=10
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"activities": [
|
||||
{
|
||||
"type": "model_activated",
|
||||
"description": "Activated model v1.2.0",
|
||||
"timestamp": "2024-01-25T12:00:00Z",
|
||||
"metadata": {
|
||||
"version_id": "uuid",
|
||||
"version": "1.2.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "training_completed",
|
||||
"description": "Training complete: Run-2024-01, mAP 95.1%",
|
||||
"timestamp": "2024-01-24T18:30:00Z",
|
||||
"metadata": {
|
||||
"task_id": "uuid",
|
||||
"task_name": "Run-2024-01",
|
||||
"mAP": 0.951
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Activity Aggregation Query:**
|
||||
|
||||
```sql
|
||||
-- Union all activity sources, ordered by timestamp DESC, limit 10
|
||||
(
|
||||
SELECT 'document_uploaded' as type,
|
||||
filename as entity_name,
|
||||
created_at as timestamp,
|
||||
document_id as entity_id
|
||||
FROM admin_documents
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT 'annotation_modified' as type,
|
||||
-- join to get filename and field name
|
||||
...
|
||||
FROM annotation_history
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT CASE WHEN status = 'completed' THEN 'training_completed'
|
||||
WHEN status = 'failed' THEN 'training_failed' END as type,
|
||||
name as entity_name,
|
||||
completed_at as timestamp,
|
||||
task_id as entity_id
|
||||
FROM training_tasks
|
||||
WHERE status IN ('completed', 'failed')
|
||||
ORDER BY completed_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT 'model_activated' as type,
|
||||
version as entity_name,
|
||||
activated_at as timestamp,
|
||||
version_id as entity_id
|
||||
FROM model_versions
|
||||
WHERE activated_at IS NOT NULL
|
||||
ORDER BY activated_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. UX Interactions
|
||||
|
||||
### 4.1 Loading States
|
||||
|
||||
| Component | Loading State |
|
||||
|-----------|--------------|
|
||||
| Stats Cards | Skeleton placeholder (gray boxes) |
|
||||
| Data Quality Ring | Skeleton circle |
|
||||
| Active Model | Skeleton lines |
|
||||
| Recent Activity | Skeleton list items (5 rows) |
|
||||
|
||||
**Loading Duration Thresholds:**
|
||||
- < 300ms: No loading state shown
|
||||
- 300ms - 3s: Show skeleton
|
||||
- > 3s: Show skeleton + "Taking longer than expected" message
|
||||
|
||||
### 4.2 Error States
|
||||
|
||||
| Error Type | Display |
|
||||
|------------|---------|
|
||||
| API Error | Toast notification + retry button in affected panel |
|
||||
| Network Error | Full page overlay with retry option |
|
||||
| Partial Failure | Show available data, error badge on failed sections |
|
||||
|
||||
### 4.3 Refresh Behavior
|
||||
|
||||
| Trigger | Behavior |
|
||||
|---------|----------|
|
||||
| Page Load | Fetch all data |
|
||||
| Manual Refresh | Button in header, refetch all |
|
||||
| Auto Refresh | Every 30 seconds for activity panel |
|
||||
| Focus Return | Refetch if page was hidden > 5 minutes |
|
||||
|
||||
### 4.4 Click Actions
|
||||
|
||||
| Element | Action |
|
||||
|---------|--------|
|
||||
| Total Documents card | Navigate to `/documents` |
|
||||
| Complete card | Navigate to `/documents?filter=complete` |
|
||||
| Incomplete card | Navigate to `/documents?filter=incomplete` |
|
||||
| Pending card | Navigate to `/documents?filter=pending` |
|
||||
| "View Incomplete Docs" button | Navigate to `/documents?filter=incomplete` |
|
||||
| Activity item | Navigate to related entity |
|
||||
| "Go to Training" button | Navigate to `/training` |
|
||||
| Active Model version | Navigate to `/models/{version_id}` |
|
||||
|
||||
### 4.5 Tooltips
|
||||
|
||||
| Element | Tooltip Content |
|
||||
|---------|----------------|
|
||||
| Completeness % | "25 of 33 labeled documents have complete annotations" |
|
||||
| mAP metric | "Mean Average Precision at IoU 0.5" |
|
||||
| Precision metric | "Proportion of correct positive predictions" |
|
||||
| Recall metric | "Proportion of actual positives correctly identified" |
|
||||
| Incomplete count | "Documents labeled but missing invoice_number/ocr_number or bankgiro/plusgiro" |
|
||||
|
||||
---
|
||||
|
||||
## 5. Data Model
|
||||
|
||||
### 5.1 TypeScript Types
|
||||
|
||||
```typescript
|
||||
// Dashboard Stats
|
||||
interface DashboardStats {
|
||||
total_documents: number;
|
||||
annotation_complete: number;
|
||||
annotation_incomplete: number;
|
||||
pending: number;
|
||||
completeness_rate: number;
|
||||
}
|
||||
|
||||
// Active Model
|
||||
interface ActiveModelInfo {
|
||||
model: ModelVersion | null;
|
||||
running_training: RunningTraining | null;
|
||||
}
|
||||
|
||||
interface ModelVersion {
|
||||
version_id: string;
|
||||
version: string;
|
||||
name: string;
|
||||
metrics_mAP: number;
|
||||
metrics_precision: number;
|
||||
metrics_recall: number;
|
||||
document_count: number;
|
||||
activated_at: string;
|
||||
}
|
||||
|
||||
interface RunningTraining {
|
||||
task_id: string;
|
||||
name: string;
|
||||
status: 'running';
|
||||
started_at: string;
|
||||
progress: number;
|
||||
}
|
||||
|
||||
// Activity
|
||||
interface Activity {
|
||||
type: ActivityType;
|
||||
description: string;
|
||||
timestamp: string;
|
||||
metadata: Record<string, unknown>;
|
||||
}
|
||||
|
||||
type ActivityType =
|
||||
| 'document_uploaded'
|
||||
| 'annotation_modified'
|
||||
| 'training_completed'
|
||||
| 'training_failed'
|
||||
| 'model_activated';
|
||||
|
||||
// Activity Response
|
||||
interface ActivityResponse {
|
||||
activities: Activity[];
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 React Query Hooks
|
||||
|
||||
```typescript
|
||||
// useDashboardStats
|
||||
const useDashboardStats = () => {
|
||||
return useQuery({
|
||||
queryKey: ['dashboard', 'stats'],
|
||||
queryFn: () => api.get('/admin/dashboard/stats'),
|
||||
refetchInterval: 30000, // 30 seconds
|
||||
});
|
||||
};
|
||||
|
||||
// useActiveModel
|
||||
const useActiveModel = () => {
|
||||
return useQuery({
|
||||
queryKey: ['dashboard', 'active-model'],
|
||||
queryFn: () => api.get('/admin/dashboard/active-model'),
|
||||
refetchInterval: 60000, // 1 minute
|
||||
});
|
||||
};
|
||||
|
||||
// useRecentActivity
|
||||
const useRecentActivity = (limit = 10) => {
|
||||
return useQuery({
|
||||
queryKey: ['dashboard', 'activity', limit],
|
||||
queryFn: () => api.get(`/admin/dashboard/activity?limit=${limit}`),
|
||||
refetchInterval: 30000,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Annotation Completeness Definition
|
||||
|
||||
### 6.1 Core Fields
|
||||
|
||||
A document is **complete** when it has annotations for:
|
||||
|
||||
| Requirement | Fields | Logic |
|
||||
|-------------|--------|-------|
|
||||
| Identifier | `invoice_number` (class_id=0) OR `ocr_number` (class_id=3) | At least one |
|
||||
| Payment Account | `bankgiro` (class_id=4) OR `plusgiro` (class_id=5) | At least one |
|
||||
|
||||
### 6.2 Status Categories
|
||||
|
||||
| Category | Criteria |
|
||||
|----------|----------|
|
||||
| **Complete** | status=labeled AND has identifier AND has payment account |
|
||||
| **Incomplete** | status=labeled AND (missing identifier OR missing payment account) |
|
||||
| **Pending** | status IN (pending, auto_labeling) |
|
||||
|
||||
### 6.3 Filter Implementation
|
||||
|
||||
```sql
|
||||
-- Complete documents
|
||||
WHERE status = 'labeled'
|
||||
AND document_id IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (0, 3)
|
||||
)
|
||||
AND document_id IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (4, 5)
|
||||
)
|
||||
|
||||
-- Incomplete documents
|
||||
WHERE status = 'labeled'
|
||||
AND (
|
||||
document_id NOT IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (0, 3)
|
||||
)
|
||||
OR document_id NOT IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (4, 5)
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Implementation Checklist
|
||||
|
||||
### Backend
|
||||
- [ ] Create `/api/v1/admin/dashboard/stats` endpoint
|
||||
- [ ] Create `/api/v1/admin/dashboard/active-model` endpoint
|
||||
- [ ] Create `/api/v1/admin/dashboard/activity` endpoint
|
||||
- [ ] Add completeness calculation logic to document repository
|
||||
- [ ] Implement activity aggregation query
|
||||
|
||||
### Frontend
|
||||
- [ ] Create `DashboardOverview` component
|
||||
- [ ] Create `StatsCard` component
|
||||
- [ ] Create `DataQualityPanel` component with progress ring
|
||||
- [ ] Create `ActiveModelPanel` component
|
||||
- [ ] Create `RecentActivityPanel` component
|
||||
- [ ] Create `SystemStatusBar` component
|
||||
- [ ] Add React Query hooks for dashboard data
|
||||
- [ ] Implement loading skeletons
|
||||
- [ ] Implement error states
|
||||
- [ ] Add navigation actions
|
||||
- [ ] Add tooltips
|
||||
|
||||
### Testing
|
||||
- [ ] Unit tests for completeness calculation
|
||||
- [ ] Unit tests for activity aggregation
|
||||
- [ ] Integration tests for dashboard endpoints
|
||||
- [ ] E2E tests for dashboard interactions
|
||||
@@ -1,619 +0,0 @@
|
||||
# 多池处理架构设计文档
|
||||
|
||||
## 1. 研究总结
|
||||
|
||||
### 1.1 当前问题分析
|
||||
|
||||
我们之前实现的双池模式存在稳定性问题,主要原因:
|
||||
|
||||
| 问题 | 原因 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
|
||||
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
|
||||
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
|
||||
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
|
||||
|
||||
### 1.2 推荐架构方案
|
||||
|
||||
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**:
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Main Process │ │ CPU Workers │ │ GPU Worker │
|
||||
│ │ │ (4 processes) │ │ (1 process) │
|
||||
│ ┌───────────┐ │ │ │ │ │
|
||||
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
|
||||
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
|
||||
│ └───────────┘ │ │ │ │ │
|
||||
│ ▲ │ │ │ │ │ │ │
|
||||
│ │ │ │ ▼ │ │ ▼ │
|
||||
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
|
||||
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
|
||||
│ │ Collector │ │ │ │ │ │
|
||||
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────┐ │
|
||||
│ │ Database │ │
|
||||
│ │ Batch │ │
|
||||
│ │ Writer │ │
|
||||
│ └───────────┘ │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 核心设计原则
|
||||
|
||||
### 2.1 CUDA 兼容性
|
||||
|
||||
```python
|
||||
# 关键:使用 spawn 启动方式
|
||||
import multiprocessing as mp
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
# GPU worker 初始化时设置设备
|
||||
def init_gpu_worker(gpu_id: int = 0):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
global _ocr
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=True, ...)
|
||||
```
|
||||
|
||||
### 2.2 Worker 初始化模式
|
||||
|
||||
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
|
||||
|
||||
```python
|
||||
# 全局变量保存模型
|
||||
_ocr = None
|
||||
|
||||
def init_worker(use_gpu: bool, gpu_id: int = 0):
|
||||
global _ocr
|
||||
if use_gpu:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
|
||||
|
||||
# 创建 Pool 时使用 initializer
|
||||
pool = ProcessPoolExecutor(
|
||||
max_workers=1,
|
||||
initializer=init_worker,
|
||||
initargs=(True, 0), # use_gpu=True, gpu_id=0
|
||||
mp_context=mp.get_context("spawn")
|
||||
)
|
||||
```
|
||||
|
||||
### 2.3 队列模式 vs as_completed
|
||||
|
||||
| 方式 | 优点 | 缺点 | 适用场景 |
|
||||
|------|------|------|----------|
|
||||
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
|
||||
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
|
||||
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
|
||||
|
||||
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
|
||||
|
||||
---
|
||||
|
||||
## 3. 详细开发计划
|
||||
|
||||
### 阶段 1:重构基础架构 (2-3天)
|
||||
|
||||
#### 1.1 创建 WorkerPool 抽象类
|
||||
|
||||
```python
|
||||
# src/processing/worker_pool.py
|
||||
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Optional, Callable
|
||||
import multiprocessing as mp
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""任务结果容器"""
|
||||
task_id: str
|
||||
success: bool
|
||||
data: Any
|
||||
error: Optional[str] = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
class WorkerPool(ABC):
|
||||
"""Worker Pool 抽象基类"""
|
||||
|
||||
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
|
||||
self.max_workers = max_workers
|
||||
self.use_gpu = use_gpu
|
||||
self.gpu_id = gpu_id
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_initializer(self) -> Callable:
|
||||
"""返回 worker 初始化函数"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_init_args(self) -> tuple:
|
||||
"""返回初始化参数"""
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
"""启动 worker pool"""
|
||||
ctx = mp.get_context("spawn")
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=self.max_workers,
|
||||
mp_context=ctx,
|
||||
initializer=self.get_initializer(),
|
||||
initargs=self.get_init_args()
|
||||
)
|
||||
|
||||
def submit(self, fn: Callable, *args, **kwargs) -> Future:
|
||||
"""提交任务"""
|
||||
if not self._executor:
|
||||
raise RuntimeError("Pool not started")
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def shutdown(self, wait: bool = True):
|
||||
"""关闭 pool"""
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=wait)
|
||||
self._executor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.shutdown()
|
||||
```
|
||||
|
||||
#### 1.2 实现 CPU 和 GPU Worker Pool
|
||||
|
||||
```python
|
||||
# src/processing/cpu_pool.py
|
||||
|
||||
class CPUWorkerPool(WorkerPool):
|
||||
"""CPU-only worker pool for text PDF processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 4):
|
||||
super().__init__(max_workers=max_workers, use_gpu=False)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_cpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return ()
|
||||
|
||||
# src/processing/gpu_pool.py
|
||||
|
||||
class GPUWorkerPool(WorkerPool):
|
||||
"""GPU worker pool for OCR processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
|
||||
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_gpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return (self.gpu_id,)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 2:实现双池协调器 (2-3天)
|
||||
|
||||
#### 2.1 任务分发器
|
||||
|
||||
```python
|
||||
# src/processing/task_dispatcher.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple
|
||||
|
||||
class TaskType(Enum):
|
||||
CPU = auto() # Text PDF
|
||||
GPU = auto() # Scanned PDF
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
id: str
|
||||
task_type: TaskType
|
||||
data: Any
|
||||
|
||||
class TaskDispatcher:
|
||||
"""根据 PDF 类型分发任务到不同的 pool"""
|
||||
|
||||
def classify_task(self, doc_info: dict) -> TaskType:
|
||||
"""判断文档是否需要 OCR"""
|
||||
# 基于 PDF 特征判断
|
||||
if self._is_scanned_pdf(doc_info):
|
||||
return TaskType.GPU
|
||||
return TaskType.CPU
|
||||
|
||||
def _is_scanned_pdf(self, doc_info: dict) -> bool:
|
||||
"""检测是否为扫描件"""
|
||||
# 1. 检查是否有可提取文本
|
||||
# 2. 检查图片比例
|
||||
# 3. 检查文本密度
|
||||
pass
|
||||
|
||||
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
|
||||
"""将任务分为 CPU 和 GPU 两组"""
|
||||
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
|
||||
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
|
||||
return cpu_tasks, gpu_tasks
|
||||
```
|
||||
|
||||
#### 2.2 双池协调器
|
||||
|
||||
```python
|
||||
# src/processing/dual_pool_coordinator.py
|
||||
|
||||
from concurrent.futures import as_completed
|
||||
from typing import List, Iterator
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DualPoolCoordinator:
|
||||
"""协调 CPU 和 GPU 两个 worker pool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_workers: int = 4,
|
||||
gpu_workers: int = 1,
|
||||
gpu_id: int = 0
|
||||
):
|
||||
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
|
||||
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
|
||||
self.dispatcher = TaskDispatcher()
|
||||
|
||||
def __enter__(self):
|
||||
self.cpu_pool.start()
|
||||
self.gpu_pool.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.cpu_pool.shutdown()
|
||||
self.gpu_pool.shutdown()
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
documents: List[dict],
|
||||
cpu_task_fn: Callable,
|
||||
gpu_task_fn: Callable,
|
||||
on_result: Optional[Callable[[TaskResult], None]] = None,
|
||||
on_error: Optional[Callable[[str, Exception], None]] = None
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
处理一批文档,自动分发到 CPU 或 GPU pool
|
||||
|
||||
Args:
|
||||
documents: 待处理文档列表
|
||||
cpu_task_fn: CPU 任务处理函数
|
||||
gpu_task_fn: GPU 任务处理函数
|
||||
on_result: 结果回调(可选)
|
||||
on_error: 错误回调(可选)
|
||||
|
||||
Returns:
|
||||
所有任务结果列表
|
||||
"""
|
||||
# 分类任务
|
||||
tasks = [
|
||||
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
|
||||
for doc in documents
|
||||
]
|
||||
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
|
||||
|
||||
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
|
||||
|
||||
# 提交任务到各自的 pool
|
||||
cpu_futures = {
|
||||
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
|
||||
for t in cpu_tasks
|
||||
}
|
||||
gpu_futures = {
|
||||
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
|
||||
for t in gpu_tasks
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
results = []
|
||||
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
|
||||
|
||||
for future in as_completed(all_futures):
|
||||
task_id = cpu_futures.get(future) or gpu_futures.get(future)
|
||||
pool_type = "CPU" if future in cpu_futures else "GPU"
|
||||
|
||||
try:
|
||||
data = future.result(timeout=300) # 5分钟超时
|
||||
result = TaskResult(task_id=task_id, success=True, data=data)
|
||||
if on_result:
|
||||
on_result(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
|
||||
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
|
||||
if on_error:
|
||||
on_error(task_id, e)
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 3:集成到 autolabel (1-2天)
|
||||
|
||||
#### 3.1 修改 autolabel.py
|
||||
|
||||
```python
|
||||
# src/cli/autolabel.py
|
||||
|
||||
def run_autolabel_dual_pool(args):
|
||||
"""使用双池模式运行自动标注"""
|
||||
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator
|
||||
|
||||
# 初始化数据库批处理
|
||||
db_batch = []
|
||||
db_batch_size = 100
|
||||
|
||||
def on_result(result: TaskResult):
|
||||
"""处理成功结果"""
|
||||
nonlocal db_batch
|
||||
db_batch.append(result.data)
|
||||
|
||||
if len(db_batch) >= db_batch_size:
|
||||
save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
def on_error(task_id: str, error: Exception):
|
||||
"""处理错误"""
|
||||
logger.error(f"Task {task_id} failed: {error}")
|
||||
|
||||
# 创建双池协调器
|
||||
with DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers or 4,
|
||||
gpu_workers=args.gpu_workers or 1,
|
||||
gpu_id=0
|
||||
) as coordinator:
|
||||
|
||||
# 处理所有 CSV
|
||||
for csv_file in csv_files:
|
||||
documents = load_documents_from_csv(csv_file)
|
||||
|
||||
results = coordinator.process_batch(
|
||||
documents=documents,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=on_result,
|
||||
on_error=on_error
|
||||
)
|
||||
|
||||
logger.info(f"CSV {csv_file}: {len(results)} processed")
|
||||
|
||||
# 保存剩余批次
|
||||
if db_batch:
|
||||
save_documents_batch(db_batch)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 4:测试与验证 (1-2天)
|
||||
|
||||
#### 4.1 单元测试
|
||||
|
||||
```python
|
||||
# tests/unit/test_dual_pool.py
|
||||
|
||||
import pytest
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
|
||||
|
||||
class TestDualPoolCoordinator:
|
||||
|
||||
def test_cpu_only_batch(self):
|
||||
"""测试纯 CPU 任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 10
|
||||
assert all(r.success for r in results)
|
||||
|
||||
def test_mixed_batch(self):
|
||||
"""测试混合任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [
|
||||
{"id": "text_1", "type": "text"},
|
||||
{"id": "scan_1", "type": "scanned"},
|
||||
{"id": "text_2", "type": "text"},
|
||||
]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_timeout_handling(self):
|
||||
"""测试超时处理"""
|
||||
pass
|
||||
|
||||
def test_error_recovery(self):
|
||||
"""测试错误恢复"""
|
||||
pass
|
||||
```
|
||||
|
||||
#### 4.2 集成测试
|
||||
|
||||
```python
|
||||
# tests/integration/test_autolabel_dual_pool.py
|
||||
|
||||
def test_autolabel_with_dual_pool():
|
||||
"""端到端测试双池模式"""
|
||||
# 使用少量测试数据
|
||||
result = subprocess.run([
|
||||
"python", "-m", "src.cli.autolabel",
|
||||
"--cpu-workers", "2",
|
||||
"--gpu-workers", "1",
|
||||
"--limit", "50"
|
||||
], capture_output=True)
|
||||
|
||||
assert result.returncode == 0
|
||||
# 验证数据库记录
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键技术点
|
||||
|
||||
### 4.1 避免死锁的策略
|
||||
|
||||
```python
|
||||
# 1. 使用 timeout
|
||||
try:
|
||||
result = future.result(timeout=300)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Task timed out")
|
||||
|
||||
# 2. 使用哨兵值
|
||||
SENTINEL = object()
|
||||
queue.put(SENTINEL) # 发送结束信号
|
||||
|
||||
# 3. 检查进程状态
|
||||
if not worker.is_alive():
|
||||
logger.error("Worker died unexpectedly")
|
||||
break
|
||||
|
||||
# 4. 先清空队列再 join
|
||||
while not queue.empty():
|
||||
results.append(queue.get_nowait())
|
||||
worker.join(timeout=5.0)
|
||||
```
|
||||
|
||||
### 4.2 PaddleOCR 特殊处理
|
||||
|
||||
```python
|
||||
# PaddleOCR 必须在 worker 进程中初始化
|
||||
def init_paddle_worker(gpu_id: int):
|
||||
global _ocr
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
# 延迟导入,确保 CUDA 环境变量生效
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang='en',
|
||||
use_gpu=True,
|
||||
show_log=False,
|
||||
# 重要:设置 GPU 内存比例
|
||||
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3 资源监控
|
||||
|
||||
```python
|
||||
import psutil
|
||||
import GPUtil
|
||||
|
||||
def get_resource_usage():
|
||||
"""获取系统资源使用情况"""
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
gpu_info = []
|
||||
for gpu in GPUtil.getGPUs():
|
||||
gpu_info.append({
|
||||
"id": gpu.id,
|
||||
"memory_used": gpu.memoryUsed,
|
||||
"memory_total": gpu.memoryTotal,
|
||||
"utilization": gpu.load * 100
|
||||
})
|
||||
|
||||
return {
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_percent": memory.percent,
|
||||
"gpu": gpu_info
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 风险评估与应对
|
||||
|
||||
| 风险 | 可能性 | 影响 | 应对策略 |
|
||||
|------|--------|------|----------|
|
||||
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1,设置 gpu_mem 参数 |
|
||||
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
|
||||
| 任务分类错误 | 中 | 中 | 添加回退机制,CPU 失败后尝试 GPU |
|
||||
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 备选方案
|
||||
|
||||
如果上述方案仍存在问题,可以考虑:
|
||||
|
||||
### 6.1 使用 Ray
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def cpu_task(data):
|
||||
return process_text_pdf(data)
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
def gpu_task(data):
|
||||
return process_scanned_pdf(data)
|
||||
|
||||
# 自动资源调度
|
||||
futures = [cpu_task.remote(d) for d in cpu_docs]
|
||||
futures += [gpu_task.remote(d) for d in gpu_docs]
|
||||
results = ray.get(futures)
|
||||
```
|
||||
|
||||
### 6.2 单池 + 动态 GPU 调度
|
||||
|
||||
保持单池模式,但在每个任务内部动态决定是否使用 GPU:
|
||||
|
||||
```python
|
||||
def process_document(doc_data):
|
||||
if is_scanned_pdf(doc_data):
|
||||
# 使用 GPU (需要全局锁或信号量控制并发)
|
||||
with gpu_semaphore:
|
||||
return process_with_ocr(doc_data)
|
||||
else:
|
||||
return process_text_only(doc_data)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 时间线总结
|
||||
|
||||
| 阶段 | 任务 | 预计工作量 |
|
||||
|------|------|------------|
|
||||
| 阶段 1 | 基础架构重构 | 2-3 天 |
|
||||
| 阶段 2 | 双池协调器实现 | 2-3 天 |
|
||||
| 阶段 3 | 集成到 autolabel | 1-2 天 |
|
||||
| 阶段 4 | 测试与验证 | 1-2 天 |
|
||||
| **总计** | | **6-10 天** |
|
||||
|
||||
---
|
||||
|
||||
## 8. 参考资料
|
||||
|
||||
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
|
||||
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
|
||||
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
|
||||
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
|
||||
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
|
||||
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)
|
||||
35
docs/product-plan-v2-CHANGELOG.md
Normal file
35
docs/product-plan-v2-CHANGELOG.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Product Plan v2 - Change Log
|
||||
|
||||
## [v2.1] - 2026-02-01
|
||||
|
||||
### New Features
|
||||
|
||||
#### Epic 7: Dashboard Enhancement
|
||||
- Added **US-7.1**: Data quality metrics panel showing annotation completeness rate
|
||||
- Added **US-7.2**: Active model status panel with mAP/precision/recall metrics
|
||||
- Added **US-7.3**: Recent activity feed showing last 10 system activities
|
||||
- Added **US-7.4**: Meaningful stats cards (Total/Complete/Incomplete/Pending)
|
||||
|
||||
#### Annotation Completeness Definition
|
||||
- Defined "annotation complete" criteria:
|
||||
- Must have `invoice_number` OR `ocr_number` (identifier)
|
||||
- Must have `bankgiro` OR `plusgiro` (payment account)
|
||||
|
||||
### New API Endpoints
|
||||
- Added `GET /api/v1/admin/dashboard/stats` - Dashboard statistics with completeness calculation
|
||||
- Added `GET /api/v1/admin/dashboard/active-model` - Active model info with running training status
|
||||
- Added `GET /api/v1/admin/dashboard/activity` - Recent activity feed aggregated from multiple sources
|
||||
|
||||
### New UI Components
|
||||
- Added **5.0 Dashboard Overview** wireframe with:
|
||||
- Stats cards row (Total/Complete/Incomplete/Pending)
|
||||
- Data Quality panel with percentage ring
|
||||
- Active Model panel with metrics display
|
||||
- Recent Activity list with icons and relative timestamps
|
||||
- System Status bar
|
||||
|
||||
---
|
||||
|
||||
## [v2.0] - 2024-01-15
|
||||
- Initial version with Epic 1-6
|
||||
- Batch upload, document management, annotation workflow, training management
|
||||
1448
docs/product-plan-v2.md
Normal file
1448
docs/product-plan-v2.md
Normal file
File diff suppressed because it is too large
Load Diff
54
docs/training-flow.mmd
Normal file
54
docs/training-flow.mmd
Normal file
@@ -0,0 +1,54 @@
|
||||
flowchart TD
|
||||
A[CLI Entry Point\nsrc/cli/train.py] --> B[Parse Arguments\n--model, --epochs, --batch, --imgsz, etc.]
|
||||
B --> C[Connect PostgreSQL\nDB_HOST / DB_NAME / DB_PASSWORD]
|
||||
|
||||
C --> D[Load Data from DB\nsrc/yolo/db_dataset.py]
|
||||
D --> D1[Scan temp/doc_id/images/\nfor rendered PNGs]
|
||||
D --> D2[Batch load field_results\nfrom database - batch 500]
|
||||
|
||||
D1 --> E[Create DBYOLODataset]
|
||||
D2 --> E
|
||||
|
||||
E --> F[Split Train/Val/Test\n80% / 10% / 10%\nDocument-level, seed=42]
|
||||
|
||||
F --> G[Export to YOLO Format]
|
||||
G --> G1[Copy images to\ntrain/val/test dirs]
|
||||
G --> G2[Generate .txt labels\nclass x_center y_center w h]
|
||||
G --> G3[Generate dataset.yaml\n+ classes.txt]
|
||||
G --> G4[Coordinate Conversion\nPDF points 72DPI -> render DPI\nNormalize to 0-1]
|
||||
|
||||
G1 --> H{--export-only?}
|
||||
G2 --> H
|
||||
G3 --> H
|
||||
G4 --> H
|
||||
|
||||
H -- Yes --> Z[Done - Dataset exported]
|
||||
H -- No --> I[Load YOLO Model]
|
||||
|
||||
I --> I1{--resume?}
|
||||
I1 -- Yes --> I2[Load last.pt checkpoint]
|
||||
I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt]
|
||||
|
||||
I2 --> J[Configure Training]
|
||||
I3 --> J
|
||||
|
||||
J --> J1[Conservative Augmentation\nrotation=5 deg, translate=5%\nno flip, no mosaic, no mixup]
|
||||
J --> J2[imgsz=1280, pretrained=True]
|
||||
|
||||
J1 --> K[model.train\nUltralytics Training Loop]
|
||||
J2 --> K
|
||||
|
||||
K --> L[Training Outputs\nruns/train/name/]
|
||||
L --> L1[weights/best.pt\nweights/last.pt]
|
||||
L --> L2[results.csv + results.png\nTraining curves]
|
||||
L --> L3[PR curves, F1 curves\nConfusion matrix]
|
||||
|
||||
L1 --> M[Test Set Validation\nmodel.val split=test]
|
||||
M --> N[Report Metrics\nmAP@0.5 = 93.5%\nmAP@0.5-0.95]
|
||||
|
||||
N --> O[Close DB Connection]
|
||||
|
||||
style A fill:#4a90d9,color:#fff
|
||||
style K fill:#e67e22,color:#fff
|
||||
style N fill:#27ae60,color:#fff
|
||||
style Z fill:#95a5a6,color:#fff
|
||||
302
docs/ux-design-prompt-v2.md
Normal file
302
docs/ux-design-prompt-v2.md
Normal file
@@ -0,0 +1,302 @@
|
||||
# Document Annotation Tool – UX Design Spec v2
|
||||
|
||||
## Theme: Warm Graphite (Modern Enterprise)
|
||||
|
||||
---
|
||||
|
||||
## 1. Design Principles (Updated)
|
||||
|
||||
1. **Clarity** – High contrast, but never pure black-on-white
|
||||
2. **Warm Neutrality** – Slightly warm grays reduce visual fatigue
|
||||
3. **Focus** – Content-first layouts with restrained accents
|
||||
4. **Consistency** – Reusable patterns, predictable behavior
|
||||
5. **Professional Trust** – Calm, serious, enterprise-ready
|
||||
6. **Longevity** – No trendy colors that age quickly
|
||||
|
||||
---
|
||||
|
||||
## 2. Color Palette (Warm Graphite)
|
||||
|
||||
### Core Colors
|
||||
|
||||
| Usage | Color Name | Hex |
|
||||
|------|-----------|-----|
|
||||
| Primary Text | Soft Black | #121212 |
|
||||
| Secondary Text | Charcoal Gray | #2A2A2A |
|
||||
| Muted Text | Warm Gray | #6B6B6B |
|
||||
| Disabled Text | Light Warm Gray | #9A9A9A |
|
||||
|
||||
### Backgrounds
|
||||
|
||||
| Usage | Color | Hex |
|
||||
|-----|------|-----|
|
||||
| App Background | Paper White | #FAFAF8 |
|
||||
| Card / Panel | White | #FFFFFF |
|
||||
| Hover Surface | Subtle Warm Gray | #F1F0ED |
|
||||
| Selected Row | Very Light Warm Gray | #ECEAE6 |
|
||||
|
||||
### Borders & Dividers
|
||||
|
||||
| Usage | Color | Hex |
|
||||
|------|------|-----|
|
||||
| Default Border | Warm Light Gray | #E6E4E1 |
|
||||
| Strong Divider | Neutral Gray | #D8D6D2 |
|
||||
|
||||
### Semantic States (Muted & Professional)
|
||||
|
||||
| State | Color | Hex |
|
||||
|------|-------|-----|
|
||||
| Success | Olive Gray | #3E4A3A |
|
||||
| Error | Brick Gray | #4A3A3A |
|
||||
| Warning | Sand Gray | #4A4A3A |
|
||||
| Info | Graphite Gray | #3A3A3A |
|
||||
|
||||
> Accent colors are **never saturated** and are used only for status, progress, or selection.
|
||||
|
||||
---
|
||||
|
||||
## 3. Typography
|
||||
|
||||
- **Font Family**: Inter / SF Pro / system-ui
|
||||
- **Headings**:
|
||||
- Weight: 600–700
|
||||
- Color: #121212
|
||||
- Letter spacing: -0.01em
|
||||
- **Body Text**:
|
||||
- Weight: 400
|
||||
- Color: #2A2A2A
|
||||
- **Captions / Meta**:
|
||||
- Weight: 400
|
||||
- Color: #6B6B6B
|
||||
- **Monospace (IDs / Values)**:
|
||||
- JetBrains Mono / SF Mono
|
||||
- Color: #2A2A2A
|
||||
|
||||
---
|
||||
|
||||
## 4. Global Layout
|
||||
|
||||
### Top Navigation Bar
|
||||
|
||||
- Height: 56px
|
||||
- Background: #FAFAF8
|
||||
- Bottom Border: 1px solid #E6E4E1
|
||||
- Logo: Text or icon in #121212
|
||||
|
||||
**Navigation Items**
|
||||
- Default: #6B6B6B
|
||||
- Hover: #2A2A2A
|
||||
- Active:
|
||||
- Text: #121212
|
||||
- Bottom indicator: 2px solid #3A3A3A (rounded ends)
|
||||
|
||||
**Avatar**
|
||||
- Circle background: #ECEAE6
|
||||
- Text: #2A2A2A
|
||||
|
||||
---
|
||||
|
||||
## 5. Page: Documents (Dashboard)
|
||||
|
||||
### Page Header
|
||||
|
||||
- Title: "Documents" (#121212)
|
||||
- Actions:
|
||||
- Primary button: Dark graphite outline
|
||||
- Secondary button: Subtle border only
|
||||
|
||||
### Filters Bar
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Border: 1px solid #E6E4E1
|
||||
- Inputs:
|
||||
- Background: #FFFFFF
|
||||
- Hover: #F1F0ED
|
||||
- Focus ring: 1px #3A3A3A
|
||||
|
||||
### Document Table
|
||||
|
||||
- Table background: #FFFFFF
|
||||
- Header text: #6B6B6B
|
||||
- Row hover: #F1F0ED
|
||||
- Row selected:
|
||||
- Background: #ECEAE6
|
||||
- Left indicator: 3px solid #3A3A3A
|
||||
|
||||
### Status Badges
|
||||
|
||||
- Pending:
|
||||
- BG: #FFFFFF
|
||||
- Border: #D8D6D2
|
||||
- Text: #2A2A2A
|
||||
|
||||
- Labeled:
|
||||
- BG: #2A2A2A
|
||||
- Text: #FFFFFF
|
||||
|
||||
- Exported:
|
||||
- BG: #ECEAE6
|
||||
- Text: #2A2A2A
|
||||
- Icon: ✓
|
||||
|
||||
### Auto-label States
|
||||
|
||||
- Running:
|
||||
- Progress bar: #3A3A3A on #ECEAE6
|
||||
- Completed:
|
||||
- Text: #3E4A3A
|
||||
- Failed:
|
||||
- BG: #F1EDED
|
||||
- Text: #4A3A3A
|
||||
|
||||
---
|
||||
|
||||
## 6. Upload Modals (Single & Batch)
|
||||
|
||||
### Modal Container
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Border radius: 8px
|
||||
- Shadow: 0 1px 3px rgba(0,0,0,0.08)
|
||||
|
||||
### Drop Zone
|
||||
|
||||
- Background: #FAFAF8
|
||||
- Border: 1px dashed #D8D6D2
|
||||
- Hover: #F1F0ED
|
||||
- Icon: Graphite gray
|
||||
|
||||
### Form Fields
|
||||
|
||||
- Input BG: #FFFFFF
|
||||
- Border: #D8D6D2
|
||||
- Focus: 1px solid #3A3A3A
|
||||
|
||||
Primary Action Button:
|
||||
- Text: #FFFFFF
|
||||
- BG: #2A2A2A
|
||||
- Hover: #121212
|
||||
|
||||
---
|
||||
|
||||
## 7. Document Detail View
|
||||
|
||||
### Canvas Area
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Annotation styles:
|
||||
- Manual: Solid border #2A2A2A
|
||||
- Auto: Dashed border #6B6B6B
|
||||
- Selected: 2px border #3A3A3A + resize handles
|
||||
|
||||
### Right Info Panel
|
||||
|
||||
- Card background: #FFFFFF
|
||||
- Section headers: #121212
|
||||
- Meta text: #6B6B6B
|
||||
|
||||
### Annotation Table
|
||||
|
||||
- Same table styles as Documents
|
||||
- Inline edit:
|
||||
- Input background: #FAFAF8
|
||||
- Save button: Graphite
|
||||
|
||||
### Locked State (Auto-label Running)
|
||||
|
||||
- Banner BG: #FAFAF8
|
||||
- Border-left: 3px solid #4A4A3A
|
||||
- Progress bar: Graphite
|
||||
|
||||
---
|
||||
|
||||
## 8. Training Page
|
||||
|
||||
### Document Selector
|
||||
|
||||
- Selected rows use same highlight rules
|
||||
- Verified state:
|
||||
- Full: Olive gray check
|
||||
- Partial: Sand gray warning
|
||||
|
||||
### Configuration Panel
|
||||
|
||||
- Card layout
|
||||
- Inputs aligned to grid
|
||||
- Schedule option visually muted until enabled
|
||||
|
||||
Primary CTA:
|
||||
- Start Training button in dark graphite
|
||||
|
||||
---
|
||||
|
||||
## 9. Models & Training History
|
||||
|
||||
### Training Job List
|
||||
|
||||
- Job cards use #FFFFFF background
|
||||
- Running job:
|
||||
- Progress bar: #3A3A3A
|
||||
- Completed job:
|
||||
- Metrics bars in graphite
|
||||
|
||||
### Model Detail Panel
|
||||
|
||||
- Sectioned cards
|
||||
- Metric bars:
|
||||
- Track: #ECEAE6
|
||||
- Fill: #3A3A3A
|
||||
|
||||
Actions:
|
||||
- Primary: Download Model
|
||||
- Secondary: View Logs / Use as Base
|
||||
|
||||
---
|
||||
|
||||
## 10. Micro-interactions (Refined)
|
||||
|
||||
| Element | Interaction | Animation |
|
||||
|------|------------|-----------|
|
||||
| Button hover | BG lightens | 150ms ease-out |
|
||||
| Button press | Scale 0.98 | 100ms |
|
||||
| Row hover | BG fade | 120ms |
|
||||
| Modal open | Fade + scale 0.96 → 1 | 200ms |
|
||||
| Progress fill | Smooth | ease-out |
|
||||
| Annotation select | Border + handles | 120ms |
|
||||
|
||||
---
|
||||
|
||||
## 11. Tailwind Theme (Updated)
|
||||
|
||||
```js
|
||||
colors: {
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
bg: {
|
||||
app: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
},
|
||||
border: '#E6E4E1',
|
||||
accent: '#3A3A3A',
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Final Notes
|
||||
|
||||
- Pure black (#000000) should **never** be used as large surfaces
|
||||
- Accent color usage should stay under **10% of UI area**
|
||||
- Warm grays are intentional and must not be "corrected" to blue-grays
|
||||
|
||||
This theme is designed to scale from internal tool → polished SaaS without redesign.
|
||||
|
||||
273
docs/web-refactoring-complete.md
Normal file
273
docs/web-refactoring-complete.md
Normal file
@@ -0,0 +1,273 @@
|
||||
# Web Directory Refactoring - Complete ✅
|
||||
|
||||
**Date**: 2026-01-25
|
||||
**Status**: ✅ Completed
|
||||
**Tests**: 188 passing (0 failures)
|
||||
**Coverage**: 23% (maintained)
|
||||
|
||||
---
|
||||
|
||||
## Final Directory Structure
|
||||
|
||||
```
|
||||
src/web/
|
||||
├── api/
|
||||
│ ├── __init__.py
|
||||
│ └── v1/
|
||||
│ ├── __init__.py
|
||||
│ ├── routes.py # Public inference API
|
||||
│ ├── admin/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── documents.py # Document management (was admin_routes.py)
|
||||
│ │ ├── annotations.py # Annotation routes (was admin_annotation_routes.py)
|
||||
│ │ └── training.py # Training routes (was admin_training_routes.py)
|
||||
│ ├── async_api/
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── routes.py # Async processing API (was async_routes.py)
|
||||
│ └── batch/
|
||||
│ ├── __init__.py
|
||||
│ └── routes.py # Batch upload API (was batch_upload_routes.py)
|
||||
│
|
||||
├── schemas/
|
||||
│ ├── __init__.py
|
||||
│ ├── common.py # Shared models (ErrorResponse)
|
||||
│ ├── admin.py # Admin schemas (was admin_schemas.py)
|
||||
│ └── inference.py # Inference + async schemas (was schemas.py)
|
||||
│
|
||||
├── services/
|
||||
│ ├── __init__.py
|
||||
│ ├── inference.py # Inference service (was services.py)
|
||||
│ ├── autolabel.py # Auto-label service (was admin_autolabel.py)
|
||||
│ ├── async_processing.py # Async processing (was async_service.py)
|
||||
│ └── batch_upload.py # Batch upload service (was batch_upload_service.py)
|
||||
│
|
||||
├── core/
|
||||
│ ├── __init__.py
|
||||
│ ├── auth.py # Authentication (was admin_auth.py)
|
||||
│ ├── rate_limiter.py # Rate limiting (unchanged)
|
||||
│ └── scheduler.py # Task scheduler (was admin_scheduler.py)
|
||||
│
|
||||
├── workers/
|
||||
│ ├── __init__.py
|
||||
│ ├── async_queue.py # Async task queue (was async_queue.py)
|
||||
│ └── batch_queue.py # Batch task queue (was batch_queue.py)
|
||||
│
|
||||
├── __init__.py # Main exports
|
||||
├── app.py # FastAPI app (imports updated)
|
||||
├── config.py # Configuration (unchanged)
|
||||
└── dependencies.py # Global dependencies (unchanged)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Changes Summary
|
||||
|
||||
### Files Moved and Renamed
|
||||
|
||||
| Old Location | New Location | Change Type |
|
||||
|-------------|--------------|-------------|
|
||||
| `admin_routes.py` | `api/v1/admin/documents.py` | Moved + Renamed |
|
||||
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Moved + Renamed |
|
||||
| `admin_training_routes.py` | `api/v1/admin/training.py` | Moved + Renamed |
|
||||
| `admin_auth.py` | `core/auth.py` | Moved |
|
||||
| `admin_autolabel.py` | `services/autolabel.py` | Moved |
|
||||
| `admin_scheduler.py` | `core/scheduler.py` | Moved |
|
||||
| `admin_schemas.py` | `schemas/admin.py` | Moved |
|
||||
| `routes.py` | `api/v1/routes.py` | Moved |
|
||||
| `schemas.py` | `schemas/inference.py` | Moved |
|
||||
| `services.py` | `services/inference.py` | Moved |
|
||||
| `async_routes.py` | `api/v1/async_api/routes.py` | Moved |
|
||||
| `async_queue.py` | `workers/async_queue.py` | Moved |
|
||||
| `async_service.py` | `services/async_processing.py` | Moved + Renamed |
|
||||
| `batch_queue.py` | `workers/batch_queue.py` | Moved |
|
||||
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Moved |
|
||||
| `batch_upload_service.py` | `services/batch_upload.py` | Moved |
|
||||
|
||||
**Total**: 16 files reorganized
|
||||
|
||||
### Files Updated
|
||||
|
||||
**Source Files** (imports updated):
|
||||
- `app.py` - Updated all imports to new structure
|
||||
- `api/v1/admin/documents.py` - Updated schema/auth imports
|
||||
- `api/v1/admin/annotations.py` - Updated schema/service imports
|
||||
- `api/v1/admin/training.py` - Updated schema/auth imports
|
||||
- `api/v1/routes.py` - Updated schema imports
|
||||
- `api/v1/async_api/routes.py` - Updated schema imports
|
||||
- `api/v1/batch/routes.py` - Updated service/worker imports
|
||||
- `services/async_processing.py` - Updated worker/core imports
|
||||
|
||||
**Test Files** (all 15 updated):
|
||||
- `test_admin_annotations.py`
|
||||
- `test_admin_auth.py`
|
||||
- `test_admin_routes.py`
|
||||
- `test_admin_routes_enhanced.py`
|
||||
- `test_admin_training.py`
|
||||
- `test_annotation_locks.py`
|
||||
- `test_annotation_phase5.py`
|
||||
- `test_async_queue.py`
|
||||
- `test_async_routes.py`
|
||||
- `test_async_service.py`
|
||||
- `test_autolabel_with_locks.py`
|
||||
- `test_batch_queue.py`
|
||||
- `test_batch_upload_routes.py`
|
||||
- `test_batch_upload_service.py`
|
||||
- `test_training_phase4.py`
|
||||
- `conftest.py`
|
||||
|
||||
---
|
||||
|
||||
## Import Examples
|
||||
|
||||
### Old Import Style (Before Refactoring)
|
||||
```python
|
||||
from src.web.admin_routes import create_admin_router
|
||||
from src.web.admin_schemas import DocumentItem
|
||||
from src.web.admin_auth import validate_admin_token
|
||||
from src.web.async_routes import create_async_router
|
||||
from src.web.schemas import ErrorResponse
|
||||
```
|
||||
|
||||
### New Import Style (After Refactoring)
|
||||
```python
|
||||
# Admin API
|
||||
from src.web.api.v1.admin.documents import create_admin_router
|
||||
from src.web.api.v1.admin import create_admin_router # Shorter alternative
|
||||
|
||||
# Schemas
|
||||
from src.web.schemas.admin import DocumentItem
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
# Core components
|
||||
from src.web.core.auth import validate_admin_token
|
||||
|
||||
# Async API
|
||||
from src.web.api.v1.async_api.routes import create_async_router
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Benefits Achieved
|
||||
|
||||
### 1. **Clear Separation of Concerns**
|
||||
- **API Routes**: All in `api/v1/` by version and feature
|
||||
- **Data Models**: All in `schemas/` by domain
|
||||
- **Business Logic**: All in `services/`
|
||||
- **Core Components**: Reusable utilities in `core/`
|
||||
- **Background Jobs**: Task queues in `workers/`
|
||||
|
||||
### 2. **Better Scalability**
|
||||
- Easy to add API v2 without touching v1
|
||||
- Clear namespace for each module
|
||||
- Reduced file sizes (no 800+ line files)
|
||||
- Follows single responsibility principle
|
||||
|
||||
### 3. **Improved Maintainability**
|
||||
- Find files by function, not by prefix
|
||||
- Each module has one clear purpose
|
||||
- Easier to onboard new developers
|
||||
- Better IDE navigation
|
||||
|
||||
### 4. **Standards Compliance**
|
||||
- Follows FastAPI best practices
|
||||
- Matches Django/Flask project structures
|
||||
- Standard Python package organization
|
||||
- Industry-standard naming conventions
|
||||
|
||||
---
|
||||
|
||||
## Testing Results
|
||||
|
||||
**Before Refactoring**:
|
||||
- 188 tests passing
|
||||
- 23% code coverage
|
||||
- Flat directory structure
|
||||
|
||||
**After Refactoring**:
|
||||
- ✅ 188 tests passing (0 failures)
|
||||
- ✅ 23% code coverage (maintained)
|
||||
- ✅ Clean hierarchical structure
|
||||
- ✅ All imports updated
|
||||
- ✅ No backward compatibility shims needed
|
||||
|
||||
---
|
||||
|
||||
## Migration Statistics
|
||||
|
||||
| Metric | Count |
|
||||
|--------|-------|
|
||||
| Files moved | 16 |
|
||||
| Directories created | 9 |
|
||||
| Files updated (source) | 8 |
|
||||
| Files updated (tests) | 16 |
|
||||
| Import statements updated | ~150 |
|
||||
| Lines of code changed | ~200 |
|
||||
| Tests broken | 0 |
|
||||
| Coverage lost | 0% |
|
||||
|
||||
---
|
||||
|
||||
## Code Diff Summary
|
||||
|
||||
```diff
|
||||
Before:
|
||||
src/web/
|
||||
├── admin_routes.py (645 lines)
|
||||
├── admin_annotation_routes.py (504 lines)
|
||||
├── admin_training_routes.py (565 lines)
|
||||
├── admin_auth.py (22 lines)
|
||||
├── admin_schemas.py (262 lines)
|
||||
... (15 more files at root level)
|
||||
|
||||
After:
|
||||
src/web/
|
||||
├── api/v1/
|
||||
│ ├── admin/ (3 route files)
|
||||
│ ├── async_api/ (1 route file)
|
||||
│ └── batch/ (1 route file)
|
||||
├── schemas/ (3 schema files)
|
||||
├── services/ (4 service files)
|
||||
├── core/ (3 core files)
|
||||
└── workers/ (2 worker files)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps (Optional)
|
||||
|
||||
### Phase 2: Documentation
|
||||
- [ ] Update API documentation with new import paths
|
||||
- [ ] Create migration guide for external developers
|
||||
- [ ] Update CLAUDE.md with new structure
|
||||
|
||||
### Phase 3: Further Optimization
|
||||
- [ ] Split large files (>400 lines) if needed
|
||||
- [ ] Extract common utilities
|
||||
- [ ] Add typing stubs
|
||||
|
||||
### Phase 4: Deprecation (Future)
|
||||
- [ ] Add deprecation warnings if creating compatibility layer
|
||||
- [ ] Remove old imports after grace period
|
||||
- [ ] Update all documentation
|
||||
|
||||
---
|
||||
|
||||
## Rollback Instructions
|
||||
|
||||
If needed, rollback is simple:
|
||||
```bash
|
||||
git revert <commit-hash>
|
||||
```
|
||||
|
||||
All changes are in version control, making rollback safe and easy.
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
✅ **Refactoring completed successfully**
|
||||
✅ **Zero breaking changes**
|
||||
✅ **All tests passing**
|
||||
✅ **Industry-standard structure achieved**
|
||||
|
||||
The web directory is now organized following Python and FastAPI best practices, making it easier to scale, maintain, and extend.
|
||||
186
docs/web-refactoring-plan.md
Normal file
186
docs/web-refactoring-plan.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Web Directory Refactoring Plan
|
||||
|
||||
## Current Structure Issues
|
||||
|
||||
1. **Flat structure**: All files in one directory (20 Python files)
|
||||
2. **Naming inconsistency**: Mix of `admin_*`, `async_*`, `batch_*` prefixes
|
||||
3. **Mixed concerns**: Routes, schemas, services, and workers in same directory
|
||||
4. **Poor scalability**: Hard to navigate and maintain as project grows
|
||||
|
||||
## Proposed Structure (Best Practices)
|
||||
|
||||
```
|
||||
src/web/
|
||||
├── __init__.py # Main exports
|
||||
├── app.py # FastAPI app factory
|
||||
├── config.py # App configuration
|
||||
├── dependencies.py # Global dependencies
|
||||
│
|
||||
├── api/ # API Routes Layer
|
||||
│ ├── __init__.py
|
||||
│ └── v1/ # API version 1
|
||||
│ ├── __init__.py
|
||||
│ ├── routes.py # Public API routes (inference)
|
||||
│ ├── admin/ # Admin API routes
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── documents.py # admin_routes.py → documents.py
|
||||
│ │ ├── annotations.py # admin_annotation_routes.py → annotations.py
|
||||
│ │ ├── training.py # admin_training_routes.py → training.py
|
||||
│ │ └── auth.py # admin_auth.py → auth.py (routes only)
|
||||
│ ├── async_api/ # Async processing API
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── routes.py # async_routes.py → routes.py
|
||||
│ └── batch/ # Batch upload API
|
||||
│ ├── __init__.py
|
||||
│ └── routes.py # batch_upload_routes.py → routes.py
|
||||
│
|
||||
├── schemas/ # Pydantic Models
|
||||
│ ├── __init__.py
|
||||
│ ├── common.py # Shared schemas (ErrorResponse, etc.)
|
||||
│ ├── inference.py # schemas.py → inference.py
|
||||
│ ├── admin.py # admin_schemas.py → admin.py
|
||||
│ ├── async_api.py # New: async API schemas
|
||||
│ └── batch.py # New: batch upload schemas
|
||||
│
|
||||
├── services/ # Business Logic Layer
|
||||
│ ├── __init__.py
|
||||
│ ├── inference.py # services.py → inference.py
|
||||
│ ├── autolabel.py # admin_autolabel.py → autolabel.py
|
||||
│ ├── async_processing.py # async_service.py → async_processing.py
|
||||
│ └── batch_upload.py # batch_upload_service.py → batch_upload.py
|
||||
│
|
||||
├── core/ # Core Components
|
||||
│ ├── __init__.py
|
||||
│ ├── auth.py # admin_auth.py → auth.py (logic only)
|
||||
│ ├── rate_limiter.py # rate_limiter.py → rate_limiter.py
|
||||
│ └── scheduler.py # admin_scheduler.py → scheduler.py
|
||||
│
|
||||
└── workers/ # Background Task Queues
|
||||
├── __init__.py
|
||||
├── async_queue.py # async_queue.py → async_queue.py
|
||||
└── batch_queue.py # batch_queue.py → batch_queue.py
|
||||
```
|
||||
|
||||
## File Mapping
|
||||
|
||||
### Current → New Location
|
||||
|
||||
| Current File | New Location | Purpose |
|
||||
|--------------|--------------|---------|
|
||||
| `admin_routes.py` | `api/v1/admin/documents.py` | Document management routes |
|
||||
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Annotation routes |
|
||||
| `admin_training_routes.py` | `api/v1/admin/training.py` | Training routes |
|
||||
| `admin_auth.py` | Split: `api/v1/admin/auth.py` + `core/auth.py` | Auth routes + logic |
|
||||
| `admin_schemas.py` | `schemas/admin.py` | Admin Pydantic models |
|
||||
| `admin_autolabel.py` | `services/autolabel.py` | Auto-label service |
|
||||
| `admin_scheduler.py` | `core/scheduler.py` | Training scheduler |
|
||||
| `routes.py` | `api/v1/routes.py` | Public inference API |
|
||||
| `schemas.py` | `schemas/inference.py` | Inference models |
|
||||
| `services.py` | `services/inference.py` | Inference service |
|
||||
| `async_routes.py` | `api/v1/async_api/routes.py` | Async API routes |
|
||||
| `async_service.py` | `services/async_processing.py` | Async processing service |
|
||||
| `async_queue.py` | `workers/async_queue.py` | Async task queue |
|
||||
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Batch upload routes |
|
||||
| `batch_upload_service.py` | `services/batch_upload.py` | Batch upload service |
|
||||
| `batch_queue.py` | `workers/batch_queue.py` | Batch task queue |
|
||||
| `rate_limiter.py` | `core/rate_limiter.py` | Rate limiting logic |
|
||||
| `config.py` | `config.py` | Keep as-is |
|
||||
| `dependencies.py` | `dependencies.py` | Keep as-is |
|
||||
| `app.py` | `app.py` | Keep as-is (update imports) |
|
||||
|
||||
## Benefits
|
||||
|
||||
### 1. Clear Separation of Concerns
|
||||
- **Routes**: API endpoint definitions
|
||||
- **Schemas**: Data validation models
|
||||
- **Services**: Business logic
|
||||
- **Core**: Reusable components
|
||||
- **Workers**: Background processing
|
||||
|
||||
### 2. Better Scalability
|
||||
- Easy to add new API versions (`v2/`)
|
||||
- Clear namespace for each domain
|
||||
- Reduced file size (no 800+ line files)
|
||||
|
||||
### 3. Improved Maintainability
|
||||
- Find files by function, not by prefix
|
||||
- Each module has single responsibility
|
||||
- Easier to write focused tests
|
||||
|
||||
### 4. Standard Python Patterns
|
||||
- Package-based organization
|
||||
- Follows FastAPI best practices
|
||||
- Similar to Django/Flask structures
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### Phase 1: Create New Structure (No Breaking Changes)
|
||||
1. Create new directories: `api/`, `schemas/`, `services/`, `core/`, `workers/`
|
||||
2. Copy files to new locations (don't delete originals yet)
|
||||
3. Update imports in new files
|
||||
4. Add `__init__.py` with proper exports
|
||||
|
||||
### Phase 2: Update Tests
|
||||
5. Update test imports to use new structure
|
||||
6. Run tests to verify nothing breaks
|
||||
7. Fix any import issues
|
||||
|
||||
### Phase 3: Update Main App
|
||||
8. Update `app.py` to import from new locations
|
||||
9. Run full test suite
|
||||
10. Verify all endpoints work
|
||||
|
||||
### Phase 4: Cleanup
|
||||
11. Delete old files
|
||||
12. Update documentation
|
||||
13. Final test run
|
||||
|
||||
## Migration Priority
|
||||
|
||||
**High Priority** (Most used):
|
||||
- Routes and schemas (user-facing APIs)
|
||||
- Services (core business logic)
|
||||
|
||||
**Medium Priority**:
|
||||
- Core components (auth, rate limiter)
|
||||
- Workers (background tasks)
|
||||
|
||||
**Low Priority**:
|
||||
- Config and dependencies (already well-located)
|
||||
|
||||
## Backwards Compatibility
|
||||
|
||||
During migration, maintain backwards compatibility:
|
||||
|
||||
```python
|
||||
# src/web/__init__.py
|
||||
# Old imports still work
|
||||
from src.web.api.v1.admin.documents import router as admin_router
|
||||
from src.web.schemas.admin import AdminDocument
|
||||
|
||||
# Keep old names for compatibility (temporary)
|
||||
admin_routes = admin_router # Deprecated alias
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
1. **Unit Tests**: Test each module independently
|
||||
2. **Integration Tests**: Test API endpoints still work
|
||||
3. **Import Tests**: Verify all old imports still work
|
||||
4. **Coverage**: Maintain current 23% coverage minimum
|
||||
|
||||
## Rollback Plan
|
||||
|
||||
If issues arise:
|
||||
1. Keep old files until fully migrated
|
||||
2. Git allows easy revert
|
||||
3. Tests catch breaking changes early
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
Would you like me to:
|
||||
1. **Start Phase 1**: Create new directory structure and move files?
|
||||
2. **Create migration script**: Automate the file moves and import updates?
|
||||
3. **Focus on specific area**: Start with admin API or async API first?
|
||||
218
docs/web-refactoring-status.md
Normal file
218
docs/web-refactoring-status.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# Web Directory Refactoring - Current Status
|
||||
|
||||
## ✅ Completed Steps
|
||||
|
||||
### 1. Directory Structure Created
|
||||
```
|
||||
src/web/
|
||||
├── api/
|
||||
│ ├── v1/
|
||||
│ │ ├── admin/ (documents.py, annotations.py, training.py)
|
||||
│ │ ├── async_api/ (routes.py)
|
||||
│ │ ├── batch/ (routes.py)
|
||||
│ │ └── routes.py (public inference API)
|
||||
├── schemas/
|
||||
│ ├── admin.py (admin schemas)
|
||||
│ ├── inference.py (inference + async schemas)
|
||||
│ └── common.py (ErrorResponse)
|
||||
├── services/
|
||||
│ ├── autolabel.py
|
||||
│ ├── async_processing.py
|
||||
│ ├── batch_upload.py
|
||||
│ └── inference.py
|
||||
├── core/
|
||||
│ ├── auth.py
|
||||
│ ├── rate_limiter.py
|
||||
│ └── scheduler.py
|
||||
└── workers/
|
||||
├── async_queue.py
|
||||
└── batch_queue.py
|
||||
```
|
||||
|
||||
### 2. Files Copied and Imports Updated
|
||||
|
||||
#### Admin API (✅ Complete)
|
||||
- [x] `admin_routes.py` → `api/v1/admin/documents.py` (imports updated)
|
||||
- [x] `admin_annotation_routes.py` → `api/v1/admin/annotations.py` (imports updated)
|
||||
- [x] `admin_training_routes.py` → `api/v1/admin/training.py` (imports updated)
|
||||
- [x] `api/v1/admin/__init__.py` created with exports
|
||||
|
||||
#### Public & Async API (✅ Complete)
|
||||
- [x] `routes.py` → `api/v1/routes.py` (imports updated)
|
||||
- [x] `async_routes.py` → `api/v1/async_api/routes.py` (imports updated)
|
||||
- [x] `batch_upload_routes.py` → `api/v1/batch/routes.py` (copied, imports pending)
|
||||
|
||||
#### Schemas (✅ Complete)
|
||||
- [x] `admin_schemas.py` → `schemas/admin.py`
|
||||
- [x] `schemas.py` → `schemas/inference.py`
|
||||
- [x] `schemas/common.py` created
|
||||
- [x] `schemas/__init__.py` created with exports
|
||||
|
||||
#### Services (✅ Complete)
|
||||
- [x] `admin_autolabel.py` → `services/autolabel.py`
|
||||
- [x] `async_service.py` → `services/async_processing.py`
|
||||
- [x] `batch_upload_service.py` → `services/batch_upload.py`
|
||||
- [x] `services.py` → `services/inference.py`
|
||||
- [x] `services/__init__.py` created
|
||||
|
||||
#### Core Components (✅ Complete)
|
||||
- [x] `admin_auth.py` → `core/auth.py`
|
||||
- [x] `rate_limiter.py` → `core/rate_limiter.py`
|
||||
- [x] `admin_scheduler.py` → `core/scheduler.py`
|
||||
- [x] `core/__init__.py` created
|
||||
|
||||
#### Workers (✅ Complete)
|
||||
- [x] `async_queue.py` → `workers/async_queue.py`
|
||||
- [x] `batch_queue.py` → `workers/batch_queue.py`
|
||||
- [x] `workers/__init__.py` created
|
||||
|
||||
#### Main App (✅ Complete)
|
||||
- [x] `app.py` imports updated to use new structure
|
||||
|
||||
---
|
||||
|
||||
## ⏳ Remaining Work
|
||||
|
||||
### 1. Update Remaining File Imports (HIGH PRIORITY)
|
||||
|
||||
Files that need import updates:
|
||||
- [ ] `api/v1/batch/routes.py` - update to use new schema/service imports
|
||||
- [ ] `services/autolabel.py` - may need import updates if it references old paths
|
||||
- [ ] `services/async_processing.py` - check for old import references
|
||||
- [ ] `services/batch_upload.py` - check for old import references
|
||||
- [ ] `services/inference.py` - check for old import references
|
||||
|
||||
### 2. Update ALL Test Files (CRITICAL)
|
||||
|
||||
Test files need to import from new locations. Pattern:
|
||||
|
||||
**Old:**
|
||||
```python
|
||||
from src.web.admin_routes import create_admin_router
|
||||
from src.web.admin_schemas import DocumentItem
|
||||
from src.web.admin_auth import validate_admin_token
|
||||
```
|
||||
|
||||
**New:**
|
||||
```python
|
||||
from src.web.api.v1.admin import create_admin_router
|
||||
from src.web.schemas.admin import DocumentItem
|
||||
from src.web.core.auth import validate_admin_token
|
||||
```
|
||||
|
||||
Test files to update:
|
||||
- [ ] `tests/web/test_admin_annotations.py`
|
||||
- [ ] `tests/web/test_admin_auth.py`
|
||||
- [ ] `tests/web/test_admin_routes.py`
|
||||
- [ ] `tests/web/test_admin_routes_enhanced.py`
|
||||
- [ ] `tests/web/test_admin_training.py`
|
||||
- [ ] `tests/web/test_annotation_locks.py`
|
||||
- [ ] `tests/web/test_annotation_phase5.py`
|
||||
- [ ] `tests/web/test_async_queue.py`
|
||||
- [ ] `tests/web/test_async_routes.py`
|
||||
- [ ] `tests/web/test_async_service.py`
|
||||
- [ ] `tests/web/test_autolabel_with_locks.py`
|
||||
- [ ] `tests/web/test_batch_queue.py`
|
||||
- [ ] `tests/web/test_batch_upload_routes.py`
|
||||
- [ ] `tests/web/test_batch_upload_service.py`
|
||||
- [ ] `tests/web/test_rate_limiter.py`
|
||||
- [ ] `tests/web/test_training_phase4.py`
|
||||
|
||||
### 3. Create Backward Compatibility Layer (OPTIONAL)
|
||||
|
||||
Keep old imports working temporarily:
|
||||
|
||||
```python
|
||||
# src/web/admin_routes.py (temporary compatibility shim)
|
||||
\"\"\"
|
||||
DEPRECATED: Use src.web.api.v1.admin.documents instead.
|
||||
This file will be removed in next version.
|
||||
\"\"\"
|
||||
import warnings
|
||||
from src.web.api.v1.admin.documents import *
|
||||
|
||||
warnings.warn(
|
||||
"Importing from src.web.admin_routes is deprecated. "
|
||||
"Use src.web.api.v1.admin.documents instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Verify and Test
|
||||
|
||||
1. Run tests:
|
||||
```bash
|
||||
pytest tests/web/ -v
|
||||
```
|
||||
|
||||
2. Check for any import errors:
|
||||
```bash
|
||||
python -c "from src.web.app import create_app; create_app()"
|
||||
```
|
||||
|
||||
3. Start server and test endpoints:
|
||||
```bash
|
||||
python run_server.py
|
||||
```
|
||||
|
||||
### 5. Clean Up Old Files (ONLY AFTER TESTS PASS)
|
||||
|
||||
Old files to remove:
|
||||
- `src/web/admin_*.py` (7 files)
|
||||
- `src/web/async_*.py` (3 files)
|
||||
- `src/web/batch_*.py` (3 files)
|
||||
- `src/web/routes.py`
|
||||
- `src/web/services.py`
|
||||
- `src/web/schemas.py`
|
||||
- `src/web/rate_limiter.py`
|
||||
|
||||
Keep these files (don't remove):
|
||||
- `src/web/__init__.py`
|
||||
- `src/web/app.py`
|
||||
- `src/web/config.py`
|
||||
- `src/web/dependencies.py`
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Next Immediate Steps
|
||||
|
||||
1. **Update batch/routes.py imports** - Quick fix for remaining API route
|
||||
2. **Update test file imports** - Critical for verification
|
||||
3. **Run test suite** - Verify nothing broke
|
||||
4. **Fix any import errors** - Address failures
|
||||
5. **Remove old files** - Clean up after tests pass
|
||||
|
||||
---
|
||||
|
||||
## 📊 Migration Impact Summary
|
||||
|
||||
| Category | Files Moved | Imports Updated | Status |
|
||||
|----------|-------------|-----------------|--------|
|
||||
| API Routes | 7 | 5/7 | 🟡 In Progress |
|
||||
| Schemas | 3 | 3/3 | ✅ Complete |
|
||||
| Services | 4 | 0/4 | ⚠️ Pending |
|
||||
| Core | 3 | 3/3 | ✅ Complete |
|
||||
| Workers | 2 | 2/2 | ✅ Complete |
|
||||
| Tests | 0 | 0/16 | ❌ Not Started |
|
||||
|
||||
**Overall Progress: 65%**
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Benefits After Migration
|
||||
|
||||
1. **Better Organization**: Clear separation by function
|
||||
2. **Easier Navigation**: Find files by purpose, not prefix
|
||||
3. **Scalability**: Easy to add new API versions
|
||||
4. **Standard Structure**: Follows FastAPI best practices
|
||||
5. **Maintainability**: Each module has single responsibility
|
||||
|
||||
---
|
||||
|
||||
## 📝 Notes
|
||||
|
||||
- All original files are still in place (no data loss risk)
|
||||
- New structure is operational but needs import updates
|
||||
- Backward compatibility can be added if needed
|
||||
- Tests will validate the migration success
|
||||
5
frontend/.env.example
Normal file
5
frontend/.env.example
Normal file
@@ -0,0 +1,5 @@
|
||||
# Backend API URL
|
||||
VITE_API_URL=http://localhost:8000
|
||||
|
||||
# WebSocket URL (for future real-time updates)
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
24
frontend/.gitignore
vendored
Normal file
24
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
20
frontend/README.md
Normal file
20
frontend/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
<div align="center">
|
||||
<img width="1200" height="475" alt="GHBanner" src="https://github.com/user-attachments/assets/0aa67016-6eaf-458a-adb2-6e31a0763ed6" />
|
||||
</div>
|
||||
|
||||
# Run and deploy your AI Studio app
|
||||
|
||||
This contains everything you need to run your app locally.
|
||||
|
||||
View your app in AI Studio: https://ai.studio/apps/drive/13hqd80ft4g_LngMYB8LLJxx2XU8C_eI4
|
||||
|
||||
## Run Locally
|
||||
|
||||
**Prerequisites:** Node.js
|
||||
|
||||
|
||||
1. Install dependencies:
|
||||
`npm install`
|
||||
2. Set the `GEMINI_API_KEY` in [.env.local](.env.local) to your Gemini API key
|
||||
3. Run the app:
|
||||
`npm run dev`
|
||||
240
frontend/REFACTORING_PLAN.md
Normal file
240
frontend/REFACTORING_PLAN.md
Normal file
@@ -0,0 +1,240 @@
|
||||
# Frontend Refactoring Plan
|
||||
|
||||
## Current Structure Issues
|
||||
|
||||
1. **Flat component organization** - All components in one directory
|
||||
2. **Mock data only** - No real API integration
|
||||
3. **No state management** - Props drilling everywhere
|
||||
4. **CDN dependencies** - Should use npm packages
|
||||
5. **Manual routing** - Using useState instead of react-router
|
||||
6. **No TypeScript integration with backend** - Types don't match API schemas
|
||||
|
||||
## Recommended Structure
|
||||
|
||||
```
|
||||
frontend/
|
||||
├── public/
|
||||
│ └── favicon.ico
|
||||
│
|
||||
├── src/
|
||||
│ ├── api/ # API Layer
|
||||
│ │ ├── client.ts # Axios instance + interceptors
|
||||
│ │ ├── types.ts # API request/response types
|
||||
│ │ └── endpoints/
|
||||
│ │ ├── documents.ts # GET /api/v1/admin/documents
|
||||
│ │ ├── annotations.ts # GET/POST /api/v1/admin/documents/{id}/annotations
|
||||
│ │ ├── training.ts # GET/POST /api/v1/admin/training/*
|
||||
│ │ ├── inference.ts # POST /api/v1/infer
|
||||
│ │ └── async.ts # POST /api/v1/async/submit
|
||||
│ │
|
||||
│ ├── components/
|
||||
│ │ ├── common/ # Reusable components
|
||||
│ │ │ ├── Badge.tsx
|
||||
│ │ │ ├── Button.tsx
|
||||
│ │ │ ├── Input.tsx
|
||||
│ │ │ ├── Modal.tsx
|
||||
│ │ │ ├── Table.tsx
|
||||
│ │ │ ├── ProgressBar.tsx
|
||||
│ │ │ └── StatusBadge.tsx
|
||||
│ │ │
|
||||
│ │ ├── layout/ # Layout components
|
||||
│ │ │ ├── TopNav.tsx
|
||||
│ │ │ ├── Sidebar.tsx
|
||||
│ │ │ └── PageHeader.tsx
|
||||
│ │ │
|
||||
│ │ ├── documents/ # Document-specific components
|
||||
│ │ │ ├── DocumentTable.tsx
|
||||
│ │ │ ├── DocumentFilters.tsx
|
||||
│ │ │ ├── DocumentRow.tsx
|
||||
│ │ │ ├── UploadModal.tsx
|
||||
│ │ │ └── BatchUploadModal.tsx
|
||||
│ │ │
|
||||
│ │ ├── annotations/ # Annotation components
|
||||
│ │ │ ├── AnnotationCanvas.tsx
|
||||
│ │ │ ├── AnnotationBox.tsx
|
||||
│ │ │ ├── AnnotationTable.tsx
|
||||
│ │ │ ├── FieldEditor.tsx
|
||||
│ │ │ └── VerificationPanel.tsx
|
||||
│ │ │
|
||||
│ │ └── training/ # Training components
|
||||
│ │ ├── DocumentSelector.tsx
|
||||
│ │ ├── TrainingConfig.tsx
|
||||
│ │ ├── TrainingJobList.tsx
|
||||
│ │ ├── ModelCard.tsx
|
||||
│ │ └── MetricsChart.tsx
|
||||
│ │
|
||||
│ ├── pages/ # Page-level components
|
||||
│ │ ├── DocumentsPage.tsx # Was Dashboard.tsx
|
||||
│ │ ├── DocumentDetailPage.tsx # Was DocumentDetail.tsx
|
||||
│ │ ├── TrainingPage.tsx # Was Training.tsx
|
||||
│ │ ├── ModelsPage.tsx # Was Models.tsx
|
||||
│ │ └── InferencePage.tsx # New: Test inference
|
||||
│ │
|
||||
│ ├── hooks/ # Custom React Hooks
|
||||
│ │ ├── useDocuments.ts # Document CRUD + listing
|
||||
│ │ ├── useAnnotations.ts # Annotation management
|
||||
│ │ ├── useTraining.ts # Training jobs
|
||||
│ │ ├── usePolling.ts # Auto-refresh for async jobs
|
||||
│ │ └── useDebounce.ts # Debounce search inputs
|
||||
│ │
|
||||
│ ├── store/ # State Management (Zustand)
|
||||
│ │ ├── documentsStore.ts
|
||||
│ │ ├── annotationsStore.ts
|
||||
│ │ ├── trainingStore.ts
|
||||
│ │ └── uiStore.ts
|
||||
│ │
|
||||
│ ├── types/ # TypeScript Types
|
||||
│ │ ├── index.ts
|
||||
│ │ ├── document.ts
|
||||
│ │ ├── annotation.ts
|
||||
│ │ ├── training.ts
|
||||
│ │ └── api.ts
|
||||
│ │
|
||||
│ ├── utils/ # Utility Functions
|
||||
│ │ ├── formatters.ts # Date, currency, etc.
|
||||
│ │ ├── validators.ts # Form validation
|
||||
│ │ └── constants.ts # Field definitions, statuses
|
||||
│ │
|
||||
│ ├── styles/
|
||||
│ │ └── index.css # Tailwind entry
|
||||
│ │
|
||||
│ ├── App.tsx
|
||||
│ ├── main.tsx
|
||||
│ └── router.tsx # React Router config
|
||||
│
|
||||
├── .env.example
|
||||
├── package.json
|
||||
├── tsconfig.json
|
||||
├── vite.config.ts
|
||||
├── tailwind.config.js
|
||||
├── postcss.config.js
|
||||
└── index.html
|
||||
```
|
||||
|
||||
## Migration Steps
|
||||
|
||||
### Phase 1: Setup Infrastructure
|
||||
- [ ] Install dependencies (axios, react-router, zustand, @tanstack/react-query)
|
||||
- [ ] Setup local Tailwind (remove CDN)
|
||||
- [ ] Create API client with interceptors
|
||||
- [ ] Add environment variables (.env.local with VITE_API_URL)
|
||||
|
||||
### Phase 2: Create API Layer
|
||||
- [ ] Create `src/api/client.ts` with axios instance
|
||||
- [ ] Create `src/api/endpoints/documents.ts` matching backend API
|
||||
- [ ] Create `src/api/endpoints/annotations.ts`
|
||||
- [ ] Create `src/api/endpoints/training.ts`
|
||||
- [ ] Add types matching backend schemas
|
||||
|
||||
### Phase 3: Reorganize Components
|
||||
- [ ] Move existing components to new structure
|
||||
- [ ] Split large components (Dashboard > DocumentTable + DocumentFilters + DocumentRow)
|
||||
- [ ] Extract reusable components (Badge, Button already done)
|
||||
- [ ] Create layout components (TopNav, Sidebar)
|
||||
|
||||
### Phase 4: Add Routing
|
||||
- [ ] Install react-router-dom
|
||||
- [ ] Create router.tsx with routes
|
||||
- [ ] Update App.tsx to use RouterProvider
|
||||
- [ ] Add navigation links
|
||||
|
||||
### Phase 5: State Management
|
||||
- [ ] Create custom hooks (useDocuments, useAnnotations)
|
||||
- [ ] Use @tanstack/react-query for server state
|
||||
- [ ] Add Zustand stores for UI state
|
||||
- [ ] Replace mock data with API calls
|
||||
|
||||
### Phase 6: Backend Integration
|
||||
- [ ] Update CORS settings in backend
|
||||
- [ ] Test all API endpoints
|
||||
- [ ] Add error handling
|
||||
- [ ] Add loading states
|
||||
|
||||
## Dependencies to Add
|
||||
|
||||
```json
|
||||
{
|
||||
"dependencies": {
|
||||
"react-router-dom": "^6.22.0",
|
||||
"axios": "^1.6.7",
|
||||
"zustand": "^4.5.0",
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"clsx": "^2.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tailwindcss": "^3.4.1",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"postcss": "^8.4.35"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Files to Create
|
||||
|
||||
### tailwind.config.js
|
||||
```javascript
|
||||
export default {
|
||||
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
warm: {
|
||||
bg: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
border: '#E6E4E1',
|
||||
divider: '#D8D6D2',
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
state: {
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
info: '#3A3A3A',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### .env.example
|
||||
```bash
|
||||
VITE_API_URL=http://localhost:8000
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
```
|
||||
|
||||
## Type Generation from Backend
|
||||
|
||||
Consider generating TypeScript types from Python Pydantic schemas:
|
||||
- Option 1: Use `datamodel-code-generator` to convert schemas
|
||||
- Option 2: Manually maintain types in `src/types/api.ts`
|
||||
- Option 3: Use OpenAPI spec + openapi-typescript-codegen
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Unit tests: Vitest for components
|
||||
- Integration tests: React Testing Library
|
||||
- E2E tests: Playwright (matching backend)
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Code splitting by route
|
||||
- Lazy load heavy components (AnnotationCanvas)
|
||||
- Optimize re-renders with React.memo
|
||||
- Use virtual scrolling for large tables
|
||||
- Image lazy loading for document previews
|
||||
|
||||
## Accessibility
|
||||
|
||||
- Proper ARIA labels
|
||||
- Keyboard navigation
|
||||
- Focus management
|
||||
- Color contrast compliance (already done with Warm Graphite theme)
|
||||
256
frontend/SETUP.md
Normal file
256
frontend/SETUP.md
Normal file
@@ -0,0 +1,256 @@
|
||||
# Frontend Setup Guide
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Install Dependencies
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
```
|
||||
|
||||
### 2. Configure Environment
|
||||
|
||||
Copy `.env.example` to `.env.local` and update if needed:
|
||||
|
||||
```bash
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
Default configuration:
|
||||
```
|
||||
VITE_API_URL=http://localhost:8000
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
```
|
||||
|
||||
### 3. Start Backend API
|
||||
|
||||
Make sure the backend is running first:
|
||||
|
||||
```bash
|
||||
# From project root
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python run_server.py"
|
||||
```
|
||||
|
||||
Backend will be available at: http://localhost:8000
|
||||
|
||||
### 4. Start Frontend Dev Server
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Frontend will be available at: http://localhost:3000
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
frontend/
|
||||
├── src/
|
||||
│ ├── api/ # API client layer
|
||||
│ │ ├── client.ts # Axios instance with interceptors
|
||||
│ │ ├── types.ts # API type definitions
|
||||
│ │ └── endpoints/
|
||||
│ │ ├── documents.ts # Document API calls
|
||||
│ │ ├── annotations.ts # Annotation API calls
|
||||
│ │ └── training.ts # Training API calls
|
||||
│ │
|
||||
│ ├── components/ # React components
|
||||
│ │ └── Dashboard.tsx # Updated with real API integration
|
||||
│ │
|
||||
│ ├── hooks/ # Custom React Hooks
|
||||
│ │ ├── useDocuments.ts
|
||||
│ │ ├── useDocumentDetail.ts
|
||||
│ │ ├── useAnnotations.ts
|
||||
│ │ └── useTraining.ts
|
||||
│ │
|
||||
│ ├── styles/
|
||||
│ │ └── index.css # Tailwind CSS entry
|
||||
│ │
|
||||
│ ├── App.tsx
|
||||
│ └── main.tsx # App entry point with QueryClient
|
||||
│
|
||||
├── components/ # Legacy components (to be migrated)
|
||||
│ ├── Badge.tsx
|
||||
│ ├── Button.tsx
|
||||
│ ├── Layout.tsx
|
||||
│ ├── DocumentDetail.tsx
|
||||
│ ├── Training.tsx
|
||||
│ ├── Models.tsx
|
||||
│ └── UploadModal.tsx
|
||||
│
|
||||
├── tailwind.config.js # Tailwind configuration
|
||||
├── postcss.config.js
|
||||
├── vite.config.ts
|
||||
├── package.json
|
||||
└── index.html
|
||||
```
|
||||
|
||||
## Key Technologies
|
||||
|
||||
- **React 19** - UI framework
|
||||
- **TypeScript** - Type safety
|
||||
- **Vite** - Build tool
|
||||
- **Tailwind CSS** - Styling (Warm Graphite theme)
|
||||
- **Axios** - HTTP client
|
||||
- **@tanstack/react-query** - Server state management
|
||||
- **lucide-react** - Icon library
|
||||
|
||||
## API Integration
|
||||
|
||||
### Authentication
|
||||
|
||||
The app stores admin token in localStorage:
|
||||
|
||||
```typescript
|
||||
localStorage.setItem('admin_token', 'your-token')
|
||||
```
|
||||
|
||||
All API requests automatically include the `X-Admin-Token` header.
|
||||
|
||||
### Available Hooks
|
||||
|
||||
#### useDocuments
|
||||
|
||||
```typescript
|
||||
const {
|
||||
documents,
|
||||
total,
|
||||
isLoading,
|
||||
uploadDocument,
|
||||
deleteDocument,
|
||||
triggerAutoLabel,
|
||||
} = useDocuments({ status: 'labeled', limit: 20 })
|
||||
```
|
||||
|
||||
#### useDocumentDetail
|
||||
|
||||
```typescript
|
||||
const { document, annotations, isLoading } = useDocumentDetail(documentId)
|
||||
```
|
||||
|
||||
#### useAnnotations
|
||||
|
||||
```typescript
|
||||
const {
|
||||
createAnnotation,
|
||||
updateAnnotation,
|
||||
deleteAnnotation,
|
||||
verifyAnnotation,
|
||||
overrideAnnotation,
|
||||
} = useAnnotations(documentId)
|
||||
```
|
||||
|
||||
#### useTraining
|
||||
|
||||
```typescript
|
||||
const {
|
||||
models,
|
||||
isLoadingModels,
|
||||
startTraining,
|
||||
downloadModel,
|
||||
} = useTraining()
|
||||
```
|
||||
|
||||
## Features Implemented
|
||||
|
||||
### Phase 1 (Completed)
|
||||
- ✅ API client with axios interceptors
|
||||
- ✅ Type-safe API endpoints
|
||||
- ✅ React Query for server state
|
||||
- ✅ Custom hooks for all APIs
|
||||
- ✅ Dashboard with real data
|
||||
- ✅ Local Tailwind CSS
|
||||
- ✅ Environment configuration
|
||||
- ✅ CORS configured in backend
|
||||
|
||||
### Phase 2 (TODO)
|
||||
- [ ] Update DocumentDetail to use useDocumentDetail
|
||||
- [ ] Update Training page to use useTraining hooks
|
||||
- [ ] Update Models page with real data
|
||||
- [ ] Add UploadModal integration with API
|
||||
- [ ] Add react-router for proper routing
|
||||
- [ ] Add error boundary
|
||||
- [ ] Add loading states
|
||||
- [ ] Add toast notifications
|
||||
|
||||
### Phase 3 (TODO)
|
||||
- [ ] Annotation canvas with real data
|
||||
- [ ] Batch upload functionality
|
||||
- [ ] Auto-label progress polling
|
||||
- [ ] Training job monitoring
|
||||
- [ ] Model download functionality
|
||||
- [ ] Search and filtering
|
||||
- [ ] Pagination
|
||||
|
||||
## Development Tips
|
||||
|
||||
### Hot Module Replacement
|
||||
|
||||
Vite supports HMR. Changes will reflect immediately without page reload.
|
||||
|
||||
### API Debugging
|
||||
|
||||
Check browser console for API requests:
|
||||
- Network tab shows all requests/responses
|
||||
- Axios interceptors log errors automatically
|
||||
|
||||
### Type Safety
|
||||
|
||||
TypeScript types in `src/api/types.ts` match backend Pydantic schemas.
|
||||
|
||||
To regenerate types from backend:
|
||||
```bash
|
||||
# TODO: Add type generation script
|
||||
```
|
||||
|
||||
### Backend API Documentation
|
||||
|
||||
Visit http://localhost:8000/docs for interactive API documentation (Swagger UI).
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### CORS Errors
|
||||
|
||||
If you see CORS errors:
|
||||
1. Check backend is running at http://localhost:8000
|
||||
2. Verify CORS settings in `src/web/app.py`
|
||||
3. Check `.env.local` has correct `VITE_API_URL`
|
||||
|
||||
### Module Not Found
|
||||
|
||||
If imports fail:
|
||||
```bash
|
||||
rm -rf node_modules package-lock.json
|
||||
npm install
|
||||
```
|
||||
|
||||
### Types Not Matching
|
||||
|
||||
If API responses don't match types:
|
||||
1. Check backend version is up-to-date
|
||||
2. Verify types in `src/api/types.ts`
|
||||
3. Check API response in Network tab
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Run `npm install` to install dependencies
|
||||
2. Start backend server
|
||||
3. Run `npm run dev` to start frontend
|
||||
4. Open http://localhost:3000
|
||||
5. Create an admin token via backend API
|
||||
6. Store token in localStorage via browser console:
|
||||
```javascript
|
||||
localStorage.setItem('admin_token', 'your-token-here')
|
||||
```
|
||||
7. Refresh page to see authenticated API calls
|
||||
|
||||
## Production Build
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
npm run preview # Preview production build
|
||||
```
|
||||
|
||||
Build output will be in `dist/` directory.
|
||||
15
frontend/index.html
Normal file
15
frontend/index.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Graphite Annotator - Invoice Field Extraction</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
5
frontend/metadata.json
Normal file
5
frontend/metadata.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Graphite Annotator",
|
||||
"description": "A professional, warm graphite themed document annotation and training tool for enterprise use cases.",
|
||||
"requestFramePermissions": []
|
||||
}
|
||||
4899
frontend/package-lock.json
generated
Normal file
4899
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
41
frontend/package.json
Normal file
41
frontend/package.json
Normal file
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"name": "graphite-annotator",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:coverage": "vitest run --coverage"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"axios": "^1.6.7",
|
||||
"clsx": "^2.1.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"lucide-react": "^0.563.0",
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"react-router-dom": "^6.22.0",
|
||||
"recharts": "^3.7.0",
|
||||
"zustand": "^4.5.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/react": "^16.3.2",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
"@types/node": "^22.14.0",
|
||||
"@vitejs/plugin-react": "^5.0.0",
|
||||
"@vitest/coverage-v8": "^4.0.18",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"jsdom": "^27.4.0",
|
||||
"postcss": "^8.4.35",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"typescript": "~5.8.2",
|
||||
"vite": "^6.2.0",
|
||||
"vitest": "^4.0.18"
|
||||
}
|
||||
}
|
||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
||||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
81
frontend/src/App.tsx
Normal file
81
frontend/src/App.tsx
Normal file
@@ -0,0 +1,81 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { Layout } from './components/Layout'
|
||||
import { DashboardOverview } from './components/DashboardOverview'
|
||||
import { Dashboard } from './components/Dashboard'
|
||||
import { DocumentDetail } from './components/DocumentDetail'
|
||||
import { Training } from './components/Training'
|
||||
import { DatasetDetail } from './components/DatasetDetail'
|
||||
import { Models } from './components/Models'
|
||||
import { Login } from './components/Login'
|
||||
import { InferenceDemo } from './components/InferenceDemo'
|
||||
|
||||
const App: React.FC = () => {
|
||||
const [currentView, setCurrentView] = useState('dashboard')
|
||||
const [selectedDocId, setSelectedDocId] = useState<string | null>(null)
|
||||
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
const token = localStorage.getItem('admin_token')
|
||||
setIsAuthenticated(!!token)
|
||||
}, [])
|
||||
|
||||
const handleNavigate = (view: string, docId?: string) => {
|
||||
setCurrentView(view)
|
||||
if (docId) {
|
||||
setSelectedDocId(docId)
|
||||
}
|
||||
}
|
||||
|
||||
const handleLogin = (token: string) => {
|
||||
setIsAuthenticated(true)
|
||||
}
|
||||
|
||||
const handleLogout = () => {
|
||||
localStorage.removeItem('admin_token')
|
||||
setIsAuthenticated(false)
|
||||
setCurrentView('documents')
|
||||
}
|
||||
|
||||
if (!isAuthenticated) {
|
||||
return <Login onLogin={handleLogin} />
|
||||
}
|
||||
|
||||
const renderContent = () => {
|
||||
switch (currentView) {
|
||||
case 'dashboard':
|
||||
return <DashboardOverview onNavigate={handleNavigate} />
|
||||
case 'documents':
|
||||
return <Dashboard onNavigate={handleNavigate} />
|
||||
case 'detail':
|
||||
return (
|
||||
<DocumentDetail
|
||||
docId={selectedDocId || '1'}
|
||||
onBack={() => setCurrentView('documents')}
|
||||
/>
|
||||
)
|
||||
case 'demo':
|
||||
return <InferenceDemo />
|
||||
case 'training':
|
||||
return <Training onNavigate={handleNavigate} />
|
||||
case 'dataset-detail':
|
||||
return (
|
||||
<DatasetDetail
|
||||
datasetId={selectedDocId || ''}
|
||||
onBack={() => setCurrentView('training')}
|
||||
/>
|
||||
)
|
||||
case 'models':
|
||||
return <Models />
|
||||
default:
|
||||
return <DashboardOverview onNavigate={handleNavigate} />
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Layout activeView={currentView} onNavigate={handleNavigate} onLogout={handleLogout}>
|
||||
{renderContent()}
|
||||
</Layout>
|
||||
)
|
||||
}
|
||||
|
||||
export default App
|
||||
41
frontend/src/api/client.ts
Normal file
41
frontend/src/api/client.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import axios, { AxiosInstance, AxiosError } from 'axios'
|
||||
|
||||
const apiClient: AxiosInstance = axios.create({
|
||||
baseURL: import.meta.env.VITE_API_URL || 'http://localhost:8000',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 30000,
|
||||
})
|
||||
|
||||
apiClient.interceptors.request.use(
|
||||
(config) => {
|
||||
const token = localStorage.getItem('admin_token')
|
||||
if (token) {
|
||||
config.headers['X-Admin-Token'] = token
|
||||
}
|
||||
return config
|
||||
},
|
||||
(error) => {
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
apiClient.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error: AxiosError) => {
|
||||
if (error.response?.status === 401) {
|
||||
console.warn('Authentication required. Please set admin_token in localStorage.')
|
||||
// Don't redirect to avoid infinite loop
|
||||
// User should manually set: localStorage.setItem('admin_token', 'your-token')
|
||||
}
|
||||
|
||||
if (error.response?.status === 429) {
|
||||
console.error('Rate limit exceeded')
|
||||
}
|
||||
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
export default apiClient
|
||||
66
frontend/src/api/endpoints/annotations.ts
Normal file
66
frontend/src/api/endpoints/annotations.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
AnnotationItem,
|
||||
CreateAnnotationRequest,
|
||||
AnnotationOverrideRequest,
|
||||
} from '../types'
|
||||
|
||||
export const annotationsApi = {
|
||||
list: async (documentId: string): Promise<AnnotationItem[]> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/documents/${documentId}/annotations`
|
||||
)
|
||||
return data.annotations
|
||||
},
|
||||
|
||||
create: async (
|
||||
documentId: string,
|
||||
annotation: CreateAnnotationRequest
|
||||
): Promise<AnnotationItem> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/annotations`,
|
||||
annotation
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
update: async (
|
||||
documentId: string,
|
||||
annotationId: string,
|
||||
updates: Partial<CreateAnnotationRequest>
|
||||
): Promise<AnnotationItem> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`,
|
||||
updates
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (documentId: string, annotationId: string): Promise<void> => {
|
||||
await apiClient.delete(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`
|
||||
)
|
||||
},
|
||||
|
||||
verify: async (
|
||||
documentId: string,
|
||||
annotationId: string
|
||||
): Promise<{ annotation_id: string; is_verified: boolean; message: string }> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/verify`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
override: async (
|
||||
documentId: string,
|
||||
annotationId: string,
|
||||
overrideData: AnnotationOverrideRequest
|
||||
): Promise<{ annotation_id: string; source: string; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/override`,
|
||||
overrideData
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
/**
|
||||
* Tests for augmentation API endpoints.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { augmentationApi } from './augmentation'
|
||||
import apiClient from '../client'
|
||||
|
||||
// Mock the API client
|
||||
vi.mock('../client', () => ({
|
||||
default: {
|
||||
get: vi.fn(),
|
||||
post: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('augmentationApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getTypes', () => {
|
||||
it('should fetch augmentation types', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.getTypes()
|
||||
|
||||
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types')
|
||||
expect(result.augmentation_types).toHaveLength(1)
|
||||
expect(result.augmentation_types[0].name).toBe('gaussian_noise')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getPresets', () => {
|
||||
it('should fetch augmentation presets', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations' },
|
||||
{ name: 'moderate', description: 'Balanced augmentations' },
|
||||
],
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.getPresets()
|
||||
|
||||
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets')
|
||||
expect(result.presets).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('preview', () => {
|
||||
it('should preview single augmentation', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
preview_url: 'data:image/png;base64,xxx',
|
||||
original_url: 'data:image/png;base64,yyy',
|
||||
applied_params: { std: 15 },
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.preview('doc-123', {
|
||||
augmentation_type: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
})
|
||||
|
||||
expect(apiClient.post).toHaveBeenCalledWith(
|
||||
'/api/v1/admin/augmentation/preview/doc-123',
|
||||
{
|
||||
augmentation_type: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
},
|
||||
{ params: { page: 1 } }
|
||||
)
|
||||
expect(result.preview_url).toBe('data:image/png;base64,xxx')
|
||||
})
|
||||
|
||||
it('should support custom page number', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
preview_url: 'data:image/png;base64,xxx',
|
||||
original_url: 'data:image/png;base64,yyy',
|
||||
applied_params: {},
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
await augmentationApi.preview(
|
||||
'doc-123',
|
||||
{ augmentation_type: 'gaussian_noise', params: {} },
|
||||
2
|
||||
)
|
||||
|
||||
expect(apiClient.post).toHaveBeenCalledWith(
|
||||
'/api/v1/admin/augmentation/preview/doc-123',
|
||||
expect.anything(),
|
||||
{ params: { page: 2 } }
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
144
frontend/src/api/endpoints/augmentation.ts
Normal file
144
frontend/src/api/endpoints/augmentation.ts
Normal file
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Augmentation API endpoints.
|
||||
*
|
||||
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
|
||||
*/
|
||||
|
||||
import apiClient from '../client'
|
||||
|
||||
// Types
|
||||
export interface AugmentationTypeInfo {
|
||||
name: string
|
||||
description: string
|
||||
affects_geometry: boolean
|
||||
stage: string
|
||||
default_params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationTypesResponse {
|
||||
augmentation_types: AugmentationTypeInfo[]
|
||||
}
|
||||
|
||||
export interface PresetInfo {
|
||||
name: string
|
||||
description: string
|
||||
config?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface PresetsResponse {
|
||||
presets: PresetInfo[]
|
||||
}
|
||||
|
||||
export interface PreviewRequest {
|
||||
augmentation_type: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface PreviewResponse {
|
||||
preview_url: string
|
||||
original_url: string
|
||||
applied_params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationParams {
|
||||
enabled: boolean
|
||||
probability: number
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationConfig {
|
||||
perspective_warp?: AugmentationParams
|
||||
wrinkle?: AugmentationParams
|
||||
edge_damage?: AugmentationParams
|
||||
stain?: AugmentationParams
|
||||
lighting_variation?: AugmentationParams
|
||||
shadow?: AugmentationParams
|
||||
gaussian_blur?: AugmentationParams
|
||||
motion_blur?: AugmentationParams
|
||||
gaussian_noise?: AugmentationParams
|
||||
salt_pepper?: AugmentationParams
|
||||
paper_texture?: AugmentationParams
|
||||
scanner_artifacts?: AugmentationParams
|
||||
preserve_bboxes?: boolean
|
||||
seed?: number | null
|
||||
}
|
||||
|
||||
export interface BatchRequest {
|
||||
dataset_id: string
|
||||
config: AugmentationConfig
|
||||
output_name: string
|
||||
multiplier: number
|
||||
}
|
||||
|
||||
export interface BatchResponse {
|
||||
task_id: string
|
||||
status: string
|
||||
message: string
|
||||
estimated_images: number
|
||||
}
|
||||
|
||||
// API functions
|
||||
export const augmentationApi = {
|
||||
/**
|
||||
* Fetch available augmentation types.
|
||||
*/
|
||||
async getTypes(): Promise<AugmentationTypesResponse> {
|
||||
const response = await apiClient.get<AugmentationTypesResponse>(
|
||||
'/api/v1/admin/augmentation/types'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Fetch augmentation presets.
|
||||
*/
|
||||
async getPresets(): Promise<PresetsResponse> {
|
||||
const response = await apiClient.get<PresetsResponse>(
|
||||
'/api/v1/admin/augmentation/presets'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Preview a single augmentation on a document page.
|
||||
*/
|
||||
async preview(
|
||||
documentId: string,
|
||||
request: PreviewRequest,
|
||||
page: number = 1
|
||||
): Promise<PreviewResponse> {
|
||||
const response = await apiClient.post<PreviewResponse>(
|
||||
`/api/v1/admin/augmentation/preview/${documentId}`,
|
||||
request,
|
||||
{ params: { page } }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Preview full augmentation config on a document page.
|
||||
*/
|
||||
async previewConfig(
|
||||
documentId: string,
|
||||
config: AugmentationConfig,
|
||||
page: number = 1
|
||||
): Promise<PreviewResponse> {
|
||||
const response = await apiClient.post<PreviewResponse>(
|
||||
`/api/v1/admin/augmentation/preview-config/${documentId}`,
|
||||
config,
|
||||
{ params: { page } }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Create an augmented dataset.
|
||||
*/
|
||||
async createBatch(request: BatchRequest): Promise<BatchResponse> {
|
||||
const response = await apiClient.post<BatchResponse>(
|
||||
'/api/v1/admin/augmentation/batch',
|
||||
request
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
52
frontend/src/api/endpoints/datasets.ts
Normal file
52
frontend/src/api/endpoints/datasets.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetListResponse,
|
||||
DatasetResponse,
|
||||
DatasetTrainRequest,
|
||||
TrainingTaskResponse,
|
||||
} from '../types'
|
||||
|
||||
export const datasetsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DatasetListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/datasets', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
create: async (req: DatasetCreateRequest): Promise<DatasetResponse> => {
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/datasets', req)
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (datasetId: string): Promise<DatasetDetailResponse> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/training/datasets/${datasetId}`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
remove: async (datasetId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.delete(
|
||||
`/api/v1/admin/training/datasets/${datasetId}`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
trainFromDataset: async (
|
||||
datasetId: string,
|
||||
req: DatasetTrainRequest
|
||||
): Promise<TrainingTaskResponse> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/training/datasets/${datasetId}/train`,
|
||||
req
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
122
frontend/src/api/endpoints/documents.ts
Normal file
122
frontend/src/api/endpoints/documents.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
DocumentListResponse,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
UploadDocumentResponse,
|
||||
DocumentCategoriesResponse,
|
||||
} from '../types'
|
||||
|
||||
export const documentsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
category?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DocumentListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/documents', { params })
|
||||
return data
|
||||
},
|
||||
|
||||
getCategories: async (): Promise<DocumentCategoriesResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/documents/categories')
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
upload: async (
|
||||
file: File,
|
||||
options?: { groupKey?: string; category?: string }
|
||||
): Promise<UploadDocumentResponse> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const params: Record<string, string> = {}
|
||||
if (options?.groupKey) {
|
||||
params.group_key = options.groupKey
|
||||
}
|
||||
if (options?.category) {
|
||||
params.category = options.category
|
||||
}
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
batchUpload: async (
|
||||
files: File[],
|
||||
csvFile?: File
|
||||
): Promise<{ batch_id: string; message: string; documents_created: number }> => {
|
||||
const formData = new FormData()
|
||||
|
||||
files.forEach((file) => {
|
||||
formData.append('files', file)
|
||||
})
|
||||
|
||||
if (csvFile) {
|
||||
formData.append('csv_file', csvFile)
|
||||
}
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/admin/batch/upload', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (documentId: string): Promise<void> => {
|
||||
await apiClient.delete(`/api/v1/admin/documents/${documentId}`)
|
||||
},
|
||||
|
||||
updateStatus: async (
|
||||
documentId: string,
|
||||
status: string
|
||||
): Promise<DocumentItem> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/status`,
|
||||
null,
|
||||
{ params: { status } }
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
triggerAutoLabel: async (documentId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/auto-label`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
updateGroupKey: async (
|
||||
documentId: string,
|
||||
groupKey: string | null
|
||||
): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/group-key`,
|
||||
null,
|
||||
{ params: { group_key: groupKey } }
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
updateCategory: async (
|
||||
documentId: string,
|
||||
category: string
|
||||
): Promise<{ status: string; document_id: string; category: string; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/category`,
|
||||
{ category }
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
7
frontend/src/api/endpoints/index.ts
Normal file
7
frontend/src/api/endpoints/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export { documentsApi } from './documents'
|
||||
export { annotationsApi } from './annotations'
|
||||
export { trainingApi } from './training'
|
||||
export { inferenceApi } from './inference'
|
||||
export { datasetsApi } from './datasets'
|
||||
export { augmentationApi } from './augmentation'
|
||||
export { modelsApi } from './models'
|
||||
16
frontend/src/api/endpoints/inference.ts
Normal file
16
frontend/src/api/endpoints/inference.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import apiClient from '../client'
|
||||
import type { InferenceResponse } from '../types'
|
||||
|
||||
export const inferenceApi = {
|
||||
processDocument: async (file: File): Promise<InferenceResponse> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/infer', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
})
|
||||
return data
|
||||
},
|
||||
}
|
||||
55
frontend/src/api/endpoints/models.ts
Normal file
55
frontend/src/api/endpoints/models.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ModelVersionResponse,
|
||||
ActiveModelResponse,
|
||||
} from '../types'
|
||||
|
||||
export const modelsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<ModelVersionListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/models', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (versionId: string): Promise<ModelVersionDetailResponse> => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
getActive: async (): Promise<ActiveModelResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/models/active')
|
||||
return data
|
||||
},
|
||||
|
||||
activate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`)
|
||||
return data
|
||||
},
|
||||
|
||||
deactivate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`)
|
||||
return data
|
||||
},
|
||||
|
||||
archive: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`)
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (versionId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
reload: async (): Promise<{ message: string; reloaded: boolean }> => {
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/models/reload')
|
||||
return data
|
||||
},
|
||||
}
|
||||
74
frontend/src/api/endpoints/training.ts
Normal file
74
frontend/src/api/endpoints/training.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import apiClient from '../client'
|
||||
import type { TrainingModelsResponse, DocumentListResponse } from '../types'
|
||||
|
||||
export const trainingApi = {
|
||||
getDocumentsForTraining: async (params?: {
|
||||
has_annotations?: boolean
|
||||
min_annotation_count?: number
|
||||
exclude_used_in_training?: boolean
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DocumentListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/documents', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getModels: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<TrainingModelsResponse> => {
|
||||
const { data} = await apiClient.get('/api/v1/admin/training/models', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getTaskDetail: async (taskId: string) => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/training/tasks/${taskId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
startTraining: async (config: {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
model_base?: string
|
||||
}) => {
|
||||
// Convert frontend config to backend TrainingTaskCreate format
|
||||
const taskRequest = {
|
||||
name: config.name,
|
||||
task_type: 'yolo',
|
||||
description: config.description,
|
||||
config: {
|
||||
document_ids: config.document_ids,
|
||||
epochs: config.epochs,
|
||||
batch_size: config.batch_size,
|
||||
base_model: config.model_base,
|
||||
},
|
||||
}
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/tasks', taskRequest)
|
||||
return data
|
||||
},
|
||||
|
||||
cancelTask: async (taskId: string) => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/training/tasks/${taskId}/cancel`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
downloadModel: async (taskId: string): Promise<Blob> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/training/models/${taskId}/download`,
|
||||
{
|
||||
responseType: 'blob',
|
||||
}
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
364
frontend/src/api/types.ts
Normal file
364
frontend/src/api/types.ts
Normal file
@@ -0,0 +1,364 @@
|
||||
export interface DocumentItem {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
content_type: string
|
||||
page_count: number
|
||||
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
group_key: string | null
|
||||
category: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
annotation_count?: number
|
||||
annotation_sources?: {
|
||||
manual: number
|
||||
auto: number
|
||||
verified: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface DocumentListResponse {
|
||||
documents: DocumentItem[]
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
}
|
||||
|
||||
export interface AnnotationItem {
|
||||
annotation_id: string
|
||||
page_number: number
|
||||
class_id: number
|
||||
class_name: string
|
||||
bbox: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
normalized_bbox: {
|
||||
x_center: number
|
||||
y_center: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
text_value: string | null
|
||||
confidence: number | null
|
||||
source: 'manual' | 'auto'
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface DocumentDetailResponse {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
content_type: string
|
||||
page_count: number
|
||||
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
batch_id: string | null
|
||||
group_key: string | null
|
||||
category: string
|
||||
csv_field_values: Record<string, string> | null
|
||||
can_annotate: boolean
|
||||
annotation_lock_until: string | null
|
||||
annotations: AnnotationItem[]
|
||||
image_urls: string[]
|
||||
training_history: Array<{
|
||||
task_id: string
|
||||
name: string
|
||||
trained_at: string
|
||||
model_metrics: {
|
||||
mAP: number | null
|
||||
precision: number | null
|
||||
recall: number | null
|
||||
} | null
|
||||
}>
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface TrainingTask {
|
||||
task_id: string
|
||||
admin_token: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: 'pending' | 'running' | 'completed' | 'failed'
|
||||
task_type: string
|
||||
config: Record<string, unknown>
|
||||
started_at: string | null
|
||||
completed_at: string | null
|
||||
error_message: string | null
|
||||
result_metrics: Record<string, unknown>
|
||||
model_path: string | null
|
||||
document_count: number
|
||||
metrics_mAP: number | null
|
||||
metrics_precision: number | null
|
||||
metrics_recall: number | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionItem {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
metrics_mAP: number | null
|
||||
document_count: number
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface TrainingModelsResponse {
|
||||
models: ModelVersionItem[]
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
}
|
||||
|
||||
export interface ErrorResponse {
|
||||
detail: string
|
||||
}
|
||||
|
||||
export interface UploadDocumentResponse {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
page_count: number
|
||||
status: string
|
||||
category: string
|
||||
group_key: string | null
|
||||
auto_label_started: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface DocumentCategoriesResponse {
|
||||
categories: string[]
|
||||
total: number
|
||||
}
|
||||
|
||||
export interface CreateAnnotationRequest {
|
||||
page_number: number
|
||||
class_id: number
|
||||
bbox: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
text_value?: string
|
||||
}
|
||||
|
||||
export interface AnnotationOverrideRequest {
|
||||
text_value?: string
|
||||
bbox?: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
class_id?: number
|
||||
class_name?: string
|
||||
reason?: string
|
||||
}
|
||||
|
||||
export interface CrossValidationResult {
|
||||
is_valid: boolean
|
||||
payment_line_ocr: string | null
|
||||
payment_line_amount: string | null
|
||||
payment_line_account: string | null
|
||||
payment_line_account_type: 'bankgiro' | 'plusgiro' | null
|
||||
ocr_match: boolean | null
|
||||
amount_match: boolean | null
|
||||
bankgiro_match: boolean | null
|
||||
plusgiro_match: boolean | null
|
||||
details: string[]
|
||||
}
|
||||
|
||||
export interface InferenceResult {
|
||||
document_id: string
|
||||
document_type: string
|
||||
success: boolean
|
||||
fields: Record<string, string>
|
||||
confidence: Record<string, number>
|
||||
cross_validation: CrossValidationResult | null
|
||||
processing_time_ms: number
|
||||
visualization_url: string | null
|
||||
errors: string[]
|
||||
fallback_used: boolean
|
||||
}
|
||||
|
||||
export interface InferenceResponse {
|
||||
result: InferenceResult
|
||||
}
|
||||
|
||||
// Dataset types
|
||||
|
||||
export interface DatasetCreateRequest {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
train_ratio?: number
|
||||
val_ratio?: number
|
||||
seed?: number
|
||||
}
|
||||
|
||||
export interface DatasetResponse {
|
||||
dataset_id: string
|
||||
name: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface DatasetDocumentItem {
|
||||
document_id: string
|
||||
split: string
|
||||
page_count: number
|
||||
annotation_count: number
|
||||
}
|
||||
|
||||
export interface DatasetListItem {
|
||||
dataset_id: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: string
|
||||
training_status: string | null
|
||||
active_training_task_id: string | null
|
||||
total_documents: number
|
||||
total_images: number
|
||||
total_annotations: number
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface DatasetListResponse {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
datasets: DatasetListItem[]
|
||||
}
|
||||
|
||||
export interface DatasetDetailResponse {
|
||||
dataset_id: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: string
|
||||
training_status: string | null
|
||||
active_training_task_id: string | null
|
||||
train_ratio: number
|
||||
val_ratio: number
|
||||
seed: number
|
||||
total_documents: number
|
||||
total_images: number
|
||||
total_annotations: number
|
||||
dataset_path: string | null
|
||||
error_message: string | null
|
||||
documents: DatasetDocumentItem[]
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface AugmentationParams {
|
||||
enabled: boolean
|
||||
probability: number
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationTrainingConfig {
|
||||
gaussian_noise?: AugmentationParams
|
||||
perspective_warp?: AugmentationParams
|
||||
wrinkle?: AugmentationParams
|
||||
edge_damage?: AugmentationParams
|
||||
stain?: AugmentationParams
|
||||
lighting_variation?: AugmentationParams
|
||||
shadow?: AugmentationParams
|
||||
gaussian_blur?: AugmentationParams
|
||||
motion_blur?: AugmentationParams
|
||||
salt_pepper?: AugmentationParams
|
||||
paper_texture?: AugmentationParams
|
||||
scanner_artifacts?: AugmentationParams
|
||||
preserve_bboxes?: boolean
|
||||
seed?: number | null
|
||||
}
|
||||
|
||||
export interface DatasetTrainRequest {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
image_size?: number
|
||||
learning_rate?: number
|
||||
device?: string
|
||||
augmentation?: AugmentationTrainingConfig
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface TrainingTaskResponse {
|
||||
task_id: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
// Model Version types
|
||||
|
||||
export interface ModelVersionItem {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
metrics_mAP: number | null
|
||||
document_count: number
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionDetailResponse {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
description: string | null
|
||||
model_path: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
task_id: string | null
|
||||
dataset_id: string | null
|
||||
metrics_mAP: number | null
|
||||
metrics_precision: number | null
|
||||
metrics_recall: number | null
|
||||
document_count: number
|
||||
training_config: Record<string, unknown> | null
|
||||
file_size: number | null
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionListResponse {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
models: ModelVersionItem[]
|
||||
}
|
||||
|
||||
export interface ModelVersionResponse {
|
||||
version_id: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface ActiveModelResponse {
|
||||
has_active_model: boolean
|
||||
model: ModelVersionItem | null
|
||||
}
|
||||
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
@@ -0,0 +1,251 @@
|
||||
/**
|
||||
* Tests for AugmentationConfig component.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { AugmentationConfig } from './AugmentationConfig'
|
||||
import { augmentationApi } from '../api/endpoints/augmentation'
|
||||
import type { ReactNode } from 'react'
|
||||
|
||||
// Mock the API
|
||||
vi.mock('../api/endpoints/augmentation', () => ({
|
||||
augmentationApi: {
|
||||
getTypes: vi.fn(),
|
||||
getPresets: vi.fn(),
|
||||
preview: vi.fn(),
|
||||
previewConfig: vi.fn(),
|
||||
createBatch: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Default mock data
|
||||
const mockTypes = {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise to simulate sensor noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
{
|
||||
name: 'perspective_warp',
|
||||
description: 'Applies perspective transformation',
|
||||
affects_geometry: true,
|
||||
stage: 'geometric',
|
||||
default_params: { max_warp: 0.02 },
|
||||
},
|
||||
{
|
||||
name: 'gaussian_blur',
|
||||
description: 'Applies Gaussian blur',
|
||||
affects_geometry: false,
|
||||
stage: 'blur',
|
||||
default_params: { kernel_size: 5 },
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const mockPresets = {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations for high-quality documents' },
|
||||
{ name: 'moderate', description: 'Balanced augmentation settings' },
|
||||
{ name: 'aggressive', description: 'Strong augmentations for data diversity' },
|
||||
],
|
||||
}
|
||||
|
||||
// Test wrapper with QueryClient
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('AugmentationConfig', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes)
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets)
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render enable checkbox', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should be collapsed when disabled', () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
// Config options should not be visible
|
||||
expect(screen.queryByText(/preset/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should expand when enabled', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/preset/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('preset selection', () => {
|
||||
it('should display available presets', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('conservative')).toBeInTheDocument()
|
||||
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||
expect(screen.getByText('aggressive')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onConfigChange when preset is selected', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onConfigChange = vi.fn()
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={onConfigChange}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText('moderate'))
|
||||
|
||||
expect(onConfigChange).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('enable toggle', () => {
|
||||
it('should call onEnabledChange when checkbox is toggled', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onEnabledChange = vi.fn()
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={onEnabledChange}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i }))
|
||||
|
||||
expect(onEnabledChange).toHaveBeenCalledWith(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('augmentation types', () => {
|
||||
it('should display augmentation types when in custom mode', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
showCustomOptions={true}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should indicate which augmentations affect geometry', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
showCustomOptions={true}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
// perspective_warp affects geometry
|
||||
const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div')
|
||||
expect(perspectiveItem).toHaveTextContent(/affects bbox/i)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('loading state', () => {
|
||||
it('should show loading indicator while fetching types', () => {
|
||||
vi.mocked(augmentationApi.getTypes).mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
)
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
136
frontend/src/components/AugmentationConfig.tsx
Normal file
136
frontend/src/components/AugmentationConfig.tsx
Normal file
@@ -0,0 +1,136 @@
|
||||
/**
|
||||
* AugmentationConfig component for configuring image augmentation during training.
|
||||
*
|
||||
* Provides preset selection and optional custom augmentation type configuration.
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { Loader2, AlertTriangle } from 'lucide-react'
|
||||
import { useAugmentation } from '../hooks/useAugmentation'
|
||||
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||
|
||||
interface AugmentationConfigProps {
|
||||
enabled: boolean
|
||||
onEnabledChange: (enabled: boolean) => void
|
||||
config: Partial<AugmentationConfigType>
|
||||
onConfigChange: (config: Partial<AugmentationConfigType>) => void
|
||||
showCustomOptions?: boolean
|
||||
}
|
||||
|
||||
export const AugmentationConfig: React.FC<AugmentationConfigProps> = ({
|
||||
enabled,
|
||||
onEnabledChange,
|
||||
config,
|
||||
onConfigChange,
|
||||
showCustomOptions = false,
|
||||
}) => {
|
||||
const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation()
|
||||
|
||||
const isLoading = isLoadingTypes || isLoadingPresets
|
||||
|
||||
const handlePresetSelect = (presetName: string) => {
|
||||
const preset = presets.find((p) => p.name === presetName)
|
||||
if (preset && preset.config) {
|
||||
onConfigChange(preset.config as Partial<AugmentationConfigType>)
|
||||
} else {
|
||||
// Apply a basic config based on preset name
|
||||
const presetConfigs: Record<string, Partial<AugmentationConfigType>> = {
|
||||
conservative: {
|
||||
gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } },
|
||||
},
|
||||
moderate: {
|
||||
gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } },
|
||||
lighting_variation: { enabled: true, probability: 0.3, params: {} },
|
||||
perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } },
|
||||
},
|
||||
aggressive: {
|
||||
gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } },
|
||||
motion_blur: { enabled: true, probability: 0.3, params: {} },
|
||||
lighting_variation: { enabled: true, probability: 0.5, params: {} },
|
||||
shadow: { enabled: true, probability: 0.3, params: {} },
|
||||
perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } },
|
||||
wrinkle: { enabled: true, probability: 0.2, params: {} },
|
||||
stain: { enabled: true, probability: 0.2, params: {} },
|
||||
},
|
||||
}
|
||||
onConfigChange(presetConfigs[presetName] || {})
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="border border-warm-divider rounded-lg p-4 bg-warm-bg-secondary">
|
||||
{/* Enable checkbox */}
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={enabled}
|
||||
onChange={(e) => onEnabledChange(e.target.checked)}
|
||||
className="w-4 h-4 rounded border-warm-divider text-warm-state-info focus:ring-warm-state-info"
|
||||
aria-label="Enable augmentation"
|
||||
/>
|
||||
<span className="text-sm font-medium text-warm-text-secondary">Enable Augmentation</span>
|
||||
<span className="text-xs text-warm-text-muted">(Simulate real-world document conditions)</span>
|
||||
</label>
|
||||
|
||||
{/* Expanded content when enabled */}
|
||||
{enabled && (
|
||||
<div className="mt-4 space-y-4">
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-4" data-testid="augmentation-loading">
|
||||
<Loader2 className="w-5 h-5 animate-spin text-warm-state-info" />
|
||||
<span className="ml-2 text-sm text-warm-text-muted">Loading augmentation options...</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{/* Preset selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Preset</label>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{presets.map((preset) => (
|
||||
<button
|
||||
key={preset.name}
|
||||
onClick={() => handlePresetSelect(preset.name)}
|
||||
className="px-3 py-1.5 text-sm rounded-md border border-warm-divider hover:bg-warm-bg-tertiary transition-colors"
|
||||
title={preset.description}
|
||||
>
|
||||
{preset.name}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Custom options (if enabled) */}
|
||||
{showCustomOptions && (
|
||||
<div className="border-t border-warm-divider pt-4">
|
||||
<h4 className="text-sm font-medium text-warm-text-secondary mb-3">Augmentation Types</h4>
|
||||
<div className="grid gap-2">
|
||||
{augmentationTypes.map((type) => (
|
||||
<div
|
||||
key={type.name}
|
||||
className="flex items-center justify-between p-2 bg-warm-bg-primary rounded border border-warm-divider"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-warm-text-primary">{type.name}</span>
|
||||
{type.affects_geometry && (
|
||||
<span className="flex items-center gap-1 text-xs text-warm-state-warning">
|
||||
<AlertTriangle size={12} />
|
||||
affects bbox
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs text-warm-text-muted">{type.stage}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
32
frontend/src/components/Badge.test.tsx
Normal file
32
frontend/src/components/Badge.test.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { Badge } from './Badge';
|
||||
import { DocumentStatus } from '../types';
|
||||
|
||||
describe('Badge', () => {
|
||||
it('renders Exported badge with check icon', () => {
|
||||
render(<Badge status="Exported" />);
|
||||
expect(screen.getByText('Exported')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Pending status', () => {
|
||||
render(<Badge status={DocumentStatus.PENDING} />);
|
||||
expect(screen.getByText('Pending')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Verified status', () => {
|
||||
render(<Badge status={DocumentStatus.VERIFIED} />);
|
||||
expect(screen.getByText('Verified')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Labeled status', () => {
|
||||
render(<Badge status={DocumentStatus.LABELED} />);
|
||||
expect(screen.getByText('Labeled')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Partial status with warning indicator', () => {
|
||||
render(<Badge status={DocumentStatus.PARTIAL} />);
|
||||
expect(screen.getByText('Partial')).toBeInTheDocument();
|
||||
expect(screen.getByText('!')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
39
frontend/src/components/Badge.tsx
Normal file
39
frontend/src/components/Badge.tsx
Normal file
@@ -0,0 +1,39 @@
|
||||
import React from 'react';
|
||||
import { DocumentStatus } from '../types';
|
||||
import { Check } from 'lucide-react';
|
||||
|
||||
interface BadgeProps {
|
||||
status: DocumentStatus | 'Exported';
|
||||
}
|
||||
|
||||
export const Badge: React.FC<BadgeProps> = ({ status }) => {
|
||||
if (status === 'Exported') {
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-full text-xs font-medium bg-warm-selected text-warm-text-secondary">
|
||||
<Check size={12} strokeWidth={3} />
|
||||
Exported
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
const styles = {
|
||||
[DocumentStatus.PENDING]: "bg-white border border-warm-divider text-warm-text-secondary",
|
||||
[DocumentStatus.LABELED]: "bg-warm-text-secondary text-white border border-transparent",
|
||||
[DocumentStatus.VERIFIED]: "bg-warm-state-success/10 text-warm-state-success border border-warm-state-success/20",
|
||||
[DocumentStatus.PARTIAL]: "bg-warm-state-warning/10 text-warm-state-warning border border-warm-state-warning/20",
|
||||
};
|
||||
|
||||
const icons = {
|
||||
[DocumentStatus.VERIFIED]: <Check size={12} className="mr-1" />,
|
||||
[DocumentStatus.PARTIAL]: <span className="mr-1 text-[10px] font-bold">!</span>,
|
||||
[DocumentStatus.PENDING]: null,
|
||||
[DocumentStatus.LABELED]: null,
|
||||
}
|
||||
|
||||
return (
|
||||
<span className={`inline-flex items-center px-3 py-1 rounded-full text-xs font-medium border ${styles[status]}`}>
|
||||
{icons[status]}
|
||||
{status}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
38
frontend/src/components/Button.test.tsx
Normal file
38
frontend/src/components/Button.test.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { Button } from './Button';
|
||||
|
||||
describe('Button', () => {
|
||||
it('renders children text', () => {
|
||||
render(<Button>Click me</Button>);
|
||||
expect(screen.getByRole('button', { name: 'Click me' })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onClick handler', async () => {
|
||||
const user = userEvent.setup();
|
||||
const onClick = vi.fn();
|
||||
render(<Button onClick={onClick}>Click</Button>);
|
||||
await user.click(screen.getByRole('button'));
|
||||
expect(onClick).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('is disabled when disabled prop is set', () => {
|
||||
render(<Button disabled>Disabled</Button>);
|
||||
expect(screen.getByRole('button')).toBeDisabled();
|
||||
});
|
||||
|
||||
it('applies variant styles', () => {
|
||||
const { rerender } = render(<Button variant="primary">Primary</Button>);
|
||||
const btn = screen.getByRole('button');
|
||||
expect(btn.className).toContain('bg-warm-text-secondary');
|
||||
|
||||
rerender(<Button variant="secondary">Secondary</Button>);
|
||||
expect(screen.getByRole('button').className).toContain('border');
|
||||
});
|
||||
|
||||
it('applies size styles', () => {
|
||||
render(<Button size="sm">Small</Button>);
|
||||
expect(screen.getByRole('button').className).toContain('h-8');
|
||||
});
|
||||
});
|
||||
38
frontend/src/components/Button.tsx
Normal file
38
frontend/src/components/Button.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import React from 'react';
|
||||
|
||||
interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement> {
|
||||
variant?: 'primary' | 'secondary' | 'outline' | 'text';
|
||||
size?: 'sm' | 'md' | 'lg';
|
||||
}
|
||||
|
||||
export const Button: React.FC<ButtonProps> = ({
|
||||
variant = 'primary',
|
||||
size = 'md',
|
||||
className = '',
|
||||
children,
|
||||
...props
|
||||
}) => {
|
||||
const baseStyles = "inline-flex items-center justify-center rounded-md font-medium transition-all duration-150 ease-out active:scale-98 disabled:opacity-50 disabled:pointer-events-none";
|
||||
|
||||
const variants = {
|
||||
primary: "bg-warm-text-secondary text-white hover:bg-warm-text-primary shadow-sm",
|
||||
secondary: "bg-white border border-warm-divider text-warm-text-secondary hover:bg-warm-hover",
|
||||
outline: "bg-transparent border border-warm-text-secondary text-warm-text-secondary hover:bg-warm-hover",
|
||||
text: "text-warm-text-muted hover:text-warm-text-primary hover:bg-warm-hover",
|
||||
};
|
||||
|
||||
const sizes = {
|
||||
sm: "h-8 px-3 text-xs",
|
||||
md: "h-10 px-4 text-sm",
|
||||
lg: "h-12 px-6 text-base",
|
||||
};
|
||||
|
||||
return (
|
||||
<button
|
||||
className={`${baseStyles} ${variants[variant]} ${sizes[size]} ${className}`}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
};
|
||||
300
frontend/src/components/Dashboard.tsx
Normal file
300
frontend/src/components/Dashboard.tsx
Normal file
@@ -0,0 +1,300 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
|
||||
import { Badge } from './Badge'
|
||||
import { Button } from './Button'
|
||||
import { UploadModal } from './UploadModal'
|
||||
import { useDocuments, useCategories } from '../hooks/useDocuments'
|
||||
import type { DocumentItem } from '../api/types'
|
||||
|
||||
interface DashboardProps {
|
||||
onNavigate: (view: string, docId?: string) => void
|
||||
}
|
||||
|
||||
const getStatusForBadge = (status: string): string => {
|
||||
const statusMap: Record<string, string> = {
|
||||
pending: 'Pending',
|
||||
labeled: 'Labeled',
|
||||
verified: 'Verified',
|
||||
exported: 'Exported',
|
||||
}
|
||||
return statusMap[status] || status
|
||||
}
|
||||
|
||||
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
|
||||
if (doc.auto_label_status === 'running') {
|
||||
return 45
|
||||
}
|
||||
if (doc.auto_label_status === 'completed') {
|
||||
return 100
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
||||
const [isUploadOpen, setIsUploadOpen] = useState(false)
|
||||
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
|
||||
const [statusFilter, setStatusFilter] = useState<string>('')
|
||||
const [categoryFilter, setCategoryFilter] = useState<string>('')
|
||||
const [limit] = useState(20)
|
||||
const [offset] = useState(0)
|
||||
|
||||
const { categories } = useCategories()
|
||||
|
||||
const { documents, total, isLoading, error, refetch } = useDocuments({
|
||||
status: statusFilter || undefined,
|
||||
category: categoryFilter || undefined,
|
||||
limit,
|
||||
offset,
|
||||
})
|
||||
|
||||
const toggleSelection = (id: string) => {
|
||||
const newSet = new Set(selectedDocs)
|
||||
if (newSet.has(id)) {
|
||||
newSet.delete(id)
|
||||
} else {
|
||||
newSet.add(id)
|
||||
}
|
||||
setSelectedDocs(newSet)
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<div className="bg-red-50 border border-red-200 text-red-800 p-4 rounded-lg">
|
||||
Error loading documents. Please check your connection to the backend API.
|
||||
<button
|
||||
onClick={() => refetch()}
|
||||
className="ml-4 underline hover:no-underline"
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||
<div className="flex items-center justify-between mb-8">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||
Documents
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mt-1">
|
||||
{isLoading ? 'Loading...' : `${total} documents total`}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<Button variant="secondary" disabled={selectedDocs.size === 0}>
|
||||
Export Selection ({selectedDocs.size})
|
||||
</Button>
|
||||
<Button onClick={() => setIsUploadOpen(true)}>Upload Documents</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-4 mb-6 shadow-sm flex flex-wrap gap-4 items-center">
|
||||
<div className="relative flex-1 min-w-[200px]">
|
||||
<Search
|
||||
className="absolute left-3 top-1/2 -translate-y-1/2 text-warm-text-muted"
|
||||
size={16}
|
||||
/>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search documents..."
|
||||
className="w-full pl-9 pr-4 h-10 rounded-md border border-warm-border bg-white focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-3">
|
||||
<div className="relative">
|
||||
<select
|
||||
value={categoryFilter}
|
||||
onChange={(e) => setCategoryFilter(e.target.value)}
|
||||
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
|
||||
>
|
||||
<option value="">All Categories</option>
|
||||
{categories.map((cat) => (
|
||||
<option key={cat} value={cat}>
|
||||
{cat.charAt(0).toUpperCase() + cat.slice(1)}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<ChevronDown
|
||||
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||
size={14}
|
||||
/>
|
||||
</div>
|
||||
<div className="relative">
|
||||
<select
|
||||
value={statusFilter}
|
||||
onChange={(e) => setStatusFilter(e.target.value)}
|
||||
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
|
||||
>
|
||||
<option value="">All Statuses</option>
|
||||
<option value="pending">Pending</option>
|
||||
<option value="labeled">Labeled</option>
|
||||
<option value="verified">Verified</option>
|
||||
<option value="exported">Exported</option>
|
||||
</select>
|
||||
<ChevronDown
|
||||
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||
size={14}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<table className="w-full text-left border-collapse">
|
||||
<thead>
|
||||
<tr className="border-b border-warm-border bg-white">
|
||||
<th className="py-3 pl-6 pr-4 w-12">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary"
|
||||
/>
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Document Name
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Date
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Status
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Annotations
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Category
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Group
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
|
||||
Auto-label
|
||||
</th>
|
||||
<th className="py-3 px-4 w-12"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{isLoading ? (
|
||||
<tr>
|
||||
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
|
||||
Loading documents...
|
||||
</td>
|
||||
</tr>
|
||||
) : documents.length === 0 ? (
|
||||
<tr>
|
||||
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
|
||||
No documents found. Upload your first document to get started.
|
||||
</td>
|
||||
</tr>
|
||||
) : (
|
||||
documents.map((doc) => {
|
||||
const isSelected = selectedDocs.has(doc.document_id)
|
||||
const progress = getAutoLabelProgress(doc)
|
||||
|
||||
return (
|
||||
<tr
|
||||
key={doc.document_id}
|
||||
onClick={() => onNavigate('detail', doc.document_id)}
|
||||
className={`
|
||||
group transition-colors duration-150 cursor-pointer border-b border-warm-border last:border-0
|
||||
${isSelected ? 'bg-warm-selected' : 'hover:bg-warm-hover bg-white'}
|
||||
`}
|
||||
>
|
||||
<td
|
||||
className="py-4 pl-6 pr-4 relative"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
toggleSelection(doc.document_id)
|
||||
}}
|
||||
>
|
||||
{isSelected && (
|
||||
<div className="absolute left-0 top-0 bottom-0 w-[3px] bg-warm-state-info" />
|
||||
)}
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={isSelected}
|
||||
readOnly
|
||||
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary cursor-pointer"
|
||||
/>
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="p-2 bg-warm-bg rounded border border-warm-border text-warm-text-muted">
|
||||
<FileText size={16} />
|
||||
</div>
|
||||
<span className="font-medium text-warm-text-secondary">
|
||||
{doc.filename}
|
||||
</span>
|
||||
</div>
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary font-mono">
|
||||
{new Date(doc.created_at).toLocaleDateString()}
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
<Badge status={getStatusForBadge(doc.status)} />
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
||||
{doc.annotation_count || 0} annotations
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary capitalize">
|
||||
{doc.category || 'invoice'}
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-muted">
|
||||
{doc.group_key || '-'}
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
{doc.auto_label_status === 'running' && progress && (
|
||||
<div className="w-full">
|
||||
<div className="flex justify-between text-xs mb-1">
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
Running
|
||||
</span>
|
||||
<span className="text-warm-text-muted">{progress}%</span>
|
||||
</div>
|
||||
<div className="h-1.5 w-full bg-warm-selected rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-warm-state-info transition-all duration-500 ease-out"
|
||||
style={{ width: `${progress}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{doc.auto_label_status === 'completed' && (
|
||||
<span className="text-sm font-medium text-warm-state-success">
|
||||
Completed
|
||||
</span>
|
||||
)}
|
||||
{doc.auto_label_status === 'failed' && (
|
||||
<span className="text-sm font-medium text-warm-state-error">
|
||||
Failed
|
||||
</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="py-4 px-4 text-right">
|
||||
<button className="text-warm-text-muted hover:text-warm-text-secondary p-1 rounded hover:bg-black/5 transition-colors">
|
||||
<MoreHorizontal size={18} />
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<UploadModal
|
||||
isOpen={isUploadOpen}
|
||||
onClose={() => {
|
||||
setIsUploadOpen(false)
|
||||
refetch()
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
148
frontend/src/components/DashboardOverview.tsx
Normal file
148
frontend/src/components/DashboardOverview.tsx
Normal file
@@ -0,0 +1,148 @@
|
||||
import React from 'react'
|
||||
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import { useTraining } from '../hooks/useTraining'
|
||||
|
||||
interface DashboardOverviewProps {
|
||||
onNavigate: (view: string) => void
|
||||
}
|
||||
|
||||
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
|
||||
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
|
||||
const { models, isLoadingModels } = useTraining()
|
||||
|
||||
const stats = [
|
||||
{
|
||||
label: 'Total Documents',
|
||||
value: docsLoading ? '...' : totalDocs.toString(),
|
||||
icon: FileText,
|
||||
color: 'text-warm-text-primary',
|
||||
bgColor: 'bg-warm-bg',
|
||||
},
|
||||
{
|
||||
label: 'Labeled',
|
||||
value: '0',
|
||||
icon: CheckCircle,
|
||||
color: 'text-warm-state-success',
|
||||
bgColor: 'bg-green-50',
|
||||
},
|
||||
{
|
||||
label: 'Pending',
|
||||
value: '0',
|
||||
icon: Clock,
|
||||
color: 'text-warm-state-warning',
|
||||
bgColor: 'bg-yellow-50',
|
||||
},
|
||||
{
|
||||
label: 'Training Models',
|
||||
value: isLoadingModels ? '...' : models.length.toString(),
|
||||
icon: TrendingUp,
|
||||
color: 'text-warm-state-info',
|
||||
bgColor: 'bg-blue-50',
|
||||
},
|
||||
]
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||
{/* Header */}
|
||||
<div className="mb-8">
|
||||
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||
Dashboard
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mt-1">
|
||||
Overview of your document annotation system
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Stats Grid */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
|
||||
{stats.map((stat) => (
|
||||
<div
|
||||
key={stat.label}
|
||||
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
|
||||
>
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
|
||||
<stat.icon className={stat.color} size={24} />
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-2xl font-bold text-warm-text-primary mb-1">
|
||||
{stat.value}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">{stat.label}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Quick Actions */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
Quick Actions
|
||||
</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<Button onClick={() => onNavigate('documents')} className="justify-start">
|
||||
<FileText size={18} className="mr-2" />
|
||||
Manage Documents
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
|
||||
<Activity size={18} className="mr-2" />
|
||||
Start Training
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
|
||||
<TrendingUp size={18} className="mr-2" />
|
||||
View Models
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Recent Activity */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<div className="p-6 border-b border-warm-border">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary">
|
||||
Recent Activity
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6">
|
||||
<div className="text-center py-8 text-warm-text-muted">
|
||||
<Activity size={48} className="mx-auto mb-3 opacity-20" />
|
||||
<p className="text-sm">No recent activity</p>
|
||||
<p className="text-xs mt-1">
|
||||
Start by uploading documents or creating training jobs
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* System Status */}
|
||||
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
System Status
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Backend API</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Online
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Database</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Connected
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">GPU</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Available
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
176
frontend/src/components/DatasetDetail.tsx
Normal file
176
frontend/src/components/DatasetDetail.tsx
Normal file
@@ -0,0 +1,176 @@
|
||||
import React from 'react'
|
||||
import { ArrowLeft, Loader2, Play, AlertCircle, Check, Award } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDatasetDetail } from '../hooks/useDatasets'
|
||||
|
||||
interface DatasetDetailProps {
|
||||
datasetId: string
|
||||
onBack: () => void
|
||||
}
|
||||
|
||||
const SPLIT_STYLES: Record<string, string> = {
|
||||
train: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
val: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
test: 'bg-warm-state-success/10 text-warm-state-success',
|
||||
}
|
||||
|
||||
const STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
|
||||
building: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Building' },
|
||||
ready: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Ready' },
|
||||
trained: { bg: 'bg-purple-100', text: 'text-purple-700', label: 'Trained' },
|
||||
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
|
||||
archived: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Archived' },
|
||||
}
|
||||
|
||||
const TRAINING_STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
|
||||
pending: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Pending' },
|
||||
scheduled: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Scheduled' },
|
||||
running: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Training' },
|
||||
completed: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Completed' },
|
||||
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
|
||||
cancelled: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Cancelled' },
|
||||
}
|
||||
|
||||
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
|
||||
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-20 text-warm-text-muted">
|
||||
<Loader2 size={24} className="animate-spin mr-2" />Loading dataset...
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (error || !dataset) {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||
<ArrowLeft size={16} />Back
|
||||
</button>
|
||||
<p className="text-warm-state-error">Failed to load dataset.</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const statusConfig = STATUS_STYLES[dataset.status] || STATUS_STYLES.ready
|
||||
const trainingStatusConfig = dataset.training_status
|
||||
? TRAINING_STATUS_STYLES[dataset.training_status]
|
||||
: null
|
||||
|
||||
// Determine if training button should be shown and enabled
|
||||
const isTrainingInProgress = dataset.training_status === 'running' || dataset.training_status === 'pending'
|
||||
const canStartTraining = dataset.status === 'ready' && !isTrainingInProgress
|
||||
|
||||
// Determine status icon
|
||||
const statusIcon = dataset.status === 'trained'
|
||||
? <Award size={14} className="text-purple-700" />
|
||||
: dataset.status === 'ready'
|
||||
? <Check size={14} className="text-warm-state-success" />
|
||||
: dataset.status === 'failed'
|
||||
? <AlertCircle size={14} className="text-warm-state-error" />
|
||||
: dataset.status === 'building'
|
||||
? <Loader2 size={14} className="animate-spin text-warm-state-info" />
|
||||
: null
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
{/* Header */}
|
||||
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||
<ArrowLeft size={16} />Back to Datasets
|
||||
</button>
|
||||
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<div className="flex items-center gap-3 mb-1">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
|
||||
{dataset.name} {statusIcon}
|
||||
</h2>
|
||||
{/* Status Badge */}
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${statusConfig.bg} ${statusConfig.text}`}>
|
||||
{statusConfig.label}
|
||||
</span>
|
||||
{/* Training Status Badge */}
|
||||
{trainingStatusConfig && (
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${trainingStatusConfig.bg} ${trainingStatusConfig.text}`}>
|
||||
{isTrainingInProgress && <Loader2 size={12} className="mr-1 animate-spin" />}
|
||||
{trainingStatusConfig.label}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{dataset.description && (
|
||||
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
|
||||
)}
|
||||
</div>
|
||||
{/* Training Button */}
|
||||
{(dataset.status === 'ready' || dataset.status === 'trained') && (
|
||||
<Button
|
||||
disabled={isTrainingInProgress}
|
||||
className={isTrainingInProgress ? 'opacity-50 cursor-not-allowed' : ''}
|
||||
>
|
||||
{isTrainingInProgress ? (
|
||||
<><Loader2 size={14} className="mr-1 animate-spin" />Training...</>
|
||||
) : (
|
||||
<><Play size={14} className="mr-1" />Start Training</>
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{dataset.error_message && (
|
||||
<div className="bg-warm-state-error/10 border border-warm-state-error/20 rounded-lg p-4 mb-6 text-sm text-warm-state-error">
|
||||
{dataset.error_message}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Stats */}
|
||||
<div className="grid grid-cols-4 gap-4 mb-8">
|
||||
{[
|
||||
['Documents', dataset.total_documents],
|
||||
['Images', dataset.total_images],
|
||||
['Annotations', dataset.total_annotations],
|
||||
['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
|
||||
].map(([label, value]) => (
|
||||
<div key={String(label)} className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||
<p className="text-xs text-warm-text-muted uppercase font-semibold mb-1">{label}</p>
|
||||
<p className="text-2xl font-bold text-warm-text-primary font-mono">{value}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Document list */}
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Documents</h3>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||
<table className="w-full text-left">
|
||||
<thead className="bg-white border-b border-warm-border">
|
||||
<tr>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Split</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{dataset.documents.map(doc => (
|
||||
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||
<td className="py-3 px-4">
|
||||
<span className={`inline-flex px-2.5 py-1 rounded-full text-xs font-medium ${SPLIT_STYLES[doc.split] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||
{doc.split}
|
||||
</span>
|
||||
</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<p className="text-xs text-warm-text-muted mt-4">
|
||||
Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
|
||||
{dataset.dataset_path && <> | Path: <code className="text-xs">{dataset.dataset_path}</code></>}
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
567
frontend/src/components/DocumentDetail.tsx
Normal file
567
frontend/src/components/DocumentDetail.tsx
Normal file
@@ -0,0 +1,567 @@
|
||||
import React, { useState, useRef, useEffect } from 'react'
|
||||
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocumentDetail } from '../hooks/useDocumentDetail'
|
||||
import { useAnnotations } from '../hooks/useAnnotations'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import { documentsApi } from '../api/endpoints/documents'
|
||||
import type { AnnotationItem } from '../api/types'
|
||||
|
||||
interface DocumentDetailProps {
|
||||
docId: string
|
||||
onBack: () => void
|
||||
}
|
||||
|
||||
// Field class mapping from backend
|
||||
const FIELD_CLASSES: Record<number, string> = {
|
||||
0: 'invoice_number',
|
||||
1: 'invoice_date',
|
||||
2: 'invoice_due_date',
|
||||
3: 'ocr_number',
|
||||
4: 'bankgiro',
|
||||
5: 'plusgiro',
|
||||
6: 'amount',
|
||||
7: 'supplier_organisation_number',
|
||||
8: 'payment_line',
|
||||
9: 'customer_number',
|
||||
}
|
||||
|
||||
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
|
||||
const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
|
||||
const {
|
||||
createAnnotation,
|
||||
updateAnnotation,
|
||||
deleteAnnotation,
|
||||
isCreating,
|
||||
isDeleting,
|
||||
} = useAnnotations(docId)
|
||||
const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
|
||||
|
||||
const [selectedId, setSelectedId] = useState<string | null>(null)
|
||||
const [zoom, setZoom] = useState(100)
|
||||
const [isDrawing, setIsDrawing] = useState(false)
|
||||
const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
|
||||
const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
|
||||
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
|
||||
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
|
||||
const [selectedClassId, setSelectedClassId] = useState<number>(0)
|
||||
const [currentPage, setCurrentPage] = useState(1)
|
||||
const [imageSize, setImageSize] = useState<{ width: number; height: number } | null>(null)
|
||||
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null)
|
||||
|
||||
const canvasRef = useRef<HTMLDivElement>(null)
|
||||
const imageRef = useRef<HTMLImageElement>(null)
|
||||
|
||||
const [isMarkingComplete, setIsMarkingComplete] = useState(false)
|
||||
|
||||
const selectedAnnotation = annotations?.find((a) => a.annotation_id === selectedId)
|
||||
|
||||
// Handle mark as complete
|
||||
const handleMarkComplete = async () => {
|
||||
if (!annotations || annotations.length === 0) {
|
||||
alert('Please add at least one annotation before marking as complete.')
|
||||
return
|
||||
}
|
||||
|
||||
if (!confirm('Mark this document as labeled? This will save annotations to the database.')) {
|
||||
return
|
||||
}
|
||||
|
||||
setIsMarkingComplete(true)
|
||||
try {
|
||||
const result = await documentsApi.updateStatus(docId, 'labeled')
|
||||
alert(`Document marked as labeled. ${(result as any).fields_saved || annotations.length} annotations saved.`)
|
||||
onBack() // Return to document list
|
||||
} catch (error) {
|
||||
console.error('Failed to mark document as complete:', error)
|
||||
alert('Failed to mark document as complete. Please try again.')
|
||||
} finally {
|
||||
setIsMarkingComplete(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Load image via fetch with authentication header
|
||||
useEffect(() => {
|
||||
let objectUrl: string | null = null
|
||||
|
||||
const loadImage = async () => {
|
||||
if (!docId) return
|
||||
|
||||
const token = localStorage.getItem('admin_token')
|
||||
const imageUrl = `${import.meta.env.VITE_API_URL || 'http://localhost:8000'}/api/v1/admin/documents/${docId}/images/${currentPage}`
|
||||
|
||||
try {
|
||||
const response = await fetch(imageUrl, {
|
||||
headers: {
|
||||
'X-Admin-Token': token || '',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to load image: ${response.status}`)
|
||||
}
|
||||
|
||||
const blob = await response.blob()
|
||||
objectUrl = URL.createObjectURL(blob)
|
||||
setImageBlobUrl(objectUrl)
|
||||
} catch (error) {
|
||||
console.error('Failed to load image:', error)
|
||||
}
|
||||
}
|
||||
|
||||
loadImage()
|
||||
|
||||
// Cleanup: revoke object URL when component unmounts or page changes
|
||||
return () => {
|
||||
if (objectUrl) {
|
||||
URL.revokeObjectURL(objectUrl)
|
||||
}
|
||||
}
|
||||
}, [currentPage, docId])
|
||||
|
||||
// Load image size
|
||||
useEffect(() => {
|
||||
if (imageRef.current && imageRef.current.complete) {
|
||||
setImageSize({
|
||||
width: imageRef.current.naturalWidth,
|
||||
height: imageRef.current.naturalHeight,
|
||||
})
|
||||
}
|
||||
}, [imageBlobUrl])
|
||||
|
||||
const handleImageLoad = () => {
|
||||
if (imageRef.current) {
|
||||
setImageSize({
|
||||
width: imageRef.current.naturalWidth,
|
||||
height: imageRef.current.naturalHeight,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (!canvasRef.current || !imageSize) return
|
||||
const rect = canvasRef.current.getBoundingClientRect()
|
||||
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||
setIsDrawing(true)
|
||||
setDrawStart({ x, y })
|
||||
setDrawEnd({ x, y })
|
||||
}
|
||||
|
||||
const handleMouseMove = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (!isDrawing || !canvasRef.current || !imageSize) return
|
||||
const rect = canvasRef.current.getBoundingClientRect()
|
||||
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||
setDrawEnd({ x, y })
|
||||
}
|
||||
|
||||
const handleMouseUp = () => {
|
||||
if (!isDrawing || !drawStart || !drawEnd || !imageSize) {
|
||||
setIsDrawing(false)
|
||||
return
|
||||
}
|
||||
|
||||
const bbox_x = Math.min(drawStart.x, drawEnd.x)
|
||||
const bbox_y = Math.min(drawStart.y, drawEnd.y)
|
||||
const bbox_width = Math.abs(drawEnd.x - drawStart.x)
|
||||
const bbox_height = Math.abs(drawEnd.y - drawStart.y)
|
||||
|
||||
// Only create if box is large enough (min 10x10 pixels)
|
||||
if (bbox_width > 10 && bbox_height > 10) {
|
||||
createAnnotation({
|
||||
page_number: currentPage,
|
||||
class_id: selectedClassId,
|
||||
bbox: {
|
||||
x: Math.round(bbox_x),
|
||||
y: Math.round(bbox_y),
|
||||
width: Math.round(bbox_width),
|
||||
height: Math.round(bbox_height),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
setIsDrawing(false)
|
||||
setDrawStart(null)
|
||||
setDrawEnd(null)
|
||||
}
|
||||
|
||||
const handleDeleteAnnotation = (annotationId: string) => {
|
||||
if (confirm('Are you sure you want to delete this annotation?')) {
|
||||
deleteAnnotation(annotationId)
|
||||
setSelectedId(null)
|
||||
}
|
||||
}
|
||||
|
||||
if (isLoading || !document) {
|
||||
return (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<div className="text-warm-text-muted">Loading...</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Get current page annotations
|
||||
const pageAnnotations = annotations?.filter((a) => a.page_number === currentPage) || []
|
||||
|
||||
return (
|
||||
<div className="flex h-[calc(100vh-56px)] overflow-hidden">
|
||||
{/* Main Canvas Area */}
|
||||
<div className="flex-1 bg-warm-bg flex flex-col relative">
|
||||
{/* Toolbar */}
|
||||
<div className="h-14 border-b border-warm-border bg-white flex items-center justify-between px-4 z-10">
|
||||
<div className="flex items-center gap-4">
|
||||
<button
|
||||
onClick={onBack}
|
||||
className="p-2 hover:bg-warm-hover rounded-md text-warm-text-secondary transition-colors"
|
||||
>
|
||||
<ChevronLeft size={20} />
|
||||
</button>
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-primary">{document.filename}</h2>
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Page {currentPage} of {document.page_count}
|
||||
</p>
|
||||
</div>
|
||||
<div className="h-6 w-px bg-warm-divider mx-2" />
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||
onClick={() => setZoom((z) => Math.max(50, z - 10))}
|
||||
>
|
||||
<ZoomOut size={16} />
|
||||
</button>
|
||||
<span className="text-xs font-mono w-12 text-center text-warm-text-secondary">
|
||||
{zoom}%
|
||||
</span>
|
||||
<button
|
||||
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||
onClick={() => setZoom((z) => Math.min(200, z + 10))}
|
||||
>
|
||||
<ZoomIn size={16} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="secondary" size="sm">
|
||||
Auto-label
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="sm"
|
||||
onClick={handleMarkComplete}
|
||||
disabled={isMarkingComplete || document.status === 'labeled'}
|
||||
>
|
||||
<CheckCircle size={16} className="mr-1" />
|
||||
{isMarkingComplete ? 'Saving...' : document.status === 'labeled' ? 'Labeled' : 'Mark Complete'}
|
||||
</Button>
|
||||
{document.page_count > 1 && (
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={() => setCurrentPage((p) => Math.max(1, p - 1))}
|
||||
disabled={currentPage === 1}
|
||||
>
|
||||
Prev
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={() => setCurrentPage((p) => Math.min(document.page_count, p + 1))}
|
||||
disabled={currentPage === document.page_count}
|
||||
>
|
||||
Next
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Canvas Scroll Area */}
|
||||
<div className="flex-1 overflow-auto p-8 flex justify-center bg-warm-bg">
|
||||
<div
|
||||
ref={canvasRef}
|
||||
className="bg-white shadow-lg relative transition-transform duration-200 ease-out origin-top"
|
||||
style={{
|
||||
width: imageSize?.width || 800,
|
||||
height: imageSize?.height || 1132,
|
||||
transform: `scale(${zoom / 100})`,
|
||||
marginBottom: '100px',
|
||||
cursor: isDrawing ? 'crosshair' : 'default',
|
||||
}}
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseUp={handleMouseUp}
|
||||
onClick={() => setSelectedId(null)}
|
||||
>
|
||||
{/* Document Image */}
|
||||
{imageBlobUrl ? (
|
||||
<img
|
||||
ref={imageRef}
|
||||
src={imageBlobUrl}
|
||||
alt={`Page ${currentPage}`}
|
||||
className="w-full h-full object-contain select-none pointer-events-none"
|
||||
onLoad={handleImageLoad}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex items-center justify-center h-full">
|
||||
<div className="text-warm-text-muted">Loading image...</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Annotation Overlays */}
|
||||
{pageAnnotations.map((ann) => {
|
||||
const isSelected = selectedId === ann.annotation_id
|
||||
return (
|
||||
<div
|
||||
key={ann.annotation_id}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
setSelectedId(ann.annotation_id)
|
||||
}}
|
||||
className={`
|
||||
absolute group cursor-pointer transition-all duration-100
|
||||
${
|
||||
ann.source === 'auto'
|
||||
? 'border border-dashed border-warm-text-muted bg-transparent'
|
||||
: 'border-2 border-warm-text-secondary bg-warm-text-secondary/5'
|
||||
}
|
||||
${
|
||||
isSelected
|
||||
? 'border-2 border-warm-state-info ring-4 ring-warm-state-info/10 z-20'
|
||||
: 'hover:bg-warm-state-info/5 z-10'
|
||||
}
|
||||
`}
|
||||
style={{
|
||||
left: ann.bbox.x,
|
||||
top: ann.bbox.y,
|
||||
width: ann.bbox.width,
|
||||
height: ann.bbox.height,
|
||||
}}
|
||||
>
|
||||
{/* Label Tag */}
|
||||
<div
|
||||
className={`
|
||||
absolute -top-6 left-0 text-[10px] uppercase font-bold px-1.5 py-0.5 rounded-sm tracking-wide shadow-sm whitespace-nowrap
|
||||
${
|
||||
isSelected
|
||||
? 'bg-warm-state-info text-white'
|
||||
: 'bg-white text-warm-text-secondary border border-warm-border'
|
||||
}
|
||||
`}
|
||||
>
|
||||
{ann.class_name}
|
||||
</div>
|
||||
|
||||
{/* Resize Handles (Visual only) */}
|
||||
{isSelected && (
|
||||
<>
|
||||
<div className="absolute -top-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -top-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -bottom-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -bottom-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
|
||||
{/* Drawing Box Preview */}
|
||||
{isDrawing && drawStart && drawEnd && (
|
||||
<div
|
||||
className="absolute border-2 border-warm-state-info bg-warm-state-info/10 z-30 pointer-events-none"
|
||||
style={{
|
||||
left: Math.min(drawStart.x, drawEnd.x),
|
||||
top: Math.min(drawStart.y, drawEnd.y),
|
||||
width: Math.abs(drawEnd.x - drawStart.x),
|
||||
height: Math.abs(drawEnd.y - drawStart.y),
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right Sidebar */}
|
||||
<div className="w-80 bg-white border-l border-warm-border flex flex-col shadow-[-4px_0_15px_-3px_rgba(0,0,0,0.03)] z-20">
|
||||
{/* Field Selector */}
|
||||
<div className="p-4 border-b border-warm-border">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Draw Annotation</h3>
|
||||
<div className="space-y-2">
|
||||
<label className="block text-xs text-warm-text-muted mb-1">Select Field Type</label>
|
||||
<select
|
||||
value={selectedClassId}
|
||||
onChange={(e) => setSelectedClassId(Number(e.target.value))}
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
>
|
||||
{Object.entries(FIELD_CLASSES).map(([id, name]) => (
|
||||
<option key={id} value={id}>
|
||||
{name.replace(/_/g, ' ')}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-warm-text-muted mt-2">
|
||||
Click and drag on the document to create a bounding box
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Document Info Card */}
|
||||
<div className="p-4 border-b border-warm-border">
|
||||
<div className="bg-white rounded-lg border border-warm-border p-4 shadow-sm">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Document Info</h3>
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Status</span>
|
||||
<span className="text-warm-text-secondary font-medium capitalize">
|
||||
{document.status}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Size</span>
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{(document.file_size / 1024 / 1024).toFixed(2)} MB
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Uploaded</span>
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{new Date(document.created_at).toLocaleDateString()}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-warm-text-muted">Group</span>
|
||||
{isEditingGroupKey ? (
|
||||
<div className="flex items-center gap-1">
|
||||
<input
|
||||
type="text"
|
||||
value={editGroupKeyValue}
|
||||
onChange={(e) => setEditGroupKeyValue(e.target.value)}
|
||||
className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
placeholder="group key"
|
||||
autoFocus
|
||||
/>
|
||||
<button
|
||||
onClick={() => {
|
||||
updateGroupKey(
|
||||
{ documentId: docId, groupKey: editGroupKeyValue.trim() || null },
|
||||
{
|
||||
onSuccess: () => {
|
||||
setIsEditingGroupKey(false)
|
||||
refetch()
|
||||
},
|
||||
onError: () => {
|
||||
alert('Failed to update group key. Please try again.')
|
||||
},
|
||||
}
|
||||
)
|
||||
}}
|
||||
disabled={isUpdatingGroupKey}
|
||||
className="p-0.5 text-warm-state-success hover:bg-warm-hover rounded"
|
||||
>
|
||||
<Check size={14} />
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setIsEditingGroupKey(false)
|
||||
setEditGroupKeyValue(document.group_key || '')
|
||||
}}
|
||||
className="p-0.5 text-warm-state-error hover:bg-warm-hover rounded"
|
||||
>
|
||||
<X size={14} />
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center gap-1">
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{document.group_key || '-'}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => {
|
||||
setEditGroupKeyValue(document.group_key || '')
|
||||
setIsEditingGroupKey(true)
|
||||
}}
|
||||
className="p-0.5 text-warm-text-muted hover:text-warm-text-secondary hover:bg-warm-hover rounded"
|
||||
>
|
||||
<Edit2 size={12} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Annotations List */}
|
||||
<div className="flex-1 overflow-y-auto p-4">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary">Annotations</h3>
|
||||
<span className="text-xs text-warm-text-muted">{pageAnnotations.length} items</span>
|
||||
</div>
|
||||
|
||||
{pageAnnotations.length === 0 ? (
|
||||
<div className="text-center py-8 text-warm-text-muted">
|
||||
<Tag size={48} className="mx-auto mb-3 opacity-20" />
|
||||
<p className="text-sm">No annotations yet</p>
|
||||
<p className="text-xs mt-1">Draw on the document to add annotations</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-3">
|
||||
{pageAnnotations.map((ann) => (
|
||||
<div
|
||||
key={ann.annotation_id}
|
||||
onClick={() => setSelectedId(ann.annotation_id)}
|
||||
className={`
|
||||
group p-3 rounded-md border transition-all duration-150 cursor-pointer
|
||||
${
|
||||
selectedId === ann.annotation_id
|
||||
? 'bg-warm-bg border-warm-state-info shadow-sm'
|
||||
: 'bg-white border-warm-border hover:border-warm-text-muted'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex justify-between items-start mb-1">
|
||||
<span className="text-xs font-bold text-warm-text-secondary uppercase tracking-wider">
|
||||
{ann.class_name.replace(/_/g, ' ')}
|
||||
</span>
|
||||
{selectedId === ann.annotation_id && (
|
||||
<div className="flex gap-1">
|
||||
<button
|
||||
onClick={() => handleDeleteAnnotation(ann.annotation_id)}
|
||||
className="text-warm-text-muted hover:text-warm-state-error"
|
||||
disabled={isDeleting}
|
||||
>
|
||||
<Trash2 size={12} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<p className="text-sm text-warm-text-muted font-mono truncate">
|
||||
{ann.text_value || '(no text)'}
|
||||
</p>
|
||||
<div className="flex items-center gap-2 mt-2">
|
||||
<span
|
||||
className={`text-[10px] px-1.5 py-0.5 rounded ${
|
||||
ann.source === 'auto'
|
||||
? 'bg-blue-50 text-blue-700'
|
||||
: 'bg-green-50 text-green-700'
|
||||
}`}
|
||||
>
|
||||
{ann.source}
|
||||
</span>
|
||||
{ann.confidence && (
|
||||
<span className="text-[10px] text-warm-text-muted">
|
||||
{(ann.confidence * 100).toFixed(0)}%
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
466
frontend/src/components/InferenceDemo.tsx
Normal file
466
frontend/src/components/InferenceDemo.tsx
Normal file
@@ -0,0 +1,466 @@
|
||||
import React, { useState, useRef } from 'react'
|
||||
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { inferenceApi } from '../api/endpoints'
|
||||
import type { InferenceResult } from '../api/types'
|
||||
|
||||
export const InferenceDemo: React.FC = () => {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [selectedFile, setSelectedFile] = useState<File | null>(null)
|
||||
const [isProcessing, setIsProcessing] = useState(false)
|
||||
const [result, setResult] = useState<InferenceResult | null>(null)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const handleFileSelect = (file: File | null) => {
|
||||
if (!file) return
|
||||
|
||||
const validTypes = ['application/pdf', 'image/png', 'image/jpeg', 'image/jpg']
|
||||
if (!validTypes.includes(file.type)) {
|
||||
setError('Please upload a PDF, PNG, or JPG file')
|
||||
return
|
||||
}
|
||||
|
||||
if (file.size > 50 * 1024 * 1024) {
|
||||
setError('File size must be less than 50MB')
|
||||
return
|
||||
}
|
||||
|
||||
setSelectedFile(file)
|
||||
setResult(null)
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(false)
|
||||
if (e.dataTransfer.files.length > 0) {
|
||||
handleFileSelect(e.dataTransfer.files[0])
|
||||
}
|
||||
}
|
||||
|
||||
const handleBrowseClick = () => {
|
||||
fileInputRef.current?.click()
|
||||
}
|
||||
|
||||
const handleProcess = async () => {
|
||||
if (!selectedFile) return
|
||||
|
||||
setIsProcessing(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const response = await inferenceApi.processDocument(selectedFile)
|
||||
console.log('API Response:', response)
|
||||
console.log('Visualization URL:', response.result?.visualization_url)
|
||||
setResult(response.result)
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Processing failed')
|
||||
} finally {
|
||||
setIsProcessing(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleReset = () => {
|
||||
setSelectedFile(null)
|
||||
setResult(null)
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const formatFieldName = (field: string): string => {
|
||||
const fieldNames: Record<string, string> = {
|
||||
InvoiceNumber: 'Invoice Number',
|
||||
InvoiceDate: 'Invoice Date',
|
||||
InvoiceDueDate: 'Due Date',
|
||||
OCR: 'OCR Number',
|
||||
Amount: 'Amount',
|
||||
Bankgiro: 'Bankgiro',
|
||||
Plusgiro: 'Plusgiro',
|
||||
supplier_org_number: 'Supplier Org Number',
|
||||
customer_number: 'Customer Number',
|
||||
payment_line: 'Payment Line',
|
||||
}
|
||||
return fieldNames[field] || field
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="max-w-7xl mx-auto px-4 py-6 space-y-6">
|
||||
{/* Header */}
|
||||
<div className="text-center">
|
||||
<h2 className="text-3xl font-bold text-warm-text-primary mb-2">
|
||||
Invoice Extraction Demo
|
||||
</h2>
|
||||
<p className="text-warm-text-muted">
|
||||
Upload a Swedish invoice to see our AI-powered field extraction in action
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Upload Area */}
|
||||
{!result && (
|
||||
<div className="max-w-2xl mx-auto">
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-8 shadow-sm">
|
||||
<div
|
||||
className={`
|
||||
relative h-72 rounded-xl border-2 border-dashed transition-all duration-200
|
||||
${isDragging
|
||||
? 'border-warm-text-secondary bg-warm-selected scale-[1.02]'
|
||||
: 'border-warm-divider bg-warm-bg hover:bg-warm-hover hover:border-warm-text-secondary/50'
|
||||
}
|
||||
${isProcessing ? 'opacity-60 pointer-events-none' : 'cursor-pointer'}
|
||||
`}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(true)
|
||||
}}
|
||||
onDragLeave={() => setIsDragging(false)}
|
||||
onDrop={handleDrop}
|
||||
onClick={handleBrowseClick}
|
||||
>
|
||||
<div className="absolute inset-0 flex flex-col items-center justify-center gap-6">
|
||||
{isProcessing ? (
|
||||
<>
|
||||
<Loader2 size={56} className="text-warm-text-secondary animate-spin" />
|
||||
<div className="text-center">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||
Processing invoice...
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">
|
||||
This may take a few moments
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
) : selectedFile ? (
|
||||
<>
|
||||
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||
<FileText size={40} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center px-4">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||
{selectedFile.name}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">
|
||||
{(selectedFile.size / 1024 / 1024).toFixed(2)} MB
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||
<UploadCloud size={40} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center px-4">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-2">
|
||||
Drag & drop invoice here
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted mb-3">
|
||||
or{' '}
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
browse files
|
||||
</span>
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Supports PDF, PNG, JPG (up to 50MB)
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".pdf,image/*"
|
||||
className="hidden"
|
||||
onChange={(e) => handleFileSelect(e.target.files?.[0] || null)}
|
||||
/>
|
||||
|
||||
{error && (
|
||||
<div className="mt-5 p-4 bg-red-50 border border-red-200 rounded-lg flex items-start gap-3">
|
||||
<AlertCircle size={18} className="text-red-600 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-sm text-red-800 font-medium">{error}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedFile && !isProcessing && (
|
||||
<div className="mt-6 flex gap-3 justify-end">
|
||||
<Button variant="secondary" onClick={handleReset}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleProcess}>Process Invoice</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Results */}
|
||||
{result && (
|
||||
<div className="space-y-6">
|
||||
{/* Status Header */}
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border shadow-sm overflow-hidden">
|
||||
<div className="p-6 flex items-center justify-between border-b border-warm-divider">
|
||||
<div className="flex items-center gap-4">
|
||||
{result.success ? (
|
||||
<div className="p-3 bg-green-100 rounded-xl">
|
||||
<CheckCircle2 size={28} className="text-green-600" />
|
||||
</div>
|
||||
) : (
|
||||
<div className="p-3 bg-yellow-100 rounded-xl">
|
||||
<AlertCircle size={28} className="text-yellow-600" />
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<h3 className="text-xl font-bold text-warm-text-primary">
|
||||
{result.success ? 'Extraction Complete' : 'Partial Results'}
|
||||
</h3>
|
||||
<p className="text-sm text-warm-text-muted mt-0.5">
|
||||
Document ID: <span className="font-mono">{result.document_id}</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button variant="secondary" onClick={handleReset}>
|
||||
Process Another
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="px-6 py-4 bg-warm-bg/50 flex items-center gap-6 text-sm">
|
||||
<div className="flex items-center gap-2 text-warm-text-secondary">
|
||||
<Clock size={16} />
|
||||
<span className="font-medium">
|
||||
{result.processing_time_ms.toFixed(0)}ms
|
||||
</span>
|
||||
</div>
|
||||
{result.fallback_used && (
|
||||
<span className="px-3 py-1.5 bg-warm-selected rounded-md text-warm-text-secondary font-medium text-xs">
|
||||
Fallback OCR Used
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Main Content Grid */}
|
||||
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
|
||||
{/* Left Column: Extracted Fields */}
|
||||
<div className="lg:col-span-2 space-y-6">
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Extracted Fields
|
||||
</h3>
|
||||
<div className="flex flex-wrap gap-4">
|
||||
{Object.entries(result.fields).map(([field, value]) => {
|
||||
const confidence = result.confidence[field]
|
||||
return (
|
||||
<div
|
||||
key={field}
|
||||
className="p-4 bg-warm-bg/70 rounded-lg border border-warm-divider hover:border-warm-text-secondary/30 transition-colors w-[calc(50%-0.5rem)]"
|
||||
>
|
||||
<div className="text-xs font-semibold text-warm-text-muted uppercase tracking-wide mb-2">
|
||||
{formatFieldName(field)}
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary mb-2 min-h-[1.5rem]">
|
||||
{value || <span className="text-warm-text-muted italic">N/A</span>}
|
||||
</div>
|
||||
{confidence && (
|
||||
<div className="flex items-center gap-1.5 text-xs font-medium text-warm-text-secondary">
|
||||
<CheckCircle2 size={13} />
|
||||
<span>{(confidence * 100).toFixed(1)}%</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Visualization */}
|
||||
{result.visualization_url && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Detection Visualization
|
||||
</h3>
|
||||
<div className="bg-warm-bg rounded-lg overflow-hidden border border-warm-divider">
|
||||
<img
|
||||
src={`${import.meta.env.VITE_API_URL || 'http://localhost:8000'}${result.visualization_url}`}
|
||||
alt="Detection visualization"
|
||||
className="w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right Column: Cross-Validation & Errors */}
|
||||
<div className="space-y-6">
|
||||
{/* Cross-Validation */}
|
||||
{result.cross_validation && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Payment Line Validation
|
||||
</h3>
|
||||
|
||||
<div
|
||||
className={`
|
||||
p-4 rounded-lg mb-4 flex items-center gap-3
|
||||
${result.cross_validation.is_valid
|
||||
? 'bg-green-50 border border-green-200'
|
||||
: 'bg-yellow-50 border border-yellow-200'
|
||||
}
|
||||
`}
|
||||
>
|
||||
{result.cross_validation.is_valid ? (
|
||||
<>
|
||||
<CheckCircle2 size={22} className="text-green-600 flex-shrink-0" />
|
||||
<span className="font-bold text-green-800">All Fields Match</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<AlertCircle size={22} className="text-yellow-600 flex-shrink-0" />
|
||||
<span className="font-bold text-yellow-800">Mismatch Detected</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2.5">
|
||||
{result.cross_validation.payment_line_ocr && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${result.cross_validation.ocr_match === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: result.cross_validation.ocr_match === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
OCR NUMBER
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_ocr}
|
||||
</div>
|
||||
</div>
|
||||
{result.cross_validation.ocr_match === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{result.cross_validation.ocr_match === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.cross_validation.payment_line_amount && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${result.cross_validation.amount_match === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: result.cross_validation.amount_match === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
AMOUNT
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_amount}
|
||||
</div>
|
||||
</div>
|
||||
{result.cross_validation.amount_match === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{result.cross_validation.amount_match === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.cross_validation.payment_line_account && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: (result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
{result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? 'BANKGIRO'
|
||||
: 'PLUSGIRO'}
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_account}
|
||||
</div>
|
||||
</div>
|
||||
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{result.cross_validation.details.length > 0 && (
|
||||
<div className="mt-4 p-3 bg-warm-bg/70 rounded-lg text-xs text-warm-text-secondary leading-relaxed border border-warm-divider">
|
||||
{result.cross_validation.details[result.cross_validation.details.length - 1]}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Errors */}
|
||||
{result.errors.length > 0 && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-red-500 rounded-full"></span>
|
||||
Issues
|
||||
</h3>
|
||||
<div className="space-y-2.5">
|
||||
{result.errors.map((err, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="p-3 bg-yellow-50 border border-yellow-200 rounded-lg flex items-start gap-3"
|
||||
>
|
||||
<AlertCircle size={16} className="text-yellow-600 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-xs text-yellow-800 leading-relaxed">{err}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
102
frontend/src/components/Layout.tsx
Normal file
102
frontend/src/components/Layout.tsx
Normal file
@@ -0,0 +1,102 @@
|
||||
import React, { useState } from 'react';
|
||||
import { Box, LayoutTemplate, Users, BookOpen, LogOut, Sparkles } from 'lucide-react';
|
||||
|
||||
interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
activeView: string;
|
||||
onNavigate: (view: string) => void;
|
||||
onLogout?: () => void;
|
||||
}
|
||||
|
||||
export const Layout: React.FC<LayoutProps> = ({ children, activeView, onNavigate, onLogout }) => {
|
||||
const [showDropdown, setShowDropdown] = useState(false);
|
||||
const navItems = [
|
||||
{ id: 'dashboard', label: 'Dashboard', icon: LayoutTemplate },
|
||||
{ id: 'demo', label: 'Demo', icon: Sparkles },
|
||||
{ id: 'training', label: 'Training', icon: Box }, // Mapped to Compliants visually in prompt, using logical name
|
||||
{ id: 'documents', label: 'Documents', icon: BookOpen },
|
||||
{ id: 'models', label: 'Models', icon: Users }, // Contacts in prompt, mapped to models for this use case
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-warm-bg font-sans text-warm-text-primary flex flex-col">
|
||||
{/* Top Navigation */}
|
||||
<nav className="h-14 bg-warm-bg border-b border-warm-border px-6 flex items-center justify-between shrink-0 sticky top-0 z-40">
|
||||
<div className="flex items-center gap-8">
|
||||
{/* Logo */}
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-8 h-8 bg-warm-text-primary rounded-full flex items-center justify-center text-white">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="3" strokeLinecap="round" strokeLinejoin="round">
|
||||
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Nav Links */}
|
||||
<div className="flex h-14">
|
||||
{navItems.map(item => {
|
||||
const isActive = activeView === item.id || (activeView === 'detail' && item.id === 'documents');
|
||||
return (
|
||||
<button
|
||||
key={item.id}
|
||||
onClick={() => onNavigate(item.id)}
|
||||
className={`
|
||||
relative px-4 h-full flex items-center text-sm font-medium transition-colors
|
||||
${isActive ? 'text-warm-text-primary' : 'text-warm-text-muted hover:text-warm-text-secondary'}
|
||||
`}
|
||||
>
|
||||
{item.label}
|
||||
{isActive && (
|
||||
<div className="absolute bottom-0 left-0 right-0 h-0.5 bg-warm-text-secondary rounded-t-full mx-2" />
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* User Profile */}
|
||||
<div className="flex items-center gap-3 pl-6 border-l border-warm-border h-6 relative">
|
||||
<button
|
||||
onClick={() => setShowDropdown(!showDropdown)}
|
||||
className="w-8 h-8 rounded-full bg-warm-selected flex items-center justify-center text-xs font-semibold text-warm-text-secondary border border-warm-divider hover:bg-warm-hover transition-colors"
|
||||
>
|
||||
AD
|
||||
</button>
|
||||
|
||||
{showDropdown && (
|
||||
<>
|
||||
<div
|
||||
className="fixed inset-0 z-10"
|
||||
onClick={() => setShowDropdown(false)}
|
||||
/>
|
||||
<div className="absolute right-0 top-10 w-48 bg-warm-card border border-warm-border rounded-lg shadow-modal z-20">
|
||||
<div className="p-3 border-b border-warm-border">
|
||||
<p className="text-sm font-medium text-warm-text-primary">Admin User</p>
|
||||
<p className="text-xs text-warm-text-muted mt-0.5">Authenticated</p>
|
||||
</div>
|
||||
{onLogout && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setShowDropdown(false)
|
||||
onLogout()
|
||||
}}
|
||||
className="w-full px-3 py-2 text-left text-sm text-warm-text-secondary hover:bg-warm-hover transition-colors flex items-center gap-2"
|
||||
>
|
||||
<LogOut size={14} />
|
||||
Sign Out
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="flex-1 overflow-auto">
|
||||
{children}
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
188
frontend/src/components/Login.tsx
Normal file
188
frontend/src/components/Login.tsx
Normal file
@@ -0,0 +1,188 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Button } from './Button'
|
||||
|
||||
interface LoginProps {
|
||||
onLogin: (token: string) => void
|
||||
}
|
||||
|
||||
export const Login: React.FC<LoginProps> = ({ onLogin }) => {
|
||||
const [token, setToken] = useState('')
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [isCreating, setIsCreating] = useState(false)
|
||||
const [error, setError] = useState('')
|
||||
const [createdToken, setCreatedToken] = useState('')
|
||||
|
||||
const handleLoginWithToken = () => {
|
||||
if (!token.trim()) {
|
||||
setError('Please enter a token')
|
||||
return
|
||||
}
|
||||
localStorage.setItem('admin_token', token.trim())
|
||||
onLogin(token.trim())
|
||||
}
|
||||
|
||||
const handleCreateToken = async () => {
|
||||
if (!name.trim()) {
|
||||
setError('Please enter a token name')
|
||||
return
|
||||
}
|
||||
|
||||
setIsCreating(true)
|
||||
setError('')
|
||||
|
||||
try {
|
||||
const response = await fetch('http://localhost:8000/api/v1/admin/auth/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: name.trim(),
|
||||
description: description.trim() || undefined,
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to create token')
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
setCreatedToken(data.token)
|
||||
setToken(data.token)
|
||||
setError('')
|
||||
} catch (err) {
|
||||
setError('Failed to create token. Please check your connection.')
|
||||
console.error(err)
|
||||
} finally {
|
||||
setIsCreating(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleUseCreatedToken = () => {
|
||||
if (createdToken) {
|
||||
localStorage.setItem('admin_token', createdToken)
|
||||
onLogin(createdToken)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-warm-bg flex items-center justify-center p-4">
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-modal p-8 max-w-md w-full">
|
||||
<h1 className="text-2xl font-bold text-warm-text-primary mb-2">
|
||||
Admin Authentication
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mb-6">
|
||||
Sign in with an admin token to access the document management system
|
||||
</p>
|
||||
|
||||
{error && (
|
||||
<div className="mb-4 p-3 bg-red-50 border border-red-200 text-red-800 rounded text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{createdToken && (
|
||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded">
|
||||
<p className="text-sm font-medium text-green-800 mb-2">Token created successfully!</p>
|
||||
<div className="bg-white border border-green-300 rounded p-2 mb-3">
|
||||
<code className="text-xs font-mono text-warm-text-primary break-all">
|
||||
{createdToken}
|
||||
</code>
|
||||
</div>
|
||||
<p className="text-xs text-green-700 mb-3">
|
||||
Save this token securely. You won't be able to see it again.
|
||||
</p>
|
||||
<Button onClick={handleUseCreatedToken} className="w-full">
|
||||
Use This Token
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Login with existing token */}
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||
Sign in with existing token
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Admin Token
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={token}
|
||||
onChange={(e) => setToken(e.target.value)}
|
||||
placeholder="Enter your admin token"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info font-mono"
|
||||
onKeyDown={(e) => e.key === 'Enter' && handleLoginWithToken()}
|
||||
/>
|
||||
</div>
|
||||
<Button onClick={handleLoginWithToken} className="w-full">
|
||||
Sign In
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="relative">
|
||||
<div className="absolute inset-0 flex items-center">
|
||||
<div className="w-full border-t border-warm-border"></div>
|
||||
</div>
|
||||
<div className="relative flex justify-center text-xs">
|
||||
<span className="px-2 bg-warm-card text-warm-text-muted">OR</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Create new token */}
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||
Create new admin token
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Token Name <span className="text-red-500">*</span>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="e.g., my-laptop"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Description (optional)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="e.g., Personal laptop access"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
onClick={handleCreateToken}
|
||||
variant="secondary"
|
||||
disabled={isCreating}
|
||||
className="w-full"
|
||||
>
|
||||
{isCreating ? 'Creating...' : 'Create Token'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-6 pt-4 border-t border-warm-border">
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Admin tokens are used to authenticate with the document management API.
|
||||
Keep your tokens secure and never share them.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
208
frontend/src/components/Models.tsx
Normal file
208
frontend/src/components/Models.tsx
Normal file
@@ -0,0 +1,208 @@
|
||||
import React, { useState } from 'react';
|
||||
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
|
||||
import { Loader2, Power, CheckCircle } from 'lucide-react';
|
||||
import { Button } from './Button';
|
||||
import { useModels, useModelDetail } from '../hooks';
|
||||
import type { ModelVersionItem } from '../api/types';
|
||||
|
||||
const formatDate = (dateString: string | null): string => {
|
||||
if (!dateString) return 'N/A';
|
||||
return new Date(dateString).toLocaleString();
|
||||
};
|
||||
|
||||
export const Models: React.FC = () => {
|
||||
const [selectedModel, setSelectedModel] = useState<ModelVersionItem | null>(null);
|
||||
const { models, isLoading, activateModel, isActivating } = useModels();
|
||||
const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
|
||||
|
||||
// Build chart data from selected model's metrics
|
||||
const metricsData = modelDetail ? [
|
||||
{ name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
|
||||
{ name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
|
||||
{ name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
|
||||
] : [
|
||||
{ name: 'Precision', value: 0 },
|
||||
{ name: 'Recall', value: 0 },
|
||||
{ name: 'mAP', value: 0 },
|
||||
];
|
||||
|
||||
// Build comparison chart from all models (with placeholder if empty)
|
||||
const chartData = models.length > 0
|
||||
? models.slice(0, 4).map(m => ({
|
||||
name: m.version,
|
||||
value: (m.metrics_mAP ?? 0) * 100,
|
||||
}))
|
||||
: [
|
||||
{ name: 'Model A', value: 0 },
|
||||
{ name: 'Model B', value: 0 },
|
||||
{ name: 'Model C', value: 0 },
|
||||
{ name: 'Model D', value: 0 },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto flex gap-8">
|
||||
{/* Left: Job History */}
|
||||
<div className="flex-1">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Model Versions</h3>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Loader2 className="animate-spin text-warm-text-muted" size={32} />
|
||||
</div>
|
||||
) : models.length === 0 ? (
|
||||
<div className="text-center py-12 text-warm-text-muted">
|
||||
No model versions found. Complete a training task to create a model version.
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
{models.map(model => (
|
||||
<div
|
||||
key={model.version_id}
|
||||
onClick={() => setSelectedModel(model)}
|
||||
className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
|
||||
selectedModel?.version_id === model.version_id
|
||||
? 'border-warm-text-secondary'
|
||||
: 'border-warm-border hover:border-warm-divider'
|
||||
}`}
|
||||
>
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<div>
|
||||
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">
|
||||
{model.name}
|
||||
{model.is_active && <CheckCircle size={16} className="inline ml-2 text-warm-state-info" />}
|
||||
</h4>
|
||||
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
|
||||
</div>
|
||||
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
|
||||
model.is_active
|
||||
? 'bg-warm-state-info/10 text-warm-state-info'
|
||||
: 'bg-warm-selected text-warm-state-success'
|
||||
}`}>
|
||||
{model.is_active ? 'Active' : model.status}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="mt-4 flex gap-8">
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Documents</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{model.document_count}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">mAP</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">
|
||||
{model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Version</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{model.version}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right: Model Detail */}
|
||||
<div className="w-[400px]">
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
|
||||
<span className={`text-sm font-medium ${
|
||||
selectedModel?.is_active ? 'text-warm-state-info' : 'text-warm-state-success'
|
||||
}`}>
|
||||
{selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="mb-8">
|
||||
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
|
||||
<p className="font-medium text-warm-text-primary">
|
||||
{selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-8">
|
||||
{/* Chart 1 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Model Comparison (mAP)</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={chartData}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
|
||||
formatter={(value: number) => [`${value.toFixed(1)}%`, 'mAP']}
|
||||
/>
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Chart 2 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Performance Metrics</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={metricsData}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
formatter={(value: number) => [`${value.toFixed(1)}%`, 'Score']}
|
||||
/>
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-8 space-y-3">
|
||||
{selectedModel && !selectedModel.is_active ? (
|
||||
<Button
|
||||
className="w-full"
|
||||
onClick={() => activateModel(selectedModel.version_id)}
|
||||
disabled={isActivating}
|
||||
>
|
||||
{isActivating ? (
|
||||
<>
|
||||
<Loader2 size={16} className="mr-2 animate-spin" />
|
||||
Activating...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Power size={16} className="mr-2" />
|
||||
Activate for Inference
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
) : (
|
||||
<Button className="w-full" disabled={!selectedModel}>
|
||||
{selectedModel?.is_active ? (
|
||||
<>
|
||||
<CheckCircle size={16} className="mr-2" />
|
||||
Currently Active
|
||||
</>
|
||||
) : (
|
||||
'Select a Model'
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
<div className="flex gap-3">
|
||||
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>View Logs</Button>
|
||||
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>Use as Base</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
487
frontend/src/components/Training.tsx
Normal file
487
frontend/src/components/Training.tsx
Normal file
@@ -0,0 +1,487 @@
|
||||
import React, { useState, useMemo } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { AugmentationConfig } from './AugmentationConfig'
|
||||
import { useDatasets } from '../hooks/useDatasets'
|
||||
import { useTrainingDocuments } from '../hooks/useTraining'
|
||||
import { trainingApi } from '../api/endpoints'
|
||||
import type { DatasetListItem } from '../api/types'
|
||||
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||
|
||||
type Tab = 'datasets' | 'create'
|
||||
|
||||
interface TrainingProps {
|
||||
onNavigate?: (view: string, id?: string) => void
|
||||
}
|
||||
|
||||
const STATUS_STYLES: Record<string, string> = {
|
||||
ready: 'bg-warm-state-success/10 text-warm-state-success',
|
||||
building: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
training: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
failed: 'bg-warm-state-error/10 text-warm-state-error',
|
||||
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
running: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
}
|
||||
|
||||
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
|
||||
// If there's an active training task, show training status
|
||||
const displayStatus = trainingStatus === 'running'
|
||||
? 'training'
|
||||
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
|
||||
? 'pending'
|
||||
: status
|
||||
|
||||
return (
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
|
||||
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
|
||||
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
|
||||
{displayStatus}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Train Dialog ---
|
||||
|
||||
interface TrainDialogProps {
|
||||
dataset: DatasetListItem
|
||||
onClose: () => void
|
||||
onSubmit: (config: {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs: number
|
||||
batch_size: number
|
||||
augmentation?: AugmentationConfigType
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}) => void
|
||||
isPending: boolean
|
||||
}
|
||||
|
||||
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
|
||||
const [name, setName] = useState(`train-${dataset.name}`)
|
||||
const [epochs, setEpochs] = useState(100)
|
||||
const [batchSize, setBatchSize] = useState(16)
|
||||
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
|
||||
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
|
||||
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
|
||||
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
||||
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
||||
|
||||
// Fetch available trained models (active or inactive, not archived)
|
||||
const { data: modelsData } = useQuery({
|
||||
queryKey: ['training', 'models', 'available'],
|
||||
queryFn: () => trainingApi.getModels(),
|
||||
})
|
||||
// Filter out archived models - only show active/inactive models for base model selection
|
||||
const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived')
|
||||
|
||||
const handleSubmit = () => {
|
||||
onSubmit({
|
||||
name,
|
||||
config: {
|
||||
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
|
||||
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
|
||||
epochs,
|
||||
batch_size: batchSize,
|
||||
augmentation: augmentationEnabled
|
||||
? (augmentationConfig as AugmentationConfigType)
|
||||
: undefined,
|
||||
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
|
||||
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
|
||||
<p className="text-sm text-warm-text-muted mb-4">
|
||||
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
|
||||
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
|
||||
</p>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
|
||||
<input type="text" value={name} onChange={e => setName(e.target.value)}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
|
||||
{/* Base Model Selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
|
||||
<select
|
||||
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
|
||||
onChange={e => {
|
||||
if (e.target.value === 'pretrained') {
|
||||
setBaseModelType('pretrained')
|
||||
setBaseModelVersionId(null)
|
||||
} else {
|
||||
setBaseModelType('existing')
|
||||
setBaseModelVersionId(e.target.value)
|
||||
}
|
||||
}}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
>
|
||||
<option value="pretrained">yolo11n.pt (Pretrained)</option>
|
||||
{availableModels.map(m => (
|
||||
<option key={m.version_id} value={m.version_id}>
|
||||
{m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
{baseModelType === 'pretrained'
|
||||
? 'Start from pretrained YOLO model'
|
||||
: 'Continue training from an existing model (incremental training)'}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-4">
|
||||
<div className="flex-1">
|
||||
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
|
||||
<input
|
||||
id="train-epochs"
|
||||
type="number"
|
||||
min={1}
|
||||
max={1000}
|
||||
value={epochs}
|
||||
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
|
||||
<input
|
||||
id="train-batch-size"
|
||||
type="number"
|
||||
min={1}
|
||||
max={128}
|
||||
value={batchSize}
|
||||
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Augmentation Configuration */}
|
||||
<AugmentationConfig
|
||||
enabled={augmentationEnabled}
|
||||
onEnabledChange={setAugmentationEnabled}
|
||||
config={augmentationConfig}
|
||||
onConfigChange={setAugmentationConfig}
|
||||
/>
|
||||
|
||||
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
|
||||
{augmentationEnabled && (
|
||||
<div>
|
||||
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
|
||||
Augmentation Multiplier
|
||||
</label>
|
||||
<input
|
||||
id="aug-multiplier"
|
||||
type="number"
|
||||
min={1}
|
||||
max={10}
|
||||
value={augmentationMultiplier}
|
||||
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Number of augmented copies per original image (1-10)
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 mt-6">
|
||||
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
|
||||
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
|
||||
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Dataset List ---
|
||||
|
||||
const DatasetList: React.FC<{
|
||||
onNavigate?: (view: string, id?: string) => void
|
||||
onSwitchTab: (tab: Tab) => void
|
||||
}> = ({ onNavigate, onSwitchTab }) => {
|
||||
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
|
||||
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
|
||||
|
||||
const handleTrain = (config: {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs: number
|
||||
batch_size: number
|
||||
augmentation?: AugmentationConfigType
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}) => {
|
||||
if (!trainTarget) return
|
||||
// Pass config to the training API
|
||||
const trainRequest = {
|
||||
name: config.name,
|
||||
config: config.config,
|
||||
}
|
||||
trainFromDataset(
|
||||
{ datasetId: trainTarget.dataset_id, req: trainRequest },
|
||||
{ onSuccess: () => setTrainTarget(null) },
|
||||
)
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
|
||||
}
|
||||
|
||||
if (datasets.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
|
||||
<Database size={48} className="mb-4 opacity-40" />
|
||||
<p className="text-lg mb-2">No datasets yet</p>
|
||||
<p className="text-sm mb-4">Create a dataset to start training</p>
|
||||
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||
<table className="w-full text-left">
|
||||
<thead className="bg-white border-b border-warm-border">
|
||||
<tr>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{datasets.map(ds => (
|
||||
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
|
||||
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
|
||||
<td className="py-3 px-4">
|
||||
<div className="flex gap-1">
|
||||
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
|
||||
<Eye size={14} />
|
||||
</button>
|
||||
{ds.status === 'ready' && (
|
||||
<button title="Train" onClick={() => setTrainTarget(ds)}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
|
||||
<Play size={14} />
|
||||
</button>
|
||||
)}
|
||||
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
|
||||
disabled={isDeleting || ds.status === 'pending' || ds.status === 'building'}
|
||||
className={`p-1.5 rounded transition-colors ${
|
||||
ds.status === 'pending' || ds.status === 'building'
|
||||
? 'text-warm-text-muted/40 cursor-not-allowed'
|
||||
: 'hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error'
|
||||
}`}>
|
||||
<Trash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
{trainTarget && (
|
||||
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Create Dataset ---
|
||||
|
||||
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
|
||||
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
|
||||
const { createDatasetAsync, isCreating } = useDatasets()
|
||||
|
||||
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [trainRatio, setTrainRatio] = useState(0.7)
|
||||
const [valRatio, setValRatio] = useState(0.2)
|
||||
|
||||
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
|
||||
|
||||
const toggleDoc = (id: string) => {
|
||||
setSelectedIds(prev => {
|
||||
const next = new Set(prev)
|
||||
if (next.has(id)) { next.delete(id) } else { next.add(id) }
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
const toggleAll = () => {
|
||||
if (selectedIds.size === documents.length) {
|
||||
setSelectedIds(new Set())
|
||||
} else {
|
||||
setSelectedIds(new Set(documents.map((d) => d.document_id)))
|
||||
}
|
||||
}
|
||||
|
||||
const handleCreate = async () => {
|
||||
await createDatasetAsync({
|
||||
name,
|
||||
description: description || undefined,
|
||||
document_ids: [...selectedIds],
|
||||
train_ratio: trainRatio,
|
||||
val_ratio: valRatio,
|
||||
})
|
||||
onSwitchTab('datasets')
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-8">
|
||||
{/* Document selection */}
|
||||
<div className="flex-1 flex flex-col">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
|
||||
{isLoadingDocs ? (
|
||||
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
|
||||
) : (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
|
||||
<div className="overflow-auto max-h-[calc(100vh-240px)]">
|
||||
<table className="w-full text-left">
|
||||
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
||||
<tr>
|
||||
<th className="py-3 pl-6 pr-4 w-12">
|
||||
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
|
||||
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{documents.map((doc) => (
|
||||
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
|
||||
onClick={() => toggleDoc(doc.document_id)}>
|
||||
<td className="py-3 pl-6 pr-4">
|
||||
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
|
||||
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
|
||||
</td>
|
||||
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
|
||||
</div>
|
||||
|
||||
{/* Config panel */}
|
||||
<div className="w-80">
|
||||
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
|
||||
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
|
||||
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
|
||||
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
|
||||
<div className="flex gap-2 text-sm">
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Train</span>
|
||||
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Val</span>
|
||||
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Test</span>
|
||||
<input type="number" value={testRatio} readOnly
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="pt-4 border-t border-warm-border">
|
||||
{selectedIds.size > 0 && selectedIds.size < 10 && (
|
||||
<p className="text-xs text-warm-state-warning mb-2">
|
||||
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
|
||||
</p>
|
||||
)}
|
||||
<Button className="w-full h-11" onClick={handleCreate}
|
||||
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
|
||||
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Main Training Component ---
|
||||
|
||||
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
||||
const [activeTab, setActiveTab] = useState<Tab>('datasets')
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
|
||||
</div>
|
||||
|
||||
{/* Tabs */}
|
||||
<div className="flex gap-1 mb-6 border-b border-warm-border">
|
||||
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
|
||||
<button key={key} onClick={() => setActiveTab(key)}
|
||||
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
|
||||
activeTab === key
|
||||
? 'border-warm-state-info text-warm-state-info'
|
||||
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
|
||||
}`}>
|
||||
{label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
|
||||
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
276
frontend/src/components/UploadModal.tsx
Normal file
276
frontend/src/components/UploadModal.tsx
Normal file
@@ -0,0 +1,276 @@
|
||||
import React, { useState, useRef } from 'react'
|
||||
import { X, UploadCloud, File, CheckCircle, AlertCircle, ChevronDown } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocuments, useCategories } from '../hooks/useDocuments'
|
||||
|
||||
interface UploadModalProps {
|
||||
isOpen: boolean
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
||||
const [groupKey, setGroupKey] = useState('')
|
||||
const [category, setCategory] = useState('invoice')
|
||||
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
||||
const [errorMessage, setErrorMessage] = useState('')
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const { uploadDocument, isUploading } = useDocuments({})
|
||||
const { categories } = useCategories()
|
||||
|
||||
if (!isOpen) return null
|
||||
|
||||
const handleFileSelect = (files: FileList | null) => {
|
||||
if (!files) return
|
||||
|
||||
const pdfFiles = Array.from(files).filter(file => {
|
||||
const isPdf = file.type === 'application/pdf'
|
||||
const isImage = file.type.startsWith('image/')
|
||||
const isUnder25MB = file.size <= 25 * 1024 * 1024
|
||||
return (isPdf || isImage) && isUnder25MB
|
||||
})
|
||||
|
||||
setSelectedFiles(prev => [...prev, ...pdfFiles])
|
||||
setUploadStatus('idle')
|
||||
setErrorMessage('')
|
||||
}
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(false)
|
||||
handleFileSelect(e.dataTransfer.files)
|
||||
}
|
||||
|
||||
const handleBrowseClick = () => {
|
||||
fileInputRef.current?.click()
|
||||
}
|
||||
|
||||
const removeFile = (index: number) => {
|
||||
setSelectedFiles(prev => prev.filter((_, i) => i !== index))
|
||||
}
|
||||
|
||||
const handleUpload = async () => {
|
||||
if (selectedFiles.length === 0) {
|
||||
setErrorMessage('Please select at least one file')
|
||||
return
|
||||
}
|
||||
|
||||
setUploadStatus('uploading')
|
||||
setErrorMessage('')
|
||||
|
||||
try {
|
||||
// Upload files one by one
|
||||
for (const file of selectedFiles) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
uploadDocument(
|
||||
{ file, groupKey: groupKey || undefined, category: category || 'invoice' },
|
||||
{
|
||||
onSuccess: () => resolve(),
|
||||
onError: (error: Error) => reject(error),
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
setUploadStatus('success')
|
||||
setTimeout(() => {
|
||||
onClose()
|
||||
setSelectedFiles([])
|
||||
setGroupKey('')
|
||||
setCategory('invoice')
|
||||
setUploadStatus('idle')
|
||||
}, 1500)
|
||||
} catch (error) {
|
||||
setUploadStatus('error')
|
||||
setErrorMessage(error instanceof Error ? error.message : 'Upload failed')
|
||||
}
|
||||
}
|
||||
|
||||
const handleClose = () => {
|
||||
if (uploadStatus === 'uploading') {
|
||||
return // Prevent closing during upload
|
||||
}
|
||||
setSelectedFiles([])
|
||||
setGroupKey('')
|
||||
setCategory('invoice')
|
||||
setUploadStatus('idle')
|
||||
setErrorMessage('')
|
||||
onClose()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/20 backdrop-blur-sm transition-opacity duration-200">
|
||||
<div
|
||||
className="w-full max-w-lg bg-warm-card rounded-lg shadow-modal border border-warm-border transform transition-all duration-200 scale-100 p-6"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<h3 className="text-xl font-semibold text-warm-text-primary">Upload Documents</h3>
|
||||
<button
|
||||
onClick={handleClose}
|
||||
className="text-warm-text-muted hover:text-warm-text-primary transition-colors disabled:opacity-50"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<X size={20} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Drop Zone */}
|
||||
<div
|
||||
className={`
|
||||
w-full h-48 rounded-lg border-2 border-dashed flex flex-col items-center justify-center gap-3 transition-colors duration-150 mb-6 cursor-pointer
|
||||
${isDragging ? 'border-warm-text-secondary bg-warm-selected' : 'border-warm-divider bg-warm-bg hover:bg-warm-hover'}
|
||||
${uploadStatus === 'uploading' ? 'opacity-50 pointer-events-none' : ''}
|
||||
`}
|
||||
onDragOver={(e) => { e.preventDefault(); setIsDragging(true); }}
|
||||
onDragLeave={() => setIsDragging(false)}
|
||||
onDrop={handleDrop}
|
||||
onClick={handleBrowseClick}
|
||||
>
|
||||
<div className="p-3 bg-white rounded-full shadow-sm">
|
||||
<UploadCloud size={24} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center">
|
||||
<p className="text-sm font-medium text-warm-text-primary">
|
||||
Drag & drop files here or <span className="underline decoration-1 underline-offset-2 hover:text-warm-state-info">Browse</span>
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted mt-1">PDF, JPG, PNG up to 25MB</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
multiple
|
||||
accept=".pdf,image/*"
|
||||
className="hidden"
|
||||
onChange={(e) => handleFileSelect(e.target.files)}
|
||||
/>
|
||||
|
||||
{/* Selected Files */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6 max-h-40 overflow-y-auto">
|
||||
<p className="text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Selected Files ({selectedFiles.length})
|
||||
</p>
|
||||
<div className="space-y-2">
|
||||
{selectedFiles.map((file, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="flex items-center justify-between p-2 bg-warm-bg rounded border border-warm-border"
|
||||
>
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0">
|
||||
<File size={16} className="text-warm-text-muted flex-shrink-0" />
|
||||
<span className="text-sm text-warm-text-secondary truncate">
|
||||
{file.name}
|
||||
</span>
|
||||
<span className="text-xs text-warm-text-muted flex-shrink-0">
|
||||
({(file.size / 1024 / 1024).toFixed(2)} MB)
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => removeFile(index)}
|
||||
className="text-warm-text-muted hover:text-warm-state-error ml-2 flex-shrink-0"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<X size={16} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Category Select */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6">
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Category
|
||||
</label>
|
||||
<div className="relative">
|
||||
<select
|
||||
value={category}
|
||||
onChange={(e) => setCategory(e.target.value)}
|
||||
className="w-full h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none cursor-pointer"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<option value="invoice">Invoice</option>
|
||||
<option value="letter">Letter</option>
|
||||
<option value="receipt">Receipt</option>
|
||||
<option value="contract">Contract</option>
|
||||
{categories
|
||||
.filter((cat) => !['invoice', 'letter', 'receipt', 'contract'].includes(cat))
|
||||
.map((cat) => (
|
||||
<option key={cat} value={cat}>
|
||||
{cat.charAt(0).toUpperCase() + cat.slice(1)}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<ChevronDown
|
||||
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||
size={14}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Select document type for training different models
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Group Key Input */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6">
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Group Key (optional)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={groupKey}
|
||||
onChange={(e) => setGroupKey(e.target.value)}
|
||||
placeholder="e.g., 2024-Q1, supplier-abc, project-name"
|
||||
className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
/>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Use group keys to organize documents into logical groups
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Status Messages */}
|
||||
{uploadStatus === 'success' && (
|
||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
|
||||
<CheckCircle size={16} className="text-green-600" />
|
||||
<span className="text-sm text-green-800">Upload successful!</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{uploadStatus === 'error' && errorMessage && (
|
||||
<div className="mb-4 p-3 bg-red-50 border border-red-200 rounded flex items-center gap-2">
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
<span className="text-sm text-red-800">{errorMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Actions */}
|
||||
<div className="mt-8 flex justify-end gap-3">
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={handleClose}
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleUpload}
|
||||
disabled={selectedFiles.length === 0 || uploadStatus === 'uploading'}
|
||||
>
|
||||
{uploadStatus === 'uploading' ? 'Uploading...' : `Upload ${selectedFiles.length > 0 ? `(${selectedFiles.length})` : ''}`}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
7
frontend/src/hooks/index.ts
Normal file
7
frontend/src/hooks/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export { useDocuments, useCategories } from './useDocuments'
|
||||
export { useDocumentDetail } from './useDocumentDetail'
|
||||
export { useAnnotations } from './useAnnotations'
|
||||
export { useTraining, useTrainingDocuments } from './useTraining'
|
||||
export { useDatasets, useDatasetDetail } from './useDatasets'
|
||||
export { useAugmentation } from './useAugmentation'
|
||||
export { useModels, useModelDetail, useActiveModel } from './useModels'
|
||||
70
frontend/src/hooks/useAnnotations.ts
Normal file
70
frontend/src/hooks/useAnnotations.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import { useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { annotationsApi } from '../api/endpoints'
|
||||
import type { CreateAnnotationRequest, AnnotationOverrideRequest } from '../api/types'
|
||||
|
||||
export const useAnnotations = (documentId: string) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (annotation: CreateAnnotationRequest) =>
|
||||
annotationsApi.create(documentId, annotation),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateMutation = useMutation({
|
||||
mutationFn: ({
|
||||
annotationId,
|
||||
updates,
|
||||
}: {
|
||||
annotationId: string
|
||||
updates: Partial<CreateAnnotationRequest>
|
||||
}) => annotationsApi.update(documentId, annotationId, updates),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (annotationId: string) =>
|
||||
annotationsApi.delete(documentId, annotationId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const verifyMutation = useMutation({
|
||||
mutationFn: (annotationId: string) =>
|
||||
annotationsApi.verify(documentId, annotationId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const overrideMutation = useMutation({
|
||||
mutationFn: ({
|
||||
annotationId,
|
||||
overrideData,
|
||||
}: {
|
||||
annotationId: string
|
||||
overrideData: AnnotationOverrideRequest
|
||||
}) => annotationsApi.override(documentId, annotationId, overrideData),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
createAnnotation: createMutation.mutate,
|
||||
isCreating: createMutation.isPending,
|
||||
updateAnnotation: updateMutation.mutate,
|
||||
isUpdating: updateMutation.isPending,
|
||||
deleteAnnotation: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
verifyAnnotation: verifyMutation.mutate,
|
||||
isVerifying: verifyMutation.isPending,
|
||||
overrideAnnotation: overrideMutation.mutate,
|
||||
isOverriding: overrideMutation.isPending,
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user