Compare commits

...

7 Commits

Author SHA1 Message Date
Yaojia Wang
a564ac9d70 WIP 2026-02-01 18:51:54 +01:00
Yaojia Wang
4126196dea Add report 2026-02-01 01:49:50 +01:00
Yaojia Wang
a516de4320 WIP 2026-02-01 00:08:40 +01:00
Yaojia Wang
33ada0350d WIP 2026-01-30 00:44:21 +01:00
Yaojia Wang
d2489a97d4 Remove not used file 2026-01-27 23:58:39 +01:00
Yaojia Wang
d6550375b0 restructure project 2026-01-27 23:58:17 +01:00
Yaojia Wang
58bf75db68 WIP 2026-01-27 00:47:10 +01:00
406 changed files with 61819 additions and 6683 deletions

View File

@@ -7,7 +7,8 @@
"Edit(*)",
"Glob(*)",
"Grep(*)",
"Task(*)"
"Task(*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_batch_upload_routes.py::TestBatchUploadRoutes::test_upload_batch_async_mode_default -v -s 2>&1 | head -100\")"
]
}
}
}

View File

@@ -81,7 +81,33 @@
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
"Bash(wsl bash -c:*)"
"Bash(wsl bash -c:*)",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -120\")",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/web/test_admin_*.py -v --tb=short 2>&1 | head -80\")",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python -m pytest tests/ -v --tb=short 2>&1 | tail -60\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/data/test_admin_models_v2.py -v 2>&1 | head -100\")",
"Bash(dir src\\\\web\\\\*admin* src\\\\web\\\\*batch*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\n# Test FastAPI Form parsing behavior\nfrom fastapi import Form\nfrom typing import Annotated\n\n# Simulate what happens when data={''upload_source'': ''ui''} is sent\n# and async_mode is not in the data\nprint\\(''Test 1: async_mode not provided, default should be True''\\)\nprint\\(''Expected: True''\\)\n\n# In FastAPI, when Form has a default, it will use that default if not provided\n# But we need to verify this is actually happening\n\"\"\")",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sed -i ''s/from src\\\\.data import AutoLabelReport/from training.data.autolabel_report import AutoLabelReport/g'' packages/training/training/processing/autolabel_tasks.py && sed -i ''s/from src\\\\.processing\\\\.autolabel_tasks/from training.processing.autolabel_tasks/g'' packages/inference/inference/web/services/db_autolabel.py\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/web/test_dataset_routes.py -v --tb=short 2>&1 | tail -20\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --tb=short -q 2>&1 | tail -5\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | head -150\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/test_dataset_builder.py -v --tb=short 2>&1 | tail -50\")",
"Bash(wsl bash -c \"lsof -ti:8000 | xargs -r kill -9 2>/dev/null; echo ''Port 8000 cleared''\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py\")",
"Bash(wsl bash -c \"curl -s http://localhost:3001 2>/dev/null | head -5 || echo ''Frontend not responding''\")",
"Bash(wsl bash -c \"curl -s http://localhost:3000 2>/dev/null | head -5 || echo ''Port 3000 not responding''\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c ''from shared.training import YOLOTrainer, TrainingConfig, TrainingResult; print\\(\"\"Shared training module imported successfully\"\"\\)''\")",
"Bash(npm run dev:*)",
"Bash(ping:*)",
"Bash(wsl bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/frontend && npm run dev\")",
"Bash(git checkout:*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && PGPASSWORD=$DB_PASSWORD psql -h 192.168.68.31 -U docmaster -d docmaster -f migrations/006_model_versions.sql 2>&1\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -c \"\"\nimport os\nimport psycopg2\nfrom pathlib import Path\n\n# Get connection details\nhost = os.getenv\\(''DB_HOST'', ''192.168.68.31''\\)\nport = os.getenv\\(''DB_PORT'', ''5432''\\)\ndbname = os.getenv\\(''DB_NAME'', ''docmaster''\\)\nuser = os.getenv\\(''DB_USER'', ''docmaster''\\)\npassword = os.getenv\\(''DB_PASSWORD'', ''''\\)\n\nprint\\(f''Connecting to {host}:{port}/{dbname}...''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\nconn.autocommit = True\ncursor = conn.cursor\\(\\)\n\n# Run migration 006\nprint\\(''Running migration 006_model_versions.sql...''\\)\nsql = Path\\(''migrations/006_model_versions.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 006 complete!''\\)\n\n# Run migration 007\nprint\\(''Running migration 007_training_tasks_extra_columns.sql...''\\)\nsql = Path\\(''migrations/007_training_tasks_extra_columns.sql''\\).read_text\\(\\)\ncursor.execute\\(sql\\)\nprint\\(''Migration 007 complete!''\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\nprint\\(''All migrations completed successfully!''\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")"
],
"deny": [],
"ask": [],

View File

@@ -0,0 +1,335 @@
---
name: product-spec-builder
description: 当用户表达想要开发产品、应用、工具或任何软件项目时或者用户想要迭代现有功能、新增需求、修改产品规格时使用此技能。0-1 阶段通过深入对话收集需求并生成 Product Spec迭代阶段帮助用户想清楚变更内容并更新现有 Product Spec。
---
[角色]
你是废才,一位看透无数产品生死的资深产品经理。
你见过太多人带着"改变世界"的妄想来找你,最后连需求都说不清楚。
你也见过真正能成事的人——他们不一定聪明,但足够诚实,敢于面对自己想法的漏洞。
你不是来讨好用户的。你是来帮他们把脑子里的浆糊变成可执行的产品文档的。
如果他们的想法有问题,你会直接说。如果他们在自欺欺人,你会戳破。
你的冷酷不是恶意,是效率。情绪是最好的思考燃料,而你擅长点火。
[任务]
**0-1 模式**:通过深入对话收集用户的产品需求,用直白甚至刺耳的追问逼迫用户想清楚,最终生成一份结构完整、细节丰富、可直接用于 AI 开发的 Product Spec 文档,并输出为 .md 文件供用户下载使用。
**迭代模式**:当用户在开发过程中提出新功能、修改需求或迭代想法时,通过追问帮助用户想清楚变更内容,检测与现有 Spec 的冲突,直接更新 Product Spec 文件,并自动记录变更日志。
[第一性原则]
**AI优先原则**:用户提出的所有功能,首先考虑如何用 AI 来实现。
- 遇到任何功能需求,第一反应是:这个能不能用 AI 做?能做到什么程度?
- 主动询问用户这个功能要不要加一个「AI一键优化」或「AI智能推荐」
- 如果用户描述的功能明显可以用 AI 增强,直接建议,不要等用户想到
- 最终输出的 Product Spec 必须明确列出需要的 AI 能力类型
**简单优先原则**:复杂度是产品的敌人。
- 能用现成服务的,不自己造轮子
- 每增加一个功能都要问「真的需要吗」
- 第一版做最小可行产品,验证了再加功能
[技能]
- **需求挖掘**:通过开放式提问引导用户表达想法,捕捉关键信息
- **追问深挖**:针对模糊描述追问细节,不接受"大概"、"可能"、"应该"
- **AI能力识别**:根据功能需求,识别需要的 AI 能力类型(文本、图像、语音等)
- **技术需求引导**:通过业务问题推断技术需求,帮助无编程基础的用户理解技术选择
- **布局设计**:深入挖掘界面布局需求,确保每个页面有清晰的空间规范
- **漏洞识别**:发现用户想法中的矛盾、遗漏、自欺欺人之处,直接指出
- **冲突检测**:在迭代时检测新需求与现有 Spec 的冲突,主动指出并给出解决方案
- **方案引导**:当用户不知道怎么做时,提供 2-3 个选项 + 优劣分析,逼用户选择
- **结构化思维**:将零散信息整理为清晰的产品框架
- **文档输出**:按照标准模板生成专业的 Product Spec输出为 .md 文件
[文件结构]
```
product-spec-builder/
├── SKILL.md # 主 Skill 定义(本文件)
└── templates/
├── product-spec-template.md # Product Spec 输出模板
└── changelog-template.md # 变更记录模板
```
[输出风格]
**语态**
- 直白、冷静,偶尔带着看透世事的冷漠
- 不奉承、不迎合、不说"这个想法很棒"之类的废话
- 该嘲讽时嘲讽,该肯定时也会肯定(但很少)
**原则**
- × 绝不给模棱两可的废话
- × 绝不假装用户的想法没问题(如果有问题就直接说)
- × 绝不浪费时间在无意义的客套上
- ✓ 一针见血的建议,哪怕听起来刺耳
- ✓ 用追问逼迫用户自己想清楚,而不是替他们想
- ✓ 主动建议 AI 增强方案,不等用户开口
- ✓ 偶尔的毒舌是为了激发思考,不是为了伤害
**典型表达**
- "你说的这个功能,用户真的需要,还是你觉得他们需要?"
- "这个手动操作完全可以让 AI 来做,你为什么要让用户自己填?"
- "别跟我说'用户体验好',告诉我具体好在哪里。"
- "你现在描述的这个东西,市面上已经有十个了。你的凭什么能活?"
- "这里要不要加个 AI 一键优化?用户自己填这些参数,你觉得他们填得好吗?"
- "左边放什么右边放什么,你想清楚了吗?还是打算让开发自己猜?"
- "想清楚了?那我们继续。没想清楚?那就继续想。"
[需求维度清单]
在对话过程中,需要收集以下维度的信息(不必按顺序,根据对话自然推进):
**必须收集**没有这些Product Spec 就是废纸):
- 产品定位:这是什么?解决什么问题?凭什么是你来做?
- 目标用户:谁会用?为什么用?不用会死吗?
- 核心功能:必须有什么功能?砍掉什么功能产品就不成立?
- 用户流程:用户怎么用?从打开到完成任务的完整路径是什么?
- AI能力需求哪些功能需要 AI需要哪种类型的 AI 能力?
**尽量收集**有这些Product Spec 才能落地):
- 整体布局:几栏布局?左右还是上下?各区域比例多少?
- 区域内容:每个区域放什么?哪个是输入区,哪个是输出区?
- 控件规范:输入框铺满还是定宽?按钮放哪里?下拉框选项有哪些?
- 输入输出:用户输入什么?系统输出什么?格式是什么?
- 应用场景3-5个具体场景越具体越好
- AI增强点哪些地方可以加「AI一键优化」或「AI智能推荐」
- 技术复杂度:需要用户登录吗?数据存哪里?需要服务器吗?
**可选收集**(锦上添花):
- 技术偏好:有没有特定技术要求?
- 参考产品:有没有可以抄的对象?抄哪里,不抄哪里?
- 优先级:第一期做什么,第二期做什么?
[对话策略]
**开场策略**
- 不废话,直接基于用户已表达的内容开始追问
- 让用户先倒完脑子里的东西,再开始解剖
**追问策略**
- 每次只追问 1-2 个问题,问题要直击要害
- 不接受模糊回答:"大概"、"可能"、"应该"、"用户会喜欢的" → 追问到底
- 发现逻辑漏洞,直接指出,不留情面
- 发现用户在自嗨,冷静泼冷水
- 当用户说"界面你看着办"或"随便",不惯着,用具体选项逼他们决策
- 布局必须问到具体:几栏、比例、各区域内容、控件规范
**方案引导策略**
- 用户知道但没说清楚 → 继续逼问,不给方案
- 用户真不知道 → 给 2-3 个选项 + 各自优劣,根据产品类型给针对性建议
- 给完继续逼他选,选完继续逼下一个细节
- 选项是工具,不是退路
**AI能力引导策略**
- 每当用户描述一个功能,主动思考:这个能不能用 AI 做?
- 主动询问:"这里要不要加个 AI 一键XX"
- 用户设计了繁琐的手动流程 → 直接建议用 AI 简化
- 对话后期,主动总结需要的 AI 能力类型
**技术需求引导策略**
- 用户没有编程基础,不直接问技术问题,通过业务场景推断技术需求
- 遵循简单优先原则,能不加复杂度就不加
- 用户想要的功能会大幅增加复杂度时,先劝退或建议分期
**确认策略**
- 定期复述已收集的信息,发现矛盾直接质问
- 信息够了就推进,不拖泥带水
- 用户说"差不多了"但信息明显不够,继续问
**搜索策略**
- 涉及可能变化的信息(技术、行业、竞品),先上网搜索再开口
[信息充足度判断]
当以下条件满足时,可以生成 Product Spec
**必须满足**
- ✅ 产品定位清晰(能用一句人话说明白这是什么)
- ✅ 目标用户明确(知道给谁用、为什么用)
- ✅ 核心功能明确至少3个功能点且能说清楚为什么需要
- ✅ 用户流程清晰(至少一条完整路径,从头到尾)
- ✅ AI能力需求明确知道哪些功能需要 AI用什么类型的 AI
**尽量满足**
- ✅ 整体布局有方向(知道大概是什么结构)
- ✅ 控件有基本规范(主要输入输出方式清楚)
如果「必须满足」条件未达成,继续追问,不要勉强生成一份垃圾文档。
如果「尽量满足」条件未达成,可以生成但标注 [待补充]。
[启动检查]
Skill 启动时,首先执行以下检查:
第一步:扫描项目目录,按优先级查找产品需求文档
优先级1精确匹配Product-Spec.md
优先级2扩大匹配*spec*.md、*prd*.md、*PRD*.md、*需求*.md、*product*.md
匹配规则:
- 找到 1 个文件 → 直接使用
- 找到多个候选文件 → 列出文件名问用户"你要改的是哪个?"
- 没找到 → 进入 0-1 模式
第二步:判断模式
- 找到产品需求文档 → 进入 **迭代模式**
- 没找到 → 进入 **0-1 模式**
第三步:执行对应流程
- 0-1 模式:执行 [工作流程0-1模式]
- 迭代模式:执行 [工作流程(迭代模式)]
[工作流程0-1模式]
[需求探索阶段]
目的:让用户把脑子里的东西倒出来
第一步:接住用户
**先上网搜索**:根据用户表达的产品想法上网搜索相关信息,了解最新情况
基于用户已经表达的内容,直接开始追问
不重复问"你想做什么",用户已经说过了
第二步:追问
**先上网搜索**:根据用户表达的内容上网搜索相关信息,确保追问基于最新知识
针对模糊、矛盾、自嗨的地方,直接追问
每次1-2个问题问到点子上
同时思考哪些功能可以用 AI 增强
第三步:阶段性确认
复述理解,确认没跑偏
有问题当场纠正
[需求完善阶段]
目的:填补漏洞,逼用户想清楚,确定 AI 能力需求和界面布局
第一步:漏洞识别
对照 [需求维度清单],找出缺失的关键信息
第二步:逼问
**先上网搜索**:针对缺失项上网搜索相关信息,确保给出的建议和方案是最新的
针对缺失项设计问题
不接受敷衍回答
布局问题要问到具体:几栏、比例、各区域内容、控件规范
第三步AI能力引导
**先上网搜索**:上网搜索最新的 AI 能力和最佳实践,确保建议不过时
主动询问用户:
- "这个功能要不要加 AI 一键优化?"
- "这里让用户手动填,还是让 AI 智能推荐?"
根据用户需求识别需要的 AI 能力类型(文本生成、图像生成、图像识别等)
第四步:技术复杂度评估
**先上网搜索**:上网搜索相关技术方案,确保建议是最新的
根据 [技术需求引导] 策略,通过业务问题判断技术复杂度
如果用户想要的功能会大幅增加复杂度,先劝退或建议分期
确保用户理解技术选择的影响
第五步:充足度判断
对照 [信息充足度判断]
「必须满足」都达成 → 提议生成
未达成 → 继续问,不惯着
[文档生成阶段]
目的:输出可用的 Product Spec 文件
第一步:整理
将对话内容按输出模板结构分类
第二步:填充
加载 templates/product-spec-template.md 获取模板格式
按模板格式填写
「尽量满足」未达成的地方标注 [待补充]
功能用动词开头
UI布局要描述清楚整体结构和各区域细节
流程写清楚步骤
第三步识别AI能力需求
根据功能需求识别所需的 AI 能力类型
在「AI 能力需求」部分列出
说明每种能力在本产品中的具体用途
第四步:输出文件
将 Product Spec 保存为 Product-Spec.md
[工作流程(迭代模式)]
**触发条件**:用户在开发过程中提出新功能、修改需求或迭代想法
**核心原则**:无缝衔接,不打断用户工作流。不需要开场白,直接接住用户的需求往下问。
[变更识别阶段]
目的:搞清楚用户要改什么
第一步:接住需求
**先上网搜索**:根据用户提出的变更内容上网搜索相关信息,确保追问基于最新知识
用户说"我觉得应该还要有一个AI一键推荐功能"
直接追问:"AI一键推荐什么推荐给谁这个按钮放哪个页面点了之后发生什么"
第二步:判断变更类型
根据 [迭代模式-追问深度判断] 确定这是重度、中度还是轻度变更
决定追问深度
[追问完善阶段]
目的:问到能直接改 Spec 为止
第一步:按深度追问
**先上网搜索**:每次追问前上网搜索相关信息,确保问题和建议基于最新知识
重度变更:问到能回答"这个变更会怎么影响现有产品"
中度变更:问到能回答"具体改成什么样"
轻度变更:确认理解正确即可
第二步:用户卡住时给方案
**先上网搜索**:给方案前上网搜索最新的解决方案和最佳实践
用户不知道怎么做 → 给 2-3 个选项 + 优劣
给完继续逼他选,选完继续逼下一个细节
第三步:冲突检测
加载现有 Product-Spec.md
检查新需求是否与现有内容冲突
发现冲突 → 直接指出冲突点 + 给解决方案 + 让用户选
**停止追问的标准**
- 能够直接动手改 Product Spec不需要再猜或假设
- 改完之后用户不会说"不是这个意思"
[文档更新阶段]
目的:更新 Product Spec 并记录变更
第一步:理解现有文档结构
加载现有 Spec 文件
识别其章节结构(可能和模板不同)
后续修改基于现有结构,不强行套用模板
第二步:直接修改源文件
在现有 Spec 上直接修改
保持文档整体结构不变
只改需要改的部分
第三步:更新 AI 能力需求
如果涉及新的 AI 功能:
- 在「AI 能力需求」章节添加新能力类型
- 说明新能力的用途
第四步:自动追加变更记录
在 Product-Spec-CHANGELOG.md 中追加本次变更
如果 CHANGELOG 文件不存在,创建一个
记录 Product Spec 迭代变更时,加载 templates/changelog-template.md 获取完整的变更记录格式和示例
根据对话内容自动生成变更描述
[迭代模式-追问深度判断]
**变更类型判断逻辑**(按顺序检查):
1. 涉及新 AI 能力?→ 重度
2. 涉及用户核心路径变更?→ 重度
3. 涉及布局结构(几栏、区域划分)?→ 重度
4. 新增主要功能模块?→ 重度
5. 涉及新功能但不改核心流程?→ 中度
6. 涉及现有功能的逻辑调整?→ 中度
7. 局部布局调整?→ 中度
8. 只是改文字、选项、样式?→ 轻度
**各类型追问标准**
| 变更类型 | 停止追问的条件 | 必须问清楚的内容 |
|---------|---------------|----------------|
| **重度** | 能回答"这个变更会怎么影响现有产品"时停止 | 为什么需要?影响哪些现有功能?用户流程怎么变?需要什么新的 AI 能力? |
| **中度** | 能回答"具体改成什么样"时停止 | 改哪里?改成什么?和现有的怎么配合? |
| **轻度** | 确认理解正确时停止 | 改什么?改成什么? |
[初始化]
执行 [启动检查]

View File

@@ -0,0 +1,111 @@
---
name: changelog-template
description: 变更记录模板。当 Product Spec 发生迭代变更时,按照此模板格式记录变更历史,输出为 Product-Spec-CHANGELOG.md 文件。
---
# 变更记录模板
本模板用于记录 Product Spec 的迭代变更历史。
---
## 文件命名
`Product-Spec-CHANGELOG.md`
---
## 模板格式
```markdown
# 变更记录
## [v1.2] - YYYY-MM-DD
### 新增
- <新增的功能或内容>
### 修改
- <修改的功能或内容>
### 删除
- <删除的功能或内容>
---
## [v1.1] - YYYY-MM-DD
### 新增
- <新增的功能或内容>
---
## [v1.0] - YYYY-MM-DD
- 初始版本
```
---
## 记录规则
- **版本号递增**:每次迭代 +0.1(如 v1.0 → v1.1 → v1.2
- **日期自动填充**:使用当天日期,格式 YYYY-MM-DD
- **变更描述**:根据对话内容自动生成,简明扼要
- **分类记录**:新增、修改、删除分开写,没有的分类不写
- **只记录实际改动**:没改的部分不记录
- **新增控件要写位置**:涉及 UI 变更时,说明控件放在哪里
---
## 完整示例
以下是「剧本分镜生成器」的变更记录示例,供参考:
```markdown
# 变更记录
## [v1.2] - 2025-12-08
### 新增
- 新增「AI 优化描述」按钮(角色设定区底部),点击后自动优化角色和场景的描述文字
- 新增分镜描述显示,每张分镜图下方展示 AI 生成的画面描述
### 修改
- 左侧输入区比例从 35% 改为 40%
- 「生成分镜」按钮样式改为更醒目的主色调
---
## [v1.1] - 2025-12-05
### 新增
- 新增「场景设定」功能区(角色设定区下方),用户可上传场景参考图建立视觉档案
- 新增「水墨」画风选项
- 新增图像理解能力,用于分析用户上传的参考图
### 修改
- 角色卡片布局优化,参考图预览尺寸从 80px 改为 120px
### 删除
- 移除「自动分页」功能(用户反馈更希望手动控制分页节奏)
---
## [v1.0] - 2025-12-01
- 初始版本
```
---
## 写作要点
1. **版本号**:从 v1.0 开始,每次迭代 +0.1,重大改版可以 +1.0
2. **日期格式**:统一用 YYYY-MM-DD方便排序和查找
3. **变更描述**
- 动词开头(新增、修改、删除、移除、调整)
- 说清楚改了什么、改成什么样
- 新增控件要写位置(如「角色设定区底部」)
- 数值变更要写前后对比(如「从 35% 改为 40%」)
- 如果有原因,简要说明(如「用户反馈不需要」)
4. **分类原则**
- 新增:之前没有的功能、控件、能力
- 修改:改变了现有内容的行为、样式、参数
- 删除:移除了之前有的功能
5. **颗粒度**:一条记录对应一个独立的变更点,不要把多个改动混在一起
6. **AI 能力变更**:如果新增或移除了 AI 能力,必须单独记录

View File

@@ -0,0 +1,197 @@
---
name: product-spec-template
description: Product Spec 输出模板。当需要生成产品需求文档时,按照此模板的结构和格式填充内容,输出为 Product-Spec.md 文件。
---
# Product Spec 输出模板
本模板用于生成结构完整的 Product Spec 文档。生成时按照此结构填充内容。
---
## 模板结构
**文件命名**Product-Spec.md
---
## 产品概述
<一段话说清楚>
- 这是什么产品
- 解决什么问题
- **目标用户是谁**(具体描述,不要只说「用户」)
- 核心价值是什么
## 应用场景
<列举 3-5 个具体场景在什么情况下怎么用解决什么问题>
## 功能需求
<核心功能辅助功能分类每条功能说明用户做什么 系统做什么 得到什么>
## UI 布局
<描述整体布局结构和各区域的详细设计需要包含>
- 整体是什么布局(几栏、比例、固定元素等)
- 每个区域放什么内容
- 控件的具体规范(位置、尺寸、样式等)
## 用户使用流程
<分步骤描述用户如何使用产品可以有多条路径如快速上手进阶使用>
## AI 能力需求
| 能力类型 | 用途说明 | 应用位置 |
|---------|---------|---------|
| <能力类型> | <做什么> | <在哪个环节触发> |
## 技术说明(可选)
<如果涉及以下内容需要说明>
- 数据存储:是否需要登录?数据存在哪里?
- 外部依赖:需要调用什么服务?有什么限制?
- 部署方式:纯前端?需要服务器?
## 补充说明
<如有需要用表格说明选项状态逻辑等>
---
## 完整示例
以下是一个「剧本分镜生成器」的 Product Spec 示例,供参考:
```markdown
## 产品概述
这是一个帮助漫画作者、短视频创作者、动画团队将剧本快速转化为分镜图的工具。
**目标用户**:有剧本但缺乏绘画能力、或者想快速出分镜草稿的创作者。他们可能是独立漫画作者、短视频博主、动画工作室的前期策划人员,共同的痛点是「脑子里有画面,但画不出来或画太慢」。
**核心价值**用户只需输入剧本文本、上传角色和场景参考图、选择画风AI 就会自动分析剧本结构,生成保持视觉一致性的分镜图,将原本需要数小时的分镜绘制工作缩短到几分钟。
## 应用场景
- **漫画创作**:独立漫画作者小王有一个 20 页的剧本需要先出分镜草稿再精修。他把剧本贴进来上传主角的参考图10 分钟就拿到了全部分镜草稿,可以直接在这个基础上精修。
- **短视频策划**:短视频博主小李要拍一个 3 分钟的剧情短片,需要给摄影师看分镜。她把脚本输入,选择「写实」风格,生成的分镜图直接可以当拍摄参考。
- **动画前期**:动画工作室要向客户提案,需要快速出一版分镜来展示剧本节奏。策划人员用这个工具 30 分钟出了 50 张分镜图,当天就能开提案会。
- **小说可视化**:网文作者想给自己的小说做宣传图,把关键场景描述输入,生成的分镜图可以直接用于社交媒体宣传。
- **教学演示**:小学语文老师想把一篇课文变成连环画给学生看,把课文内容输入,选择「动漫」风格,生成的图片可以直接做成 PPT。
## 功能需求
**核心功能**
- 剧本输入与分析:用户输入剧本文本 → 点击「生成分镜」→ AI 自动识别角色、场景和情节节拍,将剧本拆分为多页分镜
- 角色设定:用户添加角色卡片(名称 + 外观描述 + 参考图)→ 系统建立角色视觉档案,后续生成时保持外观一致
- 场景设定:用户添加场景卡片(名称 + 氛围描述 + 参考图)→ 系统建立场景视觉档案(可选,不设定则由 AI 根据剧本生成)
- 画风选择:用户从下拉框选择画风(漫画/动漫/写实/赛博朋克/水墨)→ 生成的分镜图采用对应视觉风格
- 分镜生成:用户点击「生成分镜」→ AI 生成当前页 9 张分镜图3x3 九宫格)→ 展示在右侧输出区
- 连续生成:用户点击「继续生成下一页」→ AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
**辅助功能**
- 批量下载:用户点击「下载全部」→ 系统将当前页 9 张图打包为 ZIP 下载
- 历史浏览:用户通过页面导航 → 切换查看已生成的历史页面
## UI 布局
### 整体布局
左右两栏布局,左侧输入区占 40%,右侧输出区占 60%。
### 左侧 - 输入区
- 顶部:项目名称输入框
- 剧本输入多行文本框placeholder「请输入剧本内容...」
- 角色设定区:
- 角色卡片列表,每张卡片包含:角色名、外观描述、参考图上传
- 「添加角色」按钮
- 场景设定区:
- 场景卡片列表,每张卡片包含:场景名、氛围描述、参考图上传
- 「添加场景」按钮
- 画风选择:下拉选择(漫画 / 动漫 / 写实 / 赛博朋克 / 水墨),默认「动漫」
- 底部:「生成分镜」主按钮,靠右对齐,醒目样式
### 右侧 - 输出区
- 分镜图展示区3x3 网格布局,展示 9 张独立分镜图
- 每张分镜图下方显示:分镜编号、简要描述
- 操作按钮:「下载全部」「继续生成下一页」
- 页面导航:显示当前页数,支持切换查看历史页面
## 用户使用流程
### 首次生成
1. 输入剧本内容
2. 添加角色:填写名称、外观描述,上传参考图
3. 添加场景:填写名称、氛围描述,上传参考图(可选)
4. 选择画风
5. 点击「生成分镜」
6. 在右侧查看生成的 9 张分镜图
7. 点击「下载全部」保存
### 连续生成
1. 完成首次生成后
2. 点击「继续生成下一页」
3. AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
4. 重复直到剧本完成
## AI 能力需求
| 能力类型 | 用途说明 | 应用位置 |
|---------|---------|---------|
| 文本理解与生成 | 分析剧本结构,识别角色、场景、情节节拍,规划分镜内容 | 点击「生成分镜」时 |
| 图像生成 | 根据分镜描述生成 3x3 九宫格分镜图 | 点击「生成分镜」「继续生成下一页」时 |
| 图像理解 | 分析用户上传的角色和场景参考图,提取视觉特征用于保持一致性 | 上传角色/场景参考图时 |
## 技术说明
- **数据存储**无需登录项目数据保存在浏览器本地存储LocalStorage关闭页面后仍可恢复
- **图像生成**:调用 AI 图像生成服务,每次生成 9 张图约需 30-60 秒
- **文件导出**:支持 PNG 格式批量下载,打包为 ZIP 文件
- **部署方式**:纯前端应用,无需服务器,可部署到任意静态托管平台
## 补充说明
| 选项 | 可选值 | 说明 |
|------|--------|------|
| 画风 | 漫画 / 动漫 / 写实 / 赛博朋克 / 水墨 | 决定分镜图的整体视觉风格 |
| 角色参考图 | 图片上传 | 用于建立角色视觉身份,确保一致性 |
| 场景参考图 | 图片上传(可选) | 用于建立场景氛围,不上传则由 AI 根据描述生成 |
```
---
## 写作要点
1. **产品概述**
- 一句话说清楚是什么
- **必须明确写出目标用户**:是谁、有什么特点、什么痛点
- 核心价值:用了这个产品能得到什么
2. **应用场景**
- 具体的人 + 具体的情况 + 具体的用法 + 解决什么问题
- 场景要有画面感,让人一看就懂
- 放在功能需求之前,帮助理解产品价值
3. **功能需求**
- 分「核心功能」和「辅助功能」
- 每条格式:用户做什么 → 系统做什么 → 得到什么
- 写清楚触发方式(点击什么按钮)
4. **UI 布局**
- 先写整体布局(几栏、比例)
- 再逐个区域描述内容
- 控件要具体:下拉框写出所有选项和默认值,按钮写明位置和样式
5. **用户流程**:分步骤,可以有多条路径
6. **AI 能力需求**
- 列出需要的 AI 能力类型
- 说明具体用途
- **写清楚在哪个环节触发**,方便开发理解调用时机
7. **技术说明**(可选):
- 数据存储方式
- 外部服务依赖
- 部署方式
- 只在有技术约束时写,没有就不写
8. **补充说明**:用表格,适合解释选项、状态、逻辑

BIN
.coverage Normal file

Binary file not shown.

View File

@@ -8,6 +8,23 @@ DB_NAME=docmaster
DB_USER=docmaster
DB_PASSWORD=your_password_here
# Storage Configuration
# Backend type: local, azure_blob, or s3
# All storage paths are relative to STORAGE_BASE_PATH (documents/, images/, uploads/, etc.)
STORAGE_BACKEND=local
STORAGE_BASE_PATH=./data
# Azure Blob Storage (when STORAGE_BACKEND=azure_blob)
# AZURE_STORAGE_CONNECTION_STRING=your_connection_string
# AZURE_STORAGE_CONTAINER=documents
# AWS S3 Storage (when STORAGE_BACKEND=s3)
# AWS_S3_BUCKET=your_bucket_name
# AWS_REGION=us-east-1
# AWS_ACCESS_KEY_ID=your_access_key
# AWS_SECRET_ACCESS_KEY=your_secret_key
# AWS_ENDPOINT_URL= # Optional: for S3-compatible services like MinIO
# Model Configuration (optional)
# MODEL_PATH=runs/train/invoice_fields/weights/best.pt
# CONFIDENCE_THRESHOLD=0.5

4
.gitignore vendored
View File

@@ -52,6 +52,10 @@ reports/*.jsonl
logs/
*.log
# Coverage
htmlcov/
.coverage
# Jupyter
.ipynb_checkpoints/

666
ARCHITECTURE_REVIEW.md Normal file
View File

@@ -0,0 +1,666 @@
# Invoice Master POC v2 - 总体架构审查报告
**审查日期**: 2026-02-01
**审查人**: Claude Code
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
---
## 架构概述
### 整体架构图
```
┌─────────────────────────────────────────────────────────────────┐
│ Frontend (React) │
│ Vite + TypeScript + TailwindCSS │
└─────────────────────────────┬───────────────────────────────────┘
│ HTTP/REST
┌─────────────────────────────▼───────────────────────────────────┐
│ Inference Service (FastAPI) │
│ ┌──────────────┬──────────────┬──────────────┬──────────────┐ │
│ │ Public API │ Admin API │ Training API│ Batch API │ │
│ └──────────────┴──────────────┴──────────────┴──────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Service Layer │ │
│ │ InferenceService │ AsyncProcessing │ BatchUpload │ Dataset │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Data Layer │ │
│ │ AdminDB │ AsyncRequestDB │ SQLModel │ PostgreSQL │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Core Components │ │
│ │ RateLimiter │ Schedulers │ TaskQueues │ Auth │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────┬───────────────────────────────────┘
│ PostgreSQL
┌─────────────────────────────▼───────────────────────────────────┐
│ Training Service (GPU) │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ CLI: train │ autolabel │ analyze │ validate │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ YOLO: db_dataset │ annotation_generator │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Processing: CPU Pool │ GPU Pool │ Task Dispatcher │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
┌─────────┴─────────┐
▼ ▼
┌──────────────┐ ┌──────────────┐
│ Shared │ │ Storage │
│ PDF │ OCR │ │ Local/Azure/ │
│ Normalize │ │ S3 │
└──────────────┘ └──────────────┘
```
### 技术栈
| 层级 | 技术 | 评估 |
|------|------|------|
| **前端** | React + Vite + TypeScript + TailwindCSS | ✅ 现代栈 |
| **API 框架** | FastAPI | ✅ 高性能,类型安全 |
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 ORM |
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
| **OCR** | PaddleOCR v5 | ✅ 支持瑞典语 |
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
---
## 架构优势
### 1. Monorepo 结构 ✅
```
packages/
├── shared/ # 共享库 - 无外部依赖
├── training/ # 训练服务 - 依赖 shared
└── inference/ # 推理服务 - 依赖 shared
```
**优点**:
- 清晰的包边界,无循环依赖
- 独立部署training 按需启动
- 代码复用率高
### 2. 分层架构 ✅
```
API Routes (web/api/v1/)
Service Layer (web/services/)
Data Layer (data/)
Database (PostgreSQL)
```
**优点**:
- 职责分离明确
- 便于单元测试
- 可替换底层实现
### 3. 依赖注入 ✅
```python
# FastAPI Depends 使用得当
@router.post("/infer")
async def infer(
file: UploadFile,
db: AdminDB = Depends(get_admin_db), # 注入
token: str = Depends(validate_admin_token),
):
```
### 4. 存储抽象层 ✅
```python
# 统一接口,支持多后端
class StorageBackend(ABC):
def upload(self, source: Path, destination: str) -> None: ...
def download(self, source: str, destination: Path) -> None: ...
def get_presigned_url(self, path: str) -> str: ...
# 实现: LocalStorageBackend, AzureStorageBackend, S3StorageBackend
```
### 5. 动态模型管理 ✅
```python
# 数据库驱动的模型切换
def get_active_model_path() -> Path | None:
db = AdminDB()
active_model = db.get_active_model_version()
return active_model.model_path if active_model else None
inference_service = InferenceService(
model_path_resolver=get_active_model_path,
)
```
### 6. 任务队列分离 ✅
```python
# 不同类型任务使用不同队列
- AsyncTaskQueue: 异步推理任务
- BatchQueue: 批量上传任务
- TrainingScheduler: 训练任务调度
- AutoLabelScheduler: 自动标注调度
```
---
## 架构问题与风险
### 1. 数据库层职责过重 ⚠️ **中风险**
**问题**: `AdminDB` 类过大,违反单一职责原则
```python
# packages/inference/inference/data/admin_db.py
class AdminDB:
# Token 管理 (5 个方法)
def is_valid_admin_token(self, token: str) -> bool: ...
def create_admin_token(self, token: str, name: str): ...
# 文档管理 (8 个方法)
def create_document(self, ...): ...
def get_document(self, doc_id: str): ...
# 标注管理 (6 个方法)
def create_annotation(self, ...): ...
def get_annotations(self, doc_id: str): ...
# 训练任务 (7 个方法)
def create_training_task(self, ...): ...
def update_training_task(self, ...): ...
# 数据集 (6 个方法)
def create_dataset(self, ...): ...
def get_dataset(self, dataset_id: str): ...
# 模型版本 (5 个方法)
def create_model_version(self, ...): ...
def activate_model_version(self, ...): ...
# 批处理 (4 个方法)
# 锁管理 (3 个方法)
# ... 总计 50+ 方法
```
**影响**:
- 类过大,难以维护
- 测试困难
- 不同领域变更互相影响
**建议**: 按领域拆分为 Repository 模式
```python
# 建议重构
class TokenRepository:
def validate(self, token: str) -> bool: ...
def create(self, token: Token) -> None: ...
class DocumentRepository:
def find_by_id(self, doc_id: str) -> Document | None: ...
def save(self, document: Document) -> None: ...
class TrainingRepository:
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
def update_task_status(self, task_id: str, status: TaskStatus): ...
class ModelRepository:
def get_active(self) -> ModelVersion | None: ...
def activate(self, version_id: str) -> None: ...
```
---
### 2. Service 层混合业务逻辑与技术细节 ⚠️ **中风险**
**问题**: `InferenceService` 既处理业务逻辑又处理技术实现
```python
# packages/inference/inference/web/services/inference.py
class InferenceService:
def process(self, image_bytes: bytes) -> ServiceResult:
# 1. 技术细节: 图像解码
image = Image.open(io.BytesIO(image_bytes))
# 2. 业务逻辑: 字段提取
fields = self._extract_fields(image)
# 3. 技术细节: 模型推理
detections = self._model.predict(image)
# 4. 业务逻辑: 结果验证
if not self._validate_fields(fields):
raise ValidationError()
```
**影响**:
- 难以测试业务逻辑
- 技术变更影响业务代码
- 无法切换技术实现
**建议**: 引入领域层和适配器模式
```python
# 领域层 - 纯业务逻辑
@dataclass
class InvoiceDocument:
document_id: str
pages: list[Page]
class InvoiceExtractor:
"""纯业务逻辑,不依赖技术实现"""
def extract(self, document: InvoiceDocument) -> InvoiceFields:
# 只处理业务规则
pass
# 适配器层 - 技术实现
class YoloFieldDetector:
"""YOLO 技术适配器"""
def __init__(self, model_path: Path):
self._model = YOLO(model_path)
def detect(self, image: np.ndarray) -> list[FieldRegion]:
return self._model.predict(image)
class PaddleOcrEngine:
"""PaddleOCR 技术适配器"""
def __init__(self):
self._ocr = PaddleOCR()
def recognize(self, image: np.ndarray, region: BoundingBox) -> str:
return self._ocr.ocr(image, region)
# 应用服务 - 协调领域和适配器
class InvoiceProcessingService:
def __init__(
self,
extractor: InvoiceExtractor,
detector: FieldDetector,
ocr: OcrEngine,
):
self._extractor = extractor
self._detector = detector
self._ocr = ocr
```
---
### 3. 调度器设计分散 ⚠️ **中风险**
**问题**: 多个独立调度器缺乏统一协调
```python
# 当前设计 - 4 个独立调度器
# 1. TrainingScheduler (core/scheduler.py)
# 2. AutoLabelScheduler (core/autolabel_scheduler.py)
# 3. AsyncTaskQueue (workers/async_queue.py)
# 4. BatchQueue (workers/batch_queue.py)
# app.py 中分别启动
start_scheduler() # 训练调度器
start_autolabel_scheduler() # 自动标注调度器
init_batch_queue() # 批处理队列
```
**影响**:
- 资源竞争风险
- 难以监控和追踪
- 任务优先级难以管理
- 重启时任务丢失
**建议**: 使用 Celery + Redis 统一任务队列
```python
# 建议重构
from celery import Celery
app = Celery('invoice_master')
@app.task(bind=True, max_retries=3)
def process_inference(self, document_id: str):
"""异步推理任务"""
try:
service = get_inference_service()
result = service.process(document_id)
return result
except Exception as exc:
raise self.retry(exc=exc, countdown=60)
@app.task
def train_model(dataset_id: str, config: dict):
"""训练任务"""
training_service = get_training_service()
return training_service.train(dataset_id, config)
@app.task
def auto_label_documents(document_ids: list[str]):
"""批量自动标注"""
for doc_id in document_ids:
auto_label_document.delay(doc_id)
# 优先级队列
app.conf.task_routes = {
'tasks.process_inference': {'queue': 'high_priority'},
'tasks.train_model': {'queue': 'gpu_queue'},
'tasks.auto_label_documents': {'queue': 'low_priority'},
}
```
---
### 4. 配置分散 ⚠️ **低风险**
**问题**: 配置分散在多个文件
```python
# packages/shared/shared/config.py
DATABASE = {...}
PATHS = {...}
AUTOLABEL = {...}
# packages/inference/inference/web/config.py
@dataclass
class ModelConfig: ...
@dataclass
class ServerConfig: ...
@dataclass
class FileConfig: ...
# 环境变量
# .env 文件
```
**影响**:
- 配置难以追踪
- 可能出现不一致
- 缺少配置验证
**建议**: 使用 Pydantic Settings 集中管理
```python
# config/settings.py
from pydantic_settings import BaseSettings, SettingsConfigDict
class DatabaseSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix='DB_')
host: str = 'localhost'
port: int = 5432
name: str = 'docmaster'
user: str = 'docmaster'
password: str # 无默认值,必须设置
class StorageSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix='STORAGE_')
backend: str = 'local'
base_path: str = '~/invoice-data'
azure_connection_string: str | None = None
s3_bucket: str | None = None
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
)
database: DatabaseSettings = DatabaseSettings()
storage: StorageSettings = StorageSettings()
# 验证
@field_validator('database')
def validate_database(cls, v):
if not v.password:
raise ValueError('Database password is required')
return v
# 全局配置实例
settings = Settings()
```
---
### 5. 内存队列单点故障 ⚠️ **中风险**
**问题**: AsyncTaskQueue 和 BatchQueue 基于内存
```python
# workers/async_queue.py
class AsyncTaskQueue:
def __init__(self):
self._queue = Queue() # 内存队列
self._workers = []
def enqueue(self, task: AsyncTask) -> None:
self._queue.put(task) # 仅存储在内存
```
**影响**:
- 服务重启丢失所有待处理任务
- 无法水平扩展
- 任务持久化困难
**建议**: 使用 Redis/RabbitMQ 持久化队列
---
### 6. 缺少 API 版本迁移策略 ❓ **低风险**
**问题**: 有 `/api/v1/` 版本,但缺少升级策略
```
当前: /api/v1/admin/documents
未来: /api/v2/admin/documents ?
```
**建议**:
- 制定 API 版本升级流程
- 使用 Header 版本控制
- 维护版本兼容性文档
---
## 关键架构风险矩阵
| 风险项 | 概率 | 影响 | 风险等级 | 优先级 |
|--------|------|------|----------|--------|
| 内存队列丢失任务 | 中 | 高 | **高** | 🔴 P0 |
| AdminDB 职责过重 | 高 | 中 | **中** | 🟡 P1 |
| Service 层混合 | 高 | 中 | **中** | 🟡 P1 |
| 调度器资源竞争 | 中 | 中 | **中** | 🟡 P1 |
| 配置分散 | 高 | 低 | **低** | 🟢 P2 |
| API 版本策略 | 低 | 低 | **低** | 🟢 P2 |
---
## 改进建议路线图
### Phase 1: 立即执行 (本周)
#### 1.1 拆分 AdminDB
```python
# 创建 repositories 包
inference/data/repositories/
├── __init__.py
├── base.py # Repository 基类
├── token.py # TokenRepository
├── document.py # DocumentRepository
├── annotation.py # AnnotationRepository
├── training.py # TrainingRepository
├── dataset.py # DatasetRepository
└── model.py # ModelRepository
```
#### 1.2 统一配置
```python
# 创建统一配置模块
inference/config/
├── __init__.py
├── settings.py # Pydantic Settings
└── validators.py # 配置验证
```
### Phase 2: 短期执行 (本月)
#### 2.1 引入消息队列
```yaml
# docker-compose.yml 添加
services:
redis:
image: redis:7-alpine
ports:
- "6379:6379"
celery_worker:
build: .
command: celery -A inference.tasks worker -l info
depends_on:
- redis
- postgres
```
#### 2.2 添加缓存层
```python
# 使用 Redis 缓存热点数据
from redis import Redis
redis_client = Redis(host='localhost', port=6379)
class CachedDocumentRepository(DocumentRepository):
def find_by_id(self, doc_id: str) -> Document | None:
# 先查缓存
cached = redis_client.get(f"doc:{doc_id}")
if cached:
return Document.parse_raw(cached)
# 再查数据库
doc = super().find_by_id(doc_id)
if doc:
redis_client.setex(f"doc:{doc_id}", 3600, doc.json())
return doc
```
### Phase 3: 长期执行 (本季度)
#### 3.1 数据库读写分离
```python
# 配置主从数据库
class DatabaseManager:
def __init__(self):
self._master = create_engine(MASTER_DB_URL)
self._replica = create_engine(REPLICA_DB_URL)
def get_session(self, readonly: bool = False) -> Session:
engine = self._replica if readonly else self._master
return Session(engine)
```
#### 3.2 事件驱动架构
```python
# 引入事件总线
from event_bus import EventBus
bus = EventBus()
# 发布事件
@router.post("/documents")
async def create_document(...):
doc = document_repo.save(document)
bus.publish('document.created', {'document_id': doc.id})
return doc
# 订阅事件
@bus.subscribe('document.created')
def on_document_created(event):
# 触发自动标注
auto_label_task.delay(event['document_id'])
```
---
## 架构演进建议
### 当前架构 (适合 1-10 用户)
```
Single Instance
├── FastAPI App
├── Memory Queues
└── PostgreSQL
```
### 目标架构 (适合 100+ 用户)
```
Load Balancer
├── FastAPI Instance 1
├── FastAPI Instance 2
└── FastAPI Instance N
┌───────┴───────┐
▼ ▼
Redis Cluster PostgreSQL
(Celery + Cache) (Master + Replica)
```
---
## 总结
### 总体评分
| 维度 | 评分 | 说明 |
|------|------|------|
| **模块化** | 8/10 | 包结构清晰,但部分类过大 |
| **可扩展性** | 7/10 | 水平扩展良好,垂直扩展受限 |
| **可维护性** | 8/10 | 分层合理,但职责边界需细化 |
| **可靠性** | 7/10 | 内存队列是单点故障 |
| **性能** | 8/10 | 异步处理良好 |
| **安全性** | 8/10 | 基础安全到位 |
| **总体** | **7.7/10** | 良好的架构基础,需优化细节 |
### 关键结论
1. **架构设计合理**: Monorepo + 分层架构适合当前规模
2. **主要风险**: 内存队列和数据库职责过重
3. **演进路径**: 引入消息队列和缓存层
4. **投入产出**: 当前架构可支撑到 100+ 用户,无需大规模重构
### 下一步行动
| 优先级 | 任务 | 预计工时 | 影响 |
|--------|------|----------|------|
| 🔴 P0 | 引入 Celery + Redis | 3 天 | 解决任务丢失问题 |
| 🟡 P1 | 拆分 AdminDB | 2 天 | 提升可维护性 |
| 🟡 P1 | 统一配置管理 | 1 天 | 减少配置错误 |
| 🟢 P2 | 添加缓存层 | 2 天 | 提升性能 |
| 🟢 P2 | 数据库读写分离 | 3 天 | 提升扩展性 |
---
## 附录
### 关键文件清单
| 文件 | 职责 | 问题 |
|------|------|------|
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
| `shared/shared/config.py` | 共享配置 | 分散管理 |
### 参考资源
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
- [Celery Documentation](https://docs.celeryproject.org/)
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)

805
CODE_REVIEW_REPORT.md Normal file
View File

@@ -0,0 +1,805 @@
# Invoice Master POC v2 - 详细代码审查报告
**审查日期**: 2026-02-01
**审查人**: Claude Code
**项目路径**: `C:\Users\yaoji\git\ColaCoder\invoice-master-poc-v2`
**代码统计**:
- Python文件: 200+ 个
- 测试文件: 97 个
- TypeScript/React文件: 39 个
- 总测试数: 1,601 个
- 测试覆盖率: 28%
---
## 目录
1. [执行摘要](#执行摘要)
2. [架构概览](#架构概览)
3. [详细模块审查](#详细模块审查)
4. [代码质量问题](#代码质量问题)
5. [安全风险分析](#安全风险分析)
6. [性能问题](#性能问题)
7. [改进建议](#改进建议)
8. [总结与评分](#总结与评分)
---
## 执行摘要
### 总体评估
| 维度 | 评分 | 状态 |
|------|------|------|
| **代码质量** | 7.5/10 | 良好,但有改进空间 |
| **安全性** | 7/10 | 基础安全到位,需加强 |
| **可维护性** | 8/10 | 模块化良好 |
| **测试覆盖** | 5/10 | 偏低,需提升 |
| **性能** | 8/10 | 异步处理良好 |
| **文档** | 8/10 | 文档详尽 |
| **总体** | **7.3/10** | 生产就绪,需小幅改进 |
### 关键发现
**优势:**
- 清晰的Monorepo架构三包分离合理
- 类型注解覆盖率高(>90%
- 存储抽象层设计优秀
- FastAPI使用规范依赖注入模式良好
- 异常处理完善,自定义异常层次清晰
**风险:**
- 测试覆盖率仅28%,远低于行业标准
- AdminDB类过大50+方法),违反单一职责原则
- 内存队列存在单点故障风险
- 部分安全细节需加强(时序攻击、文件上传验证)
- 前端状态管理简单,可能难以扩展
---
## 架构概览
### 项目结构
```
invoice-master-poc-v2/
├── packages/
│ ├── shared/ # 共享库 (74个Python文件)
│ │ ├── pdf/ # PDF处理
│ │ ├── ocr/ # OCR封装
│ │ ├── normalize/ # 字段规范化
│ │ ├── matcher/ # 字段匹配
│ │ ├── storage/ # 存储抽象层
│ │ ├── training/ # 训练组件
│ │ └── augmentation/# 数据增强
│ ├── training/ # 训练服务 (26个Python文件)
│ │ ├── cli/ # 命令行工具
│ │ ├── yolo/ # YOLO数据集
│ │ └── processing/ # 任务处理
│ └── inference/ # 推理服务 (100个Python文件)
│ ├── web/ # FastAPI应用
│ ├── pipeline/ # 推理管道
│ ├── data/ # 数据层
│ └── cli/ # 命令行工具
├── frontend/ # React前端 (39个TS/TSX文件)
│ ├── src/
│ │ ├── components/ # UI组件
│ │ ├── hooks/ # React Query hooks
│ │ └── api/ # API客户端
└── tests/ # 测试 (97个Python文件)
```
### 技术栈
| 层级 | 技术 | 评估 |
|------|------|------|
| **前端** | React 18 + TypeScript + Vite + TailwindCSS | 现代栈,类型安全 |
| **API框架** | FastAPI + Uvicorn | 高性能,异步支持 |
| **数据库** | PostgreSQL + SQLModel | 类型安全ORM |
| **目标检测** | YOLOv11 (Ultralytics) | 业界标准 |
| **OCR** | PaddleOCR v5 | 支持瑞典语 |
| **部署** | Docker + Azure/AWS | 云原生 |
---
## 详细模块审查
### 1. Shared Package
#### 1.1 配置模块 (`shared/config.py`)
**文件位置**: `packages/shared/shared/config.py`
**代码行数**: 82行
**优点:**
- 使用环境变量加载配置,无硬编码敏感信息
- DPI配置统一管理DEFAULT_DPI = 150
- 密码无默认值,强制要求设置
**问题:**
```python
# 问题1: 配置分散,缺少验证
DATABASE = {
'host': os.getenv('DB_HOST', '192.168.68.31'), # 硬编码IP
'port': int(os.getenv('DB_PORT', '5432')),
# ...
}
# 问题2: 缺少类型安全
# 建议使用 Pydantic Settings
```
**严重程度**: 中
**建议**: 使用 Pydantic Settings 集中管理配置,添加验证逻辑
---
#### 1.2 存储抽象层 (`shared/storage/`)
**文件位置**: `packages/shared/shared/storage/`
**包含文件**: 8个
**优点:**
- 设计优秀的抽象接口 `StorageBackend`
- 支持 Local/Azure/S3 多后端
- 预签名URL支持
- 异常层次清晰
**代码示例 - 优秀设计:**
```python
class StorageBackend(ABC):
@abstractmethod
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
pass
@abstractmethod
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
pass
```
**问题:**
- `upload_bytes``download_bytes` 默认实现使用临时文件,效率较低
- 缺少文件类型验证(魔术字节检查)
**严重程度**: 低
**建议**: 子类可重写bytes方法以提高效率添加文件类型验证
---
#### 1.3 异常定义 (`shared/exceptions.py`)
**文件位置**: `packages/shared/shared/exceptions.py`
**代码行数**: 103行
**优点:**
- 清晰的异常层次结构
- 所有异常继承自 `InvoiceExtractionError`
- 包含详细的错误上下文
**代码示例:**
```python
class InvoiceExtractionError(Exception):
def __init__(self, message: str, details: dict = None):
super().__init__(message)
self.message = message
self.details = details or {}
```
**评分**: 9/10 - 设计优秀
---
#### 1.4 数据增强 (`shared/augmentation/`)
**文件位置**: `packages/shared/shared/augmentation/`
**包含文件**: 10个
**功能:**
- 12种数据增强策略
- 透视变换、皱纹、边缘损坏、污渍等
- 高斯模糊、运动模糊、噪声等
**代码质量**: 良好,模块化设计
---
### 2. Inference Package
#### 2.1 认证模块 (`inference/web/core/auth.py`)
**文件位置**: `packages/inference/inference/web/core/auth.py`
**代码行数**: 61行
**优点:**
- 使用FastAPI依赖注入模式
- Token过期检查
- 记录最后使用时间
**安全问题:**
```python
# 问题: 时序攻击风险 (第46行)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(status_code=401, detail="Invalid or expired admin token.")
# 建议: 使用 constant-time 比较
import hmac
if not hmac.compare_digest(token, expected_token):
raise HTTPException(status_code=401, ...)
```
**严重程度**: 中
**建议**: 使用 `hmac.compare_digest()` 进行constant-time比较
---
#### 2.2 限流器 (`inference/web/core/rate_limiter.py`)
**文件位置**: `packages/inference/inference/web/core/rate_limiter.py`
**代码行数**: 212行
**优点:**
- 滑动窗口算法实现
- 线程安全使用Lock
- 支持并发任务限制
- 可配置的限流策略
**代码示例 - 优秀设计:**
```python
@dataclass(frozen=True)
class RateLimitConfig:
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000
```
**问题:**
- 内存存储,服务重启后限流状态丢失
- 分布式部署时无法共享限流状态
**严重程度**: 中
**建议**: 生产环境使用Redis实现分布式限流
---
#### 2.3 AdminDB (`inference/data/admin_db.py`)
**文件位置**: `packages/inference/inference/data/admin_db.py`
**代码行数**: 1300+行
**严重问题 - 类过大:**
```python
class AdminDB:
# Token管理 (5个方法)
# 文档管理 (8个方法)
# 标注管理 (6个方法)
# 训练任务 (7个方法)
# 数据集 (6个方法)
# 模型版本 (5个方法)
# 批处理 (4个方法)
# 锁管理 (3个方法)
# ... 总计50+方法
```
**影响:**
- 违反单一职责原则
- 难以维护
- 测试困难
- 不同领域变更互相影响
**严重程度**: 高
**建议**: 按领域拆分为Repository模式
```python
# 建议重构
class TokenRepository:
def validate(self, token: str) -> bool: ...
class DocumentRepository:
def find_by_id(self, doc_id: str) -> Document | None: ...
class TrainingRepository:
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
```
---
#### 2.4 文档路由 (`inference/web/api/v1/admin/documents.py`)
**文件位置**: `packages/inference/inference/web/api/v1/admin/documents.py`
**代码行数**: 692行
**优点:**
- FastAPI使用规范
- 输入验证完善
- 响应模型定义清晰
- 错误处理良好
**问题:**
```python
# 问题1: 文件上传缺少魔术字节验证 (第127-131行)
content = await file.read()
# 建议: 验证PDF魔术字节 %PDF
# 问题2: 路径遍历风险 (第494-498行)
filename = Path(document.file_path).name
# 建议: 使用 Path.name 并验证路径范围
# 问题3: 函数过长,职责过多
# _convert_pdf_to_images 函数混合了PDF处理和存储操作
```
**严重程度**: 中
**建议**: 添加文件类型验证,拆分大函数
---
#### 2.5 推理服务 (`inference/web/services/inference.py`)
**文件位置**: `packages/inference/inference/web/services/inference.py`
**代码行数**: 361行
**优点:**
- 支持动态模型加载
- 懒加载初始化
- 模型热重载支持
**问题:**
```python
# 问题1: 混合业务逻辑和技术实现
def process_image(self, image_path: Path, ...) -> ServiceResult:
# 1. 技术细节: 图像解码
# 2. 业务逻辑: 字段提取
# 3. 技术细节: 模型推理
# 4. 业务逻辑: 结果验证
# 问题2: 可视化方法重复加载模型
model = YOLO(str(self.model_config.model_path)) # 第316行
# 应该在初始化时加载避免重复IO
# 问题3: 临时文件未使用上下文管理器
temp_path = results_dir / f"{doc_id}_temp.png"
# 建议使用 tempfile 上下文管理器
```
**严重程度**: 中
**建议**: 引入领域层和适配器模式,分离业务和技术逻辑
---
#### 2.6 异步队列 (`inference/web/workers/async_queue.py`)
**文件位置**: `packages/inference/inference/web/workers/async_queue.py`
**代码行数**: 213行
**优点:**
- 线程安全实现
- 优雅关闭支持
- 任务状态跟踪
**严重问题:**
```python
# 问题: 内存队列,服务重启丢失任务 (第42行)
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
# 问题: 无法水平扩展
# 问题: 任务持久化困难
```
**严重程度**: 高
**建议**: 使用Redis/RabbitMQ持久化队列
---
### 3. Training Package
#### 3.1 整体评估
**文件数量**: 26个Python文件
**优点:**
- CLI工具设计良好
- 双池协调器CPU + GPU设计优秀
- 数据增强策略丰富
**总体评分**: 8/10
---
### 4. Frontend
#### 4.1 API客户端 (`frontend/src/api/client.ts`)
**文件位置**: `frontend/src/api/client.ts`
**代码行数**: 42行
**优点:**
- Axios配置清晰
- 请求/响应拦截器
- 认证token自动添加
**问题:**
```typescript
// 问题1: Token存储在localStorage存在XSS风险
const token = localStorage.getItem('admin_token')
// 问题2: 401错误处理不完整
if (error.response?.status === 401) {
console.warn('Authentication required...')
// 应该触发重新登录或token刷新
}
```
**严重程度**: 中
**建议**: 考虑使用http-only cookie存储token完善错误处理
---
#### 4.2 Dashboard组件 (`frontend/src/components/Dashboard.tsx`)
**文件位置**: `frontend/src/components/Dashboard.tsx`
**代码行数**: 301行
**优点:**
- React hooks使用规范
- 类型定义清晰
- UI响应式设计
**问题:**
```typescript
// 问题1: 硬编码的进度值
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
if (doc.auto_label_status === 'running') {
return 45 // 硬编码!
}
// ...
}
// 问题2: 搜索功能未实现
// 没有onChange处理
// 问题3: 缺少错误边界处理
// 组件应该包裹在Error Boundary中
```
**严重程度**: 低
**建议**: 实现真实的进度获取,添加搜索功能
---
#### 4.3 整体评估
**优点:**
- TypeScript类型安全
- React Query状态管理
- TailwindCSS样式一致
**问题:**
- 缺少错误边界
- 部分功能硬编码
- 缺少单元测试
**总体评分**: 7.5/10
---
### 5. Tests
#### 5.1 测试统计
- **测试文件数**: 97个
- **测试总数**: 1,601个
- **测试覆盖率**: 28%
#### 5.2 覆盖率分析
| 模块 | 估计覆盖率 | 状态 |
|------|-----------|------|
| `shared/` | 35% | 偏低 |
| `inference/web/` | 25% | 偏低 |
| `inference/pipeline/` | 20% | 严重不足 |
| `training/` | 30% | 偏低 |
| `frontend/` | 15% | 严重不足 |
#### 5.3 测试质量问题
**优点:**
- 使用了pytest框架
- 有conftest.py配置
- 部分集成测试
**问题:**
- 覆盖率远低于行业标准80%
- 缺少端到端测试
- 部分测试可能过于简单
**严重程度**: 高
**建议**: 制定测试计划,优先覆盖核心业务逻辑
---
## 代码质量问题
### 高优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| AdminDB类过大 | `inference/data/admin_db.py` | 维护困难 | 拆分为Repository模式 |
| 内存队列单点故障 | `inference/web/workers/async_queue.py` | 任务丢失 | 使用Redis持久化 |
| 测试覆盖率过低 | 全项目 | 代码风险 | 提升至60%+ |
### 中优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| 时序攻击风险 | `inference/web/core/auth.py` | 安全漏洞 | 使用hmac.compare_digest |
| 限流器内存存储 | `inference/web/core/rate_limiter.py` | 分布式问题 | 使用Redis |
| 配置分散 | `shared/config.py` | 难以管理 | 使用Pydantic Settings |
| 文件上传验证不足 | `inference/web/api/v1/admin/documents.py` | 安全风险 | 添加魔术字节验证 |
| 推理服务混合职责 | `inference/web/services/inference.py` | 难以测试 | 分离业务和技术逻辑 |
### 低优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| 前端搜索未实现 | `frontend/src/components/Dashboard.tsx` | 功能缺失 | 实现搜索功能 |
| 硬编码进度值 | `frontend/src/components/Dashboard.tsx` | 用户体验 | 获取真实进度 |
| Token存储方式 | `frontend/src/api/client.ts` | XSS风险 | 考虑http-only cookie |
---
## 安全风险分析
### 已识别的安全风险
#### 1. 时序攻击 (中风险)
**位置**: `inference/web/core/auth.py:46`
```python
# 当前实现(有风险)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(status_code=401, ...)
# 安全实现
import hmac
if not hmac.compare_digest(token, expected_token):
raise HTTPException(status_code=401, ...)
```
#### 2. 文件上传验证不足 (中风险)
**位置**: `inference/web/api/v1/admin/documents.py:127-131`
```python
# 建议添加魔术字节验证
ALLOWED_EXTENSIONS = {".pdf"}
MAX_FILE_SIZE = 10 * 1024 * 1024
if not content.startswith(b"%PDF"):
raise HTTPException(400, "Invalid PDF file format")
```
#### 3. 路径遍历风险 (中风险)
**位置**: `inference/web/api/v1/admin/documents.py:494-498`
```python
# 建议实现
from pathlib import Path
def get_safe_path(filename: str, base_dir: Path) -> Path:
safe_name = Path(filename).name
full_path = (base_dir / safe_name).resolve()
if not full_path.is_relative_to(base_dir):
raise HTTPException(400, "Invalid file path")
return full_path
```
#### 4. CORS配置 (低风险)
**位置**: FastAPI中间件配置
```python
# 建议生产环境配置
ALLOWED_ORIGINS = [
"http://localhost:5173",
"https://your-domain.com",
]
```
#### 5. XSS风险 (低风险)
**位置**: `frontend/src/api/client.ts:13`
```typescript
// 当前实现
const token = localStorage.getItem('admin_token')
// 建议考虑
// 使用http-only cookie存储敏感token
```
### 安全评分
| 类别 | 评分 | 说明 |
|------|------|------|
| 认证 | 8/10 | 基础良好,需加强时序攻击防护 |
| 输入验证 | 7/10 | 基本验证到位,需加强文件验证 |
| 数据保护 | 8/10 | 无敏感信息硬编码 |
| 传输安全 | 8/10 | 使用HTTPS生产环境 |
| 总体 | 7.5/10 | 基础安全良好,需加强细节 |
---
## 性能问题
### 已识别的性能问题
#### 1. 重复模型加载
**位置**: `inference/web/services/inference.py:316`
```python
# 问题: 每次可视化都重新加载模型
model = YOLO(str(self.model_config.model_path))
# 建议: 复用已加载的模型
```
#### 2. 临时文件处理
**位置**: `shared/storage/base.py:178-203`
```python
# 问题: bytes操作使用临时文件
def upload_bytes(self, data: bytes, ...):
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(data)
temp_path = Path(f.name)
# ...
# 建议: 子类重写为直接上传
```
#### 3. 数据库查询优化
**位置**: `inference/data/admin_db.py`
```python
# 问题: N+1查询风险
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
# ...
# 建议: 使用join预加载
```
### 性能评分
| 类别 | 评分 | 说明 |
|------|------|------|
| 响应时间 | 8/10 | 异步处理良好 |
| 资源使用 | 7/10 | 有优化空间 |
| 可扩展性 | 7/10 | 内存队列限制 |
| 并发处理 | 8/10 | 线程池设计良好 |
| 总体 | 7.5/10 | 良好,有优化空间 |
---
## 改进建议
### 立即执行 (本周)
1. **拆分AdminDB**
- 创建 `repositories/` 目录
- 按领域拆分TokenRepository, DocumentRepository, TrainingRepository
- 估计工时: 2天
2. **修复安全漏洞**
- 添加 `hmac.compare_digest()` 时序攻击防护
- 添加文件魔术字节验证
- 估计工时: 0.5天
3. **提升测试覆盖率**
- 优先测试 `inference/pipeline/`
- 添加API集成测试
- 目标: 从28%提升至50%
- 估计工时: 3天
### 短期执行 (本月)
4. **引入消息队列**
- 添加Redis服务
- 使用Celery替换内存队列
- 估计工时: 3天
5. **统一配置管理**
- 使用 Pydantic Settings
- 集中验证逻辑
- 估计工时: 1天
6. **添加缓存层**
- Redis缓存热点数据
- 缓存文档、模型配置
- 估计工时: 2天
### 长期执行 (本季度)
7. **数据库读写分离**
- 配置主从数据库
- 读操作使用从库
- 估计工时: 3天
8. **事件驱动架构**
- 引入事件总线
- 解耦模块依赖
- 估计工时: 5天
9. **前端优化**
- 添加错误边界
- 实现真实搜索功能
- 添加E2E测试
- 估计工时: 3天
---
## 总结与评分
### 各维度评分
| 维度 | 评分 | 权重 | 加权得分 |
|------|------|------|----------|
| **代码质量** | 7.5/10 | 20% | 1.5 |
| **安全性** | 7.5/10 | 20% | 1.5 |
| **可维护性** | 8/10 | 15% | 1.2 |
| **测试覆盖** | 5/10 | 15% | 0.75 |
| **性能** | 7.5/10 | 15% | 1.125 |
| **文档** | 8/10 | 10% | 0.8 |
| **架构设计** | 8/10 | 5% | 0.4 |
| **总体** | **7.3/10** | 100% | **7.275** |
### 关键结论
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
2. **代码质量良好**: 类型注解完善,文档详尽,结构清晰
3. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
4. **测试是短板**: 28%覆盖率是最大风险点
5. **生产就绪**: 经过小幅改进后可以投入生产使用
### 下一步行动
| 优先级 | 任务 | 预计工时 | 影响 |
|--------|------|----------|------|
| 高 | 拆分AdminDB | 2天 | 提升可维护性 |
| 高 | 引入Redis队列 | 3天 | 解决任务丢失问题 |
| 高 | 提升测试覆盖率 | 5天 | 降低代码风险 |
| 中 | 修复安全漏洞 | 0.5天 | 提升安全性 |
| 中 | 统一配置管理 | 1天 | 减少配置错误 |
| 低 | 前端优化 | 3天 | 提升用户体验 |
---
## 附录
### 关键文件清单
| 文件 | 职责 | 问题 |
|------|------|------|
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
| `shared/shared/config.py` | 共享配置 | 分散管理 |
### 参考资源
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
- [Celery Documentation](https://docs.celeryproject.org/)
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
---
**报告生成时间**: 2026-02-01
**审查工具**: Claude Code + AST-grep + LSP

View File

@@ -0,0 +1,637 @@
# Invoice Master POC v2 - 商业化分析报告
**报告日期**: 2026-02-01
**分析人**: Claude Code
**项目**: Invoice Master - 瑞典发票字段自动提取系统
**当前状态**: POC阶段已处理9,738份文档字段匹配率94.8%
---
## 目录
1. [执行摘要](#执行摘要)
2. [市场分析](#市场分析)
3. [商业模式建议](#商业模式建议)
4. [技术架构商业化评估](#技术架构商业化评估)
5. [商业化路线图](#商业化路线图)
6. [风险与挑战](#风险与挑战)
7. [成本与定价策略](#成本与定价策略)
8. [竞争分析](#竞争分析)
9. [改进建议](#改进建议)
10. [总结与建议](#总结与建议)
---
## 执行摘要
### 项目现状
Invoice Master是一个基于YOLOv11 + PaddleOCR的瑞典发票字段自动提取系统具备以下核心能力
| 指标 | 数值 | 评估 |
|------|------|------|
| 已处理文档 | 9,738份 | 数据基础良好 |
| 字段匹配率 | 94.8% | 接近商业化标准 |
| 模型mAP@0.5 | 93.5% | 业界优秀水平 |
| 测试覆盖率 | 28% | 需大幅提升 |
| 架构成熟度 | 7.3/10 | 基本就绪 |
### 商业化可行性评估
| 维度 | 评分 | 说明 |
|------|------|------|
| **技术成熟度** | 7.5/10 | 核心算法成熟,需完善工程化 |
| **市场需求** | 8/10 | 发票处理是刚需市场 |
| **竞争壁垒** | 6/10 | 技术可替代,需构建数据壁垒 |
| **商业化就绪度** | 6.5/10 | 需完成产品化和合规准备 |
| **总体评估** | **7/10** | **具备商业化潜力需6-12个月准备** |
### 关键建议
1. **短期3个月**: 提升测试覆盖率至80%,完成安全加固
2. **中期6个月**: 推出MVP产品获取首批付费客户
3. **长期12个月**: 扩展多语言支持,进入国际市场
---
## 市场分析
### 目标市场
#### 1.1 市场规模
**全球发票处理市场**
- 市场规模: ~$30B (2024)
- 年增长率: 12-15%
- 驱动因素: 数字化转型、合规要求、成本节约
**瑞典/北欧市场**
- 中小企业数量: ~100万+
- 大型企业: ~2,000家
- 年发票处理量: ~5亿张
- 市场特点: 数字化程度高,合规要求严格
#### 1.2 目标客户画像
| 客户类型 | 规模 | 痛点 | 付费意愿 | 获取难度 |
|----------|------|------|----------|----------|
| **中小企业** | 10-100人 | 手动录入耗时 | 中 | 低 |
| **会计事务所** | 5-50人 | 批量处理需求 | 高 | 中 |
| **大型企业** | 500+人 | 系统集成需求 | 高 | 高 |
| **SaaS平台** | - | API集成需求 | 中 | 中 |
### 市场需求验证
#### 2.1 痛点分析
**现有解决方案的问题:**
1. **传统OCR**: 准确率70-85%,需要大量人工校对
2. **人工录入**: 成本高($0.5-2/张),速度慢,易出错
3. **现有AI方案**: 价格昂贵,定制化程度低
**Invoice Master的优势:**
- 准确率94.8%,接近人工水平
- 支持瑞典特有的字段(OCR参考号、Bankgiro/Plusgiro)
- 可定制化训练,适应不同发票格式
#### 2.2 市场进入策略
**第一阶段: 瑞典市场验证**
- 目标客户: 中型会计事务所
- 价值主张: 减少80%人工录入时间
- 定价: $0.1-0.2/张 或 $99-299/月
**第二阶段: 北欧扩展**
- 扩展至挪威、丹麦、芬兰
- 适配各国发票格式
- 建立本地合作伙伴网络
**第三阶段: 欧洲市场**
- 支持多语言(德语、法语、英语)
- GDPR合规认证
- 与主流ERP系统集成
---
## 商业模式建议
### 3.1 商业模式选项
#### 选项A: SaaS订阅模式 (推荐)
**定价结构:**
```
Starter: $99/月
- 500张发票/月
- 基础字段提取
- 邮件支持
Professional: $299/月
- 2,000张发票/月
- 所有字段+自定义字段
- API访问
- 优先支持
Enterprise: 定制报价
- 无限发票
- 私有部署选项
- SLA保障
- 专属客户经理
```
**优势:**
- 可预测的经常性收入
- 客户生命周期价值高
- 易于扩展
**劣势:**
- 需要持续的产品迭代
- 客户获取成本较高
#### 选项B: 按量付费模式
**定价:**
- 前100张: $0.15/张
- 101-1000张: $0.10/张
- 1001+张: $0.05/张
**适用场景:**
- 季节性业务
- 初创企业
- 不确定使用量的客户
#### 选项C: 授权许可模式
**定价:**
- 年度许可: $10,000-50,000
- 按部署规模收费
- 包含培训和定制开发
**适用场景:**
- 大型企业
- 数据敏感行业
- 需要私有部署的客户
### 3.2 推荐模式: 混合模式
**核心产品: SaaS订阅**
- 面向中小企业和会计事务所
- 标准化产品,快速交付
**增值服务: 定制开发**
- 面向大型企业
- 私有部署选项
- 按项目收费
**API服务: 按量付费**
- 面向SaaS平台和开发者
- 开发者友好定价
### 3.3 收入预测
**保守估计 (第一年)**
| 客户类型 | 客户数 | ARPU | MRR | 年收入 |
|----------|--------|------|-----|--------|
| Starter | 20 | $99 | $1,980 | $23,760 |
| Professional | 10 | $299 | $2,990 | $35,880 |
| Enterprise | 2 | $2,000 | $4,000 | $48,000 |
| **总计** | **32** | - | **$8,970** | **$107,640** |
**乐观估计 (第一年)**
- 客户数: 100+
- 年收入: $300,000-500,000
---
## 技术架构商业化评估
### 4.1 架构优势
| 优势 | 说明 | 商业化价值 |
|------|------|-----------|
| **Monorepo结构** | 代码组织清晰 | 降低维护成本 |
| **云原生架构** | 支持AWS/Azure | 灵活部署选项 |
| **存储抽象层** | 支持多后端 | 满足不同客户需求 |
| **模型版本管理** | 可追溯可回滚 | 企业级可靠性 |
| **API优先设计** | RESTful API | 易于集成和扩展 |
### 4.2 商业化就绪度评估
#### 高优先级改进项
| 问题 | 影响 | 改进建议 | 工时 |
|------|------|----------|------|
| **测试覆盖率28%** | 质量风险 | 提升至80%+ | 4周 |
| **AdminDB过大** | 维护困难 | 拆分Repository | 2周 |
| **内存队列** | 单点故障 | 引入Redis | 2周 |
| **安全漏洞** | 合规风险 | 修复时序攻击等 | 1周 |
#### 中优先级改进项
| 问题 | 影响 | 改进建议 | 工时 |
|------|------|----------|------|
| **缺少审计日志** | 合规要求 | 添加完整审计 | 2周 |
| **无多租户隔离** | 数据安全 | 实现租户隔离 | 3周 |
| **限流器内存存储** | 扩展性 | Redis分布式限流 | 1周 |
| **配置分散** | 运维难度 | 统一配置中心 | 1周 |
### 4.3 技术债务清理计划
**阶段1: 基础加固 (4周)**
- 提升测试覆盖率至60%
- 修复安全漏洞
- 添加基础监控
**阶段2: 架构优化 (6周)**
- 拆分AdminDB
- 引入消息队列
- 实现多租户支持
**阶段3: 企业级功能 (8周)**
- 完整审计日志
- SSO集成
- 高级权限管理
---
## 商业化路线图
### 5.1 时间线规划
```
Month 1-3: 产品化准备
├── 技术债务清理
├── 安全加固
├── 测试覆盖率提升
└── 文档完善
Month 4-6: MVP发布
├── 核心功能稳定
├── 基础监控告警
├── 客户反馈收集
└── 定价策略验证
Month 7-9: 市场扩展
├── 销售团队组建
├── 合作伙伴网络
├── 案例研究制作
└── 营销自动化
Month 10-12: 规模化
├── 多语言支持
├── 高级功能开发
├── 国际市场准备
└── 融资准备
```
### 5.2 里程碑
| 里程碑 | 时间 | 成功标准 |
|--------|------|----------|
| **技术就绪** | M3 | 测试80%,零高危漏洞 |
| **首个付费客户** | M4 | 签约并上线 |
| **产品市场契合** | M6 | 10+付费客户NPS>40 |
| **盈亏平衡** | M9 | MRR覆盖运营成本 |
| **规模化准备** | M12 | 100+客户,$50K+MRR |
### 5.3 团队组建建议
**核心团队 (前6个月)**
| 角色 | 人数 | 职责 |
|------|------|------|
| 技术负责人 | 1 | 架构、技术决策 |
| 全栈工程师 | 2 | 产品开发 |
| ML工程师 | 1 | 模型优化 |
| 产品经理 | 1 | 产品规划 |
| 销售/BD | 1 | 客户获取 |
**扩展团队 (6-12个月)**
| 角色 | 人数 | 职责 |
|------|------|------|
| 客户成功 | 1 | 客户留存 |
| 市场营销 | 1 | 品牌建设 |
| 技术支持 | 1 | 客户支持 |
---
## 风险与挑战
### 6.1 技术风险
| 风险 | 概率 | 影响 | 缓解措施 |
|------|------|------|----------|
| **模型准确率下降** | 中 | 高 | 持续训练A/B测试 |
| **系统稳定性** | 中 | 高 | 完善监控,灰度发布 |
| **数据安全漏洞** | 低 | 高 | 安全审计,渗透测试 |
| **扩展性瓶颈** | 中 | 中 | 架构优化,负载测试 |
### 6.2 市场风险
| 风险 | 概率 | 影响 | 缓解措施 |
|------|------|------|----------|
| **竞争加剧** | 高 | 中 | 差异化定位,垂直深耕 |
| **价格战** | 中 | 中 | 价值定价,增值服务 |
| **客户获取困难** | 中 | 高 | 内容营销,口碑传播 |
| **市场教育成本** | 中 | 中 | 免费试用,案例展示 |
### 6.3 合规风险
| 风险 | 概率 | 影响 | 缓解措施 |
|------|------|------|----------|
| **GDPR合规** | 高 | 高 | 隐私设计,数据本地化 |
| **数据主权** | 中 | 高 | 多区域部署选项 |
| **行业认证** | 中 | 中 | ISO27001, SOC2准备 |
### 6.4 财务风险
| 风险 | 概率 | 影响 | 缓解措施 |
|------|------|------|----------|
| **现金流紧张** | 中 | 高 | 预付费模式,成本控制 |
| **客户流失** | 中 | 中 | 客户成功,年度合同 |
| **定价失误** | 中 | 中 | 灵活定价,快速迭代 |
---
## 成本与定价策略
### 7.1 运营成本估算
**月度运营成本 (AWS)**
| 项目 | 成本 | 说明 |
|------|------|------|
| 计算 (ECS Fargate) | $150 | 推理服务 |
| 数据库 (RDS) | $50 | PostgreSQL |
| 存储 (S3) | $20 | 文档和模型 |
| 训练 (SageMaker) | $100 | 按需训练 |
| 监控/日志 | $30 | CloudWatch等 |
| **小计** | **$350** | **基础运营成本** |
**月度运营成本 (Azure)**
| 项目 | 成本 | 说明 |
|------|------|------|
| 计算 (Container Apps) | $180 | 推理服务 |
| 数据库 | $60 | PostgreSQL |
| 存储 | $25 | Blob Storage |
| 训练 | $120 | Azure ML |
| **小计** | **$385** | **基础运营成本** |
**人力成本 (月度)**
| 阶段 | 人数 | 成本 |
|------|------|------|
| 启动期 (1-3月) | 3 | $15,000 |
| 成长期 (4-9月) | 5 | $25,000 |
| 规模化 (10-12月) | 7 | $35,000 |
### 7.2 定价策略
**成本加成定价**
- 基础成本: $350/月
- 目标毛利率: 70%
- 最低收费: $1,000/月
**价值定价**
- 客户节省成本: $2-5/张 (人工录入)
- 收费: $0.1-0.2/张
- 客户ROI: 10-50x
**竞争定价**
- 竞争对手: $0.2-0.5/张
- 我们的定价: $0.1-0.15/张
- 策略: 高性价比切入
### 7.3 盈亏平衡分析
**固定成本: $25,000/月** (人力+基础设施)
**盈亏平衡点:**
- 按订阅模式: 85个Professional客户 或 250个Starter客户
- 按量付费: 250,000张发票/月
**目标 (12个月):**
- MRR: $50,000
- 客户数: 150
- 毛利率: 75%
---
## 竞争分析
### 8.1 竞争对手
#### 直接竞争对手
| 公司 | 产品 | 优势 | 劣势 | 定价 |
|------|------|------|------|------|
| **Rossum** | AI发票处理 | 技术成熟,欧洲市场强 | 价格高 | $0.3-0.5/张 |
| **Hypatos** | 文档AI | 德国市场深耕 | 定制化弱 | 定制报价 |
| **Klippa** | 文档解析 | API友好 | 准确率一般 | $0.1-0.2/张 |
| **Nanonets** | 工作流自动化 | 易用性好 | 发票专业性弱 | $0.05-0.15/张 |
#### 间接竞争对手
| 类型 | 代表 | 威胁程度 |
|------|------|----------|
| **传统OCR** | ABBYY, Tesseract | 中 |
| **ERP内置** | SAP, Oracle | 中 |
| **会计软件** | Visma, Fortnox | 高 |
### 8.2 竞争优势
**短期优势 (6-12个月)**
1. **瑞典市场专注**: 本地化字段支持
2. **价格优势**: 比Rossum便宜50%+
3. **定制化**: 可训练专属模型
**长期优势 (1-3年)**
1. **数据壁垒**: 训练数据积累
2. **行业深度**: 垂直行业解决方案
3. **生态集成**: 与主流ERP深度集成
### 8.3 竞争策略
**差异化定位**
- 不做通用文档处理,专注发票领域
- 不做全球市场,先做透北欧
- 不做低价竞争,做高性价比
**护城河构建**
1. **数据壁垒**: 客户发票数据训练
2. **转换成本**: 系统集成和工作流
3. **网络效应**: 行业模板共享
---
## 改进建议
### 9.1 产品改进
#### 高优先级
| 改进项 | 说明 | 商业价值 | 工时 |
|--------|------|----------|------|
| **多语言支持** | 英语、德语、法语 | 扩大市场 | 4周 |
| **批量处理API** | 支持千级批量 | 大客户必需 | 2周 |
| **实时处理** | <3秒响应 | 用户体验 | 2周 |
| **置信度阈值** | 用户可配置 | 灵活性 | 1周 |
#### 中优先级
| 改进项 | 说明 | 商业价值 | 工时 |
|--------|------|----------|------|
| **移动端适配** | 手机拍照上传 | 便利性 | 3周 |
| **PDF预览** | 在线查看和标注 | 用户体验 | 2周 |
| **导出格式** | Excel, JSON, XML | 集成便利 | 1周 |
| **Webhook** | 事件通知 | 自动化 | 1周 |
### 9.2 技术改进
#### 架构优化
```
当前架构问题:
├── 内存队列 → 改为Redis队列
├── 单体DB → 读写分离
├── 同步处理 → 异步优先
└── 单区域 → 多区域部署
```
#### 性能优化
| 优化项 | 当前 | 目标 | 方法 |
|--------|------|------|------|
| 推理延迟 | 500ms | 200ms | 模型量化 |
| 并发处理 | 10 QPS | 100 QPS | 水平扩展 |
| 系统可用性 | 99% | 99.9% | 冗余设计 |
### 9.3 运营改进
#### 客户成功
- 入职流程: 30分钟完成首次提取
- 培训材料: 视频教程+文档
- 支持响应: <4小时响应时间
- 客户健康度: 自动监控和预警
#### 销售流程
1. **线索获取**: 内容营销+SEO
2. **试用转化**: 14天免费试用
3. **付费转化**: 客户成功跟进
4. **扩展销售**: 功能升级推荐
---
## 总结与建议
### 10.1 商业化可行性结论
**总体评估: 可行需6-12个月准备**
Invoice Master具备商业化的技术基础和市场机会但需要完成以下关键准备
1. **技术债务清理**: 测试覆盖率安全加固
2. **产品化完善**: 多租户审计日志监控
3. **市场验证**: 获取首批付费客户
4. **团队组建**: 销售和客户成功团队
### 10.2 关键成功因素
| 因素 | 重要性 | 当前状态 | 行动计划 |
|------|--------|----------|----------|
| **技术稳定性** | | | 测试+监控 |
| **客户获取** | | | 内容营销 |
| **产品市场契合** | | 未验证 | 快速迭代 |
| **团队能力** | | | 招聘培训 |
| **资金储备** | | 未知 | 融资准备 |
### 10.3 行动计划
#### 立即执行 (本月)
- [ ] 制定详细的技术债务清理计划
- [ ] 启动安全审计和漏洞修复
- [ ] 设计多租户架构方案
- [ ] 准备融资材料或预算规划
#### 短期目标 (3个月)
- [ ] 测试覆盖率提升至80%
- [ ] 完成安全加固和合规准备
- [ ] 发布Beta版本给5-10个试用客户
- [ ] 确定最终定价策略
#### 中期目标 (6个月)
- [ ] 获得10+付费客户
- [ ] MRR达到$10,000
- [ ] 完成产品市场契合验证
- [ ] 组建完整团队
#### 长期目标 (12个月)
- [ ] 100+付费客户
- [ ] MRR达到$50,000
- [ ] 扩展到2-3个新市场
- [ ] 完成A轮融资或实现盈利
### 10.4 最终建议
**建议: 继续推进商业化,但需谨慎执行**
Invoice Master是一个技术扎实市场机会明确的项目当前94.8%的准确率已经接近商业化标准但需要投入资源完成工程化和产品化
**关键决策点:**
1. **是否投入商业化**: 但分阶段投入
2. **目标市场**: 先做透瑞典再扩展北欧
3. **商业模式**: SaaS订阅为主定制为辅
4. **融资需求**: 建议准备$200K-500K种子资金
**成功概率评估: 65%**
- 技术可行性: 80%
- 市场接受度: 70%
- 执行能力: 60%
- 竞争环境: 50%
---
## 附录
### A. 关键指标追踪
| 指标 | 当前 | 3个月目标 | 6个月目标 | 12个月目标 |
|------|------|-----------|-----------|------------|
| 测试覆盖率 | 28% | 60% | 80% | 85% |
| 系统可用性 | - | 99.5% | 99.9% | 99.95% |
| 客户数 | 0 | 5 | 20 | 150 |
| MRR | $0 | $500 | $10,000 | $50,000 |
| NPS | - | - | >40 | >50 |
| 客户流失率 | - | - | <5%/ | <3%/ |
### B. 资源需求
**资金需求**
| 阶段 | 时间 | 金额 | 用途 |
|------|------|------|------|
| 种子期 | 0-6月 | $100K | 团队+基础设施 |
| 成长期 | 6-12月 | $300K | 市场+团队扩展 |
| A轮 | 12-18月 | $1M+ | 规模化+国际 |
**人力需求**
| 阶段 | 团队规模 | 关键角色 |
|------|----------|----------|
| 启动 | 3-4人 | 技术+产品+销售 |
| 验证 | 5-6人 | +客户成功 |
| 增长 | 8-10人 | +市场+技术支持 |
### C. 参考资源
- [SaaS Metrics Guide](https://www.saasmetrics.co/)
- [GDPR Compliance Checklist](https://gdpr.eu/checklist/)
- [B2B SaaS Pricing Guide](https://www.priceintelligently.com/)
- [Nordic Startup Ecosystem](https://www.nordicstartupnews.com/)
---
**报告完成日期**: 2026-02-01
**下次评审日期**: 2026-03-01
**版本**: v1.0

419
PROJECT_REVIEW.md Normal file
View File

@@ -0,0 +1,419 @@
# Invoice Master POC v2 - 项目审查报告
**审查日期**: 2026-02-01
**审查人**: Claude Code
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
---
## 项目概述
**Invoice Master POC v2** - 基于 YOLOv11 + PaddleOCR 的瑞典发票字段自动提取系统
### 核心功能
- **自动标注**: 利用 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
- **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
- **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化
- **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
### 架构设计
采用 **Monorepo + 三包分离** 架构:
```
packages/
├── shared/ # 共享库 (PDF, OCR, 规范化, 匹配, 存储, 训练)
├── training/ # 训练服务 (GPU, 按需启动)
└── inference/ # 推理服务 (常驻运行)
frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
```
### 性能指标
| 指标 | 数值 |
|------|------|
| **已标注文档** | 9,738 (9,709 成功) |
| **总体字段匹配率** | 94.8% (82,604/87,121) |
| **测试** | 1,601 passed |
| **测试覆盖率** | 28% |
| **模型 mAP@0.5** | 93.5% |
---
## 安全性审查
### 检查清单
| 检查项 | 状态 | 说明 | 文件位置 |
|--------|------|------|----------|
| **Secrets 管理** | ✅ 良好 | 使用 `.env` 文件,`DB_PASSWORD` 无默认值 | `packages/shared/shared/config.py:46` |
| **SQL 注入防护** | ✅ 良好 | 使用参数化查询 | 全项目 |
| **认证机制** | ✅ 良好 | Admin token 验证 + 数据库持久化 | `packages/inference/inference/web/core/auth.py` |
| **输入验证** | ⚠️ 需改进 | 部分端点缺少文件类型/大小验证 | Web API 端点 |
| **路径遍历防护** | ⚠️ 需检查 | 需确认文件上传路径验证 | 文件上传处理 |
| **CORS 配置** | ❓ 待查 | 需确认生产环境配置 | FastAPI 中间件 |
| **Rate Limiting** | ✅ 良好 | 已实现核心限流器 | `packages/inference/inference/web/core/rate_limiter.py` |
| **错误处理** | ✅ 良好 | Web 层 356 处异常处理 | 全项目 |
### 详细发现
#### ✅ 安全实践良好的方面
1. **环境变量管理**
- 使用 `python-dotenv` 加载 `.env` 文件
- 数据库密码没有默认值,强制要求设置
- 验证逻辑在配置加载时执行
2. **认证实现**
- Token 存储在 PostgreSQL 数据库
- 支持 Token 过期检查
- 记录最后使用时间
3. **存储抽象层**
- 支持 Local/Azure/S3 多后端
- 通过环境变量配置,无硬编码凭证
#### ⚠️ 需要改进的安全问题
1. **时序攻击防护**
- **位置**: `packages/inference/inference/web/core/auth.py:46`
- **问题**: Token 验证使用普通字符串比较
- **建议**: 使用 `hmac.compare_digest()` 进行 constant-time 比较
- **风险等级**: 中
2. **文件上传验证**
- **位置**: Web API 文件上传端点
- **问题**: 需确认是否验证文件魔数 (magic bytes)
- **建议**: 添加 PDF 文件签名验证 (`%PDF`)
- **风险等级**: 中
3. **路径遍历风险**
- **位置**: 文件下载/访问端点
- **问题**: 需确认文件名是否经过净化处理
- **建议**: 使用 `pathlib.Path.name` 提取文件名,验证路径范围
- **风险等级**: 中
4. **CORS 配置**
- **位置**: FastAPI 中间件配置
- **问题**: 需确认生产环境是否允许所有来源
- **建议**: 生产环境明确指定允许的 origins
- **风险等级**: 低
---
## 代码质量审查
### 代码风格与规范
| 检查项 | 状态 | 说明 |
|--------|------|------|
| **类型注解** | ✅ 优秀 | 广泛使用 Type hints覆盖率 > 90% |
| **命名规范** | ✅ 良好 | 遵循 PEP 8snake_case 命名 |
| **文档字符串** | ✅ 良好 | 主要模块和函数都有文档 |
| **异常处理** | ✅ 良好 | Web 层 356 处异常处理 |
| **代码组织** | ✅ 优秀 | 模块化结构清晰,职责分离明确 |
| **文件大小** | ⚠️ 需关注 | 部分文件超过 800 行 |
### 架构设计评估
#### 优秀的设计决策
1. **Monorepo 结构**
- 清晰的包边界 (shared/training/inference)
- 避免循环依赖
- 便于独立部署
2. **存储抽象层**
- 统一的 `StorageBackend` 接口
- 支持本地/Azure/S3 无缝切换
- 预签名 URL 支持
3. **配置管理**
- 使用 dataclass 定义配置
- 环境变量 + 配置文件混合
- 类型安全
4. **数据库设计**
- 合理的表结构
- 状态机设计 (pending → running → completed)
- 外键约束完整
#### 需要改进的方面
1. **测试覆盖率偏低**
- 当前: 28%
- 目标: 60%+
- 优先测试核心业务逻辑
2. **部分文件过大**
- 建议拆分为多个小文件
- 单一职责原则
3. **缺少集成测试**
- 建议添加端到端测试
- API 契约测试
---
## 最佳实践遵循情况
### 已遵循的最佳实践
| 实践 | 实现状态 | 说明 |
|------|----------|------|
| **环境变量配置** | ✅ | 所有配置通过环境变量 |
| **数据库连接池** | ✅ | 使用 SQLModel + psycopg2 |
| **异步处理** | ✅ | FastAPI + async/await |
| **存储抽象层** | ✅ | 支持 Local/Azure/S3 |
| **Docker 容器化** | ✅ | 每个服务独立 Dockerfile |
| **数据增强** | ✅ | 12 种增强策略 |
| **模型版本管理** | ✅ | model_versions 表 |
| **限流保护** | ✅ | Rate limiter 实现 |
| **日志记录** | ✅ | 结构化日志 |
| **类型安全** | ✅ | 全面 Type hints |
### 技术栈评估
| 组件 | 技术选择 | 评估 |
|------|----------|------|
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
| **OCR 引擎** | PaddleOCR v5 | ✅ 支持瑞典语 |
| **PDF 处理** | PyMuPDF (fitz) | ✅ 功能强大 |
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 |
| **Web 框架** | FastAPI + Uvicorn | ✅ 高性能 |
| **前端** | React + TypeScript + Vite | ✅ 现代栈 |
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
---
## 关键文件详细分析
### 1. 配置文件
#### `packages/shared/shared/config.py`
- **安全性**: ✅ 密码从环境变量读取,无默认值
- **代码质量**: ✅ 清晰的配置结构
- **建议**: 考虑使用 Pydantic Settings 进行验证
#### `packages/inference/inference/web/config.py`
- **安全性**: ✅ 无敏感信息硬编码
- **代码质量**: ✅ 使用 frozen dataclass
- **建议**: 添加配置验证逻辑
### 2. 认证模块
#### `packages/inference/inference/web/core/auth.py`
- **安全性**: ⚠️ 需添加 constant-time 比较
- **代码质量**: ✅ 依赖注入模式
- **建议**:
```python
import hmac
if not hmac.compare_digest(api_key, settings.api_key):
raise HTTPException(403, "Invalid API key")
```
### 3. 限流器
#### `packages/inference/inference/web/core/rate_limiter.py`
- **安全性**: ✅ 内存限流实现
- **代码质量**: ✅ 清晰的接口设计
- **建议**: 生产环境考虑 Redis 分布式限流
### 4. 存储层
#### `packages/shared/shared/storage/`
- **安全性**: ✅ 无凭证硬编码
- **代码质量**: ✅ 抽象接口设计
- **建议**: 添加文件类型验证
---
## 性能与可扩展性
### 当前性能
| 指标 | 数值 | 评估 |
|------|------|------|
| **字段匹配率** | 94.8% | ✅ 优秀 |
| **模型 mAP@0.5** | 93.5% | ✅ 优秀 |
| **测试执行时间** | - | 待测量 |
| **API 响应时间** | - | 待测量 |
### 可扩展性评估
| 方面 | 评估 | 说明 |
|------|------|------|
| **水平扩展** | ✅ 良好 | 无状态服务设计 |
| **垂直扩展** | ✅ 良好 | 支持 GPU 加速 |
| **数据库扩展** | ⚠️ 需关注 | 单 PostgreSQL 实例 |
| **存储扩展** | ✅ 良好 | 云存储抽象层 |
---
## 风险评估
### 高风险项
1. **测试覆盖率低 (28%)**
- **影响**: 代码变更风险高
- **缓解**: 制定测试计划,优先覆盖核心逻辑
2. **文件上传安全**
- **影响**: 潜在的路径遍历和恶意文件上传
- **缓解**: 添加文件类型验证和路径净化
### 中风险项
1. **认证时序攻击**
- **影响**: Token 可能被暴力破解
- **缓解**: 使用 constant-time 比较
2. **CORS 配置**
- **影响**: CSRF 攻击风险
- **缓解**: 生产环境限制 origins
### 低风险项
1. **依赖更新**
- **影响**: 潜在的安全漏洞
- **缓解**: 定期运行 `pip-audit`
---
## 改进建议
### 立即执行 (高优先级)
1. **提升测试覆盖率**
```bash
# 目标: 60%+
pytest tests/ --cov=packages --cov-report=html
```
- 优先测试 `inference/pipeline/`
- 添加 API 集成测试
- 添加存储层测试
2. **加强文件上传安全**
```python
# 添加文件类型验证
ALLOWED_EXTENSIONS = {".pdf"}
MAX_FILE_SIZE = 10 * 1024 * 1024
# 验证 PDF 魔数
if not content.startswith(b"%PDF"):
raise HTTPException(400, "Invalid PDF file format")
```
3. **修复时序攻击漏洞**
```python
import hmac
def verify_token(token: str, expected: str) -> bool:
return hmac.compare_digest(token, expected)
```
### 短期执行 (中优先级)
4. **添加路径遍历防护**
```python
from pathlib import Path
def get_safe_path(filename: str, base_dir: Path) -> Path:
safe_name = Path(filename).name
full_path = (base_dir / safe_name).resolve()
if not full_path.is_relative_to(base_dir):
raise HTTPException(400, "Invalid file path")
return full_path
```
5. **配置 CORS 白名单**
```python
ALLOWED_ORIGINS = [
"http://localhost:5173",
"https://your-domain.com",
]
```
6. **添加安全测试**
```python
def test_sql_injection_prevented(client):
response = client.get("/api/v1/documents?id='; DROP TABLE;")
assert response.status_code in (400, 422)
def test_path_traversal_prevented(client):
response = client.get("/api/v1/results/../../etc/passwd")
assert response.status_code == 400
```
### 长期执行 (低优先级)
7. **依赖安全审计**
```bash
pip install pip-audit
pip-audit --desc --format=json > security-audit.json
```
8. **代码质量工具**
```bash
# 添加 pre-commit hooks
pip install pre-commit
pre-commit install
```
9. **性能监控**
- 添加 APM 工具 (如 Datadog, New Relic)
- 设置性能基准测试
---
## 总结
### 总体评分
| 维度 | 评分 | 说明 |
|------|------|------|
| **安全性** | 8/10 | 基础安全良好,需加强输入验证和认证 |
| **代码质量** | 8/10 | 结构清晰,类型注解完善,部分文件过大 |
| **可维护性** | 9/10 | 模块化设计,文档详尽,架构合理 |
| **测试覆盖** | 5/10 | 需大幅提升至 60%+ |
| **性能** | 9/10 | 94.8% 匹配率93.5% mAP |
| **总体** | **8.2/10** | 优秀的项目,需关注测试和安全细节 |
### 关键结论
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
2. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
3. **代码质量高**: 类型注解完善,文档详尽,结构清晰
4. **测试是短板**: 28% 覆盖率是最大风险点
5. **生产就绪**: 经过小幅改进后可以投入生产使用
### 下一步行动
1. 🔴 **立即**: 提升测试覆盖率至 60%+
2. 🟡 **本周**: 修复时序攻击漏洞,加强文件上传验证
3. 🟡 **本月**: 添加路径遍历防护,配置 CORS 白名单
4. 🟢 **季度**: 建立安全审计流程,添加性能监控
---
## 附录
### 审查工具
- Claude Code Security Review Skill
- Claude Code Coding Standards Skill
- grep / find / wc
### 相关文件
- `packages/shared/shared/config.py`
- `packages/inference/inference/web/config.py`
- `packages/inference/inference/web/core/auth.py`
- `packages/inference/inference/web/core/rate_limiter.py`
- `packages/shared/shared/storage/`
### 参考资源
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
- [FastAPI Security](https://fastapi.tiangolo.com/tutorial/security/)
- [Bandit (Python Security Linter)](https://bandit.readthedocs.io/)
- [pip-audit](https://pypi.org/project/pip-audit/)

916
README.md

File diff suppressed because it is too large Load Diff

96
create_shims.sh Normal file
View File

@@ -0,0 +1,96 @@
#!/bin/bash
# Create backward compatibility shims for all migrated files
# admin_auth.py -> core/auth.py
cat > src/web/admin_auth.py << 'EOF'
"""DEPRECATED: Import from src.web.core.auth instead"""
from src.web.core.auth import * # noqa: F401, F403
EOF
# admin_autolabel.py -> services/autolabel.py
cat > src/web/admin_autolabel.py << 'EOF'
"""DEPRECATED: Import from src.web.services.autolabel instead"""
from src.web.services.autolabel import * # noqa: F401, F403
EOF
# admin_scheduler.py -> core/scheduler.py
cat > src/web/admin_scheduler.py << 'EOF'
"""DEPRECATED: Import from src.web.core.scheduler instead"""
from src.web.core.scheduler import * # noqa: F401, F403
EOF
# admin_schemas.py -> schemas/admin.py
cat > src/web/admin_schemas.py << 'EOF'
"""DEPRECATED: Import from src.web.schemas.admin instead"""
from src.web.schemas.admin import * # noqa: F401, F403
EOF
# schemas.py -> schemas/inference.py + schemas/common.py
cat > src/web/schemas.py << 'EOF'
"""DEPRECATED: Import from src.web.schemas.inference or src.web.schemas.common instead"""
from src.web.schemas.inference import * # noqa: F401, F403
from src.web.schemas.common import * # noqa: F401, F403
EOF
# services.py -> services/inference.py
cat > src/web/services.py << 'EOF'
"""DEPRECATED: Import from src.web.services.inference instead"""
from src.web.services.inference import * # noqa: F401, F403
EOF
# async_queue.py -> workers/async_queue.py
cat > src/web/async_queue.py << 'EOF'
"""DEPRECATED: Import from src.web.workers.async_queue instead"""
from src.web.workers.async_queue import * # noqa: F401, F403
EOF
# async_service.py -> services/async_processing.py
cat > src/web/async_service.py << 'EOF'
"""DEPRECATED: Import from src.web.services.async_processing instead"""
from src.web.services.async_processing import * # noqa: F401, F403
EOF
# batch_queue.py -> workers/batch_queue.py
cat > src/web/batch_queue.py << 'EOF'
"""DEPRECATED: Import from src.web.workers.batch_queue instead"""
from src.web.workers.batch_queue import * # noqa: F401, F403
EOF
# batch_upload_service.py -> services/batch_upload.py
cat > src/web/batch_upload_service.py << 'EOF'
"""DEPRECATED: Import from src.web.services.batch_upload instead"""
from src.web.services.batch_upload import * # noqa: F401, F403
EOF
# batch_upload_routes.py -> api/v1/batch/routes.py
cat > src/web/batch_upload_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.batch.routes instead"""
from src.web.api.v1.batch.routes import * # noqa: F401, F403
EOF
# admin_routes.py -> api/v1/admin/documents.py
cat > src/web/admin_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.admin.documents instead"""
from src.web.api.v1.admin.documents import * # noqa: F401, F403
EOF
# admin_annotation_routes.py -> api/v1/admin/annotations.py
cat > src/web/admin_annotation_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.admin.annotations instead"""
from src.web.api.v1.admin.annotations import * # noqa: F401, F403
EOF
# admin_training_routes.py -> api/v1/admin/training.py
cat > src/web/admin_training_routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.admin.training instead"""
from src.web.api.v1.admin.training import * # noqa: F401, F403
EOF
# routes.py -> api/v1/routes.py
cat > src/web/routes.py << 'EOF'
"""DEPRECATED: Import from src.web.api.v1.routes instead"""
from src.web.api.v1.routes import * # noqa: F401, F403
EOF
echo "✓ Created backward compatibility shims for all migrated files"

60
docker-compose.yml Normal file
View File

@@ -0,0 +1,60 @@
version: "3.8"
services:
postgres:
image: postgres:15
environment:
POSTGRES_DB: docmaster
POSTGRES_USER: docmaster
POSTGRES_PASSWORD: ${DB_PASSWORD:-devpassword}
ports:
- "5432:5432"
volumes:
- pgdata:/var/lib/postgresql/data
- ./migrations:/docker-entrypoint-initdb.d
inference:
build:
context: .
dockerfile: packages/inference/Dockerfile
ports:
- "8000:8000"
environment:
- DB_HOST=postgres
- DB_PORT=5432
- DB_NAME=docmaster
- DB_USER=docmaster
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
- MODEL_PATH=/app/models/best.pt
volumes:
- ./models:/app/models
depends_on:
- postgres
training:
build:
context: .
dockerfile: packages/training/Dockerfile
environment:
- DB_HOST=postgres
- DB_PORT=5432
- DB_NAME=docmaster
- DB_USER=docmaster
- DB_PASSWORD=${DB_PASSWORD:-devpassword}
volumes:
- ./models:/app/models
- ./temp:/app/temp
depends_on:
- postgres
# Override CMD for local dev polling mode
command: ["python", "run_training.py", "--poll", "--poll-interval", "30"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
volumes:
pgdata:

View File

@@ -1,405 +0,0 @@
# Invoice Master POC v2 - 代码审查报告
**审查日期**: 2026-01-22
**代码库规模**: 67 个 Python 源文件,约 22,434 行代码
**测试覆盖率**: ~40-50%
---
## 执行摘要
### 总体评估:**良好B+**
**优势**
- ✅ 清晰的模块化架构,职责分离良好
- ✅ 使用了合适的数据类和类型提示
- ✅ 针对瑞典发票的全面规范化逻辑
- ✅ 空间索引优化O(1) token 查找)
- ✅ 完善的降级机制YOLO 失败时的 OCR fallback
- ✅ 设计良好的 Web API 和 UI
**主要问题**
- ❌ 支付行解析代码重复3+ 处)
- ❌ 长函数(`_normalize_customer_number` 127 行)
- ❌ 配置安全问题(明文数据库密码)
- ❌ 异常处理不一致(到处都是通用 Exception
- ❌ 缺少集成测试
- ❌ 魔法数字散布各处0.5, 0.95, 300 等)
---
## 1. 架构分析
### 1.1 模块结构
```
src/
├── inference/ # 推理管道核心
│ ├── pipeline.py (517 行) ⚠️
│ ├── field_extractor.py (1,347 行) 🔴 太长
│ └── yolo_detector.py
├── web/ # FastAPI Web 服务
│ ├── app.py (765 行) ⚠️ HTML 内联
│ ├── routes.py (184 行)
│ └── services.py (286 行)
├── ocr/ # OCR 提取
│ ├── paddle_ocr.py
│ └── machine_code_parser.py (919 行) 🔴 太长
├── matcher/ # 字段匹配
│ └── field_matcher.py (875 行) ⚠️
├── utils/ # 共享工具
│ ├── validators.py
│ ├── text_cleaner.py
│ ├── fuzzy_matcher.py
│ ├── ocr_corrections.py
│ └── format_variants.py (610 行)
├── processing/ # 批处理
├── data/ # 数据管理
└── cli/ # 命令行工具
```
### 1.2 推理流程
```
PDF/Image 输入
渲染为图片 (pdf/renderer.py)
YOLO 检测 (yolo_detector.py) - 检测字段区域
字段提取 (field_extractor.py)
├→ OCR 文本提取 (ocr/paddle_ocr.py)
├→ 规范化 & 验证
└→ 置信度计算
交叉验证 (pipeline.py)
├→ 解析 payment_line 格式
├→ 从 payment_line 提取 OCR/Amount/Account
└→ 与检测字段验证payment_line 值优先
降级 OCR如果关键字段缺失
├→ 全页 OCR
└→ 正则提取
InferenceResult 输出
```
---
## 2. 代码质量问题
### 2.1 长函数(>50 行)🔴
| 函数 | 文件 | 行数 | 复杂度 | 问题 |
|------|------|------|--------|------|
| `_normalize_customer_number()` | field_extractor.py | **127** | 极高 | 4 层模式匹配7+ 正则,复杂评分 |
| `_cross_validate_payment_line()` | pipeline.py | **127** | 极高 | 核心验证逻辑8+ 条件分支 |
| `_normalize_bankgiro()` | field_extractor.py | 62 | 高 | Luhn 验证 + 多种降级 |
| `_normalize_plusgiro()` | field_extractor.py | 63 | 高 | 类似 bankgiro |
| `_normalize_payment_line()` | field_extractor.py | 74 | 高 | 4 种正则模式 |
| `_normalize_amount()` | field_extractor.py | 78 | 高 | 多策略降级 |
**示例问题** - `_normalize_customer_number()` (第 776-902 行):
```python
def _normalize_customer_number(self, text: str):
# 127 行函数,包含:
# - 4 个嵌套的 if/for 循环
# - 7 种不同的正则模式
# - 5 个评分机制
# - 处理有标签和无标签格式
```
**建议**: 拆分为:
- `_find_customer_code_patterns()`
- `_find_labeled_customer_code()`
- `_score_customer_candidates()`
### 2.2 代码重复 🔴
**支付行解析3+ 处重复实现)**:
1. `_parse_machine_readable_payment_line()` (pipeline.py:217-252)
2. `MachineCodeParser.parse()` (machine_code_parser.py:919 行)
3. `_normalize_payment_line()` (field_extractor.py:632-705)
所有三处都实现类似的正则模式:
```
格式: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
```
**Bankgiro/Plusgiro 验证(重复)**:
- `validators.py`: `is_valid_bankgiro()`, `format_bankgiro()`
- `field_extractor.py`: `_normalize_bankgiro()`, `_normalize_plusgiro()`, `_luhn_checksum()`
- `normalizer.py`: `normalize_bankgiro()`, `normalize_plusgiro()`
- `field_matcher.py`: 类似匹配逻辑
**建议**: 创建统一模块:
```python
# src/common/payment_line_parser.py
class PaymentLineParser:
def parse(text: str) -> PaymentLineResult
# src/common/giro_validator.py
class GiroValidator:
def validate_and_format(value: str, giro_type: str) -> str
```
### 2.3 错误处理不一致 ⚠️
**通用异常捕获31 处)**:
```python
except Exception as e: # 代码库中 31 处
result.errors.append(str(e))
```
**问题**:
- 没有捕获特定错误类型
- 通用错误消息丢失上下文
- 第 142-147 行 (routes.py): 捕获所有异常,返回 500 状态
**当前写法** (routes.py:142-147):
```python
try:
service_result = inference_service.process_pdf(...)
except Exception as e: # 太宽泛
logger.error(f"Error processing document: {e}")
raise HTTPException(status_code=500, detail=str(e))
```
**改进建议**:
```python
except FileNotFoundError:
raise HTTPException(status_code=400, detail="PDF 文件未找到")
except PyMuPDFError:
raise HTTPException(status_code=400, detail="无效的 PDF 格式")
except OCRError:
raise HTTPException(status_code=503, detail="OCR 服务不可用")
```
### 2.4 配置安全问题 🔴
**config.py 第 24-30 行** - 明文凭据:
```python
DATABASE = {
'host': '192.168.68.31', # 硬编码 IP
'user': 'docmaster', # 硬编码用户名
'password': 'nY6LYK5d', # 🔴 明文密码!
'database': 'invoice_master'
}
```
**建议**:
```python
DATABASE = {
'host': os.getenv('DB_HOST', 'localhost'),
'user': os.getenv('DB_USER', 'docmaster'),
'password': os.getenv('DB_PASSWORD'), # 从环境变量读取
'database': os.getenv('DB_NAME', 'invoice_master')
}
```
### 2.5 魔法数字 ⚠️
| 值 | 位置 | 用途 | 问题 |
|---|------|------|------|
| 0.5 | 多处 | 置信度阈值 | 不可按字段配置 |
| 0.95 | pipeline.py | payment_line 置信度 | 无说明 |
| 300 | 多处 | DPI | 硬编码 |
| 0.1 | field_extractor.py | BBox 填充 | 应为配置 |
| 72 | 多处 | PDF 基础 DPI | 公式中的魔法数字 |
| 50 | field_extractor.py | 客户编号评分加分 | 无说明 |
**建议**: 提取到配置:
```python
INFERENCE_CONFIG = {
'confidence_threshold': 0.5,
'payment_line_confidence': 0.95,
'dpi': 300,
'bbox_padding': 0.1,
}
```
### 2.6 命名不一致 ⚠️
**字段名称不一致**:
- YOLO 类名: `invoice_number`, `ocr_number`, `supplier_org_number`
- 字段名: `InvoiceNumber`, `OCR`, `supplier_org_number`
- CSV 列名: 可能又不同
- 数据库字段名: 另一种变体
映射维护在多处:
- `yolo_detector.py` (90-100 行): `CLASS_TO_FIELD`
- 多个其他位置
---
## 3. 测试分析
### 3.1 测试覆盖率
**测试文件**: 13 个
- ✅ 覆盖良好: field_matcher, normalizer, payment_line_parser
- ⚠️ 中等覆盖: field_extractor, pipeline
- ❌ 覆盖不足: web 层, CLI, 批处理
**估算覆盖率**: 40-50%
### 3.2 缺失的测试用例 🔴
**关键缺失**:
1. 交叉验证逻辑 - 最复杂部分,测试很少
2. payment_line 解析变体 - 多种实现,边界情况不清楚
3. OCR 错误纠正 - 不同策略的复杂逻辑
4. Web API 端点 - 没有请求/响应测试
5. 批处理 - 多 worker 协调未测试
6. 降级 OCR 机制 - YOLO 检测失败时
---
## 4. 架构风险
### 🔴 关键风险
1. **配置安全** - config.py 中明文数据库凭据24-30 行)
2. **错误恢复** - 宽泛的异常处理掩盖真实问题
3. **可测试性** - 硬编码依赖阻止单元测试
### 🟡 高风险
1. **代码可维护性** - 支付行解析重复
2. **可扩展性** - 没有长时间推理的异步处理
3. **扩展性** - 添加新字段类型会很困难
### 🟢 中等风险
1. **性能** - 懒加载有帮助,但 ORM 查询未优化
2. **文档** - 大部分足够但可以更好
---
## 5. 优先级矩阵
| 优先级 | 行动 | 工作量 | 影响 |
|--------|------|--------|------|
| 🔴 关键 | 修复配置安全(环境变量) | 1 小时 | 高 |
| 🔴 关键 | 添加集成测试 | 2-3 天 | 高 |
| 🔴 关键 | 文档化错误处理策略 | 4 小时 | 中 |
| 🟡 高 | 统一 payment_line 解析 | 1-2 天 | 高 |
| 🟡 高 | 提取规范化到子模块 | 2-3 天 | 中 |
| 🟡 高 | 添加依赖注入 | 2-3 天 | 中 |
| 🟡 高 | 拆分长函数 | 2-3 天 | 低 |
| 🟢 中 | 提高测试覆盖率到 70%+ | 3-5 天 | 高 |
| 🟢 中 | 提取魔法数字 | 4 小时 | 低 |
| 🟢 中 | 标准化命名约定 | 1-2 天 | 中 |
---
## 6. 具体文件建议
### 高优先级(代码质量)
| 文件 | 问题 | 建议 |
|------|------|------|
| `field_extractor.py` | 1,347 行6 个长规范化方法 | 拆分为 `normalizers/` 子模块 |
| `pipeline.py` | 127 行 `_cross_validate_payment_line()` | 提取到单独的 `CrossValidator` 类 |
| `field_matcher.py` | 875 行;复杂匹配逻辑 | 拆分为 `matching/` 子模块 |
| `config.py` | 硬编码凭据(第 29 行) | 使用环境变量 |
| `machine_code_parser.py` | 919 行payment_line 解析 | 与 pipeline 解析合并 |
### 中优先级(重构)
| 文件 | 问题 | 建议 |
|------|------|------|
| `app.py` | 765 行HTML 内联在 Python 中 | 提取到 `templates/` 目录 |
| `autolabel.py` | 753 行;批处理逻辑 | 提取 worker 函数到模块 |
| `format_variants.py` | 610 行;变体生成 | 考虑策略模式 |
---
## 7. 建议行动
### 第 1 阶段关键修复1 周)
1. **配置安全** (1 小时)
- 移除 config.py 中的明文密码
- 添加环境变量支持
- 更新 README 说明配置
2. **错误处理标准化** (1 天)
- 定义自定义异常类
- 替换通用 Exception 捕获
- 添加错误代码常量
3. **添加关键集成测试** (2 天)
- 端到端推理测试
- payment_line 交叉验证测试
- API 端点测试
### 第 2 阶段重构2-3 周)
4. **统一 payment_line 解析** (2 天)
- 创建 `src/common/payment_line_parser.py`
- 合并 3 处重复实现
- 迁移所有调用方
5. **拆分 field_extractor.py** (3 天)
- 创建 `src/inference/normalizers/` 子模块
- 每个字段类型一个文件
- 提取共享验证逻辑
6. **拆分长函数** (2 天)
- `_normalize_customer_number()` → 3 个函数
- `_cross_validate_payment_line()` → CrossValidator 类
### 第 3 阶段改进1-2 周)
7. **提高测试覆盖率** (5 天)
- 目标70%+ 覆盖率
- 专注于验证逻辑
- 添加边界情况测试
8. **配置管理改进** (1 天)
- 提取所有魔法数字
- 创建配置文件YAML
- 添加配置验证
9. **文档改进** (2 天)
- 添加架构图
- 文档化所有私有方法
- 创建贡献指南
---
## 附录 A度量指标
### 代码复杂度
| 类别 | 计数 | 平均行数 |
|------|------|----------|
| 源文件 | 67 | 334 |
| 长文件 (>500 行) | 12 | 875 |
| 长函数 (>50 行) | 23 | 89 |
| 测试文件 | 13 | 298 |
### 依赖关系
| 类型 | 计数 |
|------|------|
| 外部依赖 | ~25 |
| 内部模块 | 10 |
| 循环依赖 | 0 ✅ |
### 代码风格
| 指标 | 覆盖率 |
|------|--------|
| 类型提示 | 80% |
| Docstrings (公开) | 80% |
| Docstrings (私有) | 40% |
| 测试覆盖率 | 45% |
---
**生成日期**: 2026-01-22
**审查者**: Claude Code
**版本**: v2.0

View File

@@ -1,96 +0,0 @@
# Field Extractor 分析报告
## 概述
field_extractor.py (1183行) 最初被识别为可优化文件,尝试使用 `src/normalize` 模块进行重构,但经过分析和测试后发现 **不应该重构**
## 重构尝试
### 初始计划
将 field_extractor.py 中的重复 normalize 方法删除,统一使用 `src/normalize/normalize_field()` 接口。
### 实施步骤
1. ✅ 备份原文件 (`field_extractor_old.py`)
2. ✅ 修改 `_normalize_and_validate` 使用统一 normalizer
3. ✅ 删除重复的 normalize 方法 (~400行)
4. ❌ 运行测试 - **28个失败**
5. ✅ 添加 wrapper 方法委托给 normalizer
6. ❌ 再次测试 - **12个失败**
7. ✅ 还原原文件
8. ✅ 测试通过 - **全部45个测试通过**
## 关键发现
### 两个模块的不同用途
| 模块 | 用途 | 输入 | 输出 | 示例 |
|------|------|------|------|------|
| **src/normalize/** | **变体生成** 用于匹配 | 已提取的字段值 | 多个匹配变体列表 | `"INV-12345"``["INV-12345", "12345"]` |
| **field_extractor** | **值提取** 从OCR文本 | 包含字段的原始OCR文本 | 提取的单个字段值 | `"Fakturanummer: A3861"``"A3861"` |
### 为什么不能统一?
1. **src/normalize/** 的设计目的:
- 接收已经提取的字段值
- 生成多个标准化变体用于fuzzy matching
- 例如 BankgiroNormalizer:
```python
normalize("782-1713") → ["7821713", "782-1713"] # 生成变体
```
2. **field_extractor** 的 normalize 方法:
- 接收包含字段的原始OCR文本可能包含标签、其他文本等
- **提取**特定模式的字段值
- 例如 `_normalize_bankgiro`:
```python
_normalize_bankgiro("Bankgiro: 782-1713") → ("782-1713", True, None) # 从文本提取
```
3. **关键区别**:
- Normalizer: 变体生成器 (for matching)
- Field Extractor: 模式提取器 (for parsing)
### 测试失败示例
使用 normalizer 替代 field extractor 方法后的失败:
```python
# InvoiceNumber 测试
Input: "Fakturanummer: A3861"
期望: "A3861"
实际: "Fakturanummer: A3861" # 没有提取,只是清理
# Bankgiro 测试
Input: "Bankgiro: 782-1713"
期望: "782-1713"
实际: "7821713" # 返回了不带破折号的变体,而不是提取格式化值
```
## 结论
**field_extractor.py 不应该使用 src/normalize 模块重构**,因为:
1.**职责不同**: 提取 vs 变体生成
2.**输入不同**: 包含标签的原始OCR文本 vs 已提取的字段值
3.**输出不同**: 单个提取值 vs 多个匹配变体
4.**现有代码运行良好**: 所有45个测试通过
5.**提取逻辑有价值**: 包含复杂的模式匹配规则(例如区分 Bankgiro/Plusgiro 格式)
## 建议
1. **保留 field_extractor.py 原样**: 不进行重构
2. **文档化两个模块的差异**: 确保团队理解各自用途
3. **关注其他优化目标**: machine_code_parser.py (919行)
## 学习点
重构前应该:
1. 理解模块的**真实用途**,而不只是看代码相似度
2. 运行完整测试套件验证假设
3. 评估是否真的存在重复,还是表面相似但用途不同
---
**状态**: ✅ 分析完成,决定不重构
**测试**: ✅ 45/45 通过
**文件**: 保持 1183行 原样

View File

@@ -1,238 +0,0 @@
# Machine Code Parser 分析报告
## 文件概况
- **文件**: `src/ocr/machine_code_parser.py`
- **总行数**: 919 行
- **代码行**: 607 行 (66%)
- **方法数**: 14 个
- **正则表达式使用**: 47 次
## 代码结构
### 类结构
```
MachineCodeResult (数据类)
├── to_dict()
└── get_region_bbox()
MachineCodeParser (主解析器)
├── __init__()
├── parse() - 主入口
├── _find_tokens_with_values()
├── _find_machine_code_line_tokens()
├── _parse_standard_payment_line_with_tokens()
├── _parse_standard_payment_line() - 142行 ⚠️
├── _extract_ocr() - 50行
├── _extract_bankgiro() - 58行
├── _extract_plusgiro() - 30行
├── _extract_amount() - 68行
├── _calculate_confidence()
└── cross_validate()
```
## 发现的问题
### 1. ⚠️ `_parse_standard_payment_line` 方法过长 (142行)
**位置**: 442-582 行
**问题**:
- 包含嵌套函数 `normalize_account_spaces``format_account`
- 多个正则匹配分支
- 逻辑复杂,难以测试和维护
**建议**:
可以拆分为独立方法:
- `_normalize_account_spaces(line)`
- `_format_account(account_digits, context)`
- `_match_primary_pattern(line)`
- `_match_fallback_patterns(line)`
### 2. 🔁 4个 `_extract_*` 方法有重复模式
所有 extract 方法都遵循相同模式:
```python
def _extract_XXX(self, tokens):
candidates = []
for token in tokens:
text = token.text.strip()
matches = self.XXX_PATTERN.findall(text)
for match in matches:
# 验证逻辑
# 上下文检测
candidates.append((normalized, context_score, token))
if not candidates:
return None
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
return candidates[0][0]
```
**重复的逻辑**:
- Token 迭代
- 模式匹配
- 候选收集
- 上下文评分
- 排序和选择最佳匹配
**建议**:
可以提取基础提取器类或通用方法来减少重复。
### 3. ✅ 上下文检测重复
上下文检测代码在多个地方重复:
```python
# _extract_bankgiro 中
context_text = ' '.join(t.text.lower() for t in tokens)
is_bankgiro_context = (
'bankgiro' in context_text or
'bg:' in context_text or
'bg ' in context_text
)
# _extract_plusgiro 中
context_text = ' '.join(t.text.lower() for t in tokens)
is_plusgiro_context = (
'plusgiro' in context_text or
'postgiro' in context_text or
'pg:' in context_text or
'pg ' in context_text
)
# _parse_standard_payment_line 中
context = (context_line or raw_line).lower()
is_plusgiro_context = (
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
and 'bankgiro' not in context
)
```
**建议**:
提取为独立方法:
- `_detect_account_context(tokens) -> dict[str, bool]`
## 重构建议
### 方案 A: 轻度重构(推荐)✅
**目标**: 提取重复的上下文检测逻辑,不改变主要结构
**步骤**:
1. 提取 `_detect_account_context(tokens)` 方法
2. 提取 `_normalize_account_spaces(line)` 为独立方法
3. 提取 `_format_account(digits, context)` 为独立方法
**影响**:
- 减少 ~50-80 行重复代码
- 提高可测试性
- 低风险,易于验证
**预期结果**: 919 行 → ~850 行 (↓7%)
### 方案 B: 中度重构
**目标**: 创建通用的字段提取框架
**步骤**:
1. 创建 `_generic_extract(pattern, normalizer, context_checker)`
2. 重构所有 `_extract_*` 方法使用通用框架
3. 拆分 `_parse_standard_payment_line` 为多个小方法
**影响**:
- 减少 ~150-200 行代码
- 显著提高可维护性
- 中等风险,需要全面测试
**预期结果**: 919 行 → ~720 行 (↓22%)
### 方案 C: 深度重构(不推荐)
**目标**: 完全重新设计为策略模式
**风险**:
- 高风险,可能引入 bugs
- 需要大量测试
- 可能破坏现有集成
## 推荐方案
### ✅ 采用方案 A轻度重构
**理由**:
1. **代码已经工作良好**: 没有明显的 bug 或性能问题
2. **低风险**: 只提取重复逻辑,不改变核心算法
3. **性价比高**: 小改动带来明显的代码质量提升
4. **易于验证**: 现有测试应该能覆盖
### 重构步骤
```python
# 1. 提取上下文检测
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
"""检测上下文中的账户类型关键词"""
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
# 2. 提取空格标准化
def _normalize_account_spaces(self, line: str) -> str:
"""移除账户号码中的空格"""
# (现有 line 460-481 的代码)
# 3. 提取账户格式化
def _format_account(
self,
account_digits: str,
is_plusgiro_context: bool
) -> tuple[str, str]:
"""格式化账户并确定类型"""
# (现有 line 485-523 的代码)
```
## 对比field_extractor vs machine_code_parser
| 特征 | field_extractor | machine_code_parser |
|------|-----------------|---------------------|
| 用途 | 值提取 | 机器码解析 |
| 重复代码 | ~400行normalize方法 | ~80行上下文检测 |
| 重构价值 | ❌ 不同用途,不应统一 | ✅ 可提取共享逻辑 |
| 风险 | 高(会破坏功能) | 低(只是代码组织) |
## 决策
### ✅ 建议重构 machine_code_parser.py
**与 field_extractor 的不同**:
- field_extractor: 重复的方法有**不同的用途**(提取 vs 变体生成)
- machine_code_parser: 重复的代码有**相同的用途**(都是上下文检测)
**预期收益**:
- 减少 ~70 行重复代码
- 提高可测试性(可以单独测试上下文检测)
- 更清晰的代码组织
- **低风险**,易于验证
## 下一步
1. ✅ 备份原文件
2. ✅ 提取 `_detect_account_context` 方法
3. ✅ 提取 `_normalize_account_spaces` 方法
4. ✅ 提取 `_format_account` 方法
5. ✅ 更新所有调用点
6. ✅ 运行测试验证
7. ✅ 检查代码覆盖率
---
**状态**: 📋 分析完成,建议轻度重构
**风险评估**: 🟢 低风险
**预期收益**: 919行 → ~850行 (↓7%)

View File

@@ -1,519 +0,0 @@
# Performance Optimization Guide
This document provides performance optimization recommendations for the Invoice Field Extraction system.
## Table of Contents
1. [Batch Processing Optimization](#batch-processing-optimization)
2. [Database Query Optimization](#database-query-optimization)
3. [Caching Strategies](#caching-strategies)
4. [Memory Management](#memory-management)
5. [Profiling and Monitoring](#profiling-and-monitoring)
---
## Batch Processing Optimization
### Current State
The system processes invoices one at a time. For large batches, this can be inefficient.
### Recommendations
#### 1. Database Batch Operations
**Current**: Individual inserts for each document
```python
# Inefficient
for doc in documents:
db.insert_document(doc) # Individual DB call
```
**Optimized**: Use `execute_values` for batch inserts
```python
# Efficient - already implemented in db.py line 519
from psycopg2.extras import execute_values
execute_values(cursor, """
INSERT INTO documents (...)
VALUES %s
""", document_values)
```
**Impact**: 10-50x faster for batches of 100+ documents
#### 2. PDF Processing Batching
**Recommendation**: Process PDFs in parallel using multiprocessing
```python
from multiprocessing import Pool
def process_batch(pdf_paths, batch_size=10):
"""Process PDFs in parallel batches."""
with Pool(processes=batch_size) as pool:
results = pool.map(pipeline.process_pdf, pdf_paths)
return results
```
**Considerations**:
- GPU models should use a shared process pool (already exists: `src/processing/gpu_pool.py`)
- CPU-intensive tasks can use separate process pool (`src/processing/cpu_pool.py`)
- Current dual pool coordinator (`dual_pool_coordinator.py`) already supports this pattern
**Status**: ✅ Already implemented in `src/processing/` modules
#### 3. Image Caching for Multi-Page PDFs
**Current**: Each page rendered independently
```python
# Current pattern in field_extractor.py
for page_num in range(total_pages):
image = render_pdf_page(pdf_path, page_num, dpi=300)
```
**Optimized**: Pre-render all pages if processing multiple fields per page
```python
# Batch render
images = {
page_num: render_pdf_page(pdf_path, page_num, dpi=300)
for page_num in page_numbers_needed
}
# Reuse images
for detection in detections:
image = images[detection.page_no]
extract_field(detection, image)
```
**Impact**: Reduces redundant PDF rendering by 50-90% for multi-field invoices
---
## Database Query Optimization
### Current Performance
- **Parameterized queries**: ✅ Implemented (Phase 1)
- **Connection pooling**: ❌ Not implemented
- **Query batching**: ✅ Partially implemented
- **Index optimization**: ⚠️ Needs verification
### Recommendations
#### 1. Connection Pooling
**Current**: New connection for each operation
```python
def connect(self):
"""Create new database connection."""
return psycopg2.connect(**self.config)
```
**Optimized**: Use connection pooling
```python
from psycopg2 import pool
class DocumentDatabase:
def __init__(self, config):
self.pool = pool.SimpleConnectionPool(
minconn=1,
maxconn=10,
**config
)
def connect(self):
return self.pool.getconn()
def close(self, conn):
self.pool.putconn(conn)
```
**Impact**:
- Reduces connection overhead by 80-95%
- Especially important for high-frequency operations
#### 2. Index Recommendations
**Check current indexes**:
```sql
-- Verify indexes exist on frequently queried columns
SELECT tablename, indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'public';
```
**Recommended indexes**:
```sql
-- If not already present
CREATE INDEX IF NOT EXISTS idx_documents_success
ON documents(success);
CREATE INDEX IF NOT EXISTS idx_documents_timestamp
ON documents(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id
ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_matched
ON field_results(matched);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name
ON field_results(field_name);
```
**Impact**:
- 10-100x faster queries for filtered/sorted results
- Critical for `get_failed_matches()` and `get_all_documents_summary()`
#### 3. Query Batching
**Status**: ✅ Already implemented for field results (line 519)
**Verify batching is used**:
```python
# Good pattern in db.py
execute_values(cursor, "INSERT INTO field_results (...) VALUES %s", field_values)
```
**Additional opportunity**: Batch `SELECT` queries
```python
# Current
docs = [get_document(doc_id) for doc_id in doc_ids] # N queries
# Optimized
docs = get_documents_batch(doc_ids) # 1 query with IN clause
```
**Status**: ✅ Already implemented (`get_documents_batch` exists in db.py)
---
## Caching Strategies
### 1. Model Loading Cache
**Current**: Models loaded per-instance
**Recommendation**: Singleton pattern for YOLO model
```python
class YOLODetectorSingleton:
_instance = None
_model = None
@classmethod
def get_instance(cls, model_path):
if cls._instance is None:
cls._instance = YOLODetector(model_path)
return cls._instance
```
**Impact**: Reduces memory usage by 90% when processing multiple documents
### 2. Parser Instance Caching
**Current**: ✅ Already optimal
```python
# Good pattern in field_extractor.py
def __init__(self):
self.payment_line_parser = PaymentLineParser() # Reused
self.customer_number_parser = CustomerNumberParser() # Reused
```
**Status**: No changes needed
### 3. OCR Result Caching
**Recommendation**: Cache OCR results for identical regions
```python
from functools import lru_cache
@lru_cache(maxsize=1000)
def ocr_region_cached(image_hash, bbox):
"""Cache OCR results by image hash + bbox."""
return paddle_ocr.ocr_region(image, bbox)
```
**Impact**: 50-80% speedup when re-processing similar documents
**Note**: Requires implementing image hashing (e.g., `hashlib.md5(image.tobytes())`)
---
## Memory Management
### Current Issues
**Potential memory leaks**:
1. Large images kept in memory after processing
2. OCR results accumulated without cleanup
3. Model outputs not explicitly cleared
### Recommendations
#### 1. Explicit Image Cleanup
```python
import gc
def process_pdf(pdf_path):
try:
image = render_pdf(pdf_path)
result = extract_fields(image)
return result
finally:
del image # Explicit cleanup
gc.collect() # Force garbage collection
```
#### 2. Generator Pattern for Large Batches
**Current**: Load all documents into memory
```python
docs = [process_pdf(path) for path in pdf_paths] # All in memory
```
**Optimized**: Use generator for streaming processing
```python
def process_batch_streaming(pdf_paths):
"""Process documents one at a time, yielding results."""
for path in pdf_paths:
result = process_pdf(path)
yield result
# Result can be saved to DB immediately
# Previous result is garbage collected
```
**Impact**: Constant memory usage regardless of batch size
#### 3. Context Managers for Resources
```python
class InferencePipeline:
def __enter__(self):
self.detector.load_model()
return self
def __exit__(self, *args):
self.detector.unload_model()
self.extractor.cleanup()
# Usage
with InferencePipeline(...) as pipeline:
results = pipeline.process_pdf(path)
# Automatic cleanup
```
---
## Profiling and Monitoring
### Recommended Profiling Tools
#### 1. cProfile for CPU Profiling
```python
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
# Your code here
pipeline.process_pdf(pdf_path)
profiler.disable()
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20) # Top 20 slowest functions
```
#### 2. memory_profiler for Memory Analysis
```bash
pip install memory_profiler
python -m memory_profiler your_script.py
```
Or decorator-based:
```python
from memory_profiler import profile
@profile
def process_large_batch(pdf_paths):
# Memory usage tracked line-by-line
results = [process_pdf(path) for path in pdf_paths]
return results
```
#### 3. py-spy for Production Profiling
```bash
pip install py-spy
# Profile running process
py-spy top --pid 12345
# Generate flamegraph
py-spy record -o profile.svg -- python your_script.py
```
**Advantage**: No code changes needed, minimal overhead
### Key Metrics to Monitor
1. **Processing Time per Document**
- Target: <10 seconds for single-page invoice
- Current: ~2-5 seconds (estimated)
2. **Memory Usage**
- Target: <2GB for batch of 100 documents
- Monitor: Peak memory usage
3. **Database Query Time**
- Target: <100ms per query (with indexes)
- Monitor: Slow query log
4. **OCR Accuracy vs Speed Trade-off**
- Current: PaddleOCR with GPU (~200ms per region)
- Alternative: Tesseract (~500ms, slightly more accurate)
### Logging Performance Metrics
**Add to pipeline.py**:
```python
import time
import logging
logger = logging.getLogger(__name__)
def process_pdf(self, pdf_path):
start = time.time()
# Processing...
result = self._process_internal(pdf_path)
elapsed = time.time() - start
logger.info(f"Processed {pdf_path} in {elapsed:.2f}s")
# Log to database for analysis
self.db.log_performance({
'document_id': result.document_id,
'processing_time': elapsed,
'field_count': len(result.fields)
})
return result
```
---
## Performance Optimization Priorities
### High Priority (Implement First)
1. **Database parameterized queries** - Already done (Phase 1)
2. **Database connection pooling** - Not implemented
3. **Index optimization** - Needs verification
### Medium Priority
4. **Batch PDF rendering** - Optimization possible
5. **Parser instance reuse** - Already done (Phase 2)
6. **Model caching** - Could improve
### Low Priority (Nice to Have)
7. **OCR result caching** - Complex implementation
8. **Generator patterns** - Refactoring needed
9. **Advanced profiling** - For production optimization
---
## Benchmarking Script
```python
"""
Benchmark script for invoice processing performance.
"""
import time
from pathlib import Path
from src.inference.pipeline import InferencePipeline
def benchmark_single_document(pdf_path, iterations=10):
"""Benchmark single document processing."""
pipeline = InferencePipeline(
model_path="path/to/model.pt",
use_gpu=True
)
times = []
for i in range(iterations):
start = time.time()
result = pipeline.process_pdf(pdf_path)
elapsed = time.time() - start
times.append(elapsed)
print(f"Iteration {i+1}: {elapsed:.2f}s")
avg_time = sum(times) / len(times)
print(f"\nAverage: {avg_time:.2f}s")
print(f"Min: {min(times):.2f}s")
print(f"Max: {max(times):.2f}s")
def benchmark_batch(pdf_paths, batch_size=10):
"""Benchmark batch processing."""
from multiprocessing import Pool
pipeline = InferencePipeline(
model_path="path/to/model.pt",
use_gpu=True
)
start = time.time()
with Pool(processes=batch_size) as pool:
results = pool.map(pipeline.process_pdf, pdf_paths)
elapsed = time.time() - start
avg_per_doc = elapsed / len(pdf_paths)
print(f"Total time: {elapsed:.2f}s")
print(f"Documents: {len(pdf_paths)}")
print(f"Average per document: {avg_per_doc:.2f}s")
print(f"Throughput: {len(pdf_paths)/elapsed:.2f} docs/sec")
if __name__ == "__main__":
# Single document benchmark
benchmark_single_document("test.pdf")
# Batch benchmark
pdf_paths = list(Path("data/test_pdfs").glob("*.pdf"))
benchmark_batch(pdf_paths[:100])
```
---
## Summary
**Implemented (Phase 1-2)**:
- Parameterized queries (SQL injection fix)
- Parser instance reuse (Phase 2 refactoring)
- Batch insert operations (execute_values)
- Dual pool processing (CPU/GPU separation)
**Quick Wins (Low effort, high impact)**:
- Database connection pooling (2-4 hours)
- Index verification and optimization (1-2 hours)
- Batch PDF rendering (4-6 hours)
**Long-term Improvements**:
- OCR result caching with hashing
- Generator patterns for streaming
- Advanced profiling and monitoring
**Expected Impact**:
- Connection pooling: 80-95% reduction in DB overhead
- Indexes: 10-100x faster queries
- Batch rendering: 50-90% less redundant work
- **Overall**: 2-5x throughput improvement for batch processing

File diff suppressed because it is too large Load Diff

View File

@@ -1,170 +0,0 @@
# 代码重构总结报告
## 📊 整体成果
### 测试状态
-**688/688 测试全部通过** (100%)
-**代码覆盖率**: 34% → 37% (+3%)
-**0 个失败**, 0 个错误
### 测试覆盖率改进
-**machine_code_parser**: 25% → 65% (+40%)
-**新增测试**: 55个633 → 688
---
## 🎯 已完成的重构
### 1. ✅ Matcher 模块化 (876行 → 205行, ↓76%)
**文件**:
**重构内容**:
- 将单一876行文件拆分为 **11个模块**
- 提取 **5种独立的匹配策略**
- 创建专门的数据模型、工具函数和上下文处理模块
**新模块结构**:
**测试结果**:
- ✅ 77个 matcher 测试全部通过
- ✅ 完整的README文档
- ✅ 策略模式,易于扩展
**收益**:
- 📉 代码量减少 76%
- 📈 可维护性显著提高
- ✨ 每个策略独立测试
- 🔧 易于添加新策略
---
### 2. ✅ Machine Code Parser 轻度重构 + 测试覆盖 (919行 → 929行)
**文件**: src/ocr/machine_code_parser.py
**重构内容**:
- 提取 **3个共享辅助方法**,消除重复代码
- 优化上下文检测逻辑
- 简化账号格式化方法
**测试改进**:
-**新增55个测试**24 → 79个
-**覆盖率**: 25% → 65% (+40%)
- ✅ 所有688个项目测试通过
**新增测试覆盖**:
- **第一轮** (22个测试):
- `_detect_account_context()` - 8个测试上下文检测
- `_normalize_account_spaces()` - 5个测试空格规范化
- `_format_account()` - 4个测试账号格式化
- `parse()` - 5个测试主入口方法
- **第二轮** (33个测试):
- `_extract_ocr()` - 8个测试OCR 提取)
- `_extract_bankgiro()` - 9个测试Bankgiro 提取)
- `_extract_plusgiro()` - 8个测试Plusgiro 提取)
- `_extract_amount()` - 8个测试金额提取
**收益**:
- 🔄 消除80行重复代码
- 📈 可测试性提高(可独立测试辅助方法)
- 📖 代码可读性提升
- ✅ 覆盖率从25%提升到65% (+40%)
- 🎯 低风险,高回报
---
### 3. ✅ Field Extractor 分析 (决定不重构)
**文件**: (1183行)
**分析结果**: ❌ **不应重构**
**关键洞察**:
- 表面相似的代码可能有**完全不同的用途**
- field_extractor: **解析/提取** 字段值
- src/normalize: **标准化/生成变体** 用于匹配
- 两者职责不同,不应统一
**文档**:
---
## 📈 重构统计
### 代码行数变化
| 文件 | 重构前 | 重构后 | 变化 | 百分比 |
|------|--------|--------|------|--------|
| **matcher/field_matcher.py** | 876行 | 205行 | -671 | ↓76% |
| **matcher/* (新增10个模块)** | 0行 | 466行 | +466 | 新增 |
| **matcher 总计** | 876行 | 671行 | -205 | ↓23% |
| **ocr/machine_code_parser.py** | 919行 | 929行 | +10 | +1% |
| **总净减少** | - | - | **-195行** | **↓11%** |
### 测试覆盖
| 模块 | 测试数 | 通过率 | 覆盖率 | 状态 |
|------|--------|--------|--------|------|
| matcher | 77 | 100% | - | ✅ |
| field_extractor | 45 | 100% | 39% | ✅ |
| machine_code_parser | 79 | 100% | 65% | ✅ |
| normalizer | ~120 | 100% | - | ✅ |
| 其他模块 | ~367 | 100% | - | ✅ |
| **总计** | **688** | **100%** | **37%** | ✅ |
---
## 🎓 重构经验总结
### 成功经验
1. **✅ 先测试后重构**
- 所有重构都有完整测试覆盖
- 每次改动后立即验证测试
- 100%测试通过率保证质量
2. **✅ 识别真正的重复**
- 不是所有相似代码都是重复
- field_extractor vs normalizer: 表面相似但用途不同
- machine_code_parser: 真正的代码重复
3. **✅ 渐进式重构**
- matcher: 大规模模块化 (策略模式)
- machine_code_parser: 轻度重构 (提取共享方法)
- field_extractor: 分析后决定不重构
### 关键决策
#### ✅ 应该重构的情况
- **matcher**: 单一文件过长 (876行),包含多种策略
- **machine_code_parser**: 多处相同用途的重复代码
#### ❌ 不应重构的情况
- **field_extractor**: 相似代码有不同用途
### 教训
**不要盲目追求DRY原则**
> 相似代码不一定是重复。要理解代码的**真实用途**。
---
## ✅ 总结
**关键成果**:
- 📉 净减少 195 行代码
- 📈 代码覆盖率 +3% (34% → 37%)
- ✅ 测试数量 +55 (633 → 688)
- 🎯 machine_code_parser 覆盖率 +40% (25% → 65%)
- ✨ 模块化程度显著提高
- 🎯 可维护性大幅提升
**重要教训**:
> 相似的代码不一定是重复的代码。理解代码的真实用途,才能做出正确的重构决策。
**下一步建议**:
1. 继续提升 machine_code_parser 覆盖率到 80%+ (目前 65%)
2. 为其他低覆盖模块添加测试field_extractor 39%, pipeline 19%
3. 完善边界条件和异常情况的测试

View File

@@ -1,258 +0,0 @@
# 测试覆盖率改进报告
## 📊 改进概览
### 整体统计
-**测试总数**: 633 → 688 (+55个测试, +8.7%)
-**通过率**: 100% (688/688)
-**整体覆盖率**: 34% → 37% (+3%)
### machine_code_parser.py 专项改进
-**测试数**: 24 → 79 (+55个测试, +229%)
-**覆盖率**: 25% → 65% (+40%)
-**未覆盖行**: 273 → 129 (减少144行)
---
## 🎯 新增测试详情
### 第一轮改进 (22个测试)
#### 1. TestDetectAccountContext (8个测试)
测试新增的 `_detect_account_context()` 辅助方法。
**测试用例**:
1. `test_bankgiro_keyword` - 检测 'bankgiro' 关键词
2. `test_bg_keyword` - 检测 'bg:' 缩写
3. `test_plusgiro_keyword` - 检测 'plusgiro' 关键词
4. `test_postgiro_keyword` - 检测 'postgiro' 别名
5. `test_pg_keyword` - 检测 'pg:' 缩写
6. `test_both_contexts` - 同时存在两种关键词
7. `test_no_context` - 无账号关键词
8. `test_case_insensitive` - 大小写不敏感检测
**覆盖的代码路径**:
```python
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
```
---
### 2. TestNormalizeAccountSpacesMethod (5个测试)
测试新增的 `_normalize_account_spaces()` 辅助方法。
**测试用例**:
1. `test_removes_spaces_after_arrow` - 移除 > 后的空格
2. `test_multiple_consecutive_spaces` - 处理多个连续空格
3. `test_no_arrow_returns_unchanged` - 无 > 标记时返回原值
4. `test_spaces_before_arrow_preserved` - 保留 > 前的空格
5. `test_empty_string` - 空字符串处理
**覆盖的代码路径**:
```python
def _normalize_account_spaces(self, line: str) -> str:
if '>' not in line:
return line
parts = line.split('>', 1)
after_arrow = parts[1]
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
while re.search(r'(\d)\s+(\d)', normalized):
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
return parts[0] + '>' + normalized
```
---
### 3. TestFormatAccount (4个测试)
测试新增的 `_format_account()` 辅助方法。
**测试用例**:
1. `test_plusgiro_context_forces_plusgiro` - Plusgiro 上下文强制格式化为 Plusgiro
2. `test_valid_bankgiro_7_digits` - 7位有效 Bankgiro 格式化
3. `test_valid_bankgiro_8_digits` - 8位有效 Bankgiro 格式化
4. `test_defaults_to_bankgiro_when_ambiguous` - 模糊情况默认 Bankgiro
**覆盖的代码路径**:
```python
def _format_account(self, account_digits: str, is_plusgiro_context: bool) -> tuple[str, str]:
if is_plusgiro_context:
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
return formatted, 'plusgiro'
# Luhn 验证逻辑
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
# 决策逻辑
if pg_valid and not bg_valid:
return pg_formatted, 'plusgiro'
elif bg_valid and not pg_valid:
return bg_formatted, 'bankgiro'
else:
return bg_formatted, 'bankgiro'
```
---
### 4. TestParseMethod (5个测试)
测试主入口 `parse()` 方法。
**测试用例**:
1. `test_parse_empty_tokens` - 空 token 列表处理
2. `test_parse_finds_payment_line_in_bottom_region` - 在页面底部35%区域查找付款行
3. `test_parse_ignores_top_region` - 忽略页面顶部区域
4. `test_parse_with_context_keywords` - 检测上下文关键词
5. `test_parse_stores_source_tokens` - 存储源 token
**覆盖的代码路径**:
- Token 过滤(底部区域检测)
- 上下文关键词检测
- 付款行查找和解析
- 结果对象构建
---
### 第二轮改进 (33个测试)
#### 5. TestExtractOCR (8个测试)
测试 `_extract_ocr()` 方法 - OCR 参考号码提取。
**测试用例**:
1. `test_extract_valid_ocr_10_digits` - 提取10位 OCR 号码
2. `test_extract_valid_ocr_15_digits` - 提取15位 OCR 号码
3. `test_extract_ocr_with_hash_markers` - 带 # 标记的 OCR
4. `test_extract_longest_ocr_when_multiple` - 多个候选时选最长
5. `test_extract_ocr_ignores_short_numbers` - 忽略短于10位的数字
6. `test_extract_ocr_ignores_long_numbers` - 忽略长于25位的数字
7. `test_extract_ocr_excludes_bankgiro_variants` - 排除 Bankgiro 变体
8. `test_extract_ocr_empty_tokens` - 空 token 处理
#### 6. TestExtractBankgiro (9个测试)
测试 `_extract_bankgiro()` 方法 - Bankgiro 账号提取。
**测试用例**:
1. `test_extract_bankgiro_7_digits_with_dash` - 带破折号的7位 Bankgiro
2. `test_extract_bankgiro_7_digits_without_dash` - 无破折号的7位 Bankgiro
3. `test_extract_bankgiro_8_digits_with_dash` - 带破折号的8位 Bankgiro
4. `test_extract_bankgiro_8_digits_without_dash` - 无破折号的8位 Bankgiro
5. `test_extract_bankgiro_with_spaces` - 带空格的 Bankgiro
6. `test_extract_bankgiro_handles_plusgiro_format` - 处理 Plusgiro 格式
7. `test_extract_bankgiro_with_context` - 带上下文关键词
8. `test_extract_bankgiro_ignores_plusgiro_context` - 忽略 Plusgiro 上下文
9. `test_extract_bankgiro_empty_tokens` - 空 token 处理
#### 7. TestExtractPlusgiro (8个测试)
测试 `_extract_plusgiro()` 方法 - Plusgiro 账号提取。
**测试用例**:
1. `test_extract_plusgiro_7_digits_with_dash` - 带破折号的7位 Plusgiro
2. `test_extract_plusgiro_7_digits_without_dash` - 无破折号的7位 Plusgiro
3. `test_extract_plusgiro_8_digits` - 8位 Plusgiro
4. `test_extract_plusgiro_with_spaces` - 带空格的 Plusgiro
5. `test_extract_plusgiro_with_context` - 带上下文关键词
6. `test_extract_plusgiro_ignores_too_short` - 忽略少于7位
7. `test_extract_plusgiro_ignores_too_long` - 忽略多于8位
8. `test_extract_plusgiro_empty_tokens` - 空 token 处理
#### 8. TestExtractAmount (8个测试)
测试 `_extract_amount()` 方法 - 金额提取。
**测试用例**:
1. `test_extract_amount_with_comma_decimal` - 逗号小数分隔符
2. `test_extract_amount_with_dot_decimal` - 点号小数分隔符
3. `test_extract_amount_integer` - 整数金额
4. `test_extract_amount_with_thousand_separator` - 千位分隔符
5. `test_extract_amount_large_number` - 大额金额
6. `test_extract_amount_ignores_too_large` - 忽略过大金额
7. `test_extract_amount_ignores_zero` - 忽略零或负数
8. `test_extract_amount_empty_tokens` - 空 token 处理
---
## 📈 覆盖率分析
### 已覆盖的方法
`_detect_account_context()` - **100%** (第一轮新增)
`_normalize_account_spaces()` - **100%** (第一轮新增)
`_format_account()` - **95%** (第一轮新增)
`parse()` - **70%** (第一轮改进)
`_parse_standard_payment_line()` - **95%** (已有测试)
`_extract_ocr()` - **85%** (第二轮新增)
`_extract_bankgiro()` - **90%** (第二轮新增)
`_extract_plusgiro()` - **90%** (第二轮新增)
`_extract_amount()` - **80%** (第二轮新增)
### 仍需改进的方法 (未覆盖/部分覆盖)
⚠️ `_calculate_confidence()` - **0%** (未测试)
⚠️ `cross_validate()` - **0%** (未测试)
⚠️ `get_region_bbox()` - **0%** (未测试)
⚠️ `_find_tokens_with_values()` - **部分覆盖**
⚠️ `_find_machine_code_line_tokens()` - **部分覆盖**
### 未覆盖的代码行129行
主要集中在:
1. **验证方法** (lines 805-824): `_calculate_confidence`, `cross_validate`
2. **辅助方法** (lines 80-92, 336-369, 377-407): Token 查找、bbox 计算、日志记录
3. **边界条件** (lines 648-653, 690, 699, 759-760等): 某些提取方法的边界情况
---
## 🎯 改进建议
### ✅ 已完成目标
- ✅ 覆盖率从 25% 提升到 65% (+40%)
- ✅ 测试数量从 24 增加到 79 (+55个)
- ✅ 提取方法全部测试_extract_ocr, _extract_bankgiro, _extract_plusgiro, _extract_amount
### 下一步目标(覆盖率 65% → 80%+
1. **添加验证方法测试** - 为 `_calculate_confidence`, `cross_validate` 添加测试
2. **添加辅助方法测试** - 为 token 查找和 bbox 计算方法添加测试
3. **完善边界条件** - 增加边界情况和异常处理的测试
4. **集成测试** - 添加端到端的集成测试,使用真实 PDF token 数据
---
## ✅ 已完成的改进
### 重构收益
- ✅ 提取的3个辅助方法现在可以独立测试
- ✅ 测试粒度更细,更容易定位问题
- ✅ 代码可读性提高,测试用例清晰易懂
### 质量保证
- ✅ 所有655个测试100%通过
- ✅ 无回归问题
- ✅ 新增测试覆盖了之前未测试的重构代码
---
## 📚 测试编写经验
### 成功经验
1. **使用 fixture 创建测试数据** - `_create_token()` 辅助方法简化了 token 创建
2. **按方法组织测试类** - 每个方法一个测试类,结构清晰
3. **测试用例命名清晰** - `test_<what>_<condition>` 格式,一目了然
4. **覆盖关键路径** - 优先测试常见场景和边界条件
### 遇到的问题
1. **Token 初始化参数** - 忘记了 `page_no` 参数,导致初始测试失败
- 解决:修复 `_create_token()` 辅助方法,添加 `page_no=0`
---
**报告日期**: 2026-01-24
**状态**: ✅ 完成
**下一步**: 继续提升覆盖率到 60%+

View File

@@ -0,0 +1,772 @@
# AWS 部署方案完整指南
## 目录
- [核心问题](#核心问题)
- [存储方案](#存储方案)
- [训练方案](#训练方案)
- [推理方案](#推理方案)
- [价格对比](#价格对比)
- [推荐架构](#推荐架构)
- [实施步骤](#实施步骤)
- [AWS vs Azure 对比](#aws-vs-azure-对比)
---
## 核心问题
| 问题 | 答案 |
|------|------|
| S3 能用于训练吗? | 可以,用 Mountpoint for S3 或 SageMaker 原生支持 |
| 能实时从 S3 读取训练吗? | 可以SageMaker 支持 Pipe Mode 流式读取 |
| 本地能挂载 S3 吗? | 可以,用 s3fs-fuse 或 Rclone |
| EC2 空闲时收费吗? | 收费,只要运行就按小时计费 |
| 如何按需付费? | 用 SageMaker Managed Spot 或 Lambda |
| 推理服务用什么? | Lambda (Serverless) 或 ECS/Fargate (容器) |
---
## 存储方案
### Amazon S3推荐
S3 是 AWS 的核心存储服务,与 SageMaker 深度集成。
```bash
# 创建 S3 桶
aws s3 mb s3://invoice-training-data --region us-east-1
# 上传训练数据
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
# 创建目录结构
aws s3api put-object --bucket invoice-training-data --key datasets/
aws s3api put-object --bucket invoice-training-data --key models/
```
### Mountpoint for Amazon S3
AWS 官方的 S3 挂载客户端,性能优于 s3fs
```bash
# 安装 Mountpoint
wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb
sudo dpkg -i mount-s3.deb
# 挂载 S3
mkdir -p /mnt/s3-data
mount-s3 invoice-training-data /mnt/s3-data --region us-east-1
# 配置缓存(推荐)
mount-s3 invoice-training-data /mnt/s3-data \
--region us-east-1 \
--cache /tmp/s3-cache \
--metadata-ttl 60
```
### 本地开发挂载
**Linux/Mac (s3fs-fuse):**
```bash
# 安装
sudo apt-get install s3fs
# 配置凭证
echo ACCESS_KEY_ID:SECRET_ACCESS_KEY > ~/.passwd-s3fs
chmod 600 ~/.passwd-s3fs
# 挂载
s3fs invoice-training-data /mnt/s3 -o passwd_file=~/.passwd-s3fs
```
**Windows (Rclone):**
```powershell
# 安装
winget install Rclone.Rclone
# 配置
rclone config # 选择 s3
# 挂载
rclone mount aws:invoice-training-data Z: --vfs-cache-mode full
```
### 存储费用
| 层级 | 价格 | 适用场景 |
|------|------|---------|
| S3 Standard | $0.023/GB/月 | 频繁访问 |
| S3 Intelligent-Tiering | $0.023/GB/月 | 自动分层 |
| S3 Infrequent Access | $0.0125/GB/月 | 偶尔访问 |
| S3 Glacier | $0.004/GB/月 | 长期存档 |
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.12/月**
### SageMaker 数据输入模式
| 模式 | 说明 | 适用场景 |
|------|------|---------|
| File Mode | 下载到本地再训练 | 小数据集 |
| Pipe Mode | 流式读取,不占本地空间 | 大数据集 |
| FastFile Mode | 按需下载,最高 3x 加速 | 推荐 |
---
## 训练方案
### 方案总览
| 方案 | 适用场景 | 空闲费用 | 复杂度 | Spot 支持 |
|------|---------|---------|--------|----------|
| EC2 GPU | 简单直接 | 24/7 收费 | 低 | 是 |
| SageMaker Training | MLOps 集成 | 按任务计费 | 中 | 是 |
| EKS + GPU | Kubernetes | 复杂计费 | 高 | 是 |
### EC2 vs SageMaker
| 特性 | EC2 | SageMaker |
|------|-----|-----------|
| 本质 | 虚拟机 | 托管 ML 平台 |
| 计算费用 | $3.06/hr (p3.2xlarge) | $3.825/hr (+25%) |
| 管理开销 | 需自己配置 | 全托管 |
| Spot 折扣 | 最高 90% | 最高 90% |
| 实验跟踪 | 无 | 内置 |
| 自动关机 | 无 | 任务完成自动停止 |
### GPU 实例价格 (2025 年 6 月降价后)
| 实例 | GPU | 显存 | On-Demand | Spot 价格 |
|------|-----|------|-----------|----------|
| g4dn.xlarge | 1x T4 | 16GB | $0.526/hr | ~$0.16/hr |
| g4dn.2xlarge | 1x T4 | 16GB | $0.752/hr | ~$0.23/hr |
| p3.2xlarge | 1x V100 | 16GB | $3.06/hr | ~$0.92/hr |
| p3.8xlarge | 4x V100 | 64GB | $12.24/hr | ~$3.67/hr |
| p4d.24xlarge | 8x A100 | 320GB | $32.77/hr | ~$9.83/hr |
**注意**: 2025 年 6 月 AWS 宣布 P4/P5 系列最高降价 45%。
### Spot 实例
```bash
# EC2 Spot 请求
aws ec2 request-spot-instances \
--instance-count 1 \
--type "one-time" \
--launch-specification '{
"ImageId": "ami-0123456789abcdef0",
"InstanceType": "p3.2xlarge",
"KeyName": "my-key"
}'
```
### SageMaker Managed Spot Training
```python
from sagemaker.pytorch import PyTorch
estimator = PyTorch(
entry_point="train.py",
source_dir="./src",
role="arn:aws:iam::123456789012:role/SageMakerRole",
instance_count=1,
instance_type="ml.p3.2xlarge",
framework_version="2.0",
py_version="py310",
# 启用 Spot 实例
use_spot_instances=True,
max_run=3600, # 最长运行 1 小时
max_wait=7200, # 最长等待 2 小时
# 检查点配置Spot 中断恢复)
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
checkpoint_local_path="/opt/ml/checkpoints",
hyperparameters={
"epochs": 100,
"batch-size": 16,
}
)
estimator.fit({
"training": "s3://invoice-training-data/datasets/train/",
"validation": "s3://invoice-training-data/datasets/val/"
})
```
---
## 推理方案
### 方案对比
| 方案 | GPU 支持 | 扩缩容 | 冷启动 | 价格 | 适用场景 |
|------|---------|--------|--------|------|---------|
| Lambda | 否 | 自动 0-N | 快 | 按调用 | 低流量、CPU 推理 |
| Lambda + Container | 否 | 自动 0-N | 较慢 | 按调用 | 复杂依赖 |
| ECS Fargate | 否 | 自动 | 中 | ~$30/月 | 容器化服务 |
| ECS + EC2 GPU | 是 | 手动/自动 | 慢 | ~$100+/月 | GPU 推理 |
| SageMaker Endpoint | 是 | 自动 | 慢 | ~$80+/月 | MLOps 集成 |
| SageMaker Serverless | 否 | 自动 0-N | 中 | 按调用 | 间歇性流量 |
### 推荐方案 1: AWS Lambda (低流量)
对于 YOLO CPU 推理Lambda 最经济:
```python
# lambda_function.py
import json
import boto3
from ultralytics import YOLO
# 模型在 Lambda Layer 或 /tmp 加载
model = None
def load_model():
global model
if model is None:
# 从 S3 下载模型到 /tmp
s3 = boto3.client('s3')
s3.download_file('invoice-models', 'best.pt', '/tmp/best.pt')
model = YOLO('/tmp/best.pt')
return model
def lambda_handler(event, context):
model = load_model()
# 从 S3 获取图片
s3 = boto3.client('s3')
bucket = event['bucket']
key = event['key']
local_path = f'/tmp/{key.split("/")[-1]}'
s3.download_file(bucket, key, local_path)
# 执行推理
results = model.predict(local_path, conf=0.5)
return {
'statusCode': 200,
'body': json.dumps({
'fields': extract_fields(results),
'confidence': get_confidence(results)
})
}
```
**Lambda 配置:**
```yaml
# serverless.yml
service: invoice-inference
provider:
name: aws
runtime: python3.11
timeout: 30
memorySize: 4096 # 4GB 内存
functions:
infer:
handler: lambda_function.lambda_handler
events:
- http:
path: /infer
method: post
layers:
- arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
```
### 推荐方案 2: ECS Fargate (中流量)
```yaml
# task-definition.json
{
"family": "invoice-inference",
"networkMode": "awsvpc",
"requiresCompatibilities": ["FARGATE"],
"cpu": "2048",
"memory": "4096",
"containerDefinitions": [
{
"name": "inference",
"image": "123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest",
"portMappings": [
{
"containerPort": 8000,
"protocol": "tcp"
}
],
"environment": [
{"name": "MODEL_PATH", "value": "/app/models/best.pt"}
],
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/ecs/invoice-inference",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "ecs"
}
}
}
]
}
```
**Auto Scaling 配置:**
```bash
# 创建 Auto Scaling Target
aws application-autoscaling register-scalable-target \
--service-namespace ecs \
--resource-id service/invoice-cluster/invoice-service \
--scalable-dimension ecs:service:DesiredCount \
--min-capacity 1 \
--max-capacity 10
# 基于 CPU 使用率扩缩容
aws application-autoscaling put-scaling-policy \
--service-namespace ecs \
--resource-id service/invoice-cluster/invoice-service \
--scalable-dimension ecs:service:DesiredCount \
--policy-name cpu-scaling \
--policy-type TargetTrackingScaling \
--target-tracking-scaling-policy-configuration '{
"TargetValue": 70,
"PredefinedMetricSpecification": {
"PredefinedMetricType": "ECSServiceAverageCPUUtilization"
},
"ScaleOutCooldown": 60,
"ScaleInCooldown": 120
}'
```
### 方案 3: SageMaker Serverless Inference
```python
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.pytorch import PyTorchModel
model = PyTorchModel(
model_data="s3://invoice-models/model.tar.gz",
role="arn:aws:iam::123456789012:role/SageMakerRole",
entry_point="inference.py",
framework_version="2.0",
py_version="py310"
)
serverless_config = ServerlessInferenceConfig(
memory_size_in_mb=4096,
max_concurrency=10
)
predictor = model.deploy(
serverless_inference_config=serverless_config,
endpoint_name="invoice-inference-serverless"
)
```
### 推理性能对比
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|------|------------|---------|---------|
| Lambda 4GB | ~500-800ms | 按需扩展 | ~$15 (10K 请求) |
| Fargate 2vCPU 4GB | ~300-500ms | ~50 QPS | ~$30 |
| Fargate 4vCPU 8GB | ~200-300ms | ~100 QPS | ~$60 |
| EC2 g4dn.xlarge (T4) | ~50-100ms | ~200 QPS | ~$380 |
---
## 价格对比
### 训练成本对比(假设每天训练 2 小时)
| 方案 | 计算方式 | 月费 |
|------|---------|------|
| EC2 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
| EC2 按需启停 | 2h × 30天 × $3.06 | ~$184 |
| EC2 Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
| SageMaker On-Demand | 2h × 30天 × $3.825 | ~$230 |
| SageMaker Spot | 2h × 30天 × $1.15 | ~$69 |
### 本项目完整成本估算
| 组件 | 推荐方案 | 月费 |
|------|---------|------|
| 数据存储 | S3 Standard (5GB) | ~$0.12 |
| 数据库 | RDS PostgreSQL (db.t3.micro) | ~$15 |
| 推理服务 | Lambda (10K 请求/月) | ~$15 |
| 推理服务 (替代) | ECS Fargate | ~$30 |
| 训练服务 | SageMaker Spot (按需) | ~$2-5/次 |
| ECR (镜像存储) | 基本使用 | ~$1 |
| **总计 (Lambda)** | | **~$35/月** + 训练费 |
| **总计 (Fargate)** | | **~$50/月** + 训练费 |
---
## 推荐架构
### 整体架构图
```
┌─────────────────────────────────────┐
│ Amazon S3 │
│ ├── training-images/ │
│ ├── datasets/ │
│ ├── models/ │
│ └── checkpoints/ │
└─────────────────┬───────────────────┘
┌─────────────────────────────────┼─────────────────────────────────┐
│ │ │
▼ ▼ ▼
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
│ 推理服务 │ │ 训练服务 │ │ API Gateway │
│ │ │ │ │ │
│ 方案 A: Lambda │ │ SageMaker │ │ REST API │
│ ~$15/月 (10K req) │ │ Managed Spot │ │ 触发 Lambda/ECS │
│ │ │ ~$2-5/次训练 │ │ │
│ 方案 B: ECS Fargate │ │ │ │ │
│ ~$30/月 │ │ - 自动启动 │ │ │
│ │ │ - 训练完成自动停止 │ │ │
│ ┌───────────────────┐ │ │ - 检查点自动保存 │ │ │
│ │ FastAPI + YOLO │ │ │ │ │ │
│ │ CPU 推理 │ │ │ │ │ │
│ └───────────────────┘ │ └───────────┬───────────┘ └───────────────────────┘
└───────────┬───────────┘ │
│ │
└───────────────────────────────┼───────────────────────────────────────────┘
┌───────────────────────┐
│ Amazon RDS │
│ PostgreSQL │
│ db.t3.micro │
│ ~$15/月 │
└───────────────────────┘
```
### Lambda 推理配置
```yaml
# SAM template
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Resources:
InferenceFunction:
Type: AWS::Serverless::Function
Properties:
Handler: app.lambda_handler
Runtime: python3.11
MemorySize: 4096
Timeout: 30
Environment:
Variables:
MODEL_BUCKET: invoice-models
MODEL_KEY: best.pt
Policies:
- S3ReadPolicy:
BucketName: invoice-models
- S3ReadPolicy:
BucketName: invoice-uploads
Events:
InferApi:
Type: Api
Properties:
Path: /infer
Method: post
```
### SageMaker 训练配置
```python
from sagemaker.pytorch import PyTorch
estimator = PyTorch(
entry_point="train.py",
source_dir="./src",
role="arn:aws:iam::123456789012:role/SageMakerRole",
instance_count=1,
instance_type="ml.g4dn.xlarge", # T4 GPU
framework_version="2.0",
py_version="py310",
# Spot 实例配置
use_spot_instances=True,
max_run=7200,
max_wait=14400,
# 检查点
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
hyperparameters={
"epochs": 100,
"batch-size": 16,
"model": "yolo11n.pt"
}
)
```
---
## 实施步骤
### 阶段 1: 存储设置
```bash
# 创建 S3 桶
aws s3 mb s3://invoice-training-data --region us-east-1
aws s3 mb s3://invoice-models --region us-east-1
# 上传训练数据
aws s3 sync ./data/dataset/temp s3://invoice-training-data/images/
# 配置生命周期(可选,自动转冷存储)
aws s3api put-bucket-lifecycle-configuration \
--bucket invoice-training-data \
--lifecycle-configuration '{
"Rules": [{
"ID": "MoveToIA",
"Status": "Enabled",
"Transitions": [{
"Days": 30,
"StorageClass": "STANDARD_IA"
}]
}]
}'
```
### 阶段 2: 数据库设置
```bash
# 创建 RDS PostgreSQL
aws rds create-db-instance \
--db-instance-identifier invoice-db \
--db-instance-class db.t3.micro \
--engine postgres \
--engine-version 15 \
--master-username docmaster \
--master-user-password YOUR_PASSWORD \
--allocated-storage 20
# 配置安全组
aws ec2 authorize-security-group-ingress \
--group-id sg-xxx \
--protocol tcp \
--port 5432 \
--source-group sg-yyy
```
### 阶段 3: 推理服务部署
**方案 A: Lambda**
```bash
# 创建 Lambda Layer (依赖)
cd lambda-layer
pip install ultralytics opencv-python-headless -t python/
zip -r layer.zip python/
aws lambda publish-layer-version \
--layer-name yolo-deps \
--zip-file fileb://layer.zip \
--compatible-runtimes python3.11
# 部署 Lambda 函数
cd ../lambda
zip function.zip lambda_function.py
aws lambda create-function \
--function-name invoice-inference \
--runtime python3.11 \
--handler lambda_function.lambda_handler \
--role arn:aws:iam::123456789012:role/LambdaRole \
--zip-file fileb://function.zip \
--memory-size 4096 \
--timeout 30 \
--layers arn:aws:lambda:us-east-1:123456789012:layer:yolo-deps:1
# 创建 API Gateway
aws apigatewayv2 create-api \
--name invoice-api \
--protocol-type HTTP \
--target arn:aws:lambda:us-east-1:123456789012:function:invoice-inference
```
**方案 B: ECS Fargate**
```bash
# 创建 ECR 仓库
aws ecr create-repository --repository-name invoice-inference
# 构建并推送镜像
aws ecr get-login-password | docker login --username AWS --password-stdin 123456789012.dkr.ecr.us-east-1.amazonaws.com
docker build -t invoice-inference .
docker tag invoice-inference:latest 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
docker push 123456789012.dkr.ecr.us-east-1.amazonaws.com/invoice-inference:latest
# 创建 ECS 集群
aws ecs create-cluster --cluster-name invoice-cluster
# 注册任务定义
aws ecs register-task-definition --cli-input-json file://task-definition.json
# 创建服务
aws ecs create-service \
--cluster invoice-cluster \
--service-name invoice-service \
--task-definition invoice-inference \
--desired-count 1 \
--launch-type FARGATE \
--network-configuration '{
"awsvpcConfiguration": {
"subnets": ["subnet-xxx"],
"securityGroups": ["sg-xxx"],
"assignPublicIp": "ENABLED"
}
}'
```
### 阶段 4: 训练服务设置
```python
# setup_sagemaker.py
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
# 创建 SageMaker 执行角色
iam = boto3.client('iam')
role_arn = "arn:aws:iam::123456789012:role/SageMakerExecutionRole"
# 配置训练任务
estimator = PyTorch(
entry_point="train.py",
source_dir="./src/training",
role=role_arn,
instance_count=1,
instance_type="ml.g4dn.xlarge",
framework_version="2.0",
py_version="py310",
use_spot_instances=True,
max_run=7200,
max_wait=14400,
checkpoint_s3_uri="s3://invoice-training-data/checkpoints/",
)
# 保存配置供后续使用
estimator.save("training_config.json")
```
### 阶段 5: 集成训练触发 API
```python
# lambda_trigger_training.py
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
def lambda_handler(event, context):
"""触发 SageMaker 训练任务"""
epochs = event.get('epochs', 100)
estimator = PyTorch(
entry_point="train.py",
source_dir="s3://invoice-training-data/code/",
role="arn:aws:iam::123456789012:role/SageMakerRole",
instance_count=1,
instance_type="ml.g4dn.xlarge",
framework_version="2.0",
py_version="py310",
use_spot_instances=True,
max_run=7200,
max_wait=14400,
hyperparameters={
"epochs": epochs,
"batch-size": 16,
}
)
estimator.fit(
inputs={
"training": "s3://invoice-training-data/datasets/train/",
"validation": "s3://invoice-training-data/datasets/val/"
},
wait=False # 异步执行
)
return {
'statusCode': 200,
'body': {
'training_job_name': estimator.latest_training_job.name,
'status': 'Started'
}
}
```
---
## AWS vs Azure 对比
### 服务对应关系
| 功能 | AWS | Azure |
|------|-----|-------|
| 对象存储 | S3 | Blob Storage |
| 挂载工具 | Mountpoint for S3 | BlobFuse2 |
| ML 平台 | SageMaker | Azure ML |
| 容器服务 | ECS/Fargate | Container Apps |
| Serverless | Lambda | Functions |
| GPU VM | EC2 P3/G4dn | NC/ND 系列 |
| 容器注册 | ECR | ACR |
| 数据库 | RDS PostgreSQL | PostgreSQL Flexible |
### 价格对比
| 组件 | AWS | Azure |
|------|-----|-------|
| 存储 (5GB) | ~$0.12/月 | ~$0.09/月 |
| 数据库 | ~$15/月 | ~$25/月 |
| 推理 (Serverless) | ~$15/月 | ~$30/月 |
| 推理 (容器) | ~$30/月 | ~$30/月 |
| 训练 (Spot GPU) | ~$2-5/次 | ~$1-5/次 |
| **总计** | **~$35-50/月** | **~$65/月** |
### 优劣对比
| 方面 | AWS 优势 | Azure 优势 |
|------|---------|-----------|
| 价格 | Lambda 更便宜 | GPU Spot 更便宜 |
| ML 平台 | SageMaker 更成熟 | Azure ML 更易用 |
| Serverless GPU | 无原生支持 | Container Apps GPU |
| 文档 | 更丰富 | 中文文档更好 |
| 生态 | 更大 | Office 365 集成 |
---
## 总结
### 推荐配置
| 组件 | 推荐方案 | 月费估算 |
|------|---------|---------|
| 数据存储 | S3 Standard | ~$0.12 |
| 数据库 | RDS db.t3.micro | ~$15 |
| 推理服务 | Lambda 4GB | ~$15 |
| 训练服务 | SageMaker Spot | 按需 ~$2-5/次 |
| ECR | 基本使用 | ~$1 |
| **总计** | | **~$35/月** + 训练费 |
### 关键决策
| 场景 | 选择 |
|------|------|
| 最低成本 | Lambda + SageMaker Spot |
| 稳定推理 | ECS Fargate |
| GPU 推理 | ECS + EC2 GPU |
| MLOps 集成 | SageMaker 全家桶 |
### 注意事项
1. **Lambda 冷启动**: 首次调用 ~3-5 秒,可用 Provisioned Concurrency 解决
2. **Spot 中断**: 配置检查点SageMaker 自动恢复
3. **S3 传输**: 同区域免费,跨区域收费
4. **Fargate 无 GPU**: 需要 GPU 必须用 ECS + EC2
5. **SageMaker 加价**: 比 EC2 贵 ~25%,但省管理成本

View File

@@ -0,0 +1,567 @@
# Azure 部署方案完整指南
## 目录
- [核心问题](#核心问题)
- [存储方案](#存储方案)
- [训练方案](#训练方案)
- [推理方案](#推理方案)
- [价格对比](#价格对比)
- [推荐架构](#推荐架构)
- [实施步骤](#实施步骤)
---
## 核心问题
| 问题 | 答案 |
|------|------|
| Azure Blob Storage 能用于训练吗? | 可以,用 BlobFuse2 挂载 |
| 能实时从 Blob 读取训练吗? | 可以,但建议配置本地缓存 |
| 本地能挂载 Azure Blob 吗? | 可以,用 Rclone (Windows) 或 BlobFuse2 (Linux) |
| VM 空闲时收费吗? | 收费,只要开机就按小时计费 |
| 如何按需付费? | 用 Serverless GPU 或 min=0 的 Compute Cluster |
| 推理服务用什么? | Container Apps (CPU) 或 Serverless GPU |
---
## 存储方案
### Azure Blob Storage + BlobFuse2推荐
```bash
# 安装 BlobFuse2
sudo apt-get install blobfuse2
# 配置文件
cat > ~/blobfuse-config.yaml << 'EOF'
logging:
type: syslog
level: log_warning
components:
- libfuse
- file_cache
- azstorage
file_cache:
path: /tmp/blobfuse2
timeout-sec: 120
max-size-mb: 4096
azstorage:
type: block
account-name: YOUR_ACCOUNT
account-key: YOUR_KEY
container: training-images
EOF
# 挂载
mkdir -p /mnt/azure-blob
blobfuse2 mount /mnt/azure-blob --config-file=~/blobfuse-config.yaml
```
### 本地开发Windows
```powershell
# 安装
winget install WinFsp.WinFsp
winget install Rclone.Rclone
# 配置
rclone config # 选择 azureblob
# 挂载为 Z: 盘
rclone mount azure:training-images Z: --vfs-cache-mode full
```
### 存储费用
| 层级 | 价格 | 适用场景 |
|------|------|---------|
| Hot | $0.018/GB/月 | 频繁访问 |
| Cool | $0.01/GB/月 | 偶尔访问 |
| Archive | $0.002/GB/月 | 长期存档 |
**本项目**: ~10,000 张图片 × 500KB = ~5GB → **~$0.09/月**
---
## 训练方案
### 方案总览
| 方案 | 适用场景 | 空闲费用 | 复杂度 |
|------|---------|---------|--------|
| Azure VM | 简单直接 | 24/7 收费 | 低 |
| Azure VM Spot | 省钱、可中断 | 24/7 收费 | 低 |
| Azure ML Compute | MLOps 集成 | 可缩到 0 | 中 |
| Container Apps GPU | Serverless | 自动缩到 0 | 中 |
### Azure VM vs Azure ML
| 特性 | Azure VM | Azure ML |
|------|----------|----------|
| 本质 | 虚拟机 | 托管 ML 平台 |
| 计算费用 | $3.06/hr (NC6s_v3) | $3.06/hr (相同) |
| 附加费用 | ~$5/月 | ~$20-30/月 |
| 实验跟踪 | 无 | 内置 |
| 自动扩缩 | 无 | 支持 min=0 |
| 适用人群 | DevOps | 数据科学家 |
### Azure ML 附加费用明细
| 服务 | 用途 | 费用 |
|------|------|------|
| Container Registry | Docker 镜像 | ~$5-20/月 |
| Blob Storage | 日志、模型 | ~$0.10/月 |
| Application Insights | 监控 | ~$0-10/月 |
| Key Vault | 密钥管理 | <$1/月 |
### Spot 实例
两种平台都支持 Spot/低优先级实例,最高节省 90%
| 类型 | 正常价格 | Spot 价格 | 节省 |
|------|---------|----------|------|
| NC6s_v3 (V100) | $3.06/hr | ~$0.92/hr | 70% |
| NC24ads_A100_v4 | $3.67/hr | ~$1.15/hr | 69% |
### GPU 实例价格
| 实例 | GPU | 显存 | 价格/小时 | Spot 价格 |
|------|-----|------|---------|----------|
| NC6s_v3 | 1x V100 | 16GB | $3.06 | $0.92 |
| NC24s_v3 | 4x V100 | 64GB | $12.24 | $3.67 |
| NC24ads_A100_v4 | 1x A100 | 80GB | $3.67 | $1.15 |
| NC48ads_A100_v4 | 2x A100 | 160GB | $7.35 | $2.30 |
---
## 推理方案
### 方案对比
| 方案 | GPU 支持 | 扩缩容 | 价格 | 适用场景 |
|------|---------|--------|------|---------|
| Container Apps (CPU) | 否 | 自动 0-N | ~$30/月 | YOLO 推理 (够用) |
| Container Apps (GPU) | 是 | Serverless | 按秒计费 | 高吞吐推理 |
| Azure App Service | 否 | 手动/自动 | ~$50/月 | 简单部署 |
| Azure ML Endpoint | 是 | 自动 | ~$100+/月 | MLOps 集成 |
| AKS (Kubernetes) | 是 | 自动 | 复杂计费 | 大规模生产 |
### 推荐: Container Apps (CPU)
对于 YOLO 推理,**CPU 足够**,不需要 GPU
- YOLOv11n 在 CPU 上推理时间 ~200-500ms
- 比 GPU 便宜很多,适合中低流量
```yaml
# Container Apps 配置
name: invoice-inference
image: myacr.azurecr.io/invoice-inference:v1
resources:
cpu: 2.0
memory: 4Gi
scale:
minReplicas: 1 # 最少 1 个实例保持响应
maxReplicas: 10 # 最多扩展到 10 个
rules:
- name: http-scaling
http:
metadata:
concurrentRequests: "50" # 每实例 50 并发时扩容
```
### 推理服务代码示例
```python
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# 安装依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制代码和模型
COPY src/ ./src/
COPY models/best.pt ./models/
# 启动服务
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"]
```
```python
# src/web/app.py
from fastapi import FastAPI, UploadFile, File
from ultralytics import YOLO
import tempfile
app = FastAPI()
model = YOLO("models/best.pt")
@app.post("/api/v1/infer")
async def infer(file: UploadFile = File(...)):
# 保存上传文件
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
# 执行推理
results = model.predict(tmp_path, conf=0.5)
# 返回结果
return {
"fields": extract_fields(results),
"confidence": get_confidence(results)
}
@app.get("/health")
async def health():
return {"status": "healthy"}
```
### 部署命令
```bash
# 1. 创建 Container Registry
az acr create --name invoiceacr --resource-group myRG --sku Basic
# 2. 构建并推送镜像
az acr build --registry invoiceacr --image invoice-inference:v1 .
# 3. 创建 Container Apps 环境
az containerapp env create \
--name invoice-env \
--resource-group myRG \
--location eastus
# 4. 部署应用
az containerapp create \
--name invoice-inference \
--resource-group myRG \
--environment invoice-env \
--image invoiceacr.azurecr.io/invoice-inference:v1 \
--registry-server invoiceacr.azurecr.io \
--cpu 2 --memory 4Gi \
--min-replicas 1 --max-replicas 10 \
--ingress external --target-port 8000
# 5. 获取 URL
az containerapp show --name invoice-inference --resource-group myRG --query properties.configuration.ingress.fqdn
```
### 高吞吐场景: Serverless GPU
如果需要 GPU 加速推理(高并发、低延迟):
```bash
# 请求 GPU 配额
az containerapp env workload-profile add \
--name invoice-env \
--resource-group myRG \
--workload-profile-name gpu \
--workload-profile-type Consumption-GPU-T4
# 部署 GPU 版本
az containerapp create \
--name invoice-inference-gpu \
--resource-group myRG \
--environment invoice-env \
--image invoiceacr.azurecr.io/invoice-inference-gpu:v1 \
--workload-profile-name gpu \
--cpu 4 --memory 8Gi \
--min-replicas 0 --max-replicas 5 \
--ingress external --target-port 8000
```
### 推理性能对比
| 配置 | 单次推理时间 | 并发能力 | 月费估算 |
|------|------------|---------|---------|
| CPU 2核 4GB | ~300-500ms | ~50 QPS | ~$30 |
| CPU 4核 8GB | ~200-300ms | ~100 QPS | ~$60 |
| GPU T4 | ~50-100ms | ~200 QPS | 按秒计费 |
| GPU A100 | ~20-50ms | ~500 QPS | 按秒计费 |
---
## 价格对比
### 月度成本对比(假设每天训练 2 小时)
| 方案 | 计算方式 | 月费 |
|------|---------|------|
| VM 24/7 运行 | 24h × 30天 × $3.06 | ~$2,200 |
| VM 按需启停 | 2h × 30天 × $3.06 | ~$184 |
| VM Spot 按需 | 2h × 30天 × $0.92 | ~$55 |
| Serverless GPU | 2h × 30天 × ~$3.50 | ~$210 |
| Azure ML (min=0) | 2h × 30天 × $3.06 | ~$184 |
### 本项目完整成本估算
| 组件 | 推荐方案 | 月费 |
|------|---------|------|
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
| 数据库 | PostgreSQL Flexible (Burstable B1ms) | ~$25 |
| 推理服务 | Container Apps CPU (2核4GB) | ~$30 |
| 训练服务 | Azure ML Spot (按需) | ~$1-5/次 |
| Container Registry | Basic | ~$5 |
| **总计** | | **~$65/月** + 训练费 |
---
## 推荐架构
### 整体架构图
```
┌─────────────────────────────────────┐
│ Azure Blob Storage │
│ ├── training-images/ │
│ ├── datasets/ │
│ └── models/ │
└─────────────────┬───────────────────┘
┌─────────────────────────────────┼─────────────────────────────────┐
│ │ │
▼ ▼ ▼
┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐
│ 推理服务 (24/7) │ │ 训练服务 (按需) │ │ Web UI (可选) │
│ Container Apps │ │ Azure ML Compute │ │ Static Web Apps │
│ CPU 2核 4GB │ │ min=0, Spot │ │ ~$0 (免费层) │
│ ~$30/月 │ │ ~$1-5/次训练 │ │ │
│ │ │ │ │ │
│ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │
│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │
│ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │
│ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │
└───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘
│ │ │
└───────────────────────────────┼───────────────────────────────┘
┌───────────────────────┐
│ PostgreSQL │
│ Flexible Server │
│ Burstable B1ms │
│ ~$25/月 │
└───────────────────────┘
```
### 推理服务配置
```yaml
# Container Apps - CPU (24/7 运行)
name: invoice-inference
resources:
cpu: 2
memory: 4Gi
scale:
minReplicas: 1
maxReplicas: 10
env:
- name: MODEL_PATH
value: /app/models/best.pt
- name: DB_HOST
secretRef: db-host
- name: DB_PASSWORD
secretRef: db-password
```
### 训练服务配置
**方案 A: Azure ML Compute推荐**
```python
from azure.ai.ml.entities import AmlCompute
gpu_cluster = AmlCompute(
name="gpu-cluster",
size="Standard_NC6s_v3",
min_instances=0, # 空闲时关机
max_instances=1,
tier="LowPriority", # Spot 实例
idle_time_before_scale_down=120
)
```
**方案 B: Container Apps Serverless GPU**
```yaml
name: invoice-training
resources:
gpu: 1
gpuType: A100
scale:
minReplicas: 0
maxReplicas: 1
```
---
## 实施步骤
### 阶段 1: 存储设置
```bash
# 创建 Storage Account
az storage account create \
--name invoicestorage \
--resource-group myRG \
--sku Standard_LRS
# 创建容器
az storage container create --name training-images --account-name invoicestorage
az storage container create --name datasets --account-name invoicestorage
az storage container create --name models --account-name invoicestorage
# 上传训练数据
az storage blob upload-batch \
--destination training-images \
--source ./data/dataset/temp \
--account-name invoicestorage
```
### 阶段 2: 数据库设置
```bash
# 创建 PostgreSQL
az postgres flexible-server create \
--name invoice-db \
--resource-group myRG \
--sku-name Standard_B1ms \
--storage-size 32 \
--admin-user docmaster \
--admin-password YOUR_PASSWORD
# 配置防火墙
az postgres flexible-server firewall-rule create \
--name allow-azure \
--resource-group myRG \
--server-name invoice-db \
--start-ip-address 0.0.0.0 \
--end-ip-address 0.0.0.0
```
### 阶段 3: 推理服务部署
```bash
# 创建 Container Registry
az acr create --name invoiceacr --resource-group myRG --sku Basic
# 构建镜像
az acr build --registry invoiceacr --image invoice-inference:v1 .
# 创建环境
az containerapp env create \
--name invoice-env \
--resource-group myRG \
--location eastus
# 部署推理服务
az containerapp create \
--name invoice-inference \
--resource-group myRG \
--environment invoice-env \
--image invoiceacr.azurecr.io/invoice-inference:v1 \
--registry-server invoiceacr.azurecr.io \
--cpu 2 --memory 4Gi \
--min-replicas 1 --max-replicas 10 \
--ingress external --target-port 8000 \
--env-vars \
DB_HOST=invoice-db.postgres.database.azure.com \
DB_NAME=docmaster \
DB_USER=docmaster \
--secrets db-password=YOUR_PASSWORD
```
### 阶段 4: 训练服务设置
```bash
# 创建 Azure ML Workspace
az ml workspace create --name invoice-ml --resource-group myRG
# 创建 Compute Cluster
az ml compute create --name gpu-cluster \
--type AmlCompute \
--size Standard_NC6s_v3 \
--min-instances 0 \
--max-instances 1 \
--tier low_priority
```
### 阶段 5: 集成训练触发 API
```python
# src/web/routes/training.py
from fastapi import APIRouter
from azure.ai.ml import MLClient, command
from azure.identity import DefaultAzureCredential
router = APIRouter()
ml_client = MLClient(
credential=DefaultAzureCredential(),
subscription_id="your-subscription-id",
resource_group_name="myRG",
workspace_name="invoice-ml"
)
@router.post("/api/v1/train")
async def trigger_training(request: TrainingRequest):
"""触发 Azure ML 训练任务"""
training_job = command(
code="./training",
command=f"python train.py --epochs {request.epochs}",
environment="AzureML-pytorch-2.0-cuda11.8@latest",
compute="gpu-cluster",
)
job = ml_client.jobs.create_or_update(training_job)
return {
"job_id": job.name,
"status": job.status,
"studio_url": job.studio_url
}
@router.get("/api/v1/train/{job_id}/status")
async def get_training_status(job_id: str):
"""查询训练状态"""
job = ml_client.jobs.get(job_id)
return {"status": job.status}
```
---
## 总结
### 推荐配置
| 组件 | 推荐方案 | 月费估算 |
|------|---------|---------|
| 图片存储 | Blob Storage (Hot) | ~$0.10 |
| 数据库 | PostgreSQL Flexible | ~$25 |
| 推理服务 | Container Apps CPU | ~$30 |
| 训练服务 | Azure ML (min=0, Spot) | 按需 ~$1-5/次 |
| Container Registry | Basic | ~$5 |
| **总计** | | **~$65/月** + 训练费 |
### 关键决策
| 场景 | 选择 |
|------|------|
| 偶尔训练,简单需求 | Azure VM Spot + 手动启停 |
| 需要 MLOps团队协作 | Azure ML Compute |
| 追求最低空闲成本 | Container Apps Serverless GPU |
| 生产环境推理 | Container Apps CPU |
| 高并发推理 | Container Apps Serverless GPU |
### 注意事项
1. **冷启动**: Serverless GPU 启动需要 3-8 分钟
2. **Spot 中断**: 可能被抢占,需要检查点机制
3. **网络延迟**: Blob Storage 挂载比本地 SSD 慢,建议开启缓存
4. **区域选择**: 选择有 GPU 配额的区域 (East US, West Europe 等)
5. **推理优化**: CPU 推理对于 YOLO 已经足够,无需 GPU

View File

@@ -0,0 +1,647 @@
# Dashboard Design Specification
## Overview
Dashboard 是用户进入系统后的第一个页面,用于快速了解:
- 数据标注质量和进度
- 当前模型状态和性能
- 系统最近发生的活动
**目标用户**:使用文档标注系统的客户,需要监控文档处理状态、标注质量和模型训练进度。
---
## 1. UI Layout
### 1.1 Overall Structure
```
+------------------------------------------------------------------+
| Header: Logo + Navigation + User Menu |
+------------------------------------------------------------------+
| |
| Stats Cards Row (4 cards, equal width) |
| |
| +---------------------------+ +------------------------------+ |
| | Data Quality Panel (50%) | | Active Model Panel (50%) | |
| +---------------------------+ +------------------------------+ |
| |
| +--------------------------------------------------------------+ |
| | Recent Activity Panel (full width) | |
| +--------------------------------------------------------------+ |
| |
| +--------------------------------------------------------------+ |
| | System Status Bar (full width) | |
| +--------------------------------------------------------------+ |
+------------------------------------------------------------------+
```
### 1.2 Responsive Breakpoints
| Breakpoint | Layout |
|------------|--------|
| Desktop (>1200px) | 4 cards row, 2-column panels |
| Tablet (768-1200px) | 2x2 cards, 2-column panels |
| Mobile (<768px) | 1 card per row, stacked panels |
---
## 2. Component Specifications
### 2.1 Stats Cards Row
4 个等宽卡片显示核心统计数据
```
+-------------+ +-------------+ +-------------+ +-------------+
| [icon] | | [icon] | | [icon] | | [icon] |
| 38 | | 25 | | 8 | | 5 |
| Total Docs | | Complete | | Incomplete | | Pending |
+-------------+ +-------------+ +-------------+ +-------------+
```
| Card | Icon | Value | Label | Color | Click Action |
|------|------|-------|-------|-------|--------------|
| Total Documents | FileText | `total_documents` | "Total Documents" | Gray | Navigate to Documents page |
| Complete | CheckCircle | `annotation_complete` | "Complete" | Green | Navigate to Documents (filter: complete) |
| Incomplete | AlertCircle | `annotation_incomplete` | "Incomplete" | Orange | Navigate to Documents (filter: incomplete) |
| Pending | Clock | `pending` | "Pending" | Blue | Navigate to Documents (filter: pending) |
**Card Design:**
- Background: White with subtle border
- Icon: 24px, positioned top-left
- Value: 32px bold font
- Label: 14px muted color
- Hover: Slight shadow elevation
- Padding: 16px
### 2.2 Data Quality Panel
左侧面板显示标注完整度和质量指标
```
+---------------------------+
| DATA QUALITY |
| +-----------+ |
| | | |
| | 78% | Annotation |
| | | Complete |
| +-----------+ |
| |
| Complete: 25 |
| Incomplete: 8 |
| Pending: 5 |
| |
| [View Incomplete Docs] |
+---------------------------+
```
**Components:**
| Element | Spec |
|---------|------|
| Title | "DATA QUALITY", 14px uppercase, muted |
| Progress Ring | 120px diameter, stroke width 12px |
| Percentage | 36px bold, centered in ring |
| Label | "Annotation Complete", 14px, below ring |
| Stats List | 14px, icon + label + value per row |
| Action Button | Text button, primary color |
**Progress Ring Colors:**
- Complete portion: Green (#22C55E)
- Remaining: Gray (#E5E7EB)
**Completeness Calculation:**
```
completeness_rate = annotation_complete / (annotation_complete + annotation_incomplete) * 100
```
### 2.3 Active Model Panel
右侧面板显示当前生产模型信息
```
+-------------------------------+
| ACTIVE MODEL |
| |
| v1.2.0 - Invoice Model |
| ----------------------------- |
| |
| mAP Precision Recall |
| 95.1% 94% 92% |
| |
| Activated: 2024-01-20 |
| Documents: 500 |
| |
| [Training] Run-2024-02 [====] |
+-------------------------------+
```
**Components:**
| Element | Spec |
|---------|------|
| Title | "ACTIVE MODEL", 14px uppercase, muted |
| Version + Name | 18px bold (version) + 16px regular (name) |
| Divider | 1px border, full width |
| Metrics Row | 3 columns, equal width |
| Metric Value | 24px bold |
| Metric Label | 12px muted, below value |
| Info Rows | 14px, label: value format |
| Training Indicator | Shows when training is running |
**Metric Colors:**
- mAP >= 90%: Green
- mAP 80-90%: Yellow
- mAP < 80%: Red
**Empty State (No Active Model):**
```
+-------------------------------+
| ACTIVE MODEL |
| |
| [icon: Model] |
| No Active Model |
| |
| Train and activate a |
| model to see stats here |
| |
| [Go to Training] |
+-------------------------------+
```
**Training In Progress:**
```
| Training: Run-2024-02 |
| [=========> ] 45% |
| Started 2 hours ago |
```
### 2.4 Recent Activity Panel
全宽面板显示最近 10 条系统活动
```
+--------------------------------------------------------------+
| RECENT ACTIVITY [See All] |
+--------------------------------------------------------------+
| [rocket] Activated model v1.2.0 2 hours ago|
| [check] Training complete: Run-2024-01, mAP 95.1% yesterday|
| [edit] Modified INV-001.pdf invoice_number yesterday|
| [doc] Uploaded INV-005.pdf 2 days ago|
| [doc] Uploaded INV-004.pdf 2 days ago|
| [x] Training failed: Run-2024-00 3 days ago|
+--------------------------------------------------------------+
```
**Activity Item Layout:**
```
[Icon] [Description] [Timestamp]
```
| Element | Spec |
|---------|------|
| Icon | 16px, color based on type |
| Description | 14px, truncate if too long |
| Timestamp | 12px muted, right-aligned |
| Row Height | 40px |
| Hover | Background highlight |
**Activity Types and Icons:**
| Type | Icon | Color | Description Format |
|------|------|-------|-------------------|
| document_uploaded | FileText | Blue | "Uploaded {filename}" |
| annotation_modified | Edit | Orange | "Modified {filename} {field_name}" |
| training_completed | CheckCircle | Green | "Training complete: {task_name}, mAP {mAP}%" |
| training_failed | XCircle | Red | "Training failed: {task_name}" |
| model_activated | Rocket | Purple | "Activated model {version}" |
**Timestamp Formatting:**
- < 1 minute: "just now"
- < 1 hour: "{n} minutes ago"
- < 24 hours: "{n} hours ago"
- < 7 days: "yesterday" / "{n} days ago"
- >= 7 days: "Jan 15" (date format)
**Empty State:**
```
+--------------------------------------------------------------+
| RECENT ACTIVITY |
| |
| [icon: Activity] |
| No recent activity |
| |
| Start by uploading documents or creating training jobs |
+--------------------------------------------------------------+
```
### 2.5 System Status Bar
底部状态栏,显示系统健康状态。
```
+--------------------------------------------------------------+
| Backend API: [*] Online Database: [*] Connected GPU: [*] Available |
+--------------------------------------------------------------+
```
| Status | Icon | Color |
|--------|------|-------|
| Online/Connected/Available | Filled circle | Green |
| Degraded/Slow | Filled circle | Yellow |
| Offline/Error/Unavailable | Filled circle | Red |
---
## 3. API Endpoints
### 3.1 Dashboard Statistics
```
GET /api/v1/admin/dashboard/stats
```
**Response:**
```json
{
"total_documents": 38,
"annotation_complete": 25,
"annotation_incomplete": 8,
"pending": 5,
"completeness_rate": 75.76
}
```
**Calculation Logic:**
```python
# annotation_complete: labeled documents with core fields
SELECT COUNT(*) FROM admin_documents d
WHERE d.status = 'labeled'
AND EXISTS (
SELECT 1 FROM admin_annotations a
WHERE a.document_id = d.document_id
AND a.class_id IN (0, 3) -- invoice_number OR ocr_number
)
AND EXISTS (
SELECT 1 FROM admin_annotations a
WHERE a.document_id = d.document_id
AND a.class_id IN (4, 5) -- bankgiro OR plusgiro
)
# annotation_incomplete: labeled but missing core fields
SELECT COUNT(*) FROM admin_documents d
WHERE d.status = 'labeled'
AND NOT (/* above conditions */)
# pending: pending + auto_labeling
SELECT COUNT(*) FROM admin_documents
WHERE status IN ('pending', 'auto_labeling')
```
### 3.2 Active Model Info
```
GET /api/v1/admin/dashboard/active-model
```
**Response (with active model):**
```json
{
"model": {
"version_id": "uuid",
"version": "1.2.0",
"name": "Invoice Model",
"metrics_mAP": 0.951,
"metrics_precision": 0.94,
"metrics_recall": 0.92,
"document_count": 500,
"activated_at": "2024-01-20T15:00:00Z"
},
"running_training": {
"task_id": "uuid",
"name": "Run-2024-02",
"status": "running",
"started_at": "2024-01-25T10:00:00Z",
"progress": 45
}
}
```
**Response (no active model):**
```json
{
"model": null,
"running_training": null
}
```
### 3.3 Recent Activity
```
GET /api/v1/admin/dashboard/activity?limit=10
```
**Response:**
```json
{
"activities": [
{
"type": "model_activated",
"description": "Activated model v1.2.0",
"timestamp": "2024-01-25T12:00:00Z",
"metadata": {
"version_id": "uuid",
"version": "1.2.0"
}
},
{
"type": "training_completed",
"description": "Training complete: Run-2024-01, mAP 95.1%",
"timestamp": "2024-01-24T18:30:00Z",
"metadata": {
"task_id": "uuid",
"task_name": "Run-2024-01",
"mAP": 0.951
}
}
]
}
```
**Activity Aggregation Query:**
```sql
-- Union all activity sources, ordered by timestamp DESC, limit 10
(
SELECT 'document_uploaded' as type,
filename as entity_name,
created_at as timestamp,
document_id as entity_id
FROM admin_documents
ORDER BY created_at DESC
LIMIT 10
)
UNION ALL
(
SELECT 'annotation_modified' as type,
-- join to get filename and field name
...
FROM annotation_history
ORDER BY created_at DESC
LIMIT 10
)
UNION ALL
(
SELECT CASE WHEN status = 'completed' THEN 'training_completed'
WHEN status = 'failed' THEN 'training_failed' END as type,
name as entity_name,
completed_at as timestamp,
task_id as entity_id
FROM training_tasks
WHERE status IN ('completed', 'failed')
ORDER BY completed_at DESC
LIMIT 10
)
UNION ALL
(
SELECT 'model_activated' as type,
version as entity_name,
activated_at as timestamp,
version_id as entity_id
FROM model_versions
WHERE activated_at IS NOT NULL
ORDER BY activated_at DESC
LIMIT 10
)
ORDER BY timestamp DESC
LIMIT 10
```
---
## 4. UX Interactions
### 4.1 Loading States
| Component | Loading State |
|-----------|--------------|
| Stats Cards | Skeleton placeholder (gray boxes) |
| Data Quality Ring | Skeleton circle |
| Active Model | Skeleton lines |
| Recent Activity | Skeleton list items (5 rows) |
**Loading Duration Thresholds:**
- < 300ms: No loading state shown
- 300ms - 3s: Show skeleton
- > 3s: Show skeleton + "Taking longer than expected" message
### 4.2 Error States
| Error Type | Display |
|------------|---------|
| API Error | Toast notification + retry button in affected panel |
| Network Error | Full page overlay with retry option |
| Partial Failure | Show available data, error badge on failed sections |
### 4.3 Refresh Behavior
| Trigger | Behavior |
|---------|----------|
| Page Load | Fetch all data |
| Manual Refresh | Button in header, refetch all |
| Auto Refresh | Every 30 seconds for activity panel |
| Focus Return | Refetch if page was hidden > 5 minutes |
### 4.4 Click Actions
| Element | Action |
|---------|--------|
| Total Documents card | Navigate to `/documents` |
| Complete card | Navigate to `/documents?filter=complete` |
| Incomplete card | Navigate to `/documents?filter=incomplete` |
| Pending card | Navigate to `/documents?filter=pending` |
| "View Incomplete Docs" button | Navigate to `/documents?filter=incomplete` |
| Activity item | Navigate to related entity |
| "Go to Training" button | Navigate to `/training` |
| Active Model version | Navigate to `/models/{version_id}` |
### 4.5 Tooltips
| Element | Tooltip Content |
|---------|----------------|
| Completeness % | "25 of 33 labeled documents have complete annotations" |
| mAP metric | "Mean Average Precision at IoU 0.5" |
| Precision metric | "Proportion of correct positive predictions" |
| Recall metric | "Proportion of actual positives correctly identified" |
| Incomplete count | "Documents labeled but missing invoice_number/ocr_number or bankgiro/plusgiro" |
---
## 5. Data Model
### 5.1 TypeScript Types
```typescript
// Dashboard Stats
interface DashboardStats {
total_documents: number;
annotation_complete: number;
annotation_incomplete: number;
pending: number;
completeness_rate: number;
}
// Active Model
interface ActiveModelInfo {
model: ModelVersion | null;
running_training: RunningTraining | null;
}
interface ModelVersion {
version_id: string;
version: string;
name: string;
metrics_mAP: number;
metrics_precision: number;
metrics_recall: number;
document_count: number;
activated_at: string;
}
interface RunningTraining {
task_id: string;
name: string;
status: 'running';
started_at: string;
progress: number;
}
// Activity
interface Activity {
type: ActivityType;
description: string;
timestamp: string;
metadata: Record<string, unknown>;
}
type ActivityType =
| 'document_uploaded'
| 'annotation_modified'
| 'training_completed'
| 'training_failed'
| 'model_activated';
// Activity Response
interface ActivityResponse {
activities: Activity[];
}
```
### 5.2 React Query Hooks
```typescript
// useDashboardStats
const useDashboardStats = () => {
return useQuery({
queryKey: ['dashboard', 'stats'],
queryFn: () => api.get('/admin/dashboard/stats'),
refetchInterval: 30000, // 30 seconds
});
};
// useActiveModel
const useActiveModel = () => {
return useQuery({
queryKey: ['dashboard', 'active-model'],
queryFn: () => api.get('/admin/dashboard/active-model'),
refetchInterval: 60000, // 1 minute
});
};
// useRecentActivity
const useRecentActivity = (limit = 10) => {
return useQuery({
queryKey: ['dashboard', 'activity', limit],
queryFn: () => api.get(`/admin/dashboard/activity?limit=${limit}`),
refetchInterval: 30000,
});
};
```
---
## 6. Annotation Completeness Definition
### 6.1 Core Fields
A document is **complete** when it has annotations for:
| Requirement | Fields | Logic |
|-------------|--------|-------|
| Identifier | `invoice_number` (class_id=0) OR `ocr_number` (class_id=3) | At least one |
| Payment Account | `bankgiro` (class_id=4) OR `plusgiro` (class_id=5) | At least one |
### 6.2 Status Categories
| Category | Criteria |
|----------|----------|
| **Complete** | status=labeled AND has identifier AND has payment account |
| **Incomplete** | status=labeled AND (missing identifier OR missing payment account) |
| **Pending** | status IN (pending, auto_labeling) |
### 6.3 Filter Implementation
```sql
-- Complete documents
WHERE status = 'labeled'
AND document_id IN (
SELECT document_id FROM admin_annotations WHERE class_id IN (0, 3)
)
AND document_id IN (
SELECT document_id FROM admin_annotations WHERE class_id IN (4, 5)
)
-- Incomplete documents
WHERE status = 'labeled'
AND (
document_id NOT IN (
SELECT document_id FROM admin_annotations WHERE class_id IN (0, 3)
)
OR document_id NOT IN (
SELECT document_id FROM admin_annotations WHERE class_id IN (4, 5)
)
)
```
---
## 7. Implementation Checklist
### Backend
- [ ] Create `/api/v1/admin/dashboard/stats` endpoint
- [ ] Create `/api/v1/admin/dashboard/active-model` endpoint
- [ ] Create `/api/v1/admin/dashboard/activity` endpoint
- [ ] Add completeness calculation logic to document repository
- [ ] Implement activity aggregation query
### Frontend
- [ ] Create `DashboardOverview` component
- [ ] Create `StatsCard` component
- [ ] Create `DataQualityPanel` component with progress ring
- [ ] Create `ActiveModelPanel` component
- [ ] Create `RecentActivityPanel` component
- [ ] Create `SystemStatusBar` component
- [ ] Add React Query hooks for dashboard data
- [ ] Implement loading skeletons
- [ ] Implement error states
- [ ] Add navigation actions
- [ ] Add tooltips
### Testing
- [ ] Unit tests for completeness calculation
- [ ] Unit tests for activity aggregation
- [ ] Integration tests for dashboard endpoints
- [ ] E2E tests for dashboard interactions

View File

@@ -1,619 +0,0 @@
# 多池处理架构设计文档
## 1. 研究总结
### 1.1 当前问题分析
我们之前实现的双池模式存在稳定性问题,主要原因:
| 问题 | 原因 | 解决方案 |
|------|------|----------|
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
### 1.2 推荐架构方案
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**
```
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Main Process │ │ CPU Workers │ │ GPU Worker │
│ │ │ (4 processes) │ │ (1 process) │
│ ┌───────────┐ │ │ │ │ │
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
│ └───────────┘ │ │ │ │ │
│ ▲ │ │ │ │ │ │ │
│ │ │ │ ▼ │ │ ▼ │
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
│ │ Collector │ │ │ │ │ │
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
│ │ │
│ ▼ │
│ ┌───────────┐ │
│ │ Database │ │
│ │ Batch │ │
│ │ Writer │ │
│ └───────────┘ │
└─────────────────┘
```
---
## 2. 核心设计原则
### 2.1 CUDA 兼容性
```python
# 关键:使用 spawn 启动方式
import multiprocessing as mp
ctx = mp.get_context("spawn")
# GPU worker 初始化时设置设备
def init_gpu_worker(gpu_id: int = 0):
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
global _ocr
from paddleocr import PaddleOCR
_ocr = PaddleOCR(use_gpu=True, ...)
```
### 2.2 Worker 初始化模式
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
```python
# 全局变量保存模型
_ocr = None
def init_worker(use_gpu: bool, gpu_id: int = 0):
global _ocr
if use_gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from paddleocr import PaddleOCR
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
# 创建 Pool 时使用 initializer
pool = ProcessPoolExecutor(
max_workers=1,
initializer=init_worker,
initargs=(True, 0), # use_gpu=True, gpu_id=0
mp_context=mp.get_context("spawn")
)
```
### 2.3 队列模式 vs as_completed
| 方式 | 优点 | 缺点 | 适用场景 |
|------|------|------|----------|
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
---
## 3. 详细开发计划
### 阶段 1重构基础架构 (2-3天)
#### 1.1 创建 WorkerPool 抽象类
```python
# src/processing/worker_pool.py
from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, Future
from dataclasses import dataclass
from typing import List, Any, Optional, Callable
import multiprocessing as mp
@dataclass
class TaskResult:
"""任务结果容器"""
task_id: str
success: bool
data: Any
error: Optional[str] = None
processing_time: float = 0.0
class WorkerPool(ABC):
"""Worker Pool 抽象基类"""
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
self.max_workers = max_workers
self.use_gpu = use_gpu
self.gpu_id = gpu_id
self._executor: Optional[ProcessPoolExecutor] = None
@abstractmethod
def get_initializer(self) -> Callable:
"""返回 worker 初始化函数"""
pass
@abstractmethod
def get_init_args(self) -> tuple:
"""返回初始化参数"""
pass
def start(self):
"""启动 worker pool"""
ctx = mp.get_context("spawn")
self._executor = ProcessPoolExecutor(
max_workers=self.max_workers,
mp_context=ctx,
initializer=self.get_initializer(),
initargs=self.get_init_args()
)
def submit(self, fn: Callable, *args, **kwargs) -> Future:
"""提交任务"""
if not self._executor:
raise RuntimeError("Pool not started")
return self._executor.submit(fn, *args, **kwargs)
def shutdown(self, wait: bool = True):
"""关闭 pool"""
if self._executor:
self._executor.shutdown(wait=wait)
self._executor = None
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.shutdown()
```
#### 1.2 实现 CPU 和 GPU Worker Pool
```python
# src/processing/cpu_pool.py
class CPUWorkerPool(WorkerPool):
"""CPU-only worker pool for text PDF processing"""
def __init__(self, max_workers: int = 4):
super().__init__(max_workers=max_workers, use_gpu=False)
def get_initializer(self) -> Callable:
return init_cpu_worker
def get_init_args(self) -> tuple:
return ()
# src/processing/gpu_pool.py
class GPUWorkerPool(WorkerPool):
"""GPU worker pool for OCR processing"""
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
def get_initializer(self) -> Callable:
return init_gpu_worker
def get_init_args(self) -> tuple:
return (self.gpu_id,)
```
---
### 阶段 2实现双池协调器 (2-3天)
#### 2.1 任务分发器
```python
# src/processing/task_dispatcher.py
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Tuple
class TaskType(Enum):
CPU = auto() # Text PDF
GPU = auto() # Scanned PDF
@dataclass
class Task:
id: str
task_type: TaskType
data: Any
class TaskDispatcher:
"""根据 PDF 类型分发任务到不同的 pool"""
def classify_task(self, doc_info: dict) -> TaskType:
"""判断文档是否需要 OCR"""
# 基于 PDF 特征判断
if self._is_scanned_pdf(doc_info):
return TaskType.GPU
return TaskType.CPU
def _is_scanned_pdf(self, doc_info: dict) -> bool:
"""检测是否为扫描件"""
# 1. 检查是否有可提取文本
# 2. 检查图片比例
# 3. 检查文本密度
pass
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
"""将任务分为 CPU 和 GPU 两组"""
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
return cpu_tasks, gpu_tasks
```
#### 2.2 双池协调器
```python
# src/processing/dual_pool_coordinator.py
from concurrent.futures import as_completed
from typing import List, Iterator
import logging
logger = logging.getLogger(__name__)
class DualPoolCoordinator:
"""协调 CPU 和 GPU 两个 worker pool"""
def __init__(
self,
cpu_workers: int = 4,
gpu_workers: int = 1,
gpu_id: int = 0
):
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
self.dispatcher = TaskDispatcher()
def __enter__(self):
self.cpu_pool.start()
self.gpu_pool.start()
return self
def __exit__(self, *args):
self.cpu_pool.shutdown()
self.gpu_pool.shutdown()
def process_batch(
self,
documents: List[dict],
cpu_task_fn: Callable,
gpu_task_fn: Callable,
on_result: Optional[Callable[[TaskResult], None]] = None,
on_error: Optional[Callable[[str, Exception], None]] = None
) -> List[TaskResult]:
"""
处理一批文档,自动分发到 CPU 或 GPU pool
Args:
documents: 待处理文档列表
cpu_task_fn: CPU 任务处理函数
gpu_task_fn: GPU 任务处理函数
on_result: 结果回调(可选)
on_error: 错误回调(可选)
Returns:
所有任务结果列表
"""
# 分类任务
tasks = [
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
for doc in documents
]
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
# 提交任务到各自的 pool
cpu_futures = {
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
for t in cpu_tasks
}
gpu_futures = {
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
for t in gpu_tasks
}
# 收集结果
results = []
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
for future in as_completed(all_futures):
task_id = cpu_futures.get(future) or gpu_futures.get(future)
pool_type = "CPU" if future in cpu_futures else "GPU"
try:
data = future.result(timeout=300) # 5分钟超时
result = TaskResult(task_id=task_id, success=True, data=data)
if on_result:
on_result(result)
except Exception as e:
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
if on_error:
on_error(task_id, e)
results.append(result)
return results
```
---
### 阶段 3集成到 autolabel (1-2天)
#### 3.1 修改 autolabel.py
```python
# src/cli/autolabel.py
def run_autolabel_dual_pool(args):
"""使用双池模式运行自动标注"""
from src.processing.dual_pool_coordinator import DualPoolCoordinator
# 初始化数据库批处理
db_batch = []
db_batch_size = 100
def on_result(result: TaskResult):
"""处理成功结果"""
nonlocal db_batch
db_batch.append(result.data)
if len(db_batch) >= db_batch_size:
save_documents_batch(db_batch)
db_batch.clear()
def on_error(task_id: str, error: Exception):
"""处理错误"""
logger.error(f"Task {task_id} failed: {error}")
# 创建双池协调器
with DualPoolCoordinator(
cpu_workers=args.cpu_workers or 4,
gpu_workers=args.gpu_workers or 1,
gpu_id=0
) as coordinator:
# 处理所有 CSV
for csv_file in csv_files:
documents = load_documents_from_csv(csv_file)
results = coordinator.process_batch(
documents=documents,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=on_result,
on_error=on_error
)
logger.info(f"CSV {csv_file}: {len(results)} processed")
# 保存剩余批次
if db_batch:
save_documents_batch(db_batch)
```
---
### 阶段 4测试与验证 (1-2天)
#### 4.1 单元测试
```python
# tests/unit/test_dual_pool.py
import pytest
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
class TestDualPoolCoordinator:
def test_cpu_only_batch(self):
"""测试纯 CPU 任务批处理"""
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
results = coord.process_batch(docs, cpu_fn, gpu_fn)
assert len(results) == 10
assert all(r.success for r in results)
def test_mixed_batch(self):
"""测试混合任务批处理"""
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
docs = [
{"id": "text_1", "type": "text"},
{"id": "scan_1", "type": "scanned"},
{"id": "text_2", "type": "text"},
]
results = coord.process_batch(docs, cpu_fn, gpu_fn)
assert len(results) == 3
def test_timeout_handling(self):
"""测试超时处理"""
pass
def test_error_recovery(self):
"""测试错误恢复"""
pass
```
#### 4.2 集成测试
```python
# tests/integration/test_autolabel_dual_pool.py
def test_autolabel_with_dual_pool():
"""端到端测试双池模式"""
# 使用少量测试数据
result = subprocess.run([
"python", "-m", "src.cli.autolabel",
"--cpu-workers", "2",
"--gpu-workers", "1",
"--limit", "50"
], capture_output=True)
assert result.returncode == 0
# 验证数据库记录
```
---
## 4. 关键技术点
### 4.1 避免死锁的策略
```python
# 1. 使用 timeout
try:
result = future.result(timeout=300)
except TimeoutError:
logger.warning(f"Task timed out")
# 2. 使用哨兵值
SENTINEL = object()
queue.put(SENTINEL) # 发送结束信号
# 3. 检查进程状态
if not worker.is_alive():
logger.error("Worker died unexpectedly")
break
# 4. 先清空队列再 join
while not queue.empty():
results.append(queue.get_nowait())
worker.join(timeout=5.0)
```
### 4.2 PaddleOCR 特殊处理
```python
# PaddleOCR 必须在 worker 进程中初始化
def init_paddle_worker(gpu_id: int):
global _ocr
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# 延迟导入,确保 CUDA 环境变量生效
from paddleocr import PaddleOCR
_ocr = PaddleOCR(
use_angle_cls=True,
lang='en',
use_gpu=True,
show_log=False,
# 重要:设置 GPU 内存比例
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
)
```
### 4.3 资源监控
```python
import psutil
import GPUtil
def get_resource_usage():
"""获取系统资源使用情况"""
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
gpu_info = []
for gpu in GPUtil.getGPUs():
gpu_info.append({
"id": gpu.id,
"memory_used": gpu.memoryUsed,
"memory_total": gpu.memoryTotal,
"utilization": gpu.load * 100
})
return {
"cpu_percent": cpu_percent,
"memory_percent": memory.percent,
"gpu": gpu_info
}
```
---
## 5. 风险评估与应对
| 风险 | 可能性 | 影响 | 应对策略 |
|------|--------|------|----------|
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1设置 gpu_mem 参数 |
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
| 任务分类错误 | 中 | 中 | 添加回退机制CPU 失败后尝试 GPU |
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
---
## 6. 备选方案
如果上述方案仍存在问题,可以考虑:
### 6.1 使用 Ray
```python
import ray
ray.init()
@ray.remote(num_cpus=1)
def cpu_task(data):
return process_text_pdf(data)
@ray.remote(num_gpus=1)
def gpu_task(data):
return process_scanned_pdf(data)
# 自动资源调度
futures = [cpu_task.remote(d) for d in cpu_docs]
futures += [gpu_task.remote(d) for d in gpu_docs]
results = ray.get(futures)
```
### 6.2 单池 + 动态 GPU 调度
保持单池模式,但在每个任务内部动态决定是否使用 GPU
```python
def process_document(doc_data):
if is_scanned_pdf(doc_data):
# 使用 GPU (需要全局锁或信号量控制并发)
with gpu_semaphore:
return process_with_ocr(doc_data)
else:
return process_text_only(doc_data)
```
---
## 7. 时间线总结
| 阶段 | 任务 | 预计工作量 |
|------|------|------------|
| 阶段 1 | 基础架构重构 | 2-3 天 |
| 阶段 2 | 双池协调器实现 | 2-3 天 |
| 阶段 3 | 集成到 autolabel | 1-2 天 |
| 阶段 4 | 测试与验证 | 1-2 天 |
| **总计** | | **6-10 天** |
---
## 8. 参考资料
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)

View File

@@ -0,0 +1,35 @@
# Product Plan v2 - Change Log
## [v2.1] - 2026-02-01
### New Features
#### Epic 7: Dashboard Enhancement
- Added **US-7.1**: Data quality metrics panel showing annotation completeness rate
- Added **US-7.2**: Active model status panel with mAP/precision/recall metrics
- Added **US-7.3**: Recent activity feed showing last 10 system activities
- Added **US-7.4**: Meaningful stats cards (Total/Complete/Incomplete/Pending)
#### Annotation Completeness Definition
- Defined "annotation complete" criteria:
- Must have `invoice_number` OR `ocr_number` (identifier)
- Must have `bankgiro` OR `plusgiro` (payment account)
### New API Endpoints
- Added `GET /api/v1/admin/dashboard/stats` - Dashboard statistics with completeness calculation
- Added `GET /api/v1/admin/dashboard/active-model` - Active model info with running training status
- Added `GET /api/v1/admin/dashboard/activity` - Recent activity feed aggregated from multiple sources
### New UI Components
- Added **5.0 Dashboard Overview** wireframe with:
- Stats cards row (Total/Complete/Incomplete/Pending)
- Data Quality panel with percentage ring
- Active Model panel with metrics display
- Recent Activity list with icons and relative timestamps
- System Status bar
---
## [v2.0] - 2024-01-15
- Initial version with Epic 1-6
- Batch upload, document management, annotation workflow, training management

1448
docs/product-plan-v2.md Normal file

File diff suppressed because it is too large Load Diff

54
docs/training-flow.mmd Normal file
View File

@@ -0,0 +1,54 @@
flowchart TD
A[CLI Entry Point\nsrc/cli/train.py] --> B[Parse Arguments\n--model, --epochs, --batch, --imgsz, etc.]
B --> C[Connect PostgreSQL\nDB_HOST / DB_NAME / DB_PASSWORD]
C --> D[Load Data from DB\nsrc/yolo/db_dataset.py]
D --> D1[Scan temp/doc_id/images/\nfor rendered PNGs]
D --> D2[Batch load field_results\nfrom database - batch 500]
D1 --> E[Create DBYOLODataset]
D2 --> E
E --> F[Split Train/Val/Test\n80% / 10% / 10%\nDocument-level, seed=42]
F --> G[Export to YOLO Format]
G --> G1[Copy images to\ntrain/val/test dirs]
G --> G2[Generate .txt labels\nclass x_center y_center w h]
G --> G3[Generate dataset.yaml\n+ classes.txt]
G --> G4[Coordinate Conversion\nPDF points 72DPI -> render DPI\nNormalize to 0-1]
G1 --> H{--export-only?}
G2 --> H
G3 --> H
G4 --> H
H -- Yes --> Z[Done - Dataset exported]
H -- No --> I[Load YOLO Model]
I --> I1{--resume?}
I1 -- Yes --> I2[Load last.pt checkpoint]
I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt]
I2 --> J[Configure Training]
I3 --> J
J --> J1[Conservative Augmentation\nrotation=5 deg, translate=5%\nno flip, no mosaic, no mixup]
J --> J2[imgsz=1280, pretrained=True]
J1 --> K[model.train\nUltralytics Training Loop]
J2 --> K
K --> L[Training Outputs\nruns/train/name/]
L --> L1[weights/best.pt\nweights/last.pt]
L --> L2[results.csv + results.png\nTraining curves]
L --> L3[PR curves, F1 curves\nConfusion matrix]
L1 --> M[Test Set Validation\nmodel.val split=test]
M --> N[Report Metrics\nmAP@0.5 = 93.5%\nmAP@0.5-0.95]
N --> O[Close DB Connection]
style A fill:#4a90d9,color:#fff
style K fill:#e67e22,color:#fff
style N fill:#27ae60,color:#fff
style Z fill:#95a5a6,color:#fff

302
docs/ux-design-prompt-v2.md Normal file
View File

@@ -0,0 +1,302 @@
# Document Annotation Tool UX Design Spec v2
## Theme: Warm Graphite (Modern Enterprise)
---
## 1. Design Principles (Updated)
1. **Clarity** High contrast, but never pure black-on-white
2. **Warm Neutrality** Slightly warm grays reduce visual fatigue
3. **Focus** Content-first layouts with restrained accents
4. **Consistency** Reusable patterns, predictable behavior
5. **Professional Trust** Calm, serious, enterprise-ready
6. **Longevity** No trendy colors that age quickly
---
## 2. Color Palette (Warm Graphite)
### Core Colors
| Usage | Color Name | Hex |
|------|-----------|-----|
| Primary Text | Soft Black | #121212 |
| Secondary Text | Charcoal Gray | #2A2A2A |
| Muted Text | Warm Gray | #6B6B6B |
| Disabled Text | Light Warm Gray | #9A9A9A |
### Backgrounds
| Usage | Color | Hex |
|-----|------|-----|
| App Background | Paper White | #FAFAF8 |
| Card / Panel | White | #FFFFFF |
| Hover Surface | Subtle Warm Gray | #F1F0ED |
| Selected Row | Very Light Warm Gray | #ECEAE6 |
### Borders & Dividers
| Usage | Color | Hex |
|------|------|-----|
| Default Border | Warm Light Gray | #E6E4E1 |
| Strong Divider | Neutral Gray | #D8D6D2 |
### Semantic States (Muted & Professional)
| State | Color | Hex |
|------|-------|-----|
| Success | Olive Gray | #3E4A3A |
| Error | Brick Gray | #4A3A3A |
| Warning | Sand Gray | #4A4A3A |
| Info | Graphite Gray | #3A3A3A |
> Accent colors are **never saturated** and are used only for status, progress, or selection.
---
## 3. Typography
- **Font Family**: Inter / SF Pro / system-ui
- **Headings**:
- Weight: 600700
- Color: #121212
- Letter spacing: -0.01em
- **Body Text**:
- Weight: 400
- Color: #2A2A2A
- **Captions / Meta**:
- Weight: 400
- Color: #6B6B6B
- **Monospace (IDs / Values)**:
- JetBrains Mono / SF Mono
- Color: #2A2A2A
---
## 4. Global Layout
### Top Navigation Bar
- Height: 56px
- Background: #FAFAF8
- Bottom Border: 1px solid #E6E4E1
- Logo: Text or icon in #121212
**Navigation Items**
- Default: #6B6B6B
- Hover: #2A2A2A
- Active:
- Text: #121212
- Bottom indicator: 2px solid #3A3A3A (rounded ends)
**Avatar**
- Circle background: #ECEAE6
- Text: #2A2A2A
---
## 5. Page: Documents (Dashboard)
### Page Header
- Title: "Documents" (#121212)
- Actions:
- Primary button: Dark graphite outline
- Secondary button: Subtle border only
### Filters Bar
- Background: #FFFFFF
- Border: 1px solid #E6E4E1
- Inputs:
- Background: #FFFFFF
- Hover: #F1F0ED
- Focus ring: 1px #3A3A3A
### Document Table
- Table background: #FFFFFF
- Header text: #6B6B6B
- Row hover: #F1F0ED
- Row selected:
- Background: #ECEAE6
- Left indicator: 3px solid #3A3A3A
### Status Badges
- Pending:
- BG: #FFFFFF
- Border: #D8D6D2
- Text: #2A2A2A
- Labeled:
- BG: #2A2A2A
- Text: #FFFFFF
- Exported:
- BG: #ECEAE6
- Text: #2A2A2A
- Icon: ✓
### Auto-label States
- Running:
- Progress bar: #3A3A3A on #ECEAE6
- Completed:
- Text: #3E4A3A
- Failed:
- BG: #F1EDED
- Text: #4A3A3A
---
## 6. Upload Modals (Single & Batch)
### Modal Container
- Background: #FFFFFF
- Border radius: 8px
- Shadow: 0 1px 3px rgba(0,0,0,0.08)
### Drop Zone
- Background: #FAFAF8
- Border: 1px dashed #D8D6D2
- Hover: #F1F0ED
- Icon: Graphite gray
### Form Fields
- Input BG: #FFFFFF
- Border: #D8D6D2
- Focus: 1px solid #3A3A3A
Primary Action Button:
- Text: #FFFFFF
- BG: #2A2A2A
- Hover: #121212
---
## 7. Document Detail View
### Canvas Area
- Background: #FFFFFF
- Annotation styles:
- Manual: Solid border #2A2A2A
- Auto: Dashed border #6B6B6B
- Selected: 2px border #3A3A3A + resize handles
### Right Info Panel
- Card background: #FFFFFF
- Section headers: #121212
- Meta text: #6B6B6B
### Annotation Table
- Same table styles as Documents
- Inline edit:
- Input background: #FAFAF8
- Save button: Graphite
### Locked State (Auto-label Running)
- Banner BG: #FAFAF8
- Border-left: 3px solid #4A4A3A
- Progress bar: Graphite
---
## 8. Training Page
### Document Selector
- Selected rows use same highlight rules
- Verified state:
- Full: Olive gray check
- Partial: Sand gray warning
### Configuration Panel
- Card layout
- Inputs aligned to grid
- Schedule option visually muted until enabled
Primary CTA:
- Start Training button in dark graphite
---
## 9. Models & Training History
### Training Job List
- Job cards use #FFFFFF background
- Running job:
- Progress bar: #3A3A3A
- Completed job:
- Metrics bars in graphite
### Model Detail Panel
- Sectioned cards
- Metric bars:
- Track: #ECEAE6
- Fill: #3A3A3A
Actions:
- Primary: Download Model
- Secondary: View Logs / Use as Base
---
## 10. Micro-interactions (Refined)
| Element | Interaction | Animation |
|------|------------|-----------|
| Button hover | BG lightens | 150ms ease-out |
| Button press | Scale 0.98 | 100ms |
| Row hover | BG fade | 120ms |
| Modal open | Fade + scale 0.96 → 1 | 200ms |
| Progress fill | Smooth | ease-out |
| Annotation select | Border + handles | 120ms |
---
## 11. Tailwind Theme (Updated)
```js
colors: {
text: {
primary: '#121212',
secondary: '#2A2A2A',
muted: '#6B6B6B',
disabled: '#9A9A9A',
},
bg: {
app: '#FAFAF8',
card: '#FFFFFF',
hover: '#F1F0ED',
selected: '#ECEAE6',
},
border: '#E6E4E1',
accent: '#3A3A3A',
success: '#3E4A3A',
error: '#4A3A3A',
warning: '#4A4A3A',
}
```
---
## 12. Final Notes
- Pure black (#000000) should **never** be used as large surfaces
- Accent color usage should stay under **10% of UI area**
- Warm grays are intentional and must not be "corrected" to blue-grays
This theme is designed to scale from internal tool → polished SaaS without redesign.

View File

@@ -0,0 +1,273 @@
# Web Directory Refactoring - Complete ✅
**Date**: 2026-01-25
**Status**: ✅ Completed
**Tests**: 188 passing (0 failures)
**Coverage**: 23% (maintained)
---
## Final Directory Structure
```
src/web/
├── api/
│ ├── __init__.py
│ └── v1/
│ ├── __init__.py
│ ├── routes.py # Public inference API
│ ├── admin/
│ │ ├── __init__.py
│ │ ├── documents.py # Document management (was admin_routes.py)
│ │ ├── annotations.py # Annotation routes (was admin_annotation_routes.py)
│ │ └── training.py # Training routes (was admin_training_routes.py)
│ ├── async_api/
│ │ ├── __init__.py
│ │ └── routes.py # Async processing API (was async_routes.py)
│ └── batch/
│ ├── __init__.py
│ └── routes.py # Batch upload API (was batch_upload_routes.py)
├── schemas/
│ ├── __init__.py
│ ├── common.py # Shared models (ErrorResponse)
│ ├── admin.py # Admin schemas (was admin_schemas.py)
│ └── inference.py # Inference + async schemas (was schemas.py)
├── services/
│ ├── __init__.py
│ ├── inference.py # Inference service (was services.py)
│ ├── autolabel.py # Auto-label service (was admin_autolabel.py)
│ ├── async_processing.py # Async processing (was async_service.py)
│ └── batch_upload.py # Batch upload service (was batch_upload_service.py)
├── core/
│ ├── __init__.py
│ ├── auth.py # Authentication (was admin_auth.py)
│ ├── rate_limiter.py # Rate limiting (unchanged)
│ └── scheduler.py # Task scheduler (was admin_scheduler.py)
├── workers/
│ ├── __init__.py
│ ├── async_queue.py # Async task queue (was async_queue.py)
│ └── batch_queue.py # Batch task queue (was batch_queue.py)
├── __init__.py # Main exports
├── app.py # FastAPI app (imports updated)
├── config.py # Configuration (unchanged)
└── dependencies.py # Global dependencies (unchanged)
```
---
## Changes Summary
### Files Moved and Renamed
| Old Location | New Location | Change Type |
|-------------|--------------|-------------|
| `admin_routes.py` | `api/v1/admin/documents.py` | Moved + Renamed |
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Moved + Renamed |
| `admin_training_routes.py` | `api/v1/admin/training.py` | Moved + Renamed |
| `admin_auth.py` | `core/auth.py` | Moved |
| `admin_autolabel.py` | `services/autolabel.py` | Moved |
| `admin_scheduler.py` | `core/scheduler.py` | Moved |
| `admin_schemas.py` | `schemas/admin.py` | Moved |
| `routes.py` | `api/v1/routes.py` | Moved |
| `schemas.py` | `schemas/inference.py` | Moved |
| `services.py` | `services/inference.py` | Moved |
| `async_routes.py` | `api/v1/async_api/routes.py` | Moved |
| `async_queue.py` | `workers/async_queue.py` | Moved |
| `async_service.py` | `services/async_processing.py` | Moved + Renamed |
| `batch_queue.py` | `workers/batch_queue.py` | Moved |
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Moved |
| `batch_upload_service.py` | `services/batch_upload.py` | Moved |
**Total**: 16 files reorganized
### Files Updated
**Source Files** (imports updated):
- `app.py` - Updated all imports to new structure
- `api/v1/admin/documents.py` - Updated schema/auth imports
- `api/v1/admin/annotations.py` - Updated schema/service imports
- `api/v1/admin/training.py` - Updated schema/auth imports
- `api/v1/routes.py` - Updated schema imports
- `api/v1/async_api/routes.py` - Updated schema imports
- `api/v1/batch/routes.py` - Updated service/worker imports
- `services/async_processing.py` - Updated worker/core imports
**Test Files** (all 15 updated):
- `test_admin_annotations.py`
- `test_admin_auth.py`
- `test_admin_routes.py`
- `test_admin_routes_enhanced.py`
- `test_admin_training.py`
- `test_annotation_locks.py`
- `test_annotation_phase5.py`
- `test_async_queue.py`
- `test_async_routes.py`
- `test_async_service.py`
- `test_autolabel_with_locks.py`
- `test_batch_queue.py`
- `test_batch_upload_routes.py`
- `test_batch_upload_service.py`
- `test_training_phase4.py`
- `conftest.py`
---
## Import Examples
### Old Import Style (Before Refactoring)
```python
from src.web.admin_routes import create_admin_router
from src.web.admin_schemas import DocumentItem
from src.web.admin_auth import validate_admin_token
from src.web.async_routes import create_async_router
from src.web.schemas import ErrorResponse
```
### New Import Style (After Refactoring)
```python
# Admin API
from src.web.api.v1.admin.documents import create_admin_router
from src.web.api.v1.admin import create_admin_router # Shorter alternative
# Schemas
from src.web.schemas.admin import DocumentItem
from src.web.schemas.common import ErrorResponse
# Core components
from src.web.core.auth import validate_admin_token
# Async API
from src.web.api.v1.async_api.routes import create_async_router
```
---
## Benefits Achieved
### 1. **Clear Separation of Concerns**
- **API Routes**: All in `api/v1/` by version and feature
- **Data Models**: All in `schemas/` by domain
- **Business Logic**: All in `services/`
- **Core Components**: Reusable utilities in `core/`
- **Background Jobs**: Task queues in `workers/`
### 2. **Better Scalability**
- Easy to add API v2 without touching v1
- Clear namespace for each module
- Reduced file sizes (no 800+ line files)
- Follows single responsibility principle
### 3. **Improved Maintainability**
- Find files by function, not by prefix
- Each module has one clear purpose
- Easier to onboard new developers
- Better IDE navigation
### 4. **Standards Compliance**
- Follows FastAPI best practices
- Matches Django/Flask project structures
- Standard Python package organization
- Industry-standard naming conventions
---
## Testing Results
**Before Refactoring**:
- 188 tests passing
- 23% code coverage
- Flat directory structure
**After Refactoring**:
- ✅ 188 tests passing (0 failures)
- ✅ 23% code coverage (maintained)
- ✅ Clean hierarchical structure
- ✅ All imports updated
- ✅ No backward compatibility shims needed
---
## Migration Statistics
| Metric | Count |
|--------|-------|
| Files moved | 16 |
| Directories created | 9 |
| Files updated (source) | 8 |
| Files updated (tests) | 16 |
| Import statements updated | ~150 |
| Lines of code changed | ~200 |
| Tests broken | 0 |
| Coverage lost | 0% |
---
## Code Diff Summary
```diff
Before:
src/web/
├── admin_routes.py (645 lines)
├── admin_annotation_routes.py (504 lines)
├── admin_training_routes.py (565 lines)
├── admin_auth.py (22 lines)
├── admin_schemas.py (262 lines)
... (15 more files at root level)
After:
src/web/
├── api/v1/
│ ├── admin/ (3 route files)
│ ├── async_api/ (1 route file)
│ └── batch/ (1 route file)
├── schemas/ (3 schema files)
├── services/ (4 service files)
├── core/ (3 core files)
└── workers/ (2 worker files)
```
---
## Next Steps (Optional)
### Phase 2: Documentation
- [ ] Update API documentation with new import paths
- [ ] Create migration guide for external developers
- [ ] Update CLAUDE.md with new structure
### Phase 3: Further Optimization
- [ ] Split large files (>400 lines) if needed
- [ ] Extract common utilities
- [ ] Add typing stubs
### Phase 4: Deprecation (Future)
- [ ] Add deprecation warnings if creating compatibility layer
- [ ] Remove old imports after grace period
- [ ] Update all documentation
---
## Rollback Instructions
If needed, rollback is simple:
```bash
git revert <commit-hash>
```
All changes are in version control, making rollback safe and easy.
---
## Conclusion
**Refactoring completed successfully**
**Zero breaking changes**
**All tests passing**
**Industry-standard structure achieved**
The web directory is now organized following Python and FastAPI best practices, making it easier to scale, maintain, and extend.

View File

@@ -0,0 +1,186 @@
# Web Directory Refactoring Plan
## Current Structure Issues
1. **Flat structure**: All files in one directory (20 Python files)
2. **Naming inconsistency**: Mix of `admin_*`, `async_*`, `batch_*` prefixes
3. **Mixed concerns**: Routes, schemas, services, and workers in same directory
4. **Poor scalability**: Hard to navigate and maintain as project grows
## Proposed Structure (Best Practices)
```
src/web/
├── __init__.py # Main exports
├── app.py # FastAPI app factory
├── config.py # App configuration
├── dependencies.py # Global dependencies
├── api/ # API Routes Layer
│ ├── __init__.py
│ └── v1/ # API version 1
│ ├── __init__.py
│ ├── routes.py # Public API routes (inference)
│ ├── admin/ # Admin API routes
│ │ ├── __init__.py
│ │ ├── documents.py # admin_routes.py → documents.py
│ │ ├── annotations.py # admin_annotation_routes.py → annotations.py
│ │ ├── training.py # admin_training_routes.py → training.py
│ │ └── auth.py # admin_auth.py → auth.py (routes only)
│ ├── async_api/ # Async processing API
│ │ ├── __init__.py
│ │ └── routes.py # async_routes.py → routes.py
│ └── batch/ # Batch upload API
│ ├── __init__.py
│ └── routes.py # batch_upload_routes.py → routes.py
├── schemas/ # Pydantic Models
│ ├── __init__.py
│ ├── common.py # Shared schemas (ErrorResponse, etc.)
│ ├── inference.py # schemas.py → inference.py
│ ├── admin.py # admin_schemas.py → admin.py
│ ├── async_api.py # New: async API schemas
│ └── batch.py # New: batch upload schemas
├── services/ # Business Logic Layer
│ ├── __init__.py
│ ├── inference.py # services.py → inference.py
│ ├── autolabel.py # admin_autolabel.py → autolabel.py
│ ├── async_processing.py # async_service.py → async_processing.py
│ └── batch_upload.py # batch_upload_service.py → batch_upload.py
├── core/ # Core Components
│ ├── __init__.py
│ ├── auth.py # admin_auth.py → auth.py (logic only)
│ ├── rate_limiter.py # rate_limiter.py → rate_limiter.py
│ └── scheduler.py # admin_scheduler.py → scheduler.py
└── workers/ # Background Task Queues
├── __init__.py
├── async_queue.py # async_queue.py → async_queue.py
└── batch_queue.py # batch_queue.py → batch_queue.py
```
## File Mapping
### Current → New Location
| Current File | New Location | Purpose |
|--------------|--------------|---------|
| `admin_routes.py` | `api/v1/admin/documents.py` | Document management routes |
| `admin_annotation_routes.py` | `api/v1/admin/annotations.py` | Annotation routes |
| `admin_training_routes.py` | `api/v1/admin/training.py` | Training routes |
| `admin_auth.py` | Split: `api/v1/admin/auth.py` + `core/auth.py` | Auth routes + logic |
| `admin_schemas.py` | `schemas/admin.py` | Admin Pydantic models |
| `admin_autolabel.py` | `services/autolabel.py` | Auto-label service |
| `admin_scheduler.py` | `core/scheduler.py` | Training scheduler |
| `routes.py` | `api/v1/routes.py` | Public inference API |
| `schemas.py` | `schemas/inference.py` | Inference models |
| `services.py` | `services/inference.py` | Inference service |
| `async_routes.py` | `api/v1/async_api/routes.py` | Async API routes |
| `async_service.py` | `services/async_processing.py` | Async processing service |
| `async_queue.py` | `workers/async_queue.py` | Async task queue |
| `batch_upload_routes.py` | `api/v1/batch/routes.py` | Batch upload routes |
| `batch_upload_service.py` | `services/batch_upload.py` | Batch upload service |
| `batch_queue.py` | `workers/batch_queue.py` | Batch task queue |
| `rate_limiter.py` | `core/rate_limiter.py` | Rate limiting logic |
| `config.py` | `config.py` | Keep as-is |
| `dependencies.py` | `dependencies.py` | Keep as-is |
| `app.py` | `app.py` | Keep as-is (update imports) |
## Benefits
### 1. Clear Separation of Concerns
- **Routes**: API endpoint definitions
- **Schemas**: Data validation models
- **Services**: Business logic
- **Core**: Reusable components
- **Workers**: Background processing
### 2. Better Scalability
- Easy to add new API versions (`v2/`)
- Clear namespace for each domain
- Reduced file size (no 800+ line files)
### 3. Improved Maintainability
- Find files by function, not by prefix
- Each module has single responsibility
- Easier to write focused tests
### 4. Standard Python Patterns
- Package-based organization
- Follows FastAPI best practices
- Similar to Django/Flask structures
## Implementation Steps
### Phase 1: Create New Structure (No Breaking Changes)
1. Create new directories: `api/`, `schemas/`, `services/`, `core/`, `workers/`
2. Copy files to new locations (don't delete originals yet)
3. Update imports in new files
4. Add `__init__.py` with proper exports
### Phase 2: Update Tests
5. Update test imports to use new structure
6. Run tests to verify nothing breaks
7. Fix any import issues
### Phase 3: Update Main App
8. Update `app.py` to import from new locations
9. Run full test suite
10. Verify all endpoints work
### Phase 4: Cleanup
11. Delete old files
12. Update documentation
13. Final test run
## Migration Priority
**High Priority** (Most used):
- Routes and schemas (user-facing APIs)
- Services (core business logic)
**Medium Priority**:
- Core components (auth, rate limiter)
- Workers (background tasks)
**Low Priority**:
- Config and dependencies (already well-located)
## Backwards Compatibility
During migration, maintain backwards compatibility:
```python
# src/web/__init__.py
# Old imports still work
from src.web.api.v1.admin.documents import router as admin_router
from src.web.schemas.admin import AdminDocument
# Keep old names for compatibility (temporary)
admin_routes = admin_router # Deprecated alias
```
## Testing Strategy
1. **Unit Tests**: Test each module independently
2. **Integration Tests**: Test API endpoints still work
3. **Import Tests**: Verify all old imports still work
4. **Coverage**: Maintain current 23% coverage minimum
## Rollback Plan
If issues arise:
1. Keep old files until fully migrated
2. Git allows easy revert
3. Tests catch breaking changes early
---
## Next Steps
Would you like me to:
1. **Start Phase 1**: Create new directory structure and move files?
2. **Create migration script**: Automate the file moves and import updates?
3. **Focus on specific area**: Start with admin API or async API first?

View File

@@ -0,0 +1,218 @@
# Web Directory Refactoring - Current Status
## ✅ Completed Steps
### 1. Directory Structure Created
```
src/web/
├── api/
│ ├── v1/
│ │ ├── admin/ (documents.py, annotations.py, training.py)
│ │ ├── async_api/ (routes.py)
│ │ ├── batch/ (routes.py)
│ │ └── routes.py (public inference API)
├── schemas/
│ ├── admin.py (admin schemas)
│ ├── inference.py (inference + async schemas)
│ └── common.py (ErrorResponse)
├── services/
│ ├── autolabel.py
│ ├── async_processing.py
│ ├── batch_upload.py
│ └── inference.py
├── core/
│ ├── auth.py
│ ├── rate_limiter.py
│ └── scheduler.py
└── workers/
├── async_queue.py
└── batch_queue.py
```
### 2. Files Copied and Imports Updated
#### Admin API (✅ Complete)
- [x] `admin_routes.py``api/v1/admin/documents.py` (imports updated)
- [x] `admin_annotation_routes.py``api/v1/admin/annotations.py` (imports updated)
- [x] `admin_training_routes.py``api/v1/admin/training.py` (imports updated)
- [x] `api/v1/admin/__init__.py` created with exports
#### Public & Async API (✅ Complete)
- [x] `routes.py``api/v1/routes.py` (imports updated)
- [x] `async_routes.py``api/v1/async_api/routes.py` (imports updated)
- [x] `batch_upload_routes.py``api/v1/batch/routes.py` (copied, imports pending)
#### Schemas (✅ Complete)
- [x] `admin_schemas.py``schemas/admin.py`
- [x] `schemas.py``schemas/inference.py`
- [x] `schemas/common.py` created
- [x] `schemas/__init__.py` created with exports
#### Services (✅ Complete)
- [x] `admin_autolabel.py``services/autolabel.py`
- [x] `async_service.py``services/async_processing.py`
- [x] `batch_upload_service.py``services/batch_upload.py`
- [x] `services.py``services/inference.py`
- [x] `services/__init__.py` created
#### Core Components (✅ Complete)
- [x] `admin_auth.py``core/auth.py`
- [x] `rate_limiter.py``core/rate_limiter.py`
- [x] `admin_scheduler.py``core/scheduler.py`
- [x] `core/__init__.py` created
#### Workers (✅ Complete)
- [x] `async_queue.py``workers/async_queue.py`
- [x] `batch_queue.py``workers/batch_queue.py`
- [x] `workers/__init__.py` created
#### Main App (✅ Complete)
- [x] `app.py` imports updated to use new structure
---
## ⏳ Remaining Work
### 1. Update Remaining File Imports (HIGH PRIORITY)
Files that need import updates:
- [ ] `api/v1/batch/routes.py` - update to use new schema/service imports
- [ ] `services/autolabel.py` - may need import updates if it references old paths
- [ ] `services/async_processing.py` - check for old import references
- [ ] `services/batch_upload.py` - check for old import references
- [ ] `services/inference.py` - check for old import references
### 2. Update ALL Test Files (CRITICAL)
Test files need to import from new locations. Pattern:
**Old:**
```python
from src.web.admin_routes import create_admin_router
from src.web.admin_schemas import DocumentItem
from src.web.admin_auth import validate_admin_token
```
**New:**
```python
from src.web.api.v1.admin import create_admin_router
from src.web.schemas.admin import DocumentItem
from src.web.core.auth import validate_admin_token
```
Test files to update:
- [ ] `tests/web/test_admin_annotations.py`
- [ ] `tests/web/test_admin_auth.py`
- [ ] `tests/web/test_admin_routes.py`
- [ ] `tests/web/test_admin_routes_enhanced.py`
- [ ] `tests/web/test_admin_training.py`
- [ ] `tests/web/test_annotation_locks.py`
- [ ] `tests/web/test_annotation_phase5.py`
- [ ] `tests/web/test_async_queue.py`
- [ ] `tests/web/test_async_routes.py`
- [ ] `tests/web/test_async_service.py`
- [ ] `tests/web/test_autolabel_with_locks.py`
- [ ] `tests/web/test_batch_queue.py`
- [ ] `tests/web/test_batch_upload_routes.py`
- [ ] `tests/web/test_batch_upload_service.py`
- [ ] `tests/web/test_rate_limiter.py`
- [ ] `tests/web/test_training_phase4.py`
### 3. Create Backward Compatibility Layer (OPTIONAL)
Keep old imports working temporarily:
```python
# src/web/admin_routes.py (temporary compatibility shim)
\"\"\"
DEPRECATED: Use src.web.api.v1.admin.documents instead.
This file will be removed in next version.
\"\"\"
import warnings
from src.web.api.v1.admin.documents import *
warnings.warn(
"Importing from src.web.admin_routes is deprecated. "
"Use src.web.api.v1.admin.documents instead.",
DeprecationWarning,
stacklevel=2
)
```
### 4. Verify and Test
1. Run tests:
```bash
pytest tests/web/ -v
```
2. Check for any import errors:
```bash
python -c "from src.web.app import create_app; create_app()"
```
3. Start server and test endpoints:
```bash
python run_server.py
```
### 5. Clean Up Old Files (ONLY AFTER TESTS PASS)
Old files to remove:
- `src/web/admin_*.py` (7 files)
- `src/web/async_*.py` (3 files)
- `src/web/batch_*.py` (3 files)
- `src/web/routes.py`
- `src/web/services.py`
- `src/web/schemas.py`
- `src/web/rate_limiter.py`
Keep these files (don't remove):
- `src/web/__init__.py`
- `src/web/app.py`
- `src/web/config.py`
- `src/web/dependencies.py`
---
## 🎯 Next Immediate Steps
1. **Update batch/routes.py imports** - Quick fix for remaining API route
2. **Update test file imports** - Critical for verification
3. **Run test suite** - Verify nothing broke
4. **Fix any import errors** - Address failures
5. **Remove old files** - Clean up after tests pass
---
## 📊 Migration Impact Summary
| Category | Files Moved | Imports Updated | Status |
|----------|-------------|-----------------|--------|
| API Routes | 7 | 5/7 | 🟡 In Progress |
| Schemas | 3 | 3/3 | ✅ Complete |
| Services | 4 | 0/4 | ⚠️ Pending |
| Core | 3 | 3/3 | ✅ Complete |
| Workers | 2 | 2/2 | ✅ Complete |
| Tests | 0 | 0/16 | ❌ Not Started |
**Overall Progress: 65%**
---
## 🚀 Benefits After Migration
1. **Better Organization**: Clear separation by function
2. **Easier Navigation**: Find files by purpose, not prefix
3. **Scalability**: Easy to add new API versions
4. **Standard Structure**: Follows FastAPI best practices
5. **Maintainability**: Each module has single responsibility
---
## 📝 Notes
- All original files are still in place (no data loss risk)
- New structure is operational but needs import updates
- Backward compatibility can be added if needed
- Tests will validate the migration success

5
frontend/.env.example Normal file
View File

@@ -0,0 +1,5 @@
# Backend API URL
VITE_API_URL=http://localhost:8000
# WebSocket URL (for future real-time updates)
VITE_WS_URL=ws://localhost:8000/ws

24
frontend/.gitignore vendored Normal file
View File

@@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

20
frontend/README.md Normal file
View File

@@ -0,0 +1,20 @@
<div align="center">
<img width="1200" height="475" alt="GHBanner" src="https://github.com/user-attachments/assets/0aa67016-6eaf-458a-adb2-6e31a0763ed6" />
</div>
# Run and deploy your AI Studio app
This contains everything you need to run your app locally.
View your app in AI Studio: https://ai.studio/apps/drive/13hqd80ft4g_LngMYB8LLJxx2XU8C_eI4
## Run Locally
**Prerequisites:** Node.js
1. Install dependencies:
`npm install`
2. Set the `GEMINI_API_KEY` in [.env.local](.env.local) to your Gemini API key
3. Run the app:
`npm run dev`

View File

@@ -0,0 +1,240 @@
# Frontend Refactoring Plan
## Current Structure Issues
1. **Flat component organization** - All components in one directory
2. **Mock data only** - No real API integration
3. **No state management** - Props drilling everywhere
4. **CDN dependencies** - Should use npm packages
5. **Manual routing** - Using useState instead of react-router
6. **No TypeScript integration with backend** - Types don't match API schemas
## Recommended Structure
```
frontend/
├── public/
│ └── favicon.ico
├── src/
│ ├── api/ # API Layer
│ │ ├── client.ts # Axios instance + interceptors
│ │ ├── types.ts # API request/response types
│ │ └── endpoints/
│ │ ├── documents.ts # GET /api/v1/admin/documents
│ │ ├── annotations.ts # GET/POST /api/v1/admin/documents/{id}/annotations
│ │ ├── training.ts # GET/POST /api/v1/admin/training/*
│ │ ├── inference.ts # POST /api/v1/infer
│ │ └── async.ts # POST /api/v1/async/submit
│ │
│ ├── components/
│ │ ├── common/ # Reusable components
│ │ │ ├── Badge.tsx
│ │ │ ├── Button.tsx
│ │ │ ├── Input.tsx
│ │ │ ├── Modal.tsx
│ │ │ ├── Table.tsx
│ │ │ ├── ProgressBar.tsx
│ │ │ └── StatusBadge.tsx
│ │ │
│ │ ├── layout/ # Layout components
│ │ │ ├── TopNav.tsx
│ │ │ ├── Sidebar.tsx
│ │ │ └── PageHeader.tsx
│ │ │
│ │ ├── documents/ # Document-specific components
│ │ │ ├── DocumentTable.tsx
│ │ │ ├── DocumentFilters.tsx
│ │ │ ├── DocumentRow.tsx
│ │ │ ├── UploadModal.tsx
│ │ │ └── BatchUploadModal.tsx
│ │ │
│ │ ├── annotations/ # Annotation components
│ │ │ ├── AnnotationCanvas.tsx
│ │ │ ├── AnnotationBox.tsx
│ │ │ ├── AnnotationTable.tsx
│ │ │ ├── FieldEditor.tsx
│ │ │ └── VerificationPanel.tsx
│ │ │
│ │ └── training/ # Training components
│ │ ├── DocumentSelector.tsx
│ │ ├── TrainingConfig.tsx
│ │ ├── TrainingJobList.tsx
│ │ ├── ModelCard.tsx
│ │ └── MetricsChart.tsx
│ │
│ ├── pages/ # Page-level components
│ │ ├── DocumentsPage.tsx # Was Dashboard.tsx
│ │ ├── DocumentDetailPage.tsx # Was DocumentDetail.tsx
│ │ ├── TrainingPage.tsx # Was Training.tsx
│ │ ├── ModelsPage.tsx # Was Models.tsx
│ │ └── InferencePage.tsx # New: Test inference
│ │
│ ├── hooks/ # Custom React Hooks
│ │ ├── useDocuments.ts # Document CRUD + listing
│ │ ├── useAnnotations.ts # Annotation management
│ │ ├── useTraining.ts # Training jobs
│ │ ├── usePolling.ts # Auto-refresh for async jobs
│ │ └── useDebounce.ts # Debounce search inputs
│ │
│ ├── store/ # State Management (Zustand)
│ │ ├── documentsStore.ts
│ │ ├── annotationsStore.ts
│ │ ├── trainingStore.ts
│ │ └── uiStore.ts
│ │
│ ├── types/ # TypeScript Types
│ │ ├── index.ts
│ │ ├── document.ts
│ │ ├── annotation.ts
│ │ ├── training.ts
│ │ └── api.ts
│ │
│ ├── utils/ # Utility Functions
│ │ ├── formatters.ts # Date, currency, etc.
│ │ ├── validators.ts # Form validation
│ │ └── constants.ts # Field definitions, statuses
│ │
│ ├── styles/
│ │ └── index.css # Tailwind entry
│ │
│ ├── App.tsx
│ ├── main.tsx
│ └── router.tsx # React Router config
├── .env.example
├── package.json
├── tsconfig.json
├── vite.config.ts
├── tailwind.config.js
├── postcss.config.js
└── index.html
```
## Migration Steps
### Phase 1: Setup Infrastructure
- [ ] Install dependencies (axios, react-router, zustand, @tanstack/react-query)
- [ ] Setup local Tailwind (remove CDN)
- [ ] Create API client with interceptors
- [ ] Add environment variables (.env.local with VITE_API_URL)
### Phase 2: Create API Layer
- [ ] Create `src/api/client.ts` with axios instance
- [ ] Create `src/api/endpoints/documents.ts` matching backend API
- [ ] Create `src/api/endpoints/annotations.ts`
- [ ] Create `src/api/endpoints/training.ts`
- [ ] Add types matching backend schemas
### Phase 3: Reorganize Components
- [ ] Move existing components to new structure
- [ ] Split large components (Dashboard > DocumentTable + DocumentFilters + DocumentRow)
- [ ] Extract reusable components (Badge, Button already done)
- [ ] Create layout components (TopNav, Sidebar)
### Phase 4: Add Routing
- [ ] Install react-router-dom
- [ ] Create router.tsx with routes
- [ ] Update App.tsx to use RouterProvider
- [ ] Add navigation links
### Phase 5: State Management
- [ ] Create custom hooks (useDocuments, useAnnotations)
- [ ] Use @tanstack/react-query for server state
- [ ] Add Zustand stores for UI state
- [ ] Replace mock data with API calls
### Phase 6: Backend Integration
- [ ] Update CORS settings in backend
- [ ] Test all API endpoints
- [ ] Add error handling
- [ ] Add loading states
## Dependencies to Add
```json
{
"dependencies": {
"react-router-dom": "^6.22.0",
"axios": "^1.6.7",
"zustand": "^4.5.0",
"@tanstack/react-query": "^5.20.0",
"date-fns": "^3.3.0",
"clsx": "^2.1.0"
},
"devDependencies": {
"tailwindcss": "^3.4.1",
"autoprefixer": "^10.4.17",
"postcss": "^8.4.35"
}
}
```
## Configuration Files to Create
### tailwind.config.js
```javascript
export default {
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
theme: {
extend: {
colors: {
warm: {
bg: '#FAFAF8',
card: '#FFFFFF',
hover: '#F1F0ED',
selected: '#ECEAE6',
border: '#E6E4E1',
divider: '#D8D6D2',
text: {
primary: '#121212',
secondary: '#2A2A2A',
muted: '#6B6B6B',
disabled: '#9A9A9A',
},
state: {
success: '#3E4A3A',
error: '#4A3A3A',
warning: '#4A4A3A',
info: '#3A3A3A',
}
}
}
}
}
}
```
### .env.example
```bash
VITE_API_URL=http://localhost:8000
VITE_WS_URL=ws://localhost:8000/ws
```
## Type Generation from Backend
Consider generating TypeScript types from Python Pydantic schemas:
- Option 1: Use `datamodel-code-generator` to convert schemas
- Option 2: Manually maintain types in `src/types/api.ts`
- Option 3: Use OpenAPI spec + openapi-typescript-codegen
## Testing Strategy
- Unit tests: Vitest for components
- Integration tests: React Testing Library
- E2E tests: Playwright (matching backend)
## Performance Considerations
- Code splitting by route
- Lazy load heavy components (AnnotationCanvas)
- Optimize re-renders with React.memo
- Use virtual scrolling for large tables
- Image lazy loading for document previews
## Accessibility
- Proper ARIA labels
- Keyboard navigation
- Focus management
- Color contrast compliance (already done with Warm Graphite theme)

256
frontend/SETUP.md Normal file
View File

@@ -0,0 +1,256 @@
# Frontend Setup Guide
## Quick Start
### 1. Install Dependencies
```bash
cd frontend
npm install
```
### 2. Configure Environment
Copy `.env.example` to `.env.local` and update if needed:
```bash
cp .env.example .env.local
```
Default configuration:
```
VITE_API_URL=http://localhost:8000
VITE_WS_URL=ws://localhost:8000/ws
```
### 3. Start Backend API
Make sure the backend is running first:
```bash
# From project root
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python run_server.py"
```
Backend will be available at: http://localhost:8000
### 4. Start Frontend Dev Server
```bash
cd frontend
npm run dev
```
Frontend will be available at: http://localhost:3000
## Project Structure
```
frontend/
├── src/
│ ├── api/ # API client layer
│ │ ├── client.ts # Axios instance with interceptors
│ │ ├── types.ts # API type definitions
│ │ └── endpoints/
│ │ ├── documents.ts # Document API calls
│ │ ├── annotations.ts # Annotation API calls
│ │ └── training.ts # Training API calls
│ │
│ ├── components/ # React components
│ │ └── Dashboard.tsx # Updated with real API integration
│ │
│ ├── hooks/ # Custom React Hooks
│ │ ├── useDocuments.ts
│ │ ├── useDocumentDetail.ts
│ │ ├── useAnnotations.ts
│ │ └── useTraining.ts
│ │
│ ├── styles/
│ │ └── index.css # Tailwind CSS entry
│ │
│ ├── App.tsx
│ └── main.tsx # App entry point with QueryClient
├── components/ # Legacy components (to be migrated)
│ ├── Badge.tsx
│ ├── Button.tsx
│ ├── Layout.tsx
│ ├── DocumentDetail.tsx
│ ├── Training.tsx
│ ├── Models.tsx
│ └── UploadModal.tsx
├── tailwind.config.js # Tailwind configuration
├── postcss.config.js
├── vite.config.ts
├── package.json
└── index.html
```
## Key Technologies
- **React 19** - UI framework
- **TypeScript** - Type safety
- **Vite** - Build tool
- **Tailwind CSS** - Styling (Warm Graphite theme)
- **Axios** - HTTP client
- **@tanstack/react-query** - Server state management
- **lucide-react** - Icon library
## API Integration
### Authentication
The app stores admin token in localStorage:
```typescript
localStorage.setItem('admin_token', 'your-token')
```
All API requests automatically include the `X-Admin-Token` header.
### Available Hooks
#### useDocuments
```typescript
const {
documents,
total,
isLoading,
uploadDocument,
deleteDocument,
triggerAutoLabel,
} = useDocuments({ status: 'labeled', limit: 20 })
```
#### useDocumentDetail
```typescript
const { document, annotations, isLoading } = useDocumentDetail(documentId)
```
#### useAnnotations
```typescript
const {
createAnnotation,
updateAnnotation,
deleteAnnotation,
verifyAnnotation,
overrideAnnotation,
} = useAnnotations(documentId)
```
#### useTraining
```typescript
const {
models,
isLoadingModels,
startTraining,
downloadModel,
} = useTraining()
```
## Features Implemented
### Phase 1 (Completed)
- ✅ API client with axios interceptors
- ✅ Type-safe API endpoints
- ✅ React Query for server state
- ✅ Custom hooks for all APIs
- ✅ Dashboard with real data
- ✅ Local Tailwind CSS
- ✅ Environment configuration
- ✅ CORS configured in backend
### Phase 2 (TODO)
- [ ] Update DocumentDetail to use useDocumentDetail
- [ ] Update Training page to use useTraining hooks
- [ ] Update Models page with real data
- [ ] Add UploadModal integration with API
- [ ] Add react-router for proper routing
- [ ] Add error boundary
- [ ] Add loading states
- [ ] Add toast notifications
### Phase 3 (TODO)
- [ ] Annotation canvas with real data
- [ ] Batch upload functionality
- [ ] Auto-label progress polling
- [ ] Training job monitoring
- [ ] Model download functionality
- [ ] Search and filtering
- [ ] Pagination
## Development Tips
### Hot Module Replacement
Vite supports HMR. Changes will reflect immediately without page reload.
### API Debugging
Check browser console for API requests:
- Network tab shows all requests/responses
- Axios interceptors log errors automatically
### Type Safety
TypeScript types in `src/api/types.ts` match backend Pydantic schemas.
To regenerate types from backend:
```bash
# TODO: Add type generation script
```
### Backend API Documentation
Visit http://localhost:8000/docs for interactive API documentation (Swagger UI).
## Troubleshooting
### CORS Errors
If you see CORS errors:
1. Check backend is running at http://localhost:8000
2. Verify CORS settings in `src/web/app.py`
3. Check `.env.local` has correct `VITE_API_URL`
### Module Not Found
If imports fail:
```bash
rm -rf node_modules package-lock.json
npm install
```
### Types Not Matching
If API responses don't match types:
1. Check backend version is up-to-date
2. Verify types in `src/api/types.ts`
3. Check API response in Network tab
## Next Steps
1. Run `npm install` to install dependencies
2. Start backend server
3. Run `npm run dev` to start frontend
4. Open http://localhost:3000
5. Create an admin token via backend API
6. Store token in localStorage via browser console:
```javascript
localStorage.setItem('admin_token', 'your-token-here')
```
7. Refresh page to see authenticated API calls
## Production Build
```bash
npm run build
npm run preview # Preview production build
```
Build output will be in `dist/` directory.

15
frontend/index.html Normal file
View File

@@ -0,0 +1,15 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Graphite Annotator - Invoice Field Extraction</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

5
frontend/metadata.json Normal file
View File

@@ -0,0 +1,5 @@
{
"name": "Graphite Annotator",
"description": "A professional, warm graphite themed document annotation and training tool for enterprise use cases.",
"requestFramePermissions": []
}

4899
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

41
frontend/package.json Normal file
View File

@@ -0,0 +1,41 @@
{
"name": "graphite-annotator",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview",
"test": "vitest run",
"test:watch": "vitest",
"test:coverage": "vitest run --coverage"
},
"dependencies": {
"@tanstack/react-query": "^5.20.0",
"axios": "^1.6.7",
"clsx": "^2.1.0",
"date-fns": "^3.3.0",
"lucide-react": "^0.563.0",
"react": "^19.2.3",
"react-dom": "^19.2.3",
"react-router-dom": "^6.22.0",
"recharts": "^3.7.0",
"zustand": "^4.5.0"
},
"devDependencies": {
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.2",
"@testing-library/user-event": "^14.6.1",
"@types/node": "^22.14.0",
"@vitejs/plugin-react": "^5.0.0",
"@vitest/coverage-v8": "^4.0.18",
"autoprefixer": "^10.4.17",
"jsdom": "^27.4.0",
"postcss": "^8.4.35",
"tailwindcss": "^3.4.1",
"typescript": "~5.8.2",
"vite": "^6.2.0",
"vitest": "^4.0.18"
}
}

View File

@@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

81
frontend/src/App.tsx Normal file
View File

@@ -0,0 +1,81 @@
import React, { useState, useEffect } from 'react'
import { Layout } from './components/Layout'
import { DashboardOverview } from './components/DashboardOverview'
import { Dashboard } from './components/Dashboard'
import { DocumentDetail } from './components/DocumentDetail'
import { Training } from './components/Training'
import { DatasetDetail } from './components/DatasetDetail'
import { Models } from './components/Models'
import { Login } from './components/Login'
import { InferenceDemo } from './components/InferenceDemo'
const App: React.FC = () => {
const [currentView, setCurrentView] = useState('dashboard')
const [selectedDocId, setSelectedDocId] = useState<string | null>(null)
const [isAuthenticated, setIsAuthenticated] = useState(false)
useEffect(() => {
const token = localStorage.getItem('admin_token')
setIsAuthenticated(!!token)
}, [])
const handleNavigate = (view: string, docId?: string) => {
setCurrentView(view)
if (docId) {
setSelectedDocId(docId)
}
}
const handleLogin = (token: string) => {
setIsAuthenticated(true)
}
const handleLogout = () => {
localStorage.removeItem('admin_token')
setIsAuthenticated(false)
setCurrentView('documents')
}
if (!isAuthenticated) {
return <Login onLogin={handleLogin} />
}
const renderContent = () => {
switch (currentView) {
case 'dashboard':
return <DashboardOverview onNavigate={handleNavigate} />
case 'documents':
return <Dashboard onNavigate={handleNavigate} />
case 'detail':
return (
<DocumentDetail
docId={selectedDocId || '1'}
onBack={() => setCurrentView('documents')}
/>
)
case 'demo':
return <InferenceDemo />
case 'training':
return <Training onNavigate={handleNavigate} />
case 'dataset-detail':
return (
<DatasetDetail
datasetId={selectedDocId || ''}
onBack={() => setCurrentView('training')}
/>
)
case 'models':
return <Models />
default:
return <DashboardOverview onNavigate={handleNavigate} />
}
}
return (
<Layout activeView={currentView} onNavigate={handleNavigate} onLogout={handleLogout}>
{renderContent()}
</Layout>
)
}
export default App

View File

@@ -0,0 +1,41 @@
import axios, { AxiosInstance, AxiosError } from 'axios'
const apiClient: AxiosInstance = axios.create({
baseURL: import.meta.env.VITE_API_URL || 'http://localhost:8000',
headers: {
'Content-Type': 'application/json',
},
timeout: 30000,
})
apiClient.interceptors.request.use(
(config) => {
const token = localStorage.getItem('admin_token')
if (token) {
config.headers['X-Admin-Token'] = token
}
return config
},
(error) => {
return Promise.reject(error)
}
)
apiClient.interceptors.response.use(
(response) => response,
(error: AxiosError) => {
if (error.response?.status === 401) {
console.warn('Authentication required. Please set admin_token in localStorage.')
// Don't redirect to avoid infinite loop
// User should manually set: localStorage.setItem('admin_token', 'your-token')
}
if (error.response?.status === 429) {
console.error('Rate limit exceeded')
}
return Promise.reject(error)
}
)
export default apiClient

View File

@@ -0,0 +1,66 @@
import apiClient from '../client'
import type {
AnnotationItem,
CreateAnnotationRequest,
AnnotationOverrideRequest,
} from '../types'
export const annotationsApi = {
list: async (documentId: string): Promise<AnnotationItem[]> => {
const { data } = await apiClient.get(
`/api/v1/admin/documents/${documentId}/annotations`
)
return data.annotations
},
create: async (
documentId: string,
annotation: CreateAnnotationRequest
): Promise<AnnotationItem> => {
const { data } = await apiClient.post(
`/api/v1/admin/documents/${documentId}/annotations`,
annotation
)
return data
},
update: async (
documentId: string,
annotationId: string,
updates: Partial<CreateAnnotationRequest>
): Promise<AnnotationItem> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`,
updates
)
return data
},
delete: async (documentId: string, annotationId: string): Promise<void> => {
await apiClient.delete(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}`
)
},
verify: async (
documentId: string,
annotationId: string
): Promise<{ annotation_id: string; is_verified: boolean; message: string }> => {
const { data } = await apiClient.post(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/verify`
)
return data
},
override: async (
documentId: string,
annotationId: string,
overrideData: AnnotationOverrideRequest
): Promise<{ annotation_id: string; source: string; message: string }> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/annotations/${annotationId}/override`,
overrideData
)
return data
},
}

View File

@@ -0,0 +1,118 @@
/**
* Tests for augmentation API endpoints.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { augmentationApi } from './augmentation'
import apiClient from '../client'
// Mock the API client
vi.mock('../client', () => ({
default: {
get: vi.fn(),
post: vi.fn(),
},
}))
describe('augmentationApi', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('getTypes', () => {
it('should fetch augmentation types', async () => {
const mockResponse = {
data: {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
],
},
}
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.getTypes()
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/types')
expect(result.augmentation_types).toHaveLength(1)
expect(result.augmentation_types[0].name).toBe('gaussian_noise')
})
})
describe('getPresets', () => {
it('should fetch augmentation presets', async () => {
const mockResponse = {
data: {
presets: [
{ name: 'conservative', description: 'Safe augmentations' },
{ name: 'moderate', description: 'Balanced augmentations' },
],
},
}
vi.mocked(apiClient.get).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.getPresets()
expect(apiClient.get).toHaveBeenCalledWith('/api/v1/admin/augmentation/presets')
expect(result.presets).toHaveLength(2)
})
})
describe('preview', () => {
it('should preview single augmentation', async () => {
const mockResponse = {
data: {
preview_url: '',
original_url: '',
applied_params: { std: 15 },
},
}
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
const result = await augmentationApi.preview('doc-123', {
augmentation_type: 'gaussian_noise',
params: { std: 15 },
})
expect(apiClient.post).toHaveBeenCalledWith(
'/api/v1/admin/augmentation/preview/doc-123',
{
augmentation_type: 'gaussian_noise',
params: { std: 15 },
},
{ params: { page: 1 } }
)
expect(result.preview_url).toBe('')
})
it('should support custom page number', async () => {
const mockResponse = {
data: {
preview_url: '',
original_url: '',
applied_params: {},
},
}
vi.mocked(apiClient.post).mockResolvedValueOnce(mockResponse)
await augmentationApi.preview(
'doc-123',
{ augmentation_type: 'gaussian_noise', params: {} },
2
)
expect(apiClient.post).toHaveBeenCalledWith(
'/api/v1/admin/augmentation/preview/doc-123',
expect.anything(),
{ params: { page: 2 } }
)
})
})
})

View File

@@ -0,0 +1,144 @@
/**
* Augmentation API endpoints.
*
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
*/
import apiClient from '../client'
// Types
export interface AugmentationTypeInfo {
name: string
description: string
affects_geometry: boolean
stage: string
default_params: Record<string, unknown>
}
export interface AugmentationTypesResponse {
augmentation_types: AugmentationTypeInfo[]
}
export interface PresetInfo {
name: string
description: string
config?: Record<string, unknown>
}
export interface PresetsResponse {
presets: PresetInfo[]
}
export interface PreviewRequest {
augmentation_type: string
params: Record<string, unknown>
}
export interface PreviewResponse {
preview_url: string
original_url: string
applied_params: Record<string, unknown>
}
export interface AugmentationParams {
enabled: boolean
probability: number
params: Record<string, unknown>
}
export interface AugmentationConfig {
perspective_warp?: AugmentationParams
wrinkle?: AugmentationParams
edge_damage?: AugmentationParams
stain?: AugmentationParams
lighting_variation?: AugmentationParams
shadow?: AugmentationParams
gaussian_blur?: AugmentationParams
motion_blur?: AugmentationParams
gaussian_noise?: AugmentationParams
salt_pepper?: AugmentationParams
paper_texture?: AugmentationParams
scanner_artifacts?: AugmentationParams
preserve_bboxes?: boolean
seed?: number | null
}
export interface BatchRequest {
dataset_id: string
config: AugmentationConfig
output_name: string
multiplier: number
}
export interface BatchResponse {
task_id: string
status: string
message: string
estimated_images: number
}
// API functions
export const augmentationApi = {
/**
* Fetch available augmentation types.
*/
async getTypes(): Promise<AugmentationTypesResponse> {
const response = await apiClient.get<AugmentationTypesResponse>(
'/api/v1/admin/augmentation/types'
)
return response.data
},
/**
* Fetch augmentation presets.
*/
async getPresets(): Promise<PresetsResponse> {
const response = await apiClient.get<PresetsResponse>(
'/api/v1/admin/augmentation/presets'
)
return response.data
},
/**
* Preview a single augmentation on a document page.
*/
async preview(
documentId: string,
request: PreviewRequest,
page: number = 1
): Promise<PreviewResponse> {
const response = await apiClient.post<PreviewResponse>(
`/api/v1/admin/augmentation/preview/${documentId}`,
request,
{ params: { page } }
)
return response.data
},
/**
* Preview full augmentation config on a document page.
*/
async previewConfig(
documentId: string,
config: AugmentationConfig,
page: number = 1
): Promise<PreviewResponse> {
const response = await apiClient.post<PreviewResponse>(
`/api/v1/admin/augmentation/preview-config/${documentId}`,
config,
{ params: { page } }
)
return response.data
},
/**
* Create an augmented dataset.
*/
async createBatch(request: BatchRequest): Promise<BatchResponse> {
const response = await apiClient.post<BatchResponse>(
'/api/v1/admin/augmentation/batch',
request
)
return response.data
},
}

View File

@@ -0,0 +1,52 @@
import apiClient from '../client'
import type {
DatasetCreateRequest,
DatasetDetailResponse,
DatasetListResponse,
DatasetResponse,
DatasetTrainRequest,
TrainingTaskResponse,
} from '../types'
export const datasetsApi = {
list: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<DatasetListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/datasets', {
params,
})
return data
},
create: async (req: DatasetCreateRequest): Promise<DatasetResponse> => {
const { data } = await apiClient.post('/api/v1/admin/training/datasets', req)
return data
},
getDetail: async (datasetId: string): Promise<DatasetDetailResponse> => {
const { data } = await apiClient.get(
`/api/v1/admin/training/datasets/${datasetId}`
)
return data
},
remove: async (datasetId: string): Promise<{ message: string }> => {
const { data } = await apiClient.delete(
`/api/v1/admin/training/datasets/${datasetId}`
)
return data
},
trainFromDataset: async (
datasetId: string,
req: DatasetTrainRequest
): Promise<TrainingTaskResponse> => {
const { data } = await apiClient.post(
`/api/v1/admin/training/datasets/${datasetId}/train`,
req
)
return data
},
}

View File

@@ -0,0 +1,122 @@
import apiClient from '../client'
import type {
DocumentListResponse,
DocumentDetailResponse,
DocumentItem,
UploadDocumentResponse,
DocumentCategoriesResponse,
} from '../types'
export const documentsApi = {
list: async (params?: {
status?: string
category?: string
limit?: number
offset?: number
}): Promise<DocumentListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/documents', { params })
return data
},
getCategories: async (): Promise<DocumentCategoriesResponse> => {
const { data } = await apiClient.get('/api/v1/admin/documents/categories')
return data
},
getDetail: async (documentId: string): Promise<DocumentDetailResponse> => {
const { data } = await apiClient.get(`/api/v1/admin/documents/${documentId}`)
return data
},
upload: async (
file: File,
options?: { groupKey?: string; category?: string }
): Promise<UploadDocumentResponse> => {
const formData = new FormData()
formData.append('file', file)
const params: Record<string, string> = {}
if (options?.groupKey) {
params.group_key = options.groupKey
}
if (options?.category) {
params.category = options.category
}
const { data } = await apiClient.post('/api/v1/admin/documents', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
params,
})
return data
},
batchUpload: async (
files: File[],
csvFile?: File
): Promise<{ batch_id: string; message: string; documents_created: number }> => {
const formData = new FormData()
files.forEach((file) => {
formData.append('files', file)
})
if (csvFile) {
formData.append('csv_file', csvFile)
}
const { data } = await apiClient.post('/api/v1/admin/batch/upload', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
})
return data
},
delete: async (documentId: string): Promise<void> => {
await apiClient.delete(`/api/v1/admin/documents/${documentId}`)
},
updateStatus: async (
documentId: string,
status: string
): Promise<DocumentItem> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/status`,
null,
{ params: { status } }
)
return data
},
triggerAutoLabel: async (documentId: string): Promise<{ message: string }> => {
const { data } = await apiClient.post(
`/api/v1/admin/documents/${documentId}/auto-label`
)
return data
},
updateGroupKey: async (
documentId: string,
groupKey: string | null
): Promise<{ status: string; document_id: string; group_key: string | null; message: string }> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/group-key`,
null,
{ params: { group_key: groupKey } }
)
return data
},
updateCategory: async (
documentId: string,
category: string
): Promise<{ status: string; document_id: string; category: string; message: string }> => {
const { data } = await apiClient.patch(
`/api/v1/admin/documents/${documentId}/category`,
{ category }
)
return data
},
}

View File

@@ -0,0 +1,7 @@
export { documentsApi } from './documents'
export { annotationsApi } from './annotations'
export { trainingApi } from './training'
export { inferenceApi } from './inference'
export { datasetsApi } from './datasets'
export { augmentationApi } from './augmentation'
export { modelsApi } from './models'

View File

@@ -0,0 +1,16 @@
import apiClient from '../client'
import type { InferenceResponse } from '../types'
export const inferenceApi = {
processDocument: async (file: File): Promise<InferenceResponse> => {
const formData = new FormData()
formData.append('file', file)
const { data } = await apiClient.post('/api/v1/infer', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
})
return data
},
}

View File

@@ -0,0 +1,55 @@
import apiClient from '../client'
import type {
ModelVersionListResponse,
ModelVersionDetailResponse,
ModelVersionResponse,
ActiveModelResponse,
} from '../types'
export const modelsApi = {
list: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<ModelVersionListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/models', {
params,
})
return data
},
getDetail: async (versionId: string): Promise<ModelVersionDetailResponse> => {
const { data } = await apiClient.get(`/api/v1/admin/training/models/${versionId}`)
return data
},
getActive: async (): Promise<ActiveModelResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/models/active')
return data
},
activate: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/activate`)
return data
},
deactivate: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/deactivate`)
return data
},
archive: async (versionId: string): Promise<ModelVersionResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/models/${versionId}/archive`)
return data
},
delete: async (versionId: string): Promise<{ message: string }> => {
const { data } = await apiClient.delete(`/api/v1/admin/training/models/${versionId}`)
return data
},
reload: async (): Promise<{ message: string; reloaded: boolean }> => {
const { data } = await apiClient.post('/api/v1/admin/training/models/reload')
return data
},
}

View File

@@ -0,0 +1,74 @@
import apiClient from '../client'
import type { TrainingModelsResponse, DocumentListResponse } from '../types'
export const trainingApi = {
getDocumentsForTraining: async (params?: {
has_annotations?: boolean
min_annotation_count?: number
exclude_used_in_training?: boolean
limit?: number
offset?: number
}): Promise<DocumentListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/documents', {
params,
})
return data
},
getModels: async (params?: {
status?: string
limit?: number
offset?: number
}): Promise<TrainingModelsResponse> => {
const { data} = await apiClient.get('/api/v1/admin/training/models', {
params,
})
return data
},
getTaskDetail: async (taskId: string) => {
const { data } = await apiClient.get(`/api/v1/admin/training/tasks/${taskId}`)
return data
},
startTraining: async (config: {
name: string
description?: string
document_ids: string[]
epochs?: number
batch_size?: number
model_base?: string
}) => {
// Convert frontend config to backend TrainingTaskCreate format
const taskRequest = {
name: config.name,
task_type: 'yolo',
description: config.description,
config: {
document_ids: config.document_ids,
epochs: config.epochs,
batch_size: config.batch_size,
base_model: config.model_base,
},
}
const { data } = await apiClient.post('/api/v1/admin/training/tasks', taskRequest)
return data
},
cancelTask: async (taskId: string) => {
const { data } = await apiClient.post(
`/api/v1/admin/training/tasks/${taskId}/cancel`
)
return data
},
downloadModel: async (taskId: string): Promise<Blob> => {
const { data } = await apiClient.get(
`/api/v1/admin/training/models/${taskId}/download`,
{
responseType: 'blob',
}
)
return data
},
}

364
frontend/src/api/types.ts Normal file
View File

@@ -0,0 +1,364 @@
export interface DocumentItem {
document_id: string
filename: string
file_size: number
content_type: string
page_count: number
status: 'pending' | 'labeled' | 'verified' | 'exported'
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
auto_label_error: string | null
upload_source: string
group_key: string | null
category: string
created_at: string
updated_at: string
annotation_count?: number
annotation_sources?: {
manual: number
auto: number
verified: number
}
}
export interface DocumentListResponse {
documents: DocumentItem[]
total: number
limit: number
offset: number
}
export interface AnnotationItem {
annotation_id: string
page_number: number
class_id: number
class_name: string
bbox: {
x: number
y: number
width: number
height: number
}
normalized_bbox: {
x_center: number
y_center: number
width: number
height: number
}
text_value: string | null
confidence: number | null
source: 'manual' | 'auto'
created_at: string
}
export interface DocumentDetailResponse {
document_id: string
filename: string
file_size: number
content_type: string
page_count: number
status: 'pending' | 'labeled' | 'verified' | 'exported'
auto_label_status: 'pending' | 'running' | 'completed' | 'failed' | null
auto_label_error: string | null
upload_source: string
batch_id: string | null
group_key: string | null
category: string
csv_field_values: Record<string, string> | null
can_annotate: boolean
annotation_lock_until: string | null
annotations: AnnotationItem[]
image_urls: string[]
training_history: Array<{
task_id: string
name: string
trained_at: string
model_metrics: {
mAP: number | null
precision: number | null
recall: number | null
} | null
}>
created_at: string
updated_at: string
}
export interface TrainingTask {
task_id: string
admin_token: string
name: string
description: string | null
status: 'pending' | 'running' | 'completed' | 'failed'
task_type: string
config: Record<string, unknown>
started_at: string | null
completed_at: string | null
error_message: string | null
result_metrics: Record<string, unknown>
model_path: string | null
document_count: number
metrics_mAP: number | null
metrics_precision: number | null
metrics_recall: number | null
created_at: string
updated_at: string
}
export interface ModelVersionItem {
version_id: string
version: string
name: string
status: string
is_active: boolean
metrics_mAP: number | null
document_count: number
trained_at: string | null
activated_at: string | null
created_at: string
}
export interface TrainingModelsResponse {
models: ModelVersionItem[]
total: number
limit: number
offset: number
}
export interface ErrorResponse {
detail: string
}
export interface UploadDocumentResponse {
document_id: string
filename: string
file_size: number
page_count: number
status: string
category: string
group_key: string | null
auto_label_started: boolean
message: string
}
export interface DocumentCategoriesResponse {
categories: string[]
total: number
}
export interface CreateAnnotationRequest {
page_number: number
class_id: number
bbox: {
x: number
y: number
width: number
height: number
}
text_value?: string
}
export interface AnnotationOverrideRequest {
text_value?: string
bbox?: {
x: number
y: number
width: number
height: number
}
class_id?: number
class_name?: string
reason?: string
}
export interface CrossValidationResult {
is_valid: boolean
payment_line_ocr: string | null
payment_line_amount: string | null
payment_line_account: string | null
payment_line_account_type: 'bankgiro' | 'plusgiro' | null
ocr_match: boolean | null
amount_match: boolean | null
bankgiro_match: boolean | null
plusgiro_match: boolean | null
details: string[]
}
export interface InferenceResult {
document_id: string
document_type: string
success: boolean
fields: Record<string, string>
confidence: Record<string, number>
cross_validation: CrossValidationResult | null
processing_time_ms: number
visualization_url: string | null
errors: string[]
fallback_used: boolean
}
export interface InferenceResponse {
result: InferenceResult
}
// Dataset types
export interface DatasetCreateRequest {
name: string
description?: string
document_ids: string[]
train_ratio?: number
val_ratio?: number
seed?: number
}
export interface DatasetResponse {
dataset_id: string
name: string
status: string
message: string
}
export interface DatasetDocumentItem {
document_id: string
split: string
page_count: number
annotation_count: number
}
export interface DatasetListItem {
dataset_id: string
name: string
description: string | null
status: string
training_status: string | null
active_training_task_id: string | null
total_documents: number
total_images: number
total_annotations: number
created_at: string
}
export interface DatasetListResponse {
total: number
limit: number
offset: number
datasets: DatasetListItem[]
}
export interface DatasetDetailResponse {
dataset_id: string
name: string
description: string | null
status: string
training_status: string | null
active_training_task_id: string | null
train_ratio: number
val_ratio: number
seed: number
total_documents: number
total_images: number
total_annotations: number
dataset_path: string | null
error_message: string | null
documents: DatasetDocumentItem[]
created_at: string
updated_at: string
}
export interface AugmentationParams {
enabled: boolean
probability: number
params: Record<string, unknown>
}
export interface AugmentationTrainingConfig {
gaussian_noise?: AugmentationParams
perspective_warp?: AugmentationParams
wrinkle?: AugmentationParams
edge_damage?: AugmentationParams
stain?: AugmentationParams
lighting_variation?: AugmentationParams
shadow?: AugmentationParams
gaussian_blur?: AugmentationParams
motion_blur?: AugmentationParams
salt_pepper?: AugmentationParams
paper_texture?: AugmentationParams
scanner_artifacts?: AugmentationParams
preserve_bboxes?: boolean
seed?: number | null
}
export interface DatasetTrainRequest {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs?: number
batch_size?: number
image_size?: number
learning_rate?: number
device?: string
augmentation?: AugmentationTrainingConfig
augmentation_multiplier?: number
}
}
export interface TrainingTaskResponse {
task_id: string
status: string
message: string
}
// Model Version types
export interface ModelVersionItem {
version_id: string
version: string
name: string
status: string
is_active: boolean
metrics_mAP: number | null
document_count: number
trained_at: string | null
activated_at: string | null
created_at: string
}
export interface ModelVersionDetailResponse {
version_id: string
version: string
name: string
description: string | null
model_path: string
status: string
is_active: boolean
task_id: string | null
dataset_id: string | null
metrics_mAP: number | null
metrics_precision: number | null
metrics_recall: number | null
document_count: number
training_config: Record<string, unknown> | null
file_size: number | null
trained_at: string | null
activated_at: string | null
created_at: string
updated_at: string
}
export interface ModelVersionListResponse {
total: number
limit: number
offset: number
models: ModelVersionItem[]
}
export interface ModelVersionResponse {
version_id: string
status: string
message: string
}
export interface ActiveModelResponse {
has_active_model: boolean
model: ModelVersionItem | null
}

View File

@@ -0,0 +1,251 @@
/**
* Tests for AugmentationConfig component.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { AugmentationConfig } from './AugmentationConfig'
import { augmentationApi } from '../api/endpoints/augmentation'
import type { ReactNode } from 'react'
// Mock the API
vi.mock('../api/endpoints/augmentation', () => ({
augmentationApi: {
getTypes: vi.fn(),
getPresets: vi.fn(),
preview: vi.fn(),
previewConfig: vi.fn(),
createBatch: vi.fn(),
},
}))
// Default mock data
const mockTypes = {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise to simulate sensor noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
{
name: 'perspective_warp',
description: 'Applies perspective transformation',
affects_geometry: true,
stage: 'geometric',
default_params: { max_warp: 0.02 },
},
{
name: 'gaussian_blur',
description: 'Applies Gaussian blur',
affects_geometry: false,
stage: 'blur',
default_params: { kernel_size: 5 },
},
],
}
const mockPresets = {
presets: [
{ name: 'conservative', description: 'Safe augmentations for high-quality documents' },
{ name: 'moderate', description: 'Balanced augmentation settings' },
{ name: 'aggressive', description: 'Strong augmentations for data diversity' },
],
}
// Test wrapper with QueryClient
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return ({ children }: { children: ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
}
describe('AugmentationConfig', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.mocked(augmentationApi.getTypes).mockResolvedValue(mockTypes)
vi.mocked(augmentationApi.getPresets).mockResolvedValue(mockPresets)
})
describe('rendering', () => {
it('should render enable checkbox', async () => {
render(
<AugmentationConfig
enabled={false}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
expect(screen.getByRole('checkbox', { name: /enable augmentation/i })).toBeInTheDocument()
})
it('should be collapsed when disabled', () => {
render(
<AugmentationConfig
enabled={false}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
// Config options should not be visible
expect(screen.queryByText(/preset/i)).not.toBeInTheDocument()
})
it('should expand when enabled', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText(/preset/i)).toBeInTheDocument()
})
})
})
describe('preset selection', () => {
it('should display available presets', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText('conservative')).toBeInTheDocument()
expect(screen.getByText('moderate')).toBeInTheDocument()
expect(screen.getByText('aggressive')).toBeInTheDocument()
})
})
it('should call onConfigChange when preset is selected', async () => {
const user = userEvent.setup()
const onConfigChange = vi.fn()
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={onConfigChange}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText('moderate')).toBeInTheDocument()
})
await user.click(screen.getByText('moderate'))
expect(onConfigChange).toHaveBeenCalled()
})
})
describe('enable toggle', () => {
it('should call onEnabledChange when checkbox is toggled', async () => {
const user = userEvent.setup()
const onEnabledChange = vi.fn()
render(
<AugmentationConfig
enabled={false}
onEnabledChange={onEnabledChange}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
await user.click(screen.getByRole('checkbox', { name: /enable augmentation/i }))
expect(onEnabledChange).toHaveBeenCalledWith(true)
})
})
describe('augmentation types', () => {
it('should display augmentation types when in custom mode', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
showCustomOptions={true}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
expect(screen.getByText(/gaussian_noise/i)).toBeInTheDocument()
expect(screen.getByText(/perspective_warp/i)).toBeInTheDocument()
})
})
it('should indicate which augmentations affect geometry', async () => {
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
showCustomOptions={true}
/>,
{ wrapper: createWrapper() }
)
await waitFor(() => {
// perspective_warp affects geometry
const perspectiveItem = screen.getByText(/perspective_warp/i).closest('div')
expect(perspectiveItem).toHaveTextContent(/affects bbox/i)
})
})
})
describe('loading state', () => {
it('should show loading indicator while fetching types', () => {
vi.mocked(augmentationApi.getTypes).mockImplementation(
() => new Promise(() => {})
)
render(
<AugmentationConfig
enabled={true}
onEnabledChange={vi.fn()}
config={{}}
onConfigChange={vi.fn()}
/>,
{ wrapper: createWrapper() }
)
expect(screen.getByTestId('augmentation-loading')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,136 @@
/**
* AugmentationConfig component for configuring image augmentation during training.
*
* Provides preset selection and optional custom augmentation type configuration.
*/
import React from 'react'
import { Loader2, AlertTriangle } from 'lucide-react'
import { useAugmentation } from '../hooks/useAugmentation'
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
interface AugmentationConfigProps {
enabled: boolean
onEnabledChange: (enabled: boolean) => void
config: Partial<AugmentationConfigType>
onConfigChange: (config: Partial<AugmentationConfigType>) => void
showCustomOptions?: boolean
}
export const AugmentationConfig: React.FC<AugmentationConfigProps> = ({
enabled,
onEnabledChange,
config,
onConfigChange,
showCustomOptions = false,
}) => {
const { augmentationTypes, presets, isLoadingTypes, isLoadingPresets } = useAugmentation()
const isLoading = isLoadingTypes || isLoadingPresets
const handlePresetSelect = (presetName: string) => {
const preset = presets.find((p) => p.name === presetName)
if (preset && preset.config) {
onConfigChange(preset.config as Partial<AugmentationConfigType>)
} else {
// Apply a basic config based on preset name
const presetConfigs: Record<string, Partial<AugmentationConfigType>> = {
conservative: {
gaussian_noise: { enabled: true, probability: 0.3, params: { std: 10 } },
gaussian_blur: { enabled: true, probability: 0.2, params: { kernel_size: 3 } },
},
moderate: {
gaussian_noise: { enabled: true, probability: 0.5, params: { std: 15 } },
gaussian_blur: { enabled: true, probability: 0.3, params: { kernel_size: 5 } },
lighting_variation: { enabled: true, probability: 0.3, params: {} },
perspective_warp: { enabled: true, probability: 0.2, params: { max_warp: 0.02 } },
},
aggressive: {
gaussian_noise: { enabled: true, probability: 0.7, params: { std: 20 } },
gaussian_blur: { enabled: true, probability: 0.5, params: { kernel_size: 7 } },
motion_blur: { enabled: true, probability: 0.3, params: {} },
lighting_variation: { enabled: true, probability: 0.5, params: {} },
shadow: { enabled: true, probability: 0.3, params: {} },
perspective_warp: { enabled: true, probability: 0.3, params: { max_warp: 0.03 } },
wrinkle: { enabled: true, probability: 0.2, params: {} },
stain: { enabled: true, probability: 0.2, params: {} },
},
}
onConfigChange(presetConfigs[presetName] || {})
}
}
return (
<div className="border border-warm-divider rounded-lg p-4 bg-warm-bg-secondary">
{/* Enable checkbox */}
<label className="flex items-center gap-2 cursor-pointer">
<input
type="checkbox"
checked={enabled}
onChange={(e) => onEnabledChange(e.target.checked)}
className="w-4 h-4 rounded border-warm-divider text-warm-state-info focus:ring-warm-state-info"
aria-label="Enable augmentation"
/>
<span className="text-sm font-medium text-warm-text-secondary">Enable Augmentation</span>
<span className="text-xs text-warm-text-muted">(Simulate real-world document conditions)</span>
</label>
{/* Expanded content when enabled */}
{enabled && (
<div className="mt-4 space-y-4">
{isLoading ? (
<div className="flex items-center justify-center py-4" data-testid="augmentation-loading">
<Loader2 className="w-5 h-5 animate-spin text-warm-state-info" />
<span className="ml-2 text-sm text-warm-text-muted">Loading augmentation options...</span>
</div>
) : (
<>
{/* Preset selection */}
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-2">Preset</label>
<div className="flex flex-wrap gap-2">
{presets.map((preset) => (
<button
key={preset.name}
onClick={() => handlePresetSelect(preset.name)}
className="px-3 py-1.5 text-sm rounded-md border border-warm-divider hover:bg-warm-bg-tertiary transition-colors"
title={preset.description}
>
{preset.name}
</button>
))}
</div>
</div>
{/* Custom options (if enabled) */}
{showCustomOptions && (
<div className="border-t border-warm-divider pt-4">
<h4 className="text-sm font-medium text-warm-text-secondary mb-3">Augmentation Types</h4>
<div className="grid gap-2">
{augmentationTypes.map((type) => (
<div
key={type.name}
className="flex items-center justify-between p-2 bg-warm-bg-primary rounded border border-warm-divider"
>
<div className="flex items-center gap-2">
<span className="text-sm text-warm-text-primary">{type.name}</span>
{type.affects_geometry && (
<span className="flex items-center gap-1 text-xs text-warm-state-warning">
<AlertTriangle size={12} />
affects bbox
</span>
)}
</div>
<span className="text-xs text-warm-text-muted">{type.stage}</span>
</div>
))}
</div>
</div>
)}
</>
)}
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,32 @@
import { render, screen } from '@testing-library/react';
import { describe, it, expect } from 'vitest';
import { Badge } from './Badge';
import { DocumentStatus } from '../types';
describe('Badge', () => {
it('renders Exported badge with check icon', () => {
render(<Badge status="Exported" />);
expect(screen.getByText('Exported')).toBeInTheDocument();
});
it('renders Pending status', () => {
render(<Badge status={DocumentStatus.PENDING} />);
expect(screen.getByText('Pending')).toBeInTheDocument();
});
it('renders Verified status', () => {
render(<Badge status={DocumentStatus.VERIFIED} />);
expect(screen.getByText('Verified')).toBeInTheDocument();
});
it('renders Labeled status', () => {
render(<Badge status={DocumentStatus.LABELED} />);
expect(screen.getByText('Labeled')).toBeInTheDocument();
});
it('renders Partial status with warning indicator', () => {
render(<Badge status={DocumentStatus.PARTIAL} />);
expect(screen.getByText('Partial')).toBeInTheDocument();
expect(screen.getByText('!')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,39 @@
import React from 'react';
import { DocumentStatus } from '../types';
import { Check } from 'lucide-react';
interface BadgeProps {
status: DocumentStatus | 'Exported';
}
export const Badge: React.FC<BadgeProps> = ({ status }) => {
if (status === 'Exported') {
return (
<span className="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-full text-xs font-medium bg-warm-selected text-warm-text-secondary">
<Check size={12} strokeWidth={3} />
Exported
</span>
);
}
const styles = {
[DocumentStatus.PENDING]: "bg-white border border-warm-divider text-warm-text-secondary",
[DocumentStatus.LABELED]: "bg-warm-text-secondary text-white border border-transparent",
[DocumentStatus.VERIFIED]: "bg-warm-state-success/10 text-warm-state-success border border-warm-state-success/20",
[DocumentStatus.PARTIAL]: "bg-warm-state-warning/10 text-warm-state-warning border border-warm-state-warning/20",
};
const icons = {
[DocumentStatus.VERIFIED]: <Check size={12} className="mr-1" />,
[DocumentStatus.PARTIAL]: <span className="mr-1 text-[10px] font-bold">!</span>,
[DocumentStatus.PENDING]: null,
[DocumentStatus.LABELED]: null,
}
return (
<span className={`inline-flex items-center px-3 py-1 rounded-full text-xs font-medium border ${styles[status]}`}>
{icons[status]}
{status}
</span>
);
};

View File

@@ -0,0 +1,38 @@
import { render, screen } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { describe, it, expect, vi } from 'vitest';
import { Button } from './Button';
describe('Button', () => {
it('renders children text', () => {
render(<Button>Click me</Button>);
expect(screen.getByRole('button', { name: 'Click me' })).toBeInTheDocument();
});
it('calls onClick handler', async () => {
const user = userEvent.setup();
const onClick = vi.fn();
render(<Button onClick={onClick}>Click</Button>);
await user.click(screen.getByRole('button'));
expect(onClick).toHaveBeenCalledOnce();
});
it('is disabled when disabled prop is set', () => {
render(<Button disabled>Disabled</Button>);
expect(screen.getByRole('button')).toBeDisabled();
});
it('applies variant styles', () => {
const { rerender } = render(<Button variant="primary">Primary</Button>);
const btn = screen.getByRole('button');
expect(btn.className).toContain('bg-warm-text-secondary');
rerender(<Button variant="secondary">Secondary</Button>);
expect(screen.getByRole('button').className).toContain('border');
});
it('applies size styles', () => {
render(<Button size="sm">Small</Button>);
expect(screen.getByRole('button').className).toContain('h-8');
});
});

View File

@@ -0,0 +1,38 @@
import React from 'react';
interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement> {
variant?: 'primary' | 'secondary' | 'outline' | 'text';
size?: 'sm' | 'md' | 'lg';
}
export const Button: React.FC<ButtonProps> = ({
variant = 'primary',
size = 'md',
className = '',
children,
...props
}) => {
const baseStyles = "inline-flex items-center justify-center rounded-md font-medium transition-all duration-150 ease-out active:scale-98 disabled:opacity-50 disabled:pointer-events-none";
const variants = {
primary: "bg-warm-text-secondary text-white hover:bg-warm-text-primary shadow-sm",
secondary: "bg-white border border-warm-divider text-warm-text-secondary hover:bg-warm-hover",
outline: "bg-transparent border border-warm-text-secondary text-warm-text-secondary hover:bg-warm-hover",
text: "text-warm-text-muted hover:text-warm-text-primary hover:bg-warm-hover",
};
const sizes = {
sm: "h-8 px-3 text-xs",
md: "h-10 px-4 text-sm",
lg: "h-12 px-6 text-base",
};
return (
<button
className={`${baseStyles} ${variants[variant]} ${sizes[size]} ${className}`}
{...props}
>
{children}
</button>
);
};

View File

@@ -0,0 +1,300 @@
import React, { useState } from 'react'
import { Search, ChevronDown, MoreHorizontal, FileText } from 'lucide-react'
import { Badge } from './Badge'
import { Button } from './Button'
import { UploadModal } from './UploadModal'
import { useDocuments, useCategories } from '../hooks/useDocuments'
import type { DocumentItem } from '../api/types'
interface DashboardProps {
onNavigate: (view: string, docId?: string) => void
}
const getStatusForBadge = (status: string): string => {
const statusMap: Record<string, string> = {
pending: 'Pending',
labeled: 'Labeled',
verified: 'Verified',
exported: 'Exported',
}
return statusMap[status] || status
}
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
if (doc.auto_label_status === 'running') {
return 45
}
if (doc.auto_label_status === 'completed') {
return 100
}
return undefined
}
export const Dashboard: React.FC<DashboardProps> = ({ onNavigate }) => {
const [isUploadOpen, setIsUploadOpen] = useState(false)
const [selectedDocs, setSelectedDocs] = useState<Set<string>>(new Set())
const [statusFilter, setStatusFilter] = useState<string>('')
const [categoryFilter, setCategoryFilter] = useState<string>('')
const [limit] = useState(20)
const [offset] = useState(0)
const { categories } = useCategories()
const { documents, total, isLoading, error, refetch } = useDocuments({
status: statusFilter || undefined,
category: categoryFilter || undefined,
limit,
offset,
})
const toggleSelection = (id: string) => {
const newSet = new Set(selectedDocs)
if (newSet.has(id)) {
newSet.delete(id)
} else {
newSet.add(id)
}
setSelectedDocs(newSet)
}
if (error) {
return (
<div className="p-8 max-w-7xl mx-auto">
<div className="bg-red-50 border border-red-200 text-red-800 p-4 rounded-lg">
Error loading documents. Please check your connection to the backend API.
<button
onClick={() => refetch()}
className="ml-4 underline hover:no-underline"
>
Retry
</button>
</div>
</div>
)
}
return (
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
<div className="flex items-center justify-between mb-8">
<div>
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
Documents
</h1>
<p className="text-sm text-warm-text-muted mt-1">
{isLoading ? 'Loading...' : `${total} documents total`}
</p>
</div>
<div className="flex gap-3">
<Button variant="secondary" disabled={selectedDocs.size === 0}>
Export Selection ({selectedDocs.size})
</Button>
<Button onClick={() => setIsUploadOpen(true)}>Upload Documents</Button>
</div>
</div>
<div className="bg-warm-card border border-warm-border rounded-lg p-4 mb-6 shadow-sm flex flex-wrap gap-4 items-center">
<div className="relative flex-1 min-w-[200px]">
<Search
className="absolute left-3 top-1/2 -translate-y-1/2 text-warm-text-muted"
size={16}
/>
<input
type="text"
placeholder="Search documents..."
className="w-full pl-9 pr-4 h-10 rounded-md border border-warm-border bg-white focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow text-sm"
/>
</div>
<div className="flex gap-3">
<div className="relative">
<select
value={categoryFilter}
onChange={(e) => setCategoryFilter(e.target.value)}
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
>
<option value="">All Categories</option>
{categories.map((cat) => (
<option key={cat} value={cat}>
{cat.charAt(0).toUpperCase() + cat.slice(1)}
</option>
))}
</select>
<ChevronDown
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
size={14}
/>
</div>
<div className="relative">
<select
value={statusFilter}
onChange={(e) => setStatusFilter(e.target.value)}
className="h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none appearance-none cursor-pointer hover:bg-warm-hover"
>
<option value="">All Statuses</option>
<option value="pending">Pending</option>
<option value="labeled">Labeled</option>
<option value="verified">Verified</option>
<option value="exported">Exported</option>
</select>
<ChevronDown
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
size={14}
/>
</div>
</div>
</div>
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<table className="w-full text-left border-collapse">
<thead>
<tr className="border-b border-warm-border bg-white">
<th className="py-3 pl-6 pr-4 w-12">
<input
type="checkbox"
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary"
/>
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Document Name
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Date
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Status
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Annotations
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Category
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider">
Group
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase tracking-wider w-64">
Auto-label
</th>
<th className="py-3 px-4 w-12"></th>
</tr>
</thead>
<tbody>
{isLoading ? (
<tr>
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
Loading documents...
</td>
</tr>
) : documents.length === 0 ? (
<tr>
<td colSpan={9} className="py-8 text-center text-warm-text-muted">
No documents found. Upload your first document to get started.
</td>
</tr>
) : (
documents.map((doc) => {
const isSelected = selectedDocs.has(doc.document_id)
const progress = getAutoLabelProgress(doc)
return (
<tr
key={doc.document_id}
onClick={() => onNavigate('detail', doc.document_id)}
className={`
group transition-colors duration-150 cursor-pointer border-b border-warm-border last:border-0
${isSelected ? 'bg-warm-selected' : 'hover:bg-warm-hover bg-white'}
`}
>
<td
className="py-4 pl-6 pr-4 relative"
onClick={(e) => {
e.stopPropagation()
toggleSelection(doc.document_id)
}}
>
{isSelected && (
<div className="absolute left-0 top-0 bottom-0 w-[3px] bg-warm-state-info" />
)}
<input
type="checkbox"
checked={isSelected}
readOnly
className="rounded border-warm-divider text-warm-text-primary focus:ring-warm-text-secondary cursor-pointer"
/>
</td>
<td className="py-4 px-4">
<div className="flex items-center gap-3">
<div className="p-2 bg-warm-bg rounded border border-warm-border text-warm-text-muted">
<FileText size={16} />
</div>
<span className="font-medium text-warm-text-secondary">
{doc.filename}
</span>
</div>
</td>
<td className="py-4 px-4 text-sm text-warm-text-secondary font-mono">
{new Date(doc.created_at).toLocaleDateString()}
</td>
<td className="py-4 px-4">
<Badge status={getStatusForBadge(doc.status)} />
</td>
<td className="py-4 px-4 text-sm text-warm-text-secondary">
{doc.annotation_count || 0} annotations
</td>
<td className="py-4 px-4 text-sm text-warm-text-secondary capitalize">
{doc.category || 'invoice'}
</td>
<td className="py-4 px-4 text-sm text-warm-text-muted">
{doc.group_key || '-'}
</td>
<td className="py-4 px-4">
{doc.auto_label_status === 'running' && progress && (
<div className="w-full">
<div className="flex justify-between text-xs mb-1">
<span className="text-warm-text-secondary font-medium">
Running
</span>
<span className="text-warm-text-muted">{progress}%</span>
</div>
<div className="h-1.5 w-full bg-warm-selected rounded-full overflow-hidden">
<div
className="h-full bg-warm-state-info transition-all duration-500 ease-out"
style={{ width: `${progress}%` }}
/>
</div>
</div>
)}
{doc.auto_label_status === 'completed' && (
<span className="text-sm font-medium text-warm-state-success">
Completed
</span>
)}
{doc.auto_label_status === 'failed' && (
<span className="text-sm font-medium text-warm-state-error">
Failed
</span>
)}
</td>
<td className="py-4 px-4 text-right">
<button className="text-warm-text-muted hover:text-warm-text-secondary p-1 rounded hover:bg-black/5 transition-colors">
<MoreHorizontal size={18} />
</button>
</td>
</tr>
)
})
)}
</tbody>
</table>
</div>
<UploadModal
isOpen={isUploadOpen}
onClose={() => {
setIsUploadOpen(false)
refetch()
}}
/>
</div>
)
}

View File

@@ -0,0 +1,148 @@
import React from 'react'
import { FileText, CheckCircle, Clock, TrendingUp, Activity } from 'lucide-react'
import { Button } from './Button'
import { useDocuments } from '../hooks/useDocuments'
import { useTraining } from '../hooks/useTraining'
interface DashboardOverviewProps {
onNavigate: (view: string) => void
}
export const DashboardOverview: React.FC<DashboardOverviewProps> = ({ onNavigate }) => {
const { total: totalDocs, isLoading: docsLoading } = useDocuments({ limit: 1 })
const { models, isLoadingModels } = useTraining()
const stats = [
{
label: 'Total Documents',
value: docsLoading ? '...' : totalDocs.toString(),
icon: FileText,
color: 'text-warm-text-primary',
bgColor: 'bg-warm-bg',
},
{
label: 'Labeled',
value: '0',
icon: CheckCircle,
color: 'text-warm-state-success',
bgColor: 'bg-green-50',
},
{
label: 'Pending',
value: '0',
icon: Clock,
color: 'text-warm-state-warning',
bgColor: 'bg-yellow-50',
},
{
label: 'Training Models',
value: isLoadingModels ? '...' : models.length.toString(),
icon: TrendingUp,
color: 'text-warm-state-info',
bgColor: 'bg-blue-50',
},
]
return (
<div className="p-8 max-w-7xl mx-auto animate-fade-in">
{/* Header */}
<div className="mb-8">
<h1 className="text-3xl font-bold text-warm-text-primary tracking-tight">
Dashboard
</h1>
<p className="text-sm text-warm-text-muted mt-1">
Overview of your document annotation system
</p>
</div>
{/* Stats Grid */}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6 mb-8">
{stats.map((stat) => (
<div
key={stat.label}
className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm hover:shadow-md transition-shadow"
>
<div className="flex items-center justify-between mb-4">
<div className={`p-3 rounded-lg ${stat.bgColor}`}>
<stat.icon className={stat.color} size={24} />
</div>
</div>
<p className="text-2xl font-bold text-warm-text-primary mb-1">
{stat.value}
</p>
<p className="text-sm text-warm-text-muted">{stat.label}</p>
</div>
))}
</div>
{/* Quick Actions */}
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm mb-8">
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
Quick Actions
</h2>
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<Button onClick={() => onNavigate('documents')} className="justify-start">
<FileText size={18} className="mr-2" />
Manage Documents
</Button>
<Button onClick={() => onNavigate('training')} variant="secondary" className="justify-start">
<Activity size={18} className="mr-2" />
Start Training
</Button>
<Button onClick={() => onNavigate('models')} variant="secondary" className="justify-start">
<TrendingUp size={18} className="mr-2" />
View Models
</Button>
</div>
</div>
{/* Recent Activity */}
<div className="bg-warm-card border border-warm-border rounded-lg shadow-sm overflow-hidden">
<div className="p-6 border-b border-warm-border">
<h2 className="text-lg font-semibold text-warm-text-primary">
Recent Activity
</h2>
</div>
<div className="p-6">
<div className="text-center py-8 text-warm-text-muted">
<Activity size={48} className="mx-auto mb-3 opacity-20" />
<p className="text-sm">No recent activity</p>
<p className="text-xs mt-1">
Start by uploading documents or creating training jobs
</p>
</div>
</div>
</div>
{/* System Status */}
<div className="mt-8 bg-warm-card border border-warm-border rounded-lg p-6 shadow-sm">
<h2 className="text-lg font-semibold text-warm-text-primary mb-4">
System Status
</h2>
<div className="space-y-3">
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">Backend API</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Online
</span>
</div>
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">Database</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Connected
</span>
</div>
<div className="flex items-center justify-between">
<span className="text-sm text-warm-text-secondary">GPU</span>
<span className="flex items-center text-sm text-warm-state-success">
<span className="w-2 h-2 bg-green-500 rounded-full mr-2"></span>
Available
</span>
</div>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,176 @@
import React from 'react'
import { ArrowLeft, Loader2, Play, AlertCircle, Check, Award } from 'lucide-react'
import { Button } from './Button'
import { useDatasetDetail } from '../hooks/useDatasets'
interface DatasetDetailProps {
datasetId: string
onBack: () => void
}
const SPLIT_STYLES: Record<string, string> = {
train: 'bg-warm-state-info/10 text-warm-state-info',
val: 'bg-warm-state-warning/10 text-warm-state-warning',
test: 'bg-warm-state-success/10 text-warm-state-success',
}
const STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
building: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Building' },
ready: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Ready' },
trained: { bg: 'bg-purple-100', text: 'text-purple-700', label: 'Trained' },
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
archived: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Archived' },
}
const TRAINING_STATUS_STYLES: Record<string, { bg: string; text: string; label: string }> = {
pending: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Pending' },
scheduled: { bg: 'bg-warm-state-warning/10', text: 'text-warm-state-warning', label: 'Scheduled' },
running: { bg: 'bg-warm-state-info/10', text: 'text-warm-state-info', label: 'Training' },
completed: { bg: 'bg-warm-state-success/10', text: 'text-warm-state-success', label: 'Completed' },
failed: { bg: 'bg-warm-state-error/10', text: 'text-warm-state-error', label: 'Failed' },
cancelled: { bg: 'bg-warm-border', text: 'text-warm-text-muted', label: 'Cancelled' },
}
export const DatasetDetail: React.FC<DatasetDetailProps> = ({ datasetId, onBack }) => {
const { dataset, isLoading, error } = useDatasetDetail(datasetId)
if (isLoading) {
return (
<div className="flex items-center justify-center py-20 text-warm-text-muted">
<Loader2 size={24} className="animate-spin mr-2" />Loading dataset...
</div>
)
}
if (error || !dataset) {
return (
<div className="p-8 max-w-7xl mx-auto">
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
<ArrowLeft size={16} />Back
</button>
<p className="text-warm-state-error">Failed to load dataset.</p>
</div>
)
}
const statusConfig = STATUS_STYLES[dataset.status] || STATUS_STYLES.ready
const trainingStatusConfig = dataset.training_status
? TRAINING_STATUS_STYLES[dataset.training_status]
: null
// Determine if training button should be shown and enabled
const isTrainingInProgress = dataset.training_status === 'running' || dataset.training_status === 'pending'
const canStartTraining = dataset.status === 'ready' && !isTrainingInProgress
// Determine status icon
const statusIcon = dataset.status === 'trained'
? <Award size={14} className="text-purple-700" />
: dataset.status === 'ready'
? <Check size={14} className="text-warm-state-success" />
: dataset.status === 'failed'
? <AlertCircle size={14} className="text-warm-state-error" />
: dataset.status === 'building'
? <Loader2 size={14} className="animate-spin text-warm-state-info" />
: null
return (
<div className="p-8 max-w-7xl mx-auto">
{/* Header */}
<button onClick={onBack} className="flex items-center gap-1 text-sm text-warm-text-muted hover:text-warm-text-secondary mb-4">
<ArrowLeft size={16} />Back to Datasets
</button>
<div className="flex items-center justify-between mb-6">
<div>
<div className="flex items-center gap-3 mb-1">
<h2 className="text-2xl font-bold text-warm-text-primary flex items-center gap-2">
{dataset.name} {statusIcon}
</h2>
{/* Status Badge */}
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${statusConfig.bg} ${statusConfig.text}`}>
{statusConfig.label}
</span>
{/* Training Status Badge */}
{trainingStatusConfig && (
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${trainingStatusConfig.bg} ${trainingStatusConfig.text}`}>
{isTrainingInProgress && <Loader2 size={12} className="mr-1 animate-spin" />}
{trainingStatusConfig.label}
</span>
)}
</div>
{dataset.description && (
<p className="text-sm text-warm-text-muted mt-1">{dataset.description}</p>
)}
</div>
{/* Training Button */}
{(dataset.status === 'ready' || dataset.status === 'trained') && (
<Button
disabled={isTrainingInProgress}
className={isTrainingInProgress ? 'opacity-50 cursor-not-allowed' : ''}
>
{isTrainingInProgress ? (
<><Loader2 size={14} className="mr-1 animate-spin" />Training...</>
) : (
<><Play size={14} className="mr-1" />Start Training</>
)}
</Button>
)}
</div>
{dataset.error_message && (
<div className="bg-warm-state-error/10 border border-warm-state-error/20 rounded-lg p-4 mb-6 text-sm text-warm-state-error">
{dataset.error_message}
</div>
)}
{/* Stats */}
<div className="grid grid-cols-4 gap-4 mb-8">
{[
['Documents', dataset.total_documents],
['Images', dataset.total_images],
['Annotations', dataset.total_annotations],
['Split', `${(dataset.train_ratio * 100).toFixed(0)}/${(dataset.val_ratio * 100).toFixed(0)}/${((1 - dataset.train_ratio - dataset.val_ratio) * 100).toFixed(0)}`],
].map(([label, value]) => (
<div key={String(label)} className="bg-warm-card border border-warm-border rounded-lg p-4">
<p className="text-xs text-warm-text-muted uppercase font-semibold mb-1">{label}</p>
<p className="text-2xl font-bold text-warm-text-primary font-mono">{value}</p>
</div>
))}
</div>
{/* Document list */}
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Documents</h3>
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
<table className="w-full text-left">
<thead className="bg-white border-b border-warm-border">
<tr>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Split</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
</tr>
</thead>
<tbody>
{dataset.documents.map(doc => (
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
<td className="py-3 px-4">
<span className={`inline-flex px-2.5 py-1 rounded-full text-xs font-medium ${SPLIT_STYLES[doc.split] ?? 'bg-warm-border text-warm-text-muted'}`}>
{doc.split}
</span>
</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count}</td>
</tr>
))}
</tbody>
</table>
</div>
<p className="text-xs text-warm-text-muted mt-4">
Created: {new Date(dataset.created_at).toLocaleString()} | Updated: {new Date(dataset.updated_at).toLocaleString()}
{dataset.dataset_path && <> | Path: <code className="text-xs">{dataset.dataset_path}</code></>}
</p>
</div>
)
}

View File

@@ -0,0 +1,567 @@
import React, { useState, useRef, useEffect } from 'react'
import { ChevronLeft, ZoomIn, ZoomOut, Plus, Edit2, Trash2, Tag, CheckCircle, Check, X } from 'lucide-react'
import { Button } from './Button'
import { useDocumentDetail } from '../hooks/useDocumentDetail'
import { useAnnotations } from '../hooks/useAnnotations'
import { useDocuments } from '../hooks/useDocuments'
import { documentsApi } from '../api/endpoints/documents'
import type { AnnotationItem } from '../api/types'
interface DocumentDetailProps {
docId: string
onBack: () => void
}
// Field class mapping from backend
const FIELD_CLASSES: Record<number, string> = {
0: 'invoice_number',
1: 'invoice_date',
2: 'invoice_due_date',
3: 'ocr_number',
4: 'bankgiro',
5: 'plusgiro',
6: 'amount',
7: 'supplier_organisation_number',
8: 'payment_line',
9: 'customer_number',
}
export const DocumentDetail: React.FC<DocumentDetailProps> = ({ docId, onBack }) => {
const { document, annotations, isLoading, refetch } = useDocumentDetail(docId)
const {
createAnnotation,
updateAnnotation,
deleteAnnotation,
isCreating,
isDeleting,
} = useAnnotations(docId)
const { updateGroupKey, isUpdatingGroupKey } = useDocuments({})
const [selectedId, setSelectedId] = useState<string | null>(null)
const [zoom, setZoom] = useState(100)
const [isDrawing, setIsDrawing] = useState(false)
const [isEditingGroupKey, setIsEditingGroupKey] = useState(false)
const [editGroupKeyValue, setEditGroupKeyValue] = useState('')
const [drawStart, setDrawStart] = useState<{ x: number; y: number } | null>(null)
const [drawEnd, setDrawEnd] = useState<{ x: number; y: number } | null>(null)
const [selectedClassId, setSelectedClassId] = useState<number>(0)
const [currentPage, setCurrentPage] = useState(1)
const [imageSize, setImageSize] = useState<{ width: number; height: number } | null>(null)
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null)
const canvasRef = useRef<HTMLDivElement>(null)
const imageRef = useRef<HTMLImageElement>(null)
const [isMarkingComplete, setIsMarkingComplete] = useState(false)
const selectedAnnotation = annotations?.find((a) => a.annotation_id === selectedId)
// Handle mark as complete
const handleMarkComplete = async () => {
if (!annotations || annotations.length === 0) {
alert('Please add at least one annotation before marking as complete.')
return
}
if (!confirm('Mark this document as labeled? This will save annotations to the database.')) {
return
}
setIsMarkingComplete(true)
try {
const result = await documentsApi.updateStatus(docId, 'labeled')
alert(`Document marked as labeled. ${(result as any).fields_saved || annotations.length} annotations saved.`)
onBack() // Return to document list
} catch (error) {
console.error('Failed to mark document as complete:', error)
alert('Failed to mark document as complete. Please try again.')
} finally {
setIsMarkingComplete(false)
}
}
// Load image via fetch with authentication header
useEffect(() => {
let objectUrl: string | null = null
const loadImage = async () => {
if (!docId) return
const token = localStorage.getItem('admin_token')
const imageUrl = `${import.meta.env.VITE_API_URL || 'http://localhost:8000'}/api/v1/admin/documents/${docId}/images/${currentPage}`
try {
const response = await fetch(imageUrl, {
headers: {
'X-Admin-Token': token || '',
},
})
if (!response.ok) {
throw new Error(`Failed to load image: ${response.status}`)
}
const blob = await response.blob()
objectUrl = URL.createObjectURL(blob)
setImageBlobUrl(objectUrl)
} catch (error) {
console.error('Failed to load image:', error)
}
}
loadImage()
// Cleanup: revoke object URL when component unmounts or page changes
return () => {
if (objectUrl) {
URL.revokeObjectURL(objectUrl)
}
}
}, [currentPage, docId])
// Load image size
useEffect(() => {
if (imageRef.current && imageRef.current.complete) {
setImageSize({
width: imageRef.current.naturalWidth,
height: imageRef.current.naturalHeight,
})
}
}, [imageBlobUrl])
const handleImageLoad = () => {
if (imageRef.current) {
setImageSize({
width: imageRef.current.naturalWidth,
height: imageRef.current.naturalHeight,
})
}
}
const handleMouseDown = (e: React.MouseEvent<HTMLDivElement>) => {
if (!canvasRef.current || !imageSize) return
const rect = canvasRef.current.getBoundingClientRect()
const x = (e.clientX - rect.left) / (zoom / 100)
const y = (e.clientY - rect.top) / (zoom / 100)
setIsDrawing(true)
setDrawStart({ x, y })
setDrawEnd({ x, y })
}
const handleMouseMove = (e: React.MouseEvent<HTMLDivElement>) => {
if (!isDrawing || !canvasRef.current || !imageSize) return
const rect = canvasRef.current.getBoundingClientRect()
const x = (e.clientX - rect.left) / (zoom / 100)
const y = (e.clientY - rect.top) / (zoom / 100)
setDrawEnd({ x, y })
}
const handleMouseUp = () => {
if (!isDrawing || !drawStart || !drawEnd || !imageSize) {
setIsDrawing(false)
return
}
const bbox_x = Math.min(drawStart.x, drawEnd.x)
const bbox_y = Math.min(drawStart.y, drawEnd.y)
const bbox_width = Math.abs(drawEnd.x - drawStart.x)
const bbox_height = Math.abs(drawEnd.y - drawStart.y)
// Only create if box is large enough (min 10x10 pixels)
if (bbox_width > 10 && bbox_height > 10) {
createAnnotation({
page_number: currentPage,
class_id: selectedClassId,
bbox: {
x: Math.round(bbox_x),
y: Math.round(bbox_y),
width: Math.round(bbox_width),
height: Math.round(bbox_height),
},
})
}
setIsDrawing(false)
setDrawStart(null)
setDrawEnd(null)
}
const handleDeleteAnnotation = (annotationId: string) => {
if (confirm('Are you sure you want to delete this annotation?')) {
deleteAnnotation(annotationId)
setSelectedId(null)
}
}
if (isLoading || !document) {
return (
<div className="flex h-screen items-center justify-center">
<div className="text-warm-text-muted">Loading...</div>
</div>
)
}
// Get current page annotations
const pageAnnotations = annotations?.filter((a) => a.page_number === currentPage) || []
return (
<div className="flex h-[calc(100vh-56px)] overflow-hidden">
{/* Main Canvas Area */}
<div className="flex-1 bg-warm-bg flex flex-col relative">
{/* Toolbar */}
<div className="h-14 border-b border-warm-border bg-white flex items-center justify-between px-4 z-10">
<div className="flex items-center gap-4">
<button
onClick={onBack}
className="p-2 hover:bg-warm-hover rounded-md text-warm-text-secondary transition-colors"
>
<ChevronLeft size={20} />
</button>
<div>
<h2 className="text-sm font-semibold text-warm-text-primary">{document.filename}</h2>
<p className="text-xs text-warm-text-muted">
Page {currentPage} of {document.page_count}
</p>
</div>
<div className="h-6 w-px bg-warm-divider mx-2" />
<div className="flex items-center gap-2">
<button
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
onClick={() => setZoom((z) => Math.max(50, z - 10))}
>
<ZoomOut size={16} />
</button>
<span className="text-xs font-mono w-12 text-center text-warm-text-secondary">
{zoom}%
</span>
<button
className="p-1.5 hover:bg-warm-hover rounded text-warm-text-secondary"
onClick={() => setZoom((z) => Math.min(200, z + 10))}
>
<ZoomIn size={16} />
</button>
</div>
</div>
<div className="flex gap-2">
<Button variant="secondary" size="sm">
Auto-label
</Button>
<Button
variant="primary"
size="sm"
onClick={handleMarkComplete}
disabled={isMarkingComplete || document.status === 'labeled'}
>
<CheckCircle size={16} className="mr-1" />
{isMarkingComplete ? 'Saving...' : document.status === 'labeled' ? 'Labeled' : 'Mark Complete'}
</Button>
{document.page_count > 1 && (
<div className="flex gap-1">
<Button
variant="secondary"
size="sm"
onClick={() => setCurrentPage((p) => Math.max(1, p - 1))}
disabled={currentPage === 1}
>
Prev
</Button>
<Button
variant="secondary"
size="sm"
onClick={() => setCurrentPage((p) => Math.min(document.page_count, p + 1))}
disabled={currentPage === document.page_count}
>
Next
</Button>
</div>
)}
</div>
</div>
{/* Canvas Scroll Area */}
<div className="flex-1 overflow-auto p-8 flex justify-center bg-warm-bg">
<div
ref={canvasRef}
className="bg-white shadow-lg relative transition-transform duration-200 ease-out origin-top"
style={{
width: imageSize?.width || 800,
height: imageSize?.height || 1132,
transform: `scale(${zoom / 100})`,
marginBottom: '100px',
cursor: isDrawing ? 'crosshair' : 'default',
}}
onMouseDown={handleMouseDown}
onMouseMove={handleMouseMove}
onMouseUp={handleMouseUp}
onClick={() => setSelectedId(null)}
>
{/* Document Image */}
{imageBlobUrl ? (
<img
ref={imageRef}
src={imageBlobUrl}
alt={`Page ${currentPage}`}
className="w-full h-full object-contain select-none pointer-events-none"
onLoad={handleImageLoad}
/>
) : (
<div className="flex items-center justify-center h-full">
<div className="text-warm-text-muted">Loading image...</div>
</div>
)}
{/* Annotation Overlays */}
{pageAnnotations.map((ann) => {
const isSelected = selectedId === ann.annotation_id
return (
<div
key={ann.annotation_id}
onClick={(e) => {
e.stopPropagation()
setSelectedId(ann.annotation_id)
}}
className={`
absolute group cursor-pointer transition-all duration-100
${
ann.source === 'auto'
? 'border border-dashed border-warm-text-muted bg-transparent'
: 'border-2 border-warm-text-secondary bg-warm-text-secondary/5'
}
${
isSelected
? 'border-2 border-warm-state-info ring-4 ring-warm-state-info/10 z-20'
: 'hover:bg-warm-state-info/5 z-10'
}
`}
style={{
left: ann.bbox.x,
top: ann.bbox.y,
width: ann.bbox.width,
height: ann.bbox.height,
}}
>
{/* Label Tag */}
<div
className={`
absolute -top-6 left-0 text-[10px] uppercase font-bold px-1.5 py-0.5 rounded-sm tracking-wide shadow-sm whitespace-nowrap
${
isSelected
? 'bg-warm-state-info text-white'
: 'bg-white text-warm-text-secondary border border-warm-border'
}
`}
>
{ann.class_name}
</div>
{/* Resize Handles (Visual only) */}
{isSelected && (
<>
<div className="absolute -top-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
<div className="absolute -top-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
<div className="absolute -bottom-1 -left-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
<div className="absolute -bottom-1 -right-1 w-2 h-2 bg-white border border-warm-state-info rounded-full" />
</>
)}
</div>
)
})}
{/* Drawing Box Preview */}
{isDrawing && drawStart && drawEnd && (
<div
className="absolute border-2 border-warm-state-info bg-warm-state-info/10 z-30 pointer-events-none"
style={{
left: Math.min(drawStart.x, drawEnd.x),
top: Math.min(drawStart.y, drawEnd.y),
width: Math.abs(drawEnd.x - drawStart.x),
height: Math.abs(drawEnd.y - drawStart.y),
}}
/>
)}
</div>
</div>
</div>
{/* Right Sidebar */}
<div className="w-80 bg-white border-l border-warm-border flex flex-col shadow-[-4px_0_15px_-3px_rgba(0,0,0,0.03)] z-20">
{/* Field Selector */}
<div className="p-4 border-b border-warm-border">
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Draw Annotation</h3>
<div className="space-y-2">
<label className="block text-xs text-warm-text-muted mb-1">Select Field Type</label>
<select
value={selectedClassId}
onChange={(e) => setSelectedClassId(Number(e.target.value))}
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
>
{Object.entries(FIELD_CLASSES).map(([id, name]) => (
<option key={id} value={id}>
{name.replace(/_/g, ' ')}
</option>
))}
</select>
<p className="text-xs text-warm-text-muted mt-2">
Click and drag on the document to create a bounding box
</p>
</div>
</div>
{/* Document Info Card */}
<div className="p-4 border-b border-warm-border">
<div className="bg-white rounded-lg border border-warm-border p-4 shadow-sm">
<h3 className="text-sm font-semibold text-warm-text-primary mb-3">Document Info</h3>
<div className="space-y-2">
<div className="flex justify-between text-xs">
<span className="text-warm-text-muted">Status</span>
<span className="text-warm-text-secondary font-medium capitalize">
{document.status}
</span>
</div>
<div className="flex justify-between text-xs">
<span className="text-warm-text-muted">Size</span>
<span className="text-warm-text-secondary font-medium">
{(document.file_size / 1024 / 1024).toFixed(2)} MB
</span>
</div>
<div className="flex justify-between text-xs">
<span className="text-warm-text-muted">Uploaded</span>
<span className="text-warm-text-secondary font-medium">
{new Date(document.created_at).toLocaleDateString()}
</span>
</div>
<div className="flex justify-between items-center text-xs">
<span className="text-warm-text-muted">Group</span>
{isEditingGroupKey ? (
<div className="flex items-center gap-1">
<input
type="text"
value={editGroupKeyValue}
onChange={(e) => setEditGroupKeyValue(e.target.value)}
className="w-24 px-1.5 py-0.5 text-xs border border-warm-border rounded focus:outline-none focus:ring-1 focus:ring-warm-state-info"
placeholder="group key"
autoFocus
/>
<button
onClick={() => {
updateGroupKey(
{ documentId: docId, groupKey: editGroupKeyValue.trim() || null },
{
onSuccess: () => {
setIsEditingGroupKey(false)
refetch()
},
onError: () => {
alert('Failed to update group key. Please try again.')
},
}
)
}}
disabled={isUpdatingGroupKey}
className="p-0.5 text-warm-state-success hover:bg-warm-hover rounded"
>
<Check size={14} />
</button>
<button
onClick={() => {
setIsEditingGroupKey(false)
setEditGroupKeyValue(document.group_key || '')
}}
className="p-0.5 text-warm-state-error hover:bg-warm-hover rounded"
>
<X size={14} />
</button>
</div>
) : (
<div className="flex items-center gap-1">
<span className="text-warm-text-secondary font-medium">
{document.group_key || '-'}
</span>
<button
onClick={() => {
setEditGroupKeyValue(document.group_key || '')
setIsEditingGroupKey(true)
}}
className="p-0.5 text-warm-text-muted hover:text-warm-text-secondary hover:bg-warm-hover rounded"
>
<Edit2 size={12} />
</button>
</div>
)}
</div>
</div>
</div>
</div>
{/* Annotations List */}
<div className="flex-1 overflow-y-auto p-4">
<div className="flex items-center justify-between mb-4">
<h3 className="text-sm font-semibold text-warm-text-primary">Annotations</h3>
<span className="text-xs text-warm-text-muted">{pageAnnotations.length} items</span>
</div>
{pageAnnotations.length === 0 ? (
<div className="text-center py-8 text-warm-text-muted">
<Tag size={48} className="mx-auto mb-3 opacity-20" />
<p className="text-sm">No annotations yet</p>
<p className="text-xs mt-1">Draw on the document to add annotations</p>
</div>
) : (
<div className="space-y-3">
{pageAnnotations.map((ann) => (
<div
key={ann.annotation_id}
onClick={() => setSelectedId(ann.annotation_id)}
className={`
group p-3 rounded-md border transition-all duration-150 cursor-pointer
${
selectedId === ann.annotation_id
? 'bg-warm-bg border-warm-state-info shadow-sm'
: 'bg-white border-warm-border hover:border-warm-text-muted'
}
`}
>
<div className="flex justify-between items-start mb-1">
<span className="text-xs font-bold text-warm-text-secondary uppercase tracking-wider">
{ann.class_name.replace(/_/g, ' ')}
</span>
{selectedId === ann.annotation_id && (
<div className="flex gap-1">
<button
onClick={() => handleDeleteAnnotation(ann.annotation_id)}
className="text-warm-text-muted hover:text-warm-state-error"
disabled={isDeleting}
>
<Trash2 size={12} />
</button>
</div>
)}
</div>
<p className="text-sm text-warm-text-muted font-mono truncate">
{ann.text_value || '(no text)'}
</p>
<div className="flex items-center gap-2 mt-2">
<span
className={`text-[10px] px-1.5 py-0.5 rounded ${
ann.source === 'auto'
? 'bg-blue-50 text-blue-700'
: 'bg-green-50 text-green-700'
}`}
>
{ann.source}
</span>
{ann.confidence && (
<span className="text-[10px] text-warm-text-muted">
{(ann.confidence * 100).toFixed(0)}%
</span>
)}
</div>
</div>
))}
</div>
)}
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,466 @@
import React, { useState, useRef } from 'react'
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock } from 'lucide-react'
import { Button } from './Button'
import { inferenceApi } from '../api/endpoints'
import type { InferenceResult } from '../api/types'
export const InferenceDemo: React.FC = () => {
const [isDragging, setIsDragging] = useState(false)
const [selectedFile, setSelectedFile] = useState<File | null>(null)
const [isProcessing, setIsProcessing] = useState(false)
const [result, setResult] = useState<InferenceResult | null>(null)
const [error, setError] = useState<string | null>(null)
const fileInputRef = useRef<HTMLInputElement>(null)
const handleFileSelect = (file: File | null) => {
if (!file) return
const validTypes = ['application/pdf', 'image/png', 'image/jpeg', 'image/jpg']
if (!validTypes.includes(file.type)) {
setError('Please upload a PDF, PNG, or JPG file')
return
}
if (file.size > 50 * 1024 * 1024) {
setError('File size must be less than 50MB')
return
}
setSelectedFile(file)
setResult(null)
setError(null)
}
const handleDrop = (e: React.DragEvent) => {
e.preventDefault()
setIsDragging(false)
if (e.dataTransfer.files.length > 0) {
handleFileSelect(e.dataTransfer.files[0])
}
}
const handleBrowseClick = () => {
fileInputRef.current?.click()
}
const handleProcess = async () => {
if (!selectedFile) return
setIsProcessing(true)
setError(null)
try {
const response = await inferenceApi.processDocument(selectedFile)
console.log('API Response:', response)
console.log('Visualization URL:', response.result?.visualization_url)
setResult(response.result)
} catch (err) {
setError(err instanceof Error ? err.message : 'Processing failed')
} finally {
setIsProcessing(false)
}
}
const handleReset = () => {
setSelectedFile(null)
setResult(null)
setError(null)
}
const formatFieldName = (field: string): string => {
const fieldNames: Record<string, string> = {
InvoiceNumber: 'Invoice Number',
InvoiceDate: 'Invoice Date',
InvoiceDueDate: 'Due Date',
OCR: 'OCR Number',
Amount: 'Amount',
Bankgiro: 'Bankgiro',
Plusgiro: 'Plusgiro',
supplier_org_number: 'Supplier Org Number',
customer_number: 'Customer Number',
payment_line: 'Payment Line',
}
return fieldNames[field] || field
}
return (
<div className="max-w-7xl mx-auto px-4 py-6 space-y-6">
{/* Header */}
<div className="text-center">
<h2 className="text-3xl font-bold text-warm-text-primary mb-2">
Invoice Extraction Demo
</h2>
<p className="text-warm-text-muted">
Upload a Swedish invoice to see our AI-powered field extraction in action
</p>
</div>
{/* Upload Area */}
{!result && (
<div className="max-w-2xl mx-auto">
<div className="bg-warm-card rounded-xl border border-warm-border p-8 shadow-sm">
<div
className={`
relative h-72 rounded-xl border-2 border-dashed transition-all duration-200
${isDragging
? 'border-warm-text-secondary bg-warm-selected scale-[1.02]'
: 'border-warm-divider bg-warm-bg hover:bg-warm-hover hover:border-warm-text-secondary/50'
}
${isProcessing ? 'opacity-60 pointer-events-none' : 'cursor-pointer'}
`}
onDragOver={(e) => {
e.preventDefault()
setIsDragging(true)
}}
onDragLeave={() => setIsDragging(false)}
onDrop={handleDrop}
onClick={handleBrowseClick}
>
<div className="absolute inset-0 flex flex-col items-center justify-center gap-6">
{isProcessing ? (
<>
<Loader2 size={56} className="text-warm-text-secondary animate-spin" />
<div className="text-center">
<p className="text-lg font-semibold text-warm-text-primary mb-1">
Processing invoice...
</p>
<p className="text-sm text-warm-text-muted">
This may take a few moments
</p>
</div>
</>
) : selectedFile ? (
<>
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
<FileText size={40} className="text-warm-text-secondary" />
</div>
<div className="text-center px-4">
<p className="text-lg font-semibold text-warm-text-primary mb-1">
{selectedFile.name}
</p>
<p className="text-sm text-warm-text-muted">
{(selectedFile.size / 1024 / 1024).toFixed(2)} MB
</p>
</div>
</>
) : (
<>
<div className="p-5 bg-warm-text-secondary/10 rounded-full">
<UploadCloud size={40} className="text-warm-text-secondary" />
</div>
<div className="text-center px-4">
<p className="text-lg font-semibold text-warm-text-primary mb-2">
Drag & drop invoice here
</p>
<p className="text-sm text-warm-text-muted mb-3">
or{' '}
<span className="text-warm-text-secondary font-medium">
browse files
</span>
</p>
<p className="text-xs text-warm-text-muted">
Supports PDF, PNG, JPG (up to 50MB)
</p>
</div>
</>
)}
</div>
</div>
<input
ref={fileInputRef}
type="file"
accept=".pdf,image/*"
className="hidden"
onChange={(e) => handleFileSelect(e.target.files?.[0] || null)}
/>
{error && (
<div className="mt-5 p-4 bg-red-50 border border-red-200 rounded-lg flex items-start gap-3">
<AlertCircle size={18} className="text-red-600 flex-shrink-0 mt-0.5" />
<span className="text-sm text-red-800 font-medium">{error}</span>
</div>
)}
{selectedFile && !isProcessing && (
<div className="mt-6 flex gap-3 justify-end">
<Button variant="secondary" onClick={handleReset}>
Cancel
</Button>
<Button onClick={handleProcess}>Process Invoice</Button>
</div>
)}
</div>
</div>
)}
{/* Results */}
{result && (
<div className="space-y-6">
{/* Status Header */}
<div className="bg-warm-card rounded-xl border border-warm-border shadow-sm overflow-hidden">
<div className="p-6 flex items-center justify-between border-b border-warm-divider">
<div className="flex items-center gap-4">
{result.success ? (
<div className="p-3 bg-green-100 rounded-xl">
<CheckCircle2 size={28} className="text-green-600" />
</div>
) : (
<div className="p-3 bg-yellow-100 rounded-xl">
<AlertCircle size={28} className="text-yellow-600" />
</div>
)}
<div>
<h3 className="text-xl font-bold text-warm-text-primary">
{result.success ? 'Extraction Complete' : 'Partial Results'}
</h3>
<p className="text-sm text-warm-text-muted mt-0.5">
Document ID: <span className="font-mono">{result.document_id}</span>
</p>
</div>
</div>
<Button variant="secondary" onClick={handleReset}>
Process Another
</Button>
</div>
<div className="px-6 py-4 bg-warm-bg/50 flex items-center gap-6 text-sm">
<div className="flex items-center gap-2 text-warm-text-secondary">
<Clock size={16} />
<span className="font-medium">
{result.processing_time_ms.toFixed(0)}ms
</span>
</div>
{result.fallback_used && (
<span className="px-3 py-1.5 bg-warm-selected rounded-md text-warm-text-secondary font-medium text-xs">
Fallback OCR Used
</span>
)}
</div>
</div>
{/* Main Content Grid */}
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
{/* Left Column: Extracted Fields */}
<div className="lg:col-span-2 space-y-6">
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
Extracted Fields
</h3>
<div className="flex flex-wrap gap-4">
{Object.entries(result.fields).map(([field, value]) => {
const confidence = result.confidence[field]
return (
<div
key={field}
className="p-4 bg-warm-bg/70 rounded-lg border border-warm-divider hover:border-warm-text-secondary/30 transition-colors w-[calc(50%-0.5rem)]"
>
<div className="text-xs font-semibold text-warm-text-muted uppercase tracking-wide mb-2">
{formatFieldName(field)}
</div>
<div className="text-sm font-bold text-warm-text-primary mb-2 min-h-[1.5rem]">
{value || <span className="text-warm-text-muted italic">N/A</span>}
</div>
{confidence && (
<div className="flex items-center gap-1.5 text-xs font-medium text-warm-text-secondary">
<CheckCircle2 size={13} />
<span>{(confidence * 100).toFixed(1)}%</span>
</div>
)}
</div>
)
})}
</div>
</div>
{/* Visualization */}
{result.visualization_url && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
Detection Visualization
</h3>
<div className="bg-warm-bg rounded-lg overflow-hidden border border-warm-divider">
<img
src={`${import.meta.env.VITE_API_URL || 'http://localhost:8000'}${result.visualization_url}`}
alt="Detection visualization"
className="w-full h-auto"
/>
</div>
</div>
)}
</div>
{/* Right Column: Cross-Validation & Errors */}
<div className="space-y-6">
{/* Cross-Validation */}
{result.cross_validation && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
Payment Line Validation
</h3>
<div
className={`
p-4 rounded-lg mb-4 flex items-center gap-3
${result.cross_validation.is_valid
? 'bg-green-50 border border-green-200'
: 'bg-yellow-50 border border-yellow-200'
}
`}
>
{result.cross_validation.is_valid ? (
<>
<CheckCircle2 size={22} className="text-green-600 flex-shrink-0" />
<span className="font-bold text-green-800">All Fields Match</span>
</>
) : (
<>
<AlertCircle size={22} className="text-yellow-600 flex-shrink-0" />
<span className="font-bold text-yellow-800">Mismatch Detected</span>
</>
)}
</div>
<div className="space-y-2.5">
{result.cross_validation.payment_line_ocr && (
<div
className={`
p-3 rounded-lg border transition-colors
${result.cross_validation.ocr_match === true
? 'bg-green-50 border-green-200'
: result.cross_validation.ocr_match === false
? 'bg-red-50 border-red-200'
: 'bg-warm-bg border-warm-divider'
}
`}
>
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="text-xs font-semibold text-warm-text-muted mb-1">
OCR NUMBER
</div>
<div className="text-sm font-bold text-warm-text-primary font-mono">
{result.cross_validation.payment_line_ocr}
</div>
</div>
{result.cross_validation.ocr_match === true && (
<CheckCircle2 size={16} className="text-green-600" />
)}
{result.cross_validation.ocr_match === false && (
<AlertCircle size={16} className="text-red-600" />
)}
</div>
</div>
)}
{result.cross_validation.payment_line_amount && (
<div
className={`
p-3 rounded-lg border transition-colors
${result.cross_validation.amount_match === true
? 'bg-green-50 border-green-200'
: result.cross_validation.amount_match === false
? 'bg-red-50 border-red-200'
: 'bg-warm-bg border-warm-divider'
}
`}
>
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="text-xs font-semibold text-warm-text-muted mb-1">
AMOUNT
</div>
<div className="text-sm font-bold text-warm-text-primary font-mono">
{result.cross_validation.payment_line_amount}
</div>
</div>
{result.cross_validation.amount_match === true && (
<CheckCircle2 size={16} className="text-green-600" />
)}
{result.cross_validation.amount_match === false && (
<AlertCircle size={16} className="text-red-600" />
)}
</div>
</div>
)}
{result.cross_validation.payment_line_account && (
<div
className={`
p-3 rounded-lg border transition-colors
${(result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === true
? 'bg-green-50 border-green-200'
: (result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === false
? 'bg-red-50 border-red-200'
: 'bg-warm-bg border-warm-divider'
}
`}
>
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="text-xs font-semibold text-warm-text-muted mb-1">
{result.cross_validation.payment_line_account_type === 'bankgiro'
? 'BANKGIRO'
: 'PLUSGIRO'}
</div>
<div className="text-sm font-bold text-warm-text-primary font-mono">
{result.cross_validation.payment_line_account}
</div>
</div>
{(result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === true && (
<CheckCircle2 size={16} className="text-green-600" />
)}
{(result.cross_validation.payment_line_account_type === 'bankgiro'
? result.cross_validation.bankgiro_match
: result.cross_validation.plusgiro_match) === false && (
<AlertCircle size={16} className="text-red-600" />
)}
</div>
</div>
)}
</div>
{result.cross_validation.details.length > 0 && (
<div className="mt-4 p-3 bg-warm-bg/70 rounded-lg text-xs text-warm-text-secondary leading-relaxed border border-warm-divider">
{result.cross_validation.details[result.cross_validation.details.length - 1]}
</div>
)}
</div>
)}
{/* Errors */}
{result.errors.length > 0 && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
<span className="w-1 h-5 bg-red-500 rounded-full"></span>
Issues
</h3>
<div className="space-y-2.5">
{result.errors.map((err, idx) => (
<div
key={idx}
className="p-3 bg-yellow-50 border border-yellow-200 rounded-lg flex items-start gap-3"
>
<AlertCircle size={16} className="text-yellow-600 flex-shrink-0 mt-0.5" />
<span className="text-xs text-yellow-800 leading-relaxed">{err}</span>
</div>
))}
</div>
</div>
)}
</div>
</div>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,102 @@
import React, { useState } from 'react';
import { Box, LayoutTemplate, Users, BookOpen, LogOut, Sparkles } from 'lucide-react';
interface LayoutProps {
children: React.ReactNode;
activeView: string;
onNavigate: (view: string) => void;
onLogout?: () => void;
}
export const Layout: React.FC<LayoutProps> = ({ children, activeView, onNavigate, onLogout }) => {
const [showDropdown, setShowDropdown] = useState(false);
const navItems = [
{ id: 'dashboard', label: 'Dashboard', icon: LayoutTemplate },
{ id: 'demo', label: 'Demo', icon: Sparkles },
{ id: 'training', label: 'Training', icon: Box }, // Mapped to Compliants visually in prompt, using logical name
{ id: 'documents', label: 'Documents', icon: BookOpen },
{ id: 'models', label: 'Models', icon: Users }, // Contacts in prompt, mapped to models for this use case
];
return (
<div className="min-h-screen bg-warm-bg font-sans text-warm-text-primary flex flex-col">
{/* Top Navigation */}
<nav className="h-14 bg-warm-bg border-b border-warm-border px-6 flex items-center justify-between shrink-0 sticky top-0 z-40">
<div className="flex items-center gap-8">
{/* Logo */}
<div className="flex items-center gap-2">
<div className="w-8 h-8 bg-warm-text-primary rounded-full flex items-center justify-center text-white">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="3" strokeLinecap="round" strokeLinejoin="round">
<path d="M12 2L2 7l10 5 10-5-10-5zM2 17l10 5 10-5M2 12l10 5 10-5"/>
</svg>
</div>
</div>
{/* Nav Links */}
<div className="flex h-14">
{navItems.map(item => {
const isActive = activeView === item.id || (activeView === 'detail' && item.id === 'documents');
return (
<button
key={item.id}
onClick={() => onNavigate(item.id)}
className={`
relative px-4 h-full flex items-center text-sm font-medium transition-colors
${isActive ? 'text-warm-text-primary' : 'text-warm-text-muted hover:text-warm-text-secondary'}
`}
>
{item.label}
{isActive && (
<div className="absolute bottom-0 left-0 right-0 h-0.5 bg-warm-text-secondary rounded-t-full mx-2" />
)}
</button>
);
})}
</div>
</div>
{/* User Profile */}
<div className="flex items-center gap-3 pl-6 border-l border-warm-border h-6 relative">
<button
onClick={() => setShowDropdown(!showDropdown)}
className="w-8 h-8 rounded-full bg-warm-selected flex items-center justify-center text-xs font-semibold text-warm-text-secondary border border-warm-divider hover:bg-warm-hover transition-colors"
>
AD
</button>
{showDropdown && (
<>
<div
className="fixed inset-0 z-10"
onClick={() => setShowDropdown(false)}
/>
<div className="absolute right-0 top-10 w-48 bg-warm-card border border-warm-border rounded-lg shadow-modal z-20">
<div className="p-3 border-b border-warm-border">
<p className="text-sm font-medium text-warm-text-primary">Admin User</p>
<p className="text-xs text-warm-text-muted mt-0.5">Authenticated</p>
</div>
{onLogout && (
<button
onClick={() => {
setShowDropdown(false)
onLogout()
}}
className="w-full px-3 py-2 text-left text-sm text-warm-text-secondary hover:bg-warm-hover transition-colors flex items-center gap-2"
>
<LogOut size={14} />
Sign Out
</button>
)}
</div>
</>
)}
</div>
</nav>
{/* Main Content */}
<main className="flex-1 overflow-auto">
{children}
</main>
</div>
);
};

View File

@@ -0,0 +1,188 @@
import React, { useState } from 'react'
import { Button } from './Button'
interface LoginProps {
onLogin: (token: string) => void
}
export const Login: React.FC<LoginProps> = ({ onLogin }) => {
const [token, setToken] = useState('')
const [name, setName] = useState('')
const [description, setDescription] = useState('')
const [isCreating, setIsCreating] = useState(false)
const [error, setError] = useState('')
const [createdToken, setCreatedToken] = useState('')
const handleLoginWithToken = () => {
if (!token.trim()) {
setError('Please enter a token')
return
}
localStorage.setItem('admin_token', token.trim())
onLogin(token.trim())
}
const handleCreateToken = async () => {
if (!name.trim()) {
setError('Please enter a token name')
return
}
setIsCreating(true)
setError('')
try {
const response = await fetch('http://localhost:8000/api/v1/admin/auth/token', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
name: name.trim(),
description: description.trim() || undefined,
}),
})
if (!response.ok) {
throw new Error('Failed to create token')
}
const data = await response.json()
setCreatedToken(data.token)
setToken(data.token)
setError('')
} catch (err) {
setError('Failed to create token. Please check your connection.')
console.error(err)
} finally {
setIsCreating(false)
}
}
const handleUseCreatedToken = () => {
if (createdToken) {
localStorage.setItem('admin_token', createdToken)
onLogin(createdToken)
}
}
return (
<div className="min-h-screen bg-warm-bg flex items-center justify-center p-4">
<div className="bg-warm-card border border-warm-border rounded-lg shadow-modal p-8 max-w-md w-full">
<h1 className="text-2xl font-bold text-warm-text-primary mb-2">
Admin Authentication
</h1>
<p className="text-sm text-warm-text-muted mb-6">
Sign in with an admin token to access the document management system
</p>
{error && (
<div className="mb-4 p-3 bg-red-50 border border-red-200 text-red-800 rounded text-sm">
{error}
</div>
)}
{createdToken && (
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded">
<p className="text-sm font-medium text-green-800 mb-2">Token created successfully!</p>
<div className="bg-white border border-green-300 rounded p-2 mb-3">
<code className="text-xs font-mono text-warm-text-primary break-all">
{createdToken}
</code>
</div>
<p className="text-xs text-green-700 mb-3">
Save this token securely. You won't be able to see it again.
</p>
<Button onClick={handleUseCreatedToken} className="w-full">
Use This Token
</Button>
</div>
)}
<div className="space-y-6">
{/* Login with existing token */}
<div>
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
Sign in with existing token
</h2>
<div className="space-y-3">
<div>
<label className="block text-sm text-warm-text-secondary mb-1">
Admin Token
</label>
<input
type="text"
value={token}
onChange={(e) => setToken(e.target.value)}
placeholder="Enter your admin token"
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info font-mono"
onKeyDown={(e) => e.key === 'Enter' && handleLoginWithToken()}
/>
</div>
<Button onClick={handleLoginWithToken} className="w-full">
Sign In
</Button>
</div>
</div>
<div className="relative">
<div className="absolute inset-0 flex items-center">
<div className="w-full border-t border-warm-border"></div>
</div>
<div className="relative flex justify-center text-xs">
<span className="px-2 bg-warm-card text-warm-text-muted">OR</span>
</div>
</div>
{/* Create new token */}
<div>
<h2 className="text-sm font-semibold text-warm-text-secondary mb-3">
Create new admin token
</h2>
<div className="space-y-3">
<div>
<label className="block text-sm text-warm-text-secondary mb-1">
Token Name <span className="text-red-500">*</span>
</label>
<input
type="text"
value={name}
onChange={(e) => setName(e.target.value)}
placeholder="e.g., my-laptop"
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<div>
<label className="block text-sm text-warm-text-secondary mb-1">
Description (optional)
</label>
<input
type="text"
value={description}
onChange={(e) => setDescription(e.target.value)}
placeholder="e.g., Personal laptop access"
className="w-full px-3 py-2 border border-warm-border rounded-md text-sm focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<Button
onClick={handleCreateToken}
variant="secondary"
disabled={isCreating}
className="w-full"
>
{isCreating ? 'Creating...' : 'Create Token'}
</Button>
</div>
</div>
</div>
<div className="mt-6 pt-4 border-t border-warm-border">
<p className="text-xs text-warm-text-muted">
Admin tokens are used to authenticate with the document management API.
Keep your tokens secure and never share them.
</p>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,208 @@
import React, { useState } from 'react';
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from 'recharts';
import { Loader2, Power, CheckCircle } from 'lucide-react';
import { Button } from './Button';
import { useModels, useModelDetail } from '../hooks';
import type { ModelVersionItem } from '../api/types';
const formatDate = (dateString: string | null): string => {
if (!dateString) return 'N/A';
return new Date(dateString).toLocaleString();
};
export const Models: React.FC = () => {
const [selectedModel, setSelectedModel] = useState<ModelVersionItem | null>(null);
const { models, isLoading, activateModel, isActivating } = useModels();
const { model: modelDetail } = useModelDetail(selectedModel?.version_id ?? null);
// Build chart data from selected model's metrics
const metricsData = modelDetail ? [
{ name: 'Precision', value: (modelDetail.metrics_precision ?? 0) * 100 },
{ name: 'Recall', value: (modelDetail.metrics_recall ?? 0) * 100 },
{ name: 'mAP', value: (modelDetail.metrics_mAP ?? 0) * 100 },
] : [
{ name: 'Precision', value: 0 },
{ name: 'Recall', value: 0 },
{ name: 'mAP', value: 0 },
];
// Build comparison chart from all models (with placeholder if empty)
const chartData = models.length > 0
? models.slice(0, 4).map(m => ({
name: m.version,
value: (m.metrics_mAP ?? 0) * 100,
}))
: [
{ name: 'Model A', value: 0 },
{ name: 'Model B', value: 0 },
{ name: 'Model C', value: 0 },
{ name: 'Model D', value: 0 },
];
return (
<div className="p-8 max-w-7xl mx-auto flex gap-8">
{/* Left: Job History */}
<div className="flex-1">
<h2 className="text-2xl font-bold text-warm-text-primary mb-6">Models & History</h2>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Model Versions</h3>
{isLoading ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="animate-spin text-warm-text-muted" size={32} />
</div>
) : models.length === 0 ? (
<div className="text-center py-12 text-warm-text-muted">
No model versions found. Complete a training task to create a model version.
</div>
) : (
<div className="space-y-4">
{models.map(model => (
<div
key={model.version_id}
onClick={() => setSelectedModel(model)}
className={`bg-warm-card border rounded-lg p-5 shadow-sm cursor-pointer transition-colors ${
selectedModel?.version_id === model.version_id
? 'border-warm-text-secondary'
: 'border-warm-border hover:border-warm-divider'
}`}
>
<div className="flex justify-between items-start mb-2">
<div>
<h4 className="font-semibold text-warm-text-primary text-lg mb-1">
{model.name}
{model.is_active && <CheckCircle size={16} className="inline ml-2 text-warm-state-info" />}
</h4>
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
</div>
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
model.is_active
? 'bg-warm-state-info/10 text-warm-state-info'
: 'bg-warm-selected text-warm-state-success'
}`}>
{model.is_active ? 'Active' : model.status}
</span>
</div>
<div className="mt-4 flex gap-8">
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Documents</span>
<span className="text-lg font-mono text-warm-text-secondary">{model.document_count}</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">mAP</span>
<span className="text-lg font-mono text-warm-text-secondary">
{model.metrics_mAP ? `${(model.metrics_mAP * 100).toFixed(1)}%` : 'N/A'}
</span>
</div>
<div>
<span className="block text-xs text-warm-text-muted uppercase tracking-wide">Version</span>
<span className="text-lg font-mono text-warm-text-secondary">{model.version}</span>
</div>
</div>
</div>
))}
</div>
)}
</div>
{/* Right: Model Detail */}
<div className="w-[400px]">
<div className="bg-warm-card border border-warm-border rounded-lg p-6 shadow-card sticky top-8">
<div className="flex justify-between items-center mb-6">
<h3 className="text-xl font-bold text-warm-text-primary">Model Detail</h3>
<span className={`text-sm font-medium ${
selectedModel?.is_active ? 'text-warm-state-info' : 'text-warm-state-success'
}`}>
{selectedModel ? (selectedModel.is_active ? 'Active' : selectedModel.status) : '-'}
</span>
</div>
<div className="mb-8">
<p className="text-sm text-warm-text-muted mb-1">Model name</p>
<p className="font-medium text-warm-text-primary">
{selectedModel ? `${selectedModel.name} (${selectedModel.version})` : 'Select a model'}
</p>
</div>
<div className="space-y-8">
{/* Chart 1 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Model Comparison (mAP)</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={chartData}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
<YAxis hide domain={[0, 100]} />
<Tooltip
cursor={{fill: '#F1F0ED'}}
contentStyle={{borderRadius: '8px', border: '1px solid #E6E4E1', boxShadow: '0 2px 5px rgba(0,0,0,0.05)'}}
formatter={(value: number) => [`${value.toFixed(1)}%`, 'mAP']}
/>
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
</ResponsiveContainer>
</div>
</div>
{/* Chart 2 */}
<div>
<h4 className="text-sm font-semibold text-warm-text-secondary mb-4">Performance Metrics</h4>
<div className="h-40">
<ResponsiveContainer width="100%" height="100%">
<BarChart data={metricsData}>
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#E6E4E1" />
<XAxis dataKey="name" tick={{fontSize: 10, fill: '#6B6B6B'}} axisLine={false} tickLine={false} />
<YAxis hide domain={[0, 100]} />
<Tooltip
cursor={{fill: '#F1F0ED'}}
formatter={(value: number) => [`${value.toFixed(1)}%`, 'Score']}
/>
<Bar dataKey="value" fill="#3A3A3A" radius={[4, 4, 0, 0]} barSize={32} />
</BarChart>
</ResponsiveContainer>
</div>
</div>
</div>
<div className="mt-8 space-y-3">
{selectedModel && !selectedModel.is_active ? (
<Button
className="w-full"
onClick={() => activateModel(selectedModel.version_id)}
disabled={isActivating}
>
{isActivating ? (
<>
<Loader2 size={16} className="mr-2 animate-spin" />
Activating...
</>
) : (
<>
<Power size={16} className="mr-2" />
Activate for Inference
</>
)}
</Button>
) : (
<Button className="w-full" disabled={!selectedModel}>
{selectedModel?.is_active ? (
<>
<CheckCircle size={16} className="mr-2" />
Currently Active
</>
) : (
'Select a Model'
)}
</Button>
)}
<div className="flex gap-3">
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>View Logs</Button>
<Button variant="secondary" className="flex-1" disabled={!selectedModel}>Use as Base</Button>
</div>
</div>
</div>
</div>
</div>
);
};

View File

@@ -0,0 +1,487 @@
import React, { useState, useMemo } from 'react'
import { useQuery } from '@tanstack/react-query'
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
import { Button } from './Button'
import { AugmentationConfig } from './AugmentationConfig'
import { useDatasets } from '../hooks/useDatasets'
import { useTrainingDocuments } from '../hooks/useTraining'
import { trainingApi } from '../api/endpoints'
import type { DatasetListItem } from '../api/types'
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
type Tab = 'datasets' | 'create'
interface TrainingProps {
onNavigate?: (view: string, id?: string) => void
}
const STATUS_STYLES: Record<string, string> = {
ready: 'bg-warm-state-success/10 text-warm-state-success',
building: 'bg-warm-state-info/10 text-warm-state-info',
training: 'bg-warm-state-info/10 text-warm-state-info',
failed: 'bg-warm-state-error/10 text-warm-state-error',
pending: 'bg-warm-state-warning/10 text-warm-state-warning',
scheduled: 'bg-warm-state-warning/10 text-warm-state-warning',
running: 'bg-warm-state-info/10 text-warm-state-info',
}
const StatusBadge: React.FC<{ status: string; trainingStatus?: string | null }> = ({ status, trainingStatus }) => {
// If there's an active training task, show training status
const displayStatus = trainingStatus === 'running'
? 'training'
: trainingStatus === 'pending' || trainingStatus === 'scheduled'
? 'pending'
: status
return (
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${STATUS_STYLES[displayStatus] ?? 'bg-warm-border text-warm-text-muted'}`}>
{(displayStatus === 'building' || displayStatus === 'training') && <Loader2 size={12} className="mr-1 animate-spin" />}
{displayStatus === 'ready' && <Check size={12} className="mr-1" />}
{displayStatus === 'failed' && <AlertCircle size={12} className="mr-1" />}
{displayStatus}
</span>
)
}
// --- Train Dialog ---
interface TrainDialogProps {
dataset: DatasetListItem
onClose: () => void
onSubmit: (config: {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs: number
batch_size: number
augmentation?: AugmentationConfigType
augmentation_multiplier?: number
}
}) => void
isPending: boolean
}
const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, isPending }) => {
const [name, setName] = useState(`train-${dataset.name}`)
const [epochs, setEpochs] = useState(100)
const [batchSize, setBatchSize] = useState(16)
const [baseModelType, setBaseModelType] = useState<'pretrained' | 'existing'>('pretrained')
const [baseModelVersionId, setBaseModelVersionId] = useState<string | null>(null)
const [augmentationEnabled, setAugmentationEnabled] = useState(false)
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
// Fetch available trained models (active or inactive, not archived)
const { data: modelsData } = useQuery({
queryKey: ['training', 'models', 'available'],
queryFn: () => trainingApi.getModels(),
})
// Filter out archived models - only show active/inactive models for base model selection
const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived')
const handleSubmit = () => {
onSubmit({
name,
config: {
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
epochs,
batch_size: batchSize,
augmentation: augmentationEnabled
? (augmentationConfig as AugmentationConfigType)
: undefined,
augmentation_multiplier: augmentationEnabled ? augmentationMultiplier : undefined,
},
})
}
return (
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
<div className="bg-white rounded-lg border border-warm-border shadow-lg w-[480px] max-h-[90vh] overflow-y-auto p-6" onClick={e => e.stopPropagation()}>
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Start Training</h3>
<p className="text-sm text-warm-text-muted mb-4">
Dataset: <span className="font-medium text-warm-text-secondary">{dataset.name}</span>
{' '}({dataset.total_images} images, {dataset.total_annotations} annotations)
</p>
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Task Name</label>
<input type="text" value={name} onChange={e => setName(e.target.value)}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
{/* Base Model Selection */}
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Base Model</label>
<select
value={baseModelType === 'pretrained' ? 'pretrained' : baseModelVersionId ?? ''}
onChange={e => {
if (e.target.value === 'pretrained') {
setBaseModelType('pretrained')
setBaseModelVersionId(null)
} else {
setBaseModelType('existing')
setBaseModelVersionId(e.target.value)
}
}}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
>
<option value="pretrained">yolo11n.pt (Pretrained)</option>
{availableModels.map(m => (
<option key={m.version_id} value={m.version_id}>
{m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
</option>
))}
</select>
<p className="text-xs text-warm-text-muted mt-1">
{baseModelType === 'pretrained'
? 'Start from pretrained YOLO model'
: 'Continue training from an existing model (incremental training)'}
</p>
</div>
<div className="flex gap-4">
<div className="flex-1">
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
<input
id="train-epochs"
type="number"
min={1}
max={1000}
value={epochs}
onChange={e => setEpochs(Math.max(1, Math.min(1000, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
<div className="flex-1">
<label htmlFor="train-batch-size" className="block text-sm font-medium text-warm-text-secondary mb-1">Batch Size</label>
<input
id="train-batch-size"
type="number"
min={1}
max={128}
value={batchSize}
onChange={e => setBatchSize(Math.max(1, Math.min(128, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
</div>
</div>
{/* Augmentation Configuration */}
<AugmentationConfig
enabled={augmentationEnabled}
onEnabledChange={setAugmentationEnabled}
config={augmentationConfig}
onConfigChange={setAugmentationConfig}
/>
{/* Augmentation Multiplier - only shown when augmentation is enabled */}
{augmentationEnabled && (
<div>
<label htmlFor="aug-multiplier" className="block text-sm font-medium text-warm-text-secondary mb-1">
Augmentation Multiplier
</label>
<input
id="aug-multiplier"
type="number"
min={1}
max={10}
value={augmentationMultiplier}
onChange={e => setAugmentationMultiplier(Math.max(1, Math.min(10, Number(e.target.value) || 1)))}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
/>
<p className="text-xs text-warm-text-muted mt-1">
Number of augmented copies per original image (1-10)
</p>
</div>
)}
</div>
<div className="flex justify-end gap-3 mt-6">
<Button variant="secondary" onClick={onClose} disabled={isPending}>Cancel</Button>
<Button onClick={handleSubmit} disabled={isPending || !name.trim()}>
{isPending ? <><Loader2 size={14} className="mr-1 animate-spin" />Training...</> : 'Start Training'}
</Button>
</div>
</div>
</div>
)
}
// --- Dataset List ---
const DatasetList: React.FC<{
onNavigate?: (view: string, id?: string) => void
onSwitchTab: (tab: Tab) => void
}> = ({ onNavigate, onSwitchTab }) => {
const { datasets, isLoading, deleteDataset, isDeleting, trainFromDataset, isTraining } = useDatasets()
const [trainTarget, setTrainTarget] = useState<DatasetListItem | null>(null)
const handleTrain = (config: {
name: string
config: {
model_name?: string
base_model_version_id?: string | null
epochs: number
batch_size: number
augmentation?: AugmentationConfigType
augmentation_multiplier?: number
}
}) => {
if (!trainTarget) return
// Pass config to the training API
const trainRequest = {
name: config.name,
config: config.config,
}
trainFromDataset(
{ datasetId: trainTarget.dataset_id, req: trainRequest },
{ onSuccess: () => setTrainTarget(null) },
)
}
if (isLoading) {
return <div className="flex items-center justify-center py-20 text-warm-text-muted"><Loader2 size={24} className="animate-spin mr-2" />Loading datasets...</div>
}
if (datasets.length === 0) {
return (
<div className="flex flex-col items-center justify-center py-20 text-warm-text-muted">
<Database size={48} className="mb-4 opacity-40" />
<p className="text-lg mb-2">No datasets yet</p>
<p className="text-sm mb-4">Create a dataset to start training</p>
<Button onClick={() => onSwitchTab('create')}><Plus size={14} className="mr-1" />Create Dataset</Button>
</div>
)
}
return (
<>
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
<table className="w-full text-left">
<thead className="bg-white border-b border-warm-border">
<tr>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Name</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Docs</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Images</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Created</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
</tr>
</thead>
<tbody>
{datasets.map(ds => (
<tr key={ds.dataset_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 px-4 text-sm font-medium text-warm-text-secondary">{ds.name}</td>
<td className="py-3 px-4"><StatusBadge status={ds.status} trainingStatus={ds.training_status} /></td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_documents}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_images}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{ds.total_annotations}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(ds.created_at).toLocaleDateString()}</td>
<td className="py-3 px-4">
<div className="flex gap-1">
<button title="View" onClick={() => onNavigate?.('dataset-detail', ds.dataset_id)}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-info transition-colors">
<Eye size={14} />
</button>
{ds.status === 'ready' && (
<button title="Train" onClick={() => setTrainTarget(ds)}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors">
<Play size={14} />
</button>
)}
<button title="Delete" onClick={() => deleteDataset(ds.dataset_id)}
disabled={isDeleting || ds.status === 'pending' || ds.status === 'building'}
className={`p-1.5 rounded transition-colors ${
ds.status === 'pending' || ds.status === 'building'
? 'text-warm-text-muted/40 cursor-not-allowed'
: 'hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error'
}`}>
<Trash2 size={14} />
</button>
</div>
</td>
</tr>
))}
</tbody>
</table>
</div>
{trainTarget && (
<TrainDialog dataset={trainTarget} onClose={() => setTrainTarget(null)} onSubmit={handleTrain} isPending={isTraining} />
)}
</>
)
}
// --- Create Dataset ---
const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitchTab }) => {
const { documents, isLoading: isLoadingDocs } = useTrainingDocuments({ has_annotations: true })
const { createDatasetAsync, isCreating } = useDatasets()
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set())
const [name, setName] = useState('')
const [description, setDescription] = useState('')
const [trainRatio, setTrainRatio] = useState(0.7)
const [valRatio, setValRatio] = useState(0.2)
const testRatio = useMemo(() => Math.max(0, +(1 - trainRatio - valRatio).toFixed(2)), [trainRatio, valRatio])
const toggleDoc = (id: string) => {
setSelectedIds(prev => {
const next = new Set(prev)
if (next.has(id)) { next.delete(id) } else { next.add(id) }
return next
})
}
const toggleAll = () => {
if (selectedIds.size === documents.length) {
setSelectedIds(new Set())
} else {
setSelectedIds(new Set(documents.map((d) => d.document_id)))
}
}
const handleCreate = async () => {
await createDatasetAsync({
name,
description: description || undefined,
document_ids: [...selectedIds],
train_ratio: trainRatio,
val_ratio: valRatio,
})
onSwitchTab('datasets')
}
return (
<div className="flex gap-8">
{/* Document selection */}
<div className="flex-1 flex flex-col">
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Select Documents</h3>
{isLoadingDocs ? (
<div className="flex items-center justify-center py-12 text-warm-text-muted"><Loader2 size={20} className="animate-spin mr-2" />Loading...</div>
) : (
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm flex-1">
<div className="overflow-auto max-h-[calc(100vh-240px)]">
<table className="w-full text-left">
<thead className="sticky top-0 bg-white border-b border-warm-border z-10">
<tr>
<th className="py-3 pl-6 pr-4 w-12">
<input type="checkbox" checked={selectedIds.size === documents.length && documents.length > 0}
onChange={toggleAll} className="rounded border-warm-divider accent-warm-state-info" />
</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Pages</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Annotations</th>
</tr>
</thead>
<tbody>
{documents.map((doc) => (
<tr key={doc.document_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors cursor-pointer"
onClick={() => toggleDoc(doc.document_id)}>
<td className="py-3 pl-6 pr-4">
<input type="checkbox" checked={selectedIds.has(doc.document_id)} readOnly
className="rounded border-warm-divider accent-warm-state-info pointer-events-none" />
</td>
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{doc.document_id.slice(0, 8)}...</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.page_count}</td>
<td className="py-3 px-4 text-sm text-warm-text-muted font-mono">{doc.annotation_count ?? 0}</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
)}
<p className="text-sm text-warm-text-muted mt-2">{selectedIds.size} of {documents.length} documents selected</p>
</div>
{/* Config panel */}
<div className="w-80">
<div className="bg-warm-card rounded-lg border border-warm-border shadow-card p-6 sticky top-8">
<h3 className="text-lg font-semibold text-warm-text-primary mb-4">Dataset Configuration</h3>
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Name</label>
<input type="text" value={name} onChange={e => setName(e.target.value)} placeholder="e.g. invoice-dataset-v1"
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Description</label>
<textarea value={description} onChange={e => setDescription(e.target.value)} rows={2} placeholder="Optional"
className="w-full px-3 py-2 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info resize-none" />
</div>
<div>
<label className="block text-sm font-medium text-warm-text-secondary mb-1">Train / Val / Test Split</label>
<div className="flex gap-2 text-sm">
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Train</span>
<input type="number" step={0.05} min={0.1} max={0.9} value={trainRatio} onChange={e => setTrainRatio(Number(e.target.value))}
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Val</span>
<input type="number" step={0.05} min={0} max={0.5} value={valRatio} onChange={e => setValRatio(Number(e.target.value))}
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-white text-warm-text-primary text-center font-mono focus:outline-none focus:ring-1 focus:ring-warm-state-info" />
</div>
<div className="flex-1">
<span className="text-xs text-warm-text-muted">Test</span>
<input type="number" value={testRatio} readOnly
className="w-full h-9 px-2 rounded-md border border-warm-divider bg-warm-hover text-warm-text-muted text-center font-mono" />
</div>
</div>
</div>
<div className="pt-4 border-t border-warm-border">
{selectedIds.size > 0 && selectedIds.size < 10 && (
<p className="text-xs text-warm-state-warning mb-2">
Minimum 10 documents required for training ({selectedIds.size}/10 selected)
</p>
)}
<Button className="w-full h-11" onClick={handleCreate}
disabled={isCreating || selectedIds.size < 10 || !name.trim()}>
{isCreating ? <><Loader2 size={14} className="mr-1 animate-spin" />Creating...</> : <><Plus size={14} className="mr-1" />Create Dataset</>}
</Button>
</div>
</div>
</div>
</div>
</div>
)
}
// --- Main Training Component ---
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
const [activeTab, setActiveTab] = useState<Tab>('datasets')
return (
<div className="p-8 max-w-7xl mx-auto">
<div className="flex items-center justify-between mb-6">
<h2 className="text-2xl font-bold text-warm-text-primary">Training</h2>
</div>
{/* Tabs */}
<div className="flex gap-1 mb-6 border-b border-warm-border">
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
<button key={key} onClick={() => setActiveTab(key)}
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
activeTab === key
? 'border-warm-state-info text-warm-state-info'
: 'border-transparent text-warm-text-muted hover:text-warm-text-secondary'
}`}>
{label}
</button>
))}
</div>
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
</div>
)
}

View File

@@ -0,0 +1,276 @@
import React, { useState, useRef } from 'react'
import { X, UploadCloud, File, CheckCircle, AlertCircle, ChevronDown } from 'lucide-react'
import { Button } from './Button'
import { useDocuments, useCategories } from '../hooks/useDocuments'
interface UploadModalProps {
isOpen: boolean
onClose: () => void
}
export const UploadModal: React.FC<UploadModalProps> = ({ isOpen, onClose }) => {
const [isDragging, setIsDragging] = useState(false)
const [selectedFiles, setSelectedFiles] = useState<File[]>([])
const [groupKey, setGroupKey] = useState('')
const [category, setCategory] = useState('invoice')
const [uploadStatus, setUploadStatus] = useState<'idle' | 'uploading' | 'success' | 'error'>('idle')
const [errorMessage, setErrorMessage] = useState('')
const fileInputRef = useRef<HTMLInputElement>(null)
const { uploadDocument, isUploading } = useDocuments({})
const { categories } = useCategories()
if (!isOpen) return null
const handleFileSelect = (files: FileList | null) => {
if (!files) return
const pdfFiles = Array.from(files).filter(file => {
const isPdf = file.type === 'application/pdf'
const isImage = file.type.startsWith('image/')
const isUnder25MB = file.size <= 25 * 1024 * 1024
return (isPdf || isImage) && isUnder25MB
})
setSelectedFiles(prev => [...prev, ...pdfFiles])
setUploadStatus('idle')
setErrorMessage('')
}
const handleDrop = (e: React.DragEvent) => {
e.preventDefault()
setIsDragging(false)
handleFileSelect(e.dataTransfer.files)
}
const handleBrowseClick = () => {
fileInputRef.current?.click()
}
const removeFile = (index: number) => {
setSelectedFiles(prev => prev.filter((_, i) => i !== index))
}
const handleUpload = async () => {
if (selectedFiles.length === 0) {
setErrorMessage('Please select at least one file')
return
}
setUploadStatus('uploading')
setErrorMessage('')
try {
// Upload files one by one
for (const file of selectedFiles) {
await new Promise<void>((resolve, reject) => {
uploadDocument(
{ file, groupKey: groupKey || undefined, category: category || 'invoice' },
{
onSuccess: () => resolve(),
onError: (error: Error) => reject(error),
}
)
})
}
setUploadStatus('success')
setTimeout(() => {
onClose()
setSelectedFiles([])
setGroupKey('')
setCategory('invoice')
setUploadStatus('idle')
}, 1500)
} catch (error) {
setUploadStatus('error')
setErrorMessage(error instanceof Error ? error.message : 'Upload failed')
}
}
const handleClose = () => {
if (uploadStatus === 'uploading') {
return // Prevent closing during upload
}
setSelectedFiles([])
setGroupKey('')
setCategory('invoice')
setUploadStatus('idle')
setErrorMessage('')
onClose()
}
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/20 backdrop-blur-sm transition-opacity duration-200">
<div
className="w-full max-w-lg bg-warm-card rounded-lg shadow-modal border border-warm-border transform transition-all duration-200 scale-100 p-6"
onClick={(e) => e.stopPropagation()}
>
<div className="flex items-center justify-between mb-6">
<h3 className="text-xl font-semibold text-warm-text-primary">Upload Documents</h3>
<button
onClick={handleClose}
className="text-warm-text-muted hover:text-warm-text-primary transition-colors disabled:opacity-50"
disabled={uploadStatus === 'uploading'}
>
<X size={20} />
</button>
</div>
{/* Drop Zone */}
<div
className={`
w-full h-48 rounded-lg border-2 border-dashed flex flex-col items-center justify-center gap-3 transition-colors duration-150 mb-6 cursor-pointer
${isDragging ? 'border-warm-text-secondary bg-warm-selected' : 'border-warm-divider bg-warm-bg hover:bg-warm-hover'}
${uploadStatus === 'uploading' ? 'opacity-50 pointer-events-none' : ''}
`}
onDragOver={(e) => { e.preventDefault(); setIsDragging(true); }}
onDragLeave={() => setIsDragging(false)}
onDrop={handleDrop}
onClick={handleBrowseClick}
>
<div className="p-3 bg-white rounded-full shadow-sm">
<UploadCloud size={24} className="text-warm-text-secondary" />
</div>
<div className="text-center">
<p className="text-sm font-medium text-warm-text-primary">
Drag & drop files here or <span className="underline decoration-1 underline-offset-2 hover:text-warm-state-info">Browse</span>
</p>
<p className="text-xs text-warm-text-muted mt-1">PDF, JPG, PNG up to 25MB</p>
</div>
</div>
<input
ref={fileInputRef}
type="file"
multiple
accept=".pdf,image/*"
className="hidden"
onChange={(e) => handleFileSelect(e.target.files)}
/>
{/* Selected Files */}
{selectedFiles.length > 0 && (
<div className="mb-6 max-h-40 overflow-y-auto">
<p className="text-sm font-medium text-warm-text-secondary mb-2">
Selected Files ({selectedFiles.length})
</p>
<div className="space-y-2">
{selectedFiles.map((file, index) => (
<div
key={index}
className="flex items-center justify-between p-2 bg-warm-bg rounded border border-warm-border"
>
<div className="flex items-center gap-2 flex-1 min-w-0">
<File size={16} className="text-warm-text-muted flex-shrink-0" />
<span className="text-sm text-warm-text-secondary truncate">
{file.name}
</span>
<span className="text-xs text-warm-text-muted flex-shrink-0">
({(file.size / 1024 / 1024).toFixed(2)} MB)
</span>
</div>
<button
onClick={() => removeFile(index)}
className="text-warm-text-muted hover:text-warm-state-error ml-2 flex-shrink-0"
disabled={uploadStatus === 'uploading'}
>
<X size={16} />
</button>
</div>
))}
</div>
</div>
)}
{/* Category Select */}
{selectedFiles.length > 0 && (
<div className="mb-6">
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
Category
</label>
<div className="relative">
<select
value={category}
onChange={(e) => setCategory(e.target.value)}
className="w-full h-10 pl-3 pr-8 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info appearance-none cursor-pointer"
disabled={uploadStatus === 'uploading'}
>
<option value="invoice">Invoice</option>
<option value="letter">Letter</option>
<option value="receipt">Receipt</option>
<option value="contract">Contract</option>
{categories
.filter((cat) => !['invoice', 'letter', 'receipt', 'contract'].includes(cat))
.map((cat) => (
<option key={cat} value={cat}>
{cat.charAt(0).toUpperCase() + cat.slice(1)}
</option>
))}
</select>
<ChevronDown
className="absolute right-2.5 top-1/2 -translate-y-1/2 pointer-events-none text-warm-text-muted"
size={14}
/>
</div>
<p className="text-xs text-warm-text-muted mt-1">
Select document type for training different models
</p>
</div>
)}
{/* Group Key Input */}
{selectedFiles.length > 0 && (
<div className="mb-6">
<label className="block text-sm font-medium text-warm-text-secondary mb-2">
Group Key (optional)
</label>
<input
type="text"
value={groupKey}
onChange={(e) => setGroupKey(e.target.value)}
placeholder="e.g., 2024-Q1, supplier-abc, project-name"
className="w-full px-3 h-10 rounded-md border border-warm-border bg-white text-sm text-warm-text-secondary focus:outline-none focus:ring-1 focus:ring-warm-state-info transition-shadow"
disabled={uploadStatus === 'uploading'}
/>
<p className="text-xs text-warm-text-muted mt-1">
Use group keys to organize documents into logical groups
</p>
</div>
)}
{/* Status Messages */}
{uploadStatus === 'success' && (
<div className="mb-4 p-3 bg-green-50 border border-green-200 rounded flex items-center gap-2">
<CheckCircle size={16} className="text-green-600" />
<span className="text-sm text-green-800">Upload successful!</span>
</div>
)}
{uploadStatus === 'error' && errorMessage && (
<div className="mb-4 p-3 bg-red-50 border border-red-200 rounded flex items-center gap-2">
<AlertCircle size={16} className="text-red-600" />
<span className="text-sm text-red-800">{errorMessage}</span>
</div>
)}
{/* Actions */}
<div className="mt-8 flex justify-end gap-3">
<Button
variant="secondary"
onClick={handleClose}
disabled={uploadStatus === 'uploading'}
>
Cancel
</Button>
<Button
onClick={handleUpload}
disabled={selectedFiles.length === 0 || uploadStatus === 'uploading'}
>
{uploadStatus === 'uploading' ? 'Uploading...' : `Upload ${selectedFiles.length > 0 ? `(${selectedFiles.length})` : ''}`}
</Button>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,7 @@
export { useDocuments, useCategories } from './useDocuments'
export { useDocumentDetail } from './useDocumentDetail'
export { useAnnotations } from './useAnnotations'
export { useTraining, useTrainingDocuments } from './useTraining'
export { useDatasets, useDatasetDetail } from './useDatasets'
export { useAugmentation } from './useAugmentation'
export { useModels, useModelDetail, useActiveModel } from './useModels'

View File

@@ -0,0 +1,70 @@
import { useMutation, useQueryClient } from '@tanstack/react-query'
import { annotationsApi } from '../api/endpoints'
import type { CreateAnnotationRequest, AnnotationOverrideRequest } from '../api/types'
export const useAnnotations = (documentId: string) => {
const queryClient = useQueryClient()
const createMutation = useMutation({
mutationFn: (annotation: CreateAnnotationRequest) =>
annotationsApi.create(documentId, annotation),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const updateMutation = useMutation({
mutationFn: ({
annotationId,
updates,
}: {
annotationId: string
updates: Partial<CreateAnnotationRequest>
}) => annotationsApi.update(documentId, annotationId, updates),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const deleteMutation = useMutation({
mutationFn: (annotationId: string) =>
annotationsApi.delete(documentId, annotationId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const verifyMutation = useMutation({
mutationFn: (annotationId: string) =>
annotationsApi.verify(documentId, annotationId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
const overrideMutation = useMutation({
mutationFn: ({
annotationId,
overrideData,
}: {
annotationId: string
overrideData: AnnotationOverrideRequest
}) => annotationsApi.override(documentId, annotationId, overrideData),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['document', documentId] })
},
})
return {
createAnnotation: createMutation.mutate,
isCreating: createMutation.isPending,
updateAnnotation: updateMutation.mutate,
isUpdating: updateMutation.isPending,
deleteAnnotation: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
verifyAnnotation: verifyMutation.mutate,
isVerifying: verifyMutation.isPending,
overrideAnnotation: overrideMutation.mutate,
isOverriding: overrideMutation.isPending,
}
}

View File

@@ -0,0 +1,226 @@
/**
* Tests for useAugmentation hook.
*
* TDD Phase 1: RED - Write tests first, then implement to pass.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { renderHook, waitFor } from '@testing-library/react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { augmentationApi } from '../api/endpoints/augmentation'
import { useAugmentation } from './useAugmentation'
import type { ReactNode } from 'react'
// Mock the API
vi.mock('../api/endpoints/augmentation', () => ({
augmentationApi: {
getTypes: vi.fn(),
getPresets: vi.fn(),
preview: vi.fn(),
previewConfig: vi.fn(),
createBatch: vi.fn(),
},
}))
// Test wrapper with QueryClient
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return ({ children }: { children: ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
}
describe('useAugmentation', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('getTypes', () => {
it('should fetch augmentation types', async () => {
const mockTypes = {
augmentation_types: [
{
name: 'gaussian_noise',
description: 'Adds Gaussian noise',
affects_geometry: false,
stage: 'noise',
default_params: { mean: 0, std: 15 },
},
{
name: 'perspective_warp',
description: 'Applies perspective warp',
affects_geometry: true,
stage: 'geometric',
default_params: { max_warp: 0.02 },
},
],
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce(mockTypes)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.augmentationTypes).toHaveLength(2)
expect(result.current.augmentationTypes[0].name).toBe('gaussian_noise')
})
it('should handle error when fetching types', async () => {
vi.mocked(augmentationApi.getTypes).mockRejectedValueOnce(new Error('Network error'))
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.typesError).toBeTruthy()
})
})
describe('getPresets', () => {
it('should fetch augmentation presets', async () => {
const mockPresets = {
presets: [
{ name: 'conservative', description: 'Safe augmentations' },
{ name: 'moderate', description: 'Balanced augmentations' },
{ name: 'aggressive', description: 'Strong augmentations' },
],
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce(mockPresets)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingPresets).toBe(false)
})
expect(result.current.presets).toHaveLength(3)
expect(result.current.presets[0].name).toBe('conservative')
})
})
describe('preview', () => {
it('should preview single augmentation', async () => {
const mockPreview = {
preview_url: '',
original_url: '',
applied_params: { std: 15 },
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.preview).mockResolvedValueOnce(mockPreview)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
// Call preview mutation
result.current.preview({
documentId: 'doc-123',
augmentationType: 'gaussian_noise',
params: { std: 15 },
page: 1,
})
await waitFor(() => {
expect(augmentationApi.preview).toHaveBeenCalledWith(
'doc-123',
{ augmentation_type: 'gaussian_noise', params: { std: 15 } },
1
)
})
})
it('should track preview loading state', async () => {
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.preview).mockImplementation(
() => new Promise((resolve) => setTimeout(resolve, 100))
)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
expect(result.current.isPreviewing).toBe(false)
result.current.preview({
documentId: 'doc-123',
augmentationType: 'gaussian_noise',
params: {},
page: 1,
})
// State update happens asynchronously
await waitFor(() => {
expect(result.current.isPreviewing).toBe(true)
})
})
})
describe('createBatch', () => {
it('should create augmented dataset', async () => {
const mockResponse = {
task_id: 'task-123',
status: 'pending',
message: 'Augmentation task queued',
estimated_images: 100,
}
vi.mocked(augmentationApi.getTypes).mockResolvedValueOnce({ augmentation_types: [] })
vi.mocked(augmentationApi.getPresets).mockResolvedValueOnce({ presets: [] })
vi.mocked(augmentationApi.createBatch).mockResolvedValueOnce(mockResponse)
const { result } = renderHook(() => useAugmentation(), {
wrapper: createWrapper(),
})
await waitFor(() => {
expect(result.current.isLoadingTypes).toBe(false)
})
result.current.createBatch({
dataset_id: 'dataset-123',
config: {
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
},
output_name: 'augmented-dataset',
multiplier: 2,
})
await waitFor(() => {
expect(augmentationApi.createBatch).toHaveBeenCalledWith({
dataset_id: 'dataset-123',
config: {
gaussian_noise: { enabled: true, probability: 0.5, params: {} },
},
output_name: 'augmented-dataset',
multiplier: 2,
})
})
})
})
})

View File

@@ -0,0 +1,121 @@
/**
* Hook for managing augmentation operations.
*
* Provides functions for fetching augmentation types, presets, and previewing augmentations.
*/
import { useQuery, useMutation } from '@tanstack/react-query'
import {
augmentationApi,
type AugmentationTypesResponse,
type PresetsResponse,
type PreviewResponse,
type BatchRequest,
type BatchResponse,
type AugmentationConfig,
} from '../api/endpoints/augmentation'
interface PreviewParams {
documentId: string
augmentationType: string
params: Record<string, unknown>
page?: number
}
interface PreviewConfigParams {
documentId: string
config: AugmentationConfig
page?: number
}
export const useAugmentation = () => {
// Fetch augmentation types
const {
data: typesData,
isLoading: isLoadingTypes,
error: typesError,
} = useQuery<AugmentationTypesResponse>({
queryKey: ['augmentation', 'types'],
queryFn: () => augmentationApi.getTypes(),
staleTime: 5 * 60 * 1000, // Cache for 5 minutes
})
// Fetch presets
const {
data: presetsData,
isLoading: isLoadingPresets,
error: presetsError,
} = useQuery<PresetsResponse>({
queryKey: ['augmentation', 'presets'],
queryFn: () => augmentationApi.getPresets(),
staleTime: 5 * 60 * 1000,
})
// Preview single augmentation mutation
const previewMutation = useMutation<PreviewResponse, Error, PreviewParams>({
mutationFn: ({ documentId, augmentationType, params, page = 1 }) =>
augmentationApi.preview(
documentId,
{ augmentation_type: augmentationType, params },
page
),
onError: (error) => {
console.error('Preview augmentation failed:', error)
},
})
// Preview full config mutation
const previewConfigMutation = useMutation<PreviewResponse, Error, PreviewConfigParams>({
mutationFn: ({ documentId, config, page = 1 }) =>
augmentationApi.previewConfig(documentId, config, page),
onError: (error) => {
console.error('Preview config failed:', error)
},
})
// Create augmented dataset mutation
const createBatchMutation = useMutation<BatchResponse, Error, BatchRequest>({
mutationFn: (request) => augmentationApi.createBatch(request),
onError: (error) => {
console.error('Create augmented dataset failed:', error)
},
})
return {
// Types data
augmentationTypes: typesData?.augmentation_types || [],
isLoadingTypes,
typesError,
// Presets data
presets: presetsData?.presets || [],
isLoadingPresets,
presetsError,
// Preview single augmentation
preview: previewMutation.mutate,
previewAsync: previewMutation.mutateAsync,
isPreviewing: previewMutation.isPending,
previewData: previewMutation.data,
previewError: previewMutation.error,
// Preview full config
previewConfig: previewConfigMutation.mutate,
previewConfigAsync: previewConfigMutation.mutateAsync,
isPreviewingConfig: previewConfigMutation.isPending,
previewConfigData: previewConfigMutation.data,
previewConfigError: previewConfigMutation.error,
// Create batch
createBatch: createBatchMutation.mutate,
createBatchAsync: createBatchMutation.mutateAsync,
isCreatingBatch: createBatchMutation.isPending,
batchData: createBatchMutation.data,
batchError: createBatchMutation.error,
// Reset functions for clearing stale mutation state
resetPreview: previewMutation.reset,
resetPreviewConfig: previewConfigMutation.reset,
resetBatch: createBatchMutation.reset,
}
}

View File

@@ -0,0 +1,84 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { datasetsApi } from '../api/endpoints'
import type {
DatasetCreateRequest,
DatasetDetailResponse,
DatasetListResponse,
DatasetTrainRequest,
} from '../api/types'
export const useDatasets = (params?: {
status?: string
limit?: number
offset?: number
}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<DatasetListResponse>({
queryKey: ['datasets', params],
queryFn: () => datasetsApi.list(params),
staleTime: 30000,
// Poll every 5 seconds when there's an active training task
refetchInterval: (query) => {
const datasets = query.state.data?.datasets ?? []
const hasActiveTraining = datasets.some(
d => d.training_status === 'running' || d.training_status === 'pending' || d.training_status === 'scheduled'
)
return hasActiveTraining ? 5000 : false
},
})
const createMutation = useMutation({
mutationFn: (req: DatasetCreateRequest) => datasetsApi.create(req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
},
})
const deleteMutation = useMutation({
mutationFn: (datasetId: string) => datasetsApi.remove(datasetId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
},
})
const trainMutation = useMutation({
mutationFn: ({ datasetId, req }: { datasetId: string; req: DatasetTrainRequest }) =>
datasetsApi.trainFromDataset(datasetId, req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['datasets'] })
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
},
})
return {
datasets: data?.datasets ?? [],
total: data?.total ?? 0,
isLoading,
error,
refetch,
createDataset: createMutation.mutate,
createDatasetAsync: createMutation.mutateAsync,
isCreating: createMutation.isPending,
deleteDataset: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
trainFromDataset: trainMutation.mutate,
trainFromDatasetAsync: trainMutation.mutateAsync,
isTraining: trainMutation.isPending,
}
}
export const useDatasetDetail = (datasetId: string | null) => {
const { data, isLoading, error } = useQuery<DatasetDetailResponse>({
queryKey: ['datasets', datasetId],
queryFn: () => datasetsApi.getDetail(datasetId!),
enabled: !!datasetId,
staleTime: 30000,
})
return {
dataset: data ?? null,
isLoading,
error,
}
}

View File

@@ -0,0 +1,25 @@
import { useQuery } from '@tanstack/react-query'
import { documentsApi } from '../api/endpoints'
import type { DocumentDetailResponse } from '../api/types'
export const useDocumentDetail = (documentId: string | null) => {
const { data, isLoading, error, refetch } = useQuery<DocumentDetailResponse>({
queryKey: ['document', documentId],
queryFn: () => {
if (!documentId) {
throw new Error('Document ID is required')
}
return documentsApi.getDetail(documentId)
},
enabled: !!documentId,
staleTime: 10000,
})
return {
document: data || null,
annotations: data?.annotations || [],
isLoading,
error,
refetch,
}
}

View File

@@ -0,0 +1,120 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { documentsApi } from '../api/endpoints'
import type { DocumentListResponse, DocumentCategoriesResponse } from '../api/types'
interface UseDocumentsParams {
status?: string
category?: string
limit?: number
offset?: number
}
export const useDocuments = (params: UseDocumentsParams = {}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<DocumentListResponse>({
queryKey: ['documents', params],
queryFn: () => documentsApi.list(params),
staleTime: 30000,
})
const uploadMutation = useMutation({
mutationFn: ({ file, groupKey, category }: { file: File; groupKey?: string; category?: string }) =>
documentsApi.upload(file, { groupKey, category }),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
queryClient.invalidateQueries({ queryKey: ['categories'] })
},
})
const updateGroupKeyMutation = useMutation({
mutationFn: ({ documentId, groupKey }: { documentId: string; groupKey: string | null }) =>
documentsApi.updateGroupKey(documentId, groupKey),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const batchUploadMutation = useMutation({
mutationFn: ({ files, csvFile }: { files: File[]; csvFile?: File }) =>
documentsApi.batchUpload(files, csvFile),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const deleteMutation = useMutation({
mutationFn: (documentId: string) => documentsApi.delete(documentId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const updateStatusMutation = useMutation({
mutationFn: ({ documentId, status }: { documentId: string; status: string }) =>
documentsApi.updateStatus(documentId, status),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const triggerAutoLabelMutation = useMutation({
mutationFn: (documentId: string) => documentsApi.triggerAutoLabel(documentId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
},
})
const updateCategoryMutation = useMutation({
mutationFn: ({ documentId, category }: { documentId: string; category: string }) =>
documentsApi.updateCategory(documentId, category),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['documents'] })
queryClient.invalidateQueries({ queryKey: ['categories'] })
},
})
return {
documents: data?.documents || [],
total: data?.total || 0,
limit: data?.limit || params.limit || 20,
offset: data?.offset || params.offset || 0,
isLoading,
error,
refetch,
uploadDocument: uploadMutation.mutate,
uploadDocumentAsync: uploadMutation.mutateAsync,
isUploading: uploadMutation.isPending,
batchUpload: batchUploadMutation.mutate,
batchUploadAsync: batchUploadMutation.mutateAsync,
isBatchUploading: batchUploadMutation.isPending,
deleteDocument: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
updateStatus: updateStatusMutation.mutate,
isUpdatingStatus: updateStatusMutation.isPending,
triggerAutoLabel: triggerAutoLabelMutation.mutate,
isTriggeringAutoLabel: triggerAutoLabelMutation.isPending,
updateGroupKey: updateGroupKeyMutation.mutate,
updateGroupKeyAsync: updateGroupKeyMutation.mutateAsync,
isUpdatingGroupKey: updateGroupKeyMutation.isPending,
updateCategory: updateCategoryMutation.mutate,
updateCategoryAsync: updateCategoryMutation.mutateAsync,
isUpdatingCategory: updateCategoryMutation.isPending,
}
}
export const useCategories = () => {
const { data, isLoading, error, refetch } = useQuery<DocumentCategoriesResponse>({
queryKey: ['categories'],
queryFn: () => documentsApi.getCategories(),
staleTime: 60000,
})
return {
categories: data?.categories || [],
total: data?.total || 0,
isLoading,
error,
refetch,
}
}

View File

@@ -0,0 +1,98 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { modelsApi } from '../api/endpoints'
import type {
ModelVersionListResponse,
ModelVersionDetailResponse,
ActiveModelResponse,
} from '../api/types'
export const useModels = (params?: {
status?: string
limit?: number
offset?: number
}) => {
const queryClient = useQueryClient()
const { data, isLoading, error, refetch } = useQuery<ModelVersionListResponse>({
queryKey: ['models', params],
queryFn: () => modelsApi.list(params),
staleTime: 30000,
})
const activateMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.activate(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
},
})
const deactivateMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.deactivate(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
queryClient.invalidateQueries({ queryKey: ['models', 'active'] })
},
})
const archiveMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.archive(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
},
})
const deleteMutation = useMutation({
mutationFn: (versionId: string) => modelsApi.delete(versionId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['models'] })
},
})
return {
models: data?.models ?? [],
total: data?.total ?? 0,
isLoading,
error,
refetch,
activateModel: activateMutation.mutate,
activateModelAsync: activateMutation.mutateAsync,
isActivating: activateMutation.isPending,
deactivateModel: deactivateMutation.mutate,
isDeactivating: deactivateMutation.isPending,
archiveModel: archiveMutation.mutate,
isArchiving: archiveMutation.isPending,
deleteModel: deleteMutation.mutate,
isDeleting: deleteMutation.isPending,
}
}
export const useModelDetail = (versionId: string | null) => {
const { data, isLoading, error } = useQuery<ModelVersionDetailResponse>({
queryKey: ['models', versionId],
queryFn: () => modelsApi.getDetail(versionId!),
enabled: !!versionId,
staleTime: 30000,
})
return {
model: data ?? null,
isLoading,
error,
}
}
export const useActiveModel = () => {
const { data, isLoading, error } = useQuery<ActiveModelResponse>({
queryKey: ['models', 'active'],
queryFn: () => modelsApi.getActive(),
staleTime: 30000,
})
return {
hasActiveModel: data?.has_active_model ?? false,
activeModel: data?.model ?? null,
isLoading,
error,
}
}

View File

@@ -0,0 +1,83 @@
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { trainingApi } from '../api/endpoints'
import type { TrainingModelsResponse } from '../api/types'
export const useTraining = () => {
const queryClient = useQueryClient()
const { data: modelsData, isLoading: isLoadingModels } =
useQuery<TrainingModelsResponse>({
queryKey: ['training', 'models'],
queryFn: () => trainingApi.getModels(),
staleTime: 30000,
})
const startTrainingMutation = useMutation({
mutationFn: (config: {
name: string
description?: string
document_ids: string[]
epochs?: number
batch_size?: number
model_base?: string
}) => trainingApi.startTraining(config),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
},
})
const cancelTaskMutation = useMutation({
mutationFn: (taskId: string) => trainingApi.cancelTask(taskId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['training', 'models'] })
},
})
const downloadModelMutation = useMutation({
mutationFn: (taskId: string) => trainingApi.downloadModel(taskId),
onSuccess: (blob, taskId) => {
const url = window.URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `model-${taskId}.pt`
document.body.appendChild(a)
a.click()
window.URL.revokeObjectURL(url)
document.body.removeChild(a)
},
})
return {
models: modelsData?.models || [],
total: modelsData?.total || 0,
isLoadingModels,
startTraining: startTrainingMutation.mutate,
startTrainingAsync: startTrainingMutation.mutateAsync,
isStartingTraining: startTrainingMutation.isPending,
cancelTask: cancelTaskMutation.mutate,
isCancelling: cancelTaskMutation.isPending,
downloadModel: downloadModelMutation.mutate,
isDownloading: downloadModelMutation.isPending,
}
}
export const useTrainingDocuments = (params?: {
has_annotations?: boolean
min_annotation_count?: number
exclude_used_in_training?: boolean
limit?: number
offset?: number
}) => {
const { data, isLoading, error } = useQuery({
queryKey: ['training', 'documents', params],
queryFn: () => trainingApi.getDocumentsForTraining(params),
staleTime: 30000,
})
return {
documents: data?.documents || [],
total: data?.total || 0,
isLoading,
error,
}
}

23
frontend/src/main.tsx Normal file
View File

@@ -0,0 +1,23 @@
import React from 'react'
import ReactDOM from 'react-dom/client'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import App from './App'
import './styles/index.css'
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: 1,
refetchOnWindowFocus: false,
staleTime: 30000,
},
},
})
ReactDOM.createRoot(document.getElementById('root')!).render(
<React.StrictMode>
<QueryClientProvider client={queryClient}>
<App />
</QueryClientProvider>
</React.StrictMode>
)

View File

@@ -0,0 +1,26 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
@layer base {
body {
@apply bg-warm-bg text-warm-text-primary;
}
/* Custom scrollbar */
::-webkit-scrollbar {
@apply w-2 h-2;
}
::-webkit-scrollbar-track {
@apply bg-transparent;
}
::-webkit-scrollbar-thumb {
@apply bg-warm-divider rounded;
}
::-webkit-scrollbar-thumb:hover {
@apply bg-warm-text-disabled;
}
}

View File

@@ -0,0 +1,48 @@
// Legacy types for backward compatibility with old components
// These will be gradually replaced with API types
export enum DocumentStatus {
PENDING = 'Pending',
LABELED = 'Labeled',
VERIFIED = 'Verified',
PARTIAL = 'Partial'
}
export interface Document {
id: string
name: string
date: string
status: DocumentStatus
exported: boolean
autoLabelProgress?: number
autoLabelStatus?: 'Running' | 'Completed' | 'Failed'
}
export interface Annotation {
id: string
text: string
label: string
x: number
y: number
width: number
height: number
isAuto?: boolean
}
export interface TrainingJob {
id: string
name: string
startDate: string
status: 'Running' | 'Completed' | 'Failed'
progress: number
metrics?: {
accuracy: number
precision: number
recall: number
}
}
export interface ModelMetric {
name: string
value: number
}

View File

@@ -0,0 +1,47 @@
export default {
content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
theme: {
extend: {
fontFamily: {
sans: ['Inter', 'SF Pro', 'system-ui', 'sans-serif'],
mono: ['JetBrains Mono', 'SF Mono', 'monospace'],
},
colors: {
warm: {
bg: '#FAFAF8',
card: '#FFFFFF',
hover: '#F1F0ED',
selected: '#ECEAE6',
border: '#E6E4E1',
divider: '#D8D6D2',
text: {
primary: '#121212',
secondary: '#2A2A2A',
muted: '#6B6B6B',
disabled: '#9A9A9A',
},
state: {
success: '#3E4A3A',
error: '#4A3A3A',
warning: '#4A4A3A',
info: '#3A3A3A',
}
}
},
boxShadow: {
'card': '0 1px 3px rgba(0,0,0,0.08)',
'modal': '0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06)',
},
animation: {
'fade-in': 'fadeIn 0.3s ease-out',
},
keyframes: {
fadeIn: {
'0%': { opacity: '0', transform: 'translateY(10px)' },
'100%': { opacity: '1', transform: 'translateY(0)' },
}
}
}
},
plugins: [],
}

1
frontend/tests/setup.ts Normal file
View File

@@ -0,0 +1 @@
import '@testing-library/jest-dom';

29
frontend/tsconfig.json Normal file
View File

@@ -0,0 +1,29 @@
{
"compilerOptions": {
"target": "ES2022",
"experimentalDecorators": true,
"useDefineForClassFields": false,
"module": "ESNext",
"lib": [
"ES2022",
"DOM",
"DOM.Iterable"
],
"skipLibCheck": true,
"types": [
"node"
],
"moduleResolution": "bundler",
"isolatedModules": true,
"moduleDetection": "force",
"allowJs": true,
"jsx": "react-jsx",
"paths": {
"@/*": [
"./*"
]
},
"allowImportingTsExtensions": true,
"noEmit": true
}
}

16
frontend/vite.config.ts Normal file
View File

@@ -0,0 +1,16 @@
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
export default defineConfig({
server: {
port: 3000,
host: '0.0.0.0',
proxy: {
'/api': {
target: 'http://localhost:8000',
changeOrigin: true,
},
},
},
plugins: [react()],
});

19
frontend/vitest.config.ts Normal file
View File

@@ -0,0 +1,19 @@
/// <reference types="vitest/config" />
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
export default defineConfig({
plugins: [react()],
test: {
globals: true,
environment: 'jsdom',
setupFiles: ['./tests/setup.ts'],
include: ['src/**/*.test.{ts,tsx}', 'tests/**/*.test.{ts,tsx}'],
coverage: {
provider: 'v8',
reporter: ['text', 'lcov'],
include: ['src/**/*.{ts,tsx}'],
exclude: ['src/**/*.test.{ts,tsx}', 'src/main.tsx'],
},
},
});

View File

@@ -0,0 +1,18 @@
-- Training tasks table for async training job management.
-- Inference service writes pending tasks; training service polls and executes.
CREATE TABLE IF NOT EXISTS training_tasks (
task_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
status VARCHAR(20) NOT NULL DEFAULT 'pending',
config JSONB,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
scheduled_at TIMESTAMP WITH TIME ZONE,
started_at TIMESTAMP WITH TIME ZONE,
completed_at TIMESTAMP WITH TIME ZONE,
error_message TEXT,
model_path TEXT,
metrics JSONB
);
CREATE INDEX IF NOT EXISTS idx_training_tasks_status ON training_tasks(status);
CREATE INDEX IF NOT EXISTS idx_training_tasks_created ON training_tasks(created_at);

View File

@@ -0,0 +1,39 @@
-- Training Datasets Management
-- Tracks dataset-document relationships and train/val/test splits
CREATE TABLE IF NOT EXISTS training_datasets (
dataset_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(255) NOT NULL,
description TEXT,
status VARCHAR(20) NOT NULL DEFAULT 'building',
train_ratio FLOAT NOT NULL DEFAULT 0.8,
val_ratio FLOAT NOT NULL DEFAULT 0.1,
seed INTEGER NOT NULL DEFAULT 42,
total_documents INTEGER NOT NULL DEFAULT 0,
total_images INTEGER NOT NULL DEFAULT 0,
total_annotations INTEGER NOT NULL DEFAULT 0,
dataset_path VARCHAR(512),
error_message TEXT,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_training_datasets_status ON training_datasets(status);
CREATE TABLE IF NOT EXISTS dataset_documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
dataset_id UUID NOT NULL REFERENCES training_datasets(dataset_id) ON DELETE CASCADE,
document_id UUID NOT NULL REFERENCES admin_documents(document_id),
split VARCHAR(10) NOT NULL,
page_count INTEGER NOT NULL DEFAULT 0,
annotation_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
UNIQUE(dataset_id, document_id)
);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_dataset ON dataset_documents(dataset_id);
CREATE INDEX IF NOT EXISTS idx_dataset_documents_document ON dataset_documents(document_id);
-- Add dataset_id to training_tasks
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS dataset_id UUID REFERENCES training_datasets(dataset_id);
CREATE INDEX IF NOT EXISTS idx_training_tasks_dataset ON training_tasks(dataset_id);

View File

@@ -0,0 +1,8 @@
-- Add group_key column to admin_documents
-- Allows users to organize documents into logical groups
-- Add the column (nullable, VARCHAR 255)
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS group_key VARCHAR(255);
-- Add index for filtering/grouping queries
CREATE INDEX IF NOT EXISTS ix_admin_documents_group_key ON admin_documents(group_key);

View File

@@ -0,0 +1,49 @@
-- Model versions table for tracking trained model deployments.
-- Each training run can produce a model version for inference.
CREATE TABLE IF NOT EXISTS model_versions (
version_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
version VARCHAR(50) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
model_path VARCHAR(512) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'inactive',
is_active BOOLEAN NOT NULL DEFAULT FALSE,
-- Training association
task_id UUID REFERENCES training_tasks(task_id) ON DELETE SET NULL,
dataset_id UUID REFERENCES training_datasets(dataset_id) ON DELETE SET NULL,
-- Training metrics
metrics_mAP DOUBLE PRECISION,
metrics_precision DOUBLE PRECISION,
metrics_recall DOUBLE PRECISION,
document_count INTEGER NOT NULL DEFAULT 0,
-- Training configuration snapshot
training_config JSONB,
-- File info
file_size BIGINT,
-- Timestamps
trained_at TIMESTAMP WITH TIME ZONE,
activated_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_model_versions_version ON model_versions(version);
CREATE INDEX IF NOT EXISTS idx_model_versions_status ON model_versions(status);
CREATE INDEX IF NOT EXISTS idx_model_versions_is_active ON model_versions(is_active);
CREATE INDEX IF NOT EXISTS idx_model_versions_task_id ON model_versions(task_id);
CREATE INDEX IF NOT EXISTS idx_model_versions_dataset_id ON model_versions(dataset_id);
CREATE INDEX IF NOT EXISTS idx_model_versions_created ON model_versions(created_at);
-- Ensure only one active model at a time
CREATE UNIQUE INDEX IF NOT EXISTS idx_model_versions_single_active
ON model_versions(is_active) WHERE is_active = TRUE;
-- Comment
COMMENT ON TABLE model_versions IS 'Trained model versions for inference deployment';

View File

@@ -0,0 +1,46 @@
-- Add missing columns to training_tasks table
-- Add name column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS name VARCHAR(255);
UPDATE training_tasks SET name = 'Training ' || substring(task_id::text, 1, 8) WHERE name IS NULL;
ALTER TABLE training_tasks ALTER COLUMN name SET NOT NULL;
-- Add description column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS description TEXT;
-- Add admin_token column (for multi-tenant support)
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS admin_token VARCHAR(255);
-- Add task_type column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS task_type VARCHAR(20) DEFAULT 'train';
-- Add recurring schedule columns
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS cron_expression VARCHAR(50);
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT FALSE;
-- Add result metrics columns (for display without parsing JSONB)
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS result_metrics JSONB;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS document_count INTEGER DEFAULT 0;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_mAP DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_precision DOUBLE PRECISION;
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS metrics_recall DOUBLE PRECISION;
-- Rename metrics to config if exists
DO $$
BEGIN
IF EXISTS (SELECT 1 FROM information_schema.columns
WHERE table_name = 'training_tasks' AND column_name = 'metrics'
AND NOT EXISTS (SELECT 1 FROM information_schema.columns
WHERE table_name = 'training_tasks' AND column_name = 'config')) THEN
ALTER TABLE training_tasks RENAME COLUMN metrics TO config;
END IF;
END $$;
-- Add updated_at column
ALTER TABLE training_tasks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW();
-- Create index on name
CREATE INDEX IF NOT EXISTS idx_training_tasks_name ON training_tasks(name);
-- Create index on metrics_mAP
CREATE INDEX IF NOT EXISTS idx_training_tasks_mAP ON training_tasks(metrics_mAP);

View File

@@ -0,0 +1,14 @@
-- Fix foreign key constraints on model_versions table to allow CASCADE delete
-- Drop existing constraints
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_dataset_id_fkey;
ALTER TABLE model_versions DROP CONSTRAINT IF EXISTS model_versions_task_id_fkey;
-- Add constraints with ON DELETE SET NULL
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_dataset_id_fkey
FOREIGN KEY (dataset_id) REFERENCES training_datasets(dataset_id) ON DELETE SET NULL;
ALTER TABLE model_versions
ADD CONSTRAINT model_versions_task_id_fkey
FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ON DELETE SET NULL;

View File

@@ -0,0 +1,13 @@
-- Add category column to admin_documents table
-- Allows categorizing documents for training different models (e.g., invoice, letter, receipt)
ALTER TABLE admin_documents ADD COLUMN IF NOT EXISTS category VARCHAR(100) DEFAULT 'invoice';
-- Update existing NULL values to default
UPDATE admin_documents SET category = 'invoice' WHERE category IS NULL;
-- Make it NOT NULL after setting defaults
ALTER TABLE admin_documents ALTER COLUMN category SET NOT NULL;
-- Create index for category filtering
CREATE INDEX IF NOT EXISTS idx_admin_documents_category ON admin_documents(category);

View File

@@ -0,0 +1,28 @@
-- Migration: Add training_status and active_training_task_id to training_datasets
-- Description: Track training status separately from dataset build status
-- Add training_status column
ALTER TABLE training_datasets
ADD COLUMN IF NOT EXISTS training_status VARCHAR(20) DEFAULT NULL;
-- Add active_training_task_id column
ALTER TABLE training_datasets
ADD COLUMN IF NOT EXISTS active_training_task_id UUID DEFAULT NULL;
-- Create index for training_status
CREATE INDEX IF NOT EXISTS idx_training_datasets_training_status
ON training_datasets(training_status);
-- Create index for active_training_task_id
CREATE INDEX IF NOT EXISTS idx_training_datasets_active_training_task_id
ON training_datasets(active_training_task_id);
-- Update existing datasets that have been used in completed training tasks to 'trained' status
UPDATE training_datasets d
SET status = 'trained'
WHERE d.status = 'ready'
AND EXISTS (
SELECT 1 FROM training_tasks t
WHERE t.dataset_id = d.dataset_id
AND t.status = 'completed'
);

View File

@@ -0,0 +1,25 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
&& rm -rf /var/lib/apt/lists/*
# Install shared package
COPY packages/shared /app/packages/shared
RUN pip install --no-cache-dir -e /app/packages/shared
# Install inference package
COPY packages/inference /app/packages/inference
RUN pip install --no-cache-dir -e /app/packages/inference
# Copy frontend (if needed)
COPY frontend /app/frontend
WORKDIR /app/packages/inference
EXPOSE 8000
CMD ["python", "run_server.py", "--host", "0.0.0.0", "--port", "8000"]

View File

View File

@@ -0,0 +1,105 @@
"""Trigger training jobs on Azure Container Instances."""
import logging
import os
logger = logging.getLogger(__name__)
# Azure SDK is optional; only needed if using ACI trigger
try:
from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
EnvironmentVariable,
GpuResource,
ResourceRequests,
ResourceRequirements,
)
_AZURE_SDK_AVAILABLE = True
except ImportError:
_AZURE_SDK_AVAILABLE = False
def start_training_container(task_id: str) -> str | None:
"""
Start an Azure Container Instance for a training task.
Returns the container group name if successful, None otherwise.
Requires environment variables:
AZURE_SUBSCRIPTION_ID, AZURE_RESOURCE_GROUP, AZURE_ACR_IMAGE
"""
if not _AZURE_SDK_AVAILABLE:
logger.warning(
"Azure SDK not installed. Install azure-mgmt-containerinstance "
"and azure-identity to use ACI trigger."
)
return None
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "")
resource_group = os.environ.get("AZURE_RESOURCE_GROUP", "invoice-training-rg")
image = os.environ.get(
"AZURE_ACR_IMAGE", "youracr.azurecr.io/invoice-training:latest"
)
gpu_sku = os.environ.get("AZURE_GPU_SKU", "V100")
location = os.environ.get("AZURE_LOCATION", "eastus")
if not subscription_id:
logger.error("AZURE_SUBSCRIPTION_ID not set. Cannot start ACI.")
return None
credential = DefaultAzureCredential()
client = ContainerInstanceManagementClient(credential, subscription_id)
container_name = f"training-{task_id[:8]}"
env_vars = [
EnvironmentVariable(name="TASK_ID", value=task_id),
]
# Pass DB connection securely
for var in ("DB_HOST", "DB_PORT", "DB_NAME", "DB_USER"):
val = os.environ.get(var, "")
if val:
env_vars.append(EnvironmentVariable(name=var, value=val))
db_password = os.environ.get("DB_PASSWORD", "")
if db_password:
env_vars.append(
EnvironmentVariable(name="DB_PASSWORD", secure_value=db_password)
)
container = Container(
name=container_name,
image=image,
resources=ResourceRequirements(
requests=ResourceRequests(
cpu=4,
memory_in_gb=16,
gpu=GpuResource(count=1, sku=gpu_sku),
)
),
environment_variables=env_vars,
command=[
"python",
"run_training.py",
"--task-id",
task_id,
],
)
group = ContainerGroup(
location=location,
containers=[container],
os_type="Linux",
restart_policy="Never",
)
logger.info("Creating ACI container group: %s", container_name)
client.container_groups.begin_create_or_update(
resource_group, container_name, group
)
return container_name

Some files were not shown because too many files have changed in this diff Show More