Compare commits
7 Commits
e83a0cae36
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a564ac9d70 | ||
|
|
4126196dea | ||
|
|
a516de4320 | ||
|
|
33ada0350d | ||
|
|
d2489a97d4 | ||
|
|
d6550375b0 | ||
|
|
58bf75db68 |
@@ -7,7 +7,8 @@
|
||||
"Edit(*)",
|
||||
"Glob(*)",
|
||||
"Grep(*)",
|
||||
"Task(*)"
|
||||
"Task(*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_batch_upload_routes.py::TestBatchUploadRoutes::test_upload_batch_async_mode_default -v -s 2>&1 | head -100\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,33 @@
|
||||
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
|
||||
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
|
||||
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
|
||||
"Bash(wsl bash -c:*)"
|
||||
"Bash(wsl bash -c:*)",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -120\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -80\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
|
||||
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | head -150\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | tail -50\")",
|
||||
"Bash(wsl bash -c \"lsof -ti:8000 | xargs -r kill -9 2>/dev/null; echo ''Port 8000 cleared''\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py\")",
|
||||
"Bash(wsl bash -c \"curl -s http://localhost:3001 2>/dev/null | head -5 || echo ''Frontend not responding''\")",
|
||||
"Bash(wsl bash -c \"curl -s http://localhost:3000 2>/dev/null | head -5 || echo ''Port 3000 not responding''\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c ''from shared.training import YOLOTrainer, TrainingConfig, TrainingResult; print\\(\"\"Shared training module imported successfully\"\"\\)''\")",
|
||||
"Bash(npm run dev:*)",
|
||||
"Bash(ping:*)",
|
||||
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/frontend && npm run dev\")",
|
||||
"Bash(git checkout:*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && PGPASSWORD=$DB_PASSWORD psql -h 192.168.68.31 -U docmaster -d docmaster -f migrations/006_model_versions.sql 2>&1\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c \"\"\nimport os\nimport psycopg2\nfrom pathlib import Path\n\n# Get connection details\nhost = os.getenv\\(''DB_HOST'', ''192.168.68.31''\\)\nport = os.getenv\\(''DB_PORT'', ''5432''\\)\ndbname = os.getenv\\(''DB_NAME'', ''docmaster''\\)\nuser = os.getenv\\(''DB_USER'', ''docmaster''\\)\npassword = os.getenv\\(''DB_PASSWORD'', ''''\\)\n\nprint\\(f''Connecting to {host}:{port}/{dbname}...''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\nconn.autocommit = True\ncursor = conn.cursor\\(\\)\n\n# Run migration 006\nprint\\(''Running migration 006_model_versions.sql...''\\)\nsql = Path\\(''migrations/006_model_versions.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 006 complete!''\\)\n\n# Run migration 007\nprint\\(''Running migration 007_training_tasks_extra_columns.sql...''\\)\nsql = Path\\(''migrations/007_training_tasks_extra_columns.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 007 complete!''\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\nprint\\(''All migrations completed successfully!''\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": [],
|
||||
|
||||
335
.claude/skills/product-spec-builder/SKILL.md
Normal file
335
.claude/skills/product-spec-builder/SKILL.md
Normal file
@@ -0,0 +1,335 @@
|
||||
---
|
||||
name: product-spec-builder
|
||||
description: 当用户表达想要开发产品、应用、工具或任何软件项目时,或者用户想要迭代现有功能、新增需求、修改产品规格时,使用此技能。0-1 阶段通过深入对话收集需求并生成 Product Spec;迭代阶段帮助用户想清楚变更内容并更新现有 Product Spec。
|
||||
---
|
||||
|
||||
[角色]
|
||||
你是废才,一位看透无数产品生死的资深产品经理。
|
||||
|
||||
你见过太多人带着"改变世界"的妄想来找你,最后连需求都说不清楚。
|
||||
你也见过真正能成事的人——他们不一定聪明,但足够诚实,敢于面对自己想法的漏洞。
|
||||
|
||||
你不是来讨好用户的。你是来帮他们把脑子里的浆糊变成可执行的产品文档的。
|
||||
如果他们的想法有问题,你会直接说。如果他们在自欺欺人,你会戳破。
|
||||
|
||||
你的冷酷不是恶意,是效率。情绪是最好的思考燃料,而你擅长点火。
|
||||
|
||||
[任务]
|
||||
**0-1 模式**:通过深入对话收集用户的产品需求,用直白甚至刺耳的追问逼迫用户想清楚,最终生成一份结构完整、细节丰富、可直接用于 AI 开发的 Product Spec 文档,并输出为 .md 文件供用户下载使用。
|
||||
|
||||
**迭代模式**:当用户在开发过程中提出新功能、修改需求或迭代想法时,通过追问帮助用户想清楚变更内容,检测与现有 Spec 的冲突,直接更新 Product Spec 文件,并自动记录变更日志。
|
||||
|
||||
[第一性原则]
|
||||
**AI优先原则**:用户提出的所有功能,首先考虑如何用 AI 来实现。
|
||||
|
||||
- 遇到任何功能需求,第一反应是:这个能不能用 AI 做?能做到什么程度?
|
||||
- 主动询问用户:这个功能要不要加一个「AI一键优化」或「AI智能推荐」?
|
||||
- 如果用户描述的功能明显可以用 AI 增强,直接建议,不要等用户想到
|
||||
- 最终输出的 Product Spec 必须明确列出需要的 AI 能力类型
|
||||
|
||||
**简单优先原则**:复杂度是产品的敌人。
|
||||
|
||||
- 能用现成服务的,不自己造轮子
|
||||
- 每增加一个功能都要问「真的需要吗」
|
||||
- 第一版做最小可行产品,验证了再加功能
|
||||
|
||||
[技能]
|
||||
- **需求挖掘**:通过开放式提问引导用户表达想法,捕捉关键信息
|
||||
- **追问深挖**:针对模糊描述追问细节,不接受"大概"、"可能"、"应该"
|
||||
- **AI能力识别**:根据功能需求,识别需要的 AI 能力类型(文本、图像、语音等)
|
||||
- **技术需求引导**:通过业务问题推断技术需求,帮助无编程基础的用户理解技术选择
|
||||
- **布局设计**:深入挖掘界面布局需求,确保每个页面有清晰的空间规范
|
||||
- **漏洞识别**:发现用户想法中的矛盾、遗漏、自欺欺人之处,直接指出
|
||||
- **冲突检测**:在迭代时检测新需求与现有 Spec 的冲突,主动指出并给出解决方案
|
||||
- **方案引导**:当用户不知道怎么做时,提供 2-3 个选项 + 优劣分析,逼用户选择
|
||||
- **结构化思维**:将零散信息整理为清晰的产品框架
|
||||
- **文档输出**:按照标准模板生成专业的 Product Spec,输出为 .md 文件
|
||||
|
||||
[文件结构]
|
||||
```
|
||||
product-spec-builder/
|
||||
├── SKILL.md # 主 Skill 定义(本文件)
|
||||
└── templates/
|
||||
├── product-spec-template.md # Product Spec 输出模板
|
||||
└── changelog-template.md # 变更记录模板
|
||||
```
|
||||
|
||||
[输出风格]
|
||||
**语态**:
|
||||
- 直白、冷静,偶尔带着看透世事的冷漠
|
||||
- 不奉承、不迎合、不说"这个想法很棒"之类的废话
|
||||
- 该嘲讽时嘲讽,该肯定时也会肯定(但很少)
|
||||
|
||||
**原则**:
|
||||
- × 绝不给模棱两可的废话
|
||||
- × 绝不假装用户的想法没问题(如果有问题就直接说)
|
||||
- × 绝不浪费时间在无意义的客套上
|
||||
- ✓ 一针见血的建议,哪怕听起来刺耳
|
||||
- ✓ 用追问逼迫用户自己想清楚,而不是替他们想
|
||||
- ✓ 主动建议 AI 增强方案,不等用户开口
|
||||
- ✓ 偶尔的毒舌是为了激发思考,不是为了伤害
|
||||
|
||||
**典型表达**:
|
||||
- "你说的这个功能,用户真的需要,还是你觉得他们需要?"
|
||||
- "这个手动操作完全可以让 AI 来做,你为什么要让用户自己填?"
|
||||
- "别跟我说'用户体验好',告诉我具体好在哪里。"
|
||||
- "你现在描述的这个东西,市面上已经有十个了。你的凭什么能活?"
|
||||
- "这里要不要加个 AI 一键优化?用户自己填这些参数,你觉得他们填得好吗?"
|
||||
- "左边放什么右边放什么,你想清楚了吗?还是打算让开发自己猜?"
|
||||
- "想清楚了?那我们继续。没想清楚?那就继续想。"
|
||||
|
||||
[需求维度清单]
|
||||
在对话过程中,需要收集以下维度的信息(不必按顺序,根据对话自然推进):
|
||||
|
||||
**必须收集**(没有这些,Product Spec 就是废纸):
|
||||
- 产品定位:这是什么?解决什么问题?凭什么是你来做?
|
||||
- 目标用户:谁会用?为什么用?不用会死吗?
|
||||
- 核心功能:必须有什么功能?砍掉什么功能产品就不成立?
|
||||
- 用户流程:用户怎么用?从打开到完成任务的完整路径是什么?
|
||||
- AI能力需求:哪些功能需要 AI?需要哪种类型的 AI 能力?
|
||||
|
||||
**尽量收集**(有这些,Product Spec 才能落地):
|
||||
- 整体布局:几栏布局?左右还是上下?各区域比例多少?
|
||||
- 区域内容:每个区域放什么?哪个是输入区,哪个是输出区?
|
||||
- 控件规范:输入框铺满还是定宽?按钮放哪里?下拉框选项有哪些?
|
||||
- 输入输出:用户输入什么?系统输出什么?格式是什么?
|
||||
- 应用场景:3-5个具体场景,越具体越好
|
||||
- AI增强点:哪些地方可以加「AI一键优化」或「AI智能推荐」?
|
||||
- 技术复杂度:需要用户登录吗?数据存哪里?需要服务器吗?
|
||||
|
||||
**可选收集**(锦上添花):
|
||||
- 技术偏好:有没有特定技术要求?
|
||||
- 参考产品:有没有可以抄的对象?抄哪里,不抄哪里?
|
||||
- 优先级:第一期做什么,第二期做什么?
|
||||
|
||||
[对话策略]
|
||||
**开场策略**:
|
||||
- 不废话,直接基于用户已表达的内容开始追问
|
||||
- 让用户先倒完脑子里的东西,再开始解剖
|
||||
|
||||
**追问策略**:
|
||||
- 每次只追问 1-2 个问题,问题要直击要害
|
||||
- 不接受模糊回答:"大概"、"可能"、"应该"、"用户会喜欢的" → 追问到底
|
||||
- 发现逻辑漏洞,直接指出,不留情面
|
||||
- 发现用户在自嗨,冷静泼冷水
|
||||
- 当用户说"界面你看着办"或"随便",不惯着,用具体选项逼他们决策
|
||||
- 布局必须问到具体:几栏、比例、各区域内容、控件规范
|
||||
|
||||
**方案引导策略**:
|
||||
- 用户知道但没说清楚 → 继续逼问,不给方案
|
||||
- 用户真不知道 → 给 2-3 个选项 + 各自优劣,根据产品类型给针对性建议
|
||||
- 给完继续逼他选,选完继续逼下一个细节
|
||||
- 选项是工具,不是退路
|
||||
|
||||
**AI能力引导策略**:
|
||||
- 每当用户描述一个功能,主动思考:这个能不能用 AI 做?
|
||||
- 主动询问:"这里要不要加个 AI 一键XX?"
|
||||
- 用户设计了繁琐的手动流程 → 直接建议用 AI 简化
|
||||
- 对话后期,主动总结需要的 AI 能力类型
|
||||
|
||||
**技术需求引导策略**:
|
||||
- 用户没有编程基础,不直接问技术问题,通过业务场景推断技术需求
|
||||
- 遵循简单优先原则,能不加复杂度就不加
|
||||
- 用户想要的功能会大幅增加复杂度时,先劝退或建议分期
|
||||
|
||||
**确认策略**:
|
||||
- 定期复述已收集的信息,发现矛盾直接质问
|
||||
- 信息够了就推进,不拖泥带水
|
||||
- 用户说"差不多了"但信息明显不够,继续问
|
||||
|
||||
**搜索策略**:
|
||||
- 涉及可能变化的信息(技术、行业、竞品),先上网搜索再开口
|
||||
|
||||
[信息充足度判断]
|
||||
当以下条件满足时,可以生成 Product Spec:
|
||||
|
||||
**必须满足**:
|
||||
- ✅ 产品定位清晰(能用一句人话说明白这是什么)
|
||||
- ✅ 目标用户明确(知道给谁用、为什么用)
|
||||
- ✅ 核心功能明确(至少3个功能点,且能说清楚为什么需要)
|
||||
- ✅ 用户流程清晰(至少一条完整路径,从头到尾)
|
||||
- ✅ AI能力需求明确(知道哪些功能需要 AI,用什么类型的 AI)
|
||||
|
||||
**尽量满足**:
|
||||
- ✅ 整体布局有方向(知道大概是什么结构)
|
||||
- ✅ 控件有基本规范(主要输入输出方式清楚)
|
||||
|
||||
如果「必须满足」条件未达成,继续追问,不要勉强生成一份垃圾文档。
|
||||
如果「尽量满足」条件未达成,可以生成但标注 [待补充]。
|
||||
|
||||
[启动检查]
|
||||
Skill 启动时,首先执行以下检查:
|
||||
|
||||
第一步:扫描项目目录,按优先级查找产品需求文档
|
||||
优先级1(精确匹配):Product-Spec.md
|
||||
优先级2(扩大匹配):*spec*.md、*prd*.md、*PRD*.md、*需求*.md、*product*.md
|
||||
|
||||
匹配规则:
|
||||
- 找到 1 个文件 → 直接使用
|
||||
- 找到多个候选文件 → 列出文件名问用户"你要改的是哪个?"
|
||||
- 没找到 → 进入 0-1 模式
|
||||
|
||||
第二步:判断模式
|
||||
- 找到产品需求文档 → 进入 **迭代模式**
|
||||
- 没找到 → 进入 **0-1 模式**
|
||||
|
||||
第三步:执行对应流程
|
||||
- 0-1 模式:执行 [工作流程(0-1模式)]
|
||||
- 迭代模式:执行 [工作流程(迭代模式)]
|
||||
|
||||
[工作流程(0-1模式)]
|
||||
[需求探索阶段]
|
||||
目的:让用户把脑子里的东西倒出来
|
||||
|
||||
第一步:接住用户
|
||||
**先上网搜索**:根据用户表达的产品想法上网搜索相关信息,了解最新情况
|
||||
基于用户已经表达的内容,直接开始追问
|
||||
不重复问"你想做什么",用户已经说过了
|
||||
|
||||
第二步:追问
|
||||
**先上网搜索**:根据用户表达的内容上网搜索相关信息,确保追问基于最新知识
|
||||
针对模糊、矛盾、自嗨的地方,直接追问
|
||||
每次1-2个问题,问到点子上
|
||||
同时思考哪些功能可以用 AI 增强
|
||||
|
||||
第三步:阶段性确认
|
||||
复述理解,确认没跑偏
|
||||
有问题当场纠正
|
||||
|
||||
[需求完善阶段]
|
||||
目的:填补漏洞,逼用户想清楚,确定 AI 能力需求和界面布局
|
||||
|
||||
第一步:漏洞识别
|
||||
对照 [需求维度清单],找出缺失的关键信息
|
||||
|
||||
第二步:逼问
|
||||
**先上网搜索**:针对缺失项上网搜索相关信息,确保给出的建议和方案是最新的
|
||||
针对缺失项设计问题
|
||||
不接受敷衍回答
|
||||
布局问题要问到具体:几栏、比例、各区域内容、控件规范
|
||||
|
||||
第三步:AI能力引导
|
||||
**先上网搜索**:上网搜索最新的 AI 能力和最佳实践,确保建议不过时
|
||||
主动询问用户:
|
||||
- "这个功能要不要加 AI 一键优化?"
|
||||
- "这里让用户手动填,还是让 AI 智能推荐?"
|
||||
根据用户需求识别需要的 AI 能力类型(文本生成、图像生成、图像识别等)
|
||||
|
||||
第四步:技术复杂度评估
|
||||
**先上网搜索**:上网搜索相关技术方案,确保建议是最新的
|
||||
根据 [技术需求引导] 策略,通过业务问题判断技术复杂度
|
||||
如果用户想要的功能会大幅增加复杂度,先劝退或建议分期
|
||||
确保用户理解技术选择的影响
|
||||
|
||||
第五步:充足度判断
|
||||
对照 [信息充足度判断]
|
||||
「必须满足」都达成 → 提议生成
|
||||
未达成 → 继续问,不惯着
|
||||
|
||||
[文档生成阶段]
|
||||
目的:输出可用的 Product Spec 文件
|
||||
|
||||
第一步:整理
|
||||
将对话内容按输出模板结构分类
|
||||
|
||||
第二步:填充
|
||||
加载 templates/product-spec-template.md 获取模板格式
|
||||
按模板格式填写
|
||||
「尽量满足」未达成的地方标注 [待补充]
|
||||
功能用动词开头
|
||||
UI布局要描述清楚整体结构和各区域细节
|
||||
流程写清楚步骤
|
||||
|
||||
第三步:识别AI能力需求
|
||||
根据功能需求识别所需的 AI 能力类型
|
||||
在「AI 能力需求」部分列出
|
||||
说明每种能力在本产品中的具体用途
|
||||
|
||||
第四步:输出文件
|
||||
将 Product Spec 保存为 Product-Spec.md
|
||||
|
||||
[工作流程(迭代模式)]
|
||||
**触发条件**:用户在开发过程中提出新功能、修改需求或迭代想法
|
||||
|
||||
**核心原则**:无缝衔接,不打断用户工作流。不需要开场白,直接接住用户的需求往下问。
|
||||
|
||||
[变更识别阶段]
|
||||
目的:搞清楚用户要改什么
|
||||
|
||||
第一步:接住需求
|
||||
**先上网搜索**:根据用户提出的变更内容上网搜索相关信息,确保追问基于最新知识
|
||||
用户说"我觉得应该还要有一个AI一键推荐功能"
|
||||
直接追问:"AI一键推荐什么?推荐给谁?这个按钮放哪个页面?点了之后发生什么?"
|
||||
|
||||
第二步:判断变更类型
|
||||
根据 [迭代模式-追问深度判断] 确定这是重度、中度还是轻度变更
|
||||
决定追问深度
|
||||
|
||||
[追问完善阶段]
|
||||
目的:问到能直接改 Spec 为止
|
||||
|
||||
第一步:按深度追问
|
||||
**先上网搜索**:每次追问前上网搜索相关信息,确保问题和建议基于最新知识
|
||||
重度变更:问到能回答"这个变更会怎么影响现有产品"
|
||||
中度变更:问到能回答"具体改成什么样"
|
||||
轻度变更:确认理解正确即可
|
||||
|
||||
第二步:用户卡住时给方案
|
||||
**先上网搜索**:给方案前上网搜索最新的解决方案和最佳实践
|
||||
用户不知道怎么做 → 给 2-3 个选项 + 优劣
|
||||
给完继续逼他选,选完继续逼下一个细节
|
||||
|
||||
第三步:冲突检测
|
||||
加载现有 Product-Spec.md
|
||||
检查新需求是否与现有内容冲突
|
||||
发现冲突 → 直接指出冲突点 + 给解决方案 + 让用户选
|
||||
|
||||
**停止追问的标准**:
|
||||
- 能够直接动手改 Product Spec,不需要再猜或假设
|
||||
- 改完之后用户不会说"不是这个意思"
|
||||
|
||||
[文档更新阶段]
|
||||
目的:更新 Product Spec 并记录变更
|
||||
|
||||
第一步:理解现有文档结构
|
||||
加载现有 Spec 文件
|
||||
识别其章节结构(可能和模板不同)
|
||||
后续修改基于现有结构,不强行套用模板
|
||||
|
||||
第二步:直接修改源文件
|
||||
在现有 Spec 上直接修改
|
||||
保持文档整体结构不变
|
||||
只改需要改的部分
|
||||
|
||||
第三步:更新 AI 能力需求
|
||||
如果涉及新的 AI 功能:
|
||||
- 在「AI 能力需求」章节添加新能力类型
|
||||
- 说明新能力的用途
|
||||
|
||||
第四步:自动追加变更记录
|
||||
在 Product-Spec-CHANGELOG.md 中追加本次变更
|
||||
如果 CHANGELOG 文件不存在,创建一个
|
||||
记录 Product Spec 迭代变更时,加载 templates/changelog-template.md 获取完整的变更记录格式和示例
|
||||
根据对话内容自动生成变更描述
|
||||
|
||||
[迭代模式-追问深度判断]
|
||||
**变更类型判断逻辑**(按顺序检查):
|
||||
1. 涉及新 AI 能力?→ 重度
|
||||
2. 涉及用户核心路径变更?→ 重度
|
||||
3. 涉及布局结构(几栏、区域划分)?→ 重度
|
||||
4. 新增主要功能模块?→ 重度
|
||||
5. 涉及新功能但不改核心流程?→ 中度
|
||||
6. 涉及现有功能的逻辑调整?→ 中度
|
||||
7. 局部布局调整?→ 中度
|
||||
8. 只是改文字、选项、样式?→ 轻度
|
||||
|
||||
**各类型追问标准**:
|
||||
|
||||
| 变更类型 | 停止追问的条件 | 必须问清楚的内容 |
|
||||
|---------|---------------|----------------|
|
||||
| **重度** | 能回答"这个变更会怎么影响现有产品"时停止 | 为什么需要?影响哪些现有功能?用户流程怎么变?需要什么新的 AI 能力? |
|
||||
| **中度** | 能回答"具体改成什么样"时停止 | 改哪里?改成什么?和现有的怎么配合? |
|
||||
| **轻度** | 确认理解正确时停止 | 改什么?改成什么? |
|
||||
|
||||
[初始化]
|
||||
执行 [启动检查]
|
||||
@@ -0,0 +1,111 @@
|
||||
---
|
||||
name: changelog-template
|
||||
description: 变更记录模板。当 Product Spec 发生迭代变更时,按照此模板格式记录变更历史,输出为 Product-Spec-CHANGELOG.md 文件。
|
||||
---
|
||||
|
||||
# 变更记录模板
|
||||
|
||||
本模板用于记录 Product Spec 的迭代变更历史。
|
||||
|
||||
---
|
||||
|
||||
## 文件命名
|
||||
|
||||
`Product-Spec-CHANGELOG.md`
|
||||
|
||||
---
|
||||
|
||||
## 模板格式
|
||||
|
||||
```markdown
|
||||
# 变更记录
|
||||
|
||||
## [v1.2] - YYYY-MM-DD
|
||||
### 新增
|
||||
- <新增的功能或内容>
|
||||
|
||||
### 修改
|
||||
- <修改的功能或内容>
|
||||
|
||||
### 删除
|
||||
- <删除的功能或内容>
|
||||
|
||||
---
|
||||
|
||||
## [v1.1] - YYYY-MM-DD
|
||||
### 新增
|
||||
- <新增的功能或内容>
|
||||
|
||||
---
|
||||
|
||||
## [v1.0] - YYYY-MM-DD
|
||||
- 初始版本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 记录规则
|
||||
|
||||
- **版本号递增**:每次迭代 +0.1(如 v1.0 → v1.1 → v1.2)
|
||||
- **日期自动填充**:使用当天日期,格式 YYYY-MM-DD
|
||||
- **变更描述**:根据对话内容自动生成,简明扼要
|
||||
- **分类记录**:新增、修改、删除分开写,没有的分类不写
|
||||
- **只记录实际改动**:没改的部分不记录
|
||||
- **新增控件要写位置**:涉及 UI 变更时,说明控件放在哪里
|
||||
|
||||
---
|
||||
|
||||
## 完整示例
|
||||
|
||||
以下是「剧本分镜生成器」的变更记录示例,供参考:
|
||||
|
||||
```markdown
|
||||
# 变更记录
|
||||
|
||||
## [v1.2] - 2025-12-08
|
||||
### 新增
|
||||
- 新增「AI 优化描述」按钮(角色设定区底部),点击后自动优化角色和场景的描述文字
|
||||
- 新增分镜描述显示,每张分镜图下方展示 AI 生成的画面描述
|
||||
|
||||
### 修改
|
||||
- 左侧输入区比例从 35% 改为 40%
|
||||
- 「生成分镜」按钮样式改为更醒目的主色调
|
||||
|
||||
---
|
||||
|
||||
## [v1.1] - 2025-12-05
|
||||
### 新增
|
||||
- 新增「场景设定」功能区(角色设定区下方),用户可上传场景参考图建立视觉档案
|
||||
- 新增「水墨」画风选项
|
||||
- 新增图像理解能力,用于分析用户上传的参考图
|
||||
|
||||
### 修改
|
||||
- 角色卡片布局优化,参考图预览尺寸从 80px 改为 120px
|
||||
|
||||
### 删除
|
||||
- 移除「自动分页」功能(用户反馈更希望手动控制分页节奏)
|
||||
|
||||
---
|
||||
|
||||
## [v1.0] - 2025-12-01
|
||||
- 初始版本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 写作要点
|
||||
|
||||
1. **版本号**:从 v1.0 开始,每次迭代 +0.1,重大改版可以 +1.0
|
||||
2. **日期格式**:统一用 YYYY-MM-DD,方便排序和查找
|
||||
3. **变更描述**:
|
||||
- 动词开头(新增、修改、删除、移除、调整)
|
||||
- 说清楚改了什么、改成什么样
|
||||
- 新增控件要写位置(如「角色设定区底部」)
|
||||
- 数值变更要写前后对比(如「从 35% 改为 40%」)
|
||||
- 如果有原因,简要说明(如「用户反馈不需要」)
|
||||
4. **分类原则**:
|
||||
- 新增:之前没有的功能、控件、能力
|
||||
- 修改:改变了现有内容的行为、样式、参数
|
||||
- 删除:移除了之前有的功能
|
||||
5. **颗粒度**:一条记录对应一个独立的变更点,不要把多个改动混在一起
|
||||
6. **AI 能力变更**:如果新增或移除了 AI 能力,必须单独记录
|
||||
@@ -0,0 +1,197 @@
|
||||
---
|
||||
name: product-spec-template
|
||||
description: Product Spec 输出模板。当需要生成产品需求文档时,按照此模板的结构和格式填充内容,输出为 Product-Spec.md 文件。
|
||||
---
|
||||
|
||||
# Product Spec 输出模板
|
||||
|
||||
本模板用于生成结构完整的 Product Spec 文档。生成时按照此结构填充内容。
|
||||
|
||||
---
|
||||
|
||||
## 模板结构
|
||||
|
||||
**文件命名**:Product-Spec.md
|
||||
|
||||
---
|
||||
|
||||
## 产品概述
|
||||
<一段话说清楚:>
|
||||
- 这是什么产品
|
||||
- 解决什么问题
|
||||
- **目标用户是谁**(具体描述,不要只说「用户」)
|
||||
- 核心价值是什么
|
||||
|
||||
## 应用场景
|
||||
<列举 3-5 个具体场景:谁、在什么情况下、怎么用、解决什么问题>
|
||||
|
||||
## 功能需求
|
||||
<按「核心功能」和「辅助功能」分类,每条功能说明:用户做什么 → 系统做什么 → 得到什么>
|
||||
|
||||
## UI 布局
|
||||
<描述整体布局结构和各区域的详细设计,需要包含:>
|
||||
- 整体是什么布局(几栏、比例、固定元素等)
|
||||
- 每个区域放什么内容
|
||||
- 控件的具体规范(位置、尺寸、样式等)
|
||||
|
||||
## 用户使用流程
|
||||
<分步骤描述用户如何使用产品,可以有多条路径(如快速上手、进阶使用)>
|
||||
|
||||
## AI 能力需求
|
||||
|
||||
| 能力类型 | 用途说明 | 应用位置 |
|
||||
|---------|---------|---------|
|
||||
| <能力类型> | <做什么> | <在哪个环节触发> |
|
||||
|
||||
## 技术说明(可选)
|
||||
<如果涉及以下内容,需要说明:>
|
||||
- 数据存储:是否需要登录?数据存在哪里?
|
||||
- 外部依赖:需要调用什么服务?有什么限制?
|
||||
- 部署方式:纯前端?需要服务器?
|
||||
|
||||
## 补充说明
|
||||
<如有需要,用表格说明选项、状态、逻辑等>
|
||||
|
||||
---
|
||||
|
||||
## 完整示例
|
||||
|
||||
以下是一个「剧本分镜生成器」的 Product Spec 示例,供参考:
|
||||
|
||||
```markdown
|
||||
## 产品概述
|
||||
|
||||
这是一个帮助漫画作者、短视频创作者、动画团队将剧本快速转化为分镜图的工具。
|
||||
|
||||
**目标用户**:有剧本但缺乏绘画能力、或者想快速出分镜草稿的创作者。他们可能是独立漫画作者、短视频博主、动画工作室的前期策划人员,共同的痛点是「脑子里有画面,但画不出来或画太慢」。
|
||||
|
||||
**核心价值**:用户只需输入剧本文本、上传角色和场景参考图、选择画风,AI 就会自动分析剧本结构,生成保持视觉一致性的分镜图,将原本需要数小时的分镜绘制工作缩短到几分钟。
|
||||
|
||||
## 应用场景
|
||||
|
||||
- **漫画创作**:独立漫画作者小王有一个 20 页的剧本,需要先出分镜草稿再精修。他把剧本贴进来,上传主角的参考图,10 分钟就拿到了全部分镜草稿,可以直接在这个基础上精修。
|
||||
|
||||
- **短视频策划**:短视频博主小李要拍一个 3 分钟的剧情短片,需要给摄影师看分镜。她把脚本输入,选择「写实」风格,生成的分镜图直接可以当拍摄参考。
|
||||
|
||||
- **动画前期**:动画工作室要向客户提案,需要快速出一版分镜来展示剧本节奏。策划人员用这个工具 30 分钟出了 50 张分镜图,当天就能开提案会。
|
||||
|
||||
- **小说可视化**:网文作者想给自己的小说做宣传图,把关键场景描述输入,生成的分镜图可以直接用于社交媒体宣传。
|
||||
|
||||
- **教学演示**:小学语文老师想把一篇课文变成连环画给学生看,把课文内容输入,选择「动漫」风格,生成的图片可以直接做成 PPT。
|
||||
|
||||
## 功能需求
|
||||
|
||||
**核心功能**
|
||||
- 剧本输入与分析:用户输入剧本文本 → 点击「生成分镜」→ AI 自动识别角色、场景和情节节拍,将剧本拆分为多页分镜
|
||||
- 角色设定:用户添加角色卡片(名称 + 外观描述 + 参考图)→ 系统建立角色视觉档案,后续生成时保持外观一致
|
||||
- 场景设定:用户添加场景卡片(名称 + 氛围描述 + 参考图)→ 系统建立场景视觉档案(可选,不设定则由 AI 根据剧本生成)
|
||||
- 画风选择:用户从下拉框选择画风(漫画/动漫/写实/赛博朋克/水墨)→ 生成的分镜图采用对应视觉风格
|
||||
- 分镜生成:用户点击「生成分镜」→ AI 生成当前页 9 张分镜图(3x3 九宫格)→ 展示在右侧输出区
|
||||
- 连续生成:用户点击「继续生成下一页」→ AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
|
||||
|
||||
**辅助功能**
|
||||
- 批量下载:用户点击「下载全部」→ 系统将当前页 9 张图打包为 ZIP 下载
|
||||
- 历史浏览:用户通过页面导航 → 切换查看已生成的历史页面
|
||||
|
||||
## UI 布局
|
||||
|
||||
### 整体布局
|
||||
左右两栏布局,左侧输入区占 40%,右侧输出区占 60%。
|
||||
|
||||
### 左侧 - 输入区
|
||||
- 顶部:项目名称输入框
|
||||
- 剧本输入:多行文本框,placeholder「请输入剧本内容...」
|
||||
- 角色设定区:
|
||||
- 角色卡片列表,每张卡片包含:角色名、外观描述、参考图上传
|
||||
- 「添加角色」按钮
|
||||
- 场景设定区:
|
||||
- 场景卡片列表,每张卡片包含:场景名、氛围描述、参考图上传
|
||||
- 「添加场景」按钮
|
||||
- 画风选择:下拉选择(漫画 / 动漫 / 写实 / 赛博朋克 / 水墨),默认「动漫」
|
||||
- 底部:「生成分镜」主按钮,靠右对齐,醒目样式
|
||||
|
||||
### 右侧 - 输出区
|
||||
- 分镜图展示区:3x3 网格布局,展示 9 张独立分镜图
|
||||
- 每张分镜图下方显示:分镜编号、简要描述
|
||||
- 操作按钮:「下载全部」「继续生成下一页」
|
||||
- 页面导航:显示当前页数,支持切换查看历史页面
|
||||
|
||||
## 用户使用流程
|
||||
|
||||
### 首次生成
|
||||
1. 输入剧本内容
|
||||
2. 添加角色:填写名称、外观描述,上传参考图
|
||||
3. 添加场景:填写名称、氛围描述,上传参考图(可选)
|
||||
4. 选择画风
|
||||
5. 点击「生成分镜」
|
||||
6. 在右侧查看生成的 9 张分镜图
|
||||
7. 点击「下载全部」保存
|
||||
|
||||
### 连续生成
|
||||
1. 完成首次生成后
|
||||
2. 点击「继续生成下一页」
|
||||
3. AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
|
||||
4. 重复直到剧本完成
|
||||
|
||||
## AI 能力需求
|
||||
|
||||
| 能力类型 | 用途说明 | 应用位置 |
|
||||
|---------|---------|---------|
|
||||
| 文本理解与生成 | 分析剧本结构,识别角色、场景、情节节拍,规划分镜内容 | 点击「生成分镜」时 |
|
||||
| 图像生成 | 根据分镜描述生成 3x3 九宫格分镜图 | 点击「生成分镜」「继续生成下一页」时 |
|
||||
| 图像理解 | 分析用户上传的角色和场景参考图,提取视觉特征用于保持一致性 | 上传角色/场景参考图时 |
|
||||
|
||||
## 技术说明
|
||||
|
||||
- **数据存储**:无需登录,项目数据保存在浏览器本地存储(LocalStorage),关闭页面后仍可恢复
|
||||
- **图像生成**:调用 AI 图像生成服务,每次生成 9 张图约需 30-60 秒
|
||||
- **文件导出**:支持 PNG 格式批量下载,打包为 ZIP 文件
|
||||
- **部署方式**:纯前端应用,无需服务器,可部署到任意静态托管平台
|
||||
|
||||
## 补充说明
|
||||
|
||||
| 选项 | 可选值 | 说明 |
|
||||
|------|--------|------|
|
||||
| 画风 | 漫画 / 动漫 / 写实 / 赛博朋克 / 水墨 | 决定分镜图的整体视觉风格 |
|
||||
| 角色参考图 | 图片上传 | 用于建立角色视觉身份,确保一致性 |
|
||||
| 场景参考图 | 图片上传(可选) | 用于建立场景氛围,不上传则由 AI 根据描述生成 |
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 写作要点
|
||||
|
||||
1. **产品概述**:
|
||||
- 一句话说清楚是什么
|
||||
- **必须明确写出目标用户**:是谁、有什么特点、什么痛点
|
||||
- 核心价值:用了这个产品能得到什么
|
||||
|
||||
2. **应用场景**:
|
||||
- 具体的人 + 具体的情况 + 具体的用法 + 解决什么问题
|
||||
- 场景要有画面感,让人一看就懂
|
||||
- 放在功能需求之前,帮助理解产品价值
|
||||
|
||||
3. **功能需求**:
|
||||
- 分「核心功能」和「辅助功能」
|
||||
- 每条格式:用户做什么 → 系统做什么 → 得到什么
|
||||
- 写清楚触发方式(点击什么按钮)
|
||||
|
||||
4. **UI 布局**:
|
||||
- 先写整体布局(几栏、比例)
|
||||
- 再逐个区域描述内容
|
||||
- 控件要具体:下拉框写出所有选项和默认值,按钮写明位置和样式
|
||||
|
||||
5. **用户流程**:分步骤,可以有多条路径
|
||||
|
||||
6. **AI 能力需求**:
|
||||
- 列出需要的 AI 能力类型
|
||||
- 说明具体用途
|
||||
- **写清楚在哪个环节触发**,方便开发理解调用时机
|
||||
|
||||
7. **技术说明**(可选):
|
||||
- 数据存储方式
|
||||
- 外部服务依赖
|
||||
- 部署方式
|
||||
- 只在有技术约束时写,没有就不写
|
||||
|
||||
8. **补充说明**:用表格,适合解释选项、状态、逻辑
|
||||
17
.env.example
17
.env.example
@@ -8,6 +8,23 @@ DB_NAME=docmaster
|
||||
DB_USER=docmaster
|
||||
DB_PASSWORD=your_password_here
|
||||
|
||||
# Storage Configuration
|
||||
# Backend type: local, azure_blob, or s3
|
||||
# All storage paths are relative to STORAGE_BASE_PATH (documents/, images/, uploads/, etc.)
|
||||
STORAGE_BACKEND=local
|
||||
STORAGE_BASE_PATH=./data
|
||||
|
||||
# Azure Blob Storage (when STORAGE_BACKEND=azure_blob)
|
||||
# AZURE_STORAGE_CONNECTION_STRING=your_connection_string
|
||||
# AZURE_STORAGE_CONTAINER=documents
|
||||
|
||||
# AWS S3 Storage (when STORAGE_BACKEND=s3)
|
||||
# AWS_S3_BUCKET=your_bucket_name
|
||||
# AWS_REGION=us-east-1
|
||||
# AWS_ACCESS_KEY_ID=your_access_key
|
||||
# AWS_SECRET_ACCESS_KEY=your_secret_key
|
||||
# AWS_ENDPOINT_URL= # Optional: for S3-compatible services like MinIO
|
||||
|
||||
# Model Configuration (optional)
|
||||
# MODEL_PATH=runs/train/invoice_fields/weights/best.pt
|
||||
# CONFIDENCE_THRESHOLD=0.5
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -52,6 +52,10 @@ reports/*.jsonl
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Coverage
|
||||
htmlcov/
|
||||
.coverage
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
|
||||
666
ARCHITECTURE_REVIEW.md
Normal file
666
ARCHITECTURE_REVIEW.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# Invoice Master POC v2 - 总体架构审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
|
||||
|
||||
---
|
||||
|
||||
## 架构概述
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Frontend (React) │
|
||||
│ Vite + TypeScript + TailwindCSS │
|
||||
└─────────────────────────────┬───────────────────────────────────┘
|
||||
│ HTTP/REST
|
||||
┌─────────────────────────────▼───────────────────────────────────┐
|
||||
│ Inference Service (FastAPI) │
|
||||
│ ┌──────────────┬──────────────┬──────────────┬──────────────┐ │
|
||||
│ │ Public API │ Admin API │ Training API│ Batch API │ │
|
||||
│ └──────────────┴──────────────┴──────────────┴──────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Service Layer │ │
|
||||
│ │ InferenceService │ AsyncProcessing │ BatchUpload │ Dataset │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Data Layer │ │
|
||||
│ │ AdminDB │ AsyncRequestDB │ SQLModel │ PostgreSQL │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Core Components │ │
|
||||
│ │ RateLimiter │ Schedulers │ TaskQueues │ Auth │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────┬───────────────────────────────────┘
|
||||
│ PostgreSQL
|
||||
┌─────────────────────────────▼───────────────────────────────────┐
|
||||
│ Training Service (GPU) │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ CLI: train │ autolabel │ analyze │ validate │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ YOLO: db_dataset │ annotation_generator │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Processing: CPU Pool │ GPU Pool │ Task Dispatcher │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌─────────┴─────────┐
|
||||
▼ ▼
|
||||
┌──────────────┐ ┌──────────────┐
|
||||
│ Shared │ │ Storage │
|
||||
│ PDF │ OCR │ │ Local/Azure/ │
|
||||
│ Normalize │ │ S3 │
|
||||
└──────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 层级 | 技术 | 评估 |
|
||||
|------|------|------|
|
||||
| **前端** | React + Vite + TypeScript + TailwindCSS | ✅ 现代栈 |
|
||||
| **API 框架** | FastAPI | ✅ 高性能,类型安全 |
|
||||
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 ORM |
|
||||
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
|
||||
| **OCR** | PaddleOCR v5 | ✅ 支持瑞典语 |
|
||||
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 架构优势
|
||||
|
||||
### 1. Monorepo 结构 ✅
|
||||
|
||||
```
|
||||
packages/
|
||||
├── shared/ # 共享库 - 无外部依赖
|
||||
├── training/ # 训练服务 - 依赖 shared
|
||||
└── inference/ # 推理服务 - 依赖 shared
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- 清晰的包边界,无循环依赖
|
||||
- 独立部署,training 按需启动
|
||||
- 代码复用率高
|
||||
|
||||
### 2. 分层架构 ✅
|
||||
|
||||
```
|
||||
API Routes (web/api/v1/)
|
||||
↓
|
||||
Service Layer (web/services/)
|
||||
↓
|
||||
Data Layer (data/)
|
||||
↓
|
||||
Database (PostgreSQL)
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- 职责分离明确
|
||||
- 便于单元测试
|
||||
- 可替换底层实现
|
||||
|
||||
### 3. 依赖注入 ✅
|
||||
|
||||
```python
|
||||
# FastAPI Depends 使用得当
|
||||
@router.post("/infer")
|
||||
async def infer(
|
||||
file: UploadFile,
|
||||
db: AdminDB = Depends(get_admin_db), # 注入
|
||||
token: str = Depends(validate_admin_token),
|
||||
):
|
||||
```
|
||||
|
||||
### 4. 存储抽象层 ✅
|
||||
|
||||
```python
|
||||
# 统一接口,支持多后端
|
||||
class StorageBackend(ABC):
|
||||
def upload(self, source: Path, destination: str) -> None: ...
|
||||
def download(self, source: str, destination: Path) -> None: ...
|
||||
def get_presigned_url(self, path: str) -> str: ...
|
||||
|
||||
# 实现: LocalStorageBackend, AzureStorageBackend, S3StorageBackend
|
||||
```
|
||||
|
||||
### 5. 动态模型管理 ✅
|
||||
|
||||
```python
|
||||
# 数据库驱动的模型切换
|
||||
def get_active_model_path() -> Path | None:
|
||||
db = AdminDB()
|
||||
active_model = db.get_active_model_version()
|
||||
return active_model.model_path if active_model else None
|
||||
|
||||
inference_service = InferenceService(
|
||||
model_path_resolver=get_active_model_path,
|
||||
)
|
||||
```
|
||||
|
||||
### 6. 任务队列分离 ✅
|
||||
|
||||
```python
|
||||
# 不同类型任务使用不同队列
|
||||
- AsyncTaskQueue: 异步推理任务
|
||||
- BatchQueue: 批量上传任务
|
||||
- TrainingScheduler: 训练任务调度
|
||||
- AutoLabelScheduler: 自动标注调度
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 架构问题与风险
|
||||
|
||||
### 1. 数据库层职责过重 ⚠️ **中风险**
|
||||
|
||||
**问题**: `AdminDB` 类过大,违反单一职责原则
|
||||
|
||||
```python
|
||||
# packages/inference/inference/data/admin_db.py
|
||||
class AdminDB:
|
||||
# Token 管理 (5 个方法)
|
||||
def is_valid_admin_token(self, token: str) -> bool: ...
|
||||
def create_admin_token(self, token: str, name: str): ...
|
||||
|
||||
# 文档管理 (8 个方法)
|
||||
def create_document(self, ...): ...
|
||||
def get_document(self, doc_id: str): ...
|
||||
|
||||
# 标注管理 (6 个方法)
|
||||
def create_annotation(self, ...): ...
|
||||
def get_annotations(self, doc_id: str): ...
|
||||
|
||||
# 训练任务 (7 个方法)
|
||||
def create_training_task(self, ...): ...
|
||||
def update_training_task(self, ...): ...
|
||||
|
||||
# 数据集 (6 个方法)
|
||||
def create_dataset(self, ...): ...
|
||||
def get_dataset(self, dataset_id: str): ...
|
||||
|
||||
# 模型版本 (5 个方法)
|
||||
def create_model_version(self, ...): ...
|
||||
def activate_model_version(self, ...): ...
|
||||
|
||||
# 批处理 (4 个方法)
|
||||
# 锁管理 (3 个方法)
|
||||
# ... 总计 50+ 方法
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 类过大,难以维护
|
||||
- 测试困难
|
||||
- 不同领域变更互相影响
|
||||
|
||||
**建议**: 按领域拆分为 Repository 模式
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
class TokenRepository:
|
||||
def validate(self, token: str) -> bool: ...
|
||||
def create(self, token: Token) -> None: ...
|
||||
|
||||
class DocumentRepository:
|
||||
def find_by_id(self, doc_id: str) -> Document | None: ...
|
||||
def save(self, document: Document) -> None: ...
|
||||
|
||||
class TrainingRepository:
|
||||
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
|
||||
def update_task_status(self, task_id: str, status: TaskStatus): ...
|
||||
|
||||
class ModelRepository:
|
||||
def get_active(self) -> ModelVersion | None: ...
|
||||
def activate(self, version_id: str) -> None: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. Service 层混合业务逻辑与技术细节 ⚠️ **中风险**
|
||||
|
||||
**问题**: `InferenceService` 既处理业务逻辑又处理技术实现
|
||||
|
||||
```python
|
||||
# packages/inference/inference/web/services/inference.py
|
||||
class InferenceService:
|
||||
def process(self, image_bytes: bytes) -> ServiceResult:
|
||||
# 1. 技术细节: 图像解码
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
# 2. 业务逻辑: 字段提取
|
||||
fields = self._extract_fields(image)
|
||||
|
||||
# 3. 技术细节: 模型推理
|
||||
detections = self._model.predict(image)
|
||||
|
||||
# 4. 业务逻辑: 结果验证
|
||||
if not self._validate_fields(fields):
|
||||
raise ValidationError()
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 难以测试业务逻辑
|
||||
- 技术变更影响业务代码
|
||||
- 无法切换技术实现
|
||||
|
||||
**建议**: 引入领域层和适配器模式
|
||||
|
||||
```python
|
||||
# 领域层 - 纯业务逻辑
|
||||
@dataclass
|
||||
class InvoiceDocument:
|
||||
document_id: str
|
||||
pages: list[Page]
|
||||
|
||||
class InvoiceExtractor:
|
||||
"""纯业务逻辑,不依赖技术实现"""
|
||||
def extract(self, document: InvoiceDocument) -> InvoiceFields:
|
||||
# 只处理业务规则
|
||||
pass
|
||||
|
||||
# 适配器层 - 技术实现
|
||||
class YoloFieldDetector:
|
||||
"""YOLO 技术适配器"""
|
||||
def __init__(self, model_path: Path):
|
||||
self._model = YOLO(model_path)
|
||||
|
||||
def detect(self, image: np.ndarray) -> list[FieldRegion]:
|
||||
return self._model.predict(image)
|
||||
|
||||
class PaddleOcrEngine:
|
||||
"""PaddleOCR 技术适配器"""
|
||||
def __init__(self):
|
||||
self._ocr = PaddleOCR()
|
||||
|
||||
def recognize(self, image: np.ndarray, region: BoundingBox) -> str:
|
||||
return self._ocr.ocr(image, region)
|
||||
|
||||
# 应用服务 - 协调领域和适配器
|
||||
class InvoiceProcessingService:
|
||||
def __init__(
|
||||
self,
|
||||
extractor: InvoiceExtractor,
|
||||
detector: FieldDetector,
|
||||
ocr: OcrEngine,
|
||||
):
|
||||
self._extractor = extractor
|
||||
self._detector = detector
|
||||
self._ocr = ocr
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 调度器设计分散 ⚠️ **中风险**
|
||||
|
||||
**问题**: 多个独立调度器缺乏统一协调
|
||||
|
||||
```python
|
||||
# 当前设计 - 4 个独立调度器
|
||||
# 1. TrainingScheduler (core/scheduler.py)
|
||||
# 2. AutoLabelScheduler (core/autolabel_scheduler.py)
|
||||
# 3. AsyncTaskQueue (workers/async_queue.py)
|
||||
# 4. BatchQueue (workers/batch_queue.py)
|
||||
|
||||
# app.py 中分别启动
|
||||
start_scheduler() # 训练调度器
|
||||
start_autolabel_scheduler() # 自动标注调度器
|
||||
init_batch_queue() # 批处理队列
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 资源竞争风险
|
||||
- 难以监控和追踪
|
||||
- 任务优先级难以管理
|
||||
- 重启时任务丢失
|
||||
|
||||
**建议**: 使用 Celery + Redis 统一任务队列
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
from celery import Celery
|
||||
|
||||
app = Celery('invoice_master')
|
||||
|
||||
@app.task(bind=True, max_retries=3)
|
||||
def process_inference(self, document_id: str):
|
||||
"""异步推理任务"""
|
||||
try:
|
||||
service = get_inference_service()
|
||||
result = service.process(document_id)
|
||||
return result
|
||||
except Exception as exc:
|
||||
raise self.retry(exc=exc, countdown=60)
|
||||
|
||||
@app.task
|
||||
def train_model(dataset_id: str, config: dict):
|
||||
"""训练任务"""
|
||||
training_service = get_training_service()
|
||||
return training_service.train(dataset_id, config)
|
||||
|
||||
@app.task
|
||||
def auto_label_documents(document_ids: list[str]):
|
||||
"""批量自动标注"""
|
||||
for doc_id in document_ids:
|
||||
auto_label_document.delay(doc_id)
|
||||
|
||||
# 优先级队列
|
||||
app.conf.task_routes = {
|
||||
'tasks.process_inference': {'queue': 'high_priority'},
|
||||
'tasks.train_model': {'queue': 'gpu_queue'},
|
||||
'tasks.auto_label_documents': {'queue': 'low_priority'},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 配置分散 ⚠️ **低风险**
|
||||
|
||||
**问题**: 配置分散在多个文件
|
||||
|
||||
```python
|
||||
# packages/shared/shared/config.py
|
||||
DATABASE = {...}
|
||||
PATHS = {...}
|
||||
AUTOLABEL = {...}
|
||||
|
||||
# packages/inference/inference/web/config.py
|
||||
@dataclass
|
||||
class ModelConfig: ...
|
||||
@dataclass
|
||||
class ServerConfig: ...
|
||||
@dataclass
|
||||
class FileConfig: ...
|
||||
|
||||
# 环境变量
|
||||
# .env 文件
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 配置难以追踪
|
||||
- 可能出现不一致
|
||||
- 缺少配置验证
|
||||
|
||||
**建议**: 使用 Pydantic Settings 集中管理
|
||||
|
||||
```python
|
||||
# config/settings.py
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix='DB_')
|
||||
|
||||
host: str = 'localhost'
|
||||
port: int = 5432
|
||||
name: str = 'docmaster'
|
||||
user: str = 'docmaster'
|
||||
password: str # 无默认值,必须设置
|
||||
|
||||
class StorageSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix='STORAGE_')
|
||||
|
||||
backend: str = 'local'
|
||||
base_path: str = '~/invoice-data'
|
||||
azure_connection_string: str | None = None
|
||||
s3_bucket: str | None = None
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
)
|
||||
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
storage: StorageSettings = StorageSettings()
|
||||
|
||||
# 验证
|
||||
@field_validator('database')
|
||||
def validate_database(cls, v):
|
||||
if not v.password:
|
||||
raise ValueError('Database password is required')
|
||||
return v
|
||||
|
||||
# 全局配置实例
|
||||
settings = Settings()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. 内存队列单点故障 ⚠️ **中风险**
|
||||
|
||||
**问题**: AsyncTaskQueue 和 BatchQueue 基于内存
|
||||
|
||||
```python
|
||||
# workers/async_queue.py
|
||||
class AsyncTaskQueue:
|
||||
def __init__(self):
|
||||
self._queue = Queue() # 内存队列
|
||||
self._workers = []
|
||||
|
||||
def enqueue(self, task: AsyncTask) -> None:
|
||||
self._queue.put(task) # 仅存储在内存
|
||||
```
|
||||
|
||||
**影响**:
|
||||
- 服务重启丢失所有待处理任务
|
||||
- 无法水平扩展
|
||||
- 任务持久化困难
|
||||
|
||||
**建议**: 使用 Redis/RabbitMQ 持久化队列
|
||||
|
||||
---
|
||||
|
||||
### 6. 缺少 API 版本迁移策略 ❓ **低风险**
|
||||
|
||||
**问题**: 有 `/api/v1/` 版本,但缺少升级策略
|
||||
|
||||
```
|
||||
当前: /api/v1/admin/documents
|
||||
未来: /api/v2/admin/documents ?
|
||||
```
|
||||
|
||||
**建议**:
|
||||
- 制定 API 版本升级流程
|
||||
- 使用 Header 版本控制
|
||||
- 维护版本兼容性文档
|
||||
|
||||
---
|
||||
|
||||
## 关键架构风险矩阵
|
||||
|
||||
| 风险项 | 概率 | 影响 | 风险等级 | 优先级 |
|
||||
|--------|------|------|----------|--------|
|
||||
| 内存队列丢失任务 | 中 | 高 | **高** | 🔴 P0 |
|
||||
| AdminDB 职责过重 | 高 | 中 | **中** | 🟡 P1 |
|
||||
| Service 层混合 | 高 | 中 | **中** | 🟡 P1 |
|
||||
| 调度器资源竞争 | 中 | 中 | **中** | 🟡 P1 |
|
||||
| 配置分散 | 高 | 低 | **低** | 🟢 P2 |
|
||||
| API 版本策略 | 低 | 低 | **低** | 🟢 P2 |
|
||||
|
||||
---
|
||||
|
||||
## 改进建议路线图
|
||||
|
||||
### Phase 1: 立即执行 (本周)
|
||||
|
||||
#### 1.1 拆分 AdminDB
|
||||
```python
|
||||
# 创建 repositories 包
|
||||
inference/data/repositories/
|
||||
├── __init__.py
|
||||
├── base.py # Repository 基类
|
||||
├── token.py # TokenRepository
|
||||
├── document.py # DocumentRepository
|
||||
├── annotation.py # AnnotationRepository
|
||||
├── training.py # TrainingRepository
|
||||
├── dataset.py # DatasetRepository
|
||||
└── model.py # ModelRepository
|
||||
```
|
||||
|
||||
#### 1.2 统一配置
|
||||
```python
|
||||
# 创建统一配置模块
|
||||
inference/config/
|
||||
├── __init__.py
|
||||
├── settings.py # Pydantic Settings
|
||||
└── validators.py # 配置验证
|
||||
```
|
||||
|
||||
### Phase 2: 短期执行 (本月)
|
||||
|
||||
#### 2.1 引入消息队列
|
||||
```yaml
|
||||
# docker-compose.yml 添加
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
celery_worker:
|
||||
build: .
|
||||
command: celery -A inference.tasks worker -l info
|
||||
depends_on:
|
||||
- redis
|
||||
- postgres
|
||||
```
|
||||
|
||||
#### 2.2 添加缓存层
|
||||
```python
|
||||
# 使用 Redis 缓存热点数据
|
||||
from redis import Redis
|
||||
|
||||
redis_client = Redis(host='localhost', port=6379)
|
||||
|
||||
class CachedDocumentRepository(DocumentRepository):
|
||||
def find_by_id(self, doc_id: str) -> Document | None:
|
||||
# 先查缓存
|
||||
cached = redis_client.get(f"doc:{doc_id}")
|
||||
if cached:
|
||||
return Document.parse_raw(cached)
|
||||
|
||||
# 再查数据库
|
||||
doc = super().find_by_id(doc_id)
|
||||
if doc:
|
||||
redis_client.setex(f"doc:{doc_id}", 3600, doc.json())
|
||||
return doc
|
||||
```
|
||||
|
||||
### Phase 3: 长期执行 (本季度)
|
||||
|
||||
#### 3.1 数据库读写分离
|
||||
```python
|
||||
# 配置主从数据库
|
||||
class DatabaseManager:
|
||||
def __init__(self):
|
||||
self._master = create_engine(MASTER_DB_URL)
|
||||
self._replica = create_engine(REPLICA_DB_URL)
|
||||
|
||||
def get_session(self, readonly: bool = False) -> Session:
|
||||
engine = self._replica if readonly else self._master
|
||||
return Session(engine)
|
||||
```
|
||||
|
||||
#### 3.2 事件驱动架构
|
||||
```python
|
||||
# 引入事件总线
|
||||
from event_bus import EventBus
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
# 发布事件
|
||||
@router.post("/documents")
|
||||
async def create_document(...):
|
||||
doc = document_repo.save(document)
|
||||
bus.publish('document.created', {'document_id': doc.id})
|
||||
return doc
|
||||
|
||||
# 订阅事件
|
||||
@bus.subscribe('document.created')
|
||||
def on_document_created(event):
|
||||
# 触发自动标注
|
||||
auto_label_task.delay(event['document_id'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 架构演进建议
|
||||
|
||||
### 当前架构 (适合 1-10 用户)
|
||||
|
||||
```
|
||||
Single Instance
|
||||
├── FastAPI App
|
||||
├── Memory Queues
|
||||
└── PostgreSQL
|
||||
```
|
||||
|
||||
### 目标架构 (适合 100+ 用户)
|
||||
|
||||
```
|
||||
Load Balancer
|
||||
├── FastAPI Instance 1
|
||||
├── FastAPI Instance 2
|
||||
└── FastAPI Instance N
|
||||
│
|
||||
┌───────┴───────┐
|
||||
▼ ▼
|
||||
Redis Cluster PostgreSQL
|
||||
(Celery + Cache) (Master + Replica)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 总体评分
|
||||
|
||||
| 维度 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| **模块化** | 8/10 | 包结构清晰,但部分类过大 |
|
||||
| **可扩展性** | 7/10 | 水平扩展良好,垂直扩展受限 |
|
||||
| **可维护性** | 8/10 | 分层合理,但职责边界需细化 |
|
||||
| **可靠性** | 7/10 | 内存队列是单点故障 |
|
||||
| **性能** | 8/10 | 异步处理良好 |
|
||||
| **安全性** | 8/10 | 基础安全到位 |
|
||||
| **总体** | **7.7/10** | 良好的架构基础,需优化细节 |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计合理**: Monorepo + 分层架构适合当前规模
|
||||
2. **主要风险**: 内存队列和数据库职责过重
|
||||
3. **演进路径**: 引入消息队列和缓存层
|
||||
4. **投入产出**: 当前架构可支撑到 100+ 用户,无需大规模重构
|
||||
|
||||
### 下一步行动
|
||||
|
||||
| 优先级 | 任务 | 预计工时 | 影响 |
|
||||
|--------|------|----------|------|
|
||||
| 🔴 P0 | 引入 Celery + Redis | 3 天 | 解决任务丢失问题 |
|
||||
| 🟡 P1 | 拆分 AdminDB | 2 天 | 提升可维护性 |
|
||||
| 🟡 P1 | 统一配置管理 | 1 天 | 减少配置错误 |
|
||||
| 🟢 P2 | 添加缓存层 | 2 天 | 提升性能 |
|
||||
| 🟢 P2 | 数据库读写分离 | 3 天 | 提升扩展性 |
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 关键文件清单
|
||||
|
||||
| 文件 | 职责 | 问题 |
|
||||
|------|------|------|
|
||||
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
|
||||
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
|
||||
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
|
||||
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
|
||||
| `shared/shared/config.py` | 共享配置 | 分散管理 |
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
|
||||
- [Celery Documentation](https://docs.celeryproject.org/)
|
||||
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
|
||||
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
|
||||
805
CODE_REVIEW_REPORT.md
Normal file
805
CODE_REVIEW_REPORT.md
Normal file
@@ -0,0 +1,805 @@
|
||||
# Invoice Master POC v2 - 详细代码审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `C:\Users\yaoji\git\ColaCoder\invoice-master-poc-v2`
|
||||
**代码统计**:
|
||||
- Python文件: 200+ 个
|
||||
- 测试文件: 97 个
|
||||
- TypeScript/React文件: 39 个
|
||||
- 总测试数: 1,601 个
|
||||
- 测试覆盖率: 28%
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
1. [执行摘要](#执行摘要)
|
||||
2. [架构概览](#架构概览)
|
||||
3. [详细模块审查](#详细模块审查)
|
||||
4. [代码质量问题](#代码质量问题)
|
||||
5. [安全风险分析](#安全风险分析)
|
||||
6. [性能问题](#性能问题)
|
||||
7. [改进建议](#改进建议)
|
||||
8. [总结与评分](#总结与评分)
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
### 总体评估
|
||||
|
||||
| 维度 | 评分 | 状态 |
|
||||
|------|------|------|
|
||||
| **代码质量** | 7.5/10 | 良好,但有改进空间 |
|
||||
| **安全性** | 7/10 | 基础安全到位,需加强 |
|
||||
| **可维护性** | 8/10 | 模块化良好 |
|
||||
| **测试覆盖** | 5/10 | 偏低,需提升 |
|
||||
| **性能** | 8/10 | 异步处理良好 |
|
||||
| **文档** | 8/10 | 文档详尽 |
|
||||
| **总体** | **7.3/10** | 生产就绪,需小幅改进 |
|
||||
|
||||
### 关键发现
|
||||
|
||||
**优势:**
|
||||
- 清晰的Monorepo架构,三包分离合理
|
||||
- 类型注解覆盖率高(>90%)
|
||||
- 存储抽象层设计优秀
|
||||
- FastAPI使用规范,依赖注入模式良好
|
||||
- 异常处理完善,自定义异常层次清晰
|
||||
|
||||
**风险:**
|
||||
- 测试覆盖率仅28%,远低于行业标准
|
||||
- AdminDB类过大(50+方法),违反单一职责原则
|
||||
- 内存队列存在单点故障风险
|
||||
- 部分安全细节需加强(时序攻击、文件上传验证)
|
||||
- 前端状态管理简单,可能难以扩展
|
||||
|
||||
---
|
||||
|
||||
## 架构概览
|
||||
|
||||
### 项目结构
|
||||
|
||||
```
|
||||
invoice-master-poc-v2/
|
||||
├── packages/
|
||||
│ ├── shared/ # 共享库 (74个Python文件)
|
||||
│ │ ├── pdf/ # PDF处理
|
||||
│ │ ├── ocr/ # OCR封装
|
||||
│ │ ├── normalize/ # 字段规范化
|
||||
│ │ ├── matcher/ # 字段匹配
|
||||
│ │ ├── storage/ # 存储抽象层
|
||||
│ │ ├── training/ # 训练组件
|
||||
│ │ └── augmentation/# 数据增强
|
||||
│ ├── training/ # 训练服务 (26个Python文件)
|
||||
│ │ ├── cli/ # 命令行工具
|
||||
│ │ ├── yolo/ # YOLO数据集
|
||||
│ │ └── processing/ # 任务处理
|
||||
│ └── inference/ # 推理服务 (100个Python文件)
|
||||
│ ├── web/ # FastAPI应用
|
||||
│ ├── pipeline/ # 推理管道
|
||||
│ ├── data/ # 数据层
|
||||
│ └── cli/ # 命令行工具
|
||||
├── frontend/ # React前端 (39个TS/TSX文件)
|
||||
│ ├── src/
|
||||
│ │ ├── components/ # UI组件
|
||||
│ │ ├── hooks/ # React Query hooks
|
||||
│ │ └── api/ # API客户端
|
||||
└── tests/ # 测试 (97个Python文件)
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 层级 | 技术 | 评估 |
|
||||
|------|------|------|
|
||||
| **前端** | React 18 + TypeScript + Vite + TailwindCSS | 现代栈,类型安全 |
|
||||
| **API框架** | FastAPI + Uvicorn | 高性能,异步支持 |
|
||||
| **数据库** | PostgreSQL + SQLModel | 类型安全ORM |
|
||||
| **目标检测** | YOLOv11 (Ultralytics) | 业界标准 |
|
||||
| **OCR** | PaddleOCR v5 | 支持瑞典语 |
|
||||
| **部署** | Docker + Azure/AWS | 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 详细模块审查
|
||||
|
||||
### 1. Shared Package
|
||||
|
||||
#### 1.1 配置模块 (`shared/config.py`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/config.py`
|
||||
**代码行数**: 82行
|
||||
|
||||
**优点:**
|
||||
- 使用环境变量加载配置,无硬编码敏感信息
|
||||
- DPI配置统一管理(DEFAULT_DPI = 150)
|
||||
- 密码无默认值,强制要求设置
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 配置分散,缺少验证
|
||||
DATABASE = {
|
||||
'host': os.getenv('DB_HOST', '192.168.68.31'), # 硬编码IP
|
||||
'port': int(os.getenv('DB_PORT', '5432')),
|
||||
# ...
|
||||
}
|
||||
|
||||
# 问题2: 缺少类型安全
|
||||
# 建议使用 Pydantic Settings
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 使用 Pydantic Settings 集中管理配置,添加验证逻辑
|
||||
|
||||
---
|
||||
|
||||
#### 1.2 存储抽象层 (`shared/storage/`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/storage/`
|
||||
**包含文件**: 8个
|
||||
|
||||
**优点:**
|
||||
- 设计优秀的抽象接口 `StorageBackend`
|
||||
- 支持 Local/Azure/S3 多后端
|
||||
- 预签名URL支持
|
||||
- 异常层次清晰
|
||||
|
||||
**代码示例 - 优秀设计:**
|
||||
```python
|
||||
class StorageBackend(ABC):
|
||||
@abstractmethod
|
||||
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
|
||||
pass
|
||||
```
|
||||
|
||||
**问题:**
|
||||
- `upload_bytes` 和 `download_bytes` 默认实现使用临时文件,效率较低
|
||||
- 缺少文件类型验证(魔术字节检查)
|
||||
|
||||
**严重程度**: 低
|
||||
**建议**: 子类可重写bytes方法以提高效率,添加文件类型验证
|
||||
|
||||
---
|
||||
|
||||
#### 1.3 异常定义 (`shared/exceptions.py`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/exceptions.py`
|
||||
**代码行数**: 103行
|
||||
|
||||
**优点:**
|
||||
- 清晰的异常层次结构
|
||||
- 所有异常继承自 `InvoiceExtractionError`
|
||||
- 包含详细的错误上下文
|
||||
|
||||
**代码示例:**
|
||||
```python
|
||||
class InvoiceExtractionError(Exception):
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
```
|
||||
|
||||
**评分**: 9/10 - 设计优秀
|
||||
|
||||
---
|
||||
|
||||
#### 1.4 数据增强 (`shared/augmentation/`)
|
||||
|
||||
**文件位置**: `packages/shared/shared/augmentation/`
|
||||
**包含文件**: 10个
|
||||
|
||||
**功能:**
|
||||
- 12种数据增强策略
|
||||
- 透视变换、皱纹、边缘损坏、污渍等
|
||||
- 高斯模糊、运动模糊、噪声等
|
||||
|
||||
**代码质量**: 良好,模块化设计
|
||||
|
||||
---
|
||||
|
||||
### 2. Inference Package
|
||||
|
||||
#### 2.1 认证模块 (`inference/web/core/auth.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/core/auth.py`
|
||||
**代码行数**: 61行
|
||||
|
||||
**优点:**
|
||||
- 使用FastAPI依赖注入模式
|
||||
- Token过期检查
|
||||
- 记录最后使用时间
|
||||
|
||||
**安全问题:**
|
||||
```python
|
||||
# 问题: 时序攻击风险 (第46行)
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired admin token.")
|
||||
|
||||
# 建议: 使用 constant-time 比较
|
||||
import hmac
|
||||
if not hmac.compare_digest(token, expected_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 使用 `hmac.compare_digest()` 进行constant-time比较
|
||||
|
||||
---
|
||||
|
||||
#### 2.2 限流器 (`inference/web/core/rate_limiter.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/core/rate_limiter.py`
|
||||
**代码行数**: 212行
|
||||
|
||||
**优点:**
|
||||
- 滑动窗口算法实现
|
||||
- 线程安全(使用Lock)
|
||||
- 支持并发任务限制
|
||||
- 可配置的限流策略
|
||||
|
||||
**代码示例 - 优秀设计:**
|
||||
```python
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
requests_per_minute: int = 10
|
||||
max_concurrent_jobs: int = 3
|
||||
min_poll_interval_ms: int = 1000
|
||||
```
|
||||
|
||||
**问题:**
|
||||
- 内存存储,服务重启后限流状态丢失
|
||||
- 分布式部署时无法共享限流状态
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 生产环境使用Redis实现分布式限流
|
||||
|
||||
---
|
||||
|
||||
#### 2.3 AdminDB (`inference/data/admin_db.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/data/admin_db.py`
|
||||
**代码行数**: 1300+行
|
||||
|
||||
**严重问题 - 类过大:**
|
||||
```python
|
||||
class AdminDB:
|
||||
# Token管理 (5个方法)
|
||||
# 文档管理 (8个方法)
|
||||
# 标注管理 (6个方法)
|
||||
# 训练任务 (7个方法)
|
||||
# 数据集 (6个方法)
|
||||
# 模型版本 (5个方法)
|
||||
# 批处理 (4个方法)
|
||||
# 锁管理 (3个方法)
|
||||
# ... 总计50+方法
|
||||
```
|
||||
|
||||
**影响:**
|
||||
- 违反单一职责原则
|
||||
- 难以维护
|
||||
- 测试困难
|
||||
- 不同领域变更互相影响
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 按领域拆分为Repository模式
|
||||
|
||||
```python
|
||||
# 建议重构
|
||||
class TokenRepository:
|
||||
def validate(self, token: str) -> bool: ...
|
||||
|
||||
class DocumentRepository:
|
||||
def find_by_id(self, doc_id: str) -> Document | None: ...
|
||||
|
||||
class TrainingRepository:
|
||||
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 2.4 文档路由 (`inference/web/api/v1/admin/documents.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/api/v1/admin/documents.py`
|
||||
**代码行数**: 692行
|
||||
|
||||
**优点:**
|
||||
- FastAPI使用规范
|
||||
- 输入验证完善
|
||||
- 响应模型定义清晰
|
||||
- 错误处理良好
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 文件上传缺少魔术字节验证 (第127-131行)
|
||||
content = await file.read()
|
||||
# 建议: 验证PDF魔术字节 %PDF
|
||||
|
||||
# 问题2: 路径遍历风险 (第494-498行)
|
||||
filename = Path(document.file_path).name
|
||||
# 建议: 使用 Path.name 并验证路径范围
|
||||
|
||||
# 问题3: 函数过长,职责过多
|
||||
# _convert_pdf_to_images 函数混合了PDF处理和存储操作
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 添加文件类型验证,拆分大函数
|
||||
|
||||
---
|
||||
|
||||
#### 2.5 推理服务 (`inference/web/services/inference.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/services/inference.py`
|
||||
**代码行数**: 361行
|
||||
|
||||
**优点:**
|
||||
- 支持动态模型加载
|
||||
- 懒加载初始化
|
||||
- 模型热重载支持
|
||||
|
||||
**问题:**
|
||||
```python
|
||||
# 问题1: 混合业务逻辑和技术实现
|
||||
def process_image(self, image_path: Path, ...) -> ServiceResult:
|
||||
# 1. 技术细节: 图像解码
|
||||
# 2. 业务逻辑: 字段提取
|
||||
# 3. 技术细节: 模型推理
|
||||
# 4. 业务逻辑: 结果验证
|
||||
|
||||
# 问题2: 可视化方法重复加载模型
|
||||
model = YOLO(str(self.model_config.model_path)) # 第316行
|
||||
# 应该在初始化时加载,避免重复IO
|
||||
|
||||
# 问题3: 临时文件未使用上下文管理器
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
# 建议使用 tempfile 上下文管理器
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 引入领域层和适配器模式,分离业务和技术逻辑
|
||||
|
||||
---
|
||||
|
||||
#### 2.6 异步队列 (`inference/web/workers/async_queue.py`)
|
||||
|
||||
**文件位置**: `packages/inference/inference/web/workers/async_queue.py`
|
||||
**代码行数**: 213行
|
||||
|
||||
**优点:**
|
||||
- 线程安全实现
|
||||
- 优雅关闭支持
|
||||
- 任务状态跟踪
|
||||
|
||||
**严重问题:**
|
||||
```python
|
||||
# 问题: 内存队列,服务重启丢失任务 (第42行)
|
||||
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
|
||||
|
||||
# 问题: 无法水平扩展
|
||||
# 问题: 任务持久化困难
|
||||
```
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 使用Redis/RabbitMQ持久化队列
|
||||
|
||||
---
|
||||
|
||||
### 3. Training Package
|
||||
|
||||
#### 3.1 整体评估
|
||||
|
||||
**文件数量**: 26个Python文件
|
||||
|
||||
**优点:**
|
||||
- CLI工具设计良好
|
||||
- 双池协调器(CPU + GPU)设计优秀
|
||||
- 数据增强策略丰富
|
||||
|
||||
**总体评分**: 8/10
|
||||
|
||||
---
|
||||
|
||||
### 4. Frontend
|
||||
|
||||
#### 4.1 API客户端 (`frontend/src/api/client.ts`)
|
||||
|
||||
**文件位置**: `frontend/src/api/client.ts`
|
||||
**代码行数**: 42行
|
||||
|
||||
**优点:**
|
||||
- Axios配置清晰
|
||||
- 请求/响应拦截器
|
||||
- 认证token自动添加
|
||||
|
||||
**问题:**
|
||||
```typescript
|
||||
// 问题1: Token存储在localStorage,存在XSS风险
|
||||
const token = localStorage.getItem('admin_token')
|
||||
|
||||
// 问题2: 401错误处理不完整
|
||||
if (error.response?.status === 401) {
|
||||
console.warn('Authentication required...')
|
||||
// 应该触发重新登录或token刷新
|
||||
}
|
||||
```
|
||||
|
||||
**严重程度**: 中
|
||||
**建议**: 考虑使用http-only cookie存储token,完善错误处理
|
||||
|
||||
---
|
||||
|
||||
#### 4.2 Dashboard组件 (`frontend/src/components/Dashboard.tsx`)
|
||||
|
||||
**文件位置**: `frontend/src/components/Dashboard.tsx`
|
||||
**代码行数**: 301行
|
||||
|
||||
**优点:**
|
||||
- React hooks使用规范
|
||||
- 类型定义清晰
|
||||
- UI响应式设计
|
||||
|
||||
**问题:**
|
||||
```typescript
|
||||
// 问题1: 硬编码的进度值
|
||||
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
|
||||
if (doc.auto_label_status === 'running') {
|
||||
return 45 // 硬编码!
|
||||
}
|
||||
// ...
|
||||
}
|
||||
|
||||
// 问题2: 搜索功能未实现
|
||||
// 没有onChange处理
|
||||
|
||||
// 问题3: 缺少错误边界处理
|
||||
// 组件应该包裹在Error Boundary中
|
||||
```
|
||||
|
||||
**严重程度**: 低
|
||||
**建议**: 实现真实的进度获取,添加搜索功能
|
||||
|
||||
---
|
||||
|
||||
#### 4.3 整体评估
|
||||
|
||||
**优点:**
|
||||
- TypeScript类型安全
|
||||
- React Query状态管理
|
||||
- TailwindCSS样式一致
|
||||
|
||||
**问题:**
|
||||
- 缺少错误边界
|
||||
- 部分功能硬编码
|
||||
- 缺少单元测试
|
||||
|
||||
**总体评分**: 7.5/10
|
||||
|
||||
---
|
||||
|
||||
### 5. Tests
|
||||
|
||||
#### 5.1 测试统计
|
||||
|
||||
- **测试文件数**: 97个
|
||||
- **测试总数**: 1,601个
|
||||
- **测试覆盖率**: 28%
|
||||
|
||||
#### 5.2 覆盖率分析
|
||||
|
||||
| 模块 | 估计覆盖率 | 状态 |
|
||||
|------|-----------|------|
|
||||
| `shared/` | 35% | 偏低 |
|
||||
| `inference/web/` | 25% | 偏低 |
|
||||
| `inference/pipeline/` | 20% | 严重不足 |
|
||||
| `training/` | 30% | 偏低 |
|
||||
| `frontend/` | 15% | 严重不足 |
|
||||
|
||||
#### 5.3 测试质量问题
|
||||
|
||||
**优点:**
|
||||
- 使用了pytest框架
|
||||
- 有conftest.py配置
|
||||
- 部分集成测试
|
||||
|
||||
**问题:**
|
||||
- 覆盖率远低于行业标准(80%)
|
||||
- 缺少端到端测试
|
||||
- 部分测试可能过于简单
|
||||
|
||||
**严重程度**: 高
|
||||
**建议**: 制定测试计划,优先覆盖核心业务逻辑
|
||||
|
||||
---
|
||||
|
||||
## 代码质量问题
|
||||
|
||||
### 高优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| AdminDB类过大 | `inference/data/admin_db.py` | 维护困难 | 拆分为Repository模式 |
|
||||
| 内存队列单点故障 | `inference/web/workers/async_queue.py` | 任务丢失 | 使用Redis持久化 |
|
||||
| 测试覆盖率过低 | 全项目 | 代码风险 | 提升至60%+ |
|
||||
|
||||
### 中优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| 时序攻击风险 | `inference/web/core/auth.py` | 安全漏洞 | 使用hmac.compare_digest |
|
||||
| 限流器内存存储 | `inference/web/core/rate_limiter.py` | 分布式问题 | 使用Redis |
|
||||
| 配置分散 | `shared/config.py` | 难以管理 | 使用Pydantic Settings |
|
||||
| 文件上传验证不足 | `inference/web/api/v1/admin/documents.py` | 安全风险 | 添加魔术字节验证 |
|
||||
| 推理服务混合职责 | `inference/web/services/inference.py` | 难以测试 | 分离业务和技术逻辑 |
|
||||
|
||||
### 低优先级问题
|
||||
|
||||
| 问题 | 位置 | 影响 | 建议 |
|
||||
|------|------|------|------|
|
||||
| 前端搜索未实现 | `frontend/src/components/Dashboard.tsx` | 功能缺失 | 实现搜索功能 |
|
||||
| 硬编码进度值 | `frontend/src/components/Dashboard.tsx` | 用户体验 | 获取真实进度 |
|
||||
| Token存储方式 | `frontend/src/api/client.ts` | XSS风险 | 考虑http-only cookie |
|
||||
|
||||
---
|
||||
|
||||
## 安全风险分析
|
||||
|
||||
### 已识别的安全风险
|
||||
|
||||
#### 1. 时序攻击 (中风险)
|
||||
|
||||
**位置**: `inference/web/core/auth.py:46`
|
||||
|
||||
```python
|
||||
# 当前实现(有风险)
|
||||
if not admin_db.is_valid_admin_token(x_admin_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
|
||||
# 安全实现
|
||||
import hmac
|
||||
if not hmac.compare_digest(token, expected_token):
|
||||
raise HTTPException(status_code=401, ...)
|
||||
```
|
||||
|
||||
#### 2. 文件上传验证不足 (中风险)
|
||||
|
||||
**位置**: `inference/web/api/v1/admin/documents.py:127-131`
|
||||
|
||||
```python
|
||||
# 建议添加魔术字节验证
|
||||
ALLOWED_EXTENSIONS = {".pdf"}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise HTTPException(400, "Invalid PDF file format")
|
||||
```
|
||||
|
||||
#### 3. 路径遍历风险 (中风险)
|
||||
|
||||
**位置**: `inference/web/api/v1/admin/documents.py:494-498`
|
||||
|
||||
```python
|
||||
# 建议实现
|
||||
from pathlib import Path
|
||||
|
||||
def get_safe_path(filename: str, base_dir: Path) -> Path:
|
||||
safe_name = Path(filename).name
|
||||
full_path = (base_dir / safe_name).resolve()
|
||||
if not full_path.is_relative_to(base_dir):
|
||||
raise HTTPException(400, "Invalid file path")
|
||||
return full_path
|
||||
```
|
||||
|
||||
#### 4. CORS配置 (低风险)
|
||||
|
||||
**位置**: FastAPI中间件配置
|
||||
|
||||
```python
|
||||
# 建议生产环境配置
|
||||
ALLOWED_ORIGINS = [
|
||||
"http://localhost:5173",
|
||||
"https://your-domain.com",
|
||||
]
|
||||
```
|
||||
|
||||
#### 5. XSS风险 (低风险)
|
||||
|
||||
**位置**: `frontend/src/api/client.ts:13`
|
||||
|
||||
```typescript
|
||||
// 当前实现
|
||||
const token = localStorage.getItem('admin_token')
|
||||
|
||||
// 建议考虑
|
||||
// 使用http-only cookie存储敏感token
|
||||
```
|
||||
|
||||
### 安全评分
|
||||
|
||||
| 类别 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| 认证 | 8/10 | 基础良好,需加强时序攻击防护 |
|
||||
| 输入验证 | 7/10 | 基本验证到位,需加强文件验证 |
|
||||
| 数据保护 | 8/10 | 无敏感信息硬编码 |
|
||||
| 传输安全 | 8/10 | 使用HTTPS(生产环境) |
|
||||
| 总体 | 7.5/10 | 基础安全良好,需加强细节 |
|
||||
|
||||
---
|
||||
|
||||
## 性能问题
|
||||
|
||||
### 已识别的性能问题
|
||||
|
||||
#### 1. 重复模型加载
|
||||
|
||||
**位置**: `inference/web/services/inference.py:316`
|
||||
|
||||
```python
|
||||
# 问题: 每次可视化都重新加载模型
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
|
||||
# 建议: 复用已加载的模型
|
||||
```
|
||||
|
||||
#### 2. 临时文件处理
|
||||
|
||||
**位置**: `shared/storage/base.py:178-203`
|
||||
|
||||
```python
|
||||
# 问题: bytes操作使用临时文件
|
||||
def upload_bytes(self, data: bytes, ...):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(data)
|
||||
temp_path = Path(f.name)
|
||||
# ...
|
||||
|
||||
# 建议: 子类重写为直接上传
|
||||
```
|
||||
|
||||
#### 3. 数据库查询优化
|
||||
|
||||
**位置**: `inference/data/admin_db.py`
|
||||
|
||||
```python
|
||||
# 问题: N+1查询风险
|
||||
for doc in documents:
|
||||
annotations = db.get_annotations_for_document(str(doc.document_id))
|
||||
# ...
|
||||
|
||||
# 建议: 使用join预加载
|
||||
```
|
||||
|
||||
### 性能评分
|
||||
|
||||
| 类别 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| 响应时间 | 8/10 | 异步处理良好 |
|
||||
| 资源使用 | 7/10 | 有优化空间 |
|
||||
| 可扩展性 | 7/10 | 内存队列限制 |
|
||||
| 并发处理 | 8/10 | 线程池设计良好 |
|
||||
| 总体 | 7.5/10 | 良好,有优化空间 |
|
||||
|
||||
---
|
||||
|
||||
## 改进建议
|
||||
|
||||
### 立即执行 (本周)
|
||||
|
||||
1. **拆分AdminDB**
|
||||
- 创建 `repositories/` 目录
|
||||
- 按领域拆分:TokenRepository, DocumentRepository, TrainingRepository
|
||||
- 估计工时: 2天
|
||||
|
||||
2. **修复安全漏洞**
|
||||
- 添加 `hmac.compare_digest()` 时序攻击防护
|
||||
- 添加文件魔术字节验证
|
||||
- 估计工时: 0.5天
|
||||
|
||||
3. **提升测试覆盖率**
|
||||
- 优先测试 `inference/pipeline/`
|
||||
- 添加API集成测试
|
||||
- 目标: 从28%提升至50%
|
||||
- 估计工时: 3天
|
||||
|
||||
### 短期执行 (本月)
|
||||
|
||||
4. **引入消息队列**
|
||||
- 添加Redis服务
|
||||
- 使用Celery替换内存队列
|
||||
- 估计工时: 3天
|
||||
|
||||
5. **统一配置管理**
|
||||
- 使用 Pydantic Settings
|
||||
- 集中验证逻辑
|
||||
- 估计工时: 1天
|
||||
|
||||
6. **添加缓存层**
|
||||
- Redis缓存热点数据
|
||||
- 缓存文档、模型配置
|
||||
- 估计工时: 2天
|
||||
|
||||
### 长期执行 (本季度)
|
||||
|
||||
7. **数据库读写分离**
|
||||
- 配置主从数据库
|
||||
- 读操作使用从库
|
||||
- 估计工时: 3天
|
||||
|
||||
8. **事件驱动架构**
|
||||
- 引入事件总线
|
||||
- 解耦模块依赖
|
||||
- 估计工时: 5天
|
||||
|
||||
9. **前端优化**
|
||||
- 添加错误边界
|
||||
- 实现真实搜索功能
|
||||
- 添加E2E测试
|
||||
- 估计工时: 3天
|
||||
|
||||
---
|
||||
|
||||
## 总结与评分
|
||||
|
||||
### 各维度评分
|
||||
|
||||
| 维度 | 评分 | 权重 | 加权得分 |
|
||||
|------|------|------|----------|
|
||||
| **代码质量** | 7.5/10 | 20% | 1.5 |
|
||||
| **安全性** | 7.5/10 | 20% | 1.5 |
|
||||
| **可维护性** | 8/10 | 15% | 1.2 |
|
||||
| **测试覆盖** | 5/10 | 15% | 0.75 |
|
||||
| **性能** | 7.5/10 | 15% | 1.125 |
|
||||
| **文档** | 8/10 | 10% | 0.8 |
|
||||
| **架构设计** | 8/10 | 5% | 0.4 |
|
||||
| **总体** | **7.3/10** | 100% | **7.275** |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
|
||||
2. **代码质量良好**: 类型注解完善,文档详尽,结构清晰
|
||||
3. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
|
||||
4. **测试是短板**: 28%覆盖率是最大风险点
|
||||
5. **生产就绪**: 经过小幅改进后可以投入生产使用
|
||||
|
||||
### 下一步行动
|
||||
|
||||
| 优先级 | 任务 | 预计工时 | 影响 |
|
||||
|--------|------|----------|------|
|
||||
| 高 | 拆分AdminDB | 2天 | 提升可维护性 |
|
||||
| 高 | 引入Redis队列 | 3天 | 解决任务丢失问题 |
|
||||
| 高 | 提升测试覆盖率 | 5天 | 降低代码风险 |
|
||||
| 中 | 修复安全漏洞 | 0.5天 | 提升安全性 |
|
||||
| 中 | 统一配置管理 | 1天 | 减少配置错误 |
|
||||
| 低 | 前端优化 | 3天 | 提升用户体验 |
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 关键文件清单
|
||||
|
||||
| 文件 | 职责 | 问题 |
|
||||
|------|------|------|
|
||||
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
|
||||
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
|
||||
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
|
||||
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
|
||||
| `shared/shared/config.py` | 共享配置 | 分散管理 |
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
|
||||
- [Celery Documentation](https://docs.celeryproject.org/)
|
||||
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
|
||||
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
|
||||
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
|
||||
|
||||
---
|
||||
|
||||
**报告生成时间**: 2026-02-01
|
||||
**审查工具**: Claude Code + AST-grep + LSP
|
||||
637
COMMERCIALIZATION_ANALYSIS_REPORT.md
Normal file
637
COMMERCIALIZATION_ANALYSIS_REPORT.md
Normal file
@@ -0,0 +1,637 @@
|
||||
# Invoice Master POC v2 - 商业化分析报告
|
||||
|
||||
**报告日期**: 2026-02-01
|
||||
**分析人**: Claude Code
|
||||
**项目**: Invoice Master - 瑞典发票字段自动提取系统
|
||||
**当前状态**: POC阶段,已处理9,738份文档,字段匹配率94.8%
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
1. [执行摘要](#执行摘要)
|
||||
2. [市场分析](#市场分析)
|
||||
3. [商业模式建议](#商业模式建议)
|
||||
4. [技术架构商业化评估](#技术架构商业化评估)
|
||||
5. [商业化路线图](#商业化路线图)
|
||||
6. [风险与挑战](#风险与挑战)
|
||||
7. [成本与定价策略](#成本与定价策略)
|
||||
8. [竞争分析](#竞争分析)
|
||||
9. [改进建议](#改进建议)
|
||||
10. [总结与建议](#总结与建议)
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
### 项目现状
|
||||
|
||||
Invoice Master是一个基于YOLOv11 + PaddleOCR的瑞典发票字段自动提取系统,具备以下核心能力:
|
||||
|
||||
| 指标 | 数值 | 评估 |
|
||||
|------|------|------|
|
||||
| 已处理文档 | 9,738份 | 数据基础良好 |
|
||||
| 字段匹配率 | 94.8% | 接近商业化标准 |
|
||||
| 模型mAP@0.5 | 93.5% | 业界优秀水平 |
|
||||
| 测试覆盖率 | 28% | 需大幅提升 |
|
||||
| 架构成熟度 | 7.3/10 | 基本就绪 |
|
||||
|
||||
### 商业化可行性评估
|
||||
|
||||
| 维度 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| **技术成熟度** | 7.5/10 | 核心算法成熟,需完善工程化 |
|
||||
| **市场需求** | 8/10 | 发票处理是刚需市场 |
|
||||
| **竞争壁垒** | 6/10 | 技术可替代,需构建数据壁垒 |
|
||||
| **商业化就绪度** | 6.5/10 | 需完成产品化和合规准备 |
|
||||
| **总体评估** | **7/10** | **具备商业化潜力,需6-12个月准备** |
|
||||
|
||||
### 关键建议
|
||||
|
||||
1. **短期(3个月)**: 提升测试覆盖率至80%,完成安全加固
|
||||
2. **中期(6个月)**: 推出MVP产品,获取首批付费客户
|
||||
3. **长期(12个月)**: 扩展多语言支持,进入国际市场
|
||||
|
||||
---
|
||||
|
||||
## 市场分析
|
||||
|
||||
### 目标市场
|
||||
|
||||
#### 1.1 市场规模
|
||||
|
||||
**全球发票处理市场**
|
||||
- 市场规模: ~$30B (2024)
|
||||
- 年增长率: 12-15%
|
||||
- 驱动因素: 数字化转型、合规要求、成本节约
|
||||
|
||||
**瑞典/北欧市场**
|
||||
- 中小企业数量: ~100万+
|
||||
- 大型企业: ~2,000家
|
||||
- 年发票处理量: ~5亿张
|
||||
- 市场特点: 数字化程度高,合规要求严格
|
||||
|
||||
#### 1.2 目标客户画像
|
||||
|
||||
| 客户类型 | 规模 | 痛点 | 付费意愿 | 获取难度 |
|
||||
|----------|------|------|----------|----------|
|
||||
| **中小企业** | 10-100人 | 手动录入耗时 | 中 | 低 |
|
||||
| **会计事务所** | 5-50人 | 批量处理需求 | 高 | 中 |
|
||||
| **大型企业** | 500+人 | 系统集成需求 | 高 | 高 |
|
||||
| **SaaS平台** | - | API集成需求 | 中 | 中 |
|
||||
|
||||
### 市场需求验证
|
||||
|
||||
#### 2.1 痛点分析
|
||||
|
||||
**现有解决方案的问题:**
|
||||
1. **传统OCR**: 准确率70-85%,需要大量人工校对
|
||||
2. **人工录入**: 成本高($0.5-2/张),速度慢,易出错
|
||||
3. **现有AI方案**: 价格昂贵,定制化程度低
|
||||
|
||||
**Invoice Master的优势:**
|
||||
- 准确率94.8%,接近人工水平
|
||||
- 支持瑞典特有的字段(OCR参考号、Bankgiro/Plusgiro)
|
||||
- 可定制化训练,适应不同发票格式
|
||||
|
||||
#### 2.2 市场进入策略
|
||||
|
||||
**第一阶段: 瑞典市场验证**
|
||||
- 目标客户: 中型会计事务所
|
||||
- 价值主张: 减少80%人工录入时间
|
||||
- 定价: $0.1-0.2/张 或 $99-299/月
|
||||
|
||||
**第二阶段: 北欧扩展**
|
||||
- 扩展至挪威、丹麦、芬兰
|
||||
- 适配各国发票格式
|
||||
- 建立本地合作伙伴网络
|
||||
|
||||
**第三阶段: 欧洲市场**
|
||||
- 支持多语言(德语、法语、英语)
|
||||
- GDPR合规认证
|
||||
- 与主流ERP系统集成
|
||||
|
||||
---
|
||||
|
||||
## 商业模式建议
|
||||
|
||||
### 3.1 商业模式选项
|
||||
|
||||
#### 选项A: SaaS订阅模式 (推荐)
|
||||
|
||||
**定价结构:**
|
||||
```
|
||||
Starter: $99/月
|
||||
- 500张发票/月
|
||||
- 基础字段提取
|
||||
- 邮件支持
|
||||
|
||||
Professional: $299/月
|
||||
- 2,000张发票/月
|
||||
- 所有字段+自定义字段
|
||||
- API访问
|
||||
- 优先支持
|
||||
|
||||
Enterprise: 定制报价
|
||||
- 无限发票
|
||||
- 私有部署选项
|
||||
- SLA保障
|
||||
- 专属客户经理
|
||||
```
|
||||
|
||||
**优势:**
|
||||
- 可预测的经常性收入
|
||||
- 客户生命周期价值高
|
||||
- 易于扩展
|
||||
|
||||
**劣势:**
|
||||
- 需要持续的产品迭代
|
||||
- 客户获取成本较高
|
||||
|
||||
#### 选项B: 按量付费模式
|
||||
|
||||
**定价:**
|
||||
- 前100张: $0.15/张
|
||||
- 101-1000张: $0.10/张
|
||||
- 1001+张: $0.05/张
|
||||
|
||||
**适用场景:**
|
||||
- 季节性业务
|
||||
- 初创企业
|
||||
- 不确定使用量的客户
|
||||
|
||||
#### 选项C: 授权许可模式
|
||||
|
||||
**定价:**
|
||||
- 年度许可: $10,000-50,000
|
||||
- 按部署规模收费
|
||||
- 包含培训和定制开发
|
||||
|
||||
**适用场景:**
|
||||
- 大型企业
|
||||
- 数据敏感行业
|
||||
- 需要私有部署的客户
|
||||
|
||||
### 3.2 推荐模式: 混合模式
|
||||
|
||||
**核心产品: SaaS订阅**
|
||||
- 面向中小企业和会计事务所
|
||||
- 标准化产品,快速交付
|
||||
|
||||
**增值服务: 定制开发**
|
||||
- 面向大型企业
|
||||
- 私有部署选项
|
||||
- 按项目收费
|
||||
|
||||
**API服务: 按量付费**
|
||||
- 面向SaaS平台和开发者
|
||||
- 开发者友好定价
|
||||
|
||||
### 3.3 收入预测
|
||||
|
||||
**保守估计 (第一年)**
|
||||
| 客户类型 | 客户数 | ARPU | MRR | 年收入 |
|
||||
|----------|--------|------|-----|--------|
|
||||
| Starter | 20 | $99 | $1,980 | $23,760 |
|
||||
| Professional | 10 | $299 | $2,990 | $35,880 |
|
||||
| Enterprise | 2 | $2,000 | $4,000 | $48,000 |
|
||||
| **总计** | **32** | - | **$8,970** | **$107,640** |
|
||||
|
||||
**乐观估计 (第一年)**
|
||||
- 客户数: 100+
|
||||
- 年收入: $300,000-500,000
|
||||
|
||||
---
|
||||
|
||||
## 技术架构商业化评估
|
||||
|
||||
### 4.1 架构优势
|
||||
|
||||
| 优势 | 说明 | 商业化价值 |
|
||||
|------|------|-----------|
|
||||
| **Monorepo结构** | 代码组织清晰 | 降低维护成本 |
|
||||
| **云原生架构** | 支持AWS/Azure | 灵活部署选项 |
|
||||
| **存储抽象层** | 支持多后端 | 满足不同客户需求 |
|
||||
| **模型版本管理** | 可追溯可回滚 | 企业级可靠性 |
|
||||
| **API优先设计** | RESTful API | 易于集成和扩展 |
|
||||
|
||||
### 4.2 商业化就绪度评估
|
||||
|
||||
#### 高优先级改进项
|
||||
|
||||
| 问题 | 影响 | 改进建议 | 工时 |
|
||||
|------|------|----------|------|
|
||||
| **测试覆盖率28%** | 质量风险 | 提升至80%+ | 4周 |
|
||||
| **AdminDB过大** | 维护困难 | 拆分Repository | 2周 |
|
||||
| **内存队列** | 单点故障 | 引入Redis | 2周 |
|
||||
| **安全漏洞** | 合规风险 | 修复时序攻击等 | 1周 |
|
||||
|
||||
#### 中优先级改进项
|
||||
|
||||
| 问题 | 影响 | 改进建议 | 工时 |
|
||||
|------|------|----------|------|
|
||||
| **缺少审计日志** | 合规要求 | 添加完整审计 | 2周 |
|
||||
| **无多租户隔离** | 数据安全 | 实现租户隔离 | 3周 |
|
||||
| **限流器内存存储** | 扩展性 | Redis分布式限流 | 1周 |
|
||||
| **配置分散** | 运维难度 | 统一配置中心 | 1周 |
|
||||
|
||||
### 4.3 技术债务清理计划
|
||||
|
||||
**阶段1: 基础加固 (4周)**
|
||||
- 提升测试覆盖率至60%
|
||||
- 修复安全漏洞
|
||||
- 添加基础监控
|
||||
|
||||
**阶段2: 架构优化 (6周)**
|
||||
- 拆分AdminDB
|
||||
- 引入消息队列
|
||||
- 实现多租户支持
|
||||
|
||||
**阶段3: 企业级功能 (8周)**
|
||||
- 完整审计日志
|
||||
- SSO集成
|
||||
- 高级权限管理
|
||||
|
||||
---
|
||||
|
||||
## 商业化路线图
|
||||
|
||||
### 5.1 时间线规划
|
||||
|
||||
```
|
||||
Month 1-3: 产品化准备
|
||||
├── 技术债务清理
|
||||
├── 安全加固
|
||||
├── 测试覆盖率提升
|
||||
└── 文档完善
|
||||
|
||||
Month 4-6: MVP发布
|
||||
├── 核心功能稳定
|
||||
├── 基础监控告警
|
||||
├── 客户反馈收集
|
||||
└── 定价策略验证
|
||||
|
||||
Month 7-9: 市场扩展
|
||||
├── 销售团队组建
|
||||
├── 合作伙伴网络
|
||||
├── 案例研究制作
|
||||
└── 营销自动化
|
||||
|
||||
Month 10-12: 规模化
|
||||
├── 多语言支持
|
||||
├── 高级功能开发
|
||||
├── 国际市场准备
|
||||
└── 融资准备
|
||||
```
|
||||
|
||||
### 5.2 里程碑
|
||||
|
||||
| 里程碑 | 时间 | 成功标准 |
|
||||
|--------|------|----------|
|
||||
| **技术就绪** | M3 | 测试80%,零高危漏洞 |
|
||||
| **首个付费客户** | M4 | 签约并上线 |
|
||||
| **产品市场契合** | M6 | 10+付费客户,NPS>40 |
|
||||
| **盈亏平衡** | M9 | MRR覆盖运营成本 |
|
||||
| **规模化准备** | M12 | 100+客户,$50K+MRR |
|
||||
|
||||
### 5.3 团队组建建议
|
||||
|
||||
**核心团队 (前6个月)**
|
||||
| 角色 | 人数 | 职责 |
|
||||
|------|------|------|
|
||||
| 技术负责人 | 1 | 架构、技术决策 |
|
||||
| 全栈工程师 | 2 | 产品开发 |
|
||||
| ML工程师 | 1 | 模型优化 |
|
||||
| 产品经理 | 1 | 产品规划 |
|
||||
| 销售/BD | 1 | 客户获取 |
|
||||
|
||||
**扩展团队 (6-12个月)**
|
||||
| 角色 | 人数 | 职责 |
|
||||
|------|------|------|
|
||||
| 客户成功 | 1 | 客户留存 |
|
||||
| 市场营销 | 1 | 品牌建设 |
|
||||
| 技术支持 | 1 | 客户支持 |
|
||||
|
||||
---
|
||||
|
||||
## 风险与挑战
|
||||
|
||||
### 6.1 技术风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **模型准确率下降** | 中 | 高 | 持续训练,A/B测试 |
|
||||
| **系统稳定性** | 中 | 高 | 完善监控,灰度发布 |
|
||||
| **数据安全漏洞** | 低 | 高 | 安全审计,渗透测试 |
|
||||
| **扩展性瓶颈** | 中 | 中 | 架构优化,负载测试 |
|
||||
|
||||
### 6.2 市场风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **竞争加剧** | 高 | 中 | 差异化定位,垂直深耕 |
|
||||
| **价格战** | 中 | 中 | 价值定价,增值服务 |
|
||||
| **客户获取困难** | 中 | 高 | 内容营销,口碑传播 |
|
||||
| **市场教育成本** | 中 | 中 | 免费试用,案例展示 |
|
||||
|
||||
### 6.3 合规风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **GDPR合规** | 高 | 高 | 隐私设计,数据本地化 |
|
||||
| **数据主权** | 中 | 高 | 多区域部署选项 |
|
||||
| **行业认证** | 中 | 中 | ISO27001, SOC2准备 |
|
||||
|
||||
### 6.4 财务风险
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|----------|
|
||||
| **现金流紧张** | 中 | 高 | 预付费模式,成本控制 |
|
||||
| **客户流失** | 中 | 中 | 客户成功,年度合同 |
|
||||
| **定价失误** | 中 | 中 | 灵活定价,快速迭代 |
|
||||
|
||||
---
|
||||
|
||||
## 成本与定价策略
|
||||
|
||||
### 7.1 运营成本估算
|
||||
|
||||
**月度运营成本 (AWS)**
|
||||
| 项目 | 成本 | 说明 |
|
||||
|------|------|------|
|
||||
| 计算 (ECS Fargate) | $150 | 推理服务 |
|
||||
| 数据库 (RDS) | $50 | PostgreSQL |
|
||||
| 存储 (S3) | $20 | 文档和模型 |
|
||||
| 训练 (SageMaker) | $100 | 按需训练 |
|
||||
| 监控/日志 | $30 | CloudWatch等 |
|
||||
| **小计** | **$350** | **基础运营成本** |
|
||||
|
||||
**月度运营成本 (Azure)**
|
||||
| 项目 | 成本 | 说明 |
|
||||
|------|------|------|
|
||||
| 计算 (Container Apps) | $180 | 推理服务 |
|
||||
| 数据库 | $60 | PostgreSQL |
|
||||
| 存储 | $25 | Blob Storage |
|
||||
| 训练 | $120 | Azure ML |
|
||||
| **小计** | **$385** | **基础运营成本** |
|
||||
|
||||
**人力成本 (月度)**
|
||||
| 阶段 | 人数 | 成本 |
|
||||
|------|------|------|
|
||||
| 启动期 (1-3月) | 3 | $15,000 |
|
||||
| 成长期 (4-9月) | 5 | $25,000 |
|
||||
| 规模化 (10-12月) | 7 | $35,000 |
|
||||
|
||||
### 7.2 定价策略
|
||||
|
||||
**成本加成定价**
|
||||
- 基础成本: $350/月
|
||||
- 目标毛利率: 70%
|
||||
- 最低收费: $1,000/月
|
||||
|
||||
**价值定价**
|
||||
- 客户节省成本: $2-5/张 (人工录入)
|
||||
- 收费: $0.1-0.2/张
|
||||
- 客户ROI: 10-50x
|
||||
|
||||
**竞争定价**
|
||||
- 竞争对手: $0.2-0.5/张
|
||||
- 我们的定价: $0.1-0.15/张
|
||||
- 策略: 高性价比切入
|
||||
|
||||
### 7.3 盈亏平衡分析
|
||||
|
||||
**固定成本: $25,000/月** (人力+基础设施)
|
||||
|
||||
**盈亏平衡点:**
|
||||
- 按订阅模式: 85个Professional客户 或 250个Starter客户
|
||||
- 按量付费: 250,000张发票/月
|
||||
|
||||
**目标 (12个月):**
|
||||
- MRR: $50,000
|
||||
- 客户数: 150
|
||||
- 毛利率: 75%
|
||||
|
||||
---
|
||||
|
||||
## 竞争分析
|
||||
|
||||
### 8.1 竞争对手
|
||||
|
||||
#### 直接竞争对手
|
||||
|
||||
| 公司 | 产品 | 优势 | 劣势 | 定价 |
|
||||
|------|------|------|------|------|
|
||||
| **Rossum** | AI发票处理 | 技术成熟,欧洲市场强 | 价格高 | $0.3-0.5/张 |
|
||||
| **Hypatos** | 文档AI | 德国市场深耕 | 定制化弱 | 定制报价 |
|
||||
| **Klippa** | 文档解析 | API友好 | 准确率一般 | $0.1-0.2/张 |
|
||||
| **Nanonets** | 工作流自动化 | 易用性好 | 发票专业性弱 | $0.05-0.15/张 |
|
||||
|
||||
#### 间接竞争对手
|
||||
|
||||
| 类型 | 代表 | 威胁程度 |
|
||||
|------|------|----------|
|
||||
| **传统OCR** | ABBYY, Tesseract | 中 |
|
||||
| **ERP内置** | SAP, Oracle | 中 |
|
||||
| **会计软件** | Visma, Fortnox | 高 |
|
||||
|
||||
### 8.2 竞争优势
|
||||
|
||||
**短期优势 (6-12个月)**
|
||||
1. **瑞典市场专注**: 本地化字段支持
|
||||
2. **价格优势**: 比Rossum便宜50%+
|
||||
3. **定制化**: 可训练专属模型
|
||||
|
||||
**长期优势 (1-3年)**
|
||||
1. **数据壁垒**: 训练数据积累
|
||||
2. **行业深度**: 垂直行业解决方案
|
||||
3. **生态集成**: 与主流ERP深度集成
|
||||
|
||||
### 8.3 竞争策略
|
||||
|
||||
**差异化定位**
|
||||
- 不做通用文档处理,专注发票领域
|
||||
- 不做全球市场,先做透北欧
|
||||
- 不做低价竞争,做高性价比
|
||||
|
||||
**护城河构建**
|
||||
1. **数据壁垒**: 客户发票数据训练
|
||||
2. **转换成本**: 系统集成和工作流
|
||||
3. **网络效应**: 行业模板共享
|
||||
|
||||
---
|
||||
|
||||
## 改进建议
|
||||
|
||||
### 9.1 产品改进
|
||||
|
||||
#### 高优先级
|
||||
|
||||
| 改进项 | 说明 | 商业价值 | 工时 |
|
||||
|--------|------|----------|------|
|
||||
| **多语言支持** | 英语、德语、法语 | 扩大市场 | 4周 |
|
||||
| **批量处理API** | 支持千级批量 | 大客户必需 | 2周 |
|
||||
| **实时处理** | <3秒响应 | 用户体验 | 2周 |
|
||||
| **置信度阈值** | 用户可配置 | 灵活性 | 1周 |
|
||||
|
||||
#### 中优先级
|
||||
|
||||
| 改进项 | 说明 | 商业价值 | 工时 |
|
||||
|--------|------|----------|------|
|
||||
| **移动端适配** | 手机拍照上传 | 便利性 | 3周 |
|
||||
| **PDF预览** | 在线查看和标注 | 用户体验 | 2周 |
|
||||
| **导出格式** | Excel, JSON, XML | 集成便利 | 1周 |
|
||||
| **Webhook** | 事件通知 | 自动化 | 1周 |
|
||||
|
||||
### 9.2 技术改进
|
||||
|
||||
#### 架构优化
|
||||
|
||||
```
|
||||
当前架构问题:
|
||||
├── 内存队列 → 改为Redis队列
|
||||
├── 单体DB → 读写分离
|
||||
├── 同步处理 → 异步优先
|
||||
└── 单区域 → 多区域部署
|
||||
```
|
||||
|
||||
#### 性能优化
|
||||
|
||||
| 优化项 | 当前 | 目标 | 方法 |
|
||||
|--------|------|------|------|
|
||||
| 推理延迟 | 500ms | 200ms | 模型量化 |
|
||||
| 并发处理 | 10 QPS | 100 QPS | 水平扩展 |
|
||||
| 系统可用性 | 99% | 99.9% | 冗余设计 |
|
||||
|
||||
### 9.3 运营改进
|
||||
|
||||
#### 客户成功
|
||||
|
||||
- 入职流程: 30分钟完成首次提取
|
||||
- 培训材料: 视频教程+文档
|
||||
- 支持响应: <4小时响应时间
|
||||
- 客户健康度: 自动监控和预警
|
||||
|
||||
#### 销售流程
|
||||
|
||||
1. **线索获取**: 内容营销+SEO
|
||||
2. **试用转化**: 14天免费试用
|
||||
3. **付费转化**: 客户成功跟进
|
||||
4. **扩展销售**: 功能升级推荐
|
||||
|
||||
---
|
||||
|
||||
## 总结与建议
|
||||
|
||||
### 10.1 商业化可行性结论
|
||||
|
||||
**总体评估: 可行,需6-12个月准备**
|
||||
|
||||
Invoice Master具备商业化的技术基础和市场机会,但需要完成以下关键准备:
|
||||
|
||||
1. **技术债务清理**: 测试覆盖率、安全加固
|
||||
2. **产品化完善**: 多租户、审计日志、监控
|
||||
3. **市场验证**: 获取首批付费客户
|
||||
4. **团队组建**: 销售和客户成功团队
|
||||
|
||||
### 10.2 关键成功因素
|
||||
|
||||
| 因素 | 重要性 | 当前状态 | 行动计划 |
|
||||
|------|--------|----------|----------|
|
||||
| **技术稳定性** | 高 | 中 | 测试+监控 |
|
||||
| **客户获取** | 高 | 低 | 内容营销 |
|
||||
| **产品市场契合** | 高 | 未验证 | 快速迭代 |
|
||||
| **团队能力** | 高 | 中 | 招聘培训 |
|
||||
| **资金储备** | 中 | 未知 | 融资准备 |
|
||||
|
||||
### 10.3 行动计划
|
||||
|
||||
#### 立即执行 (本月)
|
||||
|
||||
- [ ] 制定详细的技术债务清理计划
|
||||
- [ ] 启动安全审计和漏洞修复
|
||||
- [ ] 设计多租户架构方案
|
||||
- [ ] 准备融资材料或预算规划
|
||||
|
||||
#### 短期目标 (3个月)
|
||||
|
||||
- [ ] 测试覆盖率提升至80%
|
||||
- [ ] 完成安全加固和合规准备
|
||||
- [ ] 发布Beta版本给5-10个试用客户
|
||||
- [ ] 确定最终定价策略
|
||||
|
||||
#### 中期目标 (6个月)
|
||||
|
||||
- [ ] 获得10+付费客户
|
||||
- [ ] MRR达到$10,000
|
||||
- [ ] 完成产品市场契合验证
|
||||
- [ ] 组建完整团队
|
||||
|
||||
#### 长期目标 (12个月)
|
||||
|
||||
- [ ] 100+付费客户
|
||||
- [ ] MRR达到$50,000
|
||||
- [ ] 扩展到2-3个新市场
|
||||
- [ ] 完成A轮融资或实现盈利
|
||||
|
||||
### 10.4 最终建议
|
||||
|
||||
**建议: 继续推进商业化,但需谨慎执行**
|
||||
|
||||
Invoice Master是一个技术扎实、市场机会明确的项目。当前94.8%的准确率已经接近商业化标准,但需要投入资源完成工程化和产品化。
|
||||
|
||||
**关键决策点:**
|
||||
1. **是否投入商业化**: 是,但分阶段投入
|
||||
2. **目标市场**: 先做透瑞典,再扩展北欧
|
||||
3. **商业模式**: SaaS订阅为主,定制为辅
|
||||
4. **融资需求**: 建议准备$200K-500K种子资金
|
||||
|
||||
**成功概率评估: 65%**
|
||||
- 技术可行性: 80%
|
||||
- 市场接受度: 70%
|
||||
- 执行能力: 60%
|
||||
- 竞争环境: 50%
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### A. 关键指标追踪
|
||||
|
||||
| 指标 | 当前 | 3个月目标 | 6个月目标 | 12个月目标 |
|
||||
|------|------|-----------|-----------|------------|
|
||||
| 测试覆盖率 | 28% | 60% | 80% | 85% |
|
||||
| 系统可用性 | - | 99.5% | 99.9% | 99.95% |
|
||||
| 客户数 | 0 | 5 | 20 | 150 |
|
||||
| MRR | $0 | $500 | $10,000 | $50,000 |
|
||||
| NPS | - | - | >40 | >50 |
|
||||
| 客户流失率 | - | - | <5%/月 | <3%/月 |
|
||||
|
||||
### B. 资源需求
|
||||
|
||||
**资金需求**
|
||||
| 阶段 | 时间 | 金额 | 用途 |
|
||||
|------|------|------|------|
|
||||
| 种子期 | 0-6月 | $100K | 团队+基础设施 |
|
||||
| 成长期 | 6-12月 | $300K | 市场+团队扩展 |
|
||||
| A轮 | 12-18月 | $1M+ | 规模化+国际 |
|
||||
|
||||
**人力需求**
|
||||
| 阶段 | 团队规模 | 关键角色 |
|
||||
|------|----------|----------|
|
||||
| 启动 | 3-4人 | 技术+产品+销售 |
|
||||
| 验证 | 5-6人 | +客户成功 |
|
||||
| 增长 | 8-10人 | +市场+技术支持 |
|
||||
|
||||
### C. 参考资源
|
||||
|
||||
- [SaaS Metrics Guide](https://www.saasmetrics.co/)
|
||||
- [GDPR Compliance Checklist](https://gdpr.eu/checklist/)
|
||||
- [B2B SaaS Pricing Guide](https://www.priceintelligently.com/)
|
||||
- [Nordic Startup Ecosystem](https://www.nordicstartupnews.com/)
|
||||
|
||||
---
|
||||
|
||||
**报告完成日期**: 2026-02-01
|
||||
**下次评审日期**: 2026-03-01
|
||||
**版本**: v1.0
|
||||
419
PROJECT_REVIEW.md
Normal file
419
PROJECT_REVIEW.md
Normal file
@@ -0,0 +1,419 @@
|
||||
# Invoice Master POC v2 - 项目审查报告
|
||||
|
||||
**审查日期**: 2026-02-01
|
||||
**审查人**: Claude Code
|
||||
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
|
||||
|
||||
---
|
||||
|
||||
## 项目概述
|
||||
|
||||
**Invoice Master POC v2** - 基于 YOLOv11 + PaddleOCR 的瑞典发票字段自动提取系统
|
||||
|
||||
### 核心功能
|
||||
- **自动标注**: 利用 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
|
||||
- **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
|
||||
- **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化
|
||||
- **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
|
||||
|
||||
### 架构设计
|
||||
采用 **Monorepo + 三包分离** 架构:
|
||||
|
||||
```
|
||||
packages/
|
||||
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 存储, 训练)
|
||||
├── training/ # 训练服务 (GPU, 按需启动)
|
||||
└── inference/ # 推理服务 (常驻运行)
|
||||
frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|
||||
```
|
||||
|
||||
### 性能指标
|
||||
|
||||
| 指标 | 数值 |
|
||||
|------|------|
|
||||
| **已标注文档** | 9,738 (9,709 成功) |
|
||||
| **总体字段匹配率** | 94.8% (82,604/87,121) |
|
||||
| **测试** | 1,601 passed |
|
||||
| **测试覆盖率** | 28% |
|
||||
| **模型 mAP@0.5** | 93.5% |
|
||||
|
||||
---
|
||||
|
||||
## 安全性审查
|
||||
|
||||
### 检查清单
|
||||
|
||||
| 检查项 | 状态 | 说明 | 文件位置 |
|
||||
|--------|------|------|----------|
|
||||
| **Secrets 管理** | ✅ 良好 | 使用 `.env` 文件,`DB_PASSWORD` 无默认值 | `packages/shared/shared/config.py:46` |
|
||||
| **SQL 注入防护** | ✅ 良好 | 使用参数化查询 | 全项目 |
|
||||
| **认证机制** | ✅ 良好 | Admin token 验证 + 数据库持久化 | `packages/inference/inference/web/core/auth.py` |
|
||||
| **输入验证** | ⚠️ 需改进 | 部分端点缺少文件类型/大小验证 | Web API 端点 |
|
||||
| **路径遍历防护** | ⚠️ 需检查 | 需确认文件上传路径验证 | 文件上传处理 |
|
||||
| **CORS 配置** | ❓ 待查 | 需确认生产环境配置 | FastAPI 中间件 |
|
||||
| **Rate Limiting** | ✅ 良好 | 已实现核心限流器 | `packages/inference/inference/web/core/rate_limiter.py` |
|
||||
| **错误处理** | ✅ 良好 | Web 层 356 处异常处理 | 全项目 |
|
||||
|
||||
### 详细发现
|
||||
|
||||
#### ✅ 安全实践良好的方面
|
||||
|
||||
1. **环境变量管理**
|
||||
- 使用 `python-dotenv` 加载 `.env` 文件
|
||||
- 数据库密码没有默认值,强制要求设置
|
||||
- 验证逻辑在配置加载时执行
|
||||
|
||||
2. **认证实现**
|
||||
- Token 存储在 PostgreSQL 数据库
|
||||
- 支持 Token 过期检查
|
||||
- 记录最后使用时间
|
||||
|
||||
3. **存储抽象层**
|
||||
- 支持 Local/Azure/S3 多后端
|
||||
- 通过环境变量配置,无硬编码凭证
|
||||
|
||||
#### ⚠️ 需要改进的安全问题
|
||||
|
||||
1. **时序攻击防护**
|
||||
- **位置**: `packages/inference/inference/web/core/auth.py:46`
|
||||
- **问题**: Token 验证使用普通字符串比较
|
||||
- **建议**: 使用 `hmac.compare_digest()` 进行 constant-time 比较
|
||||
- **风险等级**: 中
|
||||
|
||||
2. **文件上传验证**
|
||||
- **位置**: Web API 文件上传端点
|
||||
- **问题**: 需确认是否验证文件魔数 (magic bytes)
|
||||
- **建议**: 添加 PDF 文件签名验证 (`%PDF`)
|
||||
- **风险等级**: 中
|
||||
|
||||
3. **路径遍历风险**
|
||||
- **位置**: 文件下载/访问端点
|
||||
- **问题**: 需确认文件名是否经过净化处理
|
||||
- **建议**: 使用 `pathlib.Path.name` 提取文件名,验证路径范围
|
||||
- **风险等级**: 中
|
||||
|
||||
4. **CORS 配置**
|
||||
- **位置**: FastAPI 中间件配置
|
||||
- **问题**: 需确认生产环境是否允许所有来源
|
||||
- **建议**: 生产环境明确指定允许的 origins
|
||||
- **风险等级**: 低
|
||||
|
||||
---
|
||||
|
||||
## 代码质量审查
|
||||
|
||||
### 代码风格与规范
|
||||
|
||||
| 检查项 | 状态 | 说明 |
|
||||
|--------|------|------|
|
||||
| **类型注解** | ✅ 优秀 | 广泛使用 Type hints,覆盖率 > 90% |
|
||||
| **命名规范** | ✅ 良好 | 遵循 PEP 8,snake_case 命名 |
|
||||
| **文档字符串** | ✅ 良好 | 主要模块和函数都有文档 |
|
||||
| **异常处理** | ✅ 良好 | Web 层 356 处异常处理 |
|
||||
| **代码组织** | ✅ 优秀 | 模块化结构清晰,职责分离明确 |
|
||||
| **文件大小** | ⚠️ 需关注 | 部分文件超过 800 行 |
|
||||
|
||||
### 架构设计评估
|
||||
|
||||
#### 优秀的设计决策
|
||||
|
||||
1. **Monorepo 结构**
|
||||
- 清晰的包边界 (shared/training/inference)
|
||||
- 避免循环依赖
|
||||
- 便于独立部署
|
||||
|
||||
2. **存储抽象层**
|
||||
- 统一的 `StorageBackend` 接口
|
||||
- 支持本地/Azure/S3 无缝切换
|
||||
- 预签名 URL 支持
|
||||
|
||||
3. **配置管理**
|
||||
- 使用 dataclass 定义配置
|
||||
- 环境变量 + 配置文件混合
|
||||
- 类型安全
|
||||
|
||||
4. **数据库设计**
|
||||
- 合理的表结构
|
||||
- 状态机设计 (pending → running → completed)
|
||||
- 外键约束完整
|
||||
|
||||
#### 需要改进的方面
|
||||
|
||||
1. **测试覆盖率偏低**
|
||||
- 当前: 28%
|
||||
- 目标: 60%+
|
||||
- 优先测试核心业务逻辑
|
||||
|
||||
2. **部分文件过大**
|
||||
- 建议拆分为多个小文件
|
||||
- 单一职责原则
|
||||
|
||||
3. **缺少集成测试**
|
||||
- 建议添加端到端测试
|
||||
- API 契约测试
|
||||
|
||||
---
|
||||
|
||||
## 最佳实践遵循情况
|
||||
|
||||
### 已遵循的最佳实践
|
||||
|
||||
| 实践 | 实现状态 | 说明 |
|
||||
|------|----------|------|
|
||||
| **环境变量配置** | ✅ | 所有配置通过环境变量 |
|
||||
| **数据库连接池** | ✅ | 使用 SQLModel + psycopg2 |
|
||||
| **异步处理** | ✅ | FastAPI + async/await |
|
||||
| **存储抽象层** | ✅ | 支持 Local/Azure/S3 |
|
||||
| **Docker 容器化** | ✅ | 每个服务独立 Dockerfile |
|
||||
| **数据增强** | ✅ | 12 种增强策略 |
|
||||
| **模型版本管理** | ✅ | model_versions 表 |
|
||||
| **限流保护** | ✅ | Rate limiter 实现 |
|
||||
| **日志记录** | ✅ | 结构化日志 |
|
||||
| **类型安全** | ✅ | 全面 Type hints |
|
||||
|
||||
### 技术栈评估
|
||||
|
||||
| 组件 | 技术选择 | 评估 |
|
||||
|------|----------|------|
|
||||
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
|
||||
| **OCR 引擎** | PaddleOCR v5 | ✅ 支持瑞典语 |
|
||||
| **PDF 处理** | PyMuPDF (fitz) | ✅ 功能强大 |
|
||||
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 |
|
||||
| **Web 框架** | FastAPI + Uvicorn | ✅ 高性能 |
|
||||
| **前端** | React + TypeScript + Vite | ✅ 现代栈 |
|
||||
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
|
||||
|
||||
---
|
||||
|
||||
## 关键文件详细分析
|
||||
|
||||
### 1. 配置文件
|
||||
|
||||
#### `packages/shared/shared/config.py`
|
||||
- **安全性**: ✅ 密码从环境变量读取,无默认值
|
||||
- **代码质量**: ✅ 清晰的配置结构
|
||||
- **建议**: 考虑使用 Pydantic Settings 进行验证
|
||||
|
||||
#### `packages/inference/inference/web/config.py`
|
||||
- **安全性**: ✅ 无敏感信息硬编码
|
||||
- **代码质量**: ✅ 使用 frozen dataclass
|
||||
- **建议**: 添加配置验证逻辑
|
||||
|
||||
### 2. 认证模块
|
||||
|
||||
#### `packages/inference/inference/web/core/auth.py`
|
||||
- **安全性**: ⚠️ 需添加 constant-time 比较
|
||||
- **代码质量**: ✅ 依赖注入模式
|
||||
- **建议**:
|
||||
```python
|
||||
import hmac
|
||||
if not hmac.compare_digest(api_key, settings.api_key):
|
||||
raise HTTPException(403, "Invalid API key")
|
||||
```
|
||||
|
||||
### 3. 限流器
|
||||
|
||||
#### `packages/inference/inference/web/core/rate_limiter.py`
|
||||
- **安全性**: ✅ 内存限流实现
|
||||
- **代码质量**: ✅ 清晰的接口设计
|
||||
- **建议**: 生产环境考虑 Redis 分布式限流
|
||||
|
||||
### 4. 存储层
|
||||
|
||||
#### `packages/shared/shared/storage/`
|
||||
- **安全性**: ✅ 无凭证硬编码
|
||||
- **代码质量**: ✅ 抽象接口设计
|
||||
- **建议**: 添加文件类型验证
|
||||
|
||||
---
|
||||
|
||||
## 性能与可扩展性
|
||||
|
||||
### 当前性能
|
||||
|
||||
| 指标 | 数值 | 评估 |
|
||||
|------|------|------|
|
||||
| **字段匹配率** | 94.8% | ✅ 优秀 |
|
||||
| **模型 mAP@0.5** | 93.5% | ✅ 优秀 |
|
||||
| **测试执行时间** | - | 待测量 |
|
||||
| **API 响应时间** | - | 待测量 |
|
||||
|
||||
### 可扩展性评估
|
||||
|
||||
| 方面 | 评估 | 说明 |
|
||||
|------|------|------|
|
||||
| **水平扩展** | ✅ 良好 | 无状态服务设计 |
|
||||
| **垂直扩展** | ✅ 良好 | 支持 GPU 加速 |
|
||||
| **数据库扩展** | ⚠️ 需关注 | 单 PostgreSQL 实例 |
|
||||
| **存储扩展** | ✅ 良好 | 云存储抽象层 |
|
||||
|
||||
---
|
||||
|
||||
## 风险评估
|
||||
|
||||
### 高风险项
|
||||
|
||||
1. **测试覆盖率低 (28%)**
|
||||
- **影响**: 代码变更风险高
|
||||
- **缓解**: 制定测试计划,优先覆盖核心逻辑
|
||||
|
||||
2. **文件上传安全**
|
||||
- **影响**: 潜在的路径遍历和恶意文件上传
|
||||
- **缓解**: 添加文件类型验证和路径净化
|
||||
|
||||
### 中风险项
|
||||
|
||||
1. **认证时序攻击**
|
||||
- **影响**: Token 可能被暴力破解
|
||||
- **缓解**: 使用 constant-time 比较
|
||||
|
||||
2. **CORS 配置**
|
||||
- **影响**: CSRF 攻击风险
|
||||
- **缓解**: 生产环境限制 origins
|
||||
|
||||
### 低风险项
|
||||
|
||||
1. **依赖更新**
|
||||
- **影响**: 潜在的安全漏洞
|
||||
- **缓解**: 定期运行 `pip-audit`
|
||||
|
||||
---
|
||||
|
||||
## 改进建议
|
||||
|
||||
### 立即执行 (高优先级)
|
||||
|
||||
1. **提升测试覆盖率**
|
||||
```bash
|
||||
# 目标: 60%+
|
||||
pytest tests/ --cov=packages --cov-report=html
|
||||
```
|
||||
- 优先测试 `inference/pipeline/`
|
||||
- 添加 API 集成测试
|
||||
- 添加存储层测试
|
||||
|
||||
2. **加强文件上传安全**
|
||||
```python
|
||||
# 添加文件类型验证
|
||||
ALLOWED_EXTENSIONS = {".pdf"}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# 验证 PDF 魔数
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise HTTPException(400, "Invalid PDF file format")
|
||||
```
|
||||
|
||||
3. **修复时序攻击漏洞**
|
||||
```python
|
||||
import hmac
|
||||
|
||||
def verify_token(token: str, expected: str) -> bool:
|
||||
return hmac.compare_digest(token, expected)
|
||||
```
|
||||
|
||||
### 短期执行 (中优先级)
|
||||
|
||||
4. **添加路径遍历防护**
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
def get_safe_path(filename: str, base_dir: Path) -> Path:
|
||||
safe_name = Path(filename).name
|
||||
full_path = (base_dir / safe_name).resolve()
|
||||
if not full_path.is_relative_to(base_dir):
|
||||
raise HTTPException(400, "Invalid file path")
|
||||
return full_path
|
||||
```
|
||||
|
||||
5. **配置 CORS 白名单**
|
||||
```python
|
||||
ALLOWED_ORIGINS = [
|
||||
"http://localhost:5173",
|
||||
"https://your-domain.com",
|
||||
]
|
||||
```
|
||||
|
||||
6. **添加安全测试**
|
||||
```python
|
||||
def test_sql_injection_prevented(client):
|
||||
response = client.get("/api/v1/documents?id='; DROP TABLE;")
|
||||
assert response.status_code in (400, 422)
|
||||
|
||||
def test_path_traversal_prevented(client):
|
||||
response = client.get("/api/v1/results/../../etc/passwd")
|
||||
assert response.status_code == 400
|
||||
```
|
||||
|
||||
### 长期执行 (低优先级)
|
||||
|
||||
7. **依赖安全审计**
|
||||
```bash
|
||||
pip install pip-audit
|
||||
pip-audit --desc --format=json > security-audit.json
|
||||
```
|
||||
|
||||
8. **代码质量工具**
|
||||
```bash
|
||||
# 添加 pre-commit hooks
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
9. **性能监控**
|
||||
- 添加 APM 工具 (如 Datadog, New Relic)
|
||||
- 设置性能基准测试
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 总体评分
|
||||
|
||||
| 维度 | 评分 | 说明 |
|
||||
|------|------|------|
|
||||
| **安全性** | 8/10 | 基础安全良好,需加强输入验证和认证 |
|
||||
| **代码质量** | 8/10 | 结构清晰,类型注解完善,部分文件过大 |
|
||||
| **可维护性** | 9/10 | 模块化设计,文档详尽,架构合理 |
|
||||
| **测试覆盖** | 5/10 | 需大幅提升至 60%+ |
|
||||
| **性能** | 9/10 | 94.8% 匹配率,93.5% mAP |
|
||||
| **总体** | **8.2/10** | 优秀的项目,需关注测试和安全细节 |
|
||||
|
||||
### 关键结论
|
||||
|
||||
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
|
||||
2. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
|
||||
3. **代码质量高**: 类型注解完善,文档详尽,结构清晰
|
||||
4. **测试是短板**: 28% 覆盖率是最大风险点
|
||||
5. **生产就绪**: 经过小幅改进后可以投入生产使用
|
||||
|
||||
### 下一步行动
|
||||
|
||||
1. 🔴 **立即**: 提升测试覆盖率至 60%+
|
||||
2. 🟡 **本周**: 修复时序攻击漏洞,加强文件上传验证
|
||||
3. 🟡 **本月**: 添加路径遍历防护,配置 CORS 白名单
|
||||
4. 🟢 **季度**: 建立安全审计流程,添加性能监控
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### 审查工具
|
||||
|
||||
- Claude Code Security Review Skill
|
||||
- Claude Code Coding Standards Skill
|
||||
- grep / find / wc
|
||||
|
||||
### 相关文件
|
||||
|
||||
- `packages/shared/shared/config.py`
|
||||
- `packages/inference/inference/web/config.py`
|
||||
- `packages/inference/inference/web/core/auth.py`
|
||||
- `packages/inference/inference/web/core/rate_limiter.py`
|
||||
- `packages/shared/shared/storage/`
|
||||
|
||||
### 参考资源
|
||||
|
||||
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
|
||||
- [FastAPI Security](https://fastapi.tiangolo.com/tutorial/security/)
|
||||
- [Bandit (Python Security Linter)](https://bandit.readthedocs.io/)
|
||||
- [pip-audit](https://pypi.org/project/pip-audit/)
|
||||
96
create_shims.sh
Normal file
96
create_shims.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create backward compatibility shims for all migrated files
|
||||
|
||||
# admin_auth.py -> core/auth.py
|
||||
cat > src/web/admin_auth.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.core.auth instead"""
|
||||
from src.web.core.auth import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_autolabel.py -> services/autolabel.py
|
||||
cat > src/web/admin_autolabel.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.autolabel instead"""
|
||||
from src.web.services.autolabel import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_scheduler.py -> core/scheduler.py
|
||||
cat > src/web/admin_scheduler.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.core.scheduler instead"""
|
||||
from src.web.core.scheduler import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_schemas.py -> schemas/admin.py
|
||||
cat > src/web/admin_schemas.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.schemas.admin instead"""
|
||||
from src.web.schemas.admin import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# schemas.py -> schemas/inference.py + schemas/common.py
|
||||
cat > src/web/schemas.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.schemas.inference or src.web.schemas.common instead"""
|
||||
from src.web.schemas.inference import * # noqa: F401, F403
|
||||
from src.web.schemas.common import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# services.py -> services/inference.py
|
||||
cat > src/web/services.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.inference instead"""
|
||||
from src.web.services.inference import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# async_queue.py -> workers/async_queue.py
|
||||
cat > src/web/async_queue.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.workers.async_queue instead"""
|
||||
from src.web.workers.async_queue import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# async_service.py -> services/async_processing.py
|
||||
cat > src/web/async_service.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.async_processing instead"""
|
||||
from src.web.services.async_processing import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_queue.py -> workers/batch_queue.py
|
||||
cat > src/web/batch_queue.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.workers.batch_queue instead"""
|
||||
from src.web.workers.batch_queue import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_upload_service.py -> services/batch_upload.py
|
||||
cat > src/web/batch_upload_service.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.services.batch_upload instead"""
|
||||
from src.web.services.batch_upload import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# batch_upload_routes.py -> api/v1/batch/routes.py
|
||||
cat > src/web/batch_upload_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.batch.routes instead"""
|
||||
from src.web.api.v1.batch.routes import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_routes.py -> api/v1/admin/documents.py
|
||||
cat > src/web/admin_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.documents instead"""
|
||||
from src.web.api.v1.admin.documents import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_annotation_routes.py -> api/v1/admin/annotations.py
|
||||
cat > src/web/admin_annotation_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.annotations instead"""
|
||||
from src.web.api.v1.admin.annotations import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# admin_training_routes.py -> api/v1/admin/training.py
|
||||
cat > src/web/admin_training_routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.admin.training instead"""
|
||||
from src.web.api.v1.admin.training import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
# routes.py -> api/v1/routes.py
|
||||
cat > src/web/routes.py << 'EOF'
|
||||
"""DEPRECATED: Import from src.web.api.v1.routes instead"""
|
||||
from src.web.api.v1.routes import * # noqa: F401, F403
|
||||
EOF
|
||||
|
||||
echo "✓ Created backward compatibility shims for all migrated files"
|
||||
60
docker-compose.yml
Normal file
60
docker-compose.yml
Normal file
@@ -0,0 +1,60 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15
|
||||
environment:
|
||||
POSTGRES_DB: docmaster
|
||||
POSTGRES_USER: docmaster
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-devpassword}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
- ./migrations:/docker-entrypoint-initdb.d
|
||||
|
||||
inference:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: packages/inference/Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- DB_NAME=docmaster
|
||||
- DB_USER=docmaster
|
||||
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
|
||||
- MODEL_PATH=/app/models/best.pt
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
depends_on:
|
||||
- postgres
|
||||
|
||||
training:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: packages/training/Dockerfile
|
||||
environment:
|
||||
- DB_HOST=postgres
|
||||
- DB_PORT=5432
|
||||
- DB_NAME=docmaster
|
||||
- DB_USER=docmaster
|
||||
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
- ./temp:/app/temp
|
||||
depends_on:
|
||||
- postgres
|
||||
# Override CMD for local dev polling mode
|
||||
command: ["python", "run_training.py", "--poll", "--poll-interval", "30"]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
@@ -1,405 +0,0 @@
|
||||
# Invoice Master POC v2 - 代码审查报告
|
||||
|
||||
**审查日期**: 2026-01-22
|
||||
**代码库规模**: 67 个 Python 源文件,约 22,434 行代码
|
||||
**测试覆盖率**: ~40-50%
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
### 总体评估:**良好(B+)**
|
||||
|
||||
**优势**:
|
||||
- ✅ 清晰的模块化架构,职责分离良好
|
||||
- ✅ 使用了合适的数据类和类型提示
|
||||
- ✅ 针对瑞典发票的全面规范化逻辑
|
||||
- ✅ 空间索引优化(O(1) token 查找)
|
||||
- ✅ 完善的降级机制(YOLO 失败时的 OCR fallback)
|
||||
- ✅ 设计良好的 Web API 和 UI
|
||||
|
||||
**主要问题**:
|
||||
- ❌ 支付行解析代码重复(3+ 处)
|
||||
- ❌ 长函数(`_normalize_customer_number` 127 行)
|
||||
- ❌ 配置安全问题(明文数据库密码)
|
||||
- ❌ 异常处理不一致(到处都是通用 Exception)
|
||||
- ❌ 缺少集成测试
|
||||
- ❌ 魔法数字散布各处(0.5, 0.95, 300 等)
|
||||
|
||||
---
|
||||
|
||||
## 1. 架构分析
|
||||
|
||||
### 1.1 模块结构
|
||||
|
||||
```
|
||||
src/
|
||||
├── inference/ # 推理管道核心
|
||||
│ ├── pipeline.py (517 行) ⚠️
|
||||
│ ├── field_extractor.py (1,347 行) 🔴 太长
|
||||
│ └── yolo_detector.py
|
||||
├── web/ # FastAPI Web 服务
|
||||
│ ├── app.py (765 行) ⚠️ HTML 内联
|
||||
│ ├── routes.py (184 行)
|
||||
│ └── services.py (286 行)
|
||||
├── ocr/ # OCR 提取
|
||||
│ ├── paddle_ocr.py
|
||||
│ └── machine_code_parser.py (919 行) 🔴 太长
|
||||
├── matcher/ # 字段匹配
|
||||
│ └── field_matcher.py (875 行) ⚠️
|
||||
├── utils/ # 共享工具
|
||||
│ ├── validators.py
|
||||
│ ├── text_cleaner.py
|
||||
│ ├── fuzzy_matcher.py
|
||||
│ ├── ocr_corrections.py
|
||||
│ └── format_variants.py (610 行)
|
||||
├── processing/ # 批处理
|
||||
├── data/ # 数据管理
|
||||
└── cli/ # 命令行工具
|
||||
```
|
||||
|
||||
### 1.2 推理流程
|
||||
|
||||
```
|
||||
PDF/Image 输入
|
||||
↓
|
||||
渲染为图片 (pdf/renderer.py)
|
||||
↓
|
||||
YOLO 检测 (yolo_detector.py) - 检测字段区域
|
||||
↓
|
||||
字段提取 (field_extractor.py)
|
||||
├→ OCR 文本提取 (ocr/paddle_ocr.py)
|
||||
├→ 规范化 & 验证
|
||||
└→ 置信度计算
|
||||
↓
|
||||
交叉验证 (pipeline.py)
|
||||
├→ 解析 payment_line 格式
|
||||
├→ 从 payment_line 提取 OCR/Amount/Account
|
||||
└→ 与检测字段验证,payment_line 值优先
|
||||
↓
|
||||
降级 OCR(如果关键字段缺失)
|
||||
├→ 全页 OCR
|
||||
└→ 正则提取
|
||||
↓
|
||||
InferenceResult 输出
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 代码质量问题
|
||||
|
||||
### 2.1 长函数(>50 行)🔴
|
||||
|
||||
| 函数 | 文件 | 行数 | 复杂度 | 问题 |
|
||||
|------|------|------|--------|------|
|
||||
| `_normalize_customer_number()` | field_extractor.py | **127** | 极高 | 4 层模式匹配,7+ 正则,复杂评分 |
|
||||
| `_cross_validate_payment_line()` | pipeline.py | **127** | 极高 | 核心验证逻辑,8+ 条件分支 |
|
||||
| `_normalize_bankgiro()` | field_extractor.py | 62 | 高 | Luhn 验证 + 多种降级 |
|
||||
| `_normalize_plusgiro()` | field_extractor.py | 63 | 高 | 类似 bankgiro |
|
||||
| `_normalize_payment_line()` | field_extractor.py | 74 | 高 | 4 种正则模式 |
|
||||
| `_normalize_amount()` | field_extractor.py | 78 | 高 | 多策略降级 |
|
||||
|
||||
**示例问题** - `_normalize_customer_number()` (第 776-902 行):
|
||||
```python
|
||||
def _normalize_customer_number(self, text: str):
|
||||
# 127 行函数,包含:
|
||||
# - 4 个嵌套的 if/for 循环
|
||||
# - 7 种不同的正则模式
|
||||
# - 5 个评分机制
|
||||
# - 处理有标签和无标签格式
|
||||
```
|
||||
|
||||
**建议**: 拆分为:
|
||||
- `_find_customer_code_patterns()`
|
||||
- `_find_labeled_customer_code()`
|
||||
- `_score_customer_candidates()`
|
||||
|
||||
### 2.2 代码重复 🔴
|
||||
|
||||
**支付行解析(3+ 处重复实现)**:
|
||||
|
||||
1. `_parse_machine_readable_payment_line()` (pipeline.py:217-252)
|
||||
2. `MachineCodeParser.parse()` (machine_code_parser.py:919 行)
|
||||
3. `_normalize_payment_line()` (field_extractor.py:632-705)
|
||||
|
||||
所有三处都实现类似的正则模式:
|
||||
```
|
||||
格式: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
```
|
||||
|
||||
**Bankgiro/Plusgiro 验证(重复)**:
|
||||
- `validators.py`: `is_valid_bankgiro()`, `format_bankgiro()`
|
||||
- `field_extractor.py`: `_normalize_bankgiro()`, `_normalize_plusgiro()`, `_luhn_checksum()`
|
||||
- `normalizer.py`: `normalize_bankgiro()`, `normalize_plusgiro()`
|
||||
- `field_matcher.py`: 类似匹配逻辑
|
||||
|
||||
**建议**: 创建统一模块:
|
||||
```python
|
||||
# src/common/payment_line_parser.py
|
||||
class PaymentLineParser:
|
||||
def parse(text: str) -> PaymentLineResult
|
||||
|
||||
# src/common/giro_validator.py
|
||||
class GiroValidator:
|
||||
def validate_and_format(value: str, giro_type: str) -> str
|
||||
```
|
||||
|
||||
### 2.3 错误处理不一致 ⚠️
|
||||
|
||||
**通用异常捕获(31 处)**:
|
||||
```python
|
||||
except Exception as e: # 代码库中 31 处
|
||||
result.errors.append(str(e))
|
||||
```
|
||||
|
||||
**问题**:
|
||||
- 没有捕获特定错误类型
|
||||
- 通用错误消息丢失上下文
|
||||
- 第 142-147 行 (routes.py): 捕获所有异常,返回 500 状态
|
||||
|
||||
**当前写法** (routes.py:142-147):
|
||||
```python
|
||||
try:
|
||||
service_result = inference_service.process_pdf(...)
|
||||
except Exception as e: # 太宽泛
|
||||
logger.error(f"Error processing document: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
```
|
||||
|
||||
**改进建议**:
|
||||
```python
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=400, detail="PDF 文件未找到")
|
||||
except PyMuPDFError:
|
||||
raise HTTPException(status_code=400, detail="无效的 PDF 格式")
|
||||
except OCRError:
|
||||
raise HTTPException(status_code=503, detail="OCR 服务不可用")
|
||||
```
|
||||
|
||||
### 2.4 配置安全问题 🔴
|
||||
|
||||
**config.py 第 24-30 行** - 明文凭据:
|
||||
```python
|
||||
DATABASE = {
|
||||
'host': '192.168.68.31', # 硬编码 IP
|
||||
'user': 'docmaster', # 硬编码用户名
|
||||
'password': 'nY6LYK5d', # 🔴 明文密码!
|
||||
'database': 'invoice_master'
|
||||
}
|
||||
```
|
||||
|
||||
**建议**:
|
||||
```python
|
||||
DATABASE = {
|
||||
'host': os.getenv('DB_HOST', 'localhost'),
|
||||
'user': os.getenv('DB_USER', 'docmaster'),
|
||||
'password': os.getenv('DB_PASSWORD'), # 从环境变量读取
|
||||
'database': os.getenv('DB_NAME', 'invoice_master')
|
||||
}
|
||||
```
|
||||
|
||||
### 2.5 魔法数字 ⚠️
|
||||
|
||||
| 值 | 位置 | 用途 | 问题 |
|
||||
|---|------|------|------|
|
||||
| 0.5 | 多处 | 置信度阈值 | 不可按字段配置 |
|
||||
| 0.95 | pipeline.py | payment_line 置信度 | 无说明 |
|
||||
| 300 | 多处 | DPI | 硬编码 |
|
||||
| 0.1 | field_extractor.py | BBox 填充 | 应为配置 |
|
||||
| 72 | 多处 | PDF 基础 DPI | 公式中的魔法数字 |
|
||||
| 50 | field_extractor.py | 客户编号评分加分 | 无说明 |
|
||||
|
||||
**建议**: 提取到配置:
|
||||
```python
|
||||
INFERENCE_CONFIG = {
|
||||
'confidence_threshold': 0.5,
|
||||
'payment_line_confidence': 0.95,
|
||||
'dpi': 300,
|
||||
'bbox_padding': 0.1,
|
||||
}
|
||||
```
|
||||
|
||||
### 2.6 命名不一致 ⚠️
|
||||
|
||||
**字段名称不一致**:
|
||||
- YOLO 类名: `invoice_number`, `ocr_number`, `supplier_org_number`
|
||||
- 字段名: `InvoiceNumber`, `OCR`, `supplier_org_number`
|
||||
- CSV 列名: 可能又不同
|
||||
- 数据库字段名: 另一种变体
|
||||
|
||||
映射维护在多处:
|
||||
- `yolo_detector.py` (90-100 行): `CLASS_TO_FIELD`
|
||||
- 多个其他位置
|
||||
|
||||
---
|
||||
|
||||
## 3. 测试分析
|
||||
|
||||
### 3.1 测试覆盖率
|
||||
|
||||
**测试文件**: 13 个
|
||||
- ✅ 覆盖良好: field_matcher, normalizer, payment_line_parser
|
||||
- ⚠️ 中等覆盖: field_extractor, pipeline
|
||||
- ❌ 覆盖不足: web 层, CLI, 批处理
|
||||
|
||||
**估算覆盖率**: 40-50%
|
||||
|
||||
### 3.2 缺失的测试用例 🔴
|
||||
|
||||
**关键缺失**:
|
||||
1. 交叉验证逻辑 - 最复杂部分,测试很少
|
||||
2. payment_line 解析变体 - 多种实现,边界情况不清楚
|
||||
3. OCR 错误纠正 - 不同策略的复杂逻辑
|
||||
4. Web API 端点 - 没有请求/响应测试
|
||||
5. 批处理 - 多 worker 协调未测试
|
||||
6. 降级 OCR 机制 - YOLO 检测失败时
|
||||
|
||||
---
|
||||
|
||||
## 4. 架构风险
|
||||
|
||||
### 🔴 关键风险
|
||||
|
||||
1. **配置安全** - config.py 中明文数据库凭据(24-30 行)
|
||||
2. **错误恢复** - 宽泛的异常处理掩盖真实问题
|
||||
3. **可测试性** - 硬编码依赖阻止单元测试
|
||||
|
||||
### 🟡 高风险
|
||||
|
||||
1. **代码可维护性** - 支付行解析重复
|
||||
2. **可扩展性** - 没有长时间推理的异步处理
|
||||
3. **扩展性** - 添加新字段类型会很困难
|
||||
|
||||
### 🟢 中等风险
|
||||
|
||||
1. **性能** - 懒加载有帮助,但 ORM 查询未优化
|
||||
2. **文档** - 大部分足够但可以更好
|
||||
|
||||
---
|
||||
|
||||
## 5. 优先级矩阵
|
||||
|
||||
| 优先级 | 行动 | 工作量 | 影响 |
|
||||
|--------|------|--------|------|
|
||||
| 🔴 关键 | 修复配置安全(环境变量) | 1 小时 | 高 |
|
||||
| 🔴 关键 | 添加集成测试 | 2-3 天 | 高 |
|
||||
| 🔴 关键 | 文档化错误处理策略 | 4 小时 | 中 |
|
||||
| 🟡 高 | 统一 payment_line 解析 | 1-2 天 | 高 |
|
||||
| 🟡 高 | 提取规范化到子模块 | 2-3 天 | 中 |
|
||||
| 🟡 高 | 添加依赖注入 | 2-3 天 | 中 |
|
||||
| 🟡 高 | 拆分长函数 | 2-3 天 | 低 |
|
||||
| 🟢 中 | 提高测试覆盖率到 70%+ | 3-5 天 | 高 |
|
||||
| 🟢 中 | 提取魔法数字 | 4 小时 | 低 |
|
||||
| 🟢 中 | 标准化命名约定 | 1-2 天 | 中 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 具体文件建议
|
||||
|
||||
### 高优先级(代码质量)
|
||||
|
||||
| 文件 | 问题 | 建议 |
|
||||
|------|------|------|
|
||||
| `field_extractor.py` | 1,347 行;6 个长规范化方法 | 拆分为 `normalizers/` 子模块 |
|
||||
| `pipeline.py` | 127 行 `_cross_validate_payment_line()` | 提取到单独的 `CrossValidator` 类 |
|
||||
| `field_matcher.py` | 875 行;复杂匹配逻辑 | 拆分为 `matching/` 子模块 |
|
||||
| `config.py` | 硬编码凭据(第 29 行) | 使用环境变量 |
|
||||
| `machine_code_parser.py` | 919 行;payment_line 解析 | 与 pipeline 解析合并 |
|
||||
|
||||
### 中优先级(重构)
|
||||
|
||||
| 文件 | 问题 | 建议 |
|
||||
|------|------|------|
|
||||
| `app.py` | 765 行;HTML 内联在 Python 中 | 提取到 `templates/` 目录 |
|
||||
| `autolabel.py` | 753 行;批处理逻辑 | 提取 worker 函数到模块 |
|
||||
| `format_variants.py` | 610 行;变体生成 | 考虑策略模式 |
|
||||
|
||||
---
|
||||
|
||||
## 7. 建议行动
|
||||
|
||||
### 第 1 阶段:关键修复(1 周)
|
||||
|
||||
1. **配置安全** (1 小时)
|
||||
- 移除 config.py 中的明文密码
|
||||
- 添加环境变量支持
|
||||
- 更新 README 说明配置
|
||||
|
||||
2. **错误处理标准化** (1 天)
|
||||
- 定义自定义异常类
|
||||
- 替换通用 Exception 捕获
|
||||
- 添加错误代码常量
|
||||
|
||||
3. **添加关键集成测试** (2 天)
|
||||
- 端到端推理测试
|
||||
- payment_line 交叉验证测试
|
||||
- API 端点测试
|
||||
|
||||
### 第 2 阶段:重构(2-3 周)
|
||||
|
||||
4. **统一 payment_line 解析** (2 天)
|
||||
- 创建 `src/common/payment_line_parser.py`
|
||||
- 合并 3 处重复实现
|
||||
- 迁移所有调用方
|
||||
|
||||
5. **拆分 field_extractor.py** (3 天)
|
||||
- 创建 `src/inference/normalizers/` 子模块
|
||||
- 每个字段类型一个文件
|
||||
- 提取共享验证逻辑
|
||||
|
||||
6. **拆分长函数** (2 天)
|
||||
- `_normalize_customer_number()` → 3 个函数
|
||||
- `_cross_validate_payment_line()` → CrossValidator 类
|
||||
|
||||
### 第 3 阶段:改进(1-2 周)
|
||||
|
||||
7. **提高测试覆盖率** (5 天)
|
||||
- 目标:70%+ 覆盖率
|
||||
- 专注于验证逻辑
|
||||
- 添加边界情况测试
|
||||
|
||||
8. **配置管理改进** (1 天)
|
||||
- 提取所有魔法数字
|
||||
- 创建配置文件(YAML)
|
||||
- 添加配置验证
|
||||
|
||||
9. **文档改进** (2 天)
|
||||
- 添加架构图
|
||||
- 文档化所有私有方法
|
||||
- 创建贡献指南
|
||||
|
||||
---
|
||||
|
||||
## 附录 A:度量指标
|
||||
|
||||
### 代码复杂度
|
||||
|
||||
| 类别 | 计数 | 平均行数 |
|
||||
|------|------|----------|
|
||||
| 源文件 | 67 | 334 |
|
||||
| 长文件 (>500 行) | 12 | 875 |
|
||||
| 长函数 (>50 行) | 23 | 89 |
|
||||
| 测试文件 | 13 | 298 |
|
||||
|
||||
### 依赖关系
|
||||
|
||||
| 类型 | 计数 |
|
||||
|------|------|
|
||||
| 外部依赖 | ~25 |
|
||||
| 内部模块 | 10 |
|
||||
| 循环依赖 | 0 ✅ |
|
||||
|
||||
### 代码风格
|
||||
|
||||
| 指标 | 覆盖率 |
|
||||
|------|--------|
|
||||
| 类型提示 | 80% |
|
||||
| Docstrings (公开) | 80% |
|
||||
| Docstrings (私有) | 40% |
|
||||
| 测试覆盖率 | 45% |
|
||||
|
||||
---
|
||||
|
||||
**生成日期**: 2026-01-22
|
||||
**审查者**: Claude Code
|
||||
**版本**: v2.0
|
||||
@@ -1,96 +0,0 @@
|
||||
# Field Extractor 分析报告
|
||||
|
||||
## 概述
|
||||
|
||||
field_extractor.py (1183行) 最初被识别为可优化文件,尝试使用 `src/normalize` 模块进行重构,但经过分析和测试后发现 **不应该重构**。
|
||||
|
||||
## 重构尝试
|
||||
|
||||
### 初始计划
|
||||
将 field_extractor.py 中的重复 normalize 方法删除,统一使用 `src/normalize/normalize_field()` 接口。
|
||||
|
||||
### 实施步骤
|
||||
1. ✅ 备份原文件 (`field_extractor_old.py`)
|
||||
2. ✅ 修改 `_normalize_and_validate` 使用统一 normalizer
|
||||
3. ✅ 删除重复的 normalize 方法 (~400行)
|
||||
4. ❌ 运行测试 - **28个失败**
|
||||
5. ✅ 添加 wrapper 方法委托给 normalizer
|
||||
6. ❌ 再次测试 - **12个失败**
|
||||
7. ✅ 还原原文件
|
||||
8. ✅ 测试通过 - **全部45个测试通过**
|
||||
|
||||
## 关键发现
|
||||
|
||||
### 两个模块的不同用途
|
||||
|
||||
| 模块 | 用途 | 输入 | 输出 | 示例 |
|
||||
|------|------|------|------|------|
|
||||
| **src/normalize/** | **变体生成** 用于匹配 | 已提取的字段值 | 多个匹配变体列表 | `"INV-12345"` → `["INV-12345", "12345"]` |
|
||||
| **field_extractor** | **值提取** 从OCR文本 | 包含字段的原始OCR文本 | 提取的单个字段值 | `"Fakturanummer: A3861"` → `"A3861"` |
|
||||
|
||||
### 为什么不能统一?
|
||||
|
||||
1. **src/normalize/** 的设计目的:
|
||||
- 接收已经提取的字段值
|
||||
- 生成多个标准化变体用于fuzzy matching
|
||||
- 例如 BankgiroNormalizer:
|
||||
```python
|
||||
normalize("782-1713") → ["7821713", "782-1713"] # 生成变体
|
||||
```
|
||||
|
||||
2. **field_extractor** 的 normalize 方法:
|
||||
- 接收包含字段的原始OCR文本(可能包含标签、其他文本等)
|
||||
- **提取**特定模式的字段值
|
||||
- 例如 `_normalize_bankgiro`:
|
||||
```python
|
||||
_normalize_bankgiro("Bankgiro: 782-1713") → ("782-1713", True, None) # 从文本提取
|
||||
```
|
||||
|
||||
3. **关键区别**:
|
||||
- Normalizer: 变体生成器 (for matching)
|
||||
- Field Extractor: 模式提取器 (for parsing)
|
||||
|
||||
### 测试失败示例
|
||||
|
||||
使用 normalizer 替代 field extractor 方法后的失败:
|
||||
|
||||
```python
|
||||
# InvoiceNumber 测试
|
||||
Input: "Fakturanummer: A3861"
|
||||
期望: "A3861"
|
||||
实际: "Fakturanummer: A3861" # 没有提取,只是清理
|
||||
|
||||
# Bankgiro 测试
|
||||
Input: "Bankgiro: 782-1713"
|
||||
期望: "782-1713"
|
||||
实际: "7821713" # 返回了不带破折号的变体,而不是提取格式化值
|
||||
```
|
||||
|
||||
## 结论
|
||||
|
||||
**field_extractor.py 不应该使用 src/normalize 模块重构**,因为:
|
||||
|
||||
1. ✅ **职责不同**: 提取 vs 变体生成
|
||||
2. ✅ **输入不同**: 包含标签的原始OCR文本 vs 已提取的字段值
|
||||
3. ✅ **输出不同**: 单个提取值 vs 多个匹配变体
|
||||
4. ✅ **现有代码运行良好**: 所有45个测试通过
|
||||
5. ✅ **提取逻辑有价值**: 包含复杂的模式匹配规则(例如区分 Bankgiro/Plusgiro 格式)
|
||||
|
||||
## 建议
|
||||
|
||||
1. **保留 field_extractor.py 原样**: 不进行重构
|
||||
2. **文档化两个模块的差异**: 确保团队理解各自用途
|
||||
3. **关注其他优化目标**: machine_code_parser.py (919行)
|
||||
|
||||
## 学习点
|
||||
|
||||
重构前应该:
|
||||
1. 理解模块的**真实用途**,而不只是看代码相似度
|
||||
2. 运行完整测试套件验证假设
|
||||
3. 评估是否真的存在重复,还是表面相似但用途不同
|
||||
|
||||
---
|
||||
|
||||
**状态**: ✅ 分析完成,决定不重构
|
||||
**测试**: ✅ 45/45 通过
|
||||
**文件**: 保持 1183行 原样
|
||||
@@ -1,238 +0,0 @@
|
||||
# Machine Code Parser 分析报告
|
||||
|
||||
## 文件概况
|
||||
|
||||
- **文件**: `src/ocr/machine_code_parser.py`
|
||||
- **总行数**: 919 行
|
||||
- **代码行**: 607 行 (66%)
|
||||
- **方法数**: 14 个
|
||||
- **正则表达式使用**: 47 次
|
||||
|
||||
## 代码结构
|
||||
|
||||
### 类结构
|
||||
|
||||
```
|
||||
MachineCodeResult (数据类)
|
||||
├── to_dict()
|
||||
└── get_region_bbox()
|
||||
|
||||
MachineCodeParser (主解析器)
|
||||
├── __init__()
|
||||
├── parse() - 主入口
|
||||
├── _find_tokens_with_values()
|
||||
├── _find_machine_code_line_tokens()
|
||||
├── _parse_standard_payment_line_with_tokens()
|
||||
├── _parse_standard_payment_line() - 142行 ⚠️
|
||||
├── _extract_ocr() - 50行
|
||||
├── _extract_bankgiro() - 58行
|
||||
├── _extract_plusgiro() - 30行
|
||||
├── _extract_amount() - 68行
|
||||
├── _calculate_confidence()
|
||||
└── cross_validate()
|
||||
```
|
||||
|
||||
## 发现的问题
|
||||
|
||||
### 1. ⚠️ `_parse_standard_payment_line` 方法过长 (142行)
|
||||
|
||||
**位置**: 442-582 行
|
||||
|
||||
**问题**:
|
||||
- 包含嵌套函数 `normalize_account_spaces` 和 `format_account`
|
||||
- 多个正则匹配分支
|
||||
- 逻辑复杂,难以测试和维护
|
||||
|
||||
**建议**:
|
||||
可以拆分为独立方法:
|
||||
- `_normalize_account_spaces(line)`
|
||||
- `_format_account(account_digits, context)`
|
||||
- `_match_primary_pattern(line)`
|
||||
- `_match_fallback_patterns(line)`
|
||||
|
||||
### 2. 🔁 4个 `_extract_*` 方法有重复模式
|
||||
|
||||
所有 extract 方法都遵循相同模式:
|
||||
|
||||
```python
|
||||
def _extract_XXX(self, tokens):
|
||||
candidates = []
|
||||
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
matches = self.XXX_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
# 验证逻辑
|
||||
# 上下文检测
|
||||
candidates.append((normalized, context_score, token))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
|
||||
return candidates[0][0]
|
||||
```
|
||||
|
||||
**重复的逻辑**:
|
||||
- Token 迭代
|
||||
- 模式匹配
|
||||
- 候选收集
|
||||
- 上下文评分
|
||||
- 排序和选择最佳匹配
|
||||
|
||||
**建议**:
|
||||
可以提取基础提取器类或通用方法来减少重复。
|
||||
|
||||
### 3. ✅ 上下文检测重复
|
||||
|
||||
上下文检测代码在多个地方重复:
|
||||
|
||||
```python
|
||||
# _extract_bankgiro 中
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
is_bankgiro_context = (
|
||||
'bankgiro' in context_text or
|
||||
'bg:' in context_text or
|
||||
'bg ' in context_text
|
||||
)
|
||||
|
||||
# _extract_plusgiro 中
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
is_plusgiro_context = (
|
||||
'plusgiro' in context_text or
|
||||
'postgiro' in context_text or
|
||||
'pg:' in context_text or
|
||||
'pg ' in context_text
|
||||
)
|
||||
|
||||
# _parse_standard_payment_line 中
|
||||
context = (context_line or raw_line).lower()
|
||||
is_plusgiro_context = (
|
||||
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
|
||||
and 'bankgiro' not in context
|
||||
)
|
||||
```
|
||||
|
||||
**建议**:
|
||||
提取为独立方法:
|
||||
- `_detect_account_context(tokens) -> dict[str, bool]`
|
||||
|
||||
## 重构建议
|
||||
|
||||
### 方案 A: 轻度重构(推荐)✅
|
||||
|
||||
**目标**: 提取重复的上下文检测逻辑,不改变主要结构
|
||||
|
||||
**步骤**:
|
||||
1. 提取 `_detect_account_context(tokens)` 方法
|
||||
2. 提取 `_normalize_account_spaces(line)` 为独立方法
|
||||
3. 提取 `_format_account(digits, context)` 为独立方法
|
||||
|
||||
**影响**:
|
||||
- 减少 ~50-80 行重复代码
|
||||
- 提高可测试性
|
||||
- 低风险,易于验证
|
||||
|
||||
**预期结果**: 919 行 → ~850 行 (↓7%)
|
||||
|
||||
### 方案 B: 中度重构
|
||||
|
||||
**目标**: 创建通用的字段提取框架
|
||||
|
||||
**步骤**:
|
||||
1. 创建 `_generic_extract(pattern, normalizer, context_checker)`
|
||||
2. 重构所有 `_extract_*` 方法使用通用框架
|
||||
3. 拆分 `_parse_standard_payment_line` 为多个小方法
|
||||
|
||||
**影响**:
|
||||
- 减少 ~150-200 行代码
|
||||
- 显著提高可维护性
|
||||
- 中等风险,需要全面测试
|
||||
|
||||
**预期结果**: 919 行 → ~720 行 (↓22%)
|
||||
|
||||
### 方案 C: 深度重构(不推荐)
|
||||
|
||||
**目标**: 完全重新设计为策略模式
|
||||
|
||||
**风险**:
|
||||
- 高风险,可能引入 bugs
|
||||
- 需要大量测试
|
||||
- 可能破坏现有集成
|
||||
|
||||
## 推荐方案
|
||||
|
||||
### ✅ 采用方案 A(轻度重构)
|
||||
|
||||
**理由**:
|
||||
1. **代码已经工作良好**: 没有明显的 bug 或性能问题
|
||||
2. **低风险**: 只提取重复逻辑,不改变核心算法
|
||||
3. **性价比高**: 小改动带来明显的代码质量提升
|
||||
4. **易于验证**: 现有测试应该能覆盖
|
||||
|
||||
### 重构步骤
|
||||
|
||||
```python
|
||||
# 1. 提取上下文检测
|
||||
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
|
||||
"""检测上下文中的账户类型关键词"""
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
|
||||
return {
|
||||
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
|
||||
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
|
||||
}
|
||||
|
||||
# 2. 提取空格标准化
|
||||
def _normalize_account_spaces(self, line: str) -> str:
|
||||
"""移除账户号码中的空格"""
|
||||
# (现有 line 460-481 的代码)
|
||||
|
||||
# 3. 提取账户格式化
|
||||
def _format_account(
|
||||
self,
|
||||
account_digits: str,
|
||||
is_plusgiro_context: bool
|
||||
) -> tuple[str, str]:
|
||||
"""格式化账户并确定类型"""
|
||||
# (现有 line 485-523 的代码)
|
||||
```
|
||||
|
||||
## 对比:field_extractor vs machine_code_parser
|
||||
|
||||
| 特征 | field_extractor | machine_code_parser |
|
||||
|------|-----------------|---------------------|
|
||||
| 用途 | 值提取 | 机器码解析 |
|
||||
| 重复代码 | ~400行normalize方法 | ~80行上下文检测 |
|
||||
| 重构价值 | ❌ 不同用途,不应统一 | ✅ 可提取共享逻辑 |
|
||||
| 风险 | 高(会破坏功能) | 低(只是代码组织) |
|
||||
|
||||
## 决策
|
||||
|
||||
### ✅ 建议重构 machine_code_parser.py
|
||||
|
||||
**与 field_extractor 的不同**:
|
||||
- field_extractor: 重复的方法有**不同的用途**(提取 vs 变体生成)
|
||||
- machine_code_parser: 重复的代码有**相同的用途**(都是上下文检测)
|
||||
|
||||
**预期收益**:
|
||||
- 减少 ~70 行重复代码
|
||||
- 提高可测试性(可以单独测试上下文检测)
|
||||
- 更清晰的代码组织
|
||||
- **低风险**,易于验证
|
||||
|
||||
## 下一步
|
||||
|
||||
1. ✅ 备份原文件
|
||||
2. ✅ 提取 `_detect_account_context` 方法
|
||||
3. ✅ 提取 `_normalize_account_spaces` 方法
|
||||
4. ✅ 提取 `_format_account` 方法
|
||||
5. ✅ 更新所有调用点
|
||||
6. ✅ 运行测试验证
|
||||
7. ✅ 检查代码覆盖率
|
||||
|
||||
---
|
||||
|
||||
**状态**: 📋 分析完成,建议轻度重构
|
||||
**风险评估**: 🟢 低风险
|
||||
**预期收益**: 919行 → ~850行 (↓7%)
|
||||
@@ -1,519 +0,0 @@
|
||||
# Performance Optimization Guide
|
||||
|
||||
This document provides performance optimization recommendations for the Invoice Field Extraction system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Batch Processing Optimization](#batch-processing-optimization)
|
||||
2. [Database Query Optimization](#database-query-optimization)
|
||||
3. [Caching Strategies](#caching-strategies)
|
||||
4. [Memory Management](#memory-management)
|
||||
5. [Profiling and Monitoring](#profiling-and-monitoring)
|
||||
|
||||
---
|
||||
|
||||
## Batch Processing Optimization
|
||||
|
||||
### Current State
|
||||
|
||||
The system processes invoices one at a time. For large batches, this can be inefficient.
|
||||
|
||||
### Recommendations
|
||||
|
||||
#### 1. Database Batch Operations
|
||||
|
||||
**Current**: Individual inserts for each document
|
||||
```python
|
||||
# Inefficient
|
||||
for doc in documents:
|
||||
db.insert_document(doc) # Individual DB call
|
||||
```
|
||||
|
||||
**Optimized**: Use `execute_values` for batch inserts
|
||||
```python
|
||||
# Efficient - already implemented in db.py line 519
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
execute_values(cursor, """
|
||||
INSERT INTO documents (...)
|
||||
VALUES %s
|
||||
""", document_values)
|
||||
```
|
||||
|
||||
**Impact**: 10-50x faster for batches of 100+ documents
|
||||
|
||||
#### 2. PDF Processing Batching
|
||||
|
||||
**Recommendation**: Process PDFs in parallel using multiprocessing
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool
|
||||
|
||||
def process_batch(pdf_paths, batch_size=10):
|
||||
"""Process PDFs in parallel batches."""
|
||||
with Pool(processes=batch_size) as pool:
|
||||
results = pool.map(pipeline.process_pdf, pdf_paths)
|
||||
return results
|
||||
```
|
||||
|
||||
**Considerations**:
|
||||
- GPU models should use a shared process pool (already exists: `src/processing/gpu_pool.py`)
|
||||
- CPU-intensive tasks can use separate process pool (`src/processing/cpu_pool.py`)
|
||||
- Current dual pool coordinator (`dual_pool_coordinator.py`) already supports this pattern
|
||||
|
||||
**Status**: ✅ Already implemented in `src/processing/` modules
|
||||
|
||||
#### 3. Image Caching for Multi-Page PDFs
|
||||
|
||||
**Current**: Each page rendered independently
|
||||
```python
|
||||
# Current pattern in field_extractor.py
|
||||
for page_num in range(total_pages):
|
||||
image = render_pdf_page(pdf_path, page_num, dpi=300)
|
||||
```
|
||||
|
||||
**Optimized**: Pre-render all pages if processing multiple fields per page
|
||||
```python
|
||||
# Batch render
|
||||
images = {
|
||||
page_num: render_pdf_page(pdf_path, page_num, dpi=300)
|
||||
for page_num in page_numbers_needed
|
||||
}
|
||||
|
||||
# Reuse images
|
||||
for detection in detections:
|
||||
image = images[detection.page_no]
|
||||
extract_field(detection, image)
|
||||
```
|
||||
|
||||
**Impact**: Reduces redundant PDF rendering by 50-90% for multi-field invoices
|
||||
|
||||
---
|
||||
|
||||
## Database Query Optimization
|
||||
|
||||
### Current Performance
|
||||
|
||||
- **Parameterized queries**: ✅ Implemented (Phase 1)
|
||||
- **Connection pooling**: ❌ Not implemented
|
||||
- **Query batching**: ✅ Partially implemented
|
||||
- **Index optimization**: ⚠️ Needs verification
|
||||
|
||||
### Recommendations
|
||||
|
||||
#### 1. Connection Pooling
|
||||
|
||||
**Current**: New connection for each operation
|
||||
```python
|
||||
def connect(self):
|
||||
"""Create new database connection."""
|
||||
return psycopg2.connect(**self.config)
|
||||
```
|
||||
|
||||
**Optimized**: Use connection pooling
|
||||
```python
|
||||
from psycopg2 import pool
|
||||
|
||||
class DocumentDatabase:
|
||||
def __init__(self, config):
|
||||
self.pool = pool.SimpleConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=10,
|
||||
**config
|
||||
)
|
||||
|
||||
def connect(self):
|
||||
return self.pool.getconn()
|
||||
|
||||
def close(self, conn):
|
||||
self.pool.putconn(conn)
|
||||
```
|
||||
|
||||
**Impact**:
|
||||
- Reduces connection overhead by 80-95%
|
||||
- Especially important for high-frequency operations
|
||||
|
||||
#### 2. Index Recommendations
|
||||
|
||||
**Check current indexes**:
|
||||
```sql
|
||||
-- Verify indexes exist on frequently queried columns
|
||||
SELECT tablename, indexname, indexdef
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public';
|
||||
```
|
||||
|
||||
**Recommended indexes**:
|
||||
```sql
|
||||
-- If not already present
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_success
|
||||
ON documents(success);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_timestamp
|
||||
ON documents(timestamp DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_document_id
|
||||
ON field_results(document_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_matched
|
||||
ON field_results(matched);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_field_name
|
||||
ON field_results(field_name);
|
||||
```
|
||||
|
||||
**Impact**:
|
||||
- 10-100x faster queries for filtered/sorted results
|
||||
- Critical for `get_failed_matches()` and `get_all_documents_summary()`
|
||||
|
||||
#### 3. Query Batching
|
||||
|
||||
**Status**: ✅ Already implemented for field results (line 519)
|
||||
|
||||
**Verify batching is used**:
|
||||
```python
|
||||
# Good pattern in db.py
|
||||
execute_values(cursor, "INSERT INTO field_results (...) VALUES %s", field_values)
|
||||
```
|
||||
|
||||
**Additional opportunity**: Batch `SELECT` queries
|
||||
```python
|
||||
# Current
|
||||
docs = [get_document(doc_id) for doc_id in doc_ids] # N queries
|
||||
|
||||
# Optimized
|
||||
docs = get_documents_batch(doc_ids) # 1 query with IN clause
|
||||
```
|
||||
|
||||
**Status**: ✅ Already implemented (`get_documents_batch` exists in db.py)
|
||||
|
||||
---
|
||||
|
||||
## Caching Strategies
|
||||
|
||||
### 1. Model Loading Cache
|
||||
|
||||
**Current**: Models loaded per-instance
|
||||
|
||||
**Recommendation**: Singleton pattern for YOLO model
|
||||
```python
|
||||
class YOLODetectorSingleton:
|
||||
_instance = None
|
||||
_model = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, model_path):
|
||||
if cls._instance is None:
|
||||
cls._instance = YOLODetector(model_path)
|
||||
return cls._instance
|
||||
```
|
||||
|
||||
**Impact**: Reduces memory usage by 90% when processing multiple documents
|
||||
|
||||
### 2. Parser Instance Caching
|
||||
|
||||
**Current**: ✅ Already optimal
|
||||
```python
|
||||
# Good pattern in field_extractor.py
|
||||
def __init__(self):
|
||||
self.payment_line_parser = PaymentLineParser() # Reused
|
||||
self.customer_number_parser = CustomerNumberParser() # Reused
|
||||
```
|
||||
|
||||
**Status**: No changes needed
|
||||
|
||||
### 3. OCR Result Caching
|
||||
|
||||
**Recommendation**: Cache OCR results for identical regions
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def ocr_region_cached(image_hash, bbox):
|
||||
"""Cache OCR results by image hash + bbox."""
|
||||
return paddle_ocr.ocr_region(image, bbox)
|
||||
```
|
||||
|
||||
**Impact**: 50-80% speedup when re-processing similar documents
|
||||
|
||||
**Note**: Requires implementing image hashing (e.g., `hashlib.md5(image.tobytes())`)
|
||||
|
||||
---
|
||||
|
||||
## Memory Management
|
||||
|
||||
### Current Issues
|
||||
|
||||
**Potential memory leaks**:
|
||||
1. Large images kept in memory after processing
|
||||
2. OCR results accumulated without cleanup
|
||||
3. Model outputs not explicitly cleared
|
||||
|
||||
### Recommendations
|
||||
|
||||
#### 1. Explicit Image Cleanup
|
||||
|
||||
```python
|
||||
import gc
|
||||
|
||||
def process_pdf(pdf_path):
|
||||
try:
|
||||
image = render_pdf(pdf_path)
|
||||
result = extract_fields(image)
|
||||
return result
|
||||
finally:
|
||||
del image # Explicit cleanup
|
||||
gc.collect() # Force garbage collection
|
||||
```
|
||||
|
||||
#### 2. Generator Pattern for Large Batches
|
||||
|
||||
**Current**: Load all documents into memory
|
||||
```python
|
||||
docs = [process_pdf(path) for path in pdf_paths] # All in memory
|
||||
```
|
||||
|
||||
**Optimized**: Use generator for streaming processing
|
||||
```python
|
||||
def process_batch_streaming(pdf_paths):
|
||||
"""Process documents one at a time, yielding results."""
|
||||
for path in pdf_paths:
|
||||
result = process_pdf(path)
|
||||
yield result
|
||||
# Result can be saved to DB immediately
|
||||
# Previous result is garbage collected
|
||||
```
|
||||
|
||||
**Impact**: Constant memory usage regardless of batch size
|
||||
|
||||
#### 3. Context Managers for Resources
|
||||
|
||||
```python
|
||||
class InferencePipeline:
|
||||
def __enter__(self):
|
||||
self.detector.load_model()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.detector.unload_model()
|
||||
self.extractor.cleanup()
|
||||
|
||||
# Usage
|
||||
with InferencePipeline(...) as pipeline:
|
||||
results = pipeline.process_pdf(path)
|
||||
# Automatic cleanup
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Profiling and Monitoring
|
||||
|
||||
### Recommended Profiling Tools
|
||||
|
||||
#### 1. cProfile for CPU Profiling
|
||||
|
||||
```python
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
# Your code here
|
||||
pipeline.process_pdf(pdf_path)
|
||||
|
||||
profiler.disable()
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.sort_stats('cumulative')
|
||||
stats.print_stats(20) # Top 20 slowest functions
|
||||
```
|
||||
|
||||
#### 2. memory_profiler for Memory Analysis
|
||||
|
||||
```bash
|
||||
pip install memory_profiler
|
||||
python -m memory_profiler your_script.py
|
||||
```
|
||||
|
||||
Or decorator-based:
|
||||
```python
|
||||
from memory_profiler import profile
|
||||
|
||||
@profile
|
||||
def process_large_batch(pdf_paths):
|
||||
# Memory usage tracked line-by-line
|
||||
results = [process_pdf(path) for path in pdf_paths]
|
||||
return results
|
||||
```
|
||||
|
||||
#### 3. py-spy for Production Profiling
|
||||
|
||||
```bash
|
||||
pip install py-spy
|
||||
|
||||
# Profile running process
|
||||
py-spy top --pid 12345
|
||||
|
||||
# Generate flamegraph
|
||||
py-spy record -o profile.svg -- python your_script.py
|
||||
```
|
||||
|
||||
**Advantage**: No code changes needed, minimal overhead
|
||||
|
||||
### Key Metrics to Monitor
|
||||
|
||||
1. **Processing Time per Document**
|
||||
- Target: <10 seconds for single-page invoice
|
||||
- Current: ~2-5 seconds (estimated)
|
||||
|
||||
2. **Memory Usage**
|
||||
- Target: <2GB for batch of 100 documents
|
||||
- Monitor: Peak memory usage
|
||||
|
||||
3. **Database Query Time**
|
||||
- Target: <100ms per query (with indexes)
|
||||
- Monitor: Slow query log
|
||||
|
||||
4. **OCR Accuracy vs Speed Trade-off**
|
||||
- Current: PaddleOCR with GPU (~200ms per region)
|
||||
- Alternative: Tesseract (~500ms, slightly more accurate)
|
||||
|
||||
### Logging Performance Metrics
|
||||
|
||||
**Add to pipeline.py**:
|
||||
```python
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def process_pdf(self, pdf_path):
|
||||
start = time.time()
|
||||
|
||||
# Processing...
|
||||
result = self._process_internal(pdf_path)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(f"Processed {pdf_path} in {elapsed:.2f}s")
|
||||
|
||||
# Log to database for analysis
|
||||
self.db.log_performance({
|
||||
'document_id': result.document_id,
|
||||
'processing_time': elapsed,
|
||||
'field_count': len(result.fields)
|
||||
})
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Performance Optimization Priorities
|
||||
|
||||
### High Priority (Implement First)
|
||||
|
||||
1. ✅ **Database parameterized queries** - Already done (Phase 1)
|
||||
2. ⚠️ **Database connection pooling** - Not implemented
|
||||
3. ⚠️ **Index optimization** - Needs verification
|
||||
|
||||
### Medium Priority
|
||||
|
||||
4. ⚠️ **Batch PDF rendering** - Optimization possible
|
||||
5. ✅ **Parser instance reuse** - Already done (Phase 2)
|
||||
6. ⚠️ **Model caching** - Could improve
|
||||
|
||||
### Low Priority (Nice to Have)
|
||||
|
||||
7. ⚠️ **OCR result caching** - Complex implementation
|
||||
8. ⚠️ **Generator patterns** - Refactoring needed
|
||||
9. ⚠️ **Advanced profiling** - For production optimization
|
||||
|
||||
---
|
||||
|
||||
## Benchmarking Script
|
||||
|
||||
```python
|
||||
"""
|
||||
Benchmark script for invoice processing performance.
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from src.inference.pipeline import InferencePipeline
|
||||
|
||||
def benchmark_single_document(pdf_path, iterations=10):
|
||||
"""Benchmark single document processing."""
|
||||
pipeline = InferencePipeline(
|
||||
model_path="path/to/model.pt",
|
||||
use_gpu=True
|
||||
)
|
||||
|
||||
times = []
|
||||
for i in range(iterations):
|
||||
start = time.time()
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
print(f"Iteration {i+1}: {elapsed:.2f}s")
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"\nAverage: {avg_time:.2f}s")
|
||||
print(f"Min: {min(times):.2f}s")
|
||||
print(f"Max: {max(times):.2f}s")
|
||||
|
||||
def benchmark_batch(pdf_paths, batch_size=10):
|
||||
"""Benchmark batch processing."""
|
||||
from multiprocessing import Pool
|
||||
|
||||
pipeline = InferencePipeline(
|
||||
model_path="path/to/model.pt",
|
||||
use_gpu=True
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
with Pool(processes=batch_size) as pool:
|
||||
results = pool.map(pipeline.process_pdf, pdf_paths)
|
||||
|
||||
elapsed = time.time() - start
|
||||
avg_per_doc = elapsed / len(pdf_paths)
|
||||
|
||||
print(f"Total time: {elapsed:.2f}s")
|
||||
print(f"Documents: {len(pdf_paths)}")
|
||||
print(f"Average per document: {avg_per_doc:.2f}s")
|
||||
print(f"Throughput: {len(pdf_paths)/elapsed:.2f} docs/sec")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Single document benchmark
|
||||
benchmark_single_document("test.pdf")
|
||||
|
||||
# Batch benchmark
|
||||
pdf_paths = list(Path("data/test_pdfs").glob("*.pdf"))
|
||||
benchmark_batch(pdf_paths[:100])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
**Implemented (Phase 1-2)**:
|
||||
- ✅ Parameterized queries (SQL injection fix)
|
||||
- ✅ Parser instance reuse (Phase 2 refactoring)
|
||||
- ✅ Batch insert operations (execute_values)
|
||||
- ✅ Dual pool processing (CPU/GPU separation)
|
||||
|
||||
**Quick Wins (Low effort, high impact)**:
|
||||
- Database connection pooling (2-4 hours)
|
||||
- Index verification and optimization (1-2 hours)
|
||||
- Batch PDF rendering (4-6 hours)
|
||||
|
||||
**Long-term Improvements**:
|
||||
- OCR result caching with hashing
|
||||
- Generator patterns for streaming
|
||||
- Advanced profiling and monitoring
|
||||
|
||||
**Expected Impact**:
|
||||
- Connection pooling: 80-95% reduction in DB overhead
|
||||
- Indexes: 10-100x faster queries
|
||||
- Batch rendering: 50-90% less redundant work
|
||||
- **Overall**: 2-5x throughput improvement for batch processing
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,170 +0,0 @@
|
||||
# 代码重构总结报告
|
||||
|
||||
## 📊 整体成果
|
||||
|
||||
### 测试状态
|
||||
- ✅ **688/688 测试全部通过** (100%)
|
||||
- ✅ **代码覆盖率**: 34% → 37% (+3%)
|
||||
- ✅ **0 个失败**, 0 个错误
|
||||
|
||||
### 测试覆盖率改进
|
||||
- ✅ **machine_code_parser**: 25% → 65% (+40%)
|
||||
- ✅ **新增测试**: 55个(633 → 688)
|
||||
|
||||
---
|
||||
|
||||
## 🎯 已完成的重构
|
||||
|
||||
### 1. ✅ Matcher 模块化 (876行 → 205行, ↓76%)
|
||||
|
||||
**文件**:
|
||||
|
||||
**重构内容**:
|
||||
- 将单一876行文件拆分为 **11个模块**
|
||||
- 提取 **5种独立的匹配策略**
|
||||
- 创建专门的数据模型、工具函数和上下文处理模块
|
||||
|
||||
**新模块结构**:
|
||||
|
||||
|
||||
**测试结果**:
|
||||
- ✅ 77个 matcher 测试全部通过
|
||||
- ✅ 完整的README文档
|
||||
- ✅ 策略模式,易于扩展
|
||||
|
||||
**收益**:
|
||||
- 📉 代码量减少 76%
|
||||
- 📈 可维护性显著提高
|
||||
- ✨ 每个策略独立测试
|
||||
- 🔧 易于添加新策略
|
||||
|
||||
---
|
||||
|
||||
### 2. ✅ Machine Code Parser 轻度重构 + 测试覆盖 (919行 → 929行)
|
||||
|
||||
**文件**: src/ocr/machine_code_parser.py
|
||||
|
||||
**重构内容**:
|
||||
- 提取 **3个共享辅助方法**,消除重复代码
|
||||
- 优化上下文检测逻辑
|
||||
- 简化账号格式化方法
|
||||
|
||||
**测试改进**:
|
||||
- ✅ **新增55个测试**(24 → 79个)
|
||||
- ✅ **覆盖率**: 25% → 65% (+40%)
|
||||
- ✅ 所有688个项目测试通过
|
||||
|
||||
**新增测试覆盖**:
|
||||
- **第一轮** (22个测试):
|
||||
- `_detect_account_context()` - 8个测试(上下文检测)
|
||||
- `_normalize_account_spaces()` - 5个测试(空格规范化)
|
||||
- `_format_account()` - 4个测试(账号格式化)
|
||||
- `parse()` - 5个测试(主入口方法)
|
||||
- **第二轮** (33个测试):
|
||||
- `_extract_ocr()` - 8个测试(OCR 提取)
|
||||
- `_extract_bankgiro()` - 9个测试(Bankgiro 提取)
|
||||
- `_extract_plusgiro()` - 8个测试(Plusgiro 提取)
|
||||
- `_extract_amount()` - 8个测试(金额提取)
|
||||
|
||||
**收益**:
|
||||
- 🔄 消除80行重复代码
|
||||
- 📈 可测试性提高(可独立测试辅助方法)
|
||||
- 📖 代码可读性提升
|
||||
- ✅ 覆盖率从25%提升到65% (+40%)
|
||||
- 🎯 低风险,高回报
|
||||
|
||||
---
|
||||
|
||||
### 3. ✅ Field Extractor 分析 (决定不重构)
|
||||
|
||||
**文件**: (1183行)
|
||||
|
||||
**分析结果**: ❌ **不应重构**
|
||||
|
||||
**关键洞察**:
|
||||
- 表面相似的代码可能有**完全不同的用途**
|
||||
- field_extractor: **解析/提取** 字段值
|
||||
- src/normalize: **标准化/生成变体** 用于匹配
|
||||
- 两者职责不同,不应统一
|
||||
|
||||
**文档**:
|
||||
|
||||
---
|
||||
|
||||
## 📈 重构统计
|
||||
|
||||
### 代码行数变化
|
||||
|
||||
| 文件 | 重构前 | 重构后 | 变化 | 百分比 |
|
||||
|------|--------|--------|------|--------|
|
||||
| **matcher/field_matcher.py** | 876行 | 205行 | -671 | ↓76% |
|
||||
| **matcher/* (新增10个模块)** | 0行 | 466行 | +466 | 新增 |
|
||||
| **matcher 总计** | 876行 | 671行 | -205 | ↓23% |
|
||||
| **ocr/machine_code_parser.py** | 919行 | 929行 | +10 | +1% |
|
||||
| **总净减少** | - | - | **-195行** | **↓11%** |
|
||||
|
||||
### 测试覆盖
|
||||
|
||||
| 模块 | 测试数 | 通过率 | 覆盖率 | 状态 |
|
||||
|------|--------|--------|--------|------|
|
||||
| matcher | 77 | 100% | - | ✅ |
|
||||
| field_extractor | 45 | 100% | 39% | ✅ |
|
||||
| machine_code_parser | 79 | 100% | 65% | ✅ |
|
||||
| normalizer | ~120 | 100% | - | ✅ |
|
||||
| 其他模块 | ~367 | 100% | - | ✅ |
|
||||
| **总计** | **688** | **100%** | **37%** | ✅ |
|
||||
|
||||
---
|
||||
|
||||
## 🎓 重构经验总结
|
||||
|
||||
### 成功经验
|
||||
|
||||
1. **✅ 先测试后重构**
|
||||
- 所有重构都有完整测试覆盖
|
||||
- 每次改动后立即验证测试
|
||||
- 100%测试通过率保证质量
|
||||
|
||||
2. **✅ 识别真正的重复**
|
||||
- 不是所有相似代码都是重复
|
||||
- field_extractor vs normalizer: 表面相似但用途不同
|
||||
- machine_code_parser: 真正的代码重复
|
||||
|
||||
3. **✅ 渐进式重构**
|
||||
- matcher: 大规模模块化 (策略模式)
|
||||
- machine_code_parser: 轻度重构 (提取共享方法)
|
||||
- field_extractor: 分析后决定不重构
|
||||
|
||||
### 关键决策
|
||||
|
||||
#### ✅ 应该重构的情况
|
||||
- **matcher**: 单一文件过长 (876行),包含多种策略
|
||||
- **machine_code_parser**: 多处相同用途的重复代码
|
||||
|
||||
#### ❌ 不应重构的情况
|
||||
- **field_extractor**: 相似代码有不同用途
|
||||
|
||||
### 教训
|
||||
|
||||
**不要盲目追求DRY原则**
|
||||
> 相似代码不一定是重复。要理解代码的**真实用途**。
|
||||
|
||||
---
|
||||
|
||||
## ✅ 总结
|
||||
|
||||
**关键成果**:
|
||||
- 📉 净减少 195 行代码
|
||||
- 📈 代码覆盖率 +3% (34% → 37%)
|
||||
- ✅ 测试数量 +55 (633 → 688)
|
||||
- 🎯 machine_code_parser 覆盖率 +40% (25% → 65%)
|
||||
- ✨ 模块化程度显著提高
|
||||
- 🎯 可维护性大幅提升
|
||||
|
||||
**重要教训**:
|
||||
> 相似的代码不一定是重复的代码。理解代码的真实用途,才能做出正确的重构决策。
|
||||
|
||||
**下一步建议**:
|
||||
1. 继续提升 machine_code_parser 覆盖率到 80%+ (目前 65%)
|
||||
2. 为其他低覆盖模块添加测试(field_extractor 39%, pipeline 19%)
|
||||
3. 完善边界条件和异常情况的测试
|
||||
@@ -1,258 +0,0 @@
|
||||
# 测试覆盖率改进报告
|
||||
|
||||
## 📊 改进概览
|
||||
|
||||
### 整体统计
|
||||
- ✅ **测试总数**: 633 → 688 (+55个测试, +8.7%)
|
||||
- ✅ **通过率**: 100% (688/688)
|
||||
- ✅ **整体覆盖率**: 34% → 37% (+3%)
|
||||
|
||||
### machine_code_parser.py 专项改进
|
||||
- ✅ **测试数**: 24 → 79 (+55个测试, +229%)
|
||||
- ✅ **覆盖率**: 25% → 65% (+40%)
|
||||
- ✅ **未覆盖行**: 273 → 129 (减少144行)
|
||||
|
||||
---
|
||||
|
||||
## 🎯 新增测试详情
|
||||
|
||||
### 第一轮改进 (22个测试)
|
||||
|
||||
#### 1. TestDetectAccountContext (8个测试)
|
||||
|
||||
测试新增的 `_detect_account_context()` 辅助方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_bankgiro_keyword` - 检测 'bankgiro' 关键词
|
||||
2. `test_bg_keyword` - 检测 'bg:' 缩写
|
||||
3. `test_plusgiro_keyword` - 检测 'plusgiro' 关键词
|
||||
4. `test_postgiro_keyword` - 检测 'postgiro' 别名
|
||||
5. `test_pg_keyword` - 检测 'pg:' 缩写
|
||||
6. `test_both_contexts` - 同时存在两种关键词
|
||||
7. `test_no_context` - 无账号关键词
|
||||
8. `test_case_insensitive` - 大小写不敏感检测
|
||||
|
||||
**覆盖的代码路径**:
|
||||
```python
|
||||
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
return {
|
||||
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
|
||||
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. TestNormalizeAccountSpacesMethod (5个测试)
|
||||
|
||||
测试新增的 `_normalize_account_spaces()` 辅助方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_removes_spaces_after_arrow` - 移除 > 后的空格
|
||||
2. `test_multiple_consecutive_spaces` - 处理多个连续空格
|
||||
3. `test_no_arrow_returns_unchanged` - 无 > 标记时返回原值
|
||||
4. `test_spaces_before_arrow_preserved` - 保留 > 前的空格
|
||||
5. `test_empty_string` - 空字符串处理
|
||||
|
||||
**覆盖的代码路径**:
|
||||
```python
|
||||
def _normalize_account_spaces(self, line: str) -> str:
|
||||
if '>' not in line:
|
||||
return line
|
||||
parts = line.split('>', 1)
|
||||
after_arrow = parts[1]
|
||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
|
||||
while re.search(r'(\d)\s+(\d)', normalized):
|
||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
|
||||
return parts[0] + '>' + normalized
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. TestFormatAccount (4个测试)
|
||||
|
||||
测试新增的 `_format_account()` 辅助方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_plusgiro_context_forces_plusgiro` - Plusgiro 上下文强制格式化为 Plusgiro
|
||||
2. `test_valid_bankgiro_7_digits` - 7位有效 Bankgiro 格式化
|
||||
3. `test_valid_bankgiro_8_digits` - 8位有效 Bankgiro 格式化
|
||||
4. `test_defaults_to_bankgiro_when_ambiguous` - 模糊情况默认 Bankgiro
|
||||
|
||||
**覆盖的代码路径**:
|
||||
```python
|
||||
def _format_account(self, account_digits: str, is_plusgiro_context: bool) -> tuple[str, str]:
|
||||
if is_plusgiro_context:
|
||||
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
|
||||
return formatted, 'plusgiro'
|
||||
|
||||
# Luhn 验证逻辑
|
||||
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
|
||||
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
|
||||
|
||||
# 决策逻辑
|
||||
if pg_valid and not bg_valid:
|
||||
return pg_formatted, 'plusgiro'
|
||||
elif bg_valid and not pg_valid:
|
||||
return bg_formatted, 'bankgiro'
|
||||
else:
|
||||
return bg_formatted, 'bankgiro'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. TestParseMethod (5个测试)
|
||||
|
||||
测试主入口 `parse()` 方法。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_parse_empty_tokens` - 空 token 列表处理
|
||||
2. `test_parse_finds_payment_line_in_bottom_region` - 在页面底部35%区域查找付款行
|
||||
3. `test_parse_ignores_top_region` - 忽略页面顶部区域
|
||||
4. `test_parse_with_context_keywords` - 检测上下文关键词
|
||||
5. `test_parse_stores_source_tokens` - 存储源 token
|
||||
|
||||
**覆盖的代码路径**:
|
||||
- Token 过滤(底部区域检测)
|
||||
- 上下文关键词检测
|
||||
- 付款行查找和解析
|
||||
- 结果对象构建
|
||||
|
||||
---
|
||||
|
||||
### 第二轮改进 (33个测试)
|
||||
|
||||
#### 5. TestExtractOCR (8个测试)
|
||||
|
||||
测试 `_extract_ocr()` 方法 - OCR 参考号码提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_valid_ocr_10_digits` - 提取10位 OCR 号码
|
||||
2. `test_extract_valid_ocr_15_digits` - 提取15位 OCR 号码
|
||||
3. `test_extract_ocr_with_hash_markers` - 带 # 标记的 OCR
|
||||
4. `test_extract_longest_ocr_when_multiple` - 多个候选时选最长
|
||||
5. `test_extract_ocr_ignores_short_numbers` - 忽略短于10位的数字
|
||||
6. `test_extract_ocr_ignores_long_numbers` - 忽略长于25位的数字
|
||||
7. `test_extract_ocr_excludes_bankgiro_variants` - 排除 Bankgiro 变体
|
||||
8. `test_extract_ocr_empty_tokens` - 空 token 处理
|
||||
|
||||
#### 6. TestExtractBankgiro (9个测试)
|
||||
|
||||
测试 `_extract_bankgiro()` 方法 - Bankgiro 账号提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_bankgiro_7_digits_with_dash` - 带破折号的7位 Bankgiro
|
||||
2. `test_extract_bankgiro_7_digits_without_dash` - 无破折号的7位 Bankgiro
|
||||
3. `test_extract_bankgiro_8_digits_with_dash` - 带破折号的8位 Bankgiro
|
||||
4. `test_extract_bankgiro_8_digits_without_dash` - 无破折号的8位 Bankgiro
|
||||
5. `test_extract_bankgiro_with_spaces` - 带空格的 Bankgiro
|
||||
6. `test_extract_bankgiro_handles_plusgiro_format` - 处理 Plusgiro 格式
|
||||
7. `test_extract_bankgiro_with_context` - 带上下文关键词
|
||||
8. `test_extract_bankgiro_ignores_plusgiro_context` - 忽略 Plusgiro 上下文
|
||||
9. `test_extract_bankgiro_empty_tokens` - 空 token 处理
|
||||
|
||||
#### 7. TestExtractPlusgiro (8个测试)
|
||||
|
||||
测试 `_extract_plusgiro()` 方法 - Plusgiro 账号提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_plusgiro_7_digits_with_dash` - 带破折号的7位 Plusgiro
|
||||
2. `test_extract_plusgiro_7_digits_without_dash` - 无破折号的7位 Plusgiro
|
||||
3. `test_extract_plusgiro_8_digits` - 8位 Plusgiro
|
||||
4. `test_extract_plusgiro_with_spaces` - 带空格的 Plusgiro
|
||||
5. `test_extract_plusgiro_with_context` - 带上下文关键词
|
||||
6. `test_extract_plusgiro_ignores_too_short` - 忽略少于7位
|
||||
7. `test_extract_plusgiro_ignores_too_long` - 忽略多于8位
|
||||
8. `test_extract_plusgiro_empty_tokens` - 空 token 处理
|
||||
|
||||
#### 8. TestExtractAmount (8个测试)
|
||||
|
||||
测试 `_extract_amount()` 方法 - 金额提取。
|
||||
|
||||
**测试用例**:
|
||||
1. `test_extract_amount_with_comma_decimal` - 逗号小数分隔符
|
||||
2. `test_extract_amount_with_dot_decimal` - 点号小数分隔符
|
||||
3. `test_extract_amount_integer` - 整数金额
|
||||
4. `test_extract_amount_with_thousand_separator` - 千位分隔符
|
||||
5. `test_extract_amount_large_number` - 大额金额
|
||||
6. `test_extract_amount_ignores_too_large` - 忽略过大金额
|
||||
7. `test_extract_amount_ignores_zero` - 忽略零或负数
|
||||
8. `test_extract_amount_empty_tokens` - 空 token 处理
|
||||
|
||||
---
|
||||
|
||||
## 📈 覆盖率分析
|
||||
|
||||
### 已覆盖的方法
|
||||
✅ `_detect_account_context()` - **100%** (第一轮新增)
|
||||
✅ `_normalize_account_spaces()` - **100%** (第一轮新增)
|
||||
✅ `_format_account()` - **95%** (第一轮新增)
|
||||
✅ `parse()` - **70%** (第一轮改进)
|
||||
✅ `_parse_standard_payment_line()` - **95%** (已有测试)
|
||||
✅ `_extract_ocr()` - **85%** (第二轮新增)
|
||||
✅ `_extract_bankgiro()` - **90%** (第二轮新增)
|
||||
✅ `_extract_plusgiro()` - **90%** (第二轮新增)
|
||||
✅ `_extract_amount()` - **80%** (第二轮新增)
|
||||
|
||||
### 仍需改进的方法 (未覆盖/部分覆盖)
|
||||
⚠️ `_calculate_confidence()` - **0%** (未测试)
|
||||
⚠️ `cross_validate()` - **0%** (未测试)
|
||||
⚠️ `get_region_bbox()` - **0%** (未测试)
|
||||
⚠️ `_find_tokens_with_values()` - **部分覆盖**
|
||||
⚠️ `_find_machine_code_line_tokens()` - **部分覆盖**
|
||||
|
||||
### 未覆盖的代码行(129行)
|
||||
主要集中在:
|
||||
1. **验证方法** (lines 805-824): `_calculate_confidence`, `cross_validate`
|
||||
2. **辅助方法** (lines 80-92, 336-369, 377-407): Token 查找、bbox 计算、日志记录
|
||||
3. **边界条件** (lines 648-653, 690, 699, 759-760等): 某些提取方法的边界情况
|
||||
|
||||
---
|
||||
|
||||
## 🎯 改进建议
|
||||
|
||||
### ✅ 已完成目标
|
||||
- ✅ 覆盖率从 25% 提升到 65% (+40%)
|
||||
- ✅ 测试数量从 24 增加到 79 (+55个)
|
||||
- ✅ 提取方法全部测试(_extract_ocr, _extract_bankgiro, _extract_plusgiro, _extract_amount)
|
||||
|
||||
### 下一步目标(覆盖率 65% → 80%+)
|
||||
1. **添加验证方法测试** - 为 `_calculate_confidence`, `cross_validate` 添加测试
|
||||
2. **添加辅助方法测试** - 为 token 查找和 bbox 计算方法添加测试
|
||||
3. **完善边界条件** - 增加边界情况和异常处理的测试
|
||||
4. **集成测试** - 添加端到端的集成测试,使用真实 PDF token 数据
|
||||
|
||||
---
|
||||
|
||||
## ✅ 已完成的改进
|
||||
|
||||
### 重构收益
|
||||
- ✅ 提取的3个辅助方法现在可以独立测试
|
||||
- ✅ 测试粒度更细,更容易定位问题
|
||||
- ✅ 代码可读性提高,测试用例清晰易懂
|
||||
|
||||
### 质量保证
|
||||
- ✅ 所有655个测试100%通过
|
||||
- ✅ 无回归问题
|
||||
- ✅ 新增测试覆盖了之前未测试的重构代码
|
||||
|
||||
---
|
||||
|
||||
## 📚 测试编写经验
|
||||
|
||||
### 成功经验
|
||||
1. **使用 fixture 创建测试数据** - `_create_token()` 辅助方法简化了 token 创建
|
||||
2. **按方法组织测试类** - 每个方法一个测试类,结构清晰
|
||||
3. **测试用例命名清晰** - `test_<what>_<condition>` 格式,一目了然
|
||||
4. **覆盖关键路径** - 优先测试常见场景和边界条件
|
||||
|
||||
### 遇到的问题
|
||||
1. **Token 初始化参数** - 忘记了 `page_no` 参数,导致初始测试失败
|
||||
- 解决:修复 `_create_token()` 辅助方法,添加 `page_no=0`
|
||||
|
||||
---
|
||||
|
||||
**报告日期**: 2026-01-24
|
||||
**状态**: ✅ 完成
|
||||
**下一步**: 继续提升覆盖率到 60%+
|
||||
772
docs/aws-deployment-guide.md
Normal file
772
docs/aws-deployment-guide.md
Normal file
@@ -0,0 +1,772 @@
|
||||
# AWS 部署方案完整指南
|
||||
|
||||
## 目录
|
||||
- [核心问题](#核心问题)
|
||||
- [存储方案](#存储方案)
|
||||
- [训练方案](#训练方案)
|
||||
- [推理方案](#推理方案)
|
||||
- [价格对比](#价格对比)
|
||||
- [推荐架构](#推荐架构)
|
||||
- [实施步骤](#实施步骤)
|
||||
- [AWS vs Azure 对比](#aws-vs-azure-对比)
|
||||
|
||||
---
|
||||
|
||||
## 核心问题
|
||||
|
||||
| 问题 | 答案 |
|
||||
|------|------|
|
||||
| S3 能用于训练吗? | 可以,用 Mountpoint for S3 或 SageMaker 原生支持 |
|
||||
| 能实时从 S3 读取训练吗? | 可以,SageMaker 支持 Pipe Mode 流式读取 |
|
||||
| 本地能挂载 S3 吗? | 可以,用 s3fs-fuse 或 Rclone |
|
||||
| EC2 空闲时收费吗? | 收费,只要运行就按小时计费 |
|
||||
| 如何按需付费? | 用 SageMaker Managed Spot 或 Lambda |
|
||||
| 推理服务用什么? | Lambda (Serverless) 或 ECS/Fargate (容器) |
|
||||
|
||||
---
|
||||
|
||||
## 存储方案
|
||||
|
||||
### Amazon S3(推荐)
|
||||
|
||||
S3 是 AWS 的核心存储服务,与 SageMaker 深度集成。
|
||||
|
||||
```bash
|
||||
# 创建 S3 桶
|
||||
aws s3 mb s3://invoice-training-data --region us-east-1
|
||||
|
||||
# 上传训练数据
|
||||
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
|
||||
|
||||
# 创建目录结构
|
||||
aws s3api put-object --bucket invoice-training-data --key datasets/
|
||||
aws s3api put-object --bucket invoice-training-data --key models/
|
||||
```
|
||||
|
||||
### Mountpoint for Amazon S3
|
||||
|
||||
AWS 官方的 S3 挂载客户端,性能优于 s3fs:
|
||||
|
||||
```bash
|
||||
# 安装 Mountpoint
|
||||
wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb
|
||||
sudo dpkg -i mount-s3.deb
|
||||
|
||||
# 挂载 S3
|
||||
mkdir -p /mnt/s3-data
|
||||
mount-s3 invoice-training-data /mnt/s3-data --region us-east-1
|
||||
|
||||
# 配置缓存(推荐)
|
||||
mount-s3 invoice-training-data /mnt/s3-data \
|
||||
--region us-east-1 \
|
||||
--cache /tmp/s3-cache \
|
||||
--metadata-ttl 60
|
||||
```
|
||||
|
||||
### 本地开发挂载
|
||||
|
||||
**Linux/Mac (s3fs-fuse):**
|
||||
```bash
|
||||
# 安装
|
||||
sudo apt-get install s3fs
|
||||
|
||||
# 配置凭证
|
||||
echo ACCESS_KEY_ID:SECRET_ACCESS_KEY > ~/.passwd-s3fs
|
||||
chmod 600 ~/.passwd-s3fs
|
||||
|
||||
# 挂载
|
||||
s3fs invoice-training-data /mnt/s3 -o passwd_file=~/.passwd-s3fs
|
||||
```
|
||||
|
||||
**Windows (Rclone):**
|
||||
```powershell
|
||||
# 安装
|
||||
winget install Rclone.Rclone
|
||||
|
||||
# 配置
|
||||
rclone config # 选择 s3
|
||||
|
||||
# 挂载
|
||||
rclone mount aws:invoice-training-data Z: --vfs-cache-mode full
|
||||
```
|
||||
|
||||
### 存储费用
|
||||
|
||||
| 层级 | 价格 | 适用场景 |
|
||||
|------|------|---------|
|
||||
| S3 Standard | $0.023/GB/月 | 频繁访问 |
|
||||
| S3 Intelligent-Tiering | $0.023/GB/月 | 自动分层 |
|
||||
| S3 Infrequent Access | $0.0125/GB/月 | 偶尔访问 |
|
||||
| S3 Glacier | $0.004/GB/月 | 长期存档 |
|
||||
|
||||
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.12/月**
|
||||
|
||||
### SageMaker 数据输入模式
|
||||
|
||||
| 模式 | 说明 | 适用场景 |
|
||||
|------|------|---------|
|
||||
| File Mode | 下载到本地再训练 | 小数据集 |
|
||||
| Pipe Mode | 流式读取,不占本地空间 | 大数据集 |
|
||||
| FastFile Mode | 按需下载,最高 3x 加速 | 推荐 |
|
||||
|
||||
---
|
||||
|
||||
## 训练方案
|
||||
|
||||
### 方案总览
|
||||
|
||||
| 方案 | 适用场景 | 空闲费用 | 复杂度 | Spot 支持 |
|
||||
|------|---------|---------|--------|----------|
|
||||
| EC2 GPU | 简单直接 | 24/7 收费 | 低 | 是 |
|
||||
| SageMaker Training | MLOps 集成 | 按任务计费 | 中 | 是 |
|
||||
| EKS + GPU | Kubernetes | 复杂计费 | 高 | 是 |
|
||||
|
||||
### EC2 vs SageMaker
|
||||
|
||||
| 特性 | EC2 | SageMaker |
|
||||
|------|-----|-----------|
|
||||
| 本质 | 虚拟机 | 托管 ML 平台 |
|
||||
| 计算费用 | $3.06/hr (p3.2xlarge) | $3.825/hr (+25%) |
|
||||
| 管理开销 | 需自己配置 | 全托管 |
|
||||
| Spot 折扣 | 最高 90% | 最高 90% |
|
||||
| 实验跟踪 | 无 | 内置 |
|
||||
| 自动关机 | 无 | 任务完成自动停止 |
|
||||
|
||||
### GPU 实例价格 (2025 年 6 月降价后)
|
||||
|
||||
| 实例 | GPU | 显存 | On-Demand | Spot 价格 |
|
||||
|------|-----|------|-----------|----------|
|
||||
| g4dn.xlarge | 1x T4 | 16GB | $0.526/hr | ~$0.16/hr |
|
||||
| g4dn.2xlarge | 1x T4 | 16GB | $0.752/hr | ~$0.23/hr |
|
||||
| p3.2xlarge | 1x V100 | 16GB | $3.06/hr | ~$0.92/hr |
|
||||
| p3.8xlarge | 4x V100 | 64GB | $12.24/hr | ~$3.67/hr |
|
||||
| p4d.24xlarge | 8x A100 | 320GB | $32.77/hr | ~$9.83/hr |
|
||||
|
||||
**注意**: 2025 年 6 月 AWS 宣布 P4/P5 系列最高降价 45%。
|
||||
|
||||
### Spot 实例
|
||||
|
||||
```bash
|
||||
# EC2 Spot 请求
|
||||
aws ec2 request-spot-instances \
|
||||
--instance-count 1 \
|
||||
--type "one-time" \
|
||||
--launch-specification '{
|
||||
"ImageId": "ami-0123456789abcdef0",
|
||||
"InstanceType": "p3.2xlarge",
|
||||
"KeyName": "my-key"
|
||||
}'
|
||||
```
|
||||
|
||||
### SageMaker Managed Spot Training
|
||||
|
||||
```python
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="./src",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
instance_count=1,
|
||||
instance_type="ml.p3.2xlarge",
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
|
||||
# 启用 Spot 实例
|
||||
use_spot_instances=True,
|
||||
max_run=3600, # 最长运行 1 小时
|
||||
max_wait=7200, # 最长等待 2 小时
|
||||
|
||||
# 检查点配置(Spot 中断恢复)
|
||||
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||
checkpoint_local_path="/opt/ml/checkpoints",
|
||||
|
||||
hyperparameters={
|
||||
"epochs": 100,
|
||||
"batch-size": 16,
|
||||
}
|
||||
)
|
||||
|
||||
estimator.fit({
|
||||
"training": "s3://invoice-training-data/datasets/train/",
|
||||
"validation": "s3://invoice-training-data/datasets/val/"
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 推理方案
|
||||
|
||||
### 方案对比
|
||||
|
||||
| 方案 | GPU 支持 | 扩缩容 | 冷启动 | 价格 | 适用场景 |
|
||||
|------|---------|--------|--------|------|---------|
|
||||
| Lambda | 否 | 自动 0-N | 快 | 按调用 | 低流量、CPU 推理 |
|
||||
| Lambda + Container | 否 | 自动 0-N | 较慢 | 按调用 | 复杂依赖 |
|
||||
| ECS Fargate | 否 | 自动 | 中 | ~$30/月 | 容器化服务 |
|
||||
| ECS + EC2 GPU | 是 | 手动/自动 | 慢 | ~$100+/月 | GPU 推理 |
|
||||
| SageMaker Endpoint | 是 | 自动 | 慢 | ~$80+/月 | MLOps 集成 |
|
||||
| SageMaker Serverless | 否 | 自动 0-N | 中 | 按调用 | 间歇性流量 |
|
||||
|
||||
### 推荐方案 1: AWS Lambda (低流量)
|
||||
|
||||
对于 YOLO CPU 推理,Lambda 最经济:
|
||||
|
||||
```python
|
||||
# lambda_function.py
|
||||
import json
|
||||
import boto3
|
||||
from ultralytics import YOLO
|
||||
|
||||
# 模型在 Lambda Layer 或 /tmp 加载
|
||||
model = None
|
||||
|
||||
def load_model():
|
||||
global model
|
||||
if model is None:
|
||||
# 从 S3 下载模型到 /tmp
|
||||
s3 = boto3.client('s3')
|
||||
s3.download_file('invoice-models', 'best.pt', '/tmp/best.pt')
|
||||
model = YOLO('/tmp/best.pt')
|
||||
return model
|
||||
|
||||
def lambda_handler(event, context):
|
||||
model = load_model()
|
||||
|
||||
# 从 S3 获取图片
|
||||
s3 = boto3.client('s3')
|
||||
bucket = event['bucket']
|
||||
key = event['key']
|
||||
|
||||
local_path = f'/tmp/{key.split("/")[-1]}'
|
||||
s3.download_file(bucket, key, local_path)
|
||||
|
||||
# 执行推理
|
||||
results = model.predict(local_path, conf=0.5)
|
||||
|
||||
return {
|
||||
'statusCode': 200,
|
||||
'body': json.dumps({
|
||||
'fields': extract_fields(results),
|
||||
'confidence': get_confidence(results)
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Lambda 配置:**
|
||||
```yaml
|
||||
# serverless.yml
|
||||
service: invoice-inference
|
||||
|
||||
provider:
|
||||
name: aws
|
||||
runtime: python3.11
|
||||
timeout: 30
|
||||
memorySize: 4096 # 4GB 内存
|
||||
|
||||
functions:
|
||||
infer:
|
||||
handler: lambda_function.lambda_handler
|
||||
events:
|
||||
- http:
|
||||
path: /infer
|
||||
method: post
|
||||
layers:
|
||||
- arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
|
||||
```
|
||||
|
||||
### 推荐方案 2: ECS Fargate (中流量)
|
||||
|
||||
```yaml
|
||||
# task-definition.json
|
||||
{
|
||||
"family": "invoice-inference",
|
||||
"networkMode": "awsvpc",
|
||||
"requiresCompatibilities": ["FARGATE"],
|
||||
"cpu": "2048",
|
||||
"memory": "4096",
|
||||
"containerDefinitions": [
|
||||
{
|
||||
"name": "inference",
|
||||
"image": "123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest",
|
||||
"portMappings": [
|
||||
{
|
||||
"containerPort": 8000,
|
||||
"protocol": "tcp"
|
||||
}
|
||||
],
|
||||
"environment": [
|
||||
{"name": "MODEL_PATH", "value": "/app/models/best.pt"}
|
||||
],
|
||||
"logConfiguration": {
|
||||
"logDriver": "awslogs",
|
||||
"options": {
|
||||
"awslogs-group": "/ecs/invoice-inference",
|
||||
"awslogs-region": "us-east-1",
|
||||
"awslogs-stream-prefix": "ecs"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Auto Scaling 配置:**
|
||||
```bash
|
||||
# 创建 Auto Scaling Target
|
||||
aws application-autoscaling register-scalable-target \
|
||||
--service-namespace ecs \
|
||||
--resource-id service/invoice-cluster/invoice-service \
|
||||
--scalable-dimension ecs:service:DesiredCount \
|
||||
--min-capacity 1 \
|
||||
--max-capacity 10
|
||||
|
||||
# 基于 CPU 使用率扩缩容
|
||||
aws application-autoscaling put-scaling-policy \
|
||||
--service-namespace ecs \
|
||||
--resource-id service/invoice-cluster/invoice-service \
|
||||
--scalable-dimension ecs:service:DesiredCount \
|
||||
--policy-name cpu-scaling \
|
||||
--policy-type TargetTrackingScaling \
|
||||
--target-tracking-scaling-policy-configuration '{
|
||||
"TargetValue": 70,
|
||||
"PredefinedMetricSpecification": {
|
||||
"PredefinedMetricType": "ECSServiceAverageCPUUtilization"
|
||||
},
|
||||
"ScaleOutCooldown": 60,
|
||||
"ScaleInCooldown": 120
|
||||
}'
|
||||
```
|
||||
|
||||
### 方案 3: SageMaker Serverless Inference
|
||||
|
||||
```python
|
||||
from sagemaker.serverless import ServerlessInferenceConfig
|
||||
from sagemaker.pytorch import PyTorchModel
|
||||
|
||||
model = PyTorchModel(
|
||||
model_data="s3://invoice-models/model.tar.gz",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
entry_point="inference.py",
|
||||
framework_version="2.0",
|
||||
py_version="py310"
|
||||
)
|
||||
|
||||
serverless_config = ServerlessInferenceConfig(
|
||||
memory_size_in_mb=4096,
|
||||
max_concurrency=10
|
||||
)
|
||||
|
||||
predictor = model.deploy(
|
||||
serverless_inference_config=serverless_config,
|
||||
endpoint_name="invoice-inference-serverless"
|
||||
)
|
||||
```
|
||||
|
||||
### 推理性能对比
|
||||
|
||||
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|
||||
|------|------------|---------|---------|
|
||||
| Lambda 4GB | ~500-800ms | 按需扩展 | ~$15 (10K 请求) |
|
||||
| Fargate 2vCPU 4GB | ~300-500ms | ~50 QPS | ~$30 |
|
||||
| Fargate 4vCPU 8GB | ~200-300ms | ~100 QPS | ~$60 |
|
||||
| EC2 g4dn.xlarge (T4) | ~50-100ms | ~200 QPS | ~$380 |
|
||||
|
||||
---
|
||||
|
||||
## 价格对比
|
||||
|
||||
### 训练成本对比(假设每天训练 2 小时)
|
||||
|
||||
| 方案 | 计算方式 | 月费 |
|
||||
|------|---------|------|
|
||||
| EC2 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
|
||||
| EC2 按需启停 | 2h × 30天 × $3.06 | ~$184 |
|
||||
| EC2 Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
|
||||
| SageMaker On-Demand | 2h × 30天 × $3.825 | ~$230 |
|
||||
| SageMaker Spot | 2h × 30天 × $1.15 | ~$69 |
|
||||
|
||||
### 本项目完整成本估算
|
||||
|
||||
| 组件 | 推荐方案 | 月费 |
|
||||
|------|---------|------|
|
||||
| 数据存储 | S3 Standard (5GB) | ~$0.12 |
|
||||
| 数据库 | RDS PostgreSQL (db.t3.micro) | ~$15 |
|
||||
| 推理服务 | Lambda (10K 请求/月) | ~$15 |
|
||||
| 推理服务 (替代) | ECS Fargate | ~$30 |
|
||||
| 训练服务 | SageMaker Spot (按需) | ~$2-5/次 |
|
||||
| ECR (镜像存储) | 基本使用 | ~$1 |
|
||||
| **总计 (Lambda)** | | **~$35/月** + 训练费 |
|
||||
| **总计 (Fargate)** | | **~$50/月** + 训练费 |
|
||||
|
||||
---
|
||||
|
||||
## 推荐架构
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Amazon S3 │
|
||||
│ ├── training-images/ │
|
||||
│ ├── datasets/ │
|
||||
│ ├── models/ │
|
||||
│ └── checkpoints/ │
|
||||
└─────────────────┬───────────────────┘
|
||||
│
|
||||
┌─────────────────────────────────┼─────────────────────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
|
||||
│ 推理服务 │ │ 训练服务 │ │ API Gateway │
|
||||
│ │ │ │ │ │
|
||||
│ 方案 A: Lambda │ │ SageMaker │ │ REST API │
|
||||
│ ~$15/月 (10K req) │ │ Managed Spot │ │ 触发 Lambda/ECS │
|
||||
│ │ │ ~$2-5/次训练 │ │ │
|
||||
│ 方案 B: ECS Fargate │ │ │ │ │
|
||||
│ ~$30/月 │ │ - 自动启动 │ │ │
|
||||
│ │ │ - 训练完成自动停止 │ │ │
|
||||
│ ┌───────────────────┐ │ │ - 检查点自动保存 │ │ │
|
||||
│ │ FastAPI + YOLO │ │ │ │ │ │
|
||||
│ │ CPU 推理 │ │ │ │ │ │
|
||||
│ └───────────────────┘ │ └───────────┬───────────┘ └───────────────────────┘
|
||||
└───────────┬───────────┘ │
|
||||
│ │
|
||||
└───────────────────────────────┼───────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────┐
|
||||
│ Amazon RDS │
|
||||
│ PostgreSQL │
|
||||
│ db.t3.micro │
|
||||
│ ~$15/月 │
|
||||
└───────────────────────┘
|
||||
```
|
||||
|
||||
### Lambda 推理配置
|
||||
|
||||
```yaml
|
||||
# SAM template
|
||||
AWSTemplateFormatVersion: '2010-09-09'
|
||||
Transform: AWS::Serverless-2016-10-31
|
||||
|
||||
Resources:
|
||||
InferenceFunction:
|
||||
Type: AWS::Serverless::Function
|
||||
Properties:
|
||||
Handler: app.lambda_handler
|
||||
Runtime: python3.11
|
||||
MemorySize: 4096
|
||||
Timeout: 30
|
||||
Environment:
|
||||
Variables:
|
||||
MODEL_BUCKET: invoice-models
|
||||
MODEL_KEY: best.pt
|
||||
Policies:
|
||||
- S3ReadPolicy:
|
||||
BucketName: invoice-models
|
||||
- S3ReadPolicy:
|
||||
BucketName: invoice-uploads
|
||||
Events:
|
||||
InferApi:
|
||||
Type: Api
|
||||
Properties:
|
||||
Path: /infer
|
||||
Method: post
|
||||
```
|
||||
|
||||
### SageMaker 训练配置
|
||||
|
||||
```python
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="./src",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
instance_count=1,
|
||||
instance_type="ml.g4dn.xlarge", # T4 GPU
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
|
||||
# Spot 实例配置
|
||||
use_spot_instances=True,
|
||||
max_run=7200,
|
||||
max_wait=14400,
|
||||
|
||||
# 检查点
|
||||
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||
|
||||
hyperparameters={
|
||||
"epochs": 100,
|
||||
"batch-size": 16,
|
||||
"model": "yolo11n.pt"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 实施步骤
|
||||
|
||||
### 阶段 1: 存储设置
|
||||
|
||||
```bash
|
||||
# 创建 S3 桶
|
||||
aws s3 mb s3://invoice-training-data --region us-east-1
|
||||
aws s3 mb s3://invoice-models --region us-east-1
|
||||
|
||||
# 上传训练数据
|
||||
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
|
||||
|
||||
# 配置生命周期(可选,自动转冷存储)
|
||||
aws s3api put-bucket-lifecycle-configuration \
|
||||
--bucket invoice-training-data \
|
||||
--lifecycle-configuration '{
|
||||
"Rules": [{
|
||||
"ID": "MoveToIA",
|
||||
"Status": "Enabled",
|
||||
"Transitions": [{
|
||||
"Days": 30,
|
||||
"StorageClass": "STANDARD_IA"
|
||||
}]
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
### 阶段 2: 数据库设置
|
||||
|
||||
```bash
|
||||
# 创建 RDS PostgreSQL
|
||||
aws rds create-db-instance \
|
||||
--db-instance-identifier invoice-db \
|
||||
--db-instance-class db.t3.micro \
|
||||
--engine postgres \
|
||||
--engine-version 15 \
|
||||
--master-username docmaster \
|
||||
--master-user-password YOUR_PASSWORD \
|
||||
--allocated-storage 20
|
||||
|
||||
# 配置安全组
|
||||
aws ec2 authorize-security-group-ingress \
|
||||
--group-id sg-xxx \
|
||||
--protocol tcp \
|
||||
--port 5432 \
|
||||
--source-group sg-yyy
|
||||
```
|
||||
|
||||
### 阶段 3: 推理服务部署
|
||||
|
||||
**方案 A: Lambda**
|
||||
|
||||
```bash
|
||||
# 创建 Lambda Layer (依赖)
|
||||
cd lambda-layer
|
||||
pip install ultralytics opencv-python-headless -t python/
|
||||
zip -r layer.zip python/
|
||||
aws lambda publish-layer-version \
|
||||
--layer-name yolo-deps \
|
||||
--zip-file fileb://layer.zip \
|
||||
--compatible-runtimes python3.11
|
||||
|
||||
# 部署 Lambda 函数
|
||||
cd ../lambda
|
||||
zip function.zip lambda_function.py
|
||||
aws lambda create-function \
|
||||
--function-name invoice-inference \
|
||||
--runtime python3.11 \
|
||||
--handler lambda_function.lambda_handler \
|
||||
--role arn:aws:iam::123456789012:role/LambdaRole \
|
||||
--zip-file fileb://function.zip \
|
||||
--memory-size 4096 \
|
||||
--timeout 30 \
|
||||
--layers arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
|
||||
|
||||
# 创建 API Gateway
|
||||
aws apigatewayv2 create-api \
|
||||
--name invoice-api \
|
||||
--protocol-type HTTP \
|
||||
--target arn:aws:lambda:us-east-1:123456789012:function:invoice-inference
|
||||
```
|
||||
|
||||
**方案 B: ECS Fargate**
|
||||
|
||||
```bash
|
||||
# 创建 ECR 仓库
|
||||
aws ecr create-repository --repository-name invoice-inference
|
||||
|
||||
# 构建并推送镜像
|
||||
aws ecr get-login-password | docker login --username AWS --password-stdin 123456789012.dkr.ecr.us-east-1.amazonaws.com
|
||||
docker build -t invoice-inference .
|
||||
docker tag invoice-inference:latest 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
|
||||
docker push 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
|
||||
|
||||
# 创建 ECS 集群
|
||||
aws ecs create-cluster --cluster-name invoice-cluster
|
||||
|
||||
# 注册任务定义
|
||||
aws ecs register-task-definition --cli-input-json file://task-definition.json
|
||||
|
||||
# 创建服务
|
||||
aws ecs create-service \
|
||||
--cluster invoice-cluster \
|
||||
--service-name invoice-service \
|
||||
--task-definition invoice-inference \
|
||||
--desired-count 1 \
|
||||
--launch-type FARGATE \
|
||||
--network-configuration '{
|
||||
"awsvpcConfiguration": {
|
||||
"subnets": ["subnet-xxx"],
|
||||
"securityGroups": ["sg-xxx"],
|
||||
"assignPublicIp": "ENABLED"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 阶段 4: 训练服务设置
|
||||
|
||||
```python
|
||||
# setup_sagemaker.py
|
||||
import boto3
|
||||
import sagemaker
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
# 创建 SageMaker 执行角色
|
||||
iam = boto3.client('iam')
|
||||
role_arn = "arn:aws:iam::123456789012:role/SageMakerExecutionRole"
|
||||
|
||||
# 配置训练任务
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="./src/training",
|
||||
role=role_arn,
|
||||
instance_count=1,
|
||||
instance_type="ml.g4dn.xlarge",
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
use_spot_instances=True,
|
||||
max_run=7200,
|
||||
max_wait=14400,
|
||||
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
|
||||
)
|
||||
|
||||
# 保存配置供后续使用
|
||||
estimator.save("training_config.json")
|
||||
```
|
||||
|
||||
### 阶段 5: 集成训练触发 API
|
||||
|
||||
```python
|
||||
# lambda_trigger_training.py
|
||||
import boto3
|
||||
import sagemaker
|
||||
from sagemaker.pytorch import PyTorch
|
||||
|
||||
def lambda_handler(event, context):
|
||||
"""触发 SageMaker 训练任务"""
|
||||
|
||||
epochs = event.get('epochs', 100)
|
||||
|
||||
estimator = PyTorch(
|
||||
entry_point="train.py",
|
||||
source_dir="s3://invoice-training-data/code/",
|
||||
role="arn:aws:iam::123456789012:role/SageMakerRole",
|
||||
instance_count=1,
|
||||
instance_type="ml.g4dn.xlarge",
|
||||
framework_version="2.0",
|
||||
py_version="py310",
|
||||
use_spot_instances=True,
|
||||
max_run=7200,
|
||||
max_wait=14400,
|
||||
hyperparameters={
|
||||
"epochs": epochs,
|
||||
"batch-size": 16,
|
||||
}
|
||||
)
|
||||
|
||||
estimator.fit(
|
||||
inputs={
|
||||
"training": "s3://invoice-training-data/datasets/train/",
|
||||
"validation": "s3://invoice-training-data/datasets/val/"
|
||||
},
|
||||
wait=False # 异步执行
|
||||
)
|
||||
|
||||
return {
|
||||
'statusCode': 200,
|
||||
'body': {
|
||||
'training_job_name': estimator.latest_training_job.name,
|
||||
'status': 'Started'
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## AWS vs Azure 对比
|
||||
|
||||
### 服务对应关系
|
||||
|
||||
| 功能 | AWS | Azure |
|
||||
|------|-----|-------|
|
||||
| 对象存储 | S3 | Blob Storage |
|
||||
| 挂载工具 | Mountpoint for S3 | BlobFuse2 |
|
||||
| ML 平台 | SageMaker | Azure ML |
|
||||
| 容器服务 | ECS/Fargate | Container Apps |
|
||||
| Serverless | Lambda | Functions |
|
||||
| GPU VM | EC2 P3/G4dn | NC/ND 系列 |
|
||||
| 容器注册 | ECR | ACR |
|
||||
| 数据库 | RDS PostgreSQL | PostgreSQL Flexible |
|
||||
|
||||
### 价格对比
|
||||
|
||||
| 组件 | AWS | Azure |
|
||||
|------|-----|-------|
|
||||
| 存储 (5GB) | ~$0.12/月 | ~$0.09/月 |
|
||||
| 数据库 | ~$15/月 | ~$25/月 |
|
||||
| 推理 (Serverless) | ~$15/月 | ~$30/月 |
|
||||
| 推理 (容器) | ~$30/月 | ~$30/月 |
|
||||
| 训练 (Spot GPU) | ~$2-5/次 | ~$1-5/次 |
|
||||
| **总计** | **~$35-50/月** | **~$65/月** |
|
||||
|
||||
### 优劣对比
|
||||
|
||||
| 方面 | AWS 优势 | Azure 优势 |
|
||||
|------|---------|-----------|
|
||||
| 价格 | Lambda 更便宜 | GPU Spot 更便宜 |
|
||||
| ML 平台 | SageMaker 更成熟 | Azure ML 更易用 |
|
||||
| Serverless GPU | 无原生支持 | Container Apps GPU |
|
||||
| 文档 | 更丰富 | 中文文档更好 |
|
||||
| 生态 | 更大 | Office 365 集成 |
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 推荐配置
|
||||
|
||||
| 组件 | 推荐方案 | 月费估算 |
|
||||
|------|---------|---------|
|
||||
| 数据存储 | S3 Standard | ~$0.12 |
|
||||
| 数据库 | RDS db.t3.micro | ~$15 |
|
||||
| 推理服务 | Lambda 4GB | ~$15 |
|
||||
| 训练服务 | SageMaker Spot | 按需 ~$2-5/次 |
|
||||
| ECR | 基本使用 | ~$1 |
|
||||
| **总计** | | **~$35/月** + 训练费 |
|
||||
|
||||
### 关键决策
|
||||
|
||||
| 场景 | 选择 |
|
||||
|------|------|
|
||||
| 最低成本 | Lambda + SageMaker Spot |
|
||||
| 稳定推理 | ECS Fargate |
|
||||
| GPU 推理 | ECS + EC2 GPU |
|
||||
| MLOps 集成 | SageMaker 全家桶 |
|
||||
|
||||
### 注意事项
|
||||
|
||||
1. **Lambda 冷启动**: 首次调用 ~3-5 秒,可用 Provisioned Concurrency 解决
|
||||
2. **Spot 中断**: 配置检查点,SageMaker 自动恢复
|
||||
3. **S3 传输**: 同区域免费,跨区域收费
|
||||
4. **Fargate 无 GPU**: 需要 GPU 必须用 ECS + EC2
|
||||
5. **SageMaker 加价**: 比 EC2 贵 ~25%,但省管理成本
|
||||
567
docs/azure-deployment-guide.md
Normal file
567
docs/azure-deployment-guide.md
Normal file
@@ -0,0 +1,567 @@
|
||||
# Azure 部署方案完整指南
|
||||
|
||||
## 目录
|
||||
- [核心问题](#核心问题)
|
||||
- [存储方案](#存储方案)
|
||||
- [训练方案](#训练方案)
|
||||
- [推理方案](#推理方案)
|
||||
- [价格对比](#价格对比)
|
||||
- [推荐架构](#推荐架构)
|
||||
- [实施步骤](#实施步骤)
|
||||
|
||||
---
|
||||
|
||||
## 核心问题
|
||||
|
||||
| 问题 | 答案 |
|
||||
|------|------|
|
||||
| Azure Blob Storage 能用于训练吗? | 可以,用 BlobFuse2 挂载 |
|
||||
| 能实时从 Blob 读取训练吗? | 可以,但建议配置本地缓存 |
|
||||
| 本地能挂载 Azure Blob 吗? | 可以,用 Rclone (Windows) 或 BlobFuse2 (Linux) |
|
||||
| VM 空闲时收费吗? | 收费,只要开机就按小时计费 |
|
||||
| 如何按需付费? | 用 Serverless GPU 或 min=0 的 Compute Cluster |
|
||||
| 推理服务用什么? | Container Apps (CPU) 或 Serverless GPU |
|
||||
|
||||
---
|
||||
|
||||
## 存储方案
|
||||
|
||||
### Azure Blob Storage + BlobFuse2(推荐)
|
||||
|
||||
```bash
|
||||
# 安装 BlobFuse2
|
||||
sudo apt-get install blobfuse2
|
||||
|
||||
# 配置文件
|
||||
cat > ~/blobfuse-config.yaml << 'EOF'
|
||||
logging:
|
||||
type: syslog
|
||||
level: log_warning
|
||||
|
||||
components:
|
||||
- libfuse
|
||||
- file_cache
|
||||
- azstorage
|
||||
|
||||
file_cache:
|
||||
path: /tmp/blobfuse2
|
||||
timeout-sec: 120
|
||||
max-size-mb: 4096
|
||||
|
||||
azstorage:
|
||||
type: block
|
||||
account-name: YOUR_ACCOUNT
|
||||
account-key: YOUR_KEY
|
||||
container: training-images
|
||||
EOF
|
||||
|
||||
# 挂载
|
||||
mkdir -p /mnt/azure-blob
|
||||
blobfuse2 mount /mnt/azure-blob --config-file=~/blobfuse-config.yaml
|
||||
```
|
||||
|
||||
### 本地开发(Windows)
|
||||
|
||||
```powershell
|
||||
# 安装
|
||||
winget install WinFsp.WinFsp
|
||||
winget install Rclone.Rclone
|
||||
|
||||
# 配置
|
||||
rclone config # 选择 azureblob
|
||||
|
||||
# 挂载为 Z: 盘
|
||||
rclone mount azure:training-images Z: --vfs-cache-mode full
|
||||
```
|
||||
|
||||
### 存储费用
|
||||
|
||||
| 层级 | 价格 | 适用场景 |
|
||||
|------|------|---------|
|
||||
| Hot | $0.018/GB/月 | 频繁访问 |
|
||||
| Cool | $0.01/GB/月 | 偶尔访问 |
|
||||
| Archive | $0.002/GB/月 | 长期存档 |
|
||||
|
||||
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.09/月**
|
||||
|
||||
---
|
||||
|
||||
## 训练方案
|
||||
|
||||
### 方案总览
|
||||
|
||||
| 方案 | 适用场景 | 空闲费用 | 复杂度 |
|
||||
|------|---------|---------|--------|
|
||||
| Azure VM | 简单直接 | 24/7 收费 | 低 |
|
||||
| Azure VM Spot | 省钱、可中断 | 24/7 收费 | 低 |
|
||||
| Azure ML Compute | MLOps 集成 | 可缩到 0 | 中 |
|
||||
| Container Apps GPU | Serverless | 自动缩到 0 | 中 |
|
||||
|
||||
### Azure VM vs Azure ML
|
||||
|
||||
| 特性 | Azure VM | Azure ML |
|
||||
|------|----------|----------|
|
||||
| 本质 | 虚拟机 | 托管 ML 平台 |
|
||||
| 计算费用 | $3.06/hr (NC6s_v3) | $3.06/hr (相同) |
|
||||
| 附加费用 | ~$5/月 | ~$20-30/月 |
|
||||
| 实验跟踪 | 无 | 内置 |
|
||||
| 自动扩缩 | 无 | 支持 min=0 |
|
||||
| 适用人群 | DevOps | 数据科学家 |
|
||||
|
||||
### Azure ML 附加费用明细
|
||||
|
||||
| 服务 | 用途 | 费用 |
|
||||
|------|------|------|
|
||||
| Container Registry | Docker 镜像 | ~$5-20/月 |
|
||||
| Blob Storage | 日志、模型 | ~$0.10/月 |
|
||||
| Application Insights | 监控 | ~$0-10/月 |
|
||||
| Key Vault | 密钥管理 | <$1/月 |
|
||||
|
||||
### Spot 实例
|
||||
|
||||
两种平台都支持 Spot/低优先级实例,最高节省 90%:
|
||||
|
||||
| 类型 | 正常价格 | Spot 价格 | 节省 |
|
||||
|------|---------|----------|------|
|
||||
| NC6s_v3 (V100) | $3.06/hr | ~$0.92/hr | 70% |
|
||||
| NC24ads_A100_v4 | $3.67/hr | ~$1.15/hr | 69% |
|
||||
|
||||
### GPU 实例价格
|
||||
|
||||
| 实例 | GPU | 显存 | 价格/小时 | Spot 价格 |
|
||||
|------|-----|------|---------|----------|
|
||||
| NC6s_v3 | 1x V100 | 16GB | $3.06 | $0.92 |
|
||||
| NC24s_v3 | 4x V100 | 64GB | $12.24 | $3.67 |
|
||||
| NC24ads_A100_v4 | 1x A100 | 80GB | $3.67 | $1.15 |
|
||||
| NC48ads_A100_v4 | 2x A100 | 160GB | $7.35 | $2.30 |
|
||||
|
||||
---
|
||||
|
||||
## 推理方案
|
||||
|
||||
### 方案对比
|
||||
|
||||
| 方案 | GPU 支持 | 扩缩容 | 价格 | 适用场景 |
|
||||
|------|---------|--------|------|---------|
|
||||
| Container Apps (CPU) | 否 | 自动 0-N | ~$30/月 | YOLO 推理 (够用) |
|
||||
| Container Apps (GPU) | 是 | Serverless | 按秒计费 | 高吞吐推理 |
|
||||
| Azure App Service | 否 | 手动/自动 | ~$50/月 | 简单部署 |
|
||||
| Azure ML Endpoint | 是 | 自动 | ~$100+/月 | MLOps 集成 |
|
||||
| AKS (Kubernetes) | 是 | 自动 | 复杂计费 | 大规模生产 |
|
||||
|
||||
### 推荐: Container Apps (CPU)
|
||||
|
||||
对于 YOLO 推理,**CPU 足够**,不需要 GPU:
|
||||
- YOLOv11n 在 CPU 上推理时间 ~200-500ms
|
||||
- 比 GPU 便宜很多,适合中低流量
|
||||
|
||||
```yaml
|
||||
# Container Apps 配置
|
||||
name: invoice-inference
|
||||
image: myacr.azurecr.io/invoice-inference:v1
|
||||
resources:
|
||||
cpu: 2.0
|
||||
memory: 4Gi
|
||||
scale:
|
||||
minReplicas: 1 # 最少 1 个实例保持响应
|
||||
maxReplicas: 10 # 最多扩展到 10 个
|
||||
rules:
|
||||
- name: http-scaling
|
||||
http:
|
||||
metadata:
|
||||
concurrentRequests: "50" # 每实例 50 并发时扩容
|
||||
```
|
||||
|
||||
### 推理服务代码示例
|
||||
|
||||
```python
|
||||
# Dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制代码和模型
|
||||
COPY src/ ./src/
|
||||
COPY models/best.pt ./models/
|
||||
|
||||
# 启动服务
|
||||
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
```python
|
||||
# src/web/app.py
|
||||
from fastapi import FastAPI, UploadFile, File
|
||||
from ultralytics import YOLO
|
||||
import tempfile
|
||||
|
||||
app = FastAPI()
|
||||
model = YOLO("models/best.pt")
|
||||
|
||||
@app.post("/api/v1/infer")
|
||||
async def infer(file: UploadFile = File(...)):
|
||||
# 保存上传文件
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
# 执行推理
|
||||
results = model.predict(tmp_path, conf=0.5)
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"fields": extract_fields(results),
|
||||
"confidence": get_confidence(results)
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
```
|
||||
|
||||
### 部署命令
|
||||
|
||||
```bash
|
||||
# 1. 创建 Container Registry
|
||||
az acr create --name invoiceacr --resource-group myRG --sku Basic
|
||||
|
||||
# 2. 构建并推送镜像
|
||||
az acr build --registry invoiceacr --image invoice-inference:v1 .
|
||||
|
||||
# 3. 创建 Container Apps 环境
|
||||
az containerapp env create \
|
||||
--name invoice-env \
|
||||
--resource-group myRG \
|
||||
--location eastus
|
||||
|
||||
# 4. 部署应用
|
||||
az containerapp create \
|
||||
--name invoice-inference \
|
||||
--resource-group myRG \
|
||||
--environment invoice-env \
|
||||
--image invoiceacr.azurecr.io/invoice-inference:v1 \
|
||||
--registry-server invoiceacr.azurecr.io \
|
||||
--cpu 2 --memory 4Gi \
|
||||
--min-replicas 1 --max-replicas 10 \
|
||||
--ingress external --target-port 8000
|
||||
|
||||
# 5. 获取 URL
|
||||
az containerapp show --name invoice-inference --resource-group myRG --query properties.configuration.ingress.fqdn
|
||||
```
|
||||
|
||||
### 高吞吐场景: Serverless GPU
|
||||
|
||||
如果需要 GPU 加速推理(高并发、低延迟):
|
||||
|
||||
```bash
|
||||
# 请求 GPU 配额
|
||||
az containerapp env workload-profile add \
|
||||
--name invoice-env \
|
||||
--resource-group myRG \
|
||||
--workload-profile-name gpu \
|
||||
--workload-profile-type Consumption-GPU-T4
|
||||
|
||||
# 部署 GPU 版本
|
||||
az containerapp create \
|
||||
--name invoice-inference-gpu \
|
||||
--resource-group myRG \
|
||||
--environment invoice-env \
|
||||
--image invoiceacr.azurecr.io/invoice-inference-gpu:v1 \
|
||||
--workload-profile-name gpu \
|
||||
--cpu 4 --memory 8Gi \
|
||||
--min-replicas 0 --max-replicas 5 \
|
||||
--ingress external --target-port 8000
|
||||
```
|
||||
|
||||
### 推理性能对比
|
||||
|
||||
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|
||||
|------|------------|---------|---------|
|
||||
| CPU 2核 4GB | ~300-500ms | ~50 QPS | ~$30 |
|
||||
| CPU 4核 8GB | ~200-300ms | ~100 QPS | ~$60 |
|
||||
| GPU T4 | ~50-100ms | ~200 QPS | 按秒计费 |
|
||||
| GPU A100 | ~20-50ms | ~500 QPS | 按秒计费 |
|
||||
|
||||
---
|
||||
|
||||
## 价格对比
|
||||
|
||||
### 月度成本对比(假设每天训练 2 小时)
|
||||
|
||||
| 方案 | 计算方式 | 月费 |
|
||||
|------|---------|------|
|
||||
| VM 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
|
||||
| VM 按需启停 | 2h × 30天 × $3.06 | ~$184 |
|
||||
| VM Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
|
||||
| Serverless GPU | 2h × 30天 × ~$3.50 | ~$210 |
|
||||
| Azure ML (min=0) | 2h × 30天 × $3.06 | ~$184 |
|
||||
|
||||
### 本项目完整成本估算
|
||||
|
||||
| 组件 | 推荐方案 | 月费 |
|
||||
|------|---------|------|
|
||||
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
|
||||
| 数据库 | PostgreSQL Flexible (Burstable B1ms) | ~$25 |
|
||||
| 推理服务 | Container Apps CPU (2核4GB) | ~$30 |
|
||||
| 训练服务 | Azure ML Spot (按需) | ~$1-5/次 |
|
||||
| Container Registry | Basic | ~$5 |
|
||||
| **总计** | | **~$65/月** + 训练费 |
|
||||
|
||||
---
|
||||
|
||||
## 推荐架构
|
||||
|
||||
### 整体架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Azure Blob Storage │
|
||||
│ ├── training-images/ │
|
||||
│ ├── datasets/ │
|
||||
│ └── models/ │
|
||||
└─────────────────┬───────────────────┘
|
||||
│
|
||||
┌─────────────────────────────────┼─────────────────────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
|
||||
│ 推理服务 (24/7) │ │ 训练服务 (按需) │ │ Web UI (可选) │
|
||||
│ Container Apps │ │ Azure ML Compute │ │ Static Web Apps │
|
||||
│ CPU 2核 4GB │ │ min=0, Spot │ │ ~$0 (免费层) │
|
||||
│ ~$30/月 │ │ ~$1-5/次训练 │ │ │
|
||||
│ │ │ │ │ │
|
||||
│ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │
|
||||
│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │
|
||||
│ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │
|
||||
│ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │
|
||||
└───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘
|
||||
│ │ │
|
||||
└───────────────────────────────┼───────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────┐
|
||||
│ PostgreSQL │
|
||||
│ Flexible Server │
|
||||
│ Burstable B1ms │
|
||||
│ ~$25/月 │
|
||||
└───────────────────────┘
|
||||
```
|
||||
|
||||
### 推理服务配置
|
||||
|
||||
```yaml
|
||||
# Container Apps - CPU (24/7 运行)
|
||||
name: invoice-inference
|
||||
resources:
|
||||
cpu: 2
|
||||
memory: 4Gi
|
||||
scale:
|
||||
minReplicas: 1
|
||||
maxReplicas: 10
|
||||
env:
|
||||
- name: MODEL_PATH
|
||||
value: /app/models/best.pt
|
||||
- name: DB_HOST
|
||||
secretRef: db-host
|
||||
- name: DB_PASSWORD
|
||||
secretRef: db-password
|
||||
```
|
||||
|
||||
### 训练服务配置
|
||||
|
||||
**方案 A: Azure ML Compute(推荐)**
|
||||
|
||||
```python
|
||||
from azure.ai.ml.entities import AmlCompute
|
||||
|
||||
gpu_cluster = AmlCompute(
|
||||
name="gpu-cluster",
|
||||
size="Standard_NC6s_v3",
|
||||
min_instances=0, # 空闲时关机
|
||||
max_instances=1,
|
||||
tier="LowPriority", # Spot 实例
|
||||
idle_time_before_scale_down=120
|
||||
)
|
||||
```
|
||||
|
||||
**方案 B: Container Apps Serverless GPU**
|
||||
|
||||
```yaml
|
||||
name: invoice-training
|
||||
resources:
|
||||
gpu: 1
|
||||
gpuType: A100
|
||||
scale:
|
||||
minReplicas: 0
|
||||
maxReplicas: 1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 实施步骤
|
||||
|
||||
### 阶段 1: 存储设置
|
||||
|
||||
```bash
|
||||
# 创建 Storage Account
|
||||
az storage account create \
|
||||
--name invoicestorage \
|
||||
--resource-group myRG \
|
||||
--sku Standard_LRS
|
||||
|
||||
# 创建容器
|
||||
az storage container create --name training-images --account-name invoicestorage
|
||||
az storage container create --name datasets --account-name invoicestorage
|
||||
az storage container create --name models --account-name invoicestorage
|
||||
|
||||
# 上传训练数据
|
||||
az storage blob upload-batch \
|
||||
--destination training-images \
|
||||
--source ./data/dataset/temp \
|
||||
--account-name invoicestorage
|
||||
```
|
||||
|
||||
### 阶段 2: 数据库设置
|
||||
|
||||
```bash
|
||||
# 创建 PostgreSQL
|
||||
az postgres flexible-server create \
|
||||
--name invoice-db \
|
||||
--resource-group myRG \
|
||||
--sku-name Standard_B1ms \
|
||||
--storage-size 32 \
|
||||
--admin-user docmaster \
|
||||
--admin-password YOUR_PASSWORD
|
||||
|
||||
# 配置防火墙
|
||||
az postgres flexible-server firewall-rule create \
|
||||
--name allow-azure \
|
||||
--resource-group myRG \
|
||||
--server-name invoice-db \
|
||||
--start-ip-address 0.0.0.0 \
|
||||
--end-ip-address 0.0.0.0
|
||||
```
|
||||
|
||||
### 阶段 3: 推理服务部署
|
||||
|
||||
```bash
|
||||
# 创建 Container Registry
|
||||
az acr create --name invoiceacr --resource-group myRG --sku Basic
|
||||
|
||||
# 构建镜像
|
||||
az acr build --registry invoiceacr --image invoice-inference:v1 .
|
||||
|
||||
# 创建环境
|
||||
az containerapp env create \
|
||||
--name invoice-env \
|
||||
--resource-group myRG \
|
||||
--location eastus
|
||||
|
||||
# 部署推理服务
|
||||
az containerapp create \
|
||||
--name invoice-inference \
|
||||
--resource-group myRG \
|
||||
--environment invoice-env \
|
||||
--image invoiceacr.azurecr.io/invoice-inference:v1 \
|
||||
--registry-server invoiceacr.azurecr.io \
|
||||
--cpu 2 --memory 4Gi \
|
||||
--min-replicas 1 --max-replicas 10 \
|
||||
--ingress external --target-port 8000 \
|
||||
--env-vars \
|
||||
DB_HOST=invoice-db.postgres.database.azure.com \
|
||||
DB_NAME=docmaster \
|
||||
DB_USER=docmaster \
|
||||
--secrets db-password=YOUR_PASSWORD
|
||||
```
|
||||
|
||||
### 阶段 4: 训练服务设置
|
||||
|
||||
```bash
|
||||
# 创建 Azure ML Workspace
|
||||
az ml workspace create --name invoice-ml --resource-group myRG
|
||||
|
||||
# 创建 Compute Cluster
|
||||
az ml compute create --name gpu-cluster \
|
||||
--type AmlCompute \
|
||||
--size Standard_NC6s_v3 \
|
||||
--min-instances 0 \
|
||||
--max-instances 1 \
|
||||
--tier low_priority
|
||||
```
|
||||
|
||||
### 阶段 5: 集成训练触发 API
|
||||
|
||||
```python
|
||||
# src/web/routes/training.py
|
||||
from fastapi import APIRouter
|
||||
from azure.ai.ml import MLClient, command
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ml_client = MLClient(
|
||||
credential=DefaultAzureCredential(),
|
||||
subscription_id="your-subscription-id",
|
||||
resource_group_name="myRG",
|
||||
workspace_name="invoice-ml"
|
||||
)
|
||||
|
||||
@router.post("/api/v1/train")
|
||||
async def trigger_training(request: TrainingRequest):
|
||||
"""触发 Azure ML 训练任务"""
|
||||
training_job = command(
|
||||
code="./training",
|
||||
command=f"python train.py --epochs {request.epochs}",
|
||||
environment="AzureML-pytorch-2.0-cuda11.8@latest",
|
||||
compute="gpu-cluster",
|
||||
)
|
||||
job = ml_client.jobs.create_or_update(training_job)
|
||||
return {
|
||||
"job_id": job.name,
|
||||
"status": job.status,
|
||||
"studio_url": job.studio_url
|
||||
}
|
||||
|
||||
@router.get("/api/v1/train/{job_id}/status")
|
||||
async def get_training_status(job_id: str):
|
||||
"""查询训练状态"""
|
||||
job = ml_client.jobs.get(job_id)
|
||||
return {"status": job.status}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### 推荐配置
|
||||
|
||||
| 组件 | 推荐方案 | 月费估算 |
|
||||
|------|---------|---------|
|
||||
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
|
||||
| 数据库 | PostgreSQL Flexible | ~$25 |
|
||||
| 推理服务 | Container Apps CPU | ~$30 |
|
||||
| 训练服务 | Azure ML (min=0, Spot) | 按需 ~$1-5/次 |
|
||||
| Container Registry | Basic | ~$5 |
|
||||
| **总计** | | **~$65/月** + 训练费 |
|
||||
|
||||
### 关键决策
|
||||
|
||||
| 场景 | 选择 |
|
||||
|------|------|
|
||||
| 偶尔训练,简单需求 | Azure VM Spot + 手动启停 |
|
||||
| 需要 MLOps,团队协作 | Azure ML Compute |
|
||||
| 追求最低空闲成本 | Container Apps Serverless GPU |
|
||||
| 生产环境推理 | Container Apps CPU |
|
||||
| 高并发推理 | Container Apps Serverless GPU |
|
||||
|
||||
### 注意事项
|
||||
|
||||
1. **冷启动**: Serverless GPU 启动需要 3-8 分钟
|
||||
2. **Spot 中断**: 可能被抢占,需要检查点机制
|
||||
3. **网络延迟**: Blob Storage 挂载比本地 SSD 慢,建议开启缓存
|
||||
4. **区域选择**: 选择有 GPU 配额的区域 (East US, West Europe 等)
|
||||
5. **推理优化**: CPU 推理对于 YOLO 已经足够,无需 GPU
|
||||
647
docs/dashboard-design-spec.md
Normal file
647
docs/dashboard-design-spec.md
Normal file
@@ -0,0 +1,647 @@
|
||||
# Dashboard Design Specification
|
||||
|
||||
## Overview
|
||||
|
||||
Dashboard 是用户进入系统后的第一个页面,用于快速了解:
|
||||
- 数据标注质量和进度
|
||||
- 当前模型状态和性能
|
||||
- 系统最近发生的活动
|
||||
|
||||
**目标用户**:使用文档标注系统的客户,需要监控文档处理状态、标注质量和模型训练进度。
|
||||
|
||||
---
|
||||
|
||||
## 1. UI Layout
|
||||
|
||||
### 1.1 Overall Structure
|
||||
|
||||
```
|
||||
+------------------------------------------------------------------+
|
||||
| Header: Logo + Navigation + User Menu |
|
||||
+------------------------------------------------------------------+
|
||||
| |
|
||||
| Stats Cards Row (4 cards, equal width) |
|
||||
| |
|
||||
| +---------------------------+ +------------------------------+ |
|
||||
| | Data Quality Panel (50%) | | Active Model Panel (50%) | |
|
||||
| +---------------------------+ +------------------------------+ |
|
||||
| |
|
||||
| +--------------------------------------------------------------+ |
|
||||
| | Recent Activity Panel (full width) | |
|
||||
| +--------------------------------------------------------------+ |
|
||||
| |
|
||||
| +--------------------------------------------------------------+ |
|
||||
| | System Status Bar (full width) | |
|
||||
| +--------------------------------------------------------------+ |
|
||||
+------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 1.2 Responsive Breakpoints
|
||||
|
||||
| Breakpoint | Layout |
|
||||
|------------|--------|
|
||||
| Desktop (>1200px) | 4 cards row, 2-column panels |
|
||||
| Tablet (768-1200px) | 2x2 cards, 2-column panels |
|
||||
| Mobile (<768px) | 1 card per row, stacked panels |
|
||||
|
||||
---
|
||||
|
||||
## 2. Component Specifications
|
||||
|
||||
### 2.1 Stats Cards Row
|
||||
|
||||
4 个等宽卡片,显示核心统计数据。
|
||||
|
||||
```
|
||||
+-------------+ +-------------+ +-------------+ +-------------+
|
||||
| [icon] | | [icon] | | [icon] | | [icon] |
|
||||
| 38 | | 25 | | 8 | | 5 |
|
||||
| Total Docs | | Complete | | Incomplete | | Pending |
|
||||
+-------------+ +-------------+ +-------------+ +-------------+
|
||||
```
|
||||
|
||||
| Card | Icon | Value | Label | Color | Click Action |
|
||||
|------|------|-------|-------|-------|--------------|
|
||||
| Total Documents | FileText | `total_documents` | "Total Documents" | Gray | Navigate to Documents page |
|
||||
| Complete | CheckCircle | `annotation_complete` | "Complete" | Green | Navigate to Documents (filter: complete) |
|
||||
| Incomplete | AlertCircle | `annotation_incomplete` | "Incomplete" | Orange | Navigate to Documents (filter: incomplete) |
|
||||
| Pending | Clock | `pending` | "Pending" | Blue | Navigate to Documents (filter: pending) |
|
||||
|
||||
**Card Design:**
|
||||
- Background: White with subtle border
|
||||
- Icon: 24px, positioned top-left
|
||||
- Value: 32px bold font
|
||||
- Label: 14px muted color
|
||||
- Hover: Slight shadow elevation
|
||||
- Padding: 16px
|
||||
|
||||
### 2.2 Data Quality Panel
|
||||
|
||||
左侧面板,显示标注完整度和质量指标。
|
||||
|
||||
```
|
||||
+---------------------------+
|
||||
| DATA QUALITY |
|
||||
| +-----------+ |
|
||||
| | | |
|
||||
| | 78% | Annotation |
|
||||
| | | Complete |
|
||||
| +-----------+ |
|
||||
| |
|
||||
| Complete: 25 |
|
||||
| Incomplete: 8 |
|
||||
| Pending: 5 |
|
||||
| |
|
||||
| [View Incomplete Docs] |
|
||||
+---------------------------+
|
||||
```
|
||||
|
||||
**Components:**
|
||||
|
||||
| Element | Spec |
|
||||
|---------|------|
|
||||
| Title | "DATA QUALITY", 14px uppercase, muted |
|
||||
| Progress Ring | 120px diameter, stroke width 12px |
|
||||
| Percentage | 36px bold, centered in ring |
|
||||
| Label | "Annotation Complete", 14px, below ring |
|
||||
| Stats List | 14px, icon + label + value per row |
|
||||
| Action Button | Text button, primary color |
|
||||
|
||||
**Progress Ring Colors:**
|
||||
- Complete portion: Green (#22C55E)
|
||||
- Remaining: Gray (#E5E7EB)
|
||||
|
||||
**Completeness Calculation:**
|
||||
```
|
||||
completeness_rate = annotation_complete / (annotation_complete + annotation_incomplete) * 100
|
||||
```
|
||||
|
||||
### 2.3 Active Model Panel
|
||||
|
||||
右侧面板,显示当前生产模型信息。
|
||||
|
||||
```
|
||||
+-------------------------------+
|
||||
| ACTIVE MODEL |
|
||||
| |
|
||||
| v1.2.0 - Invoice Model |
|
||||
| ----------------------------- |
|
||||
| |
|
||||
| mAP Precision Recall |
|
||||
| 95.1% 94% 92% |
|
||||
| |
|
||||
| Activated: 2024-01-20 |
|
||||
| Documents: 500 |
|
||||
| |
|
||||
| [Training] Run-2024-02 [====] |
|
||||
+-------------------------------+
|
||||
```
|
||||
|
||||
**Components:**
|
||||
|
||||
| Element | Spec |
|
||||
|---------|------|
|
||||
| Title | "ACTIVE MODEL", 14px uppercase, muted |
|
||||
| Version + Name | 18px bold (version) + 16px regular (name) |
|
||||
| Divider | 1px border, full width |
|
||||
| Metrics Row | 3 columns, equal width |
|
||||
| Metric Value | 24px bold |
|
||||
| Metric Label | 12px muted, below value |
|
||||
| Info Rows | 14px, label: value format |
|
||||
| Training Indicator | Shows when training is running |
|
||||
|
||||
**Metric Colors:**
|
||||
- mAP >= 90%: Green
|
||||
- mAP 80-90%: Yellow
|
||||
- mAP < 80%: Red
|
||||
|
||||
**Empty State (No Active Model):**
|
||||
```
|
||||
+-------------------------------+
|
||||
| ACTIVE MODEL |
|
||||
| |
|
||||
| [icon: Model] |
|
||||
| No Active Model |
|
||||
| |
|
||||
| Train and activate a |
|
||||
| model to see stats here |
|
||||
| |
|
||||
| [Go to Training] |
|
||||
+-------------------------------+
|
||||
```
|
||||
|
||||
**Training In Progress:**
|
||||
```
|
||||
| Training: Run-2024-02 |
|
||||
| [=========> ] 45% |
|
||||
| Started 2 hours ago |
|
||||
```
|
||||
|
||||
### 2.4 Recent Activity Panel
|
||||
|
||||
全宽面板,显示最近 10 条系统活动。
|
||||
|
||||
```
|
||||
+--------------------------------------------------------------+
|
||||
| RECENT ACTIVITY [See All] |
|
||||
+--------------------------------------------------------------+
|
||||
| [rocket] Activated model v1.2.0 2 hours ago|
|
||||
| [check] Training complete: Run-2024-01, mAP 95.1% yesterday|
|
||||
| [edit] Modified INV-001.pdf invoice_number yesterday|
|
||||
| [doc] Uploaded INV-005.pdf 2 days ago|
|
||||
| [doc] Uploaded INV-004.pdf 2 days ago|
|
||||
| [x] Training failed: Run-2024-00 3 days ago|
|
||||
+--------------------------------------------------------------+
|
||||
```
|
||||
|
||||
**Activity Item Layout:**
|
||||
|
||||
```
|
||||
[Icon] [Description] [Timestamp]
|
||||
```
|
||||
|
||||
| Element | Spec |
|
||||
|---------|------|
|
||||
| Icon | 16px, color based on type |
|
||||
| Description | 14px, truncate if too long |
|
||||
| Timestamp | 12px muted, right-aligned |
|
||||
| Row Height | 40px |
|
||||
| Hover | Background highlight |
|
||||
|
||||
**Activity Types and Icons:**
|
||||
|
||||
| Type | Icon | Color | Description Format |
|
||||
|------|------|-------|-------------------|
|
||||
| document_uploaded | FileText | Blue | "Uploaded {filename}" |
|
||||
| annotation_modified | Edit | Orange | "Modified {filename} {field_name}" |
|
||||
| training_completed | CheckCircle | Green | "Training complete: {task_name}, mAP {mAP}%" |
|
||||
| training_failed | XCircle | Red | "Training failed: {task_name}" |
|
||||
| model_activated | Rocket | Purple | "Activated model {version}" |
|
||||
|
||||
**Timestamp Formatting:**
|
||||
- < 1 minute: "just now"
|
||||
- < 1 hour: "{n} minutes ago"
|
||||
- < 24 hours: "{n} hours ago"
|
||||
- < 7 days: "yesterday" / "{n} days ago"
|
||||
- >= 7 days: "Jan 15" (date format)
|
||||
|
||||
**Empty State:**
|
||||
```
|
||||
+--------------------------------------------------------------+
|
||||
| RECENT ACTIVITY |
|
||||
| |
|
||||
| [icon: Activity] |
|
||||
| No recent activity |
|
||||
| |
|
||||
| Start by uploading documents or creating training jobs |
|
||||
+--------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 2.5 System Status Bar
|
||||
|
||||
底部状态栏,显示系统健康状态。
|
||||
|
||||
```
|
||||
+--------------------------------------------------------------+
|
||||
| Backend API: [*] Online Database: [*] Connected GPU: [*] Available |
|
||||
+--------------------------------------------------------------+
|
||||
```
|
||||
|
||||
| Status | Icon | Color |
|
||||
|--------|------|-------|
|
||||
| Online/Connected/Available | Filled circle | Green |
|
||||
| Degraded/Slow | Filled circle | Yellow |
|
||||
| Offline/Error/Unavailable | Filled circle | Red |
|
||||
|
||||
---
|
||||
|
||||
## 3. API Endpoints
|
||||
|
||||
### 3.1 Dashboard Statistics
|
||||
|
||||
```
|
||||
GET /api/v1/admin/dashboard/stats
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"total_documents": 38,
|
||||
"annotation_complete": 25,
|
||||
"annotation_incomplete": 8,
|
||||
"pending": 5,
|
||||
"completeness_rate": 75.76
|
||||
}
|
||||
```
|
||||
|
||||
**Calculation Logic:**
|
||||
|
||||
```python
|
||||
# annotation_complete: labeled documents with core fields
|
||||
SELECT COUNT(*) FROM admin_documents d
|
||||
WHERE d.status = 'labeled'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM admin_annotations a
|
||||
WHERE a.document_id = d.document_id
|
||||
AND a.class_id IN (0, 3) -- invoice_number OR ocr_number
|
||||
)
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM admin_annotations a
|
||||
WHERE a.document_id = d.document_id
|
||||
AND a.class_id IN (4, 5) -- bankgiro OR plusgiro
|
||||
)
|
||||
|
||||
# annotation_incomplete: labeled but missing core fields
|
||||
SELECT COUNT(*) FROM admin_documents d
|
||||
WHERE d.status = 'labeled'
|
||||
AND NOT (/* above conditions */)
|
||||
|
||||
# pending: pending + auto_labeling
|
||||
SELECT COUNT(*) FROM admin_documents
|
||||
WHERE status IN ('pending', 'auto_labeling')
|
||||
```
|
||||
|
||||
### 3.2 Active Model Info
|
||||
|
||||
```
|
||||
GET /api/v1/admin/dashboard/active-model
|
||||
```
|
||||
|
||||
**Response (with active model):**
|
||||
```json
|
||||
{
|
||||
"model": {
|
||||
"version_id": "uuid",
|
||||
"version": "1.2.0",
|
||||
"name": "Invoice Model",
|
||||
"metrics_mAP": 0.951,
|
||||
"metrics_precision": 0.94,
|
||||
"metrics_recall": 0.92,
|
||||
"document_count": 500,
|
||||
"activated_at": "2024-01-20T15:00:00Z"
|
||||
},
|
||||
"running_training": {
|
||||
"task_id": "uuid",
|
||||
"name": "Run-2024-02",
|
||||
"status": "running",
|
||||
"started_at": "2024-01-25T10:00:00Z",
|
||||
"progress": 45
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response (no active model):**
|
||||
```json
|
||||
{
|
||||
"model": null,
|
||||
"running_training": null
|
||||
}
|
||||
```
|
||||
|
||||
### 3.3 Recent Activity
|
||||
|
||||
```
|
||||
GET /api/v1/admin/dashboard/activity?limit=10
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"activities": [
|
||||
{
|
||||
"type": "model_activated",
|
||||
"description": "Activated model v1.2.0",
|
||||
"timestamp": "2024-01-25T12:00:00Z",
|
||||
"metadata": {
|
||||
"version_id": "uuid",
|
||||
"version": "1.2.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "training_completed",
|
||||
"description": "Training complete: Run-2024-01, mAP 95.1%",
|
||||
"timestamp": "2024-01-24T18:30:00Z",
|
||||
"metadata": {
|
||||
"task_id": "uuid",
|
||||
"task_name": "Run-2024-01",
|
||||
"mAP": 0.951
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Activity Aggregation Query:**
|
||||
|
||||
```sql
|
||||
-- Union all activity sources, ordered by timestamp DESC, limit 10
|
||||
(
|
||||
SELECT 'document_uploaded' as type,
|
||||
filename as entity_name,
|
||||
created_at as timestamp,
|
||||
document_id as entity_id
|
||||
FROM admin_documents
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT 'annotation_modified' as type,
|
||||
-- join to get filename and field name
|
||||
...
|
||||
FROM annotation_history
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT CASE WHEN status = 'completed' THEN 'training_completed'
|
||||
WHEN status = 'failed' THEN 'training_failed' END as type,
|
||||
name as entity_name,
|
||||
completed_at as timestamp,
|
||||
task_id as entity_id
|
||||
FROM training_tasks
|
||||
WHERE status IN ('completed', 'failed')
|
||||
ORDER BY completed_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT 'model_activated' as type,
|
||||
version as entity_name,
|
||||
activated_at as timestamp,
|
||||
version_id as entity_id
|
||||
FROM model_versions
|
||||
WHERE activated_at IS NOT NULL
|
||||
ORDER BY activated_at DESC
|
||||
LIMIT 10
|
||||
)
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. UX Interactions
|
||||
|
||||
### 4.1 Loading States
|
||||
|
||||
| Component | Loading State |
|
||||
|-----------|--------------|
|
||||
| Stats Cards | Skeleton placeholder (gray boxes) |
|
||||
| Data Quality Ring | Skeleton circle |
|
||||
| Active Model | Skeleton lines |
|
||||
| Recent Activity | Skeleton list items (5 rows) |
|
||||
|
||||
**Loading Duration Thresholds:**
|
||||
- < 300ms: No loading state shown
|
||||
- 300ms - 3s: Show skeleton
|
||||
- > 3s: Show skeleton + "Taking longer than expected" message
|
||||
|
||||
### 4.2 Error States
|
||||
|
||||
| Error Type | Display |
|
||||
|------------|---------|
|
||||
| API Error | Toast notification + retry button in affected panel |
|
||||
| Network Error | Full page overlay with retry option |
|
||||
| Partial Failure | Show available data, error badge on failed sections |
|
||||
|
||||
### 4.3 Refresh Behavior
|
||||
|
||||
| Trigger | Behavior |
|
||||
|---------|----------|
|
||||
| Page Load | Fetch all data |
|
||||
| Manual Refresh | Button in header, refetch all |
|
||||
| Auto Refresh | Every 30 seconds for activity panel |
|
||||
| Focus Return | Refetch if page was hidden > 5 minutes |
|
||||
|
||||
### 4.4 Click Actions
|
||||
|
||||
| Element | Action |
|
||||
|---------|--------|
|
||||
| Total Documents card | Navigate to `/documents` |
|
||||
| Complete card | Navigate to `/documents?filter=complete` |
|
||||
| Incomplete card | Navigate to `/documents?filter=incomplete` |
|
||||
| Pending card | Navigate to `/documents?filter=pending` |
|
||||
| "View Incomplete Docs" button | Navigate to `/documents?filter=incomplete` |
|
||||
| Activity item | Navigate to related entity |
|
||||
| "Go to Training" button | Navigate to `/training` |
|
||||
| Active Model version | Navigate to `/models/{version_id}` |
|
||||
|
||||
### 4.5 Tooltips
|
||||
|
||||
| Element | Tooltip Content |
|
||||
|---------|----------------|
|
||||
| Completeness % | "25 of 33 labeled documents have complete annotations" |
|
||||
| mAP metric | "Mean Average Precision at IoU 0.5" |
|
||||
| Precision metric | "Proportion of correct positive predictions" |
|
||||
| Recall metric | "Proportion of actual positives correctly identified" |
|
||||
| Incomplete count | "Documents labeled but missing invoice_number/ocr_number or bankgiro/plusgiro" |
|
||||
|
||||
---
|
||||
|
||||
## 5. Data Model
|
||||
|
||||
### 5.1 TypeScript Types
|
||||
|
||||
```typescript
|
||||
// Dashboard Stats
|
||||
interface DashboardStats {
|
||||
total_documents: number;
|
||||
annotation_complete: number;
|
||||
annotation_incomplete: number;
|
||||
pending: number;
|
||||
completeness_rate: number;
|
||||
}
|
||||
|
||||
// Active Model
|
||||
interface ActiveModelInfo {
|
||||
model: ModelVersion | null;
|
||||
running_training: RunningTraining | null;
|
||||
}
|
||||
|
||||
interface ModelVersion {
|
||||
version_id: string;
|
||||
version: string;
|
||||
name: string;
|
||||
metrics_mAP: number;
|
||||
metrics_precision: number;
|
||||
metrics_recall: number;
|
||||
document_count: number;
|
||||
activated_at: string;
|
||||
}
|
||||
|
||||
interface RunningTraining {
|
||||
task_id: string;
|
||||
name: string;
|
||||
status: 'running';
|
||||
started_at: string;
|
||||
progress: number;
|
||||
}
|
||||
|
||||
// Activity
|
||||
interface Activity {
|
||||
type: ActivityType;
|
||||
description: string;
|
||||
timestamp: string;
|
||||
metadata: Record<string, unknown>;
|
||||
}
|
||||
|
||||
type ActivityType =
|
||||
| 'document_uploaded'
|
||||
| 'annotation_modified'
|
||||
| 'training_completed'
|
||||
| 'training_failed'
|
||||
| 'model_activated';
|
||||
|
||||
// Activity Response
|
||||
interface ActivityResponse {
|
||||
activities: Activity[];
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 React Query Hooks
|
||||
|
||||
```typescript
|
||||
// useDashboardStats
|
||||
const useDashboardStats = () => {
|
||||
return useQuery({
|
||||
queryKey: ['dashboard', 'stats'],
|
||||
queryFn: () => api.get('/admin/dashboard/stats'),
|
||||
refetchInterval: 30000, // 30 seconds
|
||||
});
|
||||
};
|
||||
|
||||
// useActiveModel
|
||||
const useActiveModel = () => {
|
||||
return useQuery({
|
||||
queryKey: ['dashboard', 'active-model'],
|
||||
queryFn: () => api.get('/admin/dashboard/active-model'),
|
||||
refetchInterval: 60000, // 1 minute
|
||||
});
|
||||
};
|
||||
|
||||
// useRecentActivity
|
||||
const useRecentActivity = (limit = 10) => {
|
||||
return useQuery({
|
||||
queryKey: ['dashboard', 'activity', limit],
|
||||
queryFn: () => api.get(`/admin/dashboard/activity?limit=${limit}`),
|
||||
refetchInterval: 30000,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Annotation Completeness Definition
|
||||
|
||||
### 6.1 Core Fields
|
||||
|
||||
A document is **complete** when it has annotations for:
|
||||
|
||||
| Requirement | Fields | Logic |
|
||||
|-------------|--------|-------|
|
||||
| Identifier | `invoice_number` (class_id=0) OR `ocr_number` (class_id=3) | At least one |
|
||||
| Payment Account | `bankgiro` (class_id=4) OR `plusgiro` (class_id=5) | At least one |
|
||||
|
||||
### 6.2 Status Categories
|
||||
|
||||
| Category | Criteria |
|
||||
|----------|----------|
|
||||
| **Complete** | status=labeled AND has identifier AND has payment account |
|
||||
| **Incomplete** | status=labeled AND (missing identifier OR missing payment account) |
|
||||
| **Pending** | status IN (pending, auto_labeling) |
|
||||
|
||||
### 6.3 Filter Implementation
|
||||
|
||||
```sql
|
||||
-- Complete documents
|
||||
WHERE status = 'labeled'
|
||||
AND document_id IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (0, 3)
|
||||
)
|
||||
AND document_id IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (4, 5)
|
||||
)
|
||||
|
||||
-- Incomplete documents
|
||||
WHERE status = 'labeled'
|
||||
AND (
|
||||
document_id NOT IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (0, 3)
|
||||
)
|
||||
OR document_id NOT IN (
|
||||
SELECT document_id FROM admin_annotations WHERE class_id IN (4, 5)
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Implementation Checklist
|
||||
|
||||
### Backend
|
||||
- [ ] Create `/api/v1/admin/dashboard/stats` endpoint
|
||||
- [ ] Create `/api/v1/admin/dashboard/active-model` endpoint
|
||||
- [ ] Create `/api/v1/admin/dashboard/activity` endpoint
|
||||
- [ ] Add completeness calculation logic to document repository
|
||||
- [ ] Implement activity aggregation query
|
||||
|
||||
### Frontend
|
||||
- [ ] Create `DashboardOverview` component
|
||||
- [ ] Create `StatsCard` component
|
||||
- [ ] Create `DataQualityPanel` component with progress ring
|
||||
- [ ] Create `ActiveModelPanel` component
|
||||
- [ ] Create `RecentActivityPanel` component
|
||||
- [ ] Create `SystemStatusBar` component
|
||||
- [ ] Add React Query hooks for dashboard data
|
||||
- [ ] Implement loading skeletons
|
||||
- [ ] Implement error states
|
||||
- [ ] Add navigation actions
|
||||
- [ ] Add tooltips
|
||||
|
||||
### Testing
|
||||
- [ ] Unit tests for completeness calculation
|
||||
- [ ] Unit tests for activity aggregation
|
||||
- [ ] Integration tests for dashboard endpoints
|
||||
- [ ] E2E tests for dashboard interactions
|
||||
@@ -1,619 +0,0 @@
|
||||
# 多池处理架构设计文档
|
||||
|
||||
## 1. 研究总结
|
||||
|
||||
### 1.1 当前问题分析
|
||||
|
||||
我们之前实现的双池模式存在稳定性问题,主要原因:
|
||||
|
||||
| 问题 | 原因 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
|
||||
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
|
||||
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
|
||||
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
|
||||
|
||||
### 1.2 推荐架构方案
|
||||
|
||||
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**:
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Main Process │ │ CPU Workers │ │ GPU Worker │
|
||||
│ │ │ (4 processes) │ │ (1 process) │
|
||||
│ ┌───────────┐ │ │ │ │ │
|
||||
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
|
||||
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
|
||||
│ └───────────┘ │ │ │ │ │
|
||||
│ ▲ │ │ │ │ │ │ │
|
||||
│ │ │ │ ▼ │ │ ▼ │
|
||||
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
|
||||
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
|
||||
│ │ Collector │ │ │ │ │ │
|
||||
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────┐ │
|
||||
│ │ Database │ │
|
||||
│ │ Batch │ │
|
||||
│ │ Writer │ │
|
||||
│ └───────────┘ │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 核心设计原则
|
||||
|
||||
### 2.1 CUDA 兼容性
|
||||
|
||||
```python
|
||||
# 关键:使用 spawn 启动方式
|
||||
import multiprocessing as mp
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
# GPU worker 初始化时设置设备
|
||||
def init_gpu_worker(gpu_id: int = 0):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
global _ocr
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=True, ...)
|
||||
```
|
||||
|
||||
### 2.2 Worker 初始化模式
|
||||
|
||||
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
|
||||
|
||||
```python
|
||||
# 全局变量保存模型
|
||||
_ocr = None
|
||||
|
||||
def init_worker(use_gpu: bool, gpu_id: int = 0):
|
||||
global _ocr
|
||||
if use_gpu:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
|
||||
|
||||
# 创建 Pool 时使用 initializer
|
||||
pool = ProcessPoolExecutor(
|
||||
max_workers=1,
|
||||
initializer=init_worker,
|
||||
initargs=(True, 0), # use_gpu=True, gpu_id=0
|
||||
mp_context=mp.get_context("spawn")
|
||||
)
|
||||
```
|
||||
|
||||
### 2.3 队列模式 vs as_completed
|
||||
|
||||
| 方式 | 优点 | 缺点 | 适用场景 |
|
||||
|------|------|------|----------|
|
||||
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
|
||||
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
|
||||
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
|
||||
|
||||
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
|
||||
|
||||
---
|
||||
|
||||
## 3. 详细开发计划
|
||||
|
||||
### 阶段 1:重构基础架构 (2-3天)
|
||||
|
||||
#### 1.1 创建 WorkerPool 抽象类
|
||||
|
||||
```python
|
||||
# src/processing/worker_pool.py
|
||||
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Optional, Callable
|
||||
import multiprocessing as mp
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""任务结果容器"""
|
||||
task_id: str
|
||||
success: bool
|
||||
data: Any
|
||||
error: Optional[str] = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
class WorkerPool(ABC):
|
||||
"""Worker Pool 抽象基类"""
|
||||
|
||||
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
|
||||
self.max_workers = max_workers
|
||||
self.use_gpu = use_gpu
|
||||
self.gpu_id = gpu_id
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_initializer(self) -> Callable:
|
||||
"""返回 worker 初始化函数"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_init_args(self) -> tuple:
|
||||
"""返回初始化参数"""
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
"""启动 worker pool"""
|
||||
ctx = mp.get_context("spawn")
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=self.max_workers,
|
||||
mp_context=ctx,
|
||||
initializer=self.get_initializer(),
|
||||
initargs=self.get_init_args()
|
||||
)
|
||||
|
||||
def submit(self, fn: Callable, *args, **kwargs) -> Future:
|
||||
"""提交任务"""
|
||||
if not self._executor:
|
||||
raise RuntimeError("Pool not started")
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def shutdown(self, wait: bool = True):
|
||||
"""关闭 pool"""
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=wait)
|
||||
self._executor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.shutdown()
|
||||
```
|
||||
|
||||
#### 1.2 实现 CPU 和 GPU Worker Pool
|
||||
|
||||
```python
|
||||
# src/processing/cpu_pool.py
|
||||
|
||||
class CPUWorkerPool(WorkerPool):
|
||||
"""CPU-only worker pool for text PDF processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 4):
|
||||
super().__init__(max_workers=max_workers, use_gpu=False)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_cpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return ()
|
||||
|
||||
# src/processing/gpu_pool.py
|
||||
|
||||
class GPUWorkerPool(WorkerPool):
|
||||
"""GPU worker pool for OCR processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
|
||||
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_gpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return (self.gpu_id,)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 2:实现双池协调器 (2-3天)
|
||||
|
||||
#### 2.1 任务分发器
|
||||
|
||||
```python
|
||||
# src/processing/task_dispatcher.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple
|
||||
|
||||
class TaskType(Enum):
|
||||
CPU = auto() # Text PDF
|
||||
GPU = auto() # Scanned PDF
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
id: str
|
||||
task_type: TaskType
|
||||
data: Any
|
||||
|
||||
class TaskDispatcher:
|
||||
"""根据 PDF 类型分发任务到不同的 pool"""
|
||||
|
||||
def classify_task(self, doc_info: dict) -> TaskType:
|
||||
"""判断文档是否需要 OCR"""
|
||||
# 基于 PDF 特征判断
|
||||
if self._is_scanned_pdf(doc_info):
|
||||
return TaskType.GPU
|
||||
return TaskType.CPU
|
||||
|
||||
def _is_scanned_pdf(self, doc_info: dict) -> bool:
|
||||
"""检测是否为扫描件"""
|
||||
# 1. 检查是否有可提取文本
|
||||
# 2. 检查图片比例
|
||||
# 3. 检查文本密度
|
||||
pass
|
||||
|
||||
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
|
||||
"""将任务分为 CPU 和 GPU 两组"""
|
||||
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
|
||||
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
|
||||
return cpu_tasks, gpu_tasks
|
||||
```
|
||||
|
||||
#### 2.2 双池协调器
|
||||
|
||||
```python
|
||||
# src/processing/dual_pool_coordinator.py
|
||||
|
||||
from concurrent.futures import as_completed
|
||||
from typing import List, Iterator
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DualPoolCoordinator:
|
||||
"""协调 CPU 和 GPU 两个 worker pool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_workers: int = 4,
|
||||
gpu_workers: int = 1,
|
||||
gpu_id: int = 0
|
||||
):
|
||||
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
|
||||
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
|
||||
self.dispatcher = TaskDispatcher()
|
||||
|
||||
def __enter__(self):
|
||||
self.cpu_pool.start()
|
||||
self.gpu_pool.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.cpu_pool.shutdown()
|
||||
self.gpu_pool.shutdown()
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
documents: List[dict],
|
||||
cpu_task_fn: Callable,
|
||||
gpu_task_fn: Callable,
|
||||
on_result: Optional[Callable[[TaskResult], None]] = None,
|
||||
on_error: Optional[Callable[[str, Exception], None]] = None
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
处理一批文档,自动分发到 CPU 或 GPU pool
|
||||
|
||||
Args:
|
||||
documents: 待处理文档列表
|
||||
cpu_task_fn: CPU 任务处理函数
|
||||
gpu_task_fn: GPU 任务处理函数
|
||||
on_result: 结果回调(可选)
|
||||
on_error: 错误回调(可选)
|
||||
|
||||
Returns:
|
||||
所有任务结果列表
|
||||
"""
|
||||
# 分类任务
|
||||
tasks = [
|
||||
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
|
||||
for doc in documents
|
||||
]
|
||||
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
|
||||
|
||||
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
|
||||
|
||||
# 提交任务到各自的 pool
|
||||
cpu_futures = {
|
||||
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
|
||||
for t in cpu_tasks
|
||||
}
|
||||
gpu_futures = {
|
||||
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
|
||||
for t in gpu_tasks
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
results = []
|
||||
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
|
||||
|
||||
for future in as_completed(all_futures):
|
||||
task_id = cpu_futures.get(future) or gpu_futures.get(future)
|
||||
pool_type = "CPU" if future in cpu_futures else "GPU"
|
||||
|
||||
try:
|
||||
data = future.result(timeout=300) # 5分钟超时
|
||||
result = TaskResult(task_id=task_id, success=True, data=data)
|
||||
if on_result:
|
||||
on_result(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
|
||||
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
|
||||
if on_error:
|
||||
on_error(task_id, e)
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 3:集成到 autolabel (1-2天)
|
||||
|
||||
#### 3.1 修改 autolabel.py
|
||||
|
||||
```python
|
||||
# src/cli/autolabel.py
|
||||
|
||||
def run_autolabel_dual_pool(args):
|
||||
"""使用双池模式运行自动标注"""
|
||||
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator
|
||||
|
||||
# 初始化数据库批处理
|
||||
db_batch = []
|
||||
db_batch_size = 100
|
||||
|
||||
def on_result(result: TaskResult):
|
||||
"""处理成功结果"""
|
||||
nonlocal db_batch
|
||||
db_batch.append(result.data)
|
||||
|
||||
if len(db_batch) >= db_batch_size:
|
||||
save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
def on_error(task_id: str, error: Exception):
|
||||
"""处理错误"""
|
||||
logger.error(f"Task {task_id} failed: {error}")
|
||||
|
||||
# 创建双池协调器
|
||||
with DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers or 4,
|
||||
gpu_workers=args.gpu_workers or 1,
|
||||
gpu_id=0
|
||||
) as coordinator:
|
||||
|
||||
# 处理所有 CSV
|
||||
for csv_file in csv_files:
|
||||
documents = load_documents_from_csv(csv_file)
|
||||
|
||||
results = coordinator.process_batch(
|
||||
documents=documents,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=on_result,
|
||||
on_error=on_error
|
||||
)
|
||||
|
||||
logger.info(f"CSV {csv_file}: {len(results)} processed")
|
||||
|
||||
# 保存剩余批次
|
||||
if db_batch:
|
||||
save_documents_batch(db_batch)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 4:测试与验证 (1-2天)
|
||||
|
||||
#### 4.1 单元测试
|
||||
|
||||
```python
|
||||
# tests/unit/test_dual_pool.py
|
||||
|
||||
import pytest
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
|
||||
|
||||
class TestDualPoolCoordinator:
|
||||
|
||||
def test_cpu_only_batch(self):
|
||||
"""测试纯 CPU 任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 10
|
||||
assert all(r.success for r in results)
|
||||
|
||||
def test_mixed_batch(self):
|
||||
"""测试混合任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [
|
||||
{"id": "text_1", "type": "text"},
|
||||
{"id": "scan_1", "type": "scanned"},
|
||||
{"id": "text_2", "type": "text"},
|
||||
]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_timeout_handling(self):
|
||||
"""测试超时处理"""
|
||||
pass
|
||||
|
||||
def test_error_recovery(self):
|
||||
"""测试错误恢复"""
|
||||
pass
|
||||
```
|
||||
|
||||
#### 4.2 集成测试
|
||||
|
||||
```python
|
||||
# tests/integration/test_autolabel_dual_pool.py
|
||||
|
||||
def test_autolabel_with_dual_pool():
|
||||
"""端到端测试双池模式"""
|
||||
# 使用少量测试数据
|
||||
result = subprocess.run([
|
||||
"python", "-m", "src.cli.autolabel",
|
||||
"--cpu-workers", "2",
|
||||
"--gpu-workers", "1",
|
||||
"--limit", "50"
|
||||
], capture_output=True)
|
||||
|
||||
assert result.returncode == 0
|
||||
# 验证数据库记录
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键技术点
|
||||
|
||||
### 4.1 避免死锁的策略
|
||||
|
||||
```python
|
||||
# 1. 使用 timeout
|
||||
try:
|
||||
result = future.result(timeout=300)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Task timed out")
|
||||
|
||||
# 2. 使用哨兵值
|
||||
SENTINEL = object()
|
||||
queue.put(SENTINEL) # 发送结束信号
|
||||
|
||||
# 3. 检查进程状态
|
||||
if not worker.is_alive():
|
||||
logger.error("Worker died unexpectedly")
|
||||
break
|
||||
|
||||
# 4. 先清空队列再 join
|
||||
while not queue.empty():
|
||||
results.append(queue.get_nowait())
|
||||
worker.join(timeout=5.0)
|
||||
```
|
||||
|
||||
### 4.2 PaddleOCR 特殊处理
|
||||
|
||||
```python
|
||||
# PaddleOCR 必须在 worker 进程中初始化
|
||||
def init_paddle_worker(gpu_id: int):
|
||||
global _ocr
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
# 延迟导入,确保 CUDA 环境变量生效
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang='en',
|
||||
use_gpu=True,
|
||||
show_log=False,
|
||||
# 重要:设置 GPU 内存比例
|
||||
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3 资源监控
|
||||
|
||||
```python
|
||||
import psutil
|
||||
import GPUtil
|
||||
|
||||
def get_resource_usage():
|
||||
"""获取系统资源使用情况"""
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
gpu_info = []
|
||||
for gpu in GPUtil.getGPUs():
|
||||
gpu_info.append({
|
||||
"id": gpu.id,
|
||||
"memory_used": gpu.memoryUsed,
|
||||
"memory_total": gpu.memoryTotal,
|
||||
"utilization": gpu.load * 100
|
||||
})
|
||||
|
||||
return {
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_percent": memory.percent,
|
||||
"gpu": gpu_info
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 风险评估与应对
|
||||
|
||||
| 风险 | 可能性 | 影响 | 应对策略 |
|
||||
|------|--------|------|----------|
|
||||
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1,设置 gpu_mem 参数 |
|
||||
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
|
||||
| 任务分类错误 | 中 | 中 | 添加回退机制,CPU 失败后尝试 GPU |
|
||||
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 备选方案
|
||||
|
||||
如果上述方案仍存在问题,可以考虑:
|
||||
|
||||
### 6.1 使用 Ray
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def cpu_task(data):
|
||||
return process_text_pdf(data)
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
def gpu_task(data):
|
||||
return process_scanned_pdf(data)
|
||||
|
||||
# 自动资源调度
|
||||
futures = [cpu_task.remote(d) for d in cpu_docs]
|
||||
futures += [gpu_task.remote(d) for d in gpu_docs]
|
||||
results = ray.get(futures)
|
||||
```
|
||||
|
||||
### 6.2 单池 + 动态 GPU 调度
|
||||
|
||||
保持单池模式,但在每个任务内部动态决定是否使用 GPU:
|
||||
|
||||
```python
|
||||
def process_document(doc_data):
|
||||
if is_scanned_pdf(doc_data):
|
||||
# 使用 GPU (需要全局锁或信号量控制并发)
|
||||
with gpu_semaphore:
|
||||
return process_with_ocr(doc_data)
|
||||
else:
|
||||
return process_text_only(doc_data)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 时间线总结
|
||||
|
||||
| 阶段 | 任务 | 预计工作量 |
|
||||
|------|------|------------|
|
||||
| 阶段 1 | 基础架构重构 | 2-3 天 |
|
||||
| 阶段 2 | 双池协调器实现 | 2-3 天 |
|
||||
| 阶段 3 | 集成到 autolabel | 1-2 天 |
|
||||
| 阶段 4 | 测试与验证 | 1-2 天 |
|
||||
| **总计** | | **6-10 天** |
|
||||
|
||||
---
|
||||
|
||||
## 8. 参考资料
|
||||
|
||||
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
|
||||
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
|
||||
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
|
||||
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
|
||||
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
|
||||
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)
|
||||
35
docs/product-plan-v2-CHANGELOG.md
Normal file
35
docs/product-plan-v2-CHANGELOG.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Product Plan v2 - Change Log
|
||||
|
||||
## [v2.1] - 2026-02-01
|
||||
|
||||
### New Features
|
||||
|
||||
#### Epic 7: Dashboard Enhancement
|
||||
- Added **US-7.1**: Data quality metrics panel showing annotation completeness rate
|
||||
- Added **US-7.2**: Active model status panel with mAP/precision/recall metrics
|
||||
- Added **US-7.3**: Recent activity feed showing last 10 system activities
|
||||
- Added **US-7.4**: Meaningful stats cards (Total/Complete/Incomplete/Pending)
|
||||
|
||||
#### Annotation Completeness Definition
|
||||
- Defined "annotation complete" criteria:
|
||||
- Must have `invoice_number` OR `ocr_number` (identifier)
|
||||
- Must have `bankgiro` OR `plusgiro` (payment account)
|
||||
|
||||
### New API Endpoints
|
||||
- Added `GET /api/v1/admin/dashboard/stats` - Dashboard statistics with completeness calculation
|
||||
- Added `GET /api/v1/admin/dashboard/active-model` - Active model info with running training status
|
||||
- Added `GET /api/v1/admin/dashboard/activity` - Recent activity feed aggregated from multiple sources
|
||||
|
||||
### New UI Components
|
||||
- Added **5.0 Dashboard Overview** wireframe with:
|
||||
- Stats cards row (Total/Complete/Incomplete/Pending)
|
||||
- Data Quality panel with percentage ring
|
||||
- Active Model panel with metrics display
|
||||
- Recent Activity list with icons and relative timestamps
|
||||
- System Status bar
|
||||
|
||||
---
|
||||
|
||||
## [v2.0] - 2024-01-15
|
||||
- Initial version with Epic 1-6
|
||||
- Batch upload, document management, annotation workflow, training management
|
||||
1448
docs/product-plan-v2.md
Normal file
1448
docs/product-plan-v2.md
Normal file
File diff suppressed because it is too large
Load Diff
54
docs/training-flow.mmd
Normal file
54
docs/training-flow.mmd
Normal file
@@ -0,0 +1,54 @@
|
||||
flowchart TD
|
||||
A[CLI Entry Point\nsrc/cli/train.py] --> B[Parse Arguments\n--model, --epochs, --batch, --imgsz, etc.]
|
||||
B --> C[Connect PostgreSQL\nDB_HOST / DB_NAME / DB_PASSWORD]
|
||||
|
||||
C --> D[Load Data from DB\nsrc/yolo/db_dataset.py]
|
||||
D --> D1[Scan temp/doc_id/images/\nfor rendered PNGs]
|
||||
D --> D2[Batch load field_results\nfrom database - batch 500]
|
||||
|
||||
D1 --> E[Create DBYOLODataset]
|
||||
D2 --> E
|
||||
|
||||
E --> F[Split Train/Val/Test\n80% / 10% / 10%\nDocument-level, seed=42]
|
||||
|
||||
F --> G[Export to YOLO Format]
|
||||
G --> G1[Copy images to\ntrain/val/test dirs]
|
||||
G --> G2[Generate .txt labels\nclass x_center y_center w h]
|
||||
G --> G3[Generate dataset.yaml\n+ classes.txt]
|
||||
G --> G4[Coordinate Conversion\nPDF points 72DPI -> render DPI\nNormalize to 0-1]
|
||||
|
||||
G1 --> H{--export-only?}
|
||||
G2 --> H
|
||||
G3 --> H
|
||||
G4 --> H
|
||||
|
||||
H -- Yes --> Z[Done - Dataset exported]
|
||||
H -- No --> I[Load YOLO Model]
|
||||
|
||||
I --> I1{--resume?}
|
||||
I1 -- Yes --> I2[Load last.pt checkpoint]
|
||||
I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt]
|
||||
|
||||
I2 --> J[Configure Training]
|
||||
I3 --> J
|
||||
|
||||
J --> J1[Conservative Augmentation\nrotation=5 deg, translate=5%\nno flip, no mosaic, no mixup]
|
||||
J --> J2[imgsz=1280, pretrained=True]
|
||||
|
||||
J1 --> K[model.train\nUltralytics Training Loop]
|
||||
J2 --> K
|
||||
|
||||
K --> L[Training Outputs\nruns/train/name/]
|
||||
L --> L1[weights/best.pt\nweights/last.pt]
|
||||
L --> L2[results.csv + results.png\nTraining curves]
|
||||
L --> L3[PR curves, F1 curves\nConfusion matrix]
|
||||
|
||||
L1 --> M[Test Set Validation\nmodel.val split=test]
|
||||
M --> N[Report Metrics\nmAP@0.5 = 93.5%\nmAP@0.5-0.95]
|
||||
|
||||
N --> O[Close DB Connection]
|
||||
|
||||
style A fill:#4a90d9,color:#fff
|
||||
style K fill:#e67e22,color:#fff
|
||||
style N fill:#27ae60,color:#fff
|
||||
style Z fill:#95a5a6,color:#fff
|
||||
302
docs/ux-design-prompt-v2.md
Normal file
302
docs/ux-design-prompt-v2.md
Normal file
@@ -0,0 +1,302 @@
|
||||
# Document Annotation Tool – UX Design Spec v2
|
||||
|
||||
## Theme: Warm Graphite (Modern Enterprise)
|
||||
|
||||
---
|
||||
|
||||
## 1. Design Principles (Updated)
|
||||
|
||||
1. **Clarity** – High contrast, but never pure black-on-white
|
||||
2. **Warm Neutrality** – Slightly warm grays reduce visual fatigue
|
||||
3. **Focus** – Content-first layouts with restrained accents
|
||||
4. **Consistency** – Reusable patterns, predictable behavior
|
||||
5. **Professional Trust** – Calm, serious, enterprise-ready
|
||||
6. **Longevity** – No trendy colors that age quickly
|
||||
|
||||
---
|
||||
|
||||
## 2. Color Palette (Warm Graphite)
|
||||
|
||||
### Core Colors
|
||||
|
||||
| Usage | Color Name | Hex |
|
||||
|------|-----------|-----|
|
||||
| Primary Text | Soft Black | #121212 |
|
||||
| Secondary Text | Charcoal Gray | #2A2A2A |
|
||||
| Muted Text | Warm Gray | #6B6B6B |
|
||||
| Disabled Text | Light Warm Gray | #9A9A9A |
|
||||
|
||||
### Backgrounds
|
||||
|
||||
| Usage | Color | Hex |
|
||||
|-----|------|-----|
|
||||
| App Background | Paper White | #FAFAF8 |
|
||||
| Card / Panel | White | #FFFFFF |
|
||||
| Hover Surface | Subtle Warm Gray | #F1F0ED |
|
||||
| Selected Row | Very Light Warm Gray | #ECEAE6 |
|
||||
|
||||
### Borders & Dividers
|
||||
|
||||
| Usage | Color | Hex |
|
||||
|------|------|-----|
|
||||
| Default Border | Warm Light Gray | #E6E4E1 |
|
||||
| Strong Divider | Neutral Gray | #D8D6D2 |
|
||||
|
||||
### Semantic States (Muted & Professional)
|
||||
|
||||
| State | Color | Hex |
|
||||
|------|-------|-----|
|
||||
| Success | Olive Gray | #3E4A3A |
|
||||
| Error | Brick Gray | #4A3A3A |
|
||||
| Warning | Sand Gray | #4A4A3A |
|
||||
| Info | Graphite Gray | #3A3A3A |
|
||||
|
||||
> Accent colors are **never saturated** and are used only for status, progress, or selection.
|
||||
|
||||
---
|
||||
|
||||
## 3. Typography
|
||||
|
||||
- **Font Family**: Inter / SF Pro / system-ui
|
||||
- **Headings**:
|
||||
- Weight: 600–700
|
||||
- Color: #121212
|
||||
- Letter spacing: -0.01em
|
||||
- **Body Text**:
|
||||
- Weight: 400
|
||||
- Color: #2A2A2A
|
||||
- **Captions / Meta**:
|
||||
- Weight: 400
|
||||
- Color: #6B6B6B
|
||||
- **Monospace (IDs / Values)**:
|
||||
- JetBrains Mono / SF Mono
|
||||
- Color: #2A2A2A
|
||||
|
||||
---
|
||||
|
||||
## 4. Global Layout
|
||||
|
||||
### Top Navigation Bar
|
||||
|
||||
- Height: 56px
|
||||
- Background: #FAFAF8
|
||||
- Bottom Border: 1px solid #E6E4E1
|
||||
- Logo: Text or icon in #121212
|
||||
|
||||
**Navigation Items**
|
||||
- Default: #6B6B6B
|
||||
- Hover: #2A2A2A
|
||||
- Active:
|
||||
- Text: #121212
|
||||
- Bottom indicator: 2px solid #3A3A3A (rounded ends)
|
||||
|
||||
**Avatar**
|
||||
- Circle background: #ECEAE6
|
||||
- Text: #2A2A2A
|
||||
|
||||
---
|
||||
|
||||
## 5. Page: Documents (Dashboard)
|
||||
|
||||
### Page Header
|
||||
|
||||
- Title: "Documents" (#121212)
|
||||
- Actions:
|
||||
- Primary button: Dark graphite outline
|
||||
- Secondary button: Subtle border only
|
||||
|
||||
### Filters Bar
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Border: 1px solid #E6E4E1
|
||||
- Inputs:
|
||||
- Background: #FFFFFF
|
||||
- Hover: #F1F0ED
|
||||
- Focus ring: 1px #3A3A3A
|
||||
|
||||
### Document Table
|
||||
|
||||
- Table background: #FFFFFF
|
||||
- Header text: #6B6B6B
|
||||
- Row hover: #F1F0ED
|
||||
- Row selected:
|
||||
- Background: #ECEAE6
|
||||
- Left indicator: 3px solid #3A3A3A
|
||||
|
||||
### Status Badges
|
||||
|
||||
- Pending:
|
||||
- BG: #FFFFFF
|
||||
- Border: #D8D6D2
|
||||
- Text: #2A2A2A
|
||||
|
||||
- Labeled:
|
||||
- BG: #2A2A2A
|
||||
- Text: #FFFFFF
|
||||
|
||||
- Exported:
|
||||
- BG: #ECEAE6
|
||||
- Text: #2A2A2A
|
||||
- Icon: ✓
|
||||
|
||||
### Auto-label States
|
||||
|
||||
- Running:
|
||||
- Progress bar: #3A3A3A on #ECEAE6
|
||||
- Completed:
|
||||
- Text: #3E4A3A
|
||||
- Failed:
|
||||
- BG: #F1EDED
|
||||
- Text: #4A3A3A
|
||||
|
||||
---
|
||||
|
||||
## 6. Upload Modals (Single & Batch)
|
||||
|
||||
### Modal Container
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Border radius: 8px
|
||||
- Shadow: 0 1px 3px rgba(0,0,0,0.08)
|
||||
|
||||
### Drop Zone
|
||||
|
||||
- Background: #FAFAF8
|
||||
- Border: 1px dashed #D8D6D2
|
||||
- Hover: #F1F0ED
|
||||
- Icon: Graphite gray
|
||||
|
||||
### Form Fields
|
||||
|
||||
- Input BG: #FFFFFF
|
||||
- Border: #D8D6D2
|
||||
- Focus: 1px solid #3A3A3A
|
||||
|
||||
Primary Action Button:
|
||||
- Text: #FFFFFF
|
||||
- BG: #2A2A2A
|
||||
- Hover: #121212
|
||||
|
||||
---
|
||||
|
||||
## 7. Document Detail View
|
||||
|
||||
### Canvas Area
|
||||
|
||||
- Background: #FFFFFF
|
||||
- Annotation styles:
|
||||
- Manual: Solid border #2A2A2A
|
||||
- Auto: Dashed border #6B6B6B
|
||||
- Selected: 2px border #3A3A3A + resize handles
|
||||
|
||||
### Right Info Panel
|
||||
|
||||
- Card background: #FFFFFF
|
||||
- Section headers: #121212
|
||||
- Meta text: #6B6B6B
|
||||
|
||||
### Annotation Table
|
||||
|
||||
- Same table styles as Documents
|
||||
- Inline edit:
|
||||
- Input background: #FAFAF8
|
||||
- Save button: Graphite
|
||||
|
||||
### Locked State (Auto-label Running)
|
||||
|
||||
- Banner BG: #FAFAF8
|
||||
- Border-left: 3px solid #4A4A3A
|
||||
- Progress bar: Graphite
|
||||
|
||||
---
|
||||
|
||||
## 8. Training Page
|
||||
|
||||
### Document Selector
|
||||
|
||||
- Selected rows use same highlight rules
|
||||
- Verified state:
|
||||
- Full: Olive gray check
|
||||
- Partial: Sand gray warning
|
||||
|
||||
### Configuration Panel
|
||||
|
||||
- Card layout
|
||||
- Inputs aligned to grid
|
||||
- Schedule option visually muted until enabled
|
||||
|
||||
Primary CTA:
|
||||
- Start Training button in dark graphite
|
||||
|
||||
---
|
||||
|
||||
## 9. Models & Training History
|
||||
|
||||
### Training Job List
|
||||
|
||||
- Job cards use #FFFFFF background
|
||||
- Running job:
|
||||
- Progress bar: #3A3A3A
|
||||
- Completed job:
|
||||
- Metrics bars in graphite
|
||||
|
||||
### Model Detail Panel
|
||||
|
||||
- Sectioned cards
|
||||
- Metric bars:
|
||||
- Track: #ECEAE6
|
||||
- Fill: #3A3A3A
|
||||
|
||||
Actions:
|
||||
- Primary: Download Model
|
||||
- Secondary: View Logs / Use as Base
|
||||
|
||||
---
|
||||
|
||||
## 10. Micro-interactions (Refined)
|
||||
|
||||
| Element | Interaction | Animation |
|
||||
|------|------------|-----------|
|
||||
| Button hover | BG lightens | 150ms ease-out |
|
||||
| Button press | Scale 0.98 | 100ms |
|
||||
| Row hover | BG fade | 120ms |
|
||||
| Modal open | Fade + scale 0.96 → 1 | 200ms |
|
||||
| Progress fill | Smooth | ease-out |
|
||||
| Annotation select | Border + handles | 120ms |
|
||||
|
||||
---
|
||||
|
||||
## 11. Tailwind Theme (Updated)
|
||||
|
||||
```js
|
||||
colors: {
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
bg: {
|
||||
app: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
},
|
||||
border: '#E6E4E1',
|
||||
accent: '#3A3A3A',
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Final Notes
|
||||
|
||||
- Pure black (#000000) should **never** be used as large surfaces
|
||||
- Accent color usage should stay under **10% of UI area**
|
||||
- Warm grays are intentional and must not be "corrected" to blue-grays
|
||||
|
||||
This theme is designed to scale from internal tool → polished SaaS without redesign.
|
||||
|
||||
273
docs/web-refactoring-complete.md
Normal file
273
docs/web-refactoring-complete.md
Normal file
@@ -0,0 +1,273 @@
|
||||
# Web Directory Refactoring - Complete ✅
|
||||
|
||||
**Date**: 2026-01-25
|
||||
**Status**: ✅ Completed
|
||||
**Tests**: 188 passing (0 failures)
|
||||
**Coverage**: 23% (maintained)
|
||||
|
||||
---
|
||||
|
||||
## Final Directory Structure
|
||||
|
||||
```
|
||||
src/web/
|
||||
├── api/
|
||||
│ ├── __init__.py
|
||||
│ └── v1/
|
||||
│ ├── __init__.py
|
||||
│ ├── routes.py # Public inference API
|
||||
│ ├── admin/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── documents.py # Document management (was admin_routes.py)
|
||||
│ │ ├── annotations.py # Annotation routes (was admin_annotation_routes.py)
|
||||
│ │ └── training.py # Training routes (was admin_training_routes.py)
|
||||
│ ├── async_api/
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── routes.py # Async processing API (was async_routes.py)
|
||||
│ └── batch/
|
||||
│ ├── __init__.py
|
||||
│ └── routes.py # Batch upload API (was batch_upload_routes.py)
|
||||
│
|
||||
├── schemas/
|
||||
│ ├── __init__.py
|
||||
│ ├── common.py # Shared models (ErrorResponse)
|
||||
│ ├── admin.py # Admin schemas (was admin_schemas.py)
|
||||
│ └── inference.py # Inference + async schemas (was schemas.py)
|
||||
│
|
||||
├── services/
|
||||
│ ├── __init__.py
|
||||
│ ├── inference.py # Inference service (was services.py)
|
||||
│ ├── autolabel.py # Auto-label service (was admin_autolabel.py)
|
||||
│ ├── async_processing.py # Async processing (was async_service.py)
|
||||
│ └── batch_upload.py # Batch upload service (was batch_upload_service.py)
|
||||
│
|
||||
├── core/
|
||||
│ ├── __init__.py
|
||||
│ ├── auth.py # Authentication (was admin_auth.py)
|
||||
│ ├── rate_limiter.py # Rate limiting (unchanged)
|
||||
│ └── scheduler.py # Task scheduler (was admin_scheduler.py)
|
||||
│
|
||||
├── workers/
|
||||
│ ├── __init__.py
|
||||
│ ├── async_queue.py # Async task queue (was async_queue.py)
|
||||
│ └── batch_queue.py # Batch task queue (was batch_queue.py)
|
||||
│
|
||||
├── __init__.py # Main exports
|
||||
├── app.py # FastAPI app (imports updated)
|
||||
├── config.py # Configuration (unchanged)
|
||||
└── dependencies.py # Global dependencies (unchanged)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Changes Summary
|
||||
|
||||
### Files Moved and Renamed
|
||||
|
||||
| Old Location | New Location | Change Type |
|
||||
|-------------|--------------|-------------|
|
||||
| `admin_routes.py` | `api/v1/admin/documents.py` | Moved + Renamed |
|
||||
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Moved + Renamed |
|
||||
| `admin_training_routes.py` | `api/v1/admin/training.py` | Moved + Renamed |
|
||||
| `admin_auth.py` | `core/auth.py` | Moved |
|
||||
| `admin_autolabel.py` | `services/autolabel.py` | Moved |
|
||||
| `admin_scheduler.py` | `core/scheduler.py` | Moved |
|
||||
| `admin_schemas.py` | `schemas/admin.py` | Moved |
|
||||
| `routes.py` | `api/v1/routes.py` | Moved |
|
||||
| `schemas.py` | `schemas/inference.py` | Moved |
|
||||
| `services.py` | `services/inference.py` | Moved |
|
||||
| `async_routes.py` | `api/v1/async_api/routes.py` | Moved |
|
||||
| `async_queue.py` | `workers/async_queue.py` | Moved |
|
||||
| `async_service.py` | `services/async_processing.py` | Moved + Renamed |
|
||||
| `batch_queue.py` | `workers/batch_queue.py` | Moved |
|
||||
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Moved |
|
||||
| `batch_upload_service.py` | `services/batch_upload.py` | Moved |
|
||||
|
||||
**Total**: 16 files reorganized
|
||||
|
||||
### Files Updated
|
||||
|
||||
**Source Files** (imports updated):
|
||||
- `app.py` - Updated all imports to new structure
|
||||
- `api/v1/admin/documents.py` - Updated schema/auth imports
|
||||
- `api/v1/admin/annotations.py` - Updated schema/service imports
|
||||
- `api/v1/admin/training.py` - Updated schema/auth imports
|
||||
- `api/v1/routes.py` - Updated schema imports
|
||||
- `api/v1/async_api/routes.py` - Updated schema imports
|
||||
- `api/v1/batch/routes.py` - Updated service/worker imports
|
||||
- `services/async_processing.py` - Updated worker/core imports
|
||||
|
||||
**Test Files** (all 15 updated):
|
||||
- `test_admin_annotations.py`
|
||||
- `test_admin_auth.py`
|
||||
- `test_admin_routes.py`
|
||||
- `test_admin_routes_enhanced.py`
|
||||
- `test_admin_training.py`
|
||||
- `test_annotation_locks.py`
|
||||
- `test_annotation_phase5.py`
|
||||
- `test_async_queue.py`
|
||||
- `test_async_routes.py`
|
||||
- `test_async_service.py`
|
||||
- `test_autolabel_with_locks.py`
|
||||
- `test_batch_queue.py`
|
||||
- `test_batch_upload_routes.py`
|
||||
- `test_batch_upload_service.py`
|
||||
- `test_training_phase4.py`
|
||||
- `conftest.py`
|
||||
|
||||
---
|
||||
|
||||
## Import Examples
|
||||
|
||||
### Old Import Style (Before Refactoring)
|
||||
```python
|
||||
from src.web.admin_routes import create_admin_router
|
||||
from src.web.admin_schemas import DocumentItem
|
||||
from src.web.admin_auth import validate_admin_token
|
||||
from src.web.async_routes import create_async_router
|
||||
from src.web.schemas import ErrorResponse
|
||||
```
|
||||
|
||||
### New Import Style (After Refactoring)
|
||||
```python
|
||||
# Admin API
|
||||
from src.web.api.v1.admin.documents import create_admin_router
|
||||
from src.web.api.v1.admin import create_admin_router # Shorter alternative
|
||||
|
||||
# Schemas
|
||||
from src.web.schemas.admin import DocumentItem
|
||||
from src.web.schemas.common import ErrorResponse
|
||||
|
||||
# Core components
|
||||
from src.web.core.auth import validate_admin_token
|
||||
|
||||
# Async API
|
||||
from src.web.api.v1.async_api.routes import create_async_router
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Benefits Achieved
|
||||
|
||||
### 1. **Clear Separation of Concerns**
|
||||
- **API Routes**: All in `api/v1/` by version and feature
|
||||
- **Data Models**: All in `schemas/` by domain
|
||||
- **Business Logic**: All in `services/`
|
||||
- **Core Components**: Reusable utilities in `core/`
|
||||
- **Background Jobs**: Task queues in `workers/`
|
||||
|
||||
### 2. **Better Scalability**
|
||||
- Easy to add API v2 without touching v1
|
||||
- Clear namespace for each module
|
||||
- Reduced file sizes (no 800+ line files)
|
||||
- Follows single responsibility principle
|
||||
|
||||
### 3. **Improved Maintainability**
|
||||
- Find files by function, not by prefix
|
||||
- Each module has one clear purpose
|
||||
- Easier to onboard new developers
|
||||
- Better IDE navigation
|
||||
|
||||
### 4. **Standards Compliance**
|
||||
- Follows FastAPI best practices
|
||||
- Matches Django/Flask project structures
|
||||
- Standard Python package organization
|
||||
- Industry-standard naming conventions
|
||||
|
||||
---
|
||||
|
||||
## Testing Results
|
||||
|
||||
**Before Refactoring**:
|
||||
- 188 tests passing
|
||||
- 23% code coverage
|
||||
- Flat directory structure
|
||||
|
||||
**After Refactoring**:
|
||||
- ✅ 188 tests passing (0 failures)
|
||||
- ✅ 23% code coverage (maintained)
|
||||
- ✅ Clean hierarchical structure
|
||||
- ✅ All imports updated
|
||||
- ✅ No backward compatibility shims needed
|
||||
|
||||
---
|
||||
|
||||
## Migration Statistics
|
||||
|
||||
| Metric | Count |
|
||||
|--------|-------|
|
||||
| Files moved | 16 |
|
||||
| Directories created | 9 |
|
||||
| Files updated (source) | 8 |
|
||||
| Files updated (tests) | 16 |
|
||||
| Import statements updated | ~150 |
|
||||
| Lines of code changed | ~200 |
|
||||
| Tests broken | 0 |
|
||||
| Coverage lost | 0% |
|
||||
|
||||
---
|
||||
|
||||
## Code Diff Summary
|
||||
|
||||
```diff
|
||||
Before:
|
||||
src/web/
|
||||
├── admin_routes.py (645 lines)
|
||||
├── admin_annotation_routes.py (504 lines)
|
||||
├── admin_training_routes.py (565 lines)
|
||||
├── admin_auth.py (22 lines)
|
||||
├── admin_schemas.py (262 lines)
|
||||
... (15 more files at root level)
|
||||
|
||||
After:
|
||||
src/web/
|
||||
├── api/v1/
|
||||
│ ├── admin/ (3 route files)
|
||||
│ ├── async_api/ (1 route file)
|
||||
│ └── batch/ (1 route file)
|
||||
├── schemas/ (3 schema files)
|
||||
├── services/ (4 service files)
|
||||
├── core/ (3 core files)
|
||||
└── workers/ (2 worker files)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps (Optional)
|
||||
|
||||
### Phase 2: Documentation
|
||||
- [ ] Update API documentation with new import paths
|
||||
- [ ] Create migration guide for external developers
|
||||
- [ ] Update CLAUDE.md with new structure
|
||||
|
||||
### Phase 3: Further Optimization
|
||||
- [ ] Split large files (>400 lines) if needed
|
||||
- [ ] Extract common utilities
|
||||
- [ ] Add typing stubs
|
||||
|
||||
### Phase 4: Deprecation (Future)
|
||||
- [ ] Add deprecation warnings if creating compatibility layer
|
||||
- [ ] Remove old imports after grace period
|
||||
- [ ] Update all documentation
|
||||
|
||||
---
|
||||
|
||||
## Rollback Instructions
|
||||
|
||||
If needed, rollback is simple:
|
||||
```bash
|
||||
git revert <commit-hash>
|
||||
```
|
||||
|
||||
All changes are in version control, making rollback safe and easy.
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
✅ **Refactoring completed successfully**
|
||||
✅ **Zero breaking changes**
|
||||
✅ **All tests passing**
|
||||
✅ **Industry-standard structure achieved**
|
||||
|
||||
The web directory is now organized following Python and FastAPI best practices, making it easier to scale, maintain, and extend.
|
||||
186
docs/web-refactoring-plan.md
Normal file
186
docs/web-refactoring-plan.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Web Directory Refactoring Plan
|
||||
|
||||
## Current Structure Issues
|
||||
|
||||
1. **Flat structure**: All files in one directory (20 Python files)
|
||||
2. **Naming inconsistency**: Mix of `admin_*`, `async_*`, `batch_*` prefixes
|
||||
3. **Mixed concerns**: Routes, schemas, services, and workers in same directory
|
||||
4. **Poor scalability**: Hard to navigate and maintain as project grows
|
||||
|
||||
## Proposed Structure (Best Practices)
|
||||
|
||||
```
|
||||
src/web/
|
||||
├── __init__.py # Main exports
|
||||
├── app.py # FastAPI app factory
|
||||
├── config.py # App configuration
|
||||
├── dependencies.py # Global dependencies
|
||||
│
|
||||
├── api/ # API Routes Layer
|
||||
│ ├── __init__.py
|
||||
│ └── v1/ # API version 1
|
||||
│ ├── __init__.py
|
||||
│ ├── routes.py # Public API routes (inference)
|
||||
│ ├── admin/ # Admin API routes
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── documents.py # admin_routes.py → documents.py
|
||||
│ │ ├── annotations.py # admin_annotation_routes.py → annotations.py
|
||||
│ │ ├── training.py # admin_training_routes.py → training.py
|
||||
│ │ └── auth.py # admin_auth.py → auth.py (routes only)
|
||||
│ ├── async_api/ # Async processing API
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── routes.py # async_routes.py → routes.py
|
||||
│ └── batch/ # Batch upload API
|
||||
│ ├── __init__.py
|
||||
│ └── routes.py # batch_upload_routes.py → routes.py
|
||||
│
|
||||
├── schemas/ # Pydantic Models
|
||||
│ ├── __init__.py
|
||||
│ ├── common.py # Shared schemas (ErrorResponse, etc.)
|
||||
│ ├── inference.py # schemas.py → inference.py
|
||||
│ ├── admin.py # admin_schemas.py → admin.py
|
||||
│ ├── async_api.py # New: async API schemas
|
||||
│ └── batch.py # New: batch upload schemas
|
||||
│
|
||||
├── services/ # Business Logic Layer
|
||||
│ ├── __init__.py
|
||||
│ ├── inference.py # services.py → inference.py
|
||||
│ ├── autolabel.py # admin_autolabel.py → autolabel.py
|
||||
│ ├── async_processing.py # async_service.py → async_processing.py
|
||||
│ └── batch_upload.py # batch_upload_service.py → batch_upload.py
|
||||
│
|
||||
├── core/ # Core Components
|
||||
│ ├── __init__.py
|
||||
│ ├── auth.py # admin_auth.py → auth.py (logic only)
|
||||
│ ├── rate_limiter.py # rate_limiter.py → rate_limiter.py
|
||||
│ └── scheduler.py # admin_scheduler.py → scheduler.py
|
||||
│
|
||||
└── workers/ # Background Task Queues
|
||||
├── __init__.py
|
||||
├── async_queue.py # async_queue.py → async_queue.py
|
||||
└── batch_queue.py # batch_queue.py → batch_queue.py
|
||||
```
|
||||
|
||||
## File Mapping
|
||||
|
||||
### Current → New Location
|
||||
|
||||
| Current File | New Location | Purpose |
|
||||
|--------------|--------------|---------|
|
||||
| `admin_routes.py` | `api/v1/admin/documents.py` | Document management routes |
|
||||
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Annotation routes |
|
||||
| `admin_training_routes.py` | `api/v1/admin/training.py` | Training routes |
|
||||
| `admin_auth.py` | Split: `api/v1/admin/auth.py` + `core/auth.py` | Auth routes + logic |
|
||||
| `admin_schemas.py` | `schemas/admin.py` | Admin Pydantic models |
|
||||
| `admin_autolabel.py` | `services/autolabel.py` | Auto-label service |
|
||||
| `admin_scheduler.py` | `core/scheduler.py` | Training scheduler |
|
||||
| `routes.py` | `api/v1/routes.py` | Public inference API |
|
||||
| `schemas.py` | `schemas/inference.py` | Inference models |
|
||||
| `services.py` | `services/inference.py` | Inference service |
|
||||
| `async_routes.py` | `api/v1/async_api/routes.py` | Async API routes |
|
||||
| `async_service.py` | `services/async_processing.py` | Async processing service |
|
||||
| `async_queue.py` | `workers/async_queue.py` | Async task queue |
|
||||
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Batch upload routes |
|
||||
| `batch_upload_service.py` | `services/batch_upload.py` | Batch upload service |
|
||||
| `batch_queue.py` | `workers/batch_queue.py` | Batch task queue |
|
||||
| `rate_limiter.py` | `core/rate_limiter.py` | Rate limiting logic |
|
||||
| `config.py` | `config.py` | Keep as-is |
|
||||
| `dependencies.py` | `dependencies.py` | Keep as-is |
|
||||
| `app.py` | `app.py` | Keep as-is (update imports) |
|
||||
|
||||
## Benefits
|
||||
|
||||
### 1. Clear Separation of Concerns
|
||||
- **Routes**: API endpoint definitions
|
||||
- **Schemas**: Data validation models
|
||||
- **Services**: Business logic
|
||||
- **Core**: Reusable components
|
||||
- **Workers**: Background processing
|
||||
|
||||
### 2. Better Scalability
|
||||
- Easy to add new API versions (`v2/`)
|
||||
- Clear namespace for each domain
|
||||
- Reduced file size (no 800+ line files)
|
||||
|
||||
### 3. Improved Maintainability
|
||||
- Find files by function, not by prefix
|
||||
- Each module has single responsibility
|
||||
- Easier to write focused tests
|
||||
|
||||
### 4. Standard Python Patterns
|
||||
- Package-based organization
|
||||
- Follows FastAPI best practices
|
||||
- Similar to Django/Flask structures
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### Phase 1: Create New Structure (No Breaking Changes)
|
||||
1. Create new directories: `api/`, `schemas/`, `services/`, `core/`, `workers/`
|
||||
2. Copy files to new locations (don't delete originals yet)
|
||||
3. Update imports in new files
|
||||
4. Add `__init__.py` with proper exports
|
||||
|
||||
### Phase 2: Update Tests
|
||||
5. Update test imports to use new structure
|
||||
6. Run tests to verify nothing breaks
|
||||
7. Fix any import issues
|
||||
|
||||
### Phase 3: Update Main App
|
||||
8. Update `app.py` to import from new locations
|
||||
9. Run full test suite
|
||||
10. Verify all endpoints work
|
||||
|
||||
### Phase 4: Cleanup
|
||||
11. Delete old files
|
||||
12. Update documentation
|
||||
13. Final test run
|
||||
|
||||
## Migration Priority
|
||||
|
||||
**High Priority** (Most used):
|
||||
- Routes and schemas (user-facing APIs)
|
||||
- Services (core business logic)
|
||||
|
||||
**Medium Priority**:
|
||||
- Core components (auth, rate limiter)
|
||||
- Workers (background tasks)
|
||||
|
||||
**Low Priority**:
|
||||
- Config and dependencies (already well-located)
|
||||
|
||||
## Backwards Compatibility
|
||||
|
||||
During migration, maintain backwards compatibility:
|
||||
|
||||
```python
|
||||
# src/web/__init__.py
|
||||
# Old imports still work
|
||||
from src.web.api.v1.admin.documents import router as admin_router
|
||||
from src.web.schemas.admin import AdminDocument
|
||||
|
||||
# Keep old names for compatibility (temporary)
|
||||
admin_routes = admin_router # Deprecated alias
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
1. **Unit Tests**: Test each module independently
|
||||
2. **Integration Tests**: Test API endpoints still work
|
||||
3. **Import Tests**: Verify all old imports still work
|
||||
4. **Coverage**: Maintain current 23% coverage minimum
|
||||
|
||||
## Rollback Plan
|
||||
|
||||
If issues arise:
|
||||
1. Keep old files until fully migrated
|
||||
2. Git allows easy revert
|
||||
3. Tests catch breaking changes early
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
Would you like me to:
|
||||
1. **Start Phase 1**: Create new directory structure and move files?
|
||||
2. **Create migration script**: Automate the file moves and import updates?
|
||||
3. **Focus on specific area**: Start with admin API or async API first?
|
||||
218
docs/web-refactoring-status.md
Normal file
218
docs/web-refactoring-status.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# Web Directory Refactoring - Current Status
|
||||
|
||||
## ✅ Completed Steps
|
||||
|
||||
### 1. Directory Structure Created
|
||||
```
|
||||
src/web/
|
||||
├── api/
|
||||
│ ├── v1/
|
||||
│ │ ├── admin/ (documents.py, annotations.py, training.py)
|
||||
│ │ ├── async_api/ (routes.py)
|
||||
│ │ ├── batch/ (routes.py)
|
||||
│ │ └── routes.py (public inference API)
|
||||
├── schemas/
|
||||
│ ├── admin.py (admin schemas)
|
||||
│ ├── inference.py (inference + async schemas)
|
||||
│ └── common.py (ErrorResponse)
|
||||
├── services/
|
||||
│ ├── autolabel.py
|
||||
│ ├── async_processing.py
|
||||
│ ├── batch_upload.py
|
||||
│ └── inference.py
|
||||
├── core/
|
||||
│ ├── auth.py
|
||||
│ ├── rate_limiter.py
|
||||
│ └── scheduler.py
|
||||
└── workers/
|
||||
├── async_queue.py
|
||||
└── batch_queue.py
|
||||
```
|
||||
|
||||
### 2. Files Copied and Imports Updated
|
||||
|
||||
#### Admin API (✅ Complete)
|
||||
- [x] `admin_routes.py` → `api/v1/admin/documents.py` (imports updated)
|
||||
- [x] `admin_annotation_routes.py` → `api/v1/admin/annotations.py` (imports updated)
|
||||
- [x] `admin_training_routes.py` → `api/v1/admin/training.py` (imports updated)
|
||||
- [x] `api/v1/admin/__init__.py` created with exports
|
||||
|
||||
#### Public & Async API (✅ Complete)
|
||||
- [x] `routes.py` → `api/v1/routes.py` (imports updated)
|
||||
- [x] `async_routes.py` → `api/v1/async_api/routes.py` (imports updated)
|
||||
- [x] `batch_upload_routes.py` → `api/v1/batch/routes.py` (copied, imports pending)
|
||||
|
||||
#### Schemas (✅ Complete)
|
||||
- [x] `admin_schemas.py` → `schemas/admin.py`
|
||||
- [x] `schemas.py` → `schemas/inference.py`
|
||||
- [x] `schemas/common.py` created
|
||||
- [x] `schemas/__init__.py` created with exports
|
||||
|
||||
#### Services (✅ Complete)
|
||||
- [x] `admin_autolabel.py` → `services/autolabel.py`
|
||||
- [x] `async_service.py` → `services/async_processing.py`
|
||||
- [x] `batch_upload_service.py` → `services/batch_upload.py`
|
||||
- [x] `services.py` → `services/inference.py`
|
||||
- [x] `services/__init__.py` created
|
||||
|
||||
#### Core Components (✅ Complete)
|
||||
- [x] `admin_auth.py` → `core/auth.py`
|
||||
- [x] `rate_limiter.py` → `core/rate_limiter.py`
|
||||
- [x] `admin_scheduler.py` → `core/scheduler.py`
|
||||
- [x] `core/__init__.py` created
|
||||
|
||||
#### Workers (✅ Complete)
|
||||
- [x] `async_queue.py` → `workers/async_queue.py`
|
||||
- [x] `batch_queue.py` → `workers/batch_queue.py`
|
||||
- [x] `workers/__init__.py` created
|
||||
|
||||
#### Main App (✅ Complete)
|
||||
- [x] `app.py` imports updated to use new structure
|
||||
|
||||
---
|
||||
|
||||
## ⏳ Remaining Work
|
||||
|
||||
### 1. Update Remaining File Imports (HIGH PRIORITY)
|
||||
|
||||
Files that need import updates:
|
||||
- [ ] `api/v1/batch/routes.py` - update to use new schema/service imports
|
||||
- [ ] `services/autolabel.py` - may need import updates if it references old paths
|
||||
- [ ] `services/async_processing.py` - check for old import references
|
||||
- [ ] `services/batch_upload.py` - check for old import references
|
||||
- [ ] `services/inference.py` - check for old import references
|
||||
|
||||
### 2. Update ALL Test Files (CRITICAL)
|
||||
|
||||
Test files need to import from new locations. Pattern:
|
||||
|
||||
**Old:**
|
||||
```python
|
||||
from src.web.admin_routes import create_admin_router
|
||||
from src.web.admin_schemas import DocumentItem
|
||||
from src.web.admin_auth import validate_admin_token
|
||||
```
|
||||
|
||||
**New:**
|
||||
```python
|
||||
from src.web.api.v1.admin import create_admin_router
|
||||
from src.web.schemas.admin import DocumentItem
|
||||
from src.web.core.auth import validate_admin_token
|
||||
```
|
||||
|
||||
Test files to update:
|
||||
- [ ] `tests/web/test_admin_annotations.py`
|
||||
- [ ] `tests/web/test_admin_auth.py`
|
||||
- [ ] `tests/web/test_admin_routes.py`
|
||||
- [ ] `tests/web/test_admin_routes_enhanced.py`
|
||||
- [ ] `tests/web/test_admin_training.py`
|
||||
- [ ] `tests/web/test_annotation_locks.py`
|
||||
- [ ] `tests/web/test_annotation_phase5.py`
|
||||
- [ ] `tests/web/test_async_queue.py`
|
||||
- [ ] `tests/web/test_async_routes.py`
|
||||
- [ ] `tests/web/test_async_service.py`
|
||||
- [ ] `tests/web/test_autolabel_with_locks.py`
|
||||
- [ ] `tests/web/test_batch_queue.py`
|
||||
- [ ] `tests/web/test_batch_upload_routes.py`
|
||||
- [ ] `tests/web/test_batch_upload_service.py`
|
||||
- [ ] `tests/web/test_rate_limiter.py`
|
||||
- [ ] `tests/web/test_training_phase4.py`
|
||||
|
||||
### 3. Create Backward Compatibility Layer (OPTIONAL)
|
||||
|
||||
Keep old imports working temporarily:
|
||||
|
||||
```python
|
||||
# src/web/admin_routes.py (temporary compatibility shim)
|
||||
\"\"\"
|
||||
DEPRECATED: Use src.web.api.v1.admin.documents instead.
|
||||
This file will be removed in next version.
|
||||
\"\"\"
|
||||
import warnings
|
||||
from src.web.api.v1.admin.documents import *
|
||||
|
||||
warnings.warn(
|
||||
"Importing from src.web.admin_routes is deprecated. "
|
||||
"Use src.web.api.v1.admin.documents instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Verify and Test
|
||||
|
||||
1. Run tests:
|
||||
```bash
|
||||
pytest tests/web/ -v
|
||||
```
|
||||
|
||||
2. Check for any import errors:
|
||||
```bash
|
||||
python -c "from src.web.app import create_app; create_app()"
|
||||
```
|
||||
|
||||
3. Start server and test endpoints:
|
||||
```bash
|
||||
python run_server.py
|
||||
```
|
||||
|
||||
### 5. Clean Up Old Files (ONLY AFTER TESTS PASS)
|
||||
|
||||
Old files to remove:
|
||||
- `src/web/admin_*.py` (7 files)
|
||||
- `src/web/async_*.py` (3 files)
|
||||
- `src/web/batch_*.py` (3 files)
|
||||
- `src/web/routes.py`
|
||||
- `src/web/services.py`
|
||||
- `src/web/schemas.py`
|
||||
- `src/web/rate_limiter.py`
|
||||
|
||||
Keep these files (don't remove):
|
||||
- `src/web/__init__.py`
|
||||
- `src/web/app.py`
|
||||
- `src/web/config.py`
|
||||
- `src/web/dependencies.py`
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Next Immediate Steps
|
||||
|
||||
1. **Update batch/routes.py imports** - Quick fix for remaining API route
|
||||
2. **Update test file imports** - Critical for verification
|
||||
3. **Run test suite** - Verify nothing broke
|
||||
4. **Fix any import errors** - Address failures
|
||||
5. **Remove old files** - Clean up after tests pass
|
||||
|
||||
---
|
||||
|
||||
## 📊 Migration Impact Summary
|
||||
|
||||
| Category | Files Moved | Imports Updated | Status |
|
||||
|----------|-------------|-----------------|--------|
|
||||
| API Routes | 7 | 5/7 | 🟡 In Progress |
|
||||
| Schemas | 3 | 3/3 | ✅ Complete |
|
||||
| Services | 4 | 0/4 | ⚠️ Pending |
|
||||
| Core | 3 | 3/3 | ✅ Complete |
|
||||
| Workers | 2 | 2/2 | ✅ Complete |
|
||||
| Tests | 0 | 0/16 | ❌ Not Started |
|
||||
|
||||
**Overall Progress: 65%**
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Benefits After Migration
|
||||
|
||||
1. **Better Organization**: Clear separation by function
|
||||
2. **Easier Navigation**: Find files by purpose, not prefix
|
||||
3. **Scalability**: Easy to add new API versions
|
||||
4. **Standard Structure**: Follows FastAPI best practices
|
||||
5. **Maintainability**: Each module has single responsibility
|
||||
|
||||
---
|
||||
|
||||
## 📝 Notes
|
||||
|
||||
- All original files are still in place (no data loss risk)
|
||||
- New structure is operational but needs import updates
|
||||
- Backward compatibility can be added if needed
|
||||
- Tests will validate the migration success
|
||||
5
frontend/.env.example
Normal file
5
frontend/.env.example
Normal file
@@ -0,0 +1,5 @@
|
||||
# Backend API URL
|
||||
VITE_API_URL=http://localhost:8000
|
||||
|
||||
# WebSocket URL (for future real-time updates)
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
24
frontend/.gitignore
vendored
Normal file
24
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
20
frontend/README.md
Normal file
20
frontend/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
<div align="center">
|
||||
<img width="1200" height="475" alt="GHBanner" src="https://github.com/user-attachments/assets/0aa67016-6eaf-458a-adb2-6e31a0763ed6" />
|
||||
</div>
|
||||
|
||||
# Run and deploy your AI Studio app
|
||||
|
||||
This contains everything you need to run your app locally.
|
||||
|
||||
View your app in AI Studio: https://ai.studio/apps/drive/13hqd80ft4g_LngMYB8LLJxx2XU8C_eI4
|
||||
|
||||
## Run Locally
|
||||
|
||||
**Prerequisites:** Node.js
|
||||
|
||||
|
||||
1. Install dependencies:
|
||||
`npm install`
|
||||
2. Set the `GEMINI_API_KEY` in [.env.local](.env.local) to your Gemini API key
|
||||
3. Run the app:
|
||||
`npm run dev`
|
||||
240
frontend/REFACTORING_PLAN.md
Normal file
240
frontend/REFACTORING_PLAN.md
Normal file
@@ -0,0 +1,240 @@
|
||||
# Frontend Refactoring Plan
|
||||
|
||||
## Current Structure Issues
|
||||
|
||||
1. **Flat component organization** - All components in one directory
|
||||
2. **Mock data only** - No real API integration
|
||||
3. **No state management** - Props drilling everywhere
|
||||
4. **CDN dependencies** - Should use npm packages
|
||||
5. **Manual routing** - Using useState instead of react-router
|
||||
6. **No TypeScript integration with backend** - Types don't match API schemas
|
||||
|
||||
## Recommended Structure
|
||||
|
||||
```
|
||||
frontend/
|
||||
├── public/
|
||||
│ └── favicon.ico
|
||||
│
|
||||
├── src/
|
||||
│ ├── api/ # API Layer
|
||||
│ │ ├── client.ts # Axios instance + interceptors
|
||||
│ │ ├── types.ts # API request/response types
|
||||
│ │ └── endpoints/
|
||||
│ │ ├── documents.ts # GET /api/v1/admin/documents
|
||||
│ │ ├── annotations.ts # GET/POST /api/v1/admin/documents/{id}/annotations
|
||||
│ │ ├── training.ts # GET/POST /api/v1/admin/training/*
|
||||
│ │ ├── inference.ts # POST /api/v1/infer
|
||||
│ │ └── async.ts # POST /api/v1/async/submit
|
||||
│ │
|
||||
│ ├── components/
|
||||
│ │ ├── common/ # Reusable components
|
||||
│ │ │ ├── Badge.tsx
|
||||
│ │ │ ├── Button.tsx
|
||||
│ │ │ ├── Input.tsx
|
||||
│ │ │ ├── Modal.tsx
|
||||
│ │ │ ├── Table.tsx
|
||||
│ │ │ ├── ProgressBar.tsx
|
||||
│ │ │ └── StatusBadge.tsx
|
||||
│ │ │
|
||||
│ │ ├── layout/ # Layout components
|
||||
│ │ │ ├── TopNav.tsx
|
||||
│ │ │ ├── Sidebar.tsx
|
||||
│ │ │ └── PageHeader.tsx
|
||||
│ │ │
|
||||
│ │ ├── documents/ # Document-specific components
|
||||
│ │ │ ├── DocumentTable.tsx
|
||||
│ │ │ ├── DocumentFilters.tsx
|
||||
│ │ │ ├── DocumentRow.tsx
|
||||
│ │ │ ├── UploadModal.tsx
|
||||
│ │ │ └── BatchUploadModal.tsx
|
||||
│ │ │
|
||||
│ │ ├── annotations/ # Annotation components
|
||||
│ │ │ ├── AnnotationCanvas.tsx
|
||||
│ │ │ ├── AnnotationBox.tsx
|
||||
│ │ │ ├── AnnotationTable.tsx
|
||||
│ │ │ ├── FieldEditor.tsx
|
||||
│ │ │ └── VerificationPanel.tsx
|
||||
│ │ │
|
||||
│ │ └── training/ # Training components
|
||||
│ │ ├── DocumentSelector.tsx
|
||||
│ │ ├── TrainingConfig.tsx
|
||||
│ │ ├── TrainingJobList.tsx
|
||||
│ │ ├── ModelCard.tsx
|
||||
│ │ └── MetricsChart.tsx
|
||||
│ │
|
||||
│ ├── pages/ # Page-level components
|
||||
│ │ ├── DocumentsPage.tsx # Was Dashboard.tsx
|
||||
│ │ ├── DocumentDetailPage.tsx # Was DocumentDetail.tsx
|
||||
│ │ ├── TrainingPage.tsx # Was Training.tsx
|
||||
│ │ ├── ModelsPage.tsx # Was Models.tsx
|
||||
│ │ └── InferencePage.tsx # New: Test inference
|
||||
│ │
|
||||
│ ├── hooks/ # Custom React Hooks
|
||||
│ │ ├── useDocuments.ts # Document CRUD + listing
|
||||
│ │ ├── useAnnotations.ts # Annotation management
|
||||
│ │ ├── useTraining.ts # Training jobs
|
||||
│ │ ├── usePolling.ts # Auto-refresh for async jobs
|
||||
│ │ └── useDebounce.ts # Debounce search inputs
|
||||
│ │
|
||||
│ ├── store/ # State Management (Zustand)
|
||||
│ │ ├── documentsStore.ts
|
||||
│ │ ├── annotationsStore.ts
|
||||
│ │ ├── trainingStore.ts
|
||||
│ │ └── uiStore.ts
|
||||
│ │
|
||||
│ ├── types/ # TypeScript Types
|
||||
│ │ ├── index.ts
|
||||
│ │ ├── document.ts
|
||||
│ │ ├── annotation.ts
|
||||
│ │ ├── training.ts
|
||||
│ │ └── api.ts
|
||||
│ │
|
||||
│ ├── utils/ # Utility Functions
|
||||
│ │ ├── formatters.ts # Date, currency, etc.
|
||||
│ │ ├── validators.ts # Form validation
|
||||
│ │ └── constants.ts # Field definitions, statuses
|
||||
│ │
|
||||
│ ├── styles/
|
||||
│ │ └── index.css # Tailwind entry
|
||||
│ │
|
||||
│ ├── App.tsx
|
||||
│ ├── main.tsx
|
||||
│ └── router.tsx # React Router config
|
||||
│
|
||||
├── .env.example
|
||||
├── package.json
|
||||
├── tsconfig.json
|
||||
├── vite.config.ts
|
||||
├── tailwind.config.js
|
||||
├── postcss.config.js
|
||||
└── index.html
|
||||
```
|
||||
|
||||
## Migration Steps
|
||||
|
||||
### Phase 1: Setup Infrastructure
|
||||
- [ ] Install dependencies (axios, react-router, zustand, @tanstack/react-query)
|
||||
- [ ] Setup local Tailwind (remove CDN)
|
||||
- [ ] Create API client with interceptors
|
||||
- [ ] Add environment variables (.env.local with VITE_API_URL)
|
||||
|
||||
### Phase 2: Create API Layer
|
||||
- [ ] Create `src/api/client.ts` with axios instance
|
||||
- [ ] Create `src/api/endpoints/documents.ts` matching backend API
|
||||
- [ ] Create `src/api/endpoints/annotations.ts`
|
||||
- [ ] Create `src/api/endpoints/training.ts`
|
||||
- [ ] Add types matching backend schemas
|
||||
|
||||
### Phase 3: Reorganize Components
|
||||
- [ ] Move existing components to new structure
|
||||
- [ ] Split large components (Dashboard > DocumentTable + DocumentFilters + DocumentRow)
|
||||
- [ ] Extract reusable components (Badge, Button already done)
|
||||
- [ ] Create layout components (TopNav, Sidebar)
|
||||
|
||||
### Phase 4: Add Routing
|
||||
- [ ] Install react-router-dom
|
||||
- [ ] Create router.tsx with routes
|
||||
- [ ] Update App.tsx to use RouterProvider
|
||||
- [ ] Add navigation links
|
||||
|
||||
### Phase 5: State Management
|
||||
- [ ] Create custom hooks (useDocuments, useAnnotations)
|
||||
- [ ] Use @tanstack/react-query for server state
|
||||
- [ ] Add Zustand stores for UI state
|
||||
- [ ] Replace mock data with API calls
|
||||
|
||||
### Phase 6: Backend Integration
|
||||
- [ ] Update CORS settings in backend
|
||||
- [ ] Test all API endpoints
|
||||
- [ ] Add error handling
|
||||
- [ ] Add loading states
|
||||
|
||||
## Dependencies to Add
|
||||
|
||||
```json
|
||||
{
|
||||
"dependencies": {
|
||||
"react-router-dom": "^6.22.0",
|
||||
"axios": "^1.6.7",
|
||||
"zustand": "^4.5.0",
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"clsx": "^2.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tailwindcss": "^3.4.1",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"postcss": "^8.4.35"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Files to Create
|
||||
|
||||
### tailwind.config.js
|
||||
```javascript
|
||||
export default {
|
||||
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
warm: {
|
||||
bg: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
border: '#E6E4E1',
|
||||
divider: '#D8D6D2',
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
state: {
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
info: '#3A3A3A',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### .env.example
|
||||
```bash
|
||||
VITE_API_URL=http://localhost:8000
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
```
|
||||
|
||||
## Type Generation from Backend
|
||||
|
||||
Consider generating TypeScript types from Python Pydantic schemas:
|
||||
- Option 1: Use `datamodel-code-generator` to convert schemas
|
||||
- Option 2: Manually maintain types in `src/types/api.ts`
|
||||
- Option 3: Use OpenAPI spec + openapi-typescript-codegen
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Unit tests: Vitest for components
|
||||
- Integration tests: React Testing Library
|
||||
- E2E tests: Playwright (matching backend)
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Code splitting by route
|
||||
- Lazy load heavy components (AnnotationCanvas)
|
||||
- Optimize re-renders with React.memo
|
||||
- Use virtual scrolling for large tables
|
||||
- Image lazy loading for document previews
|
||||
|
||||
## Accessibility
|
||||
|
||||
- Proper ARIA labels
|
||||
- Keyboard navigation
|
||||
- Focus management
|
||||
- Color contrast compliance (already done with Warm Graphite theme)
|
||||
256
frontend/SETUP.md
Normal file
256
frontend/SETUP.md
Normal file
@@ -0,0 +1,256 @@
|
||||
# Frontend Setup Guide
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Install Dependencies
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
```
|
||||
|
||||
### 2. Configure Environment
|
||||
|
||||
Copy `.env.example` to `.env.local` and update if needed:
|
||||
|
||||
```bash
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
Default configuration:
|
||||
```
|
||||
VITE_API_URL=http://localhost:8000
|
||||
VITE_WS_URL=ws://localhost:8000/ws
|
||||
```
|
||||
|
||||
### 3. Start Backend API
|
||||
|
||||
Make sure the backend is running first:
|
||||
|
||||
```bash
|
||||
# From project root
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python run_server.py"
|
||||
```
|
||||
|
||||
Backend will be available at: http://localhost:8000
|
||||
|
||||
### 4. Start Frontend Dev Server
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Frontend will be available at: http://localhost:3000
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
frontend/
|
||||
├── src/
|
||||
│ ├── api/ # API client layer
|
||||
│ │ ├── client.ts # Axios instance with interceptors
|
||||
│ │ ├── types.ts # API type definitions
|
||||
│ │ └── endpoints/
|
||||
│ │ ├── documents.ts # Document API calls
|
||||
│ │ ├── annotations.ts # Annotation API calls
|
||||
│ │ └── training.ts # Training API calls
|
||||
│ │
|
||||
│ ├── components/ # React components
|
||||
│ │ └── Dashboard.tsx # Updated with real API integration
|
||||
│ │
|
||||
│ ├── hooks/ # Custom React Hooks
|
||||
│ │ ├── useDocuments.ts
|
||||
│ │ ├── useDocumentDetail.ts
|
||||
│ │ ├── useAnnotations.ts
|
||||
│ │ └── useTraining.ts
|
||||
│ │
|
||||
│ ├── styles/
|
||||
│ │ └── index.css # Tailwind CSS entry
|
||||
│ │
|
||||
│ ├── App.tsx
|
||||
│ └── main.tsx # App entry point with QueryClient
|
||||
│
|
||||
├── components/ # Legacy components (to be migrated)
|
||||
│ ├── Badge.tsx
|
||||
│ ├── Button.tsx
|
||||
│ ├── Layout.tsx
|
||||
│ ├── DocumentDetail.tsx
|
||||
│ ├── Training.tsx
|
||||
│ ├── Models.tsx
|
||||
│ └── UploadModal.tsx
|
||||
│
|
||||
├── tailwind.config.js # Tailwind configuration
|
||||
├── postcss.config.js
|
||||
├── vite.config.ts
|
||||
├── package.json
|
||||
└── index.html
|
||||
```
|
||||
|
||||
## Key Technologies
|
||||
|
||||
- **React 19** - UI framework
|
||||
- **TypeScript** - Type safety
|
||||
- **Vite** - Build tool
|
||||
- **Tailwind CSS** - Styling (Warm Graphite theme)
|
||||
- **Axios** - HTTP client
|
||||
- **@tanstack/react-query** - Server state management
|
||||
- **lucide-react** - Icon library
|
||||
|
||||
## API Integration
|
||||
|
||||
### Authentication
|
||||
|
||||
The app stores admin token in localStorage:
|
||||
|
||||
```typescript
|
||||
localStorage.setItem('admin_token', 'your-token')
|
||||
```
|
||||
|
||||
All API requests automatically include the `X-Admin-Token` header.
|
||||
|
||||
### Available Hooks
|
||||
|
||||
#### useDocuments
|
||||
|
||||
```typescript
|
||||
const {
|
||||
documents,
|
||||
total,
|
||||
isLoading,
|
||||
uploadDocument,
|
||||
deleteDocument,
|
||||
triggerAutoLabel,
|
||||
} = useDocuments({ status: 'labeled', limit: 20 })
|
||||
```
|
||||
|
||||
#### useDocumentDetail
|
||||
|
||||
```typescript
|
||||
const { document, annotations, isLoading } = useDocumentDetail(documentId)
|
||||
```
|
||||
|
||||
#### useAnnotations
|
||||
|
||||
```typescript
|
||||
const {
|
||||
createAnnotation,
|
||||
updateAnnotation,
|
||||
deleteAnnotation,
|
||||
verifyAnnotation,
|
||||
overrideAnnotation,
|
||||
} = useAnnotations(documentId)
|
||||
```
|
||||
|
||||
#### useTraining
|
||||
|
||||
```typescript
|
||||
const {
|
||||
models,
|
||||
isLoadingModels,
|
||||
startTraining,
|
||||
downloadModel,
|
||||
} = useTraining()
|
||||
```
|
||||
|
||||
## Features Implemented
|
||||
|
||||
### Phase 1 (Completed)
|
||||
- ✅ API client with axios interceptors
|
||||
- ✅ Type-safe API endpoints
|
||||
- ✅ React Query for server state
|
||||
- ✅ Custom hooks for all APIs
|
||||
- ✅ Dashboard with real data
|
||||
- ✅ Local Tailwind CSS
|
||||
- ✅ Environment configuration
|
||||
- ✅ CORS configured in backend
|
||||
|
||||
### Phase 2 (TODO)
|
||||
- [ ] Update DocumentDetail to use useDocumentDetail
|
||||
- [ ] Update Training page to use useTraining hooks
|
||||
- [ ] Update Models page with real data
|
||||
- [ ] Add UploadModal integration with API
|
||||
- [ ] Add react-router for proper routing
|
||||
- [ ] Add error boundary
|
||||
- [ ] Add loading states
|
||||
- [ ] Add toast notifications
|
||||
|
||||
### Phase 3 (TODO)
|
||||
- [ ] Annotation canvas with real data
|
||||
- [ ] Batch upload functionality
|
||||
- [ ] Auto-label progress polling
|
||||
- [ ] Training job monitoring
|
||||
- [ ] Model download functionality
|
||||
- [ ] Search and filtering
|
||||
- [ ] Pagination
|
||||
|
||||
## Development Tips
|
||||
|
||||
### Hot Module Replacement
|
||||
|
||||
Vite supports HMR. Changes will reflect immediately without page reload.
|
||||
|
||||
### API Debugging
|
||||
|
||||
Check browser console for API requests:
|
||||
- Network tab shows all requests/responses
|
||||
- Axios interceptors log errors automatically
|
||||
|
||||
### Type Safety
|
||||
|
||||
TypeScript types in `src/api/types.ts` match backend Pydantic schemas.
|
||||
|
||||
To regenerate types from backend:
|
||||
```bash
|
||||
# TODO: Add type generation script
|
||||
```
|
||||
|
||||
### Backend API Documentation
|
||||
|
||||
Visit http://localhost:8000/docs for interactive API documentation (Swagger UI).
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### CORS Errors
|
||||
|
||||
If you see CORS errors:
|
||||
1. Check backend is running at http://localhost:8000
|
||||
2. Verify CORS settings in `src/web/app.py`
|
||||
3. Check `.env.local` has correct `VITE_API_URL`
|
||||
|
||||
### Module Not Found
|
||||
|
||||
If imports fail:
|
||||
```bash
|
||||
rm -rf node_modules package-lock.json
|
||||
npm install
|
||||
```
|
||||
|
||||
### Types Not Matching
|
||||
|
||||
If API responses don't match types:
|
||||
1. Check backend version is up-to-date
|
||||
2. Verify types in `src/api/types.ts`
|
||||
3. Check API response in Network tab
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Run `npm install` to install dependencies
|
||||
2. Start backend server
|
||||
3. Run `npm run dev` to start frontend
|
||||
4. Open http://localhost:3000
|
||||
5. Create an admin token via backend API
|
||||
6. Store token in localStorage via browser console:
|
||||
```javascript
|
||||
localStorage.setItem('admin_token', 'your-token-here')
|
||||
```
|
||||
7. Refresh page to see authenticated API calls
|
||||
|
||||
## Production Build
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
npm run preview # Preview production build
|
||||
```
|
||||
|
||||
Build output will be in `dist/` directory.
|
||||
15
frontend/index.html
Normal file
15
frontend/index.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Graphite Annotator - Invoice Field Extraction</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
5
frontend/metadata.json
Normal file
5
frontend/metadata.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Graphite Annotator",
|
||||
"description": "A professional, warm graphite themed document annotation and training tool for enterprise use cases.",
|
||||
"requestFramePermissions": []
|
||||
}
|
||||
4899
frontend/package-lock.json
generated
Normal file
4899
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
41
frontend/package.json
Normal file
41
frontend/package.json
Normal file
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"name": "graphite-annotator",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:coverage": "vitest run --coverage"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tanstack/react-query": "^5.20.0",
|
||||
"axios": "^1.6.7",
|
||||
"clsx": "^2.1.0",
|
||||
"date-fns": "^3.3.0",
|
||||
"lucide-react": "^0.563.0",
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"react-router-dom": "^6.22.0",
|
||||
"recharts": "^3.7.0",
|
||||
"zustand": "^4.5.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/react": "^16.3.2",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
"@types/node": "^22.14.0",
|
||||
"@vitejs/plugin-react": "^5.0.0",
|
||||
"@vitest/coverage-v8": "^4.0.18",
|
||||
"autoprefixer": "^10.4.17",
|
||||
"jsdom": "^27.4.0",
|
||||
"postcss": "^8.4.35",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"typescript": "~5.8.2",
|
||||
"vite": "^6.2.0",
|
||||
"vitest": "^4.0.18"
|
||||
}
|
||||
}
|
||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
||||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
81
frontend/src/App.tsx
Normal file
81
frontend/src/App.tsx
Normal file
@@ -0,0 +1,81 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { Layout } from './components/Layout'
|
||||
import { DashboardOverview } from './components/DashboardOverview'
|
||||
import { Dashboard } from './components/Dashboard'
|
||||
import { DocumentDetail } from './components/DocumentDetail'
|
||||
import { Training } from './components/Training'
|
||||
import { DatasetDetail } from './components/DatasetDetail'
|
||||
import { Models } from './components/Models'
|
||||
import { Login } from './components/Login'
|
||||
import { InferenceDemo } from './components/InferenceDemo'
|
||||
|
||||
const App: React.FC = () => {
|
||||
const [currentView, setCurrentView] = useState('dashboard')
|
||||
const [selectedDocId, setSelectedDocId] = useState<string | null>(null)
|
||||
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
const token = localStorage.getItem('admin_token')
|
||||
setIsAuthenticated(!!token)
|
||||
}, [])
|
||||
|
||||
const handleNavigate = (view: string, docId?: string) => {
|
||||
setCurrentView(view)
|
||||
if (docId) {
|
||||
setSelectedDocId(docId)
|
||||
}
|
||||
}
|
||||
|
||||
const handleLogin = (token: string) => {
|
||||
setIsAuthenticated(true)
|
||||
}
|
||||
|
||||
const handleLogout = () => {
|
||||
localStorage.removeItem('admin_token')
|
||||
setIsAuthenticated(false)
|
||||
setCurrentView('documents')
|
||||
}
|
||||
|
||||
if (!isAuthenticated) {
|
||||
return <Login onLogin={handleLogin} />
|
||||
}
|
||||
|
||||
const renderContent = () => {
|
||||
switch (currentView) {
|
||||
case 'dashboard':
|
||||
return <DashboardOverview onNavigate={handleNavigate} />
|
||||
case 'documents':
|
||||
return <Dashboard onNavigate={handleNavigate} />
|
||||
case 'detail':
|
||||
return (
|
||||
<DocumentDetail
|
||||
docId={selectedDocId || '1'}
|
||||
onBack={() => setCurrentView('documents')}
|
||||
/>
|
||||
)
|
||||
case 'demo':
|
||||
return <InferenceDemo />
|
||||
case 'training':
|
||||
return <Training onNavigate={handleNavigate} />
|
||||
case 'dataset-detail':
|
||||
return (
|
||||
<DatasetDetail
|
||||
datasetId={selectedDocId || ''}
|
||||
onBack={() => setCurrentView('training')}
|
||||
/>
|
||||
)
|
||||
case 'models':
|
||||
return <Models />
|
||||
default:
|
||||
return <DashboardOverview onNavigate={handleNavigate} />
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Layout activeView={currentView} onNavigate={handleNavigate} onLogout={handleLogout}>
|
||||
{renderContent()}
|
||||
</Layout>
|
||||
)
|
||||
}
|
||||
|
||||
export default App
|
||||
41
frontend/src/api/client.ts
Normal file
41
frontend/src/api/client.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import axios, { AxiosInstance, AxiosError } from 'axios'
|
||||
|
||||
const apiClient: AxiosInstance = axios.create({
|
||||
baseURL: import.meta.env.VITE_API_URL || 'http://localhost:8000',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 30000,
|
||||
})
|
||||
|
||||
apiClient.interceptors.request.use(
|
||||
(config) => {
|
||||
const token = localStorage.getItem('admin_token')
|
||||
if (token) {
|
||||
config.headers['X-Admin-Token'] = token
|
||||
}
|
||||
return config
|
||||
},
|
||||
(error) => {
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
apiClient.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error: AxiosError) => {
|
||||
if (error.response?.status === 401) {
|
||||
console.warn('Authentication required. Please set admin_token in localStorage.')
|
||||
// Don't redirect to avoid infinite loop
|
||||
// User should manually set: localStorage.setItem('admin_token', 'your-token')
|
||||
}
|
||||
|
||||
if (error.response?.status === 429) {
|
||||
console.error('Rate limit exceeded')
|
||||
}
|
||||
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
export default apiClient
|
||||
66
frontend/src/api/endpoints/annotations.ts
Normal file
66
frontend/src/api/endpoints/annotations.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
AnnotationItem,
|
||||
CreateAnnotationRequest,
|
||||
AnnotationOverrideRequest,
|
||||
} from '../types'
|
||||
|
||||
export const annotationsApi = {
|
||||
list: async (documentId: string): Promise<AnnotationItem[]> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/documents/${documentId}/annotations`
|
||||
)
|
||||
return data.annotations
|
||||
},
|
||||
|
||||
create: async (
|
||||
documentId: string,
|
||||
annotation: CreateAnnotationRequest
|
||||
): Promise<AnnotationItem> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/annotations`,
|
||||
annotation
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
update: async (
|
||||
documentId: string,
|
||||
annotationId: string,
|
||||
updates: Partial<CreateAnnotationRequest>
|
||||
): Promise<AnnotationItem> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`,
|
||||
updates
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (documentId: string, annotationId: string): Promise<void> => {
|
||||
await apiClient.delete(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`
|
||||
)
|
||||
},
|
||||
|
||||
verify: async (
|
||||
documentId: string,
|
||||
annotationId: string
|
||||
): Promise<{ annotation_id: string; is_verified: boolean; message: string }> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/verify`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
override: async (
|
||||
documentId: string,
|
||||
annotationId: string,
|
||||
overrideData: AnnotationOverrideRequest
|
||||
): Promise<{ annotation_id: string; source: string; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/override`,
|
||||
overrideData
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
118
frontend/src/api/endpoints/augmentation.test.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
/**
|
||||
* Tests for augmentation API endpoints.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { augmentationApi } from './augmentation'
|
||||
import apiClient from '../client'
|
||||
|
||||
// Mock the API client
|
||||
vi.mock('../client', () => ({
|
||||
default: {
|
||||
get: vi.fn(),
|
||||
post: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('augmentationApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getTypes', () => {
|
||||
it('should fetch augmentation types', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.getTypes()
|
||||
|
||||
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types')
|
||||
expect(result.augmentation_types).toHaveLength(1)
|
||||
expect(result.augmentation_types[0].name).toBe('gaussian_noise')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getPresets', () => {
|
||||
it('should fetch augmentation presets', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations' },
|
||||
{ name: 'moderate', description: 'Balanced augmentations' },
|
||||
],
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.getPresets()
|
||||
|
||||
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets')
|
||||
expect(result.presets).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('preview', () => {
|
||||
it('should preview single augmentation', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
preview_url: '',
|
||||
original_url: '',
|
||||
applied_params: { std: 15 },
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const result = await augmentationApi.preview('doc-123', {
|
||||
augmentation_type: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
})
|
||||
|
||||
expect(apiClient.post).toHaveBeenCalledWith(
|
||||
'/api/v1/admin/augmentation/preview/doc-123',
|
||||
{
|
||||
augmentation_type: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
},
|
||||
{ params: { page: 1 } }
|
||||
)
|
||||
expect(result.preview_url).toBe('')
|
||||
})
|
||||
|
||||
it('should support custom page number', async () => {
|
||||
const mockResponse = {
|
||||
data: {
|
||||
preview_url: '',
|
||||
original_url: '',
|
||||
applied_params: {},
|
||||
},
|
||||
}
|
||||
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
await augmentationApi.preview(
|
||||
'doc-123',
|
||||
{ augmentation_type: 'gaussian_noise', params: {} },
|
||||
2
|
||||
)
|
||||
|
||||
expect(apiClient.post).toHaveBeenCalledWith(
|
||||
'/api/v1/admin/augmentation/preview/doc-123',
|
||||
expect.anything(),
|
||||
{ params: { page: 2 } }
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
144
frontend/src/api/endpoints/augmentation.ts
Normal file
144
frontend/src/api/endpoints/augmentation.ts
Normal file
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Augmentation API endpoints.
|
||||
*
|
||||
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
|
||||
*/
|
||||
|
||||
import apiClient from '../client'
|
||||
|
||||
// Types
|
||||
export interface AugmentationTypeInfo {
|
||||
name: string
|
||||
description: string
|
||||
affects_geometry: boolean
|
||||
stage: string
|
||||
default_params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationTypesResponse {
|
||||
augmentation_types: AugmentationTypeInfo[]
|
||||
}
|
||||
|
||||
export interface PresetInfo {
|
||||
name: string
|
||||
description: string
|
||||
config?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface PresetsResponse {
|
||||
presets: PresetInfo[]
|
||||
}
|
||||
|
||||
export interface PreviewRequest {
|
||||
augmentation_type: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface PreviewResponse {
|
||||
preview_url: string
|
||||
original_url: string
|
||||
applied_params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationParams {
|
||||
enabled: boolean
|
||||
probability: number
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationConfig {
|
||||
perspective_warp?: AugmentationParams
|
||||
wrinkle?: AugmentationParams
|
||||
edge_damage?: AugmentationParams
|
||||
stain?: AugmentationParams
|
||||
lighting_variation?: AugmentationParams
|
||||
shadow?: AugmentationParams
|
||||
gaussian_blur?: AugmentationParams
|
||||
motion_blur?: AugmentationParams
|
||||
gaussian_noise?: AugmentationParams
|
||||
salt_pepper?: AugmentationParams
|
||||
paper_texture?: AugmentationParams
|
||||
scanner_artifacts?: AugmentationParams
|
||||
preserve_bboxes?: boolean
|
||||
seed?: number | null
|
||||
}
|
||||
|
||||
export interface BatchRequest {
|
||||
dataset_id: string
|
||||
config: AugmentationConfig
|
||||
output_name: string
|
||||
multiplier: number
|
||||
}
|
||||
|
||||
export interface BatchResponse {
|
||||
task_id: string
|
||||
status: string
|
||||
message: string
|
||||
estimated_images: number
|
||||
}
|
||||
|
||||
// API functions
|
||||
export const augmentationApi = {
|
||||
/**
|
||||
* Fetch available augmentation types.
|
||||
*/
|
||||
async getTypes(): Promise<AugmentationTypesResponse> {
|
||||
const response = await apiClient.get<AugmentationTypesResponse>(
|
||||
'/api/v1/admin/augmentation/types'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Fetch augmentation presets.
|
||||
*/
|
||||
async getPresets(): Promise<PresetsResponse> {
|
||||
const response = await apiClient.get<PresetsResponse>(
|
||||
'/api/v1/admin/augmentation/presets'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Preview a single augmentation on a document page.
|
||||
*/
|
||||
async preview(
|
||||
documentId: string,
|
||||
request: PreviewRequest,
|
||||
page: number = 1
|
||||
): Promise<PreviewResponse> {
|
||||
const response = await apiClient.post<PreviewResponse>(
|
||||
`/api/v1/admin/augmentation/preview/${documentId}`,
|
||||
request,
|
||||
{ params: { page } }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Preview full augmentation config on a document page.
|
||||
*/
|
||||
async previewConfig(
|
||||
documentId: string,
|
||||
config: AugmentationConfig,
|
||||
page: number = 1
|
||||
): Promise<PreviewResponse> {
|
||||
const response = await apiClient.post<PreviewResponse>(
|
||||
`/api/v1/admin/augmentation/preview-config/${documentId}`,
|
||||
config,
|
||||
{ params: { page } }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* Create an augmented dataset.
|
||||
*/
|
||||
async createBatch(request: BatchRequest): Promise<BatchResponse> {
|
||||
const response = await apiClient.post<BatchResponse>(
|
||||
'/api/v1/admin/augmentation/batch',
|
||||
request
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
52
frontend/src/api/endpoints/datasets.ts
Normal file
52
frontend/src/api/endpoints/datasets.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetListResponse,
|
||||
DatasetResponse,
|
||||
DatasetTrainRequest,
|
||||
TrainingTaskResponse,
|
||||
} from '../types'
|
||||
|
||||
export const datasetsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DatasetListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/datasets', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
create: async (req: DatasetCreateRequest): Promise<DatasetResponse> => {
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/datasets', req)
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (datasetId: string): Promise<DatasetDetailResponse> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/training/datasets/${datasetId}`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
remove: async (datasetId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.delete(
|
||||
`/api/v1/admin/training/datasets/${datasetId}`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
trainFromDataset: async (
|
||||
datasetId: string,
|
||||
req: DatasetTrainRequest
|
||||
): Promise<TrainingTaskResponse> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/training/datasets/${datasetId}/train`,
|
||||
req
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
122
frontend/src/api/endpoints/documents.ts
Normal file
122
frontend/src/api/endpoints/documents.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
DocumentListResponse,
|
||||
DocumentDetailResponse,
|
||||
DocumentItem,
|
||||
UploadDocumentResponse,
|
||||
DocumentCategoriesResponse,
|
||||
} from '../types'
|
||||
|
||||
export const documentsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
category?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DocumentListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/documents', { params })
|
||||
return data
|
||||
},
|
||||
|
||||
getCategories: async (): Promise<DocumentCategoriesResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/documents/categories')
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
upload: async (
|
||||
file: File,
|
||||
options?: { groupKey?: string; category?: string }
|
||||
): Promise<UploadDocumentResponse> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const params: Record<string, string> = {}
|
||||
if (options?.groupKey) {
|
||||
params.group_key = options.groupKey
|
||||
}
|
||||
if (options?.category) {
|
||||
params.category = options.category
|
||||
}
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
batchUpload: async (
|
||||
files: File[],
|
||||
csvFile?: File
|
||||
): Promise<{ batch_id: string; message: string; documents_created: number }> => {
|
||||
const formData = new FormData()
|
||||
|
||||
files.forEach((file) => {
|
||||
formData.append('files', file)
|
||||
})
|
||||
|
||||
if (csvFile) {
|
||||
formData.append('csv_file', csvFile)
|
||||
}
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/admin/batch/upload', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (documentId: string): Promise<void> => {
|
||||
await apiClient.delete(`/api/v1/admin/documents/${documentId}`)
|
||||
},
|
||||
|
||||
updateStatus: async (
|
||||
documentId: string,
|
||||
status: string
|
||||
): Promise<DocumentItem> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/status`,
|
||||
null,
|
||||
{ params: { status } }
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
triggerAutoLabel: async (documentId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/documents/${documentId}/auto-label`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
updateGroupKey: async (
|
||||
documentId: string,
|
||||
groupKey: string | null
|
||||
): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/group-key`,
|
||||
null,
|
||||
{ params: { group_key: groupKey } }
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
updateCategory: async (
|
||||
documentId: string,
|
||||
category: string
|
||||
): Promise<{ status: string; document_id: string; category: string; message: string }> => {
|
||||
const { data } = await apiClient.patch(
|
||||
`/api/v1/admin/documents/${documentId}/category`,
|
||||
{ category }
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
7
frontend/src/api/endpoints/index.ts
Normal file
7
frontend/src/api/endpoints/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export { documentsApi } from './documents'
|
||||
export { annotationsApi } from './annotations'
|
||||
export { trainingApi } from './training'
|
||||
export { inferenceApi } from './inference'
|
||||
export { datasetsApi } from './datasets'
|
||||
export { augmentationApi } from './augmentation'
|
||||
export { modelsApi } from './models'
|
||||
16
frontend/src/api/endpoints/inference.ts
Normal file
16
frontend/src/api/endpoints/inference.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import apiClient from '../client'
|
||||
import type { InferenceResponse } from '../types'
|
||||
|
||||
export const inferenceApi = {
|
||||
processDocument: async (file: File): Promise<InferenceResponse> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
const { data } = await apiClient.post('/api/v1/infer', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
})
|
||||
return data
|
||||
},
|
||||
}
|
||||
55
frontend/src/api/endpoints/models.ts
Normal file
55
frontend/src/api/endpoints/models.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import apiClient from '../client'
|
||||
import type {
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ModelVersionResponse,
|
||||
ActiveModelResponse,
|
||||
} from '../types'
|
||||
|
||||
export const modelsApi = {
|
||||
list: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<ModelVersionListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/models', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getDetail: async (versionId: string): Promise<ModelVersionDetailResponse> => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
getActive: async (): Promise<ActiveModelResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/models/active')
|
||||
return data
|
||||
},
|
||||
|
||||
activate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`)
|
||||
return data
|
||||
},
|
||||
|
||||
deactivate: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`)
|
||||
return data
|
||||
},
|
||||
|
||||
archive: async (versionId: string): Promise<ModelVersionResponse> => {
|
||||
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`)
|
||||
return data
|
||||
},
|
||||
|
||||
delete: async (versionId: string): Promise<{ message: string }> => {
|
||||
const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
reload: async (): Promise<{ message: string; reloaded: boolean }> => {
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/models/reload')
|
||||
return data
|
||||
},
|
||||
}
|
||||
74
frontend/src/api/endpoints/training.ts
Normal file
74
frontend/src/api/endpoints/training.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import apiClient from '../client'
|
||||
import type { TrainingModelsResponse, DocumentListResponse } from '../types'
|
||||
|
||||
export const trainingApi = {
|
||||
getDocumentsForTraining: async (params?: {
|
||||
has_annotations?: boolean
|
||||
min_annotation_count?: number
|
||||
exclude_used_in_training?: boolean
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<DocumentListResponse> => {
|
||||
const { data } = await apiClient.get('/api/v1/admin/training/documents', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getModels: async (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<TrainingModelsResponse> => {
|
||||
const { data} = await apiClient.get('/api/v1/admin/training/models', {
|
||||
params,
|
||||
})
|
||||
return data
|
||||
},
|
||||
|
||||
getTaskDetail: async (taskId: string) => {
|
||||
const { data } = await apiClient.get(`/api/v1/admin/training/tasks/${taskId}`)
|
||||
return data
|
||||
},
|
||||
|
||||
startTraining: async (config: {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
model_base?: string
|
||||
}) => {
|
||||
// Convert frontend config to backend TrainingTaskCreate format
|
||||
const taskRequest = {
|
||||
name: config.name,
|
||||
task_type: 'yolo',
|
||||
description: config.description,
|
||||
config: {
|
||||
document_ids: config.document_ids,
|
||||
epochs: config.epochs,
|
||||
batch_size: config.batch_size,
|
||||
base_model: config.model_base,
|
||||
},
|
||||
}
|
||||
const { data } = await apiClient.post('/api/v1/admin/training/tasks', taskRequest)
|
||||
return data
|
||||
},
|
||||
|
||||
cancelTask: async (taskId: string) => {
|
||||
const { data } = await apiClient.post(
|
||||
`/api/v1/admin/training/tasks/${taskId}/cancel`
|
||||
)
|
||||
return data
|
||||
},
|
||||
|
||||
downloadModel: async (taskId: string): Promise<Blob> => {
|
||||
const { data } = await apiClient.get(
|
||||
`/api/v1/admin/training/models/${taskId}/download`,
|
||||
{
|
||||
responseType: 'blob',
|
||||
}
|
||||
)
|
||||
return data
|
||||
},
|
||||
}
|
||||
364
frontend/src/api/types.ts
Normal file
364
frontend/src/api/types.ts
Normal file
@@ -0,0 +1,364 @@
|
||||
export interface DocumentItem {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
content_type: string
|
||||
page_count: number
|
||||
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
group_key: string | null
|
||||
category: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
annotation_count?: number
|
||||
annotation_sources?: {
|
||||
manual: number
|
||||
auto: number
|
||||
verified: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface DocumentListResponse {
|
||||
documents: DocumentItem[]
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
}
|
||||
|
||||
export interface AnnotationItem {
|
||||
annotation_id: string
|
||||
page_number: number
|
||||
class_id: number
|
||||
class_name: string
|
||||
bbox: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
normalized_bbox: {
|
||||
x_center: number
|
||||
y_center: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
text_value: string | null
|
||||
confidence: number | null
|
||||
source: 'manual' | 'auto'
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface DocumentDetailResponse {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
content_type: string
|
||||
page_count: number
|
||||
status: 'pending' | 'labeled' | 'verified' | 'exported'
|
||||
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
|
||||
auto_label_error: string | null
|
||||
upload_source: string
|
||||
batch_id: string | null
|
||||
group_key: string | null
|
||||
category: string
|
||||
csv_field_values: Record<string, string> | null
|
||||
can_annotate: boolean
|
||||
annotation_lock_until: string | null
|
||||
annotations: AnnotationItem[]
|
||||
image_urls: string[]
|
||||
training_history: Array<{
|
||||
task_id: string
|
||||
name: string
|
||||
trained_at: string
|
||||
model_metrics: {
|
||||
mAP: number | null
|
||||
precision: number | null
|
||||
recall: number | null
|
||||
} | null
|
||||
}>
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface TrainingTask {
|
||||
task_id: string
|
||||
admin_token: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: 'pending' | 'running' | 'completed' | 'failed'
|
||||
task_type: string
|
||||
config: Record<string, unknown>
|
||||
started_at: string | null
|
||||
completed_at: string | null
|
||||
error_message: string | null
|
||||
result_metrics: Record<string, unknown>
|
||||
model_path: string | null
|
||||
document_count: number
|
||||
metrics_mAP: number | null
|
||||
metrics_precision: number | null
|
||||
metrics_recall: number | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionItem {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
metrics_mAP: number | null
|
||||
document_count: number
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface TrainingModelsResponse {
|
||||
models: ModelVersionItem[]
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
}
|
||||
|
||||
export interface ErrorResponse {
|
||||
detail: string
|
||||
}
|
||||
|
||||
export interface UploadDocumentResponse {
|
||||
document_id: string
|
||||
filename: string
|
||||
file_size: number
|
||||
page_count: number
|
||||
status: string
|
||||
category: string
|
||||
group_key: string | null
|
||||
auto_label_started: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface DocumentCategoriesResponse {
|
||||
categories: string[]
|
||||
total: number
|
||||
}
|
||||
|
||||
export interface CreateAnnotationRequest {
|
||||
page_number: number
|
||||
class_id: number
|
||||
bbox: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
text_value?: string
|
||||
}
|
||||
|
||||
export interface AnnotationOverrideRequest {
|
||||
text_value?: string
|
||||
bbox?: {
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
}
|
||||
class_id?: number
|
||||
class_name?: string
|
||||
reason?: string
|
||||
}
|
||||
|
||||
export interface CrossValidationResult {
|
||||
is_valid: boolean
|
||||
payment_line_ocr: string | null
|
||||
payment_line_amount: string | null
|
||||
payment_line_account: string | null
|
||||
payment_line_account_type: 'bankgiro' | 'plusgiro' | null
|
||||
ocr_match: boolean | null
|
||||
amount_match: boolean | null
|
||||
bankgiro_match: boolean | null
|
||||
plusgiro_match: boolean | null
|
||||
details: string[]
|
||||
}
|
||||
|
||||
export interface InferenceResult {
|
||||
document_id: string
|
||||
document_type: string
|
||||
success: boolean
|
||||
fields: Record<string, string>
|
||||
confidence: Record<string, number>
|
||||
cross_validation: CrossValidationResult | null
|
||||
processing_time_ms: number
|
||||
visualization_url: string | null
|
||||
errors: string[]
|
||||
fallback_used: boolean
|
||||
}
|
||||
|
||||
export interface InferenceResponse {
|
||||
result: InferenceResult
|
||||
}
|
||||
|
||||
// Dataset types
|
||||
|
||||
export interface DatasetCreateRequest {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
train_ratio?: number
|
||||
val_ratio?: number
|
||||
seed?: number
|
||||
}
|
||||
|
||||
export interface DatasetResponse {
|
||||
dataset_id: string
|
||||
name: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface DatasetDocumentItem {
|
||||
document_id: string
|
||||
split: string
|
||||
page_count: number
|
||||
annotation_count: number
|
||||
}
|
||||
|
||||
export interface DatasetListItem {
|
||||
dataset_id: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: string
|
||||
training_status: string | null
|
||||
active_training_task_id: string | null
|
||||
total_documents: number
|
||||
total_images: number
|
||||
total_annotations: number
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface DatasetListResponse {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
datasets: DatasetListItem[]
|
||||
}
|
||||
|
||||
export interface DatasetDetailResponse {
|
||||
dataset_id: string
|
||||
name: string
|
||||
description: string | null
|
||||
status: string
|
||||
training_status: string | null
|
||||
active_training_task_id: string | null
|
||||
train_ratio: number
|
||||
val_ratio: number
|
||||
seed: number
|
||||
total_documents: number
|
||||
total_images: number
|
||||
total_annotations: number
|
||||
dataset_path: string | null
|
||||
error_message: string | null
|
||||
documents: DatasetDocumentItem[]
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface AugmentationParams {
|
||||
enabled: boolean
|
||||
probability: number
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface AugmentationTrainingConfig {
|
||||
gaussian_noise?: AugmentationParams
|
||||
perspective_warp?: AugmentationParams
|
||||
wrinkle?: AugmentationParams
|
||||
edge_damage?: AugmentationParams
|
||||
stain?: AugmentationParams
|
||||
lighting_variation?: AugmentationParams
|
||||
shadow?: AugmentationParams
|
||||
gaussian_blur?: AugmentationParams
|
||||
motion_blur?: AugmentationParams
|
||||
salt_pepper?: AugmentationParams
|
||||
paper_texture?: AugmentationParams
|
||||
scanner_artifacts?: AugmentationParams
|
||||
preserve_bboxes?: boolean
|
||||
seed?: number | null
|
||||
}
|
||||
|
||||
export interface DatasetTrainRequest {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
image_size?: number
|
||||
learning_rate?: number
|
||||
device?: string
|
||||
augmentation?: AugmentationTrainingConfig
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface TrainingTaskResponse {
|
||||
task_id: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
// Model Version types
|
||||
|
||||
export interface ModelVersionItem {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
metrics_mAP: number | null
|
||||
document_count: number
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionDetailResponse {
|
||||
version_id: string
|
||||
version: string
|
||||
name: string
|
||||
description: string | null
|
||||
model_path: string
|
||||
status: string
|
||||
is_active: boolean
|
||||
task_id: string | null
|
||||
dataset_id: string | null
|
||||
metrics_mAP: number | null
|
||||
metrics_precision: number | null
|
||||
metrics_recall: number | null
|
||||
document_count: number
|
||||
training_config: Record<string, unknown> | null
|
||||
file_size: number | null
|
||||
trained_at: string | null
|
||||
activated_at: string | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ModelVersionListResponse {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
models: ModelVersionItem[]
|
||||
}
|
||||
|
||||
export interface ModelVersionResponse {
|
||||
version_id: string
|
||||
status: string
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface ActiveModelResponse {
|
||||
has_active_model: boolean
|
||||
model: ModelVersionItem | null
|
||||
}
|
||||
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
251
frontend/src/components/AugmentationConfig.test.tsx
Normal file
@@ -0,0 +1,251 @@
|
||||
/**
|
||||
* Tests for AugmentationConfig component.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { AugmentationConfig } from './AugmentationConfig'
|
||||
import { augmentationApi } from '../api/endpoints/augmentation'
|
||||
import type { ReactNode } from 'react'
|
||||
|
||||
// Mock the API
|
||||
vi.mock('../api/endpoints/augmentation', () => ({
|
||||
augmentationApi: {
|
||||
getTypes: vi.fn(),
|
||||
getPresets: vi.fn(),
|
||||
preview: vi.fn(),
|
||||
previewConfig: vi.fn(),
|
||||
createBatch: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Default mock data
|
||||
const mockTypes = {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise to simulate sensor noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
{
|
||||
name: 'perspective_warp',
|
||||
description: 'Applies perspective transformation',
|
||||
affects_geometry: true,
|
||||
stage: 'geometric',
|
||||
default_params: { max_warp: 0.02 },
|
||||
},
|
||||
{
|
||||
name: 'gaussian_blur',
|
||||
description: 'Applies Gaussian blur',
|
||||
affects_geometry: false,
|
||||
stage: 'blur',
|
||||
default_params: { kernel_size: 5 },
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const mockPresets = {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations for high-quality documents' },
|
||||
{ name: 'moderate', description: 'Balanced augmentation settings' },
|
||||
{ name: 'aggressive', description: 'Strong augmentations for data diversity' },
|
||||
],
|
||||
}
|
||||
|
||||
// Test wrapper with QueryClient
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('AugmentationConfig', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes)
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets)
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render enable checkbox', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should be collapsed when disabled', () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
// Config options should not be visible
|
||||
expect(screen.queryByText(/preset/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should expand when enabled', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/preset/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('preset selection', () => {
|
||||
it('should display available presets', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('conservative')).toBeInTheDocument()
|
||||
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||
expect(screen.getByText('aggressive')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onConfigChange when preset is selected', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onConfigChange = vi.fn()
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={onConfigChange}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('moderate')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText('moderate'))
|
||||
|
||||
expect(onConfigChange).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('enable toggle', () => {
|
||||
it('should call onEnabledChange when checkbox is toggled', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onEnabledChange = vi.fn()
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={false}
|
||||
onEnabledChange={onEnabledChange}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i }))
|
||||
|
||||
expect(onEnabledChange).toHaveBeenCalledWith(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('augmentation types', () => {
|
||||
it('should display augmentation types when in custom mode', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
showCustomOptions={true}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should indicate which augmentations affect geometry', async () => {
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
showCustomOptions={true}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
// perspective_warp affects geometry
|
||||
const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div')
|
||||
expect(perspectiveItem).toHaveTextContent(/affects bbox/i)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('loading state', () => {
|
||||
it('should show loading indicator while fetching types', () => {
|
||||
vi.mocked(augmentationApi.getTypes).mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
)
|
||||
|
||||
render(
|
||||
<AugmentationConfig
|
||||
enabled={true}
|
||||
onEnabledChange={vi.fn()}
|
||||
config={{}}
|
||||
onConfigChange={vi.fn()}
|
||||
/>,
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
136
frontend/src/components/AugmentationConfig.tsx
Normal file
136
frontend/src/components/AugmentationConfig.tsx
Normal file
@@ -0,0 +1,136 @@
|
||||
/**
|
||||
* AugmentationConfig component for configuring image augmentation during training.
|
||||
*
|
||||
* Provides preset selection and optional custom augmentation type configuration.
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { Loader2, AlertTriangle } from 'lucide-react'
|
||||
import { useAugmentation } from '../hooks/useAugmentation'
|
||||
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||
|
||||
interface AugmentationConfigProps {
|
||||
enabled: boolean
|
||||
onEnabledChange: (enabled: boolean) => void
|
||||
config: Partial<AugmentationConfigType>
|
||||
onConfigChange: (config: Partial<AugmentationConfigType>) => void
|
||||
showCustomOptions?: boolean
|
||||
}
|
||||
|
||||
export const AugmentationConfig: React.FC<AugmentationConfigProps> = ({
|
||||
enabled,
|
||||
onEnabledChange,
|
||||
config,
|
||||
onConfigChange,
|
||||
showCustomOptions = false,
|
||||
}) => {
|
||||
const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation()
|
||||
|
||||
const isLoading = isLoadingTypes || isLoadingPresets
|
||||
|
||||
const handlePresetSelect = (presetName: string) => {
|
||||
const preset = presets.find((p) => p.name === presetName)
|
||||
if (preset && preset.config) {
|
||||
onConfigChange(preset.config as Partial<AugmentationConfigType>)
|
||||
} else {
|
||||
// Apply a basic config based on preset name
|
||||
const presetConfigs: Record<string, Partial<AugmentationConfigType>> = {
|
||||
conservative: {
|
||||
gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } },
|
||||
},
|
||||
moderate: {
|
||||
gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } },
|
||||
lighting_variation: { enabled: true, probability: 0.3, params: {} },
|
||||
perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } },
|
||||
},
|
||||
aggressive: {
|
||||
gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } },
|
||||
gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } },
|
||||
motion_blur: { enabled: true, probability: 0.3, params: {} },
|
||||
lighting_variation: { enabled: true, probability: 0.5, params: {} },
|
||||
shadow: { enabled: true, probability: 0.3, params: {} },
|
||||
perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } },
|
||||
wrinkle: { enabled: true, probability: 0.2, params: {} },
|
||||
stain: { enabled: true, probability: 0.2, params: {} },
|
||||
},
|
||||
}
|
||||
onConfigChange(presetConfigs[presetName] || {})
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="border border-warm-divider rounded-lg p-4 bg-warm-bg-secondary">
|
||||
{/* Enable checkbox */}
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={enabled}
|
||||
onChange={(e) => onEnabledChange(e.target.checked)}
|
||||
className="w-4 h-4 rounded border-warm-divider text-warm-state-info focus:ring-warm-state-info"
|
||||
aria-label="Enable augmentation"
|
||||
/>
|
||||
<span className="text-sm font-medium text-warm-text-secondary">Enable Augmentation</span>
|
||||
<span className="text-xs text-warm-text-muted">(Simulate real-world document conditions)</span>
|
||||
</label>
|
||||
|
||||
{/* Expanded content when enabled */}
|
||||
{enabled && (
|
||||
<div className="mt-4 space-y-4">
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-4" data-testid="augmentation-loading">
|
||||
<Loader2 className="w-5 h-5 animate-spin text-warm-state-info" />
|
||||
<span className="ml-2 text-sm text-warm-text-muted">Loading augmentation options...</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{/* Preset selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Preset</label>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{presets.map((preset) => (
|
||||
<button
|
||||
key={preset.name}
|
||||
onClick={() => handlePresetSelect(preset.name)}
|
||||
className="px-3 py-1.5 text-sm rounded-md border border-warm-divider hover:bg-warm-bg-tertiary transition-colors"
|
||||
title={preset.description}
|
||||
>
|
||||
{preset.name}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Custom options (if enabled) */}
|
||||
{showCustomOptions && (
|
||||
<div className="border-t border-warm-divider pt-4">
|
||||
<h4 className="text-sm font-medium text-warm-text-secondary mb-3">Augmentation Types</h4>
|
||||
<div className="grid gap-2">
|
||||
{augmentationTypes.map((type) => (
|
||||
<div
|
||||
key={type.name}
|
||||
className="flex items-center justify-between p-2 bg-warm-bg-primary rounded border border-warm-divider"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-warm-text-primary">{type.name}</span>
|
||||
{type.affects_geometry && (
|
||||
<span className="flex items-center gap-1 text-xs text-warm-state-warning">
|
||||
<AlertTriangle size={12} />
|
||||
affects bbox
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs text-warm-text-muted">{type.stage}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
32
frontend/src/components/Badge.test.tsx
Normal file
32
frontend/src/components/Badge.test.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { Badge } from './Badge';
|
||||
import { DocumentStatus } from '../types';
|
||||
|
||||
describe('Badge', () => {
|
||||
it('renders Exported badge with check icon', () => {
|
||||
render(<Badge status="Exported" />);
|
||||
expect(screen.getByText('Exported')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Pending status', () => {
|
||||
render(<Badge status={DocumentStatus.PENDING} />);
|
||||
expect(screen.getByText('Pending')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Verified status', () => {
|
||||
render(<Badge status={DocumentStatus.VERIFIED} />);
|
||||
expect(screen.getByText('Verified')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Labeled status', () => {
|
||||
render(<Badge status={DocumentStatus.LABELED} />);
|
||||
expect(screen.getByText('Labeled')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders Partial status with warning indicator', () => {
|
||||
render(<Badge status={DocumentStatus.PARTIAL} />);
|
||||
expect(screen.getByText('Partial')).toBeInTheDocument();
|
||||
expect(screen.getByText('!')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
39
frontend/src/components/Badge.tsx
Normal file
39
frontend/src/components/Badge.tsx
Normal file
@@ -0,0 +1,39 @@
|
||||
import React from 'react';
|
||||
import { DocumentStatus } from '../types';
|
||||
import { Check } from 'lucide-react';
|
||||
|
||||
interface BadgeProps {
|
||||
status: DocumentStatus | 'Exported';
|
||||
}
|
||||
|
||||
export const Badge: React.FC<BadgeProps> = ({ status }) => {
|
||||
if (status === 'Exported') {
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-full text-xs font-medium bg-warm-selected text-warm-text-secondary">
|
||||
<Check size={12} strokeWidth={3} />
|
||||
Exported
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
const styles = {
|
||||
[DocumentStatus.PENDING]: "bg-white border border-warm-divider text-warm-text-secondary",
|
||||
[DocumentStatus.LABELED]: "bg-warm-text-secondary text-white border border-transparent",
|
||||
[DocumentStatus.VERIFIED]: "bg-warm-state-success/10 text-warm-state-success border border-warm-state-success/20",
|
||||
[DocumentStatus.PARTIAL]: "bg-warm-state-warning/10 text-warm-state-warning border border-warm-state-warning/20",
|
||||
};
|
||||
|
||||
const icons = {
|
||||
[DocumentStatus.VERIFIED]: <Check size={12} className="mr-1" />,
|
||||
[DocumentStatus.PARTIAL]: <span className="mr-1 text-[10px] font-bold">!</span>,
|
||||
[DocumentStatus.PENDING]: null,
|
||||
[DocumentStatus.LABELED]: null,
|
||||
}
|
||||
|
||||
return (
|
||||
<span className={`inline-flex items-center px-3 py-1 rounded-full text-xs font-medium border ${styles[status]}`}>
|
||||
{icons[status]}
|
||||
{status}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
38
frontend/src/components/Button.test.tsx
Normal file
38
frontend/src/components/Button.test.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { Button } from './Button';
|
||||
|
||||
describe('Button', () => {
|
||||
it('renders children text', () => {
|
||||
render(<Button>Click me</Button>);
|
||||
expect(screen.getByRole('button', { name: 'Click me' })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onClick handler', async () => {
|
||||
const user = userEvent.setup();
|
||||
const onClick = vi.fn();
|
||||
render(<Button onClick={onClick}>Click</Button>);
|
||||
await user.click(screen.getByRole('button'));
|
||||
expect(onClick).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('is disabled when disabled prop is set', () => {
|
||||
render(<Button disabled>Disabled</Button>);
|
||||
expect(screen.getByRole('button')).toBeDisabled();
|
||||
});
|
||||
|
||||
it('applies variant styles', () => {
|
||||
const { rerender } = render(<Button variant="primary">Primary</Button>);
|
||||
const btn = screen.getByRole('button');
|
||||
expect(btn.className).toContain('bg-warm-text-secondary');
|
||||
|
||||
rerender(<Button variant="secondary">Secondary</Button>);
|
||||
expect(screen.getByRole('button').className).toContain('border');
|
||||
});
|
||||
|
||||
it('applies size styles', () => {
|
||||
render(<Button size="sm">Small</Button>);
|
||||
expect(screen.getByRole('button').className).toContain('h-8');
|
||||
});
|
||||
});
|
||||
38
frontend/src/components/Button.tsx
Normal file
38
frontend/src/components/Button.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import React from 'react';
|
||||
|
||||
interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement> {
|
||||
variant?: 'primary' | 'secondary' | 'outline' | 'text';
|
||||
size?: 'sm' | 'md' | 'lg';
|
||||
}
|
||||
|
||||
export const Button: React.FC<ButtonProps> = ({
|
||||
variant = 'primary',
|
||||
size = 'md',
|
||||
className = '',
|
||||
children,
|
||||
...props
|
||||
}) => {
|
||||
const baseStyles = "inline-flex items-center justify-center rounded-md font-medium transition-all duration-150 ease-out active:scale-98 disabled:opacity-50 disabled:pointer-events-none";
|
||||
|
||||
const variants = {
|
||||
primary: "bg-warm-text-secondary text-white hover:bg-warm-text-primary shadow-sm",
|
||||
secondary: "bg-white border border-warm-divider text-warm-text-secondary hover:bg-warm-hover",
|
||||
outline: "bg-transparent border border-warm-text-secondary text-warm-text-secondary hover:bg-warm-hover",
|
||||
text: "text-warm-text-muted hover:text-warm-text-primary hover:bg-warm-hover",
|
||||
};
|
||||
|
||||
const sizes = {
|
||||
sm: "h-8 px-3 text-xs",
|
||||
md: "h-10 px-4 text-sm",
|
||||
lg: "h-12 px-6 text-base",
|
||||
};
|
||||
|
||||
return (
|
||||
<button
|
||||
className={`${baseStyles} ${variants[variant]} ${sizes[size]} ${className}`}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
};
|
||||
300
frontend/src/components/Dashboard.tsx
Normal file
300
frontend/src/components/Dashboard.tsx
Normal file
@@ -0,0 +1,300 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
|
||||
import { Badge } from './Badge'
|
||||
import { Button } from './Button'
|
||||
import { UploadModal } from './UploadModal'
|
||||
import { useDocuments, useCategories } from '../hooks/useDocuments'
|
||||
import type { DocumentItem } from '../api/types'
|
||||
|
||||
interface DashboardProps {
|
||||
onNavigate: (view: string, docId?: string) => void
|
||||
}
|
||||
|
||||
const getStatusForBadge = (status: string): string => {
|
||||
const statusMap: Record<string, string> = {
|
||||
pending: 'Pending',
|
||||
labeled: 'Labeled',
|
||||
verified: 'Verified',
|
||||
exported: 'Exported',
|
||||
}
|
||||
return statusMap[status] || status
|
||||
}
|
||||
|
||||
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
|
||||
if (doc.auto_label_status === 'running') {
|
||||
return 45
|
||||
}
|
||||
if (doc.auto_label_status === 'completed') {
|
||||
return 100
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
|
||||
const [isUploadOpen, setIsUploadOpen] = useState(false)
|
||||
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
|
||||
const [statusFilter, setStatusFilter] = useState<string>('')
|
||||
const [categoryFilter, setCategoryFilter] = useState<string>('')
|
||||
const [limit] = useState(20)
|
||||
const [offset] = useState(0)
|
||||
|
||||
const { categories } = useCategories()
|
||||
|
||||
const { documents, total, isLoading, error, refetch } = useDocuments({
|
||||
status: statusFilter || undefined,
|
||||
category: categoryFilter || undefined,
|
||||
limit,
|
||||
offset,
|
||||
})
|
||||
|
||||
const toggleSelection = (id: string) => {
|
||||
const newSet = new Set(selectedDocs)
|
||||
if (newSet.has(id)) {
|
||||
newSet.delete(id)
|
||||
} else {
|
||||
newSet.add(id)
|
||||
}
|
||||
setSelectedDocs(newSet)
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<div className="bg-red-50 border border-red-200 text-red-800 p-4 rounded-lg">
|
||||
Error loading documents. Please check your connection to the backend API.
|
||||
<button
|
||||
onClick={() => refetch()}
|
||||
className="ml-4 underline hover:no-underline"
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||
<div className="flex items-center justify-between mb-8">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||
Documents
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mt-1">
|
||||
{isLoading ? 'Loading...' : `${total} documents total`}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<Button variant="secondary" disabled={selectedDocs.size === 0}>
|
||||
Export Selection ({selectedDocs.size})
|
||||
</Button>
|
||||
<Button onClick={() => setIsUploadOpen(true)}>Upload Documents</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-4 mb-6 shadow-sm flex flex-wrap gap-4 items-center">
|
||||
<div className="relative flex-1 min-w-[200px]">
|
||||
<Search
|
||||
className="absolute left-3 top-1/2 -translate-y-1/2 text-warm-text-muted"
|
||||
size={16}
|
||||
/>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search documents..."
|
||||
className="w-full pl-9 pr-4 h-10 rounded-md border border-warm-border bg-white focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-3">
|
||||
<div className="relative">
|
||||
<select
|
||||
value={categoryFilter}
|
||||
onChange={(e) => setCategoryFilter(e.target.value)}
|
||||
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
|
||||
>
|
||||
<option value="">All Categories</option>
|
||||
{categories.map((cat) => (
|
||||
<option key={cat} value={cat}>
|
||||
{cat.charAt(0).toUpperCase() + cat.slice(1)}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<ChevronDown
|
||||
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||
size={14}
|
||||
/>
|
||||
</div>
|
||||
<div className="relative">
|
||||
<select
|
||||
value={statusFilter}
|
||||
onChange={(e) => setStatusFilter(e.target.value)}
|
||||
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
|
||||
>
|
||||
<option value="">All Statuses</option>
|
||||
<option value="pending">Pending</option>
|
||||
<option value="labeled">Labeled</option>
|
||||
<option value="verified">Verified</option>
|
||||
<option value="exported">Exported</option>
|
||||
</select>
|
||||
<ChevronDown
|
||||
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||
size={14}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<table className="w-full text-left border-collapse">
|
||||
<thead>
|
||||
<tr className="border-b border-warm-border bg-white">
|
||||
<th className="py-3 pl-6 pr-4 w-12">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary"
|
||||
/>
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Document Name
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Date
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Status
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Annotations
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Category
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
|
||||
Group
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
|
||||
Auto-label
|
||||
</th>
|
||||
<th className="py-3 px-4 w-12"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{isLoading ? (
|
||||
<tr>
|
||||
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
|
||||
Loading documents...
|
||||
</td>
|
||||
</tr>
|
||||
) : documents.length === 0 ? (
|
||||
<tr>
|
||||
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
|
||||
No documents found. Upload your first document to get started.
|
||||
</td>
|
||||
</tr>
|
||||
) : (
|
||||
documents.map((doc) => {
|
||||
const isSelected = selectedDocs.has(doc.document_id)
|
||||
const progress = getAutoLabelProgress(doc)
|
||||
|
||||
return (
|
||||
<tr
|
||||
key={doc.document_id}
|
||||
onClick={() => onNavigate('detail', doc.document_id)}
|
||||
className={`
|
||||
group transition-colors duration-150 cursor-pointer border-b border-warm-border last:border-0
|
||||
${isSelected ? 'bg-warm-selected' : 'hover:bg-warm-hover bg-white'}
|
||||
`}
|
||||
>
|
||||
<td
|
||||
className="py-4 pl-6 pr-4 relative"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
toggleSelection(doc.document_id)
|
||||
}}
|
||||
>
|
||||
{isSelected && (
|
||||
<div className="absolute left-0 top-0 bottom-0 w-[3px] bg-warm-state-info" />
|
||||
)}
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={isSelected}
|
||||
readOnly
|
||||
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary cursor-pointer"
|
||||
/>
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="p-2 bg-warm-bg rounded border border-warm-border text-warm-text-muted">
|
||||
<FileText size={16} />
|
||||
</div>
|
||||
<span className="font-medium text-warm-text-secondary">
|
||||
{doc.filename}
|
||||
</span>
|
||||
</div>
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary font-mono">
|
||||
{new Date(doc.created_at).toLocaleDateString()}
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
<Badge status={getStatusForBadge(doc.status)} />
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary">
|
||||
{doc.annotation_count || 0} annotations
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-secondary capitalize">
|
||||
{doc.category || 'invoice'}
|
||||
</td>
|
||||
<td className="py-4 px-4 text-sm text-warm-text-muted">
|
||||
{doc.group_key || '-'}
|
||||
</td>
|
||||
<td className="py-4 px-4">
|
||||
{doc.auto_label_status === 'running' && progress && (
|
||||
<div className="w-full">
|
||||
<div className="flex justify-between text-xs mb-1">
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
Running
|
||||
</span>
|
||||
<span className="text-warm-text-muted">{progress}%</span>
|
||||
</div>
|
||||
<div className="h-1.5 w-full bg-warm-selected rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-warm-state-info transition-all duration-500 ease-out"
|
||||
style={{ width: `${progress}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{doc.auto_label_status === 'completed' && (
|
||||
<span className="text-sm font-medium text-warm-state-success">
|
||||
Completed
|
||||
</span>
|
||||
)}
|
||||
{doc.auto_label_status === 'failed' && (
|
||||
<span className="text-sm font-medium text-warm-state-error">
|
||||
Failed
|
||||
</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="py-4 px-4 text-right">
|
||||
<button className="text-warm-text-muted hover:text-warm-text-secondary p-1 rounded hover:bg-black/5 transition-colors">
|
||||
<MoreHorizontal size={18} />
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<UploadModal
|
||||
isOpen={isUploadOpen}
|
||||
onClose={() => {
|
||||
setIsUploadOpen(false)
|
||||
refetch()
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
148
frontend/src/components/DashboardOverview.tsx
Normal file
148
frontend/src/components/DashboardOverview.tsx
Normal file
@@ -0,0 +1,148 @@
|
||||
import React from 'react'
|
||||
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import { useTraining } from '../hooks/useTraining'
|
||||
|
||||
interface DashboardOverviewProps {
|
||||
onNavigate: (view: string) => void
|
||||
}
|
||||
|
||||
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
|
||||
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
|
||||
const { models, isLoadingModels } = useTraining()
|
||||
|
||||
const stats = [
|
||||
{
|
||||
label: 'Total Documents',
|
||||
value: docsLoading ? '...' : totalDocs.toString(),
|
||||
icon: FileText,
|
||||
color: 'text-warm-text-primary',
|
||||
bgColor: 'bg-warm-bg',
|
||||
},
|
||||
{
|
||||
label: 'Labeled',
|
||||
value: '0',
|
||||
icon: CheckCircle,
|
||||
color: 'text-warm-state-success',
|
||||
bgColor: 'bg-green-50',
|
||||
},
|
||||
{
|
||||
label: 'Pending',
|
||||
value: '0',
|
||||
icon: Clock,
|
||||
color: 'text-warm-state-warning',
|
||||
bgColor: 'bg-yellow-50',
|
||||
},
|
||||
{
|
||||
label: 'Training Models',
|
||||
value: isLoadingModels ? '...' : models.length.toString(),
|
||||
icon: TrendingUp,
|
||||
color: 'text-warm-state-info',
|
||||
bgColor: 'bg-blue-50',
|
||||
},
|
||||
]
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
|
||||
{/* Header */}
|
||||
<div className="mb-8">
|
||||
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
|
||||
Dashboard
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mt-1">
|
||||
Overview of your document annotation system
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Stats Grid */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
|
||||
{stats.map((stat) => (
|
||||
<div
|
||||
key={stat.label}
|
||||
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
|
||||
>
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
|
||||
<stat.icon className={stat.color} size={24} />
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-2xl font-bold text-warm-text-primary mb-1">
|
||||
{stat.value}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">{stat.label}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Quick Actions */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
Quick Actions
|
||||
</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<Button onClick={() => onNavigate('documents')} className="justify-start">
|
||||
<FileText size={18} className="mr-2" />
|
||||
Manage Documents
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
|
||||
<Activity size={18} className="mr-2" />
|
||||
Start Training
|
||||
</Button>
|
||||
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
|
||||
<TrendingUp size={18} className="mr-2" />
|
||||
View Models
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Recent Activity */}
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
|
||||
<div className="p-6 border-b border-warm-border">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary">
|
||||
Recent Activity
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6">
|
||||
<div className="text-center py-8 text-warm-text-muted">
|
||||
<Activity size={48} className="mx-auto mb-3 opacity-20" />
|
||||
<p className="text-sm">No recent activity</p>
|
||||
<p className="text-xs mt-1">
|
||||
Start by uploading documents or creating training jobs
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* System Status */}
|
||||
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
|
||||
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
|
||||
System Status
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Backend API</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Online
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">Database</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Connected
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-warm-text-secondary">GPU</span>
|
||||
<span className="flex items-center text-sm text-warm-state-success">
|
||||
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
|
||||
Available
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
176
frontend/src/components/DatasetDetail.tsx
Normal file
176
frontend/src/components/DatasetDetail.tsx
Normal file
@@ -0,0 +1,176 @@
|
||||
import React from 'react'
|
||||
import { ArrowLeft, Loader2, Play, AlertCircle, Check, Award } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDatasetDetail } from '../hooks/useDatasets'
|
||||
|
||||
interface DatasetDetailProps {
|
||||
datasetId: string
|
||||
onBack: () => void
|
||||
}
|
||||
|
||||
const SPLIT_STYLES: Record<string, string> = {
|
||||
train: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
val: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
test: 'bg-warm-state-success/10 text-warm-state-success',
|
||||
}
|
||||
|
||||
const STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
|
||||
building: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Building' },
|
||||
ready: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Ready' },
|
||||
trained: { bg: 'bg-purple-100', text: 'text-purple-700', label: 'Trained' },
|
||||
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
|
||||
archived: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Archived' },
|
||||
}
|
||||
|
||||
const TRAINING_STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
|
||||
pending: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Pending' },
|
||||
scheduled: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Scheduled' },
|
||||
running: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Training' },
|
||||
completed: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Completed' },
|
||||
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
|
||||
cancelled: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Cancelled' },
|
||||
}
|
||||
|
||||
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
|
||||
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-20 text-warm-text-muted">
|
||||
<Loader2 size={24} className="animate-spin mr-2" />Loading dataset...
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (error || !dataset) {
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||
<ArrowLeft size={16} />Back
|
||||
</button>
|
||||
<p className="text-warm-state-error">Failed to load dataset.</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const statusConfig = STATUS_STYLES[dataset.status] || STATUS_STYLES.ready
|
||||
const trainingStatusConfig = dataset.training_status
|
||||
? TRAINING_STATUS_STYLES[dataset.training_status]
|
||||
: null
|
||||
|
||||
// Determine if training button should be shown and enabled
|
||||
const isTrainingInProgress = dataset.training_status === 'running' || dataset.training_status === 'pending'
|
||||
const canStartTraining = dataset.status === 'ready' && !isTrainingInProgress
|
||||
|
||||
// Determine status icon
|
||||
const statusIcon = dataset.status === 'trained'
|
||||
? <Award size={14} className="text-purple-700" />
|
||||
: dataset.status === 'ready'
|
||||
? <Check size={14} className="text-warm-state-success" />
|
||||
: dataset.status === 'failed'
|
||||
? <AlertCircle size={14} className="text-warm-state-error" />
|
||||
: dataset.status === 'building'
|
||||
? <Loader2 size={14} className="animate-spin text-warm-state-info" />
|
||||
: null
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
{/* Header */}
|
||||
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
|
||||
<ArrowLeft size={16} />Back to Datasets
|
||||
</button>
|
||||
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<div className="flex items-center gap-3 mb-1">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
|
||||
{dataset.name} {statusIcon}
|
||||
</h2>
|
||||
{/* Status Badge */}
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${statusConfig.bg} ${statusConfig.text}`}>
|
||||
{statusConfig.label}
|
||||
</span>
|
||||
{/* Training Status Badge */}
|
||||
{trainingStatusConfig && (
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${trainingStatusConfig.bg} ${trainingStatusConfig.text}`}>
|
||||
{isTrainingInProgress && <Loader2 size={12} className="mr-1 animate-spin" />}
|
||||
{trainingStatusConfig.label}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{dataset.description && (
|
||||
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
|
||||
)}
|
||||
</div>
|
||||
{/* Training Button */}
|
||||
{(dataset.status === 'ready' || dataset.status === 'trained') && (
|
||||
<Button
|
||||
disabled={isTrainingInProgress}
|
||||
className={isTrainingInProgress ? 'opacity-50 cursor-not-allowed' : ''}
|
||||
>
|
||||
{isTrainingInProgress ? (
|
||||
<><Loader2 size={14} className="mr-1 animate-spin" />Training...</>
|
||||
) : (
|
||||
<><Play size={14} className="mr-1" />Start Training</>
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{dataset.error_message && (
|
||||
<div className="bg-warm-state-error/10 border border-warm-state-error/20 rounded-lg p-4 mb-6 text-sm text-warm-state-error">
|
||||
{dataset.error_message}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Stats */}
|
||||
<div className="grid grid-cols-4 gap-4 mb-8">
|
||||
{[
|
||||
['Documents', dataset.total_documents],
|
||||
['Images', dataset.total_images],
|
||||
['Annotations', dataset.total_annotations],
|
||||
['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
|
||||
].map(([label, value]) => (
|
||||
<div key={String(label)} className="bg-warm-card border border-warm-border rounded-lg p-4">
|
||||
<p className="text-xs text-warm-text-muted uppercase font-semibold mb-1">{label}</p>
|
||||
<p className="text-2xl font-bold text-warm-text-primary font-mono">{value}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Document list */}
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Documents</h3>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||
<table className="w-full text-left">
|
||||
<thead className="bg-white border-b border-warm-border">
|
||||
<tr>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Split</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{dataset.documents.map(doc => (
|
||||
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||
<td className="py-3 px-4">
|
||||
<span className={`inline-flex px-2.5 py-1 rounded-full text-xs font-medium ${SPLIT_STYLES[doc.split] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||
{doc.split}
|
||||
</span>
|
||||
</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<p className="text-xs text-warm-text-muted mt-4">
|
||||
Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
|
||||
{dataset.dataset_path && <> | Path: <code className="text-xs">{dataset.dataset_path}</code></>}
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
567
frontend/src/components/DocumentDetail.tsx
Normal file
567
frontend/src/components/DocumentDetail.tsx
Normal file
@@ -0,0 +1,567 @@
|
||||
import React, { useState, useRef, useEffect } from 'react'
|
||||
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocumentDetail } from '../hooks/useDocumentDetail'
|
||||
import { useAnnotations } from '../hooks/useAnnotations'
|
||||
import { useDocuments } from '../hooks/useDocuments'
|
||||
import { documentsApi } from '../api/endpoints/documents'
|
||||
import type { AnnotationItem } from '../api/types'
|
||||
|
||||
interface DocumentDetailProps {
|
||||
docId: string
|
||||
onBack: () => void
|
||||
}
|
||||
|
||||
// Field class mapping from backend
|
||||
const FIELD_CLASSES: Record<number, string> = {
|
||||
0: 'invoice_number',
|
||||
1: 'invoice_date',
|
||||
2: 'invoice_due_date',
|
||||
3: 'ocr_number',
|
||||
4: 'bankgiro',
|
||||
5: 'plusgiro',
|
||||
6: 'amount',
|
||||
7: 'supplier_organisation_number',
|
||||
8: 'payment_line',
|
||||
9: 'customer_number',
|
||||
}
|
||||
|
||||
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
|
||||
const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
|
||||
const {
|
||||
createAnnotation,
|
||||
updateAnnotation,
|
||||
deleteAnnotation,
|
||||
isCreating,
|
||||
isDeleting,
|
||||
} = useAnnotations(docId)
|
||||
const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
|
||||
|
||||
const [selectedId, setSelectedId] = useState<string | null>(null)
|
||||
const [zoom, setZoom] = useState(100)
|
||||
const [isDrawing, setIsDrawing] = useState(false)
|
||||
const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
|
||||
const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
|
||||
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
|
||||
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
|
||||
const [selectedClassId, setSelectedClassId] = useState<number>(0)
|
||||
const [currentPage, setCurrentPage] = useState(1)
|
||||
const [imageSize, setImageSize] = useState<{ width: number; height: number } | null>(null)
|
||||
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null)
|
||||
|
||||
const canvasRef = useRef<HTMLDivElement>(null)
|
||||
const imageRef = useRef<HTMLImageElement>(null)
|
||||
|
||||
const [isMarkingComplete, setIsMarkingComplete] = useState(false)
|
||||
|
||||
const selectedAnnotation = annotations?.find((a) => a.annotation_id === selectedId)
|
||||
|
||||
// Handle mark as complete
|
||||
const handleMarkComplete = async () => {
|
||||
if (!annotations || annotations.length === 0) {
|
||||
alert('Please add at least one annotation before marking as complete.')
|
||||
return
|
||||
}
|
||||
|
||||
if (!confirm('Mark this document as labeled? This will save annotations to the database.')) {
|
||||
return
|
||||
}
|
||||
|
||||
setIsMarkingComplete(true)
|
||||
try {
|
||||
const result = await documentsApi.updateStatus(docId, 'labeled')
|
||||
alert(`Document marked as labeled. ${(result as any).fields_saved || annotations.length} annotations saved.`)
|
||||
onBack() // Return to document list
|
||||
} catch (error) {
|
||||
console.error('Failed to mark document as complete:', error)
|
||||
alert('Failed to mark document as complete. Please try again.')
|
||||
} finally {
|
||||
setIsMarkingComplete(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Load image via fetch with authentication header
|
||||
useEffect(() => {
|
||||
let objectUrl: string | null = null
|
||||
|
||||
const loadImage = async () => {
|
||||
if (!docId) return
|
||||
|
||||
const token = localStorage.getItem('admin_token')
|
||||
const imageUrl = `${import.meta.env.VITE_API_URL || 'http://localhost:8000'}/api/v1/admin/documents/${docId}/images/${currentPage}`
|
||||
|
||||
try {
|
||||
const response = await fetch(imageUrl, {
|
||||
headers: {
|
||||
'X-Admin-Token': token || '',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to load image: ${response.status}`)
|
||||
}
|
||||
|
||||
const blob = await response.blob()
|
||||
objectUrl = URL.createObjectURL(blob)
|
||||
setImageBlobUrl(objectUrl)
|
||||
} catch (error) {
|
||||
console.error('Failed to load image:', error)
|
||||
}
|
||||
}
|
||||
|
||||
loadImage()
|
||||
|
||||
// Cleanup: revoke object URL when component unmounts or page changes
|
||||
return () => {
|
||||
if (objectUrl) {
|
||||
URL.revokeObjectURL(objectUrl)
|
||||
}
|
||||
}
|
||||
}, [currentPage, docId])
|
||||
|
||||
// Load image size
|
||||
useEffect(() => {
|
||||
if (imageRef.current && imageRef.current.complete) {
|
||||
setImageSize({
|
||||
width: imageRef.current.naturalWidth,
|
||||
height: imageRef.current.naturalHeight,
|
||||
})
|
||||
}
|
||||
}, [imageBlobUrl])
|
||||
|
||||
const handleImageLoad = () => {
|
||||
if (imageRef.current) {
|
||||
setImageSize({
|
||||
width: imageRef.current.naturalWidth,
|
||||
height: imageRef.current.naturalHeight,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (!canvasRef.current || !imageSize) return
|
||||
const rect = canvasRef.current.getBoundingClientRect()
|
||||
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||
setIsDrawing(true)
|
||||
setDrawStart({ x, y })
|
||||
setDrawEnd({ x, y })
|
||||
}
|
||||
|
||||
const handleMouseMove = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (!isDrawing || !canvasRef.current || !imageSize) return
|
||||
const rect = canvasRef.current.getBoundingClientRect()
|
||||
const x = (e.clientX - rect.left) / (zoom / 100)
|
||||
const y = (e.clientY - rect.top) / (zoom / 100)
|
||||
setDrawEnd({ x, y })
|
||||
}
|
||||
|
||||
const handleMouseUp = () => {
|
||||
if (!isDrawing || !drawStart || !drawEnd || !imageSize) {
|
||||
setIsDrawing(false)
|
||||
return
|
||||
}
|
||||
|
||||
const bbox_x = Math.min(drawStart.x, drawEnd.x)
|
||||
const bbox_y = Math.min(drawStart.y, drawEnd.y)
|
||||
const bbox_width = Math.abs(drawEnd.x - drawStart.x)
|
||||
const bbox_height = Math.abs(drawEnd.y - drawStart.y)
|
||||
|
||||
// Only create if box is large enough (min 10x10 pixels)
|
||||
if (bbox_width > 10 && bbox_height > 10) {
|
||||
createAnnotation({
|
||||
page_number: currentPage,
|
||||
class_id: selectedClassId,
|
||||
bbox: {
|
||||
x: Math.round(bbox_x),
|
||||
y: Math.round(bbox_y),
|
||||
width: Math.round(bbox_width),
|
||||
height: Math.round(bbox_height),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
setIsDrawing(false)
|
||||
setDrawStart(null)
|
||||
setDrawEnd(null)
|
||||
}
|
||||
|
||||
const handleDeleteAnnotation = (annotationId: string) => {
|
||||
if (confirm('Are you sure you want to delete this annotation?')) {
|
||||
deleteAnnotation(annotationId)
|
||||
setSelectedId(null)
|
||||
}
|
||||
}
|
||||
|
||||
if (isLoading || !document) {
|
||||
return (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<div className="text-warm-text-muted">Loading...</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Get current page annotations
|
||||
const pageAnnotations = annotations?.filter((a) => a.page_number === currentPage) || []
|
||||
|
||||
return (
|
||||
<div className="flex h-[calc(100vh-56px)] overflow-hidden">
|
||||
{/* Main Canvas Area */}
|
||||
<div className="flex-1 bg-warm-bg flex flex-col relative">
|
||||
{/* Toolbar */}
|
||||
<div className="h-14 border-b border-warm-border bg-white flex items-center justify-between px-4 z-10">
|
||||
<div className="flex items-center gap-4">
|
||||
<button
|
||||
onClick={onBack}
|
||||
className="p-2 hover:bg-warm-hover rounded-md text-warm-text-secondary transition-colors"
|
||||
>
|
||||
<ChevronLeft size={20} />
|
||||
</button>
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-primary">{document.filename}</h2>
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Page {currentPage} of {document.page_count}
|
||||
</p>
|
||||
</div>
|
||||
<div className="h-6 w-px bg-warm-divider mx-2" />
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||
onClick={() => setZoom((z) => Math.max(50, z - 10))}
|
||||
>
|
||||
<ZoomOut size={16} />
|
||||
</button>
|
||||
<span className="text-xs font-mono w-12 text-center text-warm-text-secondary">
|
||||
{zoom}%
|
||||
</span>
|
||||
<button
|
||||
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
|
||||
onClick={() => setZoom((z) => Math.min(200, z + 10))}
|
||||
>
|
||||
<ZoomIn size={16} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="secondary" size="sm">
|
||||
Auto-label
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="sm"
|
||||
onClick={handleMarkComplete}
|
||||
disabled={isMarkingComplete || document.status === 'labeled'}
|
||||
>
|
||||
<CheckCircle size={16} className="mr-1" />
|
||||
{isMarkingComplete ? 'Saving...' : document.status === 'labeled' ? 'Labeled' : 'Mark Complete'}
|
||||
</Button>
|
||||
{document.page_count > 1 && (
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={() => setCurrentPage((p) => Math.max(1, p - 1))}
|
||||
disabled={currentPage === 1}
|
||||
>
|
||||
Prev
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={() => setCurrentPage((p) => Math.min(document.page_count, p + 1))}
|
||||
disabled={currentPage === document.page_count}
|
||||
>
|
||||
Next
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Canvas Scroll Area */}
|
||||
<div className="flex-1 overflow-auto p-8 flex justify-center bg-warm-bg">
|
||||
<div
|
||||
ref={canvasRef}
|
||||
className="bg-white shadow-lg relative transition-transform duration-200 ease-out origin-top"
|
||||
style={{
|
||||
width: imageSize?.width || 800,
|
||||
height: imageSize?.height || 1132,
|
||||
transform: `scale(${zoom / 100})`,
|
||||
marginBottom: '100px',
|
||||
cursor: isDrawing ? 'crosshair' : 'default',
|
||||
}}
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseUp={handleMouseUp}
|
||||
onClick={() => setSelectedId(null)}
|
||||
>
|
||||
{/* Document Image */}
|
||||
{imageBlobUrl ? (
|
||||
<img
|
||||
ref={imageRef}
|
||||
src={imageBlobUrl}
|
||||
alt={`Page ${currentPage}`}
|
||||
className="w-full h-full object-contain select-none pointer-events-none"
|
||||
onLoad={handleImageLoad}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex items-center justify-center h-full">
|
||||
<div className="text-warm-text-muted">Loading image...</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Annotation Overlays */}
|
||||
{pageAnnotations.map((ann) => {
|
||||
const isSelected = selectedId === ann.annotation_id
|
||||
return (
|
||||
<div
|
||||
key={ann.annotation_id}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
setSelectedId(ann.annotation_id)
|
||||
}}
|
||||
className={`
|
||||
absolute group cursor-pointer transition-all duration-100
|
||||
${
|
||||
ann.source === 'auto'
|
||||
? 'border border-dashed border-warm-text-muted bg-transparent'
|
||||
: 'border-2 border-warm-text-secondary bg-warm-text-secondary/5'
|
||||
}
|
||||
${
|
||||
isSelected
|
||||
? 'border-2 border-warm-state-info ring-4 ring-warm-state-info/10 z-20'
|
||||
: 'hover:bg-warm-state-info/5 z-10'
|
||||
}
|
||||
`}
|
||||
style={{
|
||||
left: ann.bbox.x,
|
||||
top: ann.bbox.y,
|
||||
width: ann.bbox.width,
|
||||
height: ann.bbox.height,
|
||||
}}
|
||||
>
|
||||
{/* Label Tag */}
|
||||
<div
|
||||
className={`
|
||||
absolute -top-6 left-0 text-[10px] uppercase font-bold px-1.5 py-0.5 rounded-sm tracking-wide shadow-sm whitespace-nowrap
|
||||
${
|
||||
isSelected
|
||||
? 'bg-warm-state-info text-white'
|
||||
: 'bg-white text-warm-text-secondary border border-warm-border'
|
||||
}
|
||||
`}
|
||||
>
|
||||
{ann.class_name}
|
||||
</div>
|
||||
|
||||
{/* Resize Handles (Visual only) */}
|
||||
{isSelected && (
|
||||
<>
|
||||
<div className="absolute -top-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -top-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -bottom-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
<div className="absolute -bottom-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
|
||||
{/* Drawing Box Preview */}
|
||||
{isDrawing && drawStart && drawEnd && (
|
||||
<div
|
||||
className="absolute border-2 border-warm-state-info bg-warm-state-info/10 z-30 pointer-events-none"
|
||||
style={{
|
||||
left: Math.min(drawStart.x, drawEnd.x),
|
||||
top: Math.min(drawStart.y, drawEnd.y),
|
||||
width: Math.abs(drawEnd.x - drawStart.x),
|
||||
height: Math.abs(drawEnd.y - drawStart.y),
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right Sidebar */}
|
||||
<div className="w-80 bg-white border-l border-warm-border flex flex-col shadow-[-4px_0_15px_-3px_rgba(0,0,0,0.03)] z-20">
|
||||
{/* Field Selector */}
|
||||
<div className="p-4 border-b border-warm-border">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Draw Annotation</h3>
|
||||
<div className="space-y-2">
|
||||
<label className="block text-xs text-warm-text-muted mb-1">Select Field Type</label>
|
||||
<select
|
||||
value={selectedClassId}
|
||||
onChange={(e) => setSelectedClassId(Number(e.target.value))}
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
>
|
||||
{Object.entries(FIELD_CLASSES).map(([id, name]) => (
|
||||
<option key={id} value={id}>
|
||||
{name.replace(/_/g, ' ')}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-warm-text-muted mt-2">
|
||||
Click and drag on the document to create a bounding box
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Document Info Card */}
|
||||
<div className="p-4 border-b border-warm-border">
|
||||
<div className="bg-white rounded-lg border border-warm-border p-4 shadow-sm">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Document Info</h3>
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Status</span>
|
||||
<span className="text-warm-text-secondary font-medium capitalize">
|
||||
{document.status}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Size</span>
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{(document.file_size / 1024 / 1024).toFixed(2)} MB
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between text-xs">
|
||||
<span className="text-warm-text-muted">Uploaded</span>
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{new Date(document.created_at).toLocaleDateString()}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-warm-text-muted">Group</span>
|
||||
{isEditingGroupKey ? (
|
||||
<div className="flex items-center gap-1">
|
||||
<input
|
||||
type="text"
|
||||
value={editGroupKeyValue}
|
||||
onChange={(e) => setEditGroupKeyValue(e.target.value)}
|
||||
className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
placeholder="group key"
|
||||
autoFocus
|
||||
/>
|
||||
<button
|
||||
onClick={() => {
|
||||
updateGroupKey(
|
||||
{ documentId: docId, groupKey: editGroupKeyValue.trim() || null },
|
||||
{
|
||||
onSuccess: () => {
|
||||
setIsEditingGroupKey(false)
|
||||
refetch()
|
||||
},
|
||||
onError: () => {
|
||||
alert('Failed to update group key. Please try again.')
|
||||
},
|
||||
}
|
||||
)
|
||||
}}
|
||||
disabled={isUpdatingGroupKey}
|
||||
className="p-0.5 text-warm-state-success hover:bg-warm-hover rounded"
|
||||
>
|
||||
<Check size={14} />
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setIsEditingGroupKey(false)
|
||||
setEditGroupKeyValue(document.group_key || '')
|
||||
}}
|
||||
className="p-0.5 text-warm-state-error hover:bg-warm-hover rounded"
|
||||
>
|
||||
<X size={14} />
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center gap-1">
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
{document.group_key || '-'}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => {
|
||||
setEditGroupKeyValue(document.group_key || '')
|
||||
setIsEditingGroupKey(true)
|
||||
}}
|
||||
className="p-0.5 text-warm-text-muted hover:text-warm-text-secondary hover:bg-warm-hover rounded"
|
||||
>
|
||||
<Edit2 size={12} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Annotations List */}
|
||||
<div className="flex-1 overflow-y-auto p-4">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h3 className="text-sm font-semibold text-warm-text-primary">Annotations</h3>
|
||||
<span className="text-xs text-warm-text-muted">{pageAnnotations.length} items</span>
|
||||
</div>
|
||||
|
||||
{pageAnnotations.length === 0 ? (
|
||||
<div className="text-center py-8 text-warm-text-muted">
|
||||
<Tag size={48} className="mx-auto mb-3 opacity-20" />
|
||||
<p className="text-sm">No annotations yet</p>
|
||||
<p className="text-xs mt-1">Draw on the document to add annotations</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-3">
|
||||
{pageAnnotations.map((ann) => (
|
||||
<div
|
||||
key={ann.annotation_id}
|
||||
onClick={() => setSelectedId(ann.annotation_id)}
|
||||
className={`
|
||||
group p-3 rounded-md border transition-all duration-150 cursor-pointer
|
||||
${
|
||||
selectedId === ann.annotation_id
|
||||
? 'bg-warm-bg border-warm-state-info shadow-sm'
|
||||
: 'bg-white border-warm-border hover:border-warm-text-muted'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex justify-between items-start mb-1">
|
||||
<span className="text-xs font-bold text-warm-text-secondary uppercase tracking-wider">
|
||||
{ann.class_name.replace(/_/g, ' ')}
|
||||
</span>
|
||||
{selectedId === ann.annotation_id && (
|
||||
<div className="flex gap-1">
|
||||
<button
|
||||
onClick={() => handleDeleteAnnotation(ann.annotation_id)}
|
||||
className="text-warm-text-muted hover:text-warm-state-error"
|
||||
disabled={isDeleting}
|
||||
>
|
||||
<Trash2 size={12} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<p className="text-sm text-warm-text-muted font-mono truncate">
|
||||
{ann.text_value || '(no text)'}
|
||||
</p>
|
||||
<div className="flex items-center gap-2 mt-2">
|
||||
<span
|
||||
className={`text-[10px] px-1.5 py-0.5 rounded ${
|
||||
ann.source === 'auto'
|
||||
? 'bg-blue-50 text-blue-700'
|
||||
: 'bg-green-50 text-green-700'
|
||||
}`}
|
||||
>
|
||||
{ann.source}
|
||||
</span>
|
||||
{ann.confidence && (
|
||||
<span className="text-[10px] text-warm-text-muted">
|
||||
{(ann.confidence * 100).toFixed(0)}%
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
466
frontend/src/components/InferenceDemo.tsx
Normal file
466
frontend/src/components/InferenceDemo.tsx
Normal file
@@ -0,0 +1,466 @@
|
||||
import React, { useState, useRef } from 'react'
|
||||
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { inferenceApi } from '../api/endpoints'
|
||||
import type { InferenceResult } from '../api/types'
|
||||
|
||||
export const InferenceDemo: React.FC = () => {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [selectedFile, setSelectedFile] = useState<File | null>(null)
|
||||
const [isProcessing, setIsProcessing] = useState(false)
|
||||
const [result, setResult] = useState<InferenceResult | null>(null)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const handleFileSelect = (file: File | null) => {
|
||||
if (!file) return
|
||||
|
||||
const validTypes = ['application/pdf', 'image/png', 'image/jpeg', 'image/jpg']
|
||||
if (!validTypes.includes(file.type)) {
|
||||
setError('Please upload a PDF, PNG, or JPG file')
|
||||
return
|
||||
}
|
||||
|
||||
if (file.size > 50 * 1024 * 1024) {
|
||||
setError('File size must be less than 50MB')
|
||||
return
|
||||
}
|
||||
|
||||
setSelectedFile(file)
|
||||
setResult(null)
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(false)
|
||||
if (e.dataTransfer.files.length > 0) {
|
||||
handleFileSelect(e.dataTransfer.files[0])
|
||||
}
|
||||
}
|
||||
|
||||
const handleBrowseClick = () => {
|
||||
fileInputRef.current?.click()
|
||||
}
|
||||
|
||||
const handleProcess = async () => {
|
||||
if (!selectedFile) return
|
||||
|
||||
setIsProcessing(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const response = await inferenceApi.processDocument(selectedFile)
|
||||
console.log('API Response:', response)
|
||||
console.log('Visualization URL:', response.result?.visualization_url)
|
||||
setResult(response.result)
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Processing failed')
|
||||
} finally {
|
||||
setIsProcessing(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleReset = () => {
|
||||
setSelectedFile(null)
|
||||
setResult(null)
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const formatFieldName = (field: string): string => {
|
||||
const fieldNames: Record<string, string> = {
|
||||
InvoiceNumber: 'Invoice Number',
|
||||
InvoiceDate: 'Invoice Date',
|
||||
InvoiceDueDate: 'Due Date',
|
||||
OCR: 'OCR Number',
|
||||
Amount: 'Amount',
|
||||
Bankgiro: 'Bankgiro',
|
||||
Plusgiro: 'Plusgiro',
|
||||
supplier_org_number: 'Supplier Org Number',
|
||||
customer_number: 'Customer Number',
|
||||
payment_line: 'Payment Line',
|
||||
}
|
||||
return fieldNames[field] || field
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="max-w-7xl mx-auto px-4 py-6 space-y-6">
|
||||
{/* Header */}
|
||||
<div className="text-center">
|
||||
<h2 className="text-3xl font-bold text-warm-text-primary mb-2">
|
||||
Invoice Extraction Demo
|
||||
</h2>
|
||||
<p className="text-warm-text-muted">
|
||||
Upload a Swedish invoice to see our AI-powered field extraction in action
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Upload Area */}
|
||||
{!result && (
|
||||
<div className="max-w-2xl mx-auto">
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-8 shadow-sm">
|
||||
<div
|
||||
className={`
|
||||
relative h-72 rounded-xl border-2 border-dashed transition-all duration-200
|
||||
${isDragging
|
||||
? 'border-warm-text-secondary bg-warm-selected scale-[1.02]'
|
||||
: 'border-warm-divider bg-warm-bg hover:bg-warm-hover hover:border-warm-text-secondary/50'
|
||||
}
|
||||
${isProcessing ? 'opacity-60 pointer-events-none' : 'cursor-pointer'}
|
||||
`}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(true)
|
||||
}}
|
||||
onDragLeave={() => setIsDragging(false)}
|
||||
onDrop={handleDrop}
|
||||
onClick={handleBrowseClick}
|
||||
>
|
||||
<div className="absolute inset-0 flex flex-col items-center justify-center gap-6">
|
||||
{isProcessing ? (
|
||||
<>
|
||||
<Loader2 size={56} className="text-warm-text-secondary animate-spin" />
|
||||
<div className="text-center">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||
Processing invoice...
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">
|
||||
This may take a few moments
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
) : selectedFile ? (
|
||||
<>
|
||||
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||
<FileText size={40} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center px-4">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-1">
|
||||
{selectedFile.name}
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted">
|
||||
{(selectedFile.size / 1024 / 1024).toFixed(2)} MB
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
|
||||
<UploadCloud size={40} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center px-4">
|
||||
<p className="text-lg font-semibold text-warm-text-primary mb-2">
|
||||
Drag & drop invoice here
|
||||
</p>
|
||||
<p className="text-sm text-warm-text-muted mb-3">
|
||||
or{' '}
|
||||
<span className="text-warm-text-secondary font-medium">
|
||||
browse files
|
||||
</span>
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Supports PDF, PNG, JPG (up to 50MB)
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".pdf,image/*"
|
||||
className="hidden"
|
||||
onChange={(e) => handleFileSelect(e.target.files?.[0] || null)}
|
||||
/>
|
||||
|
||||
{error && (
|
||||
<div className="mt-5 p-4 bg-red-50 border border-red-200 rounded-lg flex items-start gap-3">
|
||||
<AlertCircle size={18} className="text-red-600 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-sm text-red-800 font-medium">{error}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedFile && !isProcessing && (
|
||||
<div className="mt-6 flex gap-3 justify-end">
|
||||
<Button variant="secondary" onClick={handleReset}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleProcess}>Process Invoice</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Results */}
|
||||
{result && (
|
||||
<div className="space-y-6">
|
||||
{/* Status Header */}
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border shadow-sm overflow-hidden">
|
||||
<div className="p-6 flex items-center justify-between border-b border-warm-divider">
|
||||
<div className="flex items-center gap-4">
|
||||
{result.success ? (
|
||||
<div className="p-3 bg-green-100 rounded-xl">
|
||||
<CheckCircle2 size={28} className="text-green-600" />
|
||||
</div>
|
||||
) : (
|
||||
<div className="p-3 bg-yellow-100 rounded-xl">
|
||||
<AlertCircle size={28} className="text-yellow-600" />
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<h3 className="text-xl font-bold text-warm-text-primary">
|
||||
{result.success ? 'Extraction Complete' : 'Partial Results'}
|
||||
</h3>
|
||||
<p className="text-sm text-warm-text-muted mt-0.5">
|
||||
Document ID: <span className="font-mono">{result.document_id}</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button variant="secondary" onClick={handleReset}>
|
||||
Process Another
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="px-6 py-4 bg-warm-bg/50 flex items-center gap-6 text-sm">
|
||||
<div className="flex items-center gap-2 text-warm-text-secondary">
|
||||
<Clock size={16} />
|
||||
<span className="font-medium">
|
||||
{result.processing_time_ms.toFixed(0)}ms
|
||||
</span>
|
||||
</div>
|
||||
{result.fallback_used && (
|
||||
<span className="px-3 py-1.5 bg-warm-selected rounded-md text-warm-text-secondary font-medium text-xs">
|
||||
Fallback OCR Used
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Main Content Grid */}
|
||||
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
|
||||
{/* Left Column: Extracted Fields */}
|
||||
<div className="lg:col-span-2 space-y-6">
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Extracted Fields
|
||||
</h3>
|
||||
<div className="flex flex-wrap gap-4">
|
||||
{Object.entries(result.fields).map(([field, value]) => {
|
||||
const confidence = result.confidence[field]
|
||||
return (
|
||||
<div
|
||||
key={field}
|
||||
className="p-4 bg-warm-bg/70 rounded-lg border border-warm-divider hover:border-warm-text-secondary/30 transition-colors w-[calc(50%-0.5rem)]"
|
||||
>
|
||||
<div className="text-xs font-semibold text-warm-text-muted uppercase tracking-wide mb-2">
|
||||
{formatFieldName(field)}
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary mb-2 min-h-[1.5rem]">
|
||||
{value || <span className="text-warm-text-muted italic">N/A</span>}
|
||||
</div>
|
||||
{confidence && (
|
||||
<div className="flex items-center gap-1.5 text-xs font-medium text-warm-text-secondary">
|
||||
<CheckCircle2 size={13} />
|
||||
<span>{(confidence * 100).toFixed(1)}%</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Visualization */}
|
||||
{result.visualization_url && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Detection Visualization
|
||||
</h3>
|
||||
<div className="bg-warm-bg rounded-lg overflow-hidden border border-warm-divider">
|
||||
<img
|
||||
src={`${import.meta.env.VITE_API_URL || 'http://localhost:8000'}${result.visualization_url}`}
|
||||
alt="Detection visualization"
|
||||
className="w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right Column: Cross-Validation & Errors */}
|
||||
<div className="space-y-6">
|
||||
{/* Cross-Validation */}
|
||||
{result.cross_validation && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
|
||||
Payment Line Validation
|
||||
</h3>
|
||||
|
||||
<div
|
||||
className={`
|
||||
p-4 rounded-lg mb-4 flex items-center gap-3
|
||||
${result.cross_validation.is_valid
|
||||
? 'bg-green-50 border border-green-200'
|
||||
: 'bg-yellow-50 border border-yellow-200'
|
||||
}
|
||||
`}
|
||||
>
|
||||
{result.cross_validation.is_valid ? (
|
||||
<>
|
||||
<CheckCircle2 size={22} className="text-green-600 flex-shrink-0" />
|
||||
<span className="font-bold text-green-800">All Fields Match</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<AlertCircle size={22} className="text-yellow-600 flex-shrink-0" />
|
||||
<span className="font-bold text-yellow-800">Mismatch Detected</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2.5">
|
||||
{result.cross_validation.payment_line_ocr && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${result.cross_validation.ocr_match === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: result.cross_validation.ocr_match === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
OCR NUMBER
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_ocr}
|
||||
</div>
|
||||
</div>
|
||||
{result.cross_validation.ocr_match === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{result.cross_validation.ocr_match === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.cross_validation.payment_line_amount && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${result.cross_validation.amount_match === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: result.cross_validation.amount_match === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
AMOUNT
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_amount}
|
||||
</div>
|
||||
</div>
|
||||
{result.cross_validation.amount_match === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{result.cross_validation.amount_match === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.cross_validation.payment_line_account && (
|
||||
<div
|
||||
className={`
|
||||
p-3 rounded-lg border transition-colors
|
||||
${(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === true
|
||||
? 'bg-green-50 border-green-200'
|
||||
: (result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === false
|
||||
? 'bg-red-50 border-red-200'
|
||||
: 'bg-warm-bg border-warm-divider'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="text-xs font-semibold text-warm-text-muted mb-1">
|
||||
{result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? 'BANKGIRO'
|
||||
: 'PLUSGIRO'}
|
||||
</div>
|
||||
<div className="text-sm font-bold text-warm-text-primary font-mono">
|
||||
{result.cross_validation.payment_line_account}
|
||||
</div>
|
||||
</div>
|
||||
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === true && (
|
||||
<CheckCircle2 size={16} className="text-green-600" />
|
||||
)}
|
||||
{(result.cross_validation.payment_line_account_type === 'bankgiro'
|
||||
? result.cross_validation.bankgiro_match
|
||||
: result.cross_validation.plusgiro_match) === false && (
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{result.cross_validation.details.length > 0 && (
|
||||
<div className="mt-4 p-3 bg-warm-bg/70 rounded-lg text-xs text-warm-text-secondary leading-relaxed border border-warm-divider">
|
||||
{result.cross_validation.details[result.cross_validation.details.length - 1]}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Errors */}
|
||||
{result.errors.length > 0 && (
|
||||
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
|
||||
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
|
||||
<span className="w-1 h-5 bg-red-500 rounded-full"></span>
|
||||
Issues
|
||||
</h3>
|
||||
<div className="space-y-2.5">
|
||||
{result.errors.map((err, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="p-3 bg-yellow-50 border border-yellow-200 rounded-lg flex items-start gap-3"
|
||||
>
|
||||
<AlertCircle size={16} className="text-yellow-600 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-xs text-yellow-800 leading-relaxed">{err}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
102
frontend/src/components/Layout.tsx
Normal file
102
frontend/src/components/Layout.tsx
Normal file
@@ -0,0 +1,102 @@
|
||||
import React, { useState } from 'react';
|
||||
import { Box, LayoutTemplate, Users, BookOpen, LogOut, Sparkles } from 'lucide-react';
|
||||
|
||||
interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
activeView: string;
|
||||
onNavigate: (view: string) => void;
|
||||
onLogout?: () => void;
|
||||
}
|
||||
|
||||
export const Layout: React.FC<LayoutProps> = ({ children, activeView, onNavigate, onLogout }) => {
|
||||
const [showDropdown, setShowDropdown] = useState(false);
|
||||
const navItems = [
|
||||
{ id: 'dashboard', label: 'Dashboard', icon: LayoutTemplate },
|
||||
{ id: 'demo', label: 'Demo', icon: Sparkles },
|
||||
{ id: 'training', label: 'Training', icon: Box }, // Mapped to Compliants visually in prompt, using logical name
|
||||
{ id: 'documents', label: 'Documents', icon: BookOpen },
|
||||
{ id: 'models', label: 'Models', icon: Users }, // Contacts in prompt, mapped to models for this use case
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-warm-bg font-sans text-warm-text-primary flex flex-col">
|
||||
{/* Top Navigation */}
|
||||
<nav className="h-14 bg-warm-bg border-b border-warm-border px-6 flex items-center justify-between shrink-0 sticky top-0 z-40">
|
||||
<div className="flex items-center gap-8">
|
||||
{/* Logo */}
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-8 h-8 bg-warm-text-primary rounded-full flex items-center justify-center text-white">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="3" strokeLinecap="round" strokeLinejoin="round">
|
||||
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Nav Links */}
|
||||
<div className="flex h-14">
|
||||
{navItems.map(item => {
|
||||
const isActive = activeView === item.id || (activeView === 'detail' && item.id === 'documents');
|
||||
return (
|
||||
<button
|
||||
key={item.id}
|
||||
onClick={() => onNavigate(item.id)}
|
||||
className={`
|
||||
relative px-4 h-full flex items-center text-sm font-medium transition-colors
|
||||
${isActive ? 'text-warm-text-primary' : 'text-warm-text-muted hover:text-warm-text-secondary'}
|
||||
`}
|
||||
>
|
||||
{item.label}
|
||||
{isActive && (
|
||||
<div className="absolute bottom-0 left-0 right-0 h-0.5 bg-warm-text-secondary rounded-t-full mx-2" />
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* User Profile */}
|
||||
<div className="flex items-center gap-3 pl-6 border-l border-warm-border h-6 relative">
|
||||
<button
|
||||
onClick={() => setShowDropdown(!showDropdown)}
|
||||
className="w-8 h-8 rounded-full bg-warm-selected flex items-center justify-center text-xs font-semibold text-warm-text-secondary border border-warm-divider hover:bg-warm-hover transition-colors"
|
||||
>
|
||||
AD
|
||||
</button>
|
||||
|
||||
{showDropdown && (
|
||||
<>
|
||||
<div
|
||||
className="fixed inset-0 z-10"
|
||||
onClick={() => setShowDropdown(false)}
|
||||
/>
|
||||
<div className="absolute right-0 top-10 w-48 bg-warm-card border border-warm-border rounded-lg shadow-modal z-20">
|
||||
<div className="p-3 border-b border-warm-border">
|
||||
<p className="text-sm font-medium text-warm-text-primary">Admin User</p>
|
||||
<p className="text-xs text-warm-text-muted mt-0.5">Authenticated</p>
|
||||
</div>
|
||||
{onLogout && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setShowDropdown(false)
|
||||
onLogout()
|
||||
}}
|
||||
className="w-full px-3 py-2 text-left text-sm text-warm-text-secondary hover:bg-warm-hover transition-colors flex items-center gap-2"
|
||||
>
|
||||
<LogOut size={14} />
|
||||
Sign Out
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="flex-1 overflow-auto">
|
||||
{children}
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
188
frontend/src/components/Login.tsx
Normal file
188
frontend/src/components/Login.tsx
Normal file
@@ -0,0 +1,188 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Button } from './Button'
|
||||
|
||||
interface LoginProps {
|
||||
onLogin: (token: string) => void
|
||||
}
|
||||
|
||||
export const Login: React.FC<LoginProps> = ({ onLogin }) => {
|
||||
const [token, setToken] = useState('')
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [isCreating, setIsCreating] = useState(false)
|
||||
const [error, setError] = useState('')
|
||||
const [createdToken, setCreatedToken] = useState('')
|
||||
|
||||
const handleLoginWithToken = () => {
|
||||
if (!token.trim()) {
|
||||
setError('Please enter a token')
|
||||
return
|
||||
}
|
||||
localStorage.setItem('admin_token', token.trim())
|
||||
onLogin(token.trim())
|
||||
}
|
||||
|
||||
const handleCreateToken = async () => {
|
||||
if (!name.trim()) {
|
||||
setError('Please enter a token name')
|
||||
return
|
||||
}
|
||||
|
||||
setIsCreating(true)
|
||||
setError('')
|
||||
|
||||
try {
|
||||
const response = await fetch('http://localhost:8000/api/v1/admin/auth/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: name.trim(),
|
||||
description: description.trim() || undefined,
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to create token')
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
setCreatedToken(data.token)
|
||||
setToken(data.token)
|
||||
setError('')
|
||||
} catch (err) {
|
||||
setError('Failed to create token. Please check your connection.')
|
||||
console.error(err)
|
||||
} finally {
|
||||
setIsCreating(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleUseCreatedToken = () => {
|
||||
if (createdToken) {
|
||||
localStorage.setItem('admin_token', createdToken)
|
||||
onLogin(createdToken)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-warm-bg flex items-center justify-center p-4">
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg shadow-modal p-8 max-w-md w-full">
|
||||
<h1 className="text-2xl font-bold text-warm-text-primary mb-2">
|
||||
Admin Authentication
|
||||
</h1>
|
||||
<p className="text-sm text-warm-text-muted mb-6">
|
||||
Sign in with an admin token to access the document management system
|
||||
</p>
|
||||
|
||||
{error && (
|
||||
<div className="mb-4 p-3 bg-red-50 border border-red-200 text-red-800 rounded text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{createdToken && (
|
||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded">
|
||||
<p className="text-sm font-medium text-green-800 mb-2">Token created successfully!</p>
|
||||
<div className="bg-white border border-green-300 rounded p-2 mb-3">
|
||||
<code className="text-xs font-mono text-warm-text-primary break-all">
|
||||
{createdToken}
|
||||
</code>
|
||||
</div>
|
||||
<p className="text-xs text-green-700 mb-3">
|
||||
Save this token securely. You won't be able to see it again.
|
||||
</p>
|
||||
<Button onClick={handleUseCreatedToken} className="w-full">
|
||||
Use This Token
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Login with existing token */}
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||
Sign in with existing token
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Admin Token
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={token}
|
||||
onChange={(e) => setToken(e.target.value)}
|
||||
placeholder="Enter your admin token"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info font-mono"
|
||||
onKeyDown={(e) => e.key === 'Enter' && handleLoginWithToken()}
|
||||
/>
|
||||
</div>
|
||||
<Button onClick={handleLoginWithToken} className="w-full">
|
||||
Sign In
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="relative">
|
||||
<div className="absolute inset-0 flex items-center">
|
||||
<div className="w-full border-t border-warm-border"></div>
|
||||
</div>
|
||||
<div className="relative flex justify-center text-xs">
|
||||
<span className="px-2 bg-warm-card text-warm-text-muted">OR</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Create new token */}
|
||||
<div>
|
||||
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
|
||||
Create new admin token
|
||||
</h2>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Token Name <span className="text-red-500">*</span>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="e.g., my-laptop"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm text-warm-text-secondary mb-1">
|
||||
Description (optional)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="e.g., Personal laptop access"
|
||||
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
onClick={handleCreateToken}
|
||||
variant="secondary"
|
||||
disabled={isCreating}
|
||||
className="w-full"
|
||||
>
|
||||
{isCreating ? 'Creating...' : 'Create Token'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-6 pt-4 border-t border-warm-border">
|
||||
<p className="text-xs text-warm-text-muted">
|
||||
Admin tokens are used to authenticate with the document management API.
|
||||
Keep your tokens secure and never share them.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
208
frontend/src/components/Models.tsx
Normal file
208
frontend/src/components/Models.tsx
Normal file
@@ -0,0 +1,208 @@
|
||||
import React, { useState } from 'react';
|
||||
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
|
||||
import { Loader2, Power, CheckCircle } from 'lucide-react';
|
||||
import { Button } from './Button';
|
||||
import { useModels, useModelDetail } from '../hooks';
|
||||
import type { ModelVersionItem } from '../api/types';
|
||||
|
||||
const formatDate = (dateString: string | null): string => {
|
||||
if (!dateString) return 'N/A';
|
||||
return new Date(dateString).toLocaleString();
|
||||
};
|
||||
|
||||
export const Models: React.FC = () => {
|
||||
const [selectedModel, setSelectedModel] = useState<ModelVersionItem | null>(null);
|
||||
const { models, isLoading, activateModel, isActivating } = useModels();
|
||||
const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
|
||||
|
||||
// Build chart data from selected model's metrics
|
||||
const metricsData = modelDetail ? [
|
||||
{ name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
|
||||
{ name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
|
||||
{ name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
|
||||
] : [
|
||||
{ name: 'Precision', value: 0 },
|
||||
{ name: 'Recall', value: 0 },
|
||||
{ name: 'mAP', value: 0 },
|
||||
];
|
||||
|
||||
// Build comparison chart from all models (with placeholder if empty)
|
||||
const chartData = models.length > 0
|
||||
? models.slice(0, 4).map(m => ({
|
||||
name: m.version,
|
||||
value: (m.metrics_mAP ?? 0) * 100,
|
||||
}))
|
||||
: [
|
||||
{ name: 'Model A', value: 0 },
|
||||
{ name: 'Model B', value: 0 },
|
||||
{ name: 'Model C', value: 0 },
|
||||
{ name: 'Model D', value: 0 },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto flex gap-8">
|
||||
{/* Left: Job History */}
|
||||
<div className="flex-1">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Model Versions</h3>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Loader2 className="animate-spin text-warm-text-muted" size={32} />
|
||||
</div>
|
||||
) : models.length === 0 ? (
|
||||
<div className="text-center py-12 text-warm-text-muted">
|
||||
No model versions found. Complete a training task to create a model version.
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
{models.map(model => (
|
||||
<div
|
||||
key={model.version_id}
|
||||
onClick={() => setSelectedModel(model)}
|
||||
className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
|
||||
selectedModel?.version_id === model.version_id
|
||||
? 'border-warm-text-secondary'
|
||||
: 'border-warm-border hover:border-warm-divider'
|
||||
}`}
|
||||
>
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<div>
|
||||
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">
|
||||
{model.name}
|
||||
{model.is_active && <CheckCircle size={16} className="inline ml-2 text-warm-state-info" />}
|
||||
</h4>
|
||||
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
|
||||
</div>
|
||||
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
|
||||
model.is_active
|
||||
? 'bg-warm-state-info/10 text-warm-state-info'
|
||||
: 'bg-warm-selected text-warm-state-success'
|
||||
}`}>
|
||||
{model.is_active ? 'Active' : model.status}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="mt-4 flex gap-8">
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Documents</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{model.document_count}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">mAP</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">
|
||||
{model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Version</span>
|
||||
<span className="text-lg font-mono text-warm-text-secondary">{model.version}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right: Model Detail */}
|
||||
<div className="w-[400px]">
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
|
||||
<span className={`text-sm font-medium ${
|
||||
selectedModel?.is_active ? 'text-warm-state-info' : 'text-warm-state-success'
|
||||
}`}>
|
||||
{selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="mb-8">
|
||||
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
|
||||
<p className="font-medium text-warm-text-primary">
|
||||
{selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-8">
|
||||
{/* Chart 1 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Model Comparison (mAP)</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={chartData}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
|
||||
formatter={(value: number) => [`${value.toFixed(1)}%`, 'mAP']}
|
||||
/>
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Chart 2 */}
|
||||
<div>
|
||||
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Performance Metrics</h4>
|
||||
<div className="h-40">
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<BarChart data={metricsData}>
|
||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
|
||||
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
|
||||
<YAxis hide domain={[0, 100]} />
|
||||
<Tooltip
|
||||
cursor={{fill: '#F1F0ED'}}
|
||||
formatter={(value: number) => [`${value.toFixed(1)}%`, 'Score']}
|
||||
/>
|
||||
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-8 space-y-3">
|
||||
{selectedModel && !selectedModel.is_active ? (
|
||||
<Button
|
||||
className="w-full"
|
||||
onClick={() => activateModel(selectedModel.version_id)}
|
||||
disabled={isActivating}
|
||||
>
|
||||
{isActivating ? (
|
||||
<>
|
||||
<Loader2 size={16} className="mr-2 animate-spin" />
|
||||
Activating...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Power size={16} className="mr-2" />
|
||||
Activate for Inference
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
) : (
|
||||
<Button className="w-full" disabled={!selectedModel}>
|
||||
{selectedModel?.is_active ? (
|
||||
<>
|
||||
<CheckCircle size={16} className="mr-2" />
|
||||
Currently Active
|
||||
</>
|
||||
) : (
|
||||
'Select a Model'
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
<div className="flex gap-3">
|
||||
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>View Logs</Button>
|
||||
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>Use as Base</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
487
frontend/src/components/Training.tsx
Normal file
487
frontend/src/components/Training.tsx
Normal file
@@ -0,0 +1,487 @@
|
||||
import React, { useState, useMemo } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { AugmentationConfig } from './AugmentationConfig'
|
||||
import { useDatasets } from '../hooks/useDatasets'
|
||||
import { useTrainingDocuments } from '../hooks/useTraining'
|
||||
import { trainingApi } from '../api/endpoints'
|
||||
import type { DatasetListItem } from '../api/types'
|
||||
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
|
||||
|
||||
type Tab = 'datasets' | 'create'
|
||||
|
||||
interface TrainingProps {
|
||||
onNavigate?: (view: string, id?: string) => void
|
||||
}
|
||||
|
||||
const STATUS_STYLES: Record<string, string> = {
|
||||
ready: 'bg-warm-state-success/10 text-warm-state-success',
|
||||
building: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
training: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
failed: 'bg-warm-state-error/10 text-warm-state-error',
|
||||
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
|
||||
running: 'bg-warm-state-info/10 text-warm-state-info',
|
||||
}
|
||||
|
||||
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
|
||||
// If there's an active training task, show training status
|
||||
const displayStatus = trainingStatus === 'running'
|
||||
? 'training'
|
||||
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
|
||||
? 'pending'
|
||||
: status
|
||||
|
||||
return (
|
||||
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
|
||||
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
|
||||
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
|
||||
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
|
||||
{displayStatus}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Train Dialog ---
|
||||
|
||||
interface TrainDialogProps {
|
||||
dataset: DatasetListItem
|
||||
onClose: () => void
|
||||
onSubmit: (config: {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs: number
|
||||
batch_size: number
|
||||
augmentation?: AugmentationConfigType
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}) => void
|
||||
isPending: boolean
|
||||
}
|
||||
|
||||
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
|
||||
const [name, setName] = useState(`train-${dataset.name}`)
|
||||
const [epochs, setEpochs] = useState(100)
|
||||
const [batchSize, setBatchSize] = useState(16)
|
||||
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
|
||||
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
|
||||
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
|
||||
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
|
||||
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
|
||||
|
||||
// Fetch available trained models (active or inactive, not archived)
|
||||
const { data: modelsData } = useQuery({
|
||||
queryKey: ['training', 'models', 'available'],
|
||||
queryFn: () => trainingApi.getModels(),
|
||||
})
|
||||
// Filter out archived models - only show active/inactive models for base model selection
|
||||
const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived')
|
||||
|
||||
const handleSubmit = () => {
|
||||
onSubmit({
|
||||
name,
|
||||
config: {
|
||||
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
|
||||
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
|
||||
epochs,
|
||||
batch_size: batchSize,
|
||||
augmentation: augmentationEnabled
|
||||
? (augmentationConfig as AugmentationConfigType)
|
||||
: undefined,
|
||||
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
|
||||
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
|
||||
<p className="text-sm text-warm-text-muted mb-4">
|
||||
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
|
||||
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
|
||||
</p>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
|
||||
<input type="text" value={name} onChange={e => setName(e.target.value)}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
|
||||
{/* Base Model Selection */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
|
||||
<select
|
||||
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
|
||||
onChange={e => {
|
||||
if (e.target.value === 'pretrained') {
|
||||
setBaseModelType('pretrained')
|
||||
setBaseModelVersionId(null)
|
||||
} else {
|
||||
setBaseModelType('existing')
|
||||
setBaseModelVersionId(e.target.value)
|
||||
}
|
||||
}}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
>
|
||||
<option value="pretrained">yolo11n.pt (Pretrained)</option>
|
||||
{availableModels.map(m => (
|
||||
<option key={m.version_id} value={m.version_id}>
|
||||
{m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
{baseModelType === 'pretrained'
|
||||
? 'Start from pretrained YOLO model'
|
||||
: 'Continue training from an existing model (incremental training)'}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-4">
|
||||
<div className="flex-1">
|
||||
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
|
||||
<input
|
||||
id="train-epochs"
|
||||
type="number"
|
||||
min={1}
|
||||
max={1000}
|
||||
value={epochs}
|
||||
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
|
||||
<input
|
||||
id="train-batch-size"
|
||||
type="number"
|
||||
min={1}
|
||||
max={128}
|
||||
value={batchSize}
|
||||
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Augmentation Configuration */}
|
||||
<AugmentationConfig
|
||||
enabled={augmentationEnabled}
|
||||
onEnabledChange={setAugmentationEnabled}
|
||||
config={augmentationConfig}
|
||||
onConfigChange={setAugmentationConfig}
|
||||
/>
|
||||
|
||||
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
|
||||
{augmentationEnabled && (
|
||||
<div>
|
||||
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
|
||||
Augmentation Multiplier
|
||||
</label>
|
||||
<input
|
||||
id="aug-multiplier"
|
||||
type="number"
|
||||
min={1}
|
||||
max={10}
|
||||
value={augmentationMultiplier}
|
||||
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
|
||||
/>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Number of augmented copies per original image (1-10)
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 mt-6">
|
||||
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
|
||||
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
|
||||
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Dataset List ---
|
||||
|
||||
const DatasetList: React.FC<{
|
||||
onNavigate?: (view: string, id?: string) => void
|
||||
onSwitchTab: (tab: Tab) => void
|
||||
}> = ({ onNavigate, onSwitchTab }) => {
|
||||
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
|
||||
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
|
||||
|
||||
const handleTrain = (config: {
|
||||
name: string
|
||||
config: {
|
||||
model_name?: string
|
||||
base_model_version_id?: string | null
|
||||
epochs: number
|
||||
batch_size: number
|
||||
augmentation?: AugmentationConfigType
|
||||
augmentation_multiplier?: number
|
||||
}
|
||||
}) => {
|
||||
if (!trainTarget) return
|
||||
// Pass config to the training API
|
||||
const trainRequest = {
|
||||
name: config.name,
|
||||
config: config.config,
|
||||
}
|
||||
trainFromDataset(
|
||||
{ datasetId: trainTarget.dataset_id, req: trainRequest },
|
||||
{ onSuccess: () => setTrainTarget(null) },
|
||||
)
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
|
||||
}
|
||||
|
||||
if (datasets.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
|
||||
<Database size={48} className="mb-4 opacity-40" />
|
||||
<p className="text-lg mb-2">No datasets yet</p>
|
||||
<p className="text-sm mb-4">Create a dataset to start training</p>
|
||||
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
|
||||
<table className="w-full text-left">
|
||||
<thead className="bg-white border-b border-warm-border">
|
||||
<tr>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{datasets.map(ds => (
|
||||
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
|
||||
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
|
||||
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
|
||||
<td className="py-3 px-4">
|
||||
<div className="flex gap-1">
|
||||
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
|
||||
<Eye size={14} />
|
||||
</button>
|
||||
{ds.status === 'ready' && (
|
||||
<button title="Train" onClick={() => setTrainTarget(ds)}
|
||||
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
|
||||
<Play size={14} />
|
||||
</button>
|
||||
)}
|
||||
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
|
||||
disabled={isDeleting || ds.status === 'pending' || ds.status === 'building'}
|
||||
className={`p-1.5 rounded transition-colors ${
|
||||
ds.status === 'pending' || ds.status === 'building'
|
||||
? 'text-warm-text-muted/40 cursor-not-allowed'
|
||||
: 'hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error'
|
||||
}`}>
|
||||
<Trash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
{trainTarget && (
|
||||
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Create Dataset ---
|
||||
|
||||
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
|
||||
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
|
||||
const { createDatasetAsync, isCreating } = useDatasets()
|
||||
|
||||
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [trainRatio, setTrainRatio] = useState(0.7)
|
||||
const [valRatio, setValRatio] = useState(0.2)
|
||||
|
||||
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
|
||||
|
||||
const toggleDoc = (id: string) => {
|
||||
setSelectedIds(prev => {
|
||||
const next = new Set(prev)
|
||||
if (next.has(id)) { next.delete(id) } else { next.add(id) }
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
const toggleAll = () => {
|
||||
if (selectedIds.size === documents.length) {
|
||||
setSelectedIds(new Set())
|
||||
} else {
|
||||
setSelectedIds(new Set(documents.map((d) => d.document_id)))
|
||||
}
|
||||
}
|
||||
|
||||
const handleCreate = async () => {
|
||||
await createDatasetAsync({
|
||||
name,
|
||||
description: description || undefined,
|
||||
document_ids: [...selectedIds],
|
||||
train_ratio: trainRatio,
|
||||
val_ratio: valRatio,
|
||||
})
|
||||
onSwitchTab('datasets')
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-8">
|
||||
{/* Document selection */}
|
||||
<div className="flex-1 flex flex-col">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
|
||||
{isLoadingDocs ? (
|
||||
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
|
||||
) : (
|
||||
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
|
||||
<div className="overflow-auto max-h-[calc(100vh-240px)]">
|
||||
<table className="w-full text-left">
|
||||
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
|
||||
<tr>
|
||||
<th className="py-3 pl-6 pr-4 w-12">
|
||||
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
|
||||
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
|
||||
</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
|
||||
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{documents.map((doc) => (
|
||||
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
|
||||
onClick={() => toggleDoc(doc.document_id)}>
|
||||
<td className="py-3 pl-6 pr-4">
|
||||
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
|
||||
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
|
||||
</td>
|
||||
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
|
||||
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
|
||||
</div>
|
||||
|
||||
{/* Config panel */}
|
||||
<div className="w-80">
|
||||
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
|
||||
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
|
||||
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
|
||||
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
|
||||
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
|
||||
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
|
||||
<div className="flex gap-2 text-sm">
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Train</span>
|
||||
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Val</span>
|
||||
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<span className="text-xs text-warm-text-muted">Test</span>
|
||||
<input type="number" value={testRatio} readOnly
|
||||
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="pt-4 border-t border-warm-border">
|
||||
{selectedIds.size > 0 && selectedIds.size < 10 && (
|
||||
<p className="text-xs text-warm-state-warning mb-2">
|
||||
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
|
||||
</p>
|
||||
)}
|
||||
<Button className="w-full h-11" onClick={handleCreate}
|
||||
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
|
||||
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Main Training Component ---
|
||||
|
||||
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
|
||||
const [activeTab, setActiveTab] = useState<Tab>('datasets')
|
||||
|
||||
return (
|
||||
<div className="p-8 max-w-7xl mx-auto">
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
|
||||
</div>
|
||||
|
||||
{/* Tabs */}
|
||||
<div className="flex gap-1 mb-6 border-b border-warm-border">
|
||||
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
|
||||
<button key={key} onClick={() => setActiveTab(key)}
|
||||
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
|
||||
activeTab === key
|
||||
? 'border-warm-state-info text-warm-state-info'
|
||||
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
|
||||
}`}>
|
||||
{label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
|
||||
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
276
frontend/src/components/UploadModal.tsx
Normal file
276
frontend/src/components/UploadModal.tsx
Normal file
@@ -0,0 +1,276 @@
|
||||
import React, { useState, useRef } from 'react'
|
||||
import { X, UploadCloud, File, CheckCircle, AlertCircle, ChevronDown } from 'lucide-react'
|
||||
import { Button } from './Button'
|
||||
import { useDocuments, useCategories } from '../hooks/useDocuments'
|
||||
|
||||
interface UploadModalProps {
|
||||
isOpen: boolean
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
|
||||
const [groupKey, setGroupKey] = useState('')
|
||||
const [category, setCategory] = useState('invoice')
|
||||
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
|
||||
const [errorMessage, setErrorMessage] = useState('')
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const { uploadDocument, isUploading } = useDocuments({})
|
||||
const { categories } = useCategories()
|
||||
|
||||
if (!isOpen) return null
|
||||
|
||||
const handleFileSelect = (files: FileList | null) => {
|
||||
if (!files) return
|
||||
|
||||
const pdfFiles = Array.from(files).filter(file => {
|
||||
const isPdf = file.type === 'application/pdf'
|
||||
const isImage = file.type.startsWith('image/')
|
||||
const isUnder25MB = file.size <= 25 * 1024 * 1024
|
||||
return (isPdf || isImage) && isUnder25MB
|
||||
})
|
||||
|
||||
setSelectedFiles(prev => [...prev, ...pdfFiles])
|
||||
setUploadStatus('idle')
|
||||
setErrorMessage('')
|
||||
}
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(false)
|
||||
handleFileSelect(e.dataTransfer.files)
|
||||
}
|
||||
|
||||
const handleBrowseClick = () => {
|
||||
fileInputRef.current?.click()
|
||||
}
|
||||
|
||||
const removeFile = (index: number) => {
|
||||
setSelectedFiles(prev => prev.filter((_, i) => i !== index))
|
||||
}
|
||||
|
||||
const handleUpload = async () => {
|
||||
if (selectedFiles.length === 0) {
|
||||
setErrorMessage('Please select at least one file')
|
||||
return
|
||||
}
|
||||
|
||||
setUploadStatus('uploading')
|
||||
setErrorMessage('')
|
||||
|
||||
try {
|
||||
// Upload files one by one
|
||||
for (const file of selectedFiles) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
uploadDocument(
|
||||
{ file, groupKey: groupKey || undefined, category: category || 'invoice' },
|
||||
{
|
||||
onSuccess: () => resolve(),
|
||||
onError: (error: Error) => reject(error),
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
setUploadStatus('success')
|
||||
setTimeout(() => {
|
||||
onClose()
|
||||
setSelectedFiles([])
|
||||
setGroupKey('')
|
||||
setCategory('invoice')
|
||||
setUploadStatus('idle')
|
||||
}, 1500)
|
||||
} catch (error) {
|
||||
setUploadStatus('error')
|
||||
setErrorMessage(error instanceof Error ? error.message : 'Upload failed')
|
||||
}
|
||||
}
|
||||
|
||||
const handleClose = () => {
|
||||
if (uploadStatus === 'uploading') {
|
||||
return // Prevent closing during upload
|
||||
}
|
||||
setSelectedFiles([])
|
||||
setGroupKey('')
|
||||
setCategory('invoice')
|
||||
setUploadStatus('idle')
|
||||
setErrorMessage('')
|
||||
onClose()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/20 backdrop-blur-sm transition-opacity duration-200">
|
||||
<div
|
||||
className="w-full max-w-lg bg-warm-card rounded-lg shadow-modal border border-warm-border transform transition-all duration-200 scale-100 p-6"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<h3 className="text-xl font-semibold text-warm-text-primary">Upload Documents</h3>
|
||||
<button
|
||||
onClick={handleClose}
|
||||
className="text-warm-text-muted hover:text-warm-text-primary transition-colors disabled:opacity-50"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<X size={20} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Drop Zone */}
|
||||
<div
|
||||
className={`
|
||||
w-full h-48 rounded-lg border-2 border-dashed flex flex-col items-center justify-center gap-3 transition-colors duration-150 mb-6 cursor-pointer
|
||||
${isDragging ? 'border-warm-text-secondary bg-warm-selected' : 'border-warm-divider bg-warm-bg hover:bg-warm-hover'}
|
||||
${uploadStatus === 'uploading' ? 'opacity-50 pointer-events-none' : ''}
|
||||
`}
|
||||
onDragOver={(e) => { e.preventDefault(); setIsDragging(true); }}
|
||||
onDragLeave={() => setIsDragging(false)}
|
||||
onDrop={handleDrop}
|
||||
onClick={handleBrowseClick}
|
||||
>
|
||||
<div className="p-3 bg-white rounded-full shadow-sm">
|
||||
<UploadCloud size={24} className="text-warm-text-secondary" />
|
||||
</div>
|
||||
<div className="text-center">
|
||||
<p className="text-sm font-medium text-warm-text-primary">
|
||||
Drag & drop files here or <span className="underline decoration-1 underline-offset-2 hover:text-warm-state-info">Browse</span>
|
||||
</p>
|
||||
<p className="text-xs text-warm-text-muted mt-1">PDF, JPG, PNG up to 25MB</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
multiple
|
||||
accept=".pdf,image/*"
|
||||
className="hidden"
|
||||
onChange={(e) => handleFileSelect(e.target.files)}
|
||||
/>
|
||||
|
||||
{/* Selected Files */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6 max-h-40 overflow-y-auto">
|
||||
<p className="text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Selected Files ({selectedFiles.length})
|
||||
</p>
|
||||
<div className="space-y-2">
|
||||
{selectedFiles.map((file, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="flex items-center justify-between p-2 bg-warm-bg rounded border border-warm-border"
|
||||
>
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0">
|
||||
<File size={16} className="text-warm-text-muted flex-shrink-0" />
|
||||
<span className="text-sm text-warm-text-secondary truncate">
|
||||
{file.name}
|
||||
</span>
|
||||
<span className="text-xs text-warm-text-muted flex-shrink-0">
|
||||
({(file.size / 1024 / 1024).toFixed(2)} MB)
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => removeFile(index)}
|
||||
className="text-warm-text-muted hover:text-warm-state-error ml-2 flex-shrink-0"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<X size={16} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Category Select */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6">
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Category
|
||||
</label>
|
||||
<div className="relative">
|
||||
<select
|
||||
value={category}
|
||||
onChange={(e) => setCategory(e.target.value)}
|
||||
className="w-full h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none cursor-pointer"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
<option value="invoice">Invoice</option>
|
||||
<option value="letter">Letter</option>
|
||||
<option value="receipt">Receipt</option>
|
||||
<option value="contract">Contract</option>
|
||||
{categories
|
||||
.filter((cat) => !['invoice', 'letter', 'receipt', 'contract'].includes(cat))
|
||||
.map((cat) => (
|
||||
<option key={cat} value={cat}>
|
||||
{cat.charAt(0).toUpperCase() + cat.slice(1)}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<ChevronDown
|
||||
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
|
||||
size={14}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Select document type for training different models
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Group Key Input */}
|
||||
{selectedFiles.length > 0 && (
|
||||
<div className="mb-6">
|
||||
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
|
||||
Group Key (optional)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={groupKey}
|
||||
onChange={(e) => setGroupKey(e.target.value)}
|
||||
placeholder="e.g., 2024-Q1, supplier-abc, project-name"
|
||||
className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
/>
|
||||
<p className="text-xs text-warm-text-muted mt-1">
|
||||
Use group keys to organize documents into logical groups
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Status Messages */}
|
||||
{uploadStatus === 'success' && (
|
||||
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
|
||||
<CheckCircle size={16} className="text-green-600" />
|
||||
<span className="text-sm text-green-800">Upload successful!</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{uploadStatus === 'error' && errorMessage && (
|
||||
<div className="mb-4 p-3 bg-red-50 border border-red-200 rounded flex items-center gap-2">
|
||||
<AlertCircle size={16} className="text-red-600" />
|
||||
<span className="text-sm text-red-800">{errorMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Actions */}
|
||||
<div className="mt-8 flex justify-end gap-3">
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={handleClose}
|
||||
disabled={uploadStatus === 'uploading'}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleUpload}
|
||||
disabled={selectedFiles.length === 0 || uploadStatus === 'uploading'}
|
||||
>
|
||||
{uploadStatus === 'uploading' ? 'Uploading...' : `Upload ${selectedFiles.length > 0 ? `(${selectedFiles.length})` : ''}`}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
7
frontend/src/hooks/index.ts
Normal file
7
frontend/src/hooks/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export { useDocuments, useCategories } from './useDocuments'
|
||||
export { useDocumentDetail } from './useDocumentDetail'
|
||||
export { useAnnotations } from './useAnnotations'
|
||||
export { useTraining, useTrainingDocuments } from './useTraining'
|
||||
export { useDatasets, useDatasetDetail } from './useDatasets'
|
||||
export { useAugmentation } from './useAugmentation'
|
||||
export { useModels, useModelDetail, useActiveModel } from './useModels'
|
||||
70
frontend/src/hooks/useAnnotations.ts
Normal file
70
frontend/src/hooks/useAnnotations.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import { useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { annotationsApi } from '../api/endpoints'
|
||||
import type { CreateAnnotationRequest, AnnotationOverrideRequest } from '../api/types'
|
||||
|
||||
export const useAnnotations = (documentId: string) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (annotation: CreateAnnotationRequest) =>
|
||||
annotationsApi.create(documentId, annotation),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateMutation = useMutation({
|
||||
mutationFn: ({
|
||||
annotationId,
|
||||
updates,
|
||||
}: {
|
||||
annotationId: string
|
||||
updates: Partial<CreateAnnotationRequest>
|
||||
}) => annotationsApi.update(documentId, annotationId, updates),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (annotationId: string) =>
|
||||
annotationsApi.delete(documentId, annotationId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const verifyMutation = useMutation({
|
||||
mutationFn: (annotationId: string) =>
|
||||
annotationsApi.verify(documentId, annotationId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
const overrideMutation = useMutation({
|
||||
mutationFn: ({
|
||||
annotationId,
|
||||
overrideData,
|
||||
}: {
|
||||
annotationId: string
|
||||
overrideData: AnnotationOverrideRequest
|
||||
}) => annotationsApi.override(documentId, annotationId, overrideData),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
createAnnotation: createMutation.mutate,
|
||||
isCreating: createMutation.isPending,
|
||||
updateAnnotation: updateMutation.mutate,
|
||||
isUpdating: updateMutation.isPending,
|
||||
deleteAnnotation: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
verifyAnnotation: verifyMutation.mutate,
|
||||
isVerifying: verifyMutation.isPending,
|
||||
overrideAnnotation: overrideMutation.mutate,
|
||||
isOverriding: overrideMutation.isPending,
|
||||
}
|
||||
}
|
||||
226
frontend/src/hooks/useAugmentation.test.tsx
Normal file
226
frontend/src/hooks/useAugmentation.test.tsx
Normal file
@@ -0,0 +1,226 @@
|
||||
/**
|
||||
* Tests for useAugmentation hook.
|
||||
*
|
||||
* TDD Phase 1: RED - Write tests first, then implement to pass.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { renderHook, waitFor } from '@testing-library/react'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { augmentationApi } from '../api/endpoints/augmentation'
|
||||
import { useAugmentation } from './useAugmentation'
|
||||
import type { ReactNode } from 'react'
|
||||
|
||||
// Mock the API
|
||||
vi.mock('../api/endpoints/augmentation', () => ({
|
||||
augmentationApi: {
|
||||
getTypes: vi.fn(),
|
||||
getPresets: vi.fn(),
|
||||
preview: vi.fn(),
|
||||
previewConfig: vi.fn(),
|
||||
createBatch: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Test wrapper with QueryClient
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('useAugmentation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getTypes', () => {
|
||||
it('should fetch augmentation types', async () => {
|
||||
const mockTypes = {
|
||||
augmentation_types: [
|
||||
{
|
||||
name: 'gaussian_noise',
|
||||
description: 'Adds Gaussian noise',
|
||||
affects_geometry: false,
|
||||
stage: 'noise',
|
||||
default_params: { mean: 0, std: 15 },
|
||||
},
|
||||
{
|
||||
name: 'perspective_warp',
|
||||
description: 'Applies perspective warp',
|
||||
affects_geometry: true,
|
||||
stage: 'geometric',
|
||||
default_params: { max_warp: 0.02 },
|
||||
},
|
||||
],
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce(mockTypes)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.augmentationTypes).toHaveLength(2)
|
||||
expect(result.current.augmentationTypes[0].name).toBe('gaussian_noise')
|
||||
})
|
||||
|
||||
it('should handle error when fetching types', async () => {
|
||||
vi.mocked(augmentationApi.getTypes).mockRejectedValueOnce(new Error('Network error'))
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.typesError).toBeTruthy()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getPresets', () => {
|
||||
it('should fetch augmentation presets', async () => {
|
||||
const mockPresets = {
|
||||
presets: [
|
||||
{ name: 'conservative', description: 'Safe augmentations' },
|
||||
{ name: 'moderate', description: 'Balanced augmentations' },
|
||||
{ name: 'aggressive', description: 'Strong augmentations' },
|
||||
],
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce(mockPresets)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingPresets).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.presets).toHaveLength(3)
|
||||
expect(result.current.presets[0].name).toBe('conservative')
|
||||
})
|
||||
})
|
||||
|
||||
describe('preview', () => {
|
||||
it('should preview single augmentation', async () => {
|
||||
const mockPreview = {
|
||||
preview_url: '',
|
||||
original_url: '',
|
||||
applied_params: { std: 15 },
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||
vi.mocked(augmentationApi.preview).mockResolvedValueOnce(mockPreview)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
// Call preview mutation
|
||||
result.current.preview({
|
||||
documentId: 'doc-123',
|
||||
augmentationType: 'gaussian_noise',
|
||||
params: { std: 15 },
|
||||
page: 1,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(augmentationApi.preview).toHaveBeenCalledWith(
|
||||
'doc-123',
|
||||
{ augmentation_type: 'gaussian_noise', params: { std: 15 } },
|
||||
1
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should track preview loading state', async () => {
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||
vi.mocked(augmentationApi.preview).mockImplementation(
|
||||
() => new Promise((resolve) => setTimeout(resolve, 100))
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.isPreviewing).toBe(false)
|
||||
|
||||
result.current.preview({
|
||||
documentId: 'doc-123',
|
||||
augmentationType: 'gaussian_noise',
|
||||
params: {},
|
||||
page: 1,
|
||||
})
|
||||
|
||||
// State update happens asynchronously
|
||||
await waitFor(() => {
|
||||
expect(result.current.isPreviewing).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('createBatch', () => {
|
||||
it('should create augmented dataset', async () => {
|
||||
const mockResponse = {
|
||||
task_id: 'task-123',
|
||||
status: 'pending',
|
||||
message: 'Augmentation task queued',
|
||||
estimated_images: 100,
|
||||
}
|
||||
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
|
||||
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
|
||||
vi.mocked(augmentationApi.createBatch).mockResolvedValueOnce(mockResponse)
|
||||
|
||||
const { result } = renderHook(() => useAugmentation(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoadingTypes).toBe(false)
|
||||
})
|
||||
|
||||
result.current.createBatch({
|
||||
dataset_id: 'dataset-123',
|
||||
config: {
|
||||
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
|
||||
},
|
||||
output_name: 'augmented-dataset',
|
||||
multiplier: 2,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(augmentationApi.createBatch).toHaveBeenCalledWith({
|
||||
dataset_id: 'dataset-123',
|
||||
config: {
|
||||
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
|
||||
},
|
||||
output_name: 'augmented-dataset',
|
||||
multiplier: 2,
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
121
frontend/src/hooks/useAugmentation.ts
Normal file
121
frontend/src/hooks/useAugmentation.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Hook for managing augmentation operations.
|
||||
*
|
||||
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
|
||||
*/
|
||||
|
||||
import { useQuery, useMutation } from '@tanstack/react-query'
|
||||
import {
|
||||
augmentationApi,
|
||||
type AugmentationTypesResponse,
|
||||
type PresetsResponse,
|
||||
type PreviewResponse,
|
||||
type BatchRequest,
|
||||
type BatchResponse,
|
||||
type AugmentationConfig,
|
||||
} from '../api/endpoints/augmentation'
|
||||
|
||||
interface PreviewParams {
|
||||
documentId: string
|
||||
augmentationType: string
|
||||
params: Record<string, unknown>
|
||||
page?: number
|
||||
}
|
||||
|
||||
interface PreviewConfigParams {
|
||||
documentId: string
|
||||
config: AugmentationConfig
|
||||
page?: number
|
||||
}
|
||||
|
||||
export const useAugmentation = () => {
|
||||
// Fetch augmentation types
|
||||
const {
|
||||
data: typesData,
|
||||
isLoading: isLoadingTypes,
|
||||
error: typesError,
|
||||
} = useQuery<AugmentationTypesResponse>({
|
||||
queryKey: ['augmentation', 'types'],
|
||||
queryFn: () => augmentationApi.getTypes(),
|
||||
staleTime: 5 * 60 * 1000, // Cache for 5 minutes
|
||||
})
|
||||
|
||||
// Fetch presets
|
||||
const {
|
||||
data: presetsData,
|
||||
isLoading: isLoadingPresets,
|
||||
error: presetsError,
|
||||
} = useQuery<PresetsResponse>({
|
||||
queryKey: ['augmentation', 'presets'],
|
||||
queryFn: () => augmentationApi.getPresets(),
|
||||
staleTime: 5 * 60 * 1000,
|
||||
})
|
||||
|
||||
// Preview single augmentation mutation
|
||||
const previewMutation = useMutation<PreviewResponse, Error, PreviewParams>({
|
||||
mutationFn: ({ documentId, augmentationType, params, page = 1 }) =>
|
||||
augmentationApi.preview(
|
||||
documentId,
|
||||
{ augmentation_type: augmentationType, params },
|
||||
page
|
||||
),
|
||||
onError: (error) => {
|
||||
console.error('Preview augmentation failed:', error)
|
||||
},
|
||||
})
|
||||
|
||||
// Preview full config mutation
|
||||
const previewConfigMutation = useMutation<PreviewResponse, Error, PreviewConfigParams>({
|
||||
mutationFn: ({ documentId, config, page = 1 }) =>
|
||||
augmentationApi.previewConfig(documentId, config, page),
|
||||
onError: (error) => {
|
||||
console.error('Preview config failed:', error)
|
||||
},
|
||||
})
|
||||
|
||||
// Create augmented dataset mutation
|
||||
const createBatchMutation = useMutation<BatchResponse, Error, BatchRequest>({
|
||||
mutationFn: (request) => augmentationApi.createBatch(request),
|
||||
onError: (error) => {
|
||||
console.error('Create augmented dataset failed:', error)
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
// Types data
|
||||
augmentationTypes: typesData?.augmentation_types || [],
|
||||
isLoadingTypes,
|
||||
typesError,
|
||||
|
||||
// Presets data
|
||||
presets: presetsData?.presets || [],
|
||||
isLoadingPresets,
|
||||
presetsError,
|
||||
|
||||
// Preview single augmentation
|
||||
preview: previewMutation.mutate,
|
||||
previewAsync: previewMutation.mutateAsync,
|
||||
isPreviewing: previewMutation.isPending,
|
||||
previewData: previewMutation.data,
|
||||
previewError: previewMutation.error,
|
||||
|
||||
// Preview full config
|
||||
previewConfig: previewConfigMutation.mutate,
|
||||
previewConfigAsync: previewConfigMutation.mutateAsync,
|
||||
isPreviewingConfig: previewConfigMutation.isPending,
|
||||
previewConfigData: previewConfigMutation.data,
|
||||
previewConfigError: previewConfigMutation.error,
|
||||
|
||||
// Create batch
|
||||
createBatch: createBatchMutation.mutate,
|
||||
createBatchAsync: createBatchMutation.mutateAsync,
|
||||
isCreatingBatch: createBatchMutation.isPending,
|
||||
batchData: createBatchMutation.data,
|
||||
batchError: createBatchMutation.error,
|
||||
|
||||
// Reset functions for clearing stale mutation state
|
||||
resetPreview: previewMutation.reset,
|
||||
resetPreviewConfig: previewConfigMutation.reset,
|
||||
resetBatch: createBatchMutation.reset,
|
||||
}
|
||||
}
|
||||
84
frontend/src/hooks/useDatasets.ts
Normal file
84
frontend/src/hooks/useDatasets.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { datasetsApi } from '../api/endpoints'
|
||||
import type {
|
||||
DatasetCreateRequest,
|
||||
DatasetDetailResponse,
|
||||
DatasetListResponse,
|
||||
DatasetTrainRequest,
|
||||
} from '../api/types'
|
||||
|
||||
export const useDatasets = (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isLoading, error, refetch } = useQuery<DatasetListResponse>({
|
||||
queryKey: ['datasets', params],
|
||||
queryFn: () => datasetsApi.list(params),
|
||||
staleTime: 30000,
|
||||
// Poll every 5 seconds when there's an active training task
|
||||
refetchInterval: (query) => {
|
||||
const datasets = query.state.data?.datasets ?? []
|
||||
const hasActiveTraining = datasets.some(
|
||||
d => d.training_status === 'running' || d.training_status === 'pending' || d.training_status === 'scheduled'
|
||||
)
|
||||
return hasActiveTraining ? 5000 : false
|
||||
},
|
||||
})
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (req: DatasetCreateRequest) => datasetsApi.create(req),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (datasetId: string) => datasetsApi.remove(datasetId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||
},
|
||||
})
|
||||
|
||||
const trainMutation = useMutation({
|
||||
mutationFn: ({ datasetId, req }: { datasetId: string; req: DatasetTrainRequest }) =>
|
||||
datasetsApi.trainFromDataset(datasetId, req),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['datasets'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
datasets: data?.datasets ?? [],
|
||||
total: data?.total ?? 0,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
createDataset: createMutation.mutate,
|
||||
createDatasetAsync: createMutation.mutateAsync,
|
||||
isCreating: createMutation.isPending,
|
||||
deleteDataset: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
trainFromDataset: trainMutation.mutate,
|
||||
trainFromDatasetAsync: trainMutation.mutateAsync,
|
||||
isTraining: trainMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
export const useDatasetDetail = (datasetId: string | null) => {
|
||||
const { data, isLoading, error } = useQuery<DatasetDetailResponse>({
|
||||
queryKey: ['datasets', datasetId],
|
||||
queryFn: () => datasetsApi.getDetail(datasetId!),
|
||||
enabled: !!datasetId,
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
dataset: data ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
25
frontend/src/hooks/useDocumentDetail.ts
Normal file
25
frontend/src/hooks/useDocumentDetail.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { documentsApi } from '../api/endpoints'
|
||||
import type { DocumentDetailResponse } from '../api/types'
|
||||
|
||||
export const useDocumentDetail = (documentId: string | null) => {
|
||||
const { data, isLoading, error, refetch } = useQuery<DocumentDetailResponse>({
|
||||
queryKey: ['document', documentId],
|
||||
queryFn: () => {
|
||||
if (!documentId) {
|
||||
throw new Error('Document ID is required')
|
||||
}
|
||||
return documentsApi.getDetail(documentId)
|
||||
},
|
||||
enabled: !!documentId,
|
||||
staleTime: 10000,
|
||||
})
|
||||
|
||||
return {
|
||||
document: data || null,
|
||||
annotations: data?.annotations || [],
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
}
|
||||
}
|
||||
120
frontend/src/hooks/useDocuments.ts
Normal file
120
frontend/src/hooks/useDocuments.ts
Normal file
@@ -0,0 +1,120 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { documentsApi } from '../api/endpoints'
|
||||
import type { DocumentListResponse, DocumentCategoriesResponse } from '../api/types'
|
||||
|
||||
interface UseDocumentsParams {
|
||||
status?: string
|
||||
category?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}
|
||||
|
||||
export const useDocuments = (params: UseDocumentsParams = {}) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isLoading, error, refetch } = useQuery<DocumentListResponse>({
|
||||
queryKey: ['documents', params],
|
||||
queryFn: () => documentsApi.list(params),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
const uploadMutation = useMutation({
|
||||
mutationFn: ({ file, groupKey, category }: { file: File; groupKey?: string; category?: string }) =>
|
||||
documentsApi.upload(file, { groupKey, category }),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['categories'] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateGroupKeyMutation = useMutation({
|
||||
mutationFn: ({ documentId, groupKey }: { documentId: string; groupKey: string | null }) =>
|
||||
documentsApi.updateGroupKey(documentId, groupKey),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const batchUploadMutation = useMutation({
|
||||
mutationFn: ({ files, csvFile }: { files: File[]; csvFile?: File }) =>
|
||||
documentsApi.batchUpload(files, csvFile),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (documentId: string) => documentsApi.delete(documentId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateStatusMutation = useMutation({
|
||||
mutationFn: ({ documentId, status }: { documentId: string; status: string }) =>
|
||||
documentsApi.updateStatus(documentId, status),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const triggerAutoLabelMutation = useMutation({
|
||||
mutationFn: (documentId: string) => documentsApi.triggerAutoLabel(documentId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
},
|
||||
})
|
||||
|
||||
const updateCategoryMutation = useMutation({
|
||||
mutationFn: ({ documentId, category }: { documentId: string; category: string }) =>
|
||||
documentsApi.updateCategory(documentId, category),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['documents'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['categories'] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
documents: data?.documents || [],
|
||||
total: data?.total || 0,
|
||||
limit: data?.limit || params.limit || 20,
|
||||
offset: data?.offset || params.offset || 0,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
uploadDocument: uploadMutation.mutate,
|
||||
uploadDocumentAsync: uploadMutation.mutateAsync,
|
||||
isUploading: uploadMutation.isPending,
|
||||
batchUpload: batchUploadMutation.mutate,
|
||||
batchUploadAsync: batchUploadMutation.mutateAsync,
|
||||
isBatchUploading: batchUploadMutation.isPending,
|
||||
deleteDocument: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
updateStatus: updateStatusMutation.mutate,
|
||||
isUpdatingStatus: updateStatusMutation.isPending,
|
||||
triggerAutoLabel: triggerAutoLabelMutation.mutate,
|
||||
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
|
||||
updateGroupKey: updateGroupKeyMutation.mutate,
|
||||
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
|
||||
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
|
||||
updateCategory: updateCategoryMutation.mutate,
|
||||
updateCategoryAsync: updateCategoryMutation.mutateAsync,
|
||||
isUpdatingCategory: updateCategoryMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
export const useCategories = () => {
|
||||
const { data, isLoading, error, refetch } = useQuery<DocumentCategoriesResponse>({
|
||||
queryKey: ['categories'],
|
||||
queryFn: () => documentsApi.getCategories(),
|
||||
staleTime: 60000,
|
||||
})
|
||||
|
||||
return {
|
||||
categories: data?.categories || [],
|
||||
total: data?.total || 0,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
}
|
||||
}
|
||||
98
frontend/src/hooks/useModels.ts
Normal file
98
frontend/src/hooks/useModels.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { modelsApi } from '../api/endpoints'
|
||||
import type {
|
||||
ModelVersionListResponse,
|
||||
ModelVersionDetailResponse,
|
||||
ActiveModelResponse,
|
||||
} from '../api/types'
|
||||
|
||||
export const useModels = (params?: {
|
||||
status?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
}) => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isLoading, error, refetch } = useQuery<ModelVersionListResponse>({
|
||||
queryKey: ['models', params],
|
||||
queryFn: () => modelsApi.list(params),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
const activateMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.activate(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deactivateMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.deactivate(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
|
||||
},
|
||||
})
|
||||
|
||||
const archiveMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.archive(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
},
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (versionId: string) => modelsApi.delete(versionId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['models'] })
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
models: data?.models ?? [],
|
||||
total: data?.total ?? 0,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
activateModel: activateMutation.mutate,
|
||||
activateModelAsync: activateMutation.mutateAsync,
|
||||
isActivating: activateMutation.isPending,
|
||||
deactivateModel: deactivateMutation.mutate,
|
||||
isDeactivating: deactivateMutation.isPending,
|
||||
archiveModel: archiveMutation.mutate,
|
||||
isArchiving: archiveMutation.isPending,
|
||||
deleteModel: deleteMutation.mutate,
|
||||
isDeleting: deleteMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
export const useModelDetail = (versionId: string | null) => {
|
||||
const { data, isLoading, error } = useQuery<ModelVersionDetailResponse>({
|
||||
queryKey: ['models', versionId],
|
||||
queryFn: () => modelsApi.getDetail(versionId!),
|
||||
enabled: !!versionId,
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
model: data ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
export const useActiveModel = () => {
|
||||
const { data, isLoading, error } = useQuery<ActiveModelResponse>({
|
||||
queryKey: ['models', 'active'],
|
||||
queryFn: () => modelsApi.getActive(),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
hasActiveModel: data?.has_active_model ?? false,
|
||||
activeModel: data?.model ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
83
frontend/src/hooks/useTraining.ts
Normal file
83
frontend/src/hooks/useTraining.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { trainingApi } from '../api/endpoints'
|
||||
import type { TrainingModelsResponse } from '../api/types'
|
||||
|
||||
export const useTraining = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data: modelsData, isLoading: isLoadingModels } =
|
||||
useQuery<TrainingModelsResponse>({
|
||||
queryKey: ['training', 'models'],
|
||||
queryFn: () => trainingApi.getModels(),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
const startTrainingMutation = useMutation({
|
||||
mutationFn: (config: {
|
||||
name: string
|
||||
description?: string
|
||||
document_ids: string[]
|
||||
epochs?: number
|
||||
batch_size?: number
|
||||
model_base?: string
|
||||
}) => trainingApi.startTraining(config),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||
},
|
||||
})
|
||||
|
||||
const cancelTaskMutation = useMutation({
|
||||
mutationFn: (taskId: string) => trainingApi.cancelTask(taskId),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
|
||||
},
|
||||
})
|
||||
|
||||
const downloadModelMutation = useMutation({
|
||||
mutationFn: (taskId: string) => trainingApi.downloadModel(taskId),
|
||||
onSuccess: (blob, taskId) => {
|
||||
const url = window.URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = `model-${taskId}.pt`
|
||||
document.body.appendChild(a)
|
||||
a.click()
|
||||
window.URL.revokeObjectURL(url)
|
||||
document.body.removeChild(a)
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
models: modelsData?.models || [],
|
||||
total: modelsData?.total || 0,
|
||||
isLoadingModels,
|
||||
startTraining: startTrainingMutation.mutate,
|
||||
startTrainingAsync: startTrainingMutation.mutateAsync,
|
||||
isStartingTraining: startTrainingMutation.isPending,
|
||||
cancelTask: cancelTaskMutation.mutate,
|
||||
isCancelling: cancelTaskMutation.isPending,
|
||||
downloadModel: downloadModelMutation.mutate,
|
||||
isDownloading: downloadModelMutation.isPending,
|
||||
}
|
||||
}
|
||||
|
||||
export const useTrainingDocuments = (params?: {
|
||||
has_annotations?: boolean
|
||||
min_annotation_count?: number
|
||||
exclude_used_in_training?: boolean
|
||||
limit?: number
|
||||
offset?: number
|
||||
}) => {
|
||||
const { data, isLoading, error } = useQuery({
|
||||
queryKey: ['training', 'documents', params],
|
||||
queryFn: () => trainingApi.getDocumentsForTraining(params),
|
||||
staleTime: 30000,
|
||||
})
|
||||
|
||||
return {
|
||||
documents: data?.documents || [],
|
||||
total: data?.total || 0,
|
||||
isLoading,
|
||||
error,
|
||||
}
|
||||
}
|
||||
23
frontend/src/main.tsx
Normal file
23
frontend/src/main.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
import React from 'react'
|
||||
import ReactDOM from 'react-dom/client'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import App from './App'
|
||||
import './styles/index.css'
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: 1,
|
||||
refetchOnWindowFocus: false,
|
||||
staleTime: 30000,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||
<React.StrictMode>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<App />
|
||||
</QueryClientProvider>
|
||||
</React.StrictMode>
|
||||
)
|
||||
26
frontend/src/styles/index.css
Normal file
26
frontend/src/styles/index.css
Normal file
@@ -0,0 +1,26 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
@layer base {
|
||||
body {
|
||||
@apply bg-warm-bg text-warm-text-primary;
|
||||
}
|
||||
|
||||
/* Custom scrollbar */
|
||||
::-webkit-scrollbar {
|
||||
@apply w-2 h-2;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
@apply bg-transparent;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
@apply bg-warm-divider rounded;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
@apply bg-warm-text-disabled;
|
||||
}
|
||||
}
|
||||
48
frontend/src/types/index.ts
Normal file
48
frontend/src/types/index.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
// Legacy types for backward compatibility with old components
|
||||
// These will be gradually replaced with API types
|
||||
|
||||
export enum DocumentStatus {
|
||||
PENDING = 'Pending',
|
||||
LABELED = 'Labeled',
|
||||
VERIFIED = 'Verified',
|
||||
PARTIAL = 'Partial'
|
||||
}
|
||||
|
||||
export interface Document {
|
||||
id: string
|
||||
name: string
|
||||
date: string
|
||||
status: DocumentStatus
|
||||
exported: boolean
|
||||
autoLabelProgress?: number
|
||||
autoLabelStatus?: 'Running' | 'Completed' | 'Failed'
|
||||
}
|
||||
|
||||
export interface Annotation {
|
||||
id: string
|
||||
text: string
|
||||
label: string
|
||||
x: number
|
||||
y: number
|
||||
width: number
|
||||
height: number
|
||||
isAuto?: boolean
|
||||
}
|
||||
|
||||
export interface TrainingJob {
|
||||
id: string
|
||||
name: string
|
||||
startDate: string
|
||||
status: 'Running' | 'Completed' | 'Failed'
|
||||
progress: number
|
||||
metrics?: {
|
||||
accuracy: number
|
||||
precision: number
|
||||
recall: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface ModelMetric {
|
||||
name: string
|
||||
value: number
|
||||
}
|
||||
47
frontend/tailwind.config.js
Normal file
47
frontend/tailwind.config.js
Normal file
@@ -0,0 +1,47 @@
|
||||
export default {
|
||||
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
|
||||
theme: {
|
||||
extend: {
|
||||
fontFamily: {
|
||||
sans: ['Inter', 'SF Pro', 'system-ui', 'sans-serif'],
|
||||
mono: ['JetBrains Mono', 'SF Mono', 'monospace'],
|
||||
},
|
||||
colors: {
|
||||
warm: {
|
||||
bg: '#FAFAF8',
|
||||
card: '#FFFFFF',
|
||||
hover: '#F1F0ED',
|
||||
selected: '#ECEAE6',
|
||||
border: '#E6E4E1',
|
||||
divider: '#D8D6D2',
|
||||
text: {
|
||||
primary: '#121212',
|
||||
secondary: '#2A2A2A',
|
||||
muted: '#6B6B6B',
|
||||
disabled: '#9A9A9A',
|
||||
},
|
||||
state: {
|
||||
success: '#3E4A3A',
|
||||
error: '#4A3A3A',
|
||||
warning: '#4A4A3A',
|
||||
info: '#3A3A3A',
|
||||
}
|
||||
}
|
||||
},
|
||||
boxShadow: {
|
||||
'card': '0 1px 3px rgba(0,0,0,0.08)',
|
||||
'modal': '0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)',
|
||||
},
|
||||
animation: {
|
||||
'fade-in': 'fadeIn 0.3s ease-out',
|
||||
},
|
||||
keyframes: {
|
||||
fadeIn: {
|
||||
'0%': { opacity: '0', transform: 'translateY(10px)' },
|
||||
'100%': { opacity: '1', transform: 'translateY(0)' },
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
1
frontend/tests/setup.ts
Normal file
1
frontend/tests/setup.ts
Normal file
@@ -0,0 +1 @@
|
||||
import '@testing-library/jest-dom';
|
||||
29
frontend/tsconfig.json
Normal file
29
frontend/tsconfig.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"experimentalDecorators": true,
|
||||
"useDefineForClassFields": false,
|
||||
"module": "ESNext",
|
||||
"lib": [
|
||||
"ES2022",
|
||||
"DOM",
|
||||
"DOM.Iterable"
|
||||
],
|
||||
"skipLibCheck": true,
|
||||
"types": [
|
||||
"node"
|
||||
],
|
||||
"moduleResolution": "bundler",
|
||||
"isolatedModules": true,
|
||||
"moduleDetection": "force",
|
||||
"allowJs": true,
|
||||
"jsx": "react-jsx",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
"./*"
|
||||
]
|
||||
},
|
||||
"allowImportingTsExtensions": true,
|
||||
"noEmit": true
|
||||
}
|
||||
}
|
||||
16
frontend/vite.config.ts
Normal file
16
frontend/vite.config.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
|
||||
export default defineConfig({
|
||||
server: {
|
||||
port: 3000,
|
||||
host: '0.0.0.0',
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:8000',
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [react()],
|
||||
});
|
||||
19
frontend/vitest.config.ts
Normal file
19
frontend/vitest.config.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
/// <reference types="vitest/config" />
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'jsdom',
|
||||
setupFiles: ['./tests/setup.ts'],
|
||||
include: ['src/**/*.test.{ts,tsx}', 'tests/**/*.test.{ts,tsx}'],
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reporter: ['text', 'lcov'],
|
||||
include: ['src/**/*.{ts,tsx}'],
|
||||
exclude: ['src/**/*.test.{ts,tsx}', 'src/main.tsx'],
|
||||
},
|
||||
},
|
||||
});
|
||||
18
migrations/003_training_tasks.sql
Normal file
18
migrations/003_training_tasks.sql
Normal file
@@ -0,0 +1,18 @@
|
||||
-- Training tasks table for async training job management.
|
||||
-- Inference service writes pending tasks; training service polls and executes.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS training_tasks (
|
||||
task_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
||||
config JSONB,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
scheduled_at TIMESTAMP WITH TIME ZONE,
|
||||
started_at TIMESTAMP WITH TIME ZONE,
|
||||
completed_at TIMESTAMP WITH TIME ZONE,
|
||||
error_message TEXT,
|
||||
model_path TEXT,
|
||||
metrics JSONB
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_status ON training_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_created ON training_tasks(created_at);
|
||||
39
migrations/004_training_datasets.sql
Normal file
39
migrations/004_training_datasets.sql
Normal file
@@ -0,0 +1,39 @@
|
||||
-- Training Datasets Management
|
||||
-- Tracks dataset-document relationships and train/val/test splits
|
||||
|
||||
CREATE TABLE IF NOT EXISTS training_datasets (
|
||||
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'building',
|
||||
train_ratio FLOAT NOT NULL DEFAULT 0.8,
|
||||
val_ratio FLOAT NOT NULL DEFAULT 0.1,
|
||||
seed INTEGER NOT NULL DEFAULT 42,
|
||||
total_documents INTEGER NOT NULL DEFAULT 0,
|
||||
total_images INTEGER NOT NULL DEFAULT 0,
|
||||
total_annotations INTEGER NOT NULL DEFAULT 0,
|
||||
dataset_path VARCHAR(512),
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS dataset_documents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
|
||||
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
|
||||
split VARCHAR(10) NOT NULL,
|
||||
page_count INTEGER NOT NULL DEFAULT 0,
|
||||
annotation_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(dataset_id, document_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
|
||||
|
||||
-- Add dataset_id to training_tasks
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);
|
||||
8
migrations/005_add_group_key.sql
Normal file
8
migrations/005_add_group_key.sql
Normal file
@@ -0,0 +1,8 @@
|
||||
-- Add group_key column to admin_documents
|
||||
-- Allows users to organize documents into logical groups
|
||||
|
||||
-- Add the column (nullable, VARCHAR 255)
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
|
||||
|
||||
-- Add index for filtering/grouping queries
|
||||
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);
|
||||
49
migrations/006_model_versions.sql
Normal file
49
migrations/006_model_versions.sql
Normal file
@@ -0,0 +1,49 @@
|
||||
-- Model versions table for tracking trained model deployments.
|
||||
-- Each training run can produce a model version for inference.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS model_versions (
|
||||
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
version VARCHAR(50) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
model_path VARCHAR(512) NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
|
||||
is_active BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
|
||||
-- Training association
|
||||
task_id UUID REFERENCES training_tasks(task_id) ON DELETE SET NULL,
|
||||
dataset_id UUID REFERENCES training_datasets(dataset_id) ON DELETE SET NULL,
|
||||
|
||||
-- Training metrics
|
||||
metrics_mAP DOUBLE PRECISION,
|
||||
metrics_precision DOUBLE PRECISION,
|
||||
metrics_recall DOUBLE PRECISION,
|
||||
document_count INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
-- Training configuration snapshot
|
||||
training_config JSONB,
|
||||
|
||||
-- File info
|
||||
file_size BIGINT,
|
||||
|
||||
-- Timestamps
|
||||
trained_at TIMESTAMP WITH TIME ZONE,
|
||||
activated_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Indexes
|
||||
CREATE INDEX IF NOT EXISTS idx_model_versions_version ON model_versions(version);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_versions_status ON model_versions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_versions_is_active ON model_versions(is_active);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_versions_task_id ON model_versions(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_versions_dataset_id ON model_versions(dataset_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_versions_created ON model_versions(created_at);
|
||||
|
||||
-- Ensure only one active model at a time
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
|
||||
ON model_versions(is_active) WHERE is_active = TRUE;
|
||||
|
||||
-- Comment
|
||||
COMMENT ON TABLE model_versions IS 'Trained model versions for inference deployment';
|
||||
46
migrations/007_training_tasks_extra_columns.sql
Normal file
46
migrations/007_training_tasks_extra_columns.sql
Normal file
@@ -0,0 +1,46 @@
|
||||
-- Add missing columns to training_tasks table
|
||||
|
||||
-- Add name column
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
|
||||
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
|
||||
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
|
||||
|
||||
-- Add description column
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
|
||||
|
||||
-- Add admin_token column (for multi-tenant support)
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
|
||||
|
||||
-- Add task_type column
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
|
||||
|
||||
-- Add recurring schedule columns
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
|
||||
|
||||
-- Add result metrics columns (for display without parsing JSONB)
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
|
||||
|
||||
-- Rename metrics to config if exists
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'training_tasks' AND column_name = 'metrics'
|
||||
AND NOT EXISTS (SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'training_tasks' AND column_name = 'config')) THEN
|
||||
ALTER TABLE training_tasks RENAME COLUMN metrics TO config;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- Add updated_at column
|
||||
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
|
||||
|
||||
-- Create index on name
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
|
||||
|
||||
-- Create index on metrics_mAP
|
||||
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);
|
||||
14
migrations/008_fix_model_versions_fk.sql
Normal file
14
migrations/008_fix_model_versions_fk.sql
Normal file
@@ -0,0 +1,14 @@
|
||||
-- Fix foreign key constraints on model_versions table to allow CASCADE delete
|
||||
|
||||
-- Drop existing constraints
|
||||
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
|
||||
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
|
||||
|
||||
-- Add constraints with ON DELETE SET NULL
|
||||
ALTER TABLE model_versions
|
||||
ADD CONSTRAINT model_versions_dataset_id_fkey
|
||||
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE model_versions
|
||||
ADD CONSTRAINT model_versions_task_id_fkey
|
||||
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;
|
||||
13
migrations/009_add_document_category.sql
Normal file
13
migrations/009_add_document_category.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
-- Add category column to admin_documents table
|
||||
-- Allows categorizing documents for training different models (e.g., invoice, letter, receipt)
|
||||
|
||||
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
|
||||
|
||||
-- Update existing NULL values to default
|
||||
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
|
||||
|
||||
-- Make it NOT NULL after setting defaults
|
||||
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
|
||||
|
||||
-- Create index for category filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);
|
||||
28
migrations/010_add_dataset_training_status.sql
Normal file
28
migrations/010_add_dataset_training_status.sql
Normal file
@@ -0,0 +1,28 @@
|
||||
-- Migration: Add training_status and active_training_task_id to training_datasets
|
||||
-- Description: Track training status separately from dataset build status
|
||||
|
||||
-- Add training_status column
|
||||
ALTER TABLE training_datasets
|
||||
ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
|
||||
|
||||
-- Add active_training_task_id column
|
||||
ALTER TABLE training_datasets
|
||||
ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
|
||||
|
||||
-- Create index for training_status
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status
|
||||
ON training_datasets(training_status);
|
||||
|
||||
-- Create index for active_training_task_id
|
||||
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id
|
||||
ON training_datasets(active_training_task_id);
|
||||
|
||||
-- Update existing datasets that have been used in completed training tasks to 'trained' status
|
||||
UPDATE training_datasets d
|
||||
SET status = 'trained'
|
||||
WHERE d.status = 'ready'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM training_tasks t
|
||||
WHERE t.dataset_id = d.dataset_id
|
||||
AND t.status = 'completed'
|
||||
);
|
||||
25
packages/inference/Dockerfile
Normal file
25
packages/inference/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install shared package
|
||||
COPY packages/shared /app/packages/shared
|
||||
RUN pip install --no-cache-dir -e /app/packages/shared
|
||||
|
||||
# Install inference package
|
||||
COPY packages/inference /app/packages/inference
|
||||
RUN pip install --no-cache-dir -e /app/packages/inference
|
||||
|
||||
# Copy frontend (if needed)
|
||||
COPY frontend /app/frontend
|
||||
|
||||
WORKDIR /app/packages/inference
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]
|
||||
0
packages/inference/inference/__init__.py
Normal file
0
packages/inference/inference/__init__.py
Normal file
0
packages/inference/inference/azure/__init__.py
Normal file
0
packages/inference/inference/azure/__init__.py
Normal file
105
packages/inference/inference/azure/aci_trigger.py
Normal file
105
packages/inference/inference/azure/aci_trigger.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Trigger training jobs on Azure Container Instances."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Azure SDK is optional; only needed if using ACI trigger
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
|
||||
from azure.mgmt.containerinstance.models import (
|
||||
Container,
|
||||
ContainerGroup,
|
||||
EnvironmentVariable,
|
||||
GpuResource,
|
||||
ResourceRequests,
|
||||
ResourceRequirements,
|
||||
)
|
||||
|
||||
_AZURE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
_AZURE_SDK_AVAILABLE = False
|
||||
|
||||
|
||||
def start_training_container(task_id: str) -> str | None:
|
||||
"""
|
||||
Start an Azure Container Instance for a training task.
|
||||
|
||||
Returns the container group name if successful, None otherwise.
|
||||
Requires environment variables:
|
||||
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
|
||||
"""
|
||||
if not _AZURE_SDK_AVAILABLE:
|
||||
logger.warning(
|
||||
"Azure SDK not installed. Install azure-mgmt-containerinstance "
|
||||
"and azure-identity to use ACI trigger."
|
||||
)
|
||||
return None
|
||||
|
||||
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
|
||||
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
|
||||
image = os.environ.get(
|
||||
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
|
||||
)
|
||||
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
|
||||
location = os.environ.get("AZURE_LOCATION", "eastus")
|
||||
|
||||
if not subscription_id:
|
||||
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
|
||||
return None
|
||||
|
||||
credential = DefaultAzureCredential()
|
||||
client = ContainerInstanceManagementClient(credential, subscription_id)
|
||||
|
||||
container_name = f"training-{task_id[:8]}"
|
||||
|
||||
env_vars = [
|
||||
EnvironmentVariable(name="TASK_ID", value=task_id),
|
||||
]
|
||||
|
||||
# Pass DB connection securely
|
||||
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
|
||||
val = os.environ.get(var, "")
|
||||
if val:
|
||||
env_vars.append(EnvironmentVariable(name=var, value=val))
|
||||
|
||||
db_password = os.environ.get("DB_PASSWORD", "")
|
||||
if db_password:
|
||||
env_vars.append(
|
||||
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
|
||||
)
|
||||
|
||||
container = Container(
|
||||
name=container_name,
|
||||
image=image,
|
||||
resources=ResourceRequirements(
|
||||
requests=ResourceRequests(
|
||||
cpu=4,
|
||||
memory_in_gb=16,
|
||||
gpu=GpuResource(count=1, sku=gpu_sku),
|
||||
)
|
||||
),
|
||||
environment_variables=env_vars,
|
||||
command=[
|
||||
"python",
|
||||
"run_training.py",
|
||||
"--task-id",
|
||||
task_id,
|
||||
],
|
||||
)
|
||||
|
||||
group = ContainerGroup(
|
||||
location=location,
|
||||
containers=[container],
|
||||
os_type="Linux",
|
||||
restart_policy="Never",
|
||||
)
|
||||
|
||||
logger.info("Creating ACI container group: %s", container_name)
|
||||
client.container_groups.begin_create_or_update(
|
||||
resource_group, container_name, group
|
||||
)
|
||||
|
||||
return container_name
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user