Compare commits

...

3 Commits

Author SHA1 Message Date
Yaojia Wang
e83a0cae36 Add claude config 2026-01-25 16:17:39 +01:00
Yaojia Wang
d5101e3604 Add claude config 2026-01-25 16:17:23 +01:00
Yaojia Wang
e599424a92 Re-structure the project. 2026-01-25 15:21:11 +01:00
126 changed files with 16926 additions and 3000 deletions

View File

@@ -1,263 +1,143 @@
[角色] # Invoice Master POC v2
你是废才,一位资深产品经理兼全栈开发教练。
Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
你见过太多人带着"改变世界"的妄想来找你,最后连需求都说不清楚。
你也见过真正能成事的人——他们不一定聪明,但足够诚实,敢于面对自己想法的漏洞。 ## Tech Stack
你负责引导用户完成产品开发的完整旅程:从脑子里的模糊想法,到可运行的产品。 | Component | Technology |
|-----------|------------|
[任务] | Object Detection | YOLOv11 (Ultralytics) |
引导用户完成产品开发的完整流程: | OCR Engine | PaddleOCR v5 (PP-OCRv5) |
| PDF Processing | PyMuPDF (fitz) |
1. **需求收集** → 调用 product-spec-builder生成 Product-Spec.md | Database | PostgreSQL + psycopg2 |
2. **原型设计** → 调用 ui-prompt-generator生成 UI-Prompts.md可选 | Web Framework | FastAPI + Uvicorn |
3. **项目开发** → 调用 dev-builder实现项目代码 | Deep Learning | PyTorch + CUDA 12.x |
4. **本地运行** → 启动项目,输出使用指南
## WSL Environment (REQUIRED)
[文件结构]
project/ **Prefix ALL commands with:**
├── Product-Spec.md # 产品需求文档
├── Product-Spec-CHANGELOG.md # 需求变更记录 ```bash
├── UI-Prompts.md # 原型图提示词(可选) wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && <command>"
├── [项目源代码]/ # 代码文件 ```
└── .claude/
├── CLAUDE.md # 主控(本文件) **NEVER run Python commands directly in Windows PowerShell/CMD.**
└── skills/
├── product-spec-builder/ # 需求收集 ## Project-Specific Rules
├── ui-prompt-generator/ # 原型图提示词
└── dev-builder/ # 项目开发 - Python 3.11+ with type hints
- No print() in production - use logging
[总体规则] - Run tests: `pytest --cov=src`
- 严格按照 需求收集 → 原型设计(可选)→ 项目开发 → 本地运行 的流程引导
- **任何功能变更、UI 修改、需求调整,都必须先更新 Product Spec再实现代码** ## File Structure
- 无论用户如何打断或提出新问题,完成当前回答后始终引导用户进入下一步
- 始终使用**中文**进行交流 ```
src/
[运行环境要求] ├── cli/ # autolabel, train, infer, serve
**强制要求**:所有程序运行、命令执行必须在 WSL 环境中进行 ├── pdf/ # extractor, renderer, detector
├── ocr/ # PaddleOCR wrapper, machine_code_parser
- **WSL**:所有 bash 命令必须通过 `wsl` 前缀执行 ├── inference/ # pipeline, yolo_detector, field_extractor
- **Conda 环境**:必须使用 `invoice-py311` 环境 ├── normalize/ # Per-field normalizers
├── matcher/ # Exact, substring, fuzzy strategies
命令执行格式: ├── processing/ # CPU/GPU pool architecture
```bash ├── web/ # FastAPI app, routes, services, schemas
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && <你的命令>" ├── utils/ # validators, text_cleaner, fuzzy_matcher
``` └── data/ # Database operations
tests/ # Mirror of src structure
示例: runs/train/ # Training outputs
```bash ```
# 运行 Python 脚本
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && python main.py" ## Supported Fields
# 安装依赖 | ID | Field | Description |
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && pip install -r requirements.txt" |----|-------|-------------|
| 0 | invoice_number | Invoice number |
# 运行测试 | 1 | invoice_date | Invoice date |
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && pytest" | 2 | invoice_due_date | Due date |
``` | 3 | ocr_number | OCR reference (Swedish payment) |
| 4 | bankgiro | Bankgiro account |
**注意** | 5 | plusgiro | Plusgiro account |
- 不要直接在 Windows PowerShell/CMD 中运行 Python 命令 | 6 | amount | Amount |
- 每次执行命令都需要激活 conda 环境(因为是非交互式 shell | 7 | supplier_organisation_number | Supplier org number |
- 路径需要转换为 WSL 格式(如 `/mnt/c/Users/...` | 8 | payment_line | Payment line (machine-readable) |
| 9 | customer_number | Customer number |
[Skill 调用规则]
[product-spec-builder] ## Key Patterns
**自动调用**
- 用户表达想要开发产品、应用、工具时 ### Inference Result
- 用户描述产品想法、功能需求时
- 用户要修改 UI、改界面、调整布局时迭代模式 ```python
- 用户要增加功能、新增功能时(迭代模式) @dataclass
- 用户要改需求、调整功能、修改逻辑时(迭代模式) class InferenceResult:
document_id: str
**手动调用**/prd document_type: str # "invoice" or "letter"
fields: dict[str, str]
[ui-prompt-generator] confidence: dict[str, float]
**手动调用**/ui cross_validation: CrossValidationResult | None
processing_time_ms: float
前置条件Product-Spec.md 必须存在 ```
[dev-builder] ### API Schemas
**手动调用**/dev
See `src/web/schemas.py` for request/response models.
前置条件Product-Spec.md 必须存在
## Environment Variables
[项目状态检测与路由]
初始化时自动检测项目进度,路由到对应阶段: ```bash
# Required
检测逻辑: DB_PASSWORD=
- 无 Product-Spec.md → 全新项目 → 引导用户描述想法或输入 /prd
- 有 Product-Spec.md无代码 → Spec 已完成 → 输出交付指南 # Optional (with defaults)
- 有 Product-Spec.md有代码 → 项目已创建 → 可执行 /check 或 /run DB_HOST=192.168.68.31
DB_PORT=5432
显示格式: DB_NAME=docmaster
"📊 **项目进度检测** DB_USER=docmaster
MODEL_PATH=runs/train/invoice_fields/weights/best.pt
- Product Spec[已完成/未完成] CONFIDENCE_THRESHOLD=0.5
- 原型图提示词:[已生成/未生成] SERVER_HOST=0.0.0.0
- 项目代码:[已创建/未创建] SERVER_PORT=8000
```
**当前阶段**[阶段名称]
**下一步**[具体指令或操作]" ## CLI Commands
[工作流程] ```bash
[需求收集阶段] # Auto-labeling
触发:用户表达产品想法(自动)或输入 /prd手动 python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
执行:调用 product-spec-builder skill # Training
python -m src.cli.train --model yolo11n.pt --epochs 100 --batch 16 --name invoice_fields
完成后:输出交付指南,引导下一步
# Inference
[交付阶段] python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt --input invoice.pdf --gpu
触发Product Spec 生成完成后自动执行
# Web Server
输出: python run_server.py --port 8000
"✅ **Product Spec 已生成!** ```
文件Product-Spec.md ## API Endpoints
--- | Method | Endpoint | Description |
|--------|----------|-------------|
## 📘 接下来 | GET | `/` | Web UI |
| GET | `/api/v1/health` | Health check |
- 输入 /ui 生成原型图提示词(可选) | POST | `/api/v1/infer` | Process invoice |
- 输入 /dev 开始开发项目 | GET | `/api/v1/results/{filename}` | Get visualization |
- 直接对话可以改 UI、加功能"
## Current Status
[原型图阶段]
触发:用户输入 /ui - **Tests**: 688 passing
- **Coverage**: 37%
执行:调用 ui-prompt-generator skill - **Model**: 93.5% mAP@0.5
- **Documents Labeled**: 9,738
完成后:
"✅ **原型图提示词已生成!** ## Quick Start
文件UI-Prompts.md ```bash
# Start server
把提示词发给 AI 绘图工具生成原型图,然后输入 /dev 开始开发。" wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py"
[项目开发阶段] # Run tests
触发:用户输入 /dev wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest"
第一步:询问原型图 # Access UI: http://localhost:8000
询问用户:"有原型图或设计稿吗?有的话发给我参考。" ```
用户发送图片 → 记录,开发时参考
用户说没有 → 继续
第二步:执行开发
调用 dev-builder skill
完成后:引导用户执行 /run
[代码检查阶段]
触发:用户输入 /check
执行:
第一步:读取 Product Spec 文档
加载 Product-Spec.md 文件
解析功能需求、UI 布局
第二步:扫描项目代码
遍历项目目录下的代码文件
识别已实现的功能、组件
第三步:功能完整度检查
- 功能需求Product Spec 功能需求 vs 代码实现
- UI 布局Product Spec 布局描述 vs 界面代码
第四步:输出检查报告
输出:
"📋 **项目完整度检查报告**
**对照文档**Product-Spec.md
---
✅ **已完成X项**
- [功能名称][实现位置]
⚠️ **部分完成X项**
- [功能名称][缺失内容]
❌ **缺失X项**
- [功能名称]:未实现
---
💡 **改进建议**
1. [具体建议]
2. [具体建议]
---
需要我帮你补充这些功能吗?或输入 /run 先跑起来看看。"
[本地运行阶段]
触发:用户输入 /run
执行:自动检测项目类型,安装依赖,启动项目
输出:
"🚀 **项目已启动!**
**访问地址**http://localhost:[端口号]
---
## 📖 使用指南
[根据 Product Spec 生成简要使用说明]
---
💡 **提示**
- /stop 停止服务
- /check 检查完整度
- /prd 修改需求"
[内容修订]
当用户提出修改意见时:
**流程**:先更新文档 → 再实现代码
1. 调用 product-spec-builder迭代模式
- 通过追问明确变更内容
- 更新 Product-Spec.md
- 更新 Product-Spec-CHANGELOG.md
2. 调用 dev-builder 实现代码变更
3. 建议用户执行 /check 验证
[指令集]
/prd - 需求收集,生成 Product Spec
/ui - 生成原型图提示词
/dev - 开发项目代码
/check - 对照 Spec 检查代码完整度
/run - 本地运行项目
/stop - 停止运行中的服务
/status - 显示项目进度
/help - 显示所有指令
[初始化]
以下ASCII艺术应该显示"FEICAI"字样。如果您看到乱码或显示异常请帮忙纠正使用ASCII艺术生成显示"FEICAI"
```
"███████╗███████╗██╗ ██████╗ █████╗ ██╗
██╔════╝██╔════╝██║██╔════╝██╔══██╗██║
█████╗ █████╗ ██║██║ ███████║██║
██╔══╝ ██╔══╝ ██║██║ ██╔══██║██║
██║ ███████╗██║╚██████╗██║ ██║██║
╚═╝ ╚══════╝╚═╝ ╚═════╝╚═╝ ╚═╝╚═╝"
```
"👋 我是废才,产品经理兼开发教练。
我不聊理想,只聊产品。你负责想,我负责问到你想清楚。
从需求文档到本地运行,全程我带着走。
过程中我会问很多问题,有些可能让你不舒服。不过放心,我只是想让你的产品能落地,仅此而已。
💡 输入 /help 查看所有指令
现在,说说你想做什么?"
执行 [项目状态检测与路由]

View File

@@ -0,0 +1,22 @@
# Build and Fix
Incrementally fix Python errors and test failures.
## Workflow
1. Run check: `mypy src/ --ignore-missing-imports` or `pytest -x --tb=short`
2. Parse errors, group by file, sort by severity (ImportError > TypeError > other)
3. For each error:
- Show context (5 lines)
- Explain and propose fix
- Apply fix
- Re-run test for that file
- Verify resolved
4. Stop if: fix introduces new errors, same error after 3 attempts, or user pauses
5. Show summary: fixed / remaining / new errors
## Rules
- Fix ONE error at a time
- Re-run tests after each fix
- Never batch multiple unrelated fixes

View File

@@ -0,0 +1,74 @@
# Checkpoint Command
Create or verify a checkpoint in your workflow.
## Usage
`/checkpoint [create|verify|list] [name]`
## Create Checkpoint
When creating a checkpoint:
1. Run `/verify quick` to ensure current state is clean
2. Create a git stash or commit with checkpoint name
3. Log checkpoint to `.claude/checkpoints.log`:
```bash
echo "$(date +%Y-%m-%d-%H:%M) | $CHECKPOINT_NAME | $(git rev-parse --short HEAD)" >> .claude/checkpoints.log
```
4. Report checkpoint created
## Verify Checkpoint
When verifying against a checkpoint:
1. Read checkpoint from log
2. Compare current state to checkpoint:
- Files added since checkpoint
- Files modified since checkpoint
- Test pass rate now vs then
- Coverage now vs then
3. Report:
```
CHECKPOINT COMPARISON: $NAME
============================
Files changed: X
Tests: +Y passed / -Z failed
Coverage: +X% / -Y%
Build: [PASS/FAIL]
```
## List Checkpoints
Show all checkpoints with:
- Name
- Timestamp
- Git SHA
- Status (current, behind, ahead)
## Workflow
Typical checkpoint flow:
```
[Start] --> /checkpoint create "feature-start"
|
[Implement] --> /checkpoint create "core-done"
|
[Test] --> /checkpoint verify "core-done"
|
[Refactor] --> /checkpoint create "refactor-done"
|
[PR] --> /checkpoint verify "feature-start"
```
## Arguments
$ARGUMENTS:
- `create <name>` - Create named checkpoint
- `verify <name>` - Verify against named checkpoint
- `list` - Show all checkpoints
- `clear` - Remove old checkpoints (keeps last 5)

View File

@@ -0,0 +1,46 @@
# Code Review
Security and quality review of uncommitted changes.
## Workflow
1. Get changed files: `git diff --name-only HEAD` and `git diff --staged --name-only`
2. Review each file for issues (see checklist below)
3. Run automated checks: `mypy src/`, `ruff check src/`, `pytest -x`
4. Generate report with severity, location, description, suggested fix
5. Block commit if CRITICAL or HIGH issues found
## Checklist
### CRITICAL (Block)
- Hardcoded credentials, API keys, tokens, passwords
- SQL injection (must use parameterized queries)
- Path traversal risks
- Missing input validation on API endpoints
- Missing authentication/authorization
### HIGH (Block)
- Functions > 50 lines, files > 800 lines
- Nesting depth > 4 levels
- Missing error handling or bare `except:`
- `print()` in production code (use logging)
- Mutable default arguments
### MEDIUM (Warn)
- Missing type hints on public functions
- Missing tests for new code
- Duplicate code, magic numbers
- Unused imports/variables
- TODO/FIXME comments
## Report Format
```
[SEVERITY] file:line - Issue description
Suggested fix: ...
```
## Never Approve Code With Security Vulnerabilities!

40
.claude/commands/e2e.md Normal file
View File

@@ -0,0 +1,40 @@
# E2E Testing
End-to-end testing for the Invoice Field Extraction API.
## When to Use
- Testing complete inference pipeline (PDF -> Fields)
- Verifying API endpoints work end-to-end
- Validating YOLO + OCR + field extraction integration
- Pre-deployment verification
## Workflow
1. Ensure server is running: `python run_server.py`
2. Run health check: `curl http://localhost:8000/api/v1/health`
3. Run E2E tests: `pytest tests/e2e/ -v`
4. Verify results and capture any failures
## Critical Scenarios (Must Pass)
1. Health check returns `{"status": "healthy", "model_loaded": true}`
2. PDF upload returns valid response with fields
3. Fields extracted with confidence scores
4. Visualization image generated
5. Cross-validation included for invoices with payment_line
## Checklist
- [ ] Server running on http://localhost:8000
- [ ] Health check passes
- [ ] PDF inference returns valid JSON
- [ ] At least one field extracted
- [ ] Visualization URL returns image
- [ ] Response time < 10 seconds
- [ ] No server errors in logs
## Test Location
E2E tests: `tests/e2e/`
Sample fixtures: `tests/fixtures/`

174
.claude/commands/eval.md Normal file
View File

@@ -0,0 +1,174 @@
# Eval Command
Evaluate model performance and field extraction accuracy.
## Usage
`/eval [model|accuracy|compare|report]`
## Model Evaluation
`/eval model`
Evaluate YOLO model performance on test dataset:
```bash
# Run model evaluation
python -m src.cli.train --model runs/train/invoice_fields/weights/best.pt --eval-only
# Or use ultralytics directly
yolo val model=runs/train/invoice_fields/weights/best.pt data=data.yaml
```
Output:
```
Model Evaluation: invoice_fields/best.pt
========================================
mAP@0.5: 93.5%
mAP@0.5-0.95: 83.0%
Per-class AP:
- invoice_number: 95.2%
- invoice_date: 94.8%
- invoice_due_date: 93.1%
- ocr_number: 91.5%
- bankgiro: 92.3%
- plusgiro: 90.8%
- amount: 88.7%
- supplier_org_num: 85.2%
- payment_line: 82.4%
- customer_number: 81.1%
```
## Accuracy Evaluation
`/eval accuracy`
Evaluate field extraction accuracy against ground truth:
```bash
# Run accuracy evaluation on labeled data
python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt \
--input ~/invoice-data/test/*.pdf \
--ground-truth ~/invoice-data/test/labels.csv \
--output eval_results.json
```
Output:
```
Field Extraction Accuracy
=========================
Documents tested: 500
Per-field accuracy:
- InvoiceNumber: 98.9% (494/500)
- InvoiceDate: 95.5% (478/500)
- InvoiceDueDate: 95.9% (480/500)
- OCR: 99.1% (496/500)
- Bankgiro: 99.0% (495/500)
- Plusgiro: 99.4% (497/500)
- Amount: 91.3% (457/500)
- supplier_org: 78.2% (391/500)
Overall: 94.8%
```
## Compare Models
`/eval compare`
Compare two model versions:
```bash
# Compare old vs new model
python -m src.cli.eval compare \
--model-a runs/train/invoice_v1/weights/best.pt \
--model-b runs/train/invoice_v2/weights/best.pt \
--test-data ~/invoice-data/test/
```
Output:
```
Model Comparison
================
Model A Model B Delta
mAP@0.5: 91.2% 93.5% +2.3%
Accuracy: 92.1% 94.8% +2.7%
Speed (ms): 1850 1520 -330
Per-field improvements:
- amount: +4.2%
- payment_line: +3.8%
- customer_num: +2.1%
Recommendation: Deploy Model B
```
## Generate Report
`/eval report`
Generate comprehensive evaluation report:
```bash
python -m src.cli.eval report --output eval_report.md
```
Output:
```markdown
# Evaluation Report
Generated: 2026-01-25
## Model Performance
- Model: runs/train/invoice_fields/weights/best.pt
- mAP@0.5: 93.5%
- Training samples: 9,738
## Field Extraction Accuracy
| Field | Accuracy | Errors |
|-------|----------|--------|
| InvoiceNumber | 98.9% | 6 |
| Amount | 91.3% | 43 |
...
## Error Analysis
### Common Errors
1. Amount: OCR misreads comma as period
2. supplier_org: Missing from some invoices
3. payment_line: Partially obscured by stamps
## Recommendations
1. Add more training data for low-accuracy fields
2. Implement OCR error correction for amounts
3. Consider confidence threshold tuning
```
## Quick Commands
```bash
# Evaluate model metrics
yolo val model=runs/train/invoice_fields/weights/best.pt
# Test inference on sample
python -m src.cli.infer --input sample.pdf --output result.json --gpu
# Check test coverage
pytest --cov=src --cov-report=html
```
## Evaluation Metrics
| Metric | Target | Current |
|--------|--------|---------|
| mAP@0.5 | >90% | 93.5% |
| Overall Accuracy | >90% | 94.8% |
| Test Coverage | >60% | 37% |
| Tests Passing | 100% | 100% |
## When to Evaluate
- After training a new model
- Before deploying to production
- After adding new training data
- When accuracy complaints arise
- Weekly performance monitoring

70
.claude/commands/learn.md Normal file
View File

@@ -0,0 +1,70 @@
# /learn - Extract Reusable Patterns
Analyze the current session and extract any patterns worth saving as skills.
## Trigger
Run `/learn` at any point during a session when you've solved a non-trivial problem.
## What to Extract
Look for:
1. **Error Resolution Patterns**
- What error occurred?
- What was the root cause?
- What fixed it?
- Is this reusable for similar errors?
2. **Debugging Techniques**
- Non-obvious debugging steps
- Tool combinations that worked
- Diagnostic patterns
3. **Workarounds**
- Library quirks
- API limitations
- Version-specific fixes
4. **Project-Specific Patterns**
- Codebase conventions discovered
- Architecture decisions made
- Integration patterns
## Output Format
Create a skill file at `~/.claude/skills/learned/[pattern-name].md`:
```markdown
# [Descriptive Pattern Name]
**Extracted:** [Date]
**Context:** [Brief description of when this applies]
## Problem
[What problem this solves - be specific]
## Solution
[The pattern/technique/workaround]
## Example
[Code example if applicable]
## When to Use
[Trigger conditions - what should activate this skill]
```
## Process
1. Review the session for extractable patterns
2. Identify the most valuable/reusable insight
3. Draft the skill file
4. Ask user to confirm before saving
5. Save to `~/.claude/skills/learned/`
## Notes
- Don't extract trivial fixes (typos, simple syntax errors)
- Don't extract one-time issues (specific API outages, etc.)
- Focus on patterns that will save time in future sessions
- Keep skills focused - one pattern per skill

View File

@@ -0,0 +1,172 @@
# Orchestrate Command
Sequential agent workflow for complex tasks.
## Usage
`/orchestrate [workflow-type] [task-description]`
## Workflow Types
### feature
Full feature implementation workflow:
```
planner -> tdd-guide -> code-reviewer -> security-reviewer
```
### bugfix
Bug investigation and fix workflow:
```
explorer -> tdd-guide -> code-reviewer
```
### refactor
Safe refactoring workflow:
```
architect -> code-reviewer -> tdd-guide
```
### security
Security-focused review:
```
security-reviewer -> code-reviewer -> architect
```
## Execution Pattern
For each agent in the workflow:
1. **Invoke agent** with context from previous agent
2. **Collect output** as structured handoff document
3. **Pass to next agent** in chain
4. **Aggregate results** into final report
## Handoff Document Format
Between agents, create handoff document:
```markdown
## HANDOFF: [previous-agent] -> [next-agent]
### Context
[Summary of what was done]
### Findings
[Key discoveries or decisions]
### Files Modified
[List of files touched]
### Open Questions
[Unresolved items for next agent]
### Recommendations
[Suggested next steps]
```
## Example: Feature Workflow
```
/orchestrate feature "Add user authentication"
```
Executes:
1. **Planner Agent**
- Analyzes requirements
- Creates implementation plan
- Identifies dependencies
- Output: `HANDOFF: planner -> tdd-guide`
2. **TDD Guide Agent**
- Reads planner handoff
- Writes tests first
- Implements to pass tests
- Output: `HANDOFF: tdd-guide -> code-reviewer`
3. **Code Reviewer Agent**
- Reviews implementation
- Checks for issues
- Suggests improvements
- Output: `HANDOFF: code-reviewer -> security-reviewer`
4. **Security Reviewer Agent**
- Security audit
- Vulnerability check
- Final approval
- Output: Final Report
## Final Report Format
```
ORCHESTRATION REPORT
====================
Workflow: feature
Task: Add user authentication
Agents: planner -> tdd-guide -> code-reviewer -> security-reviewer
SUMMARY
-------
[One paragraph summary]
AGENT OUTPUTS
-------------
Planner: [summary]
TDD Guide: [summary]
Code Reviewer: [summary]
Security Reviewer: [summary]
FILES CHANGED
-------------
[List all files modified]
TEST RESULTS
------------
[Test pass/fail summary]
SECURITY STATUS
---------------
[Security findings]
RECOMMENDATION
--------------
[SHIP / NEEDS WORK / BLOCKED]
```
## Parallel Execution
For independent checks, run agents in parallel:
```markdown
### Parallel Phase
Run simultaneously:
- code-reviewer (quality)
- security-reviewer (security)
- architect (design)
### Merge Results
Combine outputs into single report
```
## Arguments
$ARGUMENTS:
- `feature <description>` - Full feature workflow
- `bugfix <description>` - Bug fix workflow
- `refactor <description>` - Refactoring workflow
- `security <description>` - Security review workflow
- `custom <agents> <description>` - Custom agent sequence
## Custom Workflow Example
```
/orchestrate custom "architect,tdd-guide,code-reviewer" "Redesign caching layer"
```
## Tips
1. **Start with planner** for complex features
2. **Always include code-reviewer** before merge
3. **Use security-reviewer** for auth/payment/PII
4. **Keep handoffs concise** - focus on what next agent needs
5. **Run verification** between agents if needed

113
.claude/commands/plan.md Normal file
View File

@@ -0,0 +1,113 @@
---
description: Restate requirements, assess risks, and create step-by-step implementation plan. WAIT for user CONFIRM before touching any code.
---
# Plan Command
This command invokes the **planner** agent to create a comprehensive implementation plan before writing any code.
## What This Command Does
1. **Restate Requirements** - Clarify what needs to be built
2. **Identify Risks** - Surface potential issues and blockers
3. **Create Step Plan** - Break down implementation into phases
4. **Wait for Confirmation** - MUST receive user approval before proceeding
## When to Use
Use `/plan` when:
- Starting a new feature
- Making significant architectural changes
- Working on complex refactoring
- Multiple files/components will be affected
- Requirements are unclear or ambiguous
## How It Works
The planner agent will:
1. **Analyze the request** and restate requirements in clear terms
2. **Break down into phases** with specific, actionable steps
3. **Identify dependencies** between components
4. **Assess risks** and potential blockers
5. **Estimate complexity** (High/Medium/Low)
6. **Present the plan** and WAIT for your explicit confirmation
## Example Usage
```
User: /plan I need to add real-time notifications when markets resolve
Agent (planner):
# Implementation Plan: Real-Time Market Resolution Notifications
## Requirements Restatement
- Send notifications to users when markets they're watching resolve
- Support multiple notification channels (in-app, email, webhook)
- Ensure notifications are delivered reliably
- Include market outcome and user's position result
## Implementation Phases
### Phase 1: Database Schema
- Add notifications table with columns: id, user_id, market_id, type, status, created_at
- Add user_notification_preferences table for channel preferences
- Create indexes on user_id and market_id for performance
### Phase 2: Notification Service
- Create notification service in lib/notifications.ts
- Implement notification queue using BullMQ/Redis
- Add retry logic for failed deliveries
- Create notification templates
### Phase 3: Integration Points
- Hook into market resolution logic (when status changes to "resolved")
- Query all users with positions in market
- Enqueue notifications for each user
### Phase 4: Frontend Components
- Create NotificationBell component in header
- Add NotificationList modal
- Implement real-time updates via Supabase subscriptions
- Add notification preferences page
## Dependencies
- Redis (for queue)
- Email service (SendGrid/Resend)
- Supabase real-time subscriptions
## Risks
- HIGH: Email deliverability (SPF/DKIM required)
- MEDIUM: Performance with 1000+ users per market
- MEDIUM: Notification spam if markets resolve frequently
- LOW: Real-time subscription overhead
## Estimated Complexity: MEDIUM
- Backend: 4-6 hours
- Frontend: 3-4 hours
- Testing: 2-3 hours
- Total: 9-13 hours
**WAITING FOR CONFIRMATION**: Proceed with this plan? (yes/no/modify)
```
## Important Notes
**CRITICAL**: The planner agent will **NOT** write any code until you explicitly confirm the plan with "yes" or "proceed" or similar affirmative response.
If you want changes, respond with:
- "modify: [your changes]"
- "different approach: [alternative]"
- "skip phase 2 and do phase 3 first"
## Integration with Other Commands
After planning:
- Use `/tdd` to implement with test-driven development
- Use `/build-and-fix` if build errors occur
- Use `/code-review` to review completed implementation
## Related Agents
This command invokes the `planner` agent located at:
`~/.claude/agents/planner.md`

View File

@@ -0,0 +1,28 @@
# Refactor Clean
Safely identify and remove dead code with test verification:
1. Run dead code analysis tools:
- knip: Find unused exports and files
- depcheck: Find unused dependencies
- ts-prune: Find unused TypeScript exports
2. Generate comprehensive report in .reports/dead-code-analysis.md
3. Categorize findings by severity:
- SAFE: Test files, unused utilities
- CAUTION: API routes, components
- DANGER: Config files, main entry points
4. Propose safe deletions only
5. Before each deletion:
- Run full test suite
- Verify tests pass
- Apply change
- Re-run tests
- Rollback if tests fail
6. Show summary of cleaned items
Never delete code without running tests first!

View File

@@ -0,0 +1,80 @@
---
description: Configure your preferred package manager (npm/pnpm/yarn/bun)
disable-model-invocation: true
---
# Package Manager Setup
Configure your preferred package manager for this project or globally.
## Usage
```bash
# Detect current package manager
node scripts/setup-package-manager.js --detect
# Set global preference
node scripts/setup-package-manager.js --global pnpm
# Set project preference
node scripts/setup-package-manager.js --project bun
# List available package managers
node scripts/setup-package-manager.js --list
```
## Detection Priority
When determining which package manager to use, the following order is checked:
1. **Environment variable**: `CLAUDE_PACKAGE_MANAGER`
2. **Project config**: `.claude/package-manager.json`
3. **package.json**: `packageManager` field
4. **Lock file**: Presence of package-lock.json, yarn.lock, pnpm-lock.yaml, or bun.lockb
5. **Global config**: `~/.claude/package-manager.json`
6. **Fallback**: First available package manager (pnpm > bun > yarn > npm)
## Configuration Files
### Global Configuration
```json
// ~/.claude/package-manager.json
{
"packageManager": "pnpm"
}
```
### Project Configuration
```json
// .claude/package-manager.json
{
"packageManager": "bun"
}
```
### package.json
```json
{
"packageManager": "pnpm@8.6.0"
}
```
## Environment Variable
Set `CLAUDE_PACKAGE_MANAGER` to override all other detection methods:
```bash
# Windows (PowerShell)
$env:CLAUDE_PACKAGE_MANAGER = "pnpm"
# macOS/Linux
export CLAUDE_PACKAGE_MANAGER=pnpm
```
## Run the Detection
To see current package manager detection results, run:
```bash
node scripts/setup-package-manager.js --detect
```

326
.claude/commands/tdd.md Normal file
View File

@@ -0,0 +1,326 @@
---
description: Enforce test-driven development workflow. Scaffold interfaces, generate tests FIRST, then implement minimal code to pass. Ensure 80%+ coverage.
---
# TDD Command
This command invokes the **tdd-guide** agent to enforce test-driven development methodology.
## What This Command Does
1. **Scaffold Interfaces** - Define types/interfaces first
2. **Generate Tests First** - Write failing tests (RED)
3. **Implement Minimal Code** - Write just enough to pass (GREEN)
4. **Refactor** - Improve code while keeping tests green (REFACTOR)
5. **Verify Coverage** - Ensure 80%+ test coverage
## When to Use
Use `/tdd` when:
- Implementing new features
- Adding new functions/components
- Fixing bugs (write test that reproduces bug first)
- Refactoring existing code
- Building critical business logic
## How It Works
The tdd-guide agent will:
1. **Define interfaces** for inputs/outputs
2. **Write tests that will FAIL** (because code doesn't exist yet)
3. **Run tests** and verify they fail for the right reason
4. **Write minimal implementation** to make tests pass
5. **Run tests** and verify they pass
6. **Refactor** code while keeping tests green
7. **Check coverage** and add more tests if below 80%
## TDD Cycle
```
RED → GREEN → REFACTOR → REPEAT
RED: Write a failing test
GREEN: Write minimal code to pass
REFACTOR: Improve code, keep tests passing
REPEAT: Next feature/scenario
```
## Example Usage
```
User: /tdd I need a function to calculate market liquidity score
Agent (tdd-guide):
# TDD Session: Market Liquidity Score Calculator
## Step 1: Define Interface (SCAFFOLD)
```typescript
// lib/liquidity.ts
export interface MarketData {
totalVolume: number
bidAskSpread: number
activeTraders: number
lastTradeTime: Date
}
export function calculateLiquidityScore(market: MarketData): number {
// TODO: Implementation
throw new Error('Not implemented')
}
```
## Step 2: Write Failing Test (RED)
```typescript
// lib/liquidity.test.ts
import { calculateLiquidityScore } from './liquidity'
describe('calculateLiquidityScore', () => {
it('should return high score for liquid market', () => {
const market = {
totalVolume: 100000,
bidAskSpread: 0.01,
activeTraders: 500,
lastTradeTime: new Date()
}
const score = calculateLiquidityScore(market)
expect(score).toBeGreaterThan(80)
expect(score).toBeLessThanOrEqual(100)
})
it('should return low score for illiquid market', () => {
const market = {
totalVolume: 100,
bidAskSpread: 0.5,
activeTraders: 2,
lastTradeTime: new Date(Date.now() - 86400000) // 1 day ago
}
const score = calculateLiquidityScore(market)
expect(score).toBeLessThan(30)
expect(score).toBeGreaterThanOrEqual(0)
})
it('should handle edge case: zero volume', () => {
const market = {
totalVolume: 0,
bidAskSpread: 0,
activeTraders: 0,
lastTradeTime: new Date()
}
const score = calculateLiquidityScore(market)
expect(score).toBe(0)
})
})
```
## Step 3: Run Tests - Verify FAIL
```bash
npm test lib/liquidity.test.ts
FAIL lib/liquidity.test.ts
✕ should return high score for liquid market (2 ms)
Error: Not implemented
1 test failed, 0 passed
```
✅ Tests fail as expected. Ready to implement.
## Step 4: Implement Minimal Code (GREEN)
```typescript
// lib/liquidity.ts
export function calculateLiquidityScore(market: MarketData): number {
// Handle zero volume edge case
if (market.totalVolume === 0) {
return 0
}
// Calculate component scores (0-100 scale)
const volumeScore = Math.min(market.totalVolume / 1000, 100)
const spreadScore = Math.max(100 - (market.bidAskSpread * 1000), 0)
const traderScore = Math.min(market.activeTraders / 10, 100)
// Recent activity bonus
const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60)
const recencyScore = Math.max(100 - (hoursSinceLastTrade * 10), 0)
// Weighted average
const score = (
volumeScore * 0.4 +
spreadScore * 0.3 +
traderScore * 0.2 +
recencyScore * 0.1
)
return Math.min(Math.max(score, 0), 100) // Clamp to 0-100
}
```
## Step 5: Run Tests - Verify PASS
```bash
npm test lib/liquidity.test.ts
PASS lib/liquidity.test.ts
✓ should return high score for liquid market (3 ms)
✓ should return low score for illiquid market (2 ms)
✓ should handle edge case: zero volume (1 ms)
3 tests passed
```
✅ All tests passing!
## Step 6: Refactor (IMPROVE)
```typescript
// lib/liquidity.ts - Refactored with constants and better readability
const WEIGHTS = {
VOLUME: 0.4,
SPREAD: 0.3,
TRADERS: 0.2,
RECENCY: 0.1,
} as const
const SCALE_FACTORS = {
VOLUME: 1000,
SPREAD: 1000,
TRADERS: 10,
RECENCY_PENALTY: 10,
} as const
function clamp(value: number, min: number, max: number): number {
return Math.min(Math.max(value, min), max)
}
export function calculateLiquidityScore(market: MarketData): number {
if (market.totalVolume === 0) return 0
const volumeScore = Math.min(market.totalVolume / SCALE_FACTORS.VOLUME, 100)
const spreadScore = clamp(100 - (market.bidAskSpread * SCALE_FACTORS.SPREAD), 0, 100)
const traderScore = Math.min(market.activeTraders / SCALE_FACTORS.TRADERS, 100)
const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60)
const recencyScore = clamp(100 - (hoursSinceLastTrade * SCALE_FACTORS.RECENCY_PENALTY), 0, 100)
const weightedScore =
volumeScore * WEIGHTS.VOLUME +
spreadScore * WEIGHTS.SPREAD +
traderScore * WEIGHTS.TRADERS +
recencyScore * WEIGHTS.RECENCY
return clamp(weightedScore, 0, 100)
}
```
## Step 7: Verify Tests Still Pass
```bash
npm test lib/liquidity.test.ts
PASS lib/liquidity.test.ts
✓ should return high score for liquid market (3 ms)
✓ should return low score for illiquid market (2 ms)
✓ should handle edge case: zero volume (1 ms)
3 tests passed
```
✅ Refactoring complete, tests still passing!
## Step 8: Check Coverage
```bash
npm test -- --coverage lib/liquidity.test.ts
File | % Stmts | % Branch | % Funcs | % Lines
---------------|---------|----------|---------|--------
liquidity.ts | 100 | 100 | 100 | 100
Coverage: 100% ✅ (Target: 80%)
```
✅ TDD session complete!
```
## TDD Best Practices
**DO:**
- ✅ Write the test FIRST, before any implementation
- ✅ Run tests and verify they FAIL before implementing
- ✅ Write minimal code to make tests pass
- ✅ Refactor only after tests are green
- ✅ Add edge cases and error scenarios
- ✅ Aim for 80%+ coverage (100% for critical code)
**DON'T:**
- ❌ Write implementation before tests
- ❌ Skip running tests after each change
- ❌ Write too much code at once
- ❌ Ignore failing tests
- ❌ Test implementation details (test behavior)
- ❌ Mock everything (prefer integration tests)
## Test Types to Include
**Unit Tests** (Function-level):
- Happy path scenarios
- Edge cases (empty, null, max values)
- Error conditions
- Boundary values
**Integration Tests** (Component-level):
- API endpoints
- Database operations
- External service calls
- React components with hooks
**E2E Tests** (use `/e2e` command):
- Critical user flows
- Multi-step processes
- Full stack integration
## Coverage Requirements
- **80% minimum** for all code
- **100% required** for:
- Financial calculations
- Authentication logic
- Security-critical code
- Core business logic
## Important Notes
**MANDATORY**: Tests must be written BEFORE implementation. The TDD cycle is:
1. **RED** - Write failing test
2. **GREEN** - Implement to pass
3. **REFACTOR** - Improve code
Never skip the RED phase. Never write code before tests.
## Integration with Other Commands
- Use `/plan` first to understand what to build
- Use `/tdd` to implement with tests
- Use `/build-and-fix` if build errors occur
- Use `/code-review` to review implementation
- Use `/test-coverage` to verify coverage
## Related Agents
This command invokes the `tdd-guide` agent located at:
`~/.claude/agents/tdd-guide.md`
And can reference the `tdd-workflow` skill at:
`~/.claude/skills/tdd-workflow/`

View File

@@ -0,0 +1,27 @@
# Test Coverage
Analyze test coverage and generate missing tests:
1. Run tests with coverage: npm test --coverage or pnpm test --coverage
2. Analyze coverage report (coverage/coverage-summary.json)
3. Identify files below 80% coverage threshold
4. For each under-covered file:
- Analyze untested code paths
- Generate unit tests for functions
- Generate integration tests for APIs
- Generate E2E tests for critical flows
5. Verify new tests pass
6. Show before/after coverage metrics
7. Ensure project reaches 80%+ overall coverage
Focus on:
- Happy path scenarios
- Error handling
- Edge cases (null, undefined, empty)
- Boundary conditions

View File

@@ -0,0 +1,17 @@
# Update Codemaps
Analyze the codebase structure and update architecture documentation:
1. Scan all source files for imports, exports, and dependencies
2. Generate token-lean codemaps in the following format:
- codemaps/architecture.md - Overall architecture
- codemaps/backend.md - Backend structure
- codemaps/frontend.md - Frontend structure
- codemaps/data.md - Data models and schemas
3. Calculate diff percentage from previous version
4. If changes > 30%, request user approval before updating
5. Add freshness timestamp to each codemap
6. Save reports to .reports/codemap-diff.txt
Use TypeScript/Node.js for analysis. Focus on high-level structure, not implementation details.

View File

@@ -0,0 +1,31 @@
# Update Documentation
Sync documentation from source-of-truth:
1. Read package.json scripts section
- Generate scripts reference table
- Include descriptions from comments
2. Read .env.example
- Extract all environment variables
- Document purpose and format
3. Generate docs/CONTRIB.md with:
- Development workflow
- Available scripts
- Environment setup
- Testing procedures
4. Generate docs/RUNBOOK.md with:
- Deployment procedures
- Monitoring and alerts
- Common issues and fixes
- Rollback procedures
5. Identify obsolete documentation:
- Find docs not modified in 90+ days
- List for manual review
6. Show diff summary
Single source of truth: package.json and .env.example

View File

@@ -0,0 +1,59 @@
# Verification Command
Run comprehensive verification on current codebase state.
## Instructions
Execute verification in this exact order:
1. **Build Check**
- Run the build command for this project
- If it fails, report errors and STOP
2. **Type Check**
- Run TypeScript/type checker
- Report all errors with file:line
3. **Lint Check**
- Run linter
- Report warnings and errors
4. **Test Suite**
- Run all tests
- Report pass/fail count
- Report coverage percentage
5. **Console.log Audit**
- Search for console.log in source files
- Report locations
6. **Git Status**
- Show uncommitted changes
- Show files modified since last commit
## Output
Produce a concise verification report:
```
VERIFICATION: [PASS/FAIL]
Build: [OK/FAIL]
Types: [OK/X errors]
Lint: [OK/X issues]
Tests: [X/Y passed, Z% coverage]
Secrets: [OK/X found]
Logs: [OK/X console.logs]
Ready for PR: [YES/NO]
```
If any critical issues, list them with fix suggestions.
## Arguments
$ARGUMENTS can be:
- `quick` - Only build + types
- `full` - All checks (default)
- `pre-commit` - Checks relevant for commits
- `pre-pr` - Full checks plus security scan

157
.claude/hooks/hooks.json Normal file
View File

@@ -0,0 +1,157 @@
{
"$schema": "https://json.schemastore.org/claude-code-settings.json",
"hooks": {
"PreToolUse": [
{
"matcher": "tool == \"Bash\" && tool_input.command matches \"(npm run dev|pnpm( run)? dev|yarn dev|bun run dev)\"",
"hooks": [
{
"type": "command",
"command": "node -e \"console.error('[Hook] BLOCKED: Dev server must run in tmux for log access');console.error('[Hook] Use: tmux new-session -d -s dev \\\"npm run dev\\\"');console.error('[Hook] Then: tmux attach -t dev');process.exit(1)\""
}
],
"description": "Block dev servers outside tmux - ensures you can access logs"
},
{
"matcher": "tool == \"Bash\" && tool_input.command matches \"(npm (install|test)|pnpm (install|test)|yarn (install|test)?|bun (install|test)|cargo build|make|docker|pytest|vitest|playwright)\"",
"hooks": [
{
"type": "command",
"command": "node -e \"if(!process.env.TMUX){console.error('[Hook] Consider running in tmux for session persistence');console.error('[Hook] tmux new -s dev | tmux attach -t dev')}\""
}
],
"description": "Reminder to use tmux for long-running commands"
},
{
"matcher": "tool == \"Bash\" && tool_input.command matches \"git push\"",
"hooks": [
{
"type": "command",
"command": "node -e \"console.error('[Hook] Review changes before push...');console.error('[Hook] Continuing with push (remove this hook to add interactive review)')\""
}
],
"description": "Reminder before git push to review changes"
},
{
"matcher": "tool == \"Write\" && tool_input.file_path matches \"\\\\.(md|txt)$\" && !(tool_input.file_path matches \"README\\\\.md|CLAUDE\\\\.md|AGENTS\\\\.md|CONTRIBUTING\\\\.md\")",
"hooks": [
{
"type": "command",
"command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path||'';if(/\\.(md|txt)$/.test(p)&&!/(README|CLAUDE|AGENTS|CONTRIBUTING)\\.md$/.test(p)){console.error('[Hook] BLOCKED: Unnecessary documentation file creation');console.error('[Hook] File: '+p);console.error('[Hook] Use README.md for documentation instead');process.exit(1)}console.log(d)})\""
}
],
"description": "Block creation of random .md files - keeps docs consolidated"
},
{
"matcher": "tool == \"Edit\" || tool == \"Write\"",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/suggest-compact.js\""
}
],
"description": "Suggest manual compaction at logical intervals"
}
],
"PreCompact": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/pre-compact.js\""
}
],
"description": "Save state before context compaction"
}
],
"SessionStart": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-start.js\""
}
],
"description": "Load previous context and detect package manager on new session"
}
],
"PostToolUse": [
{
"matcher": "tool == \"Bash\"",
"hooks": [
{
"type": "command",
"command": "node -e \"let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const cmd=i.tool_input?.command||'';if(/gh pr create/.test(cmd)){const out=i.tool_output?.output||'';const m=out.match(/https:\\/\\/github.com\\/[^/]+\\/[^/]+\\/pull\\/\\d+/);if(m){console.error('[Hook] PR created: '+m[0]);const repo=m[0].replace(/https:\\/\\/github.com\\/([^/]+\\/[^/]+)\\/pull\\/\\d+/,'$1');const pr=m[0].replace(/.*\\/pull\\/(\\d+)/,'$1');console.error('[Hook] To review: gh pr review '+pr+' --repo '+repo)}}console.log(d)})\""
}
],
"description": "Log PR URL and provide review command after PR creation"
},
{
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"",
"hooks": [
{
"type": "command",
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){try{execSync('npx prettier --write \"'+p+'\"',{stdio:['pipe','pipe','pipe']})}catch(e){}}console.log(d)})\""
}
],
"description": "Auto-format JS/TS files with Prettier after edits"
},
{
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx)$\"",
"hooks": [
{
"type": "command",
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');const path=require('path');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){let dir=path.dirname(p);while(dir!==path.dirname(dir)&&!fs.existsSync(path.join(dir,'tsconfig.json'))){dir=path.dirname(dir)}if(fs.existsSync(path.join(dir,'tsconfig.json'))){try{const r=execSync('npx tsc --noEmit --pretty false 2>&1',{cwd:dir,encoding:'utf8',stdio:['pipe','pipe','pipe']});const lines=r.split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}catch(e){const lines=(e.stdout||'').split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}}}console.log(d)})\""
}
],
"description": "TypeScript check after editing .ts/.tsx files"
},
{
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"",
"hooks": [
{
"type": "command",
"command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){const c=fs.readFileSync(p,'utf8');const lines=c.split('\\n');const matches=[];lines.forEach((l,idx)=>{if(/console\\.log/.test(l))matches.push((idx+1)+': '+l.trim())});if(matches.length){console.error('[Hook] WARNING: console.log found in '+p);matches.slice(0,5).forEach(m=>console.error(m));console.error('[Hook] Remove console.log before committing')}}console.log(d)})\""
}
],
"description": "Warn about console.log statements after edits"
}
],
"Stop": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{try{execSync('git rev-parse --git-dir',{stdio:'pipe'})}catch{console.log(d);process.exit(0)}try{const files=execSync('git diff --name-only HEAD',{encoding:'utf8',stdio:['pipe','pipe','pipe']}).split('\\n').filter(f=>/\\.(ts|tsx|js|jsx)$/.test(f)&&fs.existsSync(f));let hasConsole=false;for(const f of files){if(fs.readFileSync(f,'utf8').includes('console.log')){console.error('[Hook] WARNING: console.log found in '+f);hasConsole=true}}if(hasConsole)console.error('[Hook] Remove console.log statements before committing')}catch(e){}console.log(d)})\""
}
],
"description": "Check for console.log in modified files after each response"
}
],
"SessionEnd": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-end.js\""
}
],
"description": "Persist session state on end"
},
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/evaluate-session.js\""
}
],
"description": "Evaluate session for extractable patterns"
}
]
}
}

View File

@@ -0,0 +1,36 @@
#!/bin/bash
# PreCompact Hook - Save state before context compaction
#
# Runs before Claude compacts context, giving you a chance to
# preserve important state that might get lost in summarization.
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "PreCompact": [{
# "matcher": "*",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/hooks/memory-persistence/pre-compact.sh"
# }]
# }]
# }
# }
SESSIONS_DIR="${HOME}/.claude/sessions"
COMPACTION_LOG="${SESSIONS_DIR}/compaction-log.txt"
mkdir -p "$SESSIONS_DIR"
# Log compaction event with timestamp
echo "[$(date '+%Y-%m-%d %H:%M:%S')] Context compaction triggered" >> "$COMPACTION_LOG"
# If there's an active session file, note the compaction
ACTIVE_SESSION=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1)
if [ -n "$ACTIVE_SESSION" ]; then
echo "" >> "$ACTIVE_SESSION"
echo "---" >> "$ACTIVE_SESSION"
echo "**[Compaction occurred at $(date '+%H:%M')]** - Context was summarized" >> "$ACTIVE_SESSION"
fi
echo "[PreCompact] State saved before compaction" >&2

View File

@@ -0,0 +1,61 @@
#!/bin/bash
# Stop Hook (Session End) - Persist learnings when session ends
#
# Runs when Claude session ends. Creates/updates session log file
# with timestamp for continuity tracking.
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "Stop": [{
# "matcher": "*",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/hooks/memory-persistence/session-end.sh"
# }]
# }]
# }
# }
SESSIONS_DIR="${HOME}/.claude/sessions"
TODAY=$(date '+%Y-%m-%d')
SESSION_FILE="${SESSIONS_DIR}/${TODAY}-session.tmp"
mkdir -p "$SESSIONS_DIR"
# If session file exists for today, update the end time
if [ -f "$SESSION_FILE" ]; then
# Update Last Updated timestamp
sed -i '' "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null || \
sed -i "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null
echo "[SessionEnd] Updated session file: $SESSION_FILE" >&2
else
# Create new session file with template
cat > "$SESSION_FILE" << EOF
# Session: $(date '+%Y-%m-%d')
**Date:** $TODAY
**Started:** $(date '+%H:%M')
**Last Updated:** $(date '+%H:%M')
---
## Current State
[Session context goes here]
### Completed
- [ ]
### In Progress
- [ ]
### Notes for Next Session
-
### Context to Load
\`\`\`
[relevant files]
\`\`\`
EOF
echo "[SessionEnd] Created session file: $SESSION_FILE" >&2
fi

View File

@@ -0,0 +1,37 @@
#!/bin/bash
# SessionStart Hook - Load previous context on new session
#
# Runs when a new Claude session starts. Checks for recent session
# files and notifies Claude of available context to load.
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "SessionStart": [{
# "matcher": "*",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/hooks/memory-persistence/session-start.sh"
# }]
# }]
# }
# }
SESSIONS_DIR="${HOME}/.claude/sessions"
LEARNED_DIR="${HOME}/.claude/skills/learned"
# Check for recent session files (last 7 days)
recent_sessions=$(find "$SESSIONS_DIR" -name "*.tmp" -mtime -7 2>/dev/null | wc -l | tr -d ' ')
if [ "$recent_sessions" -gt 0 ]; then
latest=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1)
echo "[SessionStart] Found $recent_sessions recent session(s)" >&2
echo "[SessionStart] Latest: $latest" >&2
fi
# Check for learned skills
learned_count=$(find "$LEARNED_DIR" -name "*.md" 2>/dev/null | wc -l | tr -d ' ')
if [ "$learned_count" -gt 0 ]; then
echo "[SessionStart] $learned_count learned skill(s) available in $LEARNED_DIR" >&2
fi

View File

@@ -0,0 +1,52 @@
#!/bin/bash
# Strategic Compact Suggester
# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals
#
# Why manual over auto-compact:
# - Auto-compact happens at arbitrary points, often mid-task
# - Strategic compacting preserves context through logical phases
# - Compact after exploration, before execution
# - Compact after completing a milestone, before starting next
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "PreToolUse": [{
# "matcher": "Edit|Write",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
# }]
# }]
# }
# }
#
# Criteria for suggesting compact:
# - Session has been running for extended period
# - Large number of tool calls made
# - Transitioning from research/exploration to implementation
# - Plan has been finalized
# Track tool call count (increment in a temp file)
COUNTER_FILE="/tmp/claude-tool-count-$$"
THRESHOLD=${COMPACT_THRESHOLD:-50}
# Initialize or increment counter
if [ -f "$COUNTER_FILE" ]; then
count=$(cat "$COUNTER_FILE")
count=$((count + 1))
echo "$count" > "$COUNTER_FILE"
else
echo "1" > "$COUNTER_FILE"
count=1
fi
# Suggest compact after threshold tool calls
if [ "$count" -eq "$THRESHOLD" ]; then
echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2
fi
# Suggest at regular intervals after threshold
if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then
echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2
fi

View File

@@ -75,7 +75,13 @@
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")", "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/structured_data/*.csv 2>/dev/null | head -20\")", "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/structured_data/*.csv 2>/dev/null | head -20\")",
"Bash(tasklist:*)", "Bash(tasklist:*)",
"Bash(findstr:*)" "Bash(findstr:*)",
"Bash(wsl bash -c \"ps aux | grep -E ''python.*train'' | grep -v grep\")",
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/\")",
"Bash(wsl bash -c \"cat /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/results.csv\")",
"Bash(wsl bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/runs/train/invoice_fields/weights/\")",
"Bash(wsl bash -c \"cat ''/mnt/c/Users/yaoji/AppData/Local/Temp/claude/c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2/tasks/b8d8565.output'' 2>/dev/null | tail -100\")",
"Bash(wsl bash -c:*)"
], ],
"deny": [], "deny": [],
"ask": [], "ask": [],

View File

@@ -0,0 +1,314 @@
# Backend Development Patterns
Backend architecture patterns for Python/FastAPI/PostgreSQL applications.
## API Design
### RESTful Structure
```
GET /api/v1/documents # List
GET /api/v1/documents/{id} # Get
POST /api/v1/documents # Create
PUT /api/v1/documents/{id} # Replace
PATCH /api/v1/documents/{id} # Update
DELETE /api/v1/documents/{id} # Delete
GET /api/v1/documents?status=processed&sort=created_at&limit=20&offset=0
```
### FastAPI Route Pattern
```python
from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile
from pydantic import BaseModel
router = APIRouter(prefix="/api/v1", tags=["inference"])
@router.post("/infer", response_model=ApiResponse[InferenceResult])
async def infer_document(
file: UploadFile = File(...),
confidence_threshold: float = Query(0.5, ge=0, le=1),
service: InferenceService = Depends(get_inference_service)
) -> ApiResponse[InferenceResult]:
result = await service.process(file, confidence_threshold)
return ApiResponse(success=True, data=result)
```
### Consistent Response Schema
```python
from typing import Generic, TypeVar
T = TypeVar('T')
class ApiResponse(BaseModel, Generic[T]):
success: bool
data: T | None = None
error: str | None = None
meta: dict | None = None
```
## Core Patterns
### Repository Pattern
```python
from typing import Protocol
class DocumentRepository(Protocol):
def find_all(self, filters: dict | None = None) -> list[Document]: ...
def find_by_id(self, id: str) -> Document | None: ...
def create(self, data: dict) -> Document: ...
def update(self, id: str, data: dict) -> Document: ...
def delete(self, id: str) -> None: ...
```
### Service Layer
```python
class InferenceService:
def __init__(self, model_path: str, use_gpu: bool = True):
self.pipeline = InferencePipeline(model_path=model_path, use_gpu=use_gpu)
async def process(self, file: UploadFile, confidence_threshold: float) -> InferenceResult:
temp_path = self._save_temp_file(file)
try:
return self.pipeline.process_pdf(temp_path)
finally:
temp_path.unlink(missing_ok=True)
```
### Dependency Injection
```python
from functools import lru_cache
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
db_host: str = "localhost"
db_password: str
model_path: str = "runs/train/invoice_fields/weights/best.pt"
class Config:
env_file = ".env"
@lru_cache()
def get_settings() -> Settings:
return Settings()
def get_inference_service(settings: Settings = Depends(get_settings)) -> InferenceService:
return InferenceService(model_path=settings.model_path)
```
## Database Patterns
### Connection Pooling
```python
from psycopg2 import pool
from contextlib import contextmanager
db_pool = pool.ThreadedConnectionPool(minconn=2, maxconn=10, **db_config)
@contextmanager
def get_db_connection():
conn = db_pool.getconn()
try:
yield conn
finally:
db_pool.putconn(conn)
```
### Query Optimization
```python
# GOOD: Select only needed columns
cur.execute("""
SELECT id, status, fields->>'InvoiceNumber' as invoice_number
FROM documents WHERE status = %s
ORDER BY created_at DESC LIMIT %s
""", ('processed', 10))
# BAD: SELECT * FROM documents
```
### N+1 Prevention
```python
# BAD: N+1 queries
for doc in documents:
doc.labels = get_labels(doc.id) # N queries
# GOOD: Batch fetch with JOIN
cur.execute("""
SELECT d.id, d.status, array_agg(l.label) as labels
FROM documents d
LEFT JOIN document_labels l ON d.id = l.document_id
GROUP BY d.id, d.status
""")
```
### Transaction Pattern
```python
def create_document_with_labels(doc_data: dict, labels: list[dict]) -> str:
with get_db_connection() as conn:
try:
with conn.cursor() as cur:
cur.execute("INSERT INTO documents ... RETURNING id", ...)
doc_id = cur.fetchone()[0]
for label in labels:
cur.execute("INSERT INTO document_labels ...", ...)
conn.commit()
return doc_id
except Exception:
conn.rollback()
raise
```
## Caching
```python
from cachetools import TTLCache
_cache = TTLCache(maxsize=1000, ttl=300)
def get_document_cached(doc_id: str) -> Document | None:
if doc_id in _cache:
return _cache[doc_id]
doc = repo.find_by_id(doc_id)
if doc:
_cache[doc_id] = doc
return doc
```
## Error Handling
### Exception Hierarchy
```python
class AppError(Exception):
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
class NotFoundError(AppError):
def __init__(self, resource: str, id: str):
super().__init__(f"{resource} not found: {id}", 404)
class ValidationError(AppError):
def __init__(self, message: str):
super().__init__(message, 400)
```
### FastAPI Exception Handler
```python
@app.exception_handler(AppError)
async def app_error_handler(request: Request, exc: AppError):
return JSONResponse(status_code=exc.status_code, content={"success": False, "error": exc.message})
@app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception):
logger.error(f"Unexpected error: {exc}", exc_info=True)
return JSONResponse(status_code=500, content={"success": False, "error": "Internal server error"})
```
### Retry with Backoff
```python
async def retry_with_backoff(fn, max_retries: int = 3, base_delay: float = 1.0):
last_error = None
for attempt in range(max_retries):
try:
return await fn() if asyncio.iscoroutinefunction(fn) else fn()
except Exception as e:
last_error = e
if attempt < max_retries - 1:
await asyncio.sleep(base_delay * (2 ** attempt))
raise last_error
```
## Rate Limiting
```python
from time import time
from collections import defaultdict
class RateLimiter:
def __init__(self):
self.requests: dict[str, list[float]] = defaultdict(list)
def check_limit(self, identifier: str, max_requests: int, window_sec: int) -> bool:
now = time()
self.requests[identifier] = [t for t in self.requests[identifier] if now - t < window_sec]
if len(self.requests[identifier]) >= max_requests:
return False
self.requests[identifier].append(now)
return True
limiter = RateLimiter()
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
ip = request.client.host
if not limiter.check_limit(ip, max_requests=100, window_sec=60):
return JSONResponse(status_code=429, content={"error": "Rate limit exceeded"})
return await call_next(request)
```
## Logging & Middleware
### Request Logging
```python
@app.middleware("http")
async def log_requests(request: Request, call_next):
request_id = str(uuid.uuid4())[:8]
start_time = time.time()
logger.info(f"[{request_id}] {request.method} {request.url.path}")
response = await call_next(request)
duration_ms = (time.time() - start_time) * 1000
logger.info(f"[{request_id}] Completed {response.status_code} in {duration_ms:.2f}ms")
return response
```
### Structured Logging
```python
class JSONFormatter(logging.Formatter):
def format(self, record):
return json.dumps({
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"module": record.module,
})
```
## Background Tasks
```python
from fastapi import BackgroundTasks
def send_notification(document_id: str, status: str):
logger.info(f"Notification: {document_id} -> {status}")
@router.post("/infer")
async def infer(file: UploadFile, background_tasks: BackgroundTasks):
result = await process_document(file)
background_tasks.add_task(send_notification, result.document_id, "completed")
return result
```
## Key Principles
- Repository pattern: Abstract data access
- Service layer: Business logic separated from routes
- Dependency injection via `Depends()`
- Connection pooling for database
- Parameterized queries only (no f-strings in SQL)
- Batch fetch to prevent N+1
- Consistent `ApiResponse[T]` format
- Exception hierarchy with proper status codes
- Rate limit by IP
- Structured logging with request ID

View File

@@ -0,0 +1,665 @@
---
name: coding-standards
description: Universal coding standards, best practices, and patterns for Python, FastAPI, and data processing development.
---
# Coding Standards & Best Practices
Python coding standards for the Invoice Master project.
## Code Quality Principles
### 1. Readability First
- Code is read more than written
- Clear variable and function names
- Self-documenting code preferred over comments
- Consistent formatting (follow PEP 8)
### 2. KISS (Keep It Simple, Stupid)
- Simplest solution that works
- Avoid over-engineering
- No premature optimization
- Easy to understand > clever code
### 3. DRY (Don't Repeat Yourself)
- Extract common logic into functions
- Create reusable utilities
- Share modules across the codebase
- Avoid copy-paste programming
### 4. YAGNI (You Aren't Gonna Need It)
- Don't build features before they're needed
- Avoid speculative generality
- Add complexity only when required
- Start simple, refactor when needed
## Python Standards
### Variable Naming
```python
# GOOD: Descriptive names
invoice_number = "INV-2024-001"
is_valid_document = True
total_confidence_score = 0.95
# BAD: Unclear names
inv = "INV-2024-001"
flag = True
x = 0.95
```
### Function Naming
```python
# GOOD: Verb-noun pattern with type hints
def extract_invoice_fields(pdf_path: Path) -> dict[str, str]:
"""Extract fields from invoice PDF."""
...
def calculate_confidence(predictions: list[float]) -> float:
"""Calculate average confidence score."""
...
def is_valid_bankgiro(value: str) -> bool:
"""Check if value is valid Bankgiro number."""
...
# BAD: Unclear or noun-only
def invoice(path):
...
def confidence(p):
...
def bankgiro(v):
...
```
### Type Hints (REQUIRED)
```python
# GOOD: Full type annotations
from typing import Optional
from pathlib import Path
from dataclasses import dataclass
@dataclass
class InferenceResult:
document_id: str
fields: dict[str, str]
confidence: dict[str, float]
processing_time_ms: float
def process_document(
pdf_path: Path,
confidence_threshold: float = 0.5
) -> InferenceResult:
"""Process PDF and return extracted fields."""
...
# BAD: No type hints
def process_document(pdf_path, confidence_threshold=0.5):
...
```
### Immutability Pattern (CRITICAL)
```python
# GOOD: Create new objects, don't mutate
def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]:
return {**fields, **updates}
def add_item(items: list[str], new_item: str) -> list[str]:
return [*items, new_item]
# BAD: Direct mutation
def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]:
fields.update(updates) # MUTATION!
return fields
def add_item(items: list[str], new_item: str) -> list[str]:
items.append(new_item) # MUTATION!
return items
```
### Error Handling
```python
import logging
logger = logging.getLogger(__name__)
# GOOD: Comprehensive error handling with logging
def load_model(model_path: Path) -> Model:
"""Load YOLO model from path."""
try:
if not model_path.exists():
raise FileNotFoundError(f"Model not found: {model_path}")
model = YOLO(str(model_path))
logger.info(f"Model loaded: {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Model loading failed: {model_path}") from e
# BAD: No error handling
def load_model(model_path):
return YOLO(str(model_path))
# BAD: Bare except
def load_model(model_path):
try:
return YOLO(str(model_path))
except: # Never use bare except!
return None
```
### Async Best Practices
```python
import asyncio
# GOOD: Parallel execution when possible
async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]:
tasks = [process_document(path) for path in pdf_paths]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle exceptions
valid_results = []
for path, result in zip(pdf_paths, results):
if isinstance(result, Exception):
logger.error(f"Failed to process {path}: {result}")
else:
valid_results.append(result)
return valid_results
# BAD: Sequential when unnecessary
async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]:
results = []
for path in pdf_paths:
result = await process_document(path)
results.append(result)
return results
```
### Context Managers
```python
from contextlib import contextmanager
from pathlib import Path
import tempfile
# GOOD: Proper resource management
@contextmanager
def temp_pdf_copy(pdf_path: Path):
"""Create temporary copy of PDF for processing."""
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
tmp.write(pdf_path.read_bytes())
tmp_path = Path(tmp.name)
try:
yield tmp_path
finally:
tmp_path.unlink(missing_ok=True)
# Usage
with temp_pdf_copy(original_pdf) as tmp_pdf:
result = process_pdf(tmp_pdf)
```
## FastAPI Best Practices
### Route Structure
```python
from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile
from pydantic import BaseModel
router = APIRouter(prefix="/api/v1", tags=["inference"])
class InferenceResponse(BaseModel):
success: bool
document_id: str
fields: dict[str, str]
confidence: dict[str, float]
processing_time_ms: float
@router.post("/infer", response_model=InferenceResponse)
async def infer_document(
file: UploadFile = File(...),
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0)
) -> InferenceResponse:
"""Process invoice PDF and extract fields."""
if not file.filename.endswith(".pdf"):
raise HTTPException(status_code=400, detail="Only PDF files accepted")
result = await inference_service.process(file, confidence_threshold)
return InferenceResponse(
success=True,
document_id=result.document_id,
fields=result.fields,
confidence=result.confidence,
processing_time_ms=result.processing_time_ms
)
```
### Input Validation with Pydantic
```python
from pydantic import BaseModel, Field, field_validator
from datetime import date
import re
class InvoiceData(BaseModel):
invoice_number: str = Field(..., min_length=1, max_length=50)
invoice_date: date
amount: float = Field(..., gt=0)
bankgiro: str | None = None
ocr_number: str | None = None
@field_validator("bankgiro")
@classmethod
def validate_bankgiro(cls, v: str | None) -> str | None:
if v is None:
return None
# Bankgiro: 7-8 digits
cleaned = re.sub(r"[^0-9]", "", v)
if not (7 <= len(cleaned) <= 8):
raise ValueError("Bankgiro must be 7-8 digits")
return cleaned
@field_validator("ocr_number")
@classmethod
def validate_ocr(cls, v: str | None) -> str | None:
if v is None:
return None
# OCR: 2-25 digits
cleaned = re.sub(r"[^0-9]", "", v)
if not (2 <= len(cleaned) <= 25):
raise ValueError("OCR must be 2-25 digits")
return cleaned
```
### Response Format
```python
from pydantic import BaseModel
from typing import Generic, TypeVar
T = TypeVar("T")
class ApiResponse(BaseModel, Generic[T]):
success: bool
data: T | None = None
error: str | None = None
meta: dict | None = None
# Success response
return ApiResponse(
success=True,
data=result,
meta={"processing_time_ms": elapsed_ms}
)
# Error response
return ApiResponse(
success=False,
error="Invalid PDF format"
)
```
## File Organization
### Project Structure
```
src/
├── cli/ # Command-line interfaces
│ ├── autolabel.py
│ ├── train.py
│ └── infer.py
├── pdf/ # PDF processing
│ ├── extractor.py
│ └── renderer.py
├── ocr/ # OCR processing
│ ├── paddle_ocr.py
│ └── machine_code_parser.py
├── inference/ # Inference pipeline
│ ├── pipeline.py
│ ├── yolo_detector.py
│ └── field_extractor.py
├── normalize/ # Field normalization
│ ├── base.py
│ ├── date_normalizer.py
│ └── amount_normalizer.py
├── web/ # FastAPI application
│ ├── app.py
│ ├── routes.py
│ ├── services.py
│ └── schemas.py
└── utils/ # Shared utilities
├── validators.py
├── text_cleaner.py
└── logging.py
tests/ # Mirror of src structure
├── test_pdf/
├── test_ocr/
└── test_inference/
```
### File Naming
```
src/ocr/paddle_ocr.py # snake_case for modules
src/inference/yolo_detector.py # snake_case for modules
tests/test_paddle_ocr.py # test_ prefix for tests
config.py # snake_case for config
```
### Module Size Guidelines
- **Maximum**: 800 lines per file
- **Typical**: 200-400 lines per file
- **Functions**: Max 50 lines each
- Extract utilities when modules grow too large
## Comments & Documentation
### When to Comment
```python
# GOOD: Explain WHY, not WHAT
# Swedish Bankgiro uses Luhn algorithm with weight [1,2,1,2...]
def validate_bankgiro_checksum(bankgiro: str) -> bool:
...
# Payment line format: 7 groups separated by #, checksum at end
def parse_payment_line(line: str) -> PaymentLineData:
...
# BAD: Stating the obvious
# Increment counter by 1
count += 1
# Set name to user's name
name = user.name
```
### Docstrings for Public APIs
```python
def extract_invoice_fields(
pdf_path: Path,
confidence_threshold: float = 0.5,
use_gpu: bool = True
) -> InferenceResult:
"""Extract structured fields from Swedish invoice PDF.
Uses YOLOv11 for field detection and PaddleOCR for text extraction.
Applies field-specific normalization and validation.
Args:
pdf_path: Path to the invoice PDF file.
confidence_threshold: Minimum confidence for field detection (0.0-1.0).
use_gpu: Whether to use GPU acceleration.
Returns:
InferenceResult containing extracted fields and confidence scores.
Raises:
FileNotFoundError: If PDF file doesn't exist.
ProcessingError: If OCR or detection fails.
Example:
>>> result = extract_invoice_fields(Path("invoice.pdf"))
>>> print(result.fields["invoice_number"])
"INV-2024-001"
"""
...
```
## Performance Best Practices
### Caching
```python
from functools import lru_cache
from cachetools import TTLCache
# Static data: LRU cache
@lru_cache(maxsize=100)
def get_field_config(field_name: str) -> FieldConfig:
"""Load field configuration (cached)."""
return load_config(field_name)
# Dynamic data: TTL cache
_document_cache = TTLCache(maxsize=1000, ttl=300) # 5 minutes
def get_document_cached(doc_id: str) -> Document | None:
if doc_id in _document_cache:
return _document_cache[doc_id]
doc = repo.find_by_id(doc_id)
if doc:
_document_cache[doc_id] = doc
return doc
```
### Database Queries
```python
# GOOD: Select only needed columns
cur.execute("""
SELECT id, status, fields->>'invoice_number'
FROM documents
WHERE status = %s
LIMIT %s
""", ('processed', 10))
# BAD: Select everything
cur.execute("SELECT * FROM documents")
# GOOD: Batch operations
cur.executemany(
"INSERT INTO labels (doc_id, field, value) VALUES (%s, %s, %s)",
[(doc_id, f, v) for f, v in fields.items()]
)
# BAD: Individual inserts in loop
for field, value in fields.items():
cur.execute("INSERT INTO labels ...", (doc_id, field, value))
```
### Lazy Loading
```python
class InferencePipeline:
def __init__(self, model_path: Path):
self.model_path = model_path
self._model: YOLO | None = None
self._ocr: PaddleOCR | None = None
@property
def model(self) -> YOLO:
"""Lazy load YOLO model."""
if self._model is None:
self._model = YOLO(str(self.model_path))
return self._model
@property
def ocr(self) -> PaddleOCR:
"""Lazy load PaddleOCR."""
if self._ocr is None:
self._ocr = PaddleOCR(use_angle_cls=True, lang="latin")
return self._ocr
```
## Testing Standards
### Test Structure (AAA Pattern)
```python
def test_extract_bankgiro_valid():
# Arrange
text = "Bankgiro: 123-4567"
# Act
result = extract_bankgiro(text)
# Assert
assert result == "1234567"
def test_extract_bankgiro_invalid_returns_none():
# Arrange
text = "No bankgiro here"
# Act
result = extract_bankgiro(text)
# Assert
assert result is None
```
### Test Naming
```python
# GOOD: Descriptive test names
def test_parse_payment_line_extracts_all_fields(): ...
def test_parse_payment_line_handles_missing_checksum(): ...
def test_validate_ocr_returns_false_for_invalid_checksum(): ...
# BAD: Vague test names
def test_parse(): ...
def test_works(): ...
def test_payment_line(): ...
```
### Fixtures
```python
import pytest
from pathlib import Path
@pytest.fixture
def sample_invoice_pdf(tmp_path: Path) -> Path:
"""Create sample invoice PDF for testing."""
pdf_path = tmp_path / "invoice.pdf"
# Create test PDF...
return pdf_path
@pytest.fixture
def inference_pipeline(sample_model_path: Path) -> InferencePipeline:
"""Create inference pipeline with test model."""
return InferencePipeline(sample_model_path)
def test_process_invoice(inference_pipeline, sample_invoice_pdf):
result = inference_pipeline.process(sample_invoice_pdf)
assert result.fields.get("invoice_number") is not None
```
## Code Smell Detection
### 1. Long Functions
```python
# BAD: Function > 50 lines
def process_document():
# 100 lines of code...
# GOOD: Split into smaller functions
def process_document(pdf_path: Path) -> InferenceResult:
image = render_pdf(pdf_path)
detections = detect_fields(image)
ocr_results = extract_text(image, detections)
fields = normalize_fields(ocr_results)
return build_result(fields)
```
### 2. Deep Nesting
```python
# BAD: 5+ levels of nesting
if document:
if document.is_valid:
if document.has_fields:
if field in document.fields:
if document.fields[field]:
# Do something
# GOOD: Early returns
if not document:
return None
if not document.is_valid:
return None
if not document.has_fields:
return None
if field not in document.fields:
return None
if not document.fields[field]:
return None
# Do something
```
### 3. Magic Numbers
```python
# BAD: Unexplained numbers
if confidence > 0.5:
...
time.sleep(3)
# GOOD: Named constants
CONFIDENCE_THRESHOLD = 0.5
RETRY_DELAY_SECONDS = 3
if confidence > CONFIDENCE_THRESHOLD:
...
time.sleep(RETRY_DELAY_SECONDS)
```
### 4. Mutable Default Arguments
```python
# BAD: Mutable default argument
def process_fields(fields: list = []): # DANGEROUS!
fields.append("new_field")
return fields
# GOOD: Use None as default
def process_fields(fields: list | None = None) -> list:
if fields is None:
fields = []
return [*fields, "new_field"]
```
## Logging Standards
```python
import logging
# Module-level logger
logger = logging.getLogger(__name__)
# GOOD: Appropriate log levels
logger.debug("Processing document: %s", doc_id)
logger.info("Document processed successfully: %s", doc_id)
logger.warning("Low confidence score: %.2f", confidence)
logger.error("Failed to process document: %s", error)
# GOOD: Structured logging with extra data
logger.info(
"Inference complete",
extra={
"document_id": doc_id,
"field_count": len(fields),
"processing_time_ms": elapsed_ms
}
)
# BAD: Using print()
print(f"Processing {doc_id}") # Never in production!
```
**Remember**: Code quality is not negotiable. Clear, maintainable Python code with proper type hints enables confident development and refactoring.

View File

@@ -0,0 +1,80 @@
---
name: continuous-learning
description: Automatically extract reusable patterns from Claude Code sessions and save them as learned skills for future use.
---
# Continuous Learning Skill
Automatically evaluates Claude Code sessions on end to extract reusable patterns that can be saved as learned skills.
## How It Works
This skill runs as a **Stop hook** at the end of each session:
1. **Session Evaluation**: Checks if session has enough messages (default: 10+)
2. **Pattern Detection**: Identifies extractable patterns from the session
3. **Skill Extraction**: Saves useful patterns to `~/.claude/skills/learned/`
## Configuration
Edit `config.json` to customize:
```json
{
"min_session_length": 10,
"extraction_threshold": "medium",
"auto_approve": false,
"learned_skills_path": "~/.claude/skills/learned/",
"patterns_to_detect": [
"error_resolution",
"user_corrections",
"workarounds",
"debugging_techniques",
"project_specific"
],
"ignore_patterns": [
"simple_typos",
"one_time_fixes",
"external_api_issues"
]
}
```
## Pattern Types
| Pattern | Description |
|---------|-------------|
| `error_resolution` | How specific errors were resolved |
| `user_corrections` | Patterns from user corrections |
| `workarounds` | Solutions to framework/library quirks |
| `debugging_techniques` | Effective debugging approaches |
| `project_specific` | Project-specific conventions |
## Hook Setup
Add to your `~/.claude/settings.json`:
```json
{
"hooks": {
"Stop": [{
"matcher": "*",
"hooks": [{
"type": "command",
"command": "~/.claude/skills/continuous-learning/evaluate-session.sh"
}]
}]
}
}
```
## Why Stop Hook?
- **Lightweight**: Runs once at session end
- **Non-blocking**: Doesn't add latency to every message
- **Complete context**: Has access to full session transcript
## Related
- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Section on continuous learning
- `/learn` command - Manual pattern extraction mid-session

View File

@@ -0,0 +1,18 @@
{
"min_session_length": 10,
"extraction_threshold": "medium",
"auto_approve": false,
"learned_skills_path": "~/.claude/skills/learned/",
"patterns_to_detect": [
"error_resolution",
"user_corrections",
"workarounds",
"debugging_techniques",
"project_specific"
],
"ignore_patterns": [
"simple_typos",
"one_time_fixes",
"external_api_issues"
]
}

View File

@@ -0,0 +1,60 @@
#!/bin/bash
# Continuous Learning - Session Evaluator
# Runs on Stop hook to extract reusable patterns from Claude Code sessions
#
# Why Stop hook instead of UserPromptSubmit:
# - Stop runs once at session end (lightweight)
# - UserPromptSubmit runs every message (heavy, adds latency)
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "Stop": [{
# "matcher": "*",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/skills/continuous-learning/evaluate-session.sh"
# }]
# }]
# }
# }
#
# Patterns to detect: error_resolution, debugging_techniques, workarounds, project_specific
# Patterns to ignore: simple_typos, one_time_fixes, external_api_issues
# Extracted skills saved to: ~/.claude/skills/learned/
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
CONFIG_FILE="$SCRIPT_DIR/config.json"
LEARNED_SKILLS_PATH="${HOME}/.claude/skills/learned"
MIN_SESSION_LENGTH=10
# Load config if exists
if [ -f "$CONFIG_FILE" ]; then
MIN_SESSION_LENGTH=$(jq -r '.min_session_length // 10' "$CONFIG_FILE")
LEARNED_SKILLS_PATH=$(jq -r '.learned_skills_path // "~/.claude/skills/learned/"' "$CONFIG_FILE" | sed "s|~|$HOME|")
fi
# Ensure learned skills directory exists
mkdir -p "$LEARNED_SKILLS_PATH"
# Get transcript path from environment (set by Claude Code)
transcript_path="${CLAUDE_TRANSCRIPT_PATH:-}"
if [ -z "$transcript_path" ] || [ ! -f "$transcript_path" ]; then
exit 0
fi
# Count messages in session
message_count=$(grep -c '"type":"user"' "$transcript_path" 2>/dev/null || echo "0")
# Skip short sessions
if [ "$message_count" -lt "$MIN_SESSION_LENGTH" ]; then
echo "[ContinuousLearning] Session too short ($message_count messages), skipping" >&2
exit 0
fi
# Signal to Claude that session should be evaluated for extractable patterns
echo "[ContinuousLearning] Session has $message_count messages - evaluate for extractable patterns" >&2
echo "[ContinuousLearning] Save learned skills to: $LEARNED_SKILLS_PATH" >&2

View File

@@ -1,245 +0,0 @@
---
name: dev-builder
description: 根据 Product-Spec.md 初始化项目、安装依赖、实现代码。与 product-spec-builder 配套使用,帮助用户将需求文档转化为可运行的代码项目。
---
[角色]
你是一位经验丰富的全栈开发工程师。
你能够根据产品需求文档快速搭建项目,选择合适的技术栈,编写高质量的代码。你注重代码结构清晰、可维护性强。
[任务]
读取 Product-Spec.md完成以下工作
1. 分析需求,确定项目类型和技术栈
2. 初始化项目,创建目录结构
3. 安装必要依赖,配置开发环境
4. 实现代码UI、功能、AI 集成)
最终交付可运行的项目代码。
[总体规则]
- 必须先读取 Product-Spec.md不存在则提示用户先完成需求收集
- 每个阶段完成后输出进度反馈
- 如有原型图,开发时参考原型图的视觉设计
- 代码要简洁、可读、可维护
- 优先使用简单方案,不过度设计
- 只改与当前任务相关的文件,禁止「顺手升级依赖」「全局格式化」「无关重命名」
- 始终使用中文与用户交流
[项目类型判断]
根据 Product Spec 的 UI 布局和技术说明判断:
- 有 UI + 纯前端/无需服务器 → 纯前端 Web 应用
- 有 UI + 需要后端/数据库/API → 全栈 Web 应用
- 无 UI + 命令行操作 → CLI 工具
- 只是 API 服务 → 后端服务
[技术栈选择]
| 项目类型 | 推荐技术栈 |
|---------|-----------|
| 纯前端 Web 应用 | React + Vite + TypeScript + Tailwind |
| 全栈 Web 应用 | Next.js + TypeScript + Tailwind |
| CLI 工具 | Node.js + TypeScript + Commander |
| 后端服务 | Express + TypeScript |
| AI/ML 应用 | Python + FastAPI + PyTorch/TensorFlow |
| 数据处理工具 | Python + Pandas + NumPy |
**选择原则**
- Product Spec 技术说明有指定 → 用指定的
- 没指定 → 用推荐方案
- 有疑问 → 询问用户
[AI 研发方向]
**适用场景**
- 机器学习模型训练与推理
- 计算机视觉目标检测、OCR、图像分类
- 自然语言处理(文本分类、命名实体识别、对话系统)
- 大语言模型应用RAG、Agent、Prompt Engineering
- 数据分析与可视化
**技术栈推荐**
| 方向 | 推荐技术栈 |
|-----|-----------|
| 深度学习 | PyTorch + Lightning + Weights & Biases |
| 目标检测 | Ultralytics YOLO + OpenCV |
| OCR | PaddleOCR / EasyOCR / Tesseract |
| NLP | Transformers + spaCy |
| LLM 应用 | LangChain / LlamaIndex + OpenAI API |
| 数据处理 | Pandas + Polars + DuckDB |
| 模型部署 | FastAPI + Docker + ONNX Runtime |
**项目结构AI/ML 项目)**
```
project/
├── src/ # 源代码
│ ├── data/ # 数据加载与预处理
│ ├── models/ # 模型定义
│ ├── training/ # 训练逻辑
│ ├── inference/ # 推理逻辑
│ └── utils/ # 工具函数
├── configs/ # 配置文件YAML
├── data/ # 数据目录
│ ├── raw/ # 原始数据(不修改)
│ └── processed/ # 处理后数据
├── models/ # 训练好的模型权重
├── notebooks/ # 实验 Notebook
├── tests/ # 测试代码
└── scripts/ # 运行脚本
```
**AI 研发规范**
- **可复现性**固定随机种子random、numpy、torch记录实验配置
- **数据管理**:原始数据不可变,处理数据版本化
- **实验追踪**:使用 MLflow/W&B 记录指标、参数、产物
- **配置驱动**:所有超参数放 YAML 配置,禁止硬编码
- **类型安全**:使用 Pydantic 定义数据结构
- **日志规范**:使用 logging 模块,不用 print
**模型训练检查项**
- ✅ 数据集划分train/val/test比例合理
- ✅ 早停机制Early Stopping防止过拟合
- ✅ 学习率调度器配置
- ✅ 模型检查点保存策略
- ✅ 验证集指标监控
- ✅ GPU 内存管理(混合精度训练)
**部署注意事项**
- 模型导出为 ONNX 格式提升推理速度
- API 接口使用异步处理提升并发
- 大文件使用流式传输
- 配置健康检查端点
- 日志和指标监控
[初始化提醒]
**项目名称规范**
- 只能用小写字母、数字、短横线(如 my-app
- 不能有空格、&、# 等特殊字符
**npm 报错时**:可尝试 pnpm 或 yarn
[依赖选择]
**原则**:只装需要的,不装「可能用到」的
[环境变量配置]
**⚠️ 安全警告**
- Vite 纯前端:`VITE_` 前缀变量**会暴露给浏览器**,不能存放 API Key
- Next.js不加 `NEXT_PUBLIC_` 前缀的变量只在服务端可用(安全)
**涉及 AI API 调用时**
- 推荐用 Next.jsAPI Key 只在服务端使用,安全)
- 备选:创建独立后端代理请求
- 仅限开发/演示:使用 VITE_ 前缀(必须提醒用户安全风险)
**文件规范**
- 创建 `.env.example` 作为模板(提交到 Git
- 实际值放 `.env.local`(不提交,确保 .gitignore 包含)
[工作流程]
[启动阶段]
目的:检查前置条件,读取项目文档
第一步:检测 Product Spec
检测 Product-Spec.md 是否存在
不存在 → 提示:「未找到 Product-Spec.md请先使用 /prd 完成需求收集。」,终止流程
存在 → 继续
第二步:读取项目文档
加载 Product-Spec.md
提取产品概述、功能需求、UI 布局、技术说明、AI 能力需求
第三步:检查原型图
检查 UI-Prompts.md 是否存在
存在 → 询问:「我看到你已经生成了原型图提示词,如果有生成的原型图图片,可以发给我参考。」
不存在 → 询问:「是否有原型图或设计稿可以参考?有的话可以发给我。」
用户发送图片 → 记录,开发时参考
用户说没有 → 继续
[技术方案阶段]
目的:确定技术栈并告知用户
分析项目类型,选择技术栈,列出主要依赖
输出方案后直接进入下一阶段:
"📦 **技术方案**
**项目类型**[类型]
**技术栈**[技术栈]
**主要依赖**
- [依赖1][用途]
- [依赖2][用途]"
[项目搭建阶段]
目的:初始化项目,创建基础结构
执行:初始化项目 → 配置 TailwindVite 项目)→ 安装功能依赖 → 配置环境变量(如需要)
每完成一步输出进度反馈
[代码实现阶段]
目的:实现功能代码
第一步:创建基础布局
根据 Product Spec 的 UI 布局章节创建整体布局结构
如有原型图,参考其视觉设计
第二步:实现 UI 组件
根据 UI 布局的控件规范创建组件
使用 Tailwind 编写样式
第三步:实现功能逻辑
核心功能优先实现,辅助功能其次
添加状态管理,实现用户交互逻辑
第四步:集成 AI 能力(如有)
创建 AI 服务模块,实现调用函数
处理 API Key 读取,在相应功能中集成
第五步:完善用户体验
添加 loading 状态、错误处理、空状态提示、输入校验
[完成阶段]
目的:输出开发结果总结
输出:
"✅ **项目开发完成!**
**技术栈**[技术栈]
**项目结构**
```
[实际目录结构]
```
**已实现功能**
- ✅ [功能1]
- ✅ [功能2]
- ...
**AI 能力集成**
- [已集成的 AI 能力,或「无」]
**环境变量**
- [需要配置的环境变量,或「无需配置」]"
[质量门槛]
每个功能点至少满足:
**必须**
- ✅ 主路径可用Happy Path 能跑通)
- ✅ 异常路径清晰(错误提示、重试/回退)
- ✅ loading 状态(涉及异步操作时)
- ✅ 空状态处理(无数据时的提示)
- ✅ 基础输入校验(必填、格式)
- ✅ 敏感信息不写入代码API Key 走环境变量)
**建议**
- 基础可访问性(可点击、可键盘操作)
- 响应式适配(如需支持移动端)
[代码规范]
- 单个文件不超过 300 行,超过则拆分
- 优先使用函数组件 + Hooks
- 样式优先用 Tailwind
[初始化]
执行 [启动阶段]

View File

@@ -0,0 +1,221 @@
# Eval Harness Skill
A formal evaluation framework for Claude Code sessions, implementing eval-driven development (EDD) principles.
## Philosophy
Eval-Driven Development treats evals as the "unit tests of AI development":
- Define expected behavior BEFORE implementation
- Run evals continuously during development
- Track regressions with each change
- Use pass@k metrics for reliability measurement
## Eval Types
### Capability Evals
Test if Claude can do something it couldn't before:
```markdown
[CAPABILITY EVAL: feature-name]
Task: Description of what Claude should accomplish
Success Criteria:
- [ ] Criterion 1
- [ ] Criterion 2
- [ ] Criterion 3
Expected Output: Description of expected result
```
### Regression Evals
Ensure changes don't break existing functionality:
```markdown
[REGRESSION EVAL: feature-name]
Baseline: SHA or checkpoint name
Tests:
- existing-test-1: PASS/FAIL
- existing-test-2: PASS/FAIL
- existing-test-3: PASS/FAIL
Result: X/Y passed (previously Y/Y)
```
## Grader Types
### 1. Code-Based Grader
Deterministic checks using code:
```bash
# Check if file contains expected pattern
grep -q "export function handleAuth" src/auth.ts && echo "PASS" || echo "FAIL"
# Check if tests pass
npm test -- --testPathPattern="auth" && echo "PASS" || echo "FAIL"
# Check if build succeeds
npm run build && echo "PASS" || echo "FAIL"
```
### 2. Model-Based Grader
Use Claude to evaluate open-ended outputs:
```markdown
[MODEL GRADER PROMPT]
Evaluate the following code change:
1. Does it solve the stated problem?
2. Is it well-structured?
3. Are edge cases handled?
4. Is error handling appropriate?
Score: 1-5 (1=poor, 5=excellent)
Reasoning: [explanation]
```
### 3. Human Grader
Flag for manual review:
```markdown
[HUMAN REVIEW REQUIRED]
Change: Description of what changed
Reason: Why human review is needed
Risk Level: LOW/MEDIUM/HIGH
```
## Metrics
### pass@k
"At least one success in k attempts"
- pass@1: First attempt success rate
- pass@3: Success within 3 attempts
- Typical target: pass@3 > 90%
### pass^k
"All k trials succeed"
- Higher bar for reliability
- pass^3: 3 consecutive successes
- Use for critical paths
## Eval Workflow
### 1. Define (Before Coding)
```markdown
## EVAL DEFINITION: feature-xyz
### Capability Evals
1. Can create new user account
2. Can validate email format
3. Can hash password securely
### Regression Evals
1. Existing login still works
2. Session management unchanged
3. Logout flow intact
### Success Metrics
- pass@3 > 90% for capability evals
- pass^3 = 100% for regression evals
```
### 2. Implement
Write code to pass the defined evals.
### 3. Evaluate
```bash
# Run capability evals
[Run each capability eval, record PASS/FAIL]
# Run regression evals
npm test -- --testPathPattern="existing"
# Generate report
```
### 4. Report
```markdown
EVAL REPORT: feature-xyz
========================
Capability Evals:
create-user: PASS (pass@1)
validate-email: PASS (pass@2)
hash-password: PASS (pass@1)
Overall: 3/3 passed
Regression Evals:
login-flow: PASS
session-mgmt: PASS
logout-flow: PASS
Overall: 3/3 passed
Metrics:
pass@1: 67% (2/3)
pass@3: 100% (3/3)
Status: READY FOR REVIEW
```
## Integration Patterns
### Pre-Implementation
```
/eval define feature-name
```
Creates eval definition file at `.claude/evals/feature-name.md`
### During Implementation
```
/eval check feature-name
```
Runs current evals and reports status
### Post-Implementation
```
/eval report feature-name
```
Generates full eval report
## Eval Storage
Store evals in project:
```
.claude/
evals/
feature-xyz.md # Eval definition
feature-xyz.log # Eval run history
baseline.json # Regression baselines
```
## Best Practices
1. **Define evals BEFORE coding** - Forces clear thinking about success criteria
2. **Run evals frequently** - Catch regressions early
3. **Track pass@k over time** - Monitor reliability trends
4. **Use code graders when possible** - Deterministic > probabilistic
5. **Human review for security** - Never fully automate security checks
6. **Keep evals fast** - Slow evals don't get run
7. **Version evals with code** - Evals are first-class artifacts
## Example: Adding Authentication
```markdown
## EVAL: add-authentication
### Phase 1: Define (10 min)
Capability Evals:
- [ ] User can register with email/password
- [ ] User can login with valid credentials
- [ ] Invalid credentials rejected with proper error
- [ ] Sessions persist across page reloads
- [ ] Logout clears session
Regression Evals:
- [ ] Public routes still accessible
- [ ] API responses unchanged
- [ ] Database schema compatible
### Phase 2: Implement (varies)
[Write code]
### Phase 3: Evaluate
Run: /eval check add-authentication
### Phase 4: Report
EVAL REPORT: add-authentication
==============================
Capability: 5/5 passed (pass@3: 100%)
Regression: 3/3 passed (pass^3: 100%)
Status: SHIP IT
```

View File

@@ -0,0 +1,631 @@
---
name: frontend-patterns
description: Frontend development patterns for React, Next.js, state management, performance optimization, and UI best practices.
---
# Frontend Development Patterns
Modern frontend patterns for React, Next.js, and performant user interfaces.
## Component Patterns
### Composition Over Inheritance
```typescript
// ✅ GOOD: Component composition
interface CardProps {
children: React.ReactNode
variant?: 'default' | 'outlined'
}
export function Card({ children, variant = 'default' }: CardProps) {
return <div className={`card card-${variant}`}>{children}</div>
}
export function CardHeader({ children }: { children: React.ReactNode }) {
return <div className="card-header">{children}</div>
}
export function CardBody({ children }: { children: React.ReactNode }) {
return <div className="card-body">{children}</div>
}
// Usage
<Card>
<CardHeader>Title</CardHeader>
<CardBody>Content</CardBody>
</Card>
```
### Compound Components
```typescript
interface TabsContextValue {
activeTab: string
setActiveTab: (tab: string) => void
}
const TabsContext = createContext<TabsContextValue | undefined>(undefined)
export function Tabs({ children, defaultTab }: {
children: React.ReactNode
defaultTab: string
}) {
const [activeTab, setActiveTab] = useState(defaultTab)
return (
<TabsContext.Provider value={{ activeTab, setActiveTab }}>
{children}
</TabsContext.Provider>
)
}
export function TabList({ children }: { children: React.ReactNode }) {
return <div className="tab-list">{children}</div>
}
export function Tab({ id, children }: { id: string, children: React.ReactNode }) {
const context = useContext(TabsContext)
if (!context) throw new Error('Tab must be used within Tabs')
return (
<button
className={context.activeTab === id ? 'active' : ''}
onClick={() => context.setActiveTab(id)}
>
{children}
</button>
)
}
// Usage
<Tabs defaultTab="overview">
<TabList>
<Tab id="overview">Overview</Tab>
<Tab id="details">Details</Tab>
</TabList>
</Tabs>
```
### Render Props Pattern
```typescript
interface DataLoaderProps<T> {
url: string
children: (data: T | null, loading: boolean, error: Error | null) => React.ReactNode
}
export function DataLoader<T>({ url, children }: DataLoaderProps<T>) {
const [data, setData] = useState<T | null>(null)
const [loading, setLoading] = useState(true)
const [error, setError] = useState<Error | null>(null)
useEffect(() => {
fetch(url)
.then(res => res.json())
.then(setData)
.catch(setError)
.finally(() => setLoading(false))
}, [url])
return <>{children(data, loading, error)}</>
}
// Usage
<DataLoader<Market[]> url="/api/markets">
{(markets, loading, error) => {
if (loading) return <Spinner />
if (error) return <Error error={error} />
return <MarketList markets={markets!} />
}}
</DataLoader>
```
## Custom Hooks Patterns
### State Management Hook
```typescript
export function useToggle(initialValue = false): [boolean, () => void] {
const [value, setValue] = useState(initialValue)
const toggle = useCallback(() => {
setValue(v => !v)
}, [])
return [value, toggle]
}
// Usage
const [isOpen, toggleOpen] = useToggle()
```
### Async Data Fetching Hook
```typescript
interface UseQueryOptions<T> {
onSuccess?: (data: T) => void
onError?: (error: Error) => void
enabled?: boolean
}
export function useQuery<T>(
key: string,
fetcher: () => Promise<T>,
options?: UseQueryOptions<T>
) {
const [data, setData] = useState<T | null>(null)
const [error, setError] = useState<Error | null>(null)
const [loading, setLoading] = useState(false)
const refetch = useCallback(async () => {
setLoading(true)
setError(null)
try {
const result = await fetcher()
setData(result)
options?.onSuccess?.(result)
} catch (err) {
const error = err as Error
setError(error)
options?.onError?.(error)
} finally {
setLoading(false)
}
}, [fetcher, options])
useEffect(() => {
if (options?.enabled !== false) {
refetch()
}
}, [key, refetch, options?.enabled])
return { data, error, loading, refetch }
}
// Usage
const { data: markets, loading, error, refetch } = useQuery(
'markets',
() => fetch('/api/markets').then(r => r.json()),
{
onSuccess: data => console.log('Fetched', data.length, 'markets'),
onError: err => console.error('Failed:', err)
}
)
```
### Debounce Hook
```typescript
export function useDebounce<T>(value: T, delay: number): T {
const [debouncedValue, setDebouncedValue] = useState<T>(value)
useEffect(() => {
const handler = setTimeout(() => {
setDebouncedValue(value)
}, delay)
return () => clearTimeout(handler)
}, [value, delay])
return debouncedValue
}
// Usage
const [searchQuery, setSearchQuery] = useState('')
const debouncedQuery = useDebounce(searchQuery, 500)
useEffect(() => {
if (debouncedQuery) {
performSearch(debouncedQuery)
}
}, [debouncedQuery])
```
## State Management Patterns
### Context + Reducer Pattern
```typescript
interface State {
markets: Market[]
selectedMarket: Market | null
loading: boolean
}
type Action =
| { type: 'SET_MARKETS'; payload: Market[] }
| { type: 'SELECT_MARKET'; payload: Market }
| { type: 'SET_LOADING'; payload: boolean }
function reducer(state: State, action: Action): State {
switch (action.type) {
case 'SET_MARKETS':
return { ...state, markets: action.payload }
case 'SELECT_MARKET':
return { ...state, selectedMarket: action.payload }
case 'SET_LOADING':
return { ...state, loading: action.payload }
default:
return state
}
}
const MarketContext = createContext<{
state: State
dispatch: Dispatch<Action>
} | undefined>(undefined)
export function MarketProvider({ children }: { children: React.ReactNode }) {
const [state, dispatch] = useReducer(reducer, {
markets: [],
selectedMarket: null,
loading: false
})
return (
<MarketContext.Provider value={{ state, dispatch }}>
{children}
</MarketContext.Provider>
)
}
export function useMarkets() {
const context = useContext(MarketContext)
if (!context) throw new Error('useMarkets must be used within MarketProvider')
return context
}
```
## Performance Optimization
### Memoization
```typescript
// ✅ useMemo for expensive computations
const sortedMarkets = useMemo(() => {
return markets.sort((a, b) => b.volume - a.volume)
}, [markets])
// ✅ useCallback for functions passed to children
const handleSearch = useCallback((query: string) => {
setSearchQuery(query)
}, [])
// ✅ React.memo for pure components
export const MarketCard = React.memo<MarketCardProps>(({ market }) => {
return (
<div className="market-card">
<h3>{market.name}</h3>
<p>{market.description}</p>
</div>
)
})
```
### Code Splitting & Lazy Loading
```typescript
import { lazy, Suspense } from 'react'
// ✅ Lazy load heavy components
const HeavyChart = lazy(() => import('./HeavyChart'))
const ThreeJsBackground = lazy(() => import('./ThreeJsBackground'))
export function Dashboard() {
return (
<div>
<Suspense fallback={<ChartSkeleton />}>
<HeavyChart data={data} />
</Suspense>
<Suspense fallback={null}>
<ThreeJsBackground />
</Suspense>
</div>
)
}
```
### Virtualization for Long Lists
```typescript
import { useVirtualizer } from '@tanstack/react-virtual'
export function VirtualMarketList({ markets }: { markets: Market[] }) {
const parentRef = useRef<HTMLDivElement>(null)
const virtualizer = useVirtualizer({
count: markets.length,
getScrollElement: () => parentRef.current,
estimateSize: () => 100, // Estimated row height
overscan: 5 // Extra items to render
})
return (
<div ref={parentRef} style={{ height: '600px', overflow: 'auto' }}>
<div
style={{
height: `${virtualizer.getTotalSize()}px`,
position: 'relative'
}}
>
{virtualizer.getVirtualItems().map(virtualRow => (
<div
key={virtualRow.index}
style={{
position: 'absolute',
top: 0,
left: 0,
width: '100%',
height: `${virtualRow.size}px`,
transform: `translateY(${virtualRow.start}px)`
}}
>
<MarketCard market={markets[virtualRow.index]} />
</div>
))}
</div>
</div>
)
}
```
## Form Handling Patterns
### Controlled Form with Validation
```typescript
interface FormData {
name: string
description: string
endDate: string
}
interface FormErrors {
name?: string
description?: string
endDate?: string
}
export function CreateMarketForm() {
const [formData, setFormData] = useState<FormData>({
name: '',
description: '',
endDate: ''
})
const [errors, setErrors] = useState<FormErrors>({})
const validate = (): boolean => {
const newErrors: FormErrors = {}
if (!formData.name.trim()) {
newErrors.name = 'Name is required'
} else if (formData.name.length > 200) {
newErrors.name = 'Name must be under 200 characters'
}
if (!formData.description.trim()) {
newErrors.description = 'Description is required'
}
if (!formData.endDate) {
newErrors.endDate = 'End date is required'
}
setErrors(newErrors)
return Object.keys(newErrors).length === 0
}
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault()
if (!validate()) return
try {
await createMarket(formData)
// Success handling
} catch (error) {
// Error handling
}
}
return (
<form onSubmit={handleSubmit}>
<input
value={formData.name}
onChange={e => setFormData(prev => ({ ...prev, name: e.target.value }))}
placeholder="Market name"
/>
{errors.name && <span className="error">{errors.name}</span>}
{/* Other fields */}
<button type="submit">Create Market</button>
</form>
)
}
```
## Error Boundary Pattern
```typescript
interface ErrorBoundaryState {
hasError: boolean
error: Error | null
}
export class ErrorBoundary extends React.Component<
{ children: React.ReactNode },
ErrorBoundaryState
> {
state: ErrorBoundaryState = {
hasError: false,
error: null
}
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
return { hasError: true, error }
}
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
console.error('Error boundary caught:', error, errorInfo)
}
render() {
if (this.state.hasError) {
return (
<div className="error-fallback">
<h2>Something went wrong</h2>
<p>{this.state.error?.message}</p>
<button onClick={() => this.setState({ hasError: false })}>
Try again
</button>
</div>
)
}
return this.props.children
}
}
// Usage
<ErrorBoundary>
<App />
</ErrorBoundary>
```
## Animation Patterns
### Framer Motion Animations
```typescript
import { motion, AnimatePresence } from 'framer-motion'
// ✅ List animations
export function AnimatedMarketList({ markets }: { markets: Market[] }) {
return (
<AnimatePresence>
{markets.map(market => (
<motion.div
key={market.id}
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -20 }}
transition={{ duration: 0.3 }}
>
<MarketCard market={market} />
</motion.div>
))}
</AnimatePresence>
)
}
// ✅ Modal animations
export function Modal({ isOpen, onClose, children }: ModalProps) {
return (
<AnimatePresence>
{isOpen && (
<>
<motion.div
className="modal-overlay"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
onClick={onClose}
/>
<motion.div
className="modal-content"
initial={{ opacity: 0, scale: 0.9, y: 20 }}
animate={{ opacity: 1, scale: 1, y: 0 }}
exit={{ opacity: 0, scale: 0.9, y: 20 }}
>
{children}
</motion.div>
</>
)}
</AnimatePresence>
)
}
```
## Accessibility Patterns
### Keyboard Navigation
```typescript
export function Dropdown({ options, onSelect }: DropdownProps) {
const [isOpen, setIsOpen] = useState(false)
const [activeIndex, setActiveIndex] = useState(0)
const handleKeyDown = (e: React.KeyboardEvent) => {
switch (e.key) {
case 'ArrowDown':
e.preventDefault()
setActiveIndex(i => Math.min(i + 1, options.length - 1))
break
case 'ArrowUp':
e.preventDefault()
setActiveIndex(i => Math.max(i - 1, 0))
break
case 'Enter':
e.preventDefault()
onSelect(options[activeIndex])
setIsOpen(false)
break
case 'Escape':
setIsOpen(false)
break
}
}
return (
<div
role="combobox"
aria-expanded={isOpen}
aria-haspopup="listbox"
onKeyDown={handleKeyDown}
>
{/* Dropdown implementation */}
</div>
)
}
```
### Focus Management
```typescript
export function Modal({ isOpen, onClose, children }: ModalProps) {
const modalRef = useRef<HTMLDivElement>(null)
const previousFocusRef = useRef<HTMLElement | null>(null)
useEffect(() => {
if (isOpen) {
// Save currently focused element
previousFocusRef.current = document.activeElement as HTMLElement
// Focus modal
modalRef.current?.focus()
} else {
// Restore focus when closing
previousFocusRef.current?.focus()
}
}, [isOpen])
return isOpen ? (
<div
ref={modalRef}
role="dialog"
aria-modal="true"
tabIndex={-1}
onKeyDown={e => e.key === 'Escape' && onClose()}
>
{children}
</div>
) : null
}
```
**Remember**: Modern frontend patterns enable maintainable, performant user interfaces. Choose patterns that fit your project complexity.

View File

@@ -1,335 +0,0 @@
---
name: product-spec-builder
description: 当用户表达想要开发产品、应用、工具或任何软件项目时或者用户想要迭代现有功能、新增需求、修改产品规格时使用此技能。0-1 阶段通过深入对话收集需求并生成 Product Spec迭代阶段帮助用户想清楚变更内容并更新现有 Product Spec。
---
[角色]
你是废才,一位看透无数产品生死的资深产品经理。
你见过太多人带着"改变世界"的妄想来找你,最后连需求都说不清楚。
你也见过真正能成事的人——他们不一定聪明,但足够诚实,敢于面对自己想法的漏洞。
你不是来讨好用户的。你是来帮他们把脑子里的浆糊变成可执行的产品文档的。
如果他们的想法有问题,你会直接说。如果他们在自欺欺人,你会戳破。
你的冷酷不是恶意,是效率。情绪是最好的思考燃料,而你擅长点火。
[任务]
**0-1 模式**:通过深入对话收集用户的产品需求,用直白甚至刺耳的追问逼迫用户想清楚,最终生成一份结构完整、细节丰富、可直接用于 AI 开发的 Product Spec 文档,并输出为 .md 文件供用户下载使用。
**迭代模式**:当用户在开发过程中提出新功能、修改需求或迭代想法时,通过追问帮助用户想清楚变更内容,检测与现有 Spec 的冲突,直接更新 Product Spec 文件,并自动记录变更日志。
[第一性原则]
**AI优先原则**:用户提出的所有功能,首先考虑如何用 AI 来实现。
- 遇到任何功能需求,第一反应是:这个能不能用 AI 做?能做到什么程度?
- 主动询问用户这个功能要不要加一个「AI一键优化」或「AI智能推荐」
- 如果用户描述的功能明显可以用 AI 增强,直接建议,不要等用户想到
- 最终输出的 Product Spec 必须明确列出需要的 AI 能力类型
**简单优先原则**:复杂度是产品的敌人。
- 能用现成服务的,不自己造轮子
- 每增加一个功能都要问「真的需要吗」
- 第一版做最小可行产品,验证了再加功能
[技能]
- **需求挖掘**:通过开放式提问引导用户表达想法,捕捉关键信息
- **追问深挖**:针对模糊描述追问细节,不接受"大概"、"可能"、"应该"
- **AI能力识别**:根据功能需求,识别需要的 AI 能力类型(文本、图像、语音等)
- **技术需求引导**:通过业务问题推断技术需求,帮助无编程基础的用户理解技术选择
- **布局设计**:深入挖掘界面布局需求,确保每个页面有清晰的空间规范
- **漏洞识别**:发现用户想法中的矛盾、遗漏、自欺欺人之处,直接指出
- **冲突检测**:在迭代时检测新需求与现有 Spec 的冲突,主动指出并给出解决方案
- **方案引导**:当用户不知道怎么做时,提供 2-3 个选项 + 优劣分析,逼用户选择
- **结构化思维**:将零散信息整理为清晰的产品框架
- **文档输出**:按照标准模板生成专业的 Product Spec输出为 .md 文件
[文件结构]
```
product-spec-builder/
├── SKILL.md # 主 Skill 定义(本文件)
└── templates/
├── product-spec-template.md # Product Spec 输出模板
└── changelog-template.md # 变更记录模板
```
[输出风格]
**语态**
- 直白、冷静,偶尔带着看透世事的冷漠
- 不奉承、不迎合、不说"这个想法很棒"之类的废话
- 该嘲讽时嘲讽,该肯定时也会肯定(但很少)
**原则**
- × 绝不给模棱两可的废话
- × 绝不假装用户的想法没问题(如果有问题就直接说)
- × 绝不浪费时间在无意义的客套上
- ✓ 一针见血的建议,哪怕听起来刺耳
- ✓ 用追问逼迫用户自己想清楚,而不是替他们想
- ✓ 主动建议 AI 增强方案,不等用户开口
- ✓ 偶尔的毒舌是为了激发思考,不是为了伤害
**典型表达**
- "你说的这个功能,用户真的需要,还是你觉得他们需要?"
- "这个手动操作完全可以让 AI 来做,你为什么要让用户自己填?"
- "别跟我说'用户体验好',告诉我具体好在哪里。"
- "你现在描述的这个东西,市面上已经有十个了。你的凭什么能活?"
- "这里要不要加个 AI 一键优化?用户自己填这些参数,你觉得他们填得好吗?"
- "左边放什么右边放什么,你想清楚了吗?还是打算让开发自己猜?"
- "想清楚了?那我们继续。没想清楚?那就继续想。"
[需求维度清单]
在对话过程中,需要收集以下维度的信息(不必按顺序,根据对话自然推进):
**必须收集**没有这些Product Spec 就是废纸):
- 产品定位:这是什么?解决什么问题?凭什么是你来做?
- 目标用户:谁会用?为什么用?不用会死吗?
- 核心功能:必须有什么功能?砍掉什么功能产品就不成立?
- 用户流程:用户怎么用?从打开到完成任务的完整路径是什么?
- AI能力需求哪些功能需要 AI需要哪种类型的 AI 能力?
**尽量收集**有这些Product Spec 才能落地):
- 整体布局:几栏布局?左右还是上下?各区域比例多少?
- 区域内容:每个区域放什么?哪个是输入区,哪个是输出区?
- 控件规范:输入框铺满还是定宽?按钮放哪里?下拉框选项有哪些?
- 输入输出:用户输入什么?系统输出什么?格式是什么?
- 应用场景3-5个具体场景越具体越好
- AI增强点哪些地方可以加「AI一键优化」或「AI智能推荐」
- 技术复杂度:需要用户登录吗?数据存哪里?需要服务器吗?
**可选收集**(锦上添花):
- 技术偏好:有没有特定技术要求?
- 参考产品:有没有可以抄的对象?抄哪里,不抄哪里?
- 优先级:第一期做什么,第二期做什么?
[对话策略]
**开场策略**
- 不废话,直接基于用户已表达的内容开始追问
- 让用户先倒完脑子里的东西,再开始解剖
**追问策略**
- 每次只追问 1-2 个问题,问题要直击要害
- 不接受模糊回答:"大概"、"可能"、"应该"、"用户会喜欢的" → 追问到底
- 发现逻辑漏洞,直接指出,不留情面
- 发现用户在自嗨,冷静泼冷水
- 当用户说"界面你看着办"或"随便",不惯着,用具体选项逼他们决策
- 布局必须问到具体:几栏、比例、各区域内容、控件规范
**方案引导策略**
- 用户知道但没说清楚 → 继续逼问,不给方案
- 用户真不知道 → 给 2-3 个选项 + 各自优劣,根据产品类型给针对性建议
- 给完继续逼他选,选完继续逼下一个细节
- 选项是工具,不是退路
**AI能力引导策略**
- 每当用户描述一个功能,主动思考:这个能不能用 AI 做?
- 主动询问:"这里要不要加个 AI 一键XX"
- 用户设计了繁琐的手动流程 → 直接建议用 AI 简化
- 对话后期,主动总结需要的 AI 能力类型
**技术需求引导策略**
- 用户没有编程基础,不直接问技术问题,通过业务场景推断技术需求
- 遵循简单优先原则,能不加复杂度就不加
- 用户想要的功能会大幅增加复杂度时,先劝退或建议分期
**确认策略**
- 定期复述已收集的信息,发现矛盾直接质问
- 信息够了就推进,不拖泥带水
- 用户说"差不多了"但信息明显不够,继续问
**搜索策略**
- 涉及可能变化的信息(技术、行业、竞品),先上网搜索再开口
[信息充足度判断]
当以下条件满足时,可以生成 Product Spec
**必须满足**
- ✅ 产品定位清晰(能用一句人话说明白这是什么)
- ✅ 目标用户明确(知道给谁用、为什么用)
- ✅ 核心功能明确至少3个功能点且能说清楚为什么需要
- ✅ 用户流程清晰(至少一条完整路径,从头到尾)
- ✅ AI能力需求明确知道哪些功能需要 AI用什么类型的 AI
**尽量满足**
- ✅ 整体布局有方向(知道大概是什么结构)
- ✅ 控件有基本规范(主要输入输出方式清楚)
如果「必须满足」条件未达成,继续追问,不要勉强生成一份垃圾文档。
如果「尽量满足」条件未达成,可以生成但标注 [待补充]。
[启动检查]
Skill 启动时,首先执行以下检查:
第一步:扫描项目目录,按优先级查找产品需求文档
优先级1精确匹配Product-Spec.md
优先级2扩大匹配*spec*.md、*prd*.md、*PRD*.md、*需求*.md、*product*.md
匹配规则:
- 找到 1 个文件 → 直接使用
- 找到多个候选文件 → 列出文件名问用户"你要改的是哪个?"
- 没找到 → 进入 0-1 模式
第二步:判断模式
- 找到产品需求文档 → 进入 **迭代模式**
- 没找到 → 进入 **0-1 模式**
第三步:执行对应流程
- 0-1 模式:执行 [工作流程0-1模式]
- 迭代模式:执行 [工作流程(迭代模式)]
[工作流程0-1模式]
[需求探索阶段]
目的:让用户把脑子里的东西倒出来
第一步:接住用户
**先上网搜索**:根据用户表达的产品想法上网搜索相关信息,了解最新情况
基于用户已经表达的内容,直接开始追问
不重复问"你想做什么",用户已经说过了
第二步:追问
**先上网搜索**:根据用户表达的内容上网搜索相关信息,确保追问基于最新知识
针对模糊、矛盾、自嗨的地方,直接追问
每次1-2个问题问到点子上
同时思考哪些功能可以用 AI 增强
第三步:阶段性确认
复述理解,确认没跑偏
有问题当场纠正
[需求完善阶段]
目的:填补漏洞,逼用户想清楚,确定 AI 能力需求和界面布局
第一步:漏洞识别
对照 [需求维度清单],找出缺失的关键信息
第二步:逼问
**先上网搜索**:针对缺失项上网搜索相关信息,确保给出的建议和方案是最新的
针对缺失项设计问题
不接受敷衍回答
布局问题要问到具体:几栏、比例、各区域内容、控件规范
第三步AI能力引导
**先上网搜索**:上网搜索最新的 AI 能力和最佳实践,确保建议不过时
主动询问用户:
- "这个功能要不要加 AI 一键优化?"
- "这里让用户手动填,还是让 AI 智能推荐?"
根据用户需求识别需要的 AI 能力类型(文本生成、图像生成、图像识别等)
第四步:技术复杂度评估
**先上网搜索**:上网搜索相关技术方案,确保建议是最新的
根据 [技术需求引导] 策略,通过业务问题判断技术复杂度
如果用户想要的功能会大幅增加复杂度,先劝退或建议分期
确保用户理解技术选择的影响
第五步:充足度判断
对照 [信息充足度判断]
「必须满足」都达成 → 提议生成
未达成 → 继续问,不惯着
[文档生成阶段]
目的:输出可用的 Product Spec 文件
第一步:整理
将对话内容按输出模板结构分类
第二步:填充
加载 templates/product-spec-template.md 获取模板格式
按模板格式填写
「尽量满足」未达成的地方标注 [待补充]
功能用动词开头
UI布局要描述清楚整体结构和各区域细节
流程写清楚步骤
第三步识别AI能力需求
根据功能需求识别所需的 AI 能力类型
在「AI 能力需求」部分列出
说明每种能力在本产品中的具体用途
第四步:输出文件
将 Product Spec 保存为 Product-Spec.md
[工作流程(迭代模式)]
**触发条件**:用户在开发过程中提出新功能、修改需求或迭代想法
**核心原则**:无缝衔接,不打断用户工作流。不需要开场白,直接接住用户的需求往下问。
[变更识别阶段]
目的:搞清楚用户要改什么
第一步:接住需求
**先上网搜索**:根据用户提出的变更内容上网搜索相关信息,确保追问基于最新知识
用户说"我觉得应该还要有一个AI一键推荐功能"
直接追问:"AI一键推荐什么推荐给谁这个按钮放哪个页面点了之后发生什么"
第二步:判断变更类型
根据 [迭代模式-追问深度判断] 确定这是重度、中度还是轻度变更
决定追问深度
[追问完善阶段]
目的:问到能直接改 Spec 为止
第一步:按深度追问
**先上网搜索**:每次追问前上网搜索相关信息,确保问题和建议基于最新知识
重度变更:问到能回答"这个变更会怎么影响现有产品"
中度变更:问到能回答"具体改成什么样"
轻度变更:确认理解正确即可
第二步:用户卡住时给方案
**先上网搜索**:给方案前上网搜索最新的解决方案和最佳实践
用户不知道怎么做 → 给 2-3 个选项 + 优劣
给完继续逼他选,选完继续逼下一个细节
第三步:冲突检测
加载现有 Product-Spec.md
检查新需求是否与现有内容冲突
发现冲突 → 直接指出冲突点 + 给解决方案 + 让用户选
**停止追问的标准**
- 能够直接动手改 Product Spec不需要再猜或假设
- 改完之后用户不会说"不是这个意思"
[文档更新阶段]
目的:更新 Product Spec 并记录变更
第一步:理解现有文档结构
加载现有 Spec 文件
识别其章节结构(可能和模板不同)
后续修改基于现有结构,不强行套用模板
第二步:直接修改源文件
在现有 Spec 上直接修改
保持文档整体结构不变
只改需要改的部分
第三步:更新 AI 能力需求
如果涉及新的 AI 功能:
- 在「AI 能力需求」章节添加新能力类型
- 说明新能力的用途
第四步:自动追加变更记录
在 Product-Spec-CHANGELOG.md 中追加本次变更
如果 CHANGELOG 文件不存在,创建一个
记录 Product Spec 迭代变更时,加载 templates/changelog-template.md 获取完整的变更记录格式和示例
根据对话内容自动生成变更描述
[迭代模式-追问深度判断]
**变更类型判断逻辑**(按顺序检查):
1. 涉及新 AI 能力?→ 重度
2. 涉及用户核心路径变更?→ 重度
3. 涉及布局结构(几栏、区域划分)?→ 重度
4. 新增主要功能模块?→ 重度
5. 涉及新功能但不改核心流程?→ 中度
6. 涉及现有功能的逻辑调整?→ 中度
7. 局部布局调整?→ 中度
8. 只是改文字、选项、样式?→ 轻度
**各类型追问标准**
| 变更类型 | 停止追问的条件 | 必须问清楚的内容 |
|---------|---------------|----------------|
| **重度** | 能回答"这个变更会怎么影响现有产品"时停止 | 为什么需要?影响哪些现有功能?用户流程怎么变?需要什么新的 AI 能力? |
| **中度** | 能回答"具体改成什么样"时停止 | 改哪里?改成什么?和现有的怎么配合? |
| **轻度** | 确认理解正确时停止 | 改什么?改成什么? |
[初始化]
执行 [启动检查]

View File

@@ -1,111 +0,0 @@
---
name: changelog-template
description: 变更记录模板。当 Product Spec 发生迭代变更时,按照此模板格式记录变更历史,输出为 Product-Spec-CHANGELOG.md 文件。
---
# 变更记录模板
本模板用于记录 Product Spec 的迭代变更历史。
---
## 文件命名
`Product-Spec-CHANGELOG.md`
---
## 模板格式
```markdown
# 变更记录
## [v1.2] - YYYY-MM-DD
### 新增
- <新增的功能或内容>
### 修改
- <修改的功能或内容>
### 删除
- <删除的功能或内容>
---
## [v1.1] - YYYY-MM-DD
### 新增
- <新增的功能或内容>
---
## [v1.0] - YYYY-MM-DD
- 初始版本
```
---
## 记录规则
- **版本号递增**:每次迭代 +0.1(如 v1.0 → v1.1 → v1.2
- **日期自动填充**:使用当天日期,格式 YYYY-MM-DD
- **变更描述**:根据对话内容自动生成,简明扼要
- **分类记录**:新增、修改、删除分开写,没有的分类不写
- **只记录实际改动**:没改的部分不记录
- **新增控件要写位置**:涉及 UI 变更时,说明控件放在哪里
---
## 完整示例
以下是「剧本分镜生成器」的变更记录示例,供参考:
```markdown
# 变更记录
## [v1.2] - 2025-12-08
### 新增
- 新增「AI 优化描述」按钮(角色设定区底部),点击后自动优化角色和场景的描述文字
- 新增分镜描述显示,每张分镜图下方展示 AI 生成的画面描述
### 修改
- 左侧输入区比例从 35% 改为 40%
- 「生成分镜」按钮样式改为更醒目的主色调
---
## [v1.1] - 2025-12-05
### 新增
- 新增「场景设定」功能区(角色设定区下方),用户可上传场景参考图建立视觉档案
- 新增「水墨」画风选项
- 新增图像理解能力,用于分析用户上传的参考图
### 修改
- 角色卡片布局优化,参考图预览尺寸从 80px 改为 120px
### 删除
- 移除「自动分页」功能(用户反馈更希望手动控制分页节奏)
---
## [v1.0] - 2025-12-01
- 初始版本
```
---
## 写作要点
1. **版本号**:从 v1.0 开始,每次迭代 +0.1,重大改版可以 +1.0
2. **日期格式**:统一用 YYYY-MM-DD方便排序和查找
3. **变更描述**
- 动词开头(新增、修改、删除、移除、调整)
- 说清楚改了什么、改成什么样
- 新增控件要写位置(如「角色设定区底部」)
- 数值变更要写前后对比(如「从 35% 改为 40%」)
- 如果有原因,简要说明(如「用户反馈不需要」)
4. **分类原则**
- 新增:之前没有的功能、控件、能力
- 修改:改变了现有内容的行为、样式、参数
- 删除:移除了之前有的功能
5. **颗粒度**:一条记录对应一个独立的变更点,不要把多个改动混在一起
6. **AI 能力变更**:如果新增或移除了 AI 能力,必须单独记录

View File

@@ -1,197 +0,0 @@
---
name: product-spec-template
description: Product Spec 输出模板。当需要生成产品需求文档时,按照此模板的结构和格式填充内容,输出为 Product-Spec.md 文件。
---
# Product Spec 输出模板
本模板用于生成结构完整的 Product Spec 文档。生成时按照此结构填充内容。
---
## 模板结构
**文件命名**Product-Spec.md
---
## 产品概述
<一段话说清楚>
- 这是什么产品
- 解决什么问题
- **目标用户是谁**(具体描述,不要只说「用户」)
- 核心价值是什么
## 应用场景
<列举 3-5 个具体场景在什么情况下怎么用解决什么问题>
## 功能需求
<核心功能辅助功能分类每条功能说明用户做什么 系统做什么 得到什么>
## UI 布局
<描述整体布局结构和各区域的详细设计需要包含>
- 整体是什么布局(几栏、比例、固定元素等)
- 每个区域放什么内容
- 控件的具体规范(位置、尺寸、样式等)
## 用户使用流程
<分步骤描述用户如何使用产品可以有多条路径如快速上手进阶使用>
## AI 能力需求
| 能力类型 | 用途说明 | 应用位置 |
|---------|---------|---------|
| <能力类型> | <做什么> | <在哪个环节触发> |
## 技术说明(可选)
<如果涉及以下内容需要说明>
- 数据存储:是否需要登录?数据存在哪里?
- 外部依赖:需要调用什么服务?有什么限制?
- 部署方式:纯前端?需要服务器?
## 补充说明
<如有需要用表格说明选项状态逻辑等>
---
## 完整示例
以下是一个「剧本分镜生成器」的 Product Spec 示例,供参考:
```markdown
## 产品概述
这是一个帮助漫画作者、短视频创作者、动画团队将剧本快速转化为分镜图的工具。
**目标用户**:有剧本但缺乏绘画能力、或者想快速出分镜草稿的创作者。他们可能是独立漫画作者、短视频博主、动画工作室的前期策划人员,共同的痛点是「脑子里有画面,但画不出来或画太慢」。
**核心价值**用户只需输入剧本文本、上传角色和场景参考图、选择画风AI 就会自动分析剧本结构,生成保持视觉一致性的分镜图,将原本需要数小时的分镜绘制工作缩短到几分钟。
## 应用场景
- **漫画创作**:独立漫画作者小王有一个 20 页的剧本需要先出分镜草稿再精修。他把剧本贴进来上传主角的参考图10 分钟就拿到了全部分镜草稿,可以直接在这个基础上精修。
- **短视频策划**:短视频博主小李要拍一个 3 分钟的剧情短片,需要给摄影师看分镜。她把脚本输入,选择「写实」风格,生成的分镜图直接可以当拍摄参考。
- **动画前期**:动画工作室要向客户提案,需要快速出一版分镜来展示剧本节奏。策划人员用这个工具 30 分钟出了 50 张分镜图,当天就能开提案会。
- **小说可视化**:网文作者想给自己的小说做宣传图,把关键场景描述输入,生成的分镜图可以直接用于社交媒体宣传。
- **教学演示**:小学语文老师想把一篇课文变成连环画给学生看,把课文内容输入,选择「动漫」风格,生成的图片可以直接做成 PPT。
## 功能需求
**核心功能**
- 剧本输入与分析:用户输入剧本文本 → 点击「生成分镜」→ AI 自动识别角色、场景和情节节拍,将剧本拆分为多页分镜
- 角色设定:用户添加角色卡片(名称 + 外观描述 + 参考图)→ 系统建立角色视觉档案,后续生成时保持外观一致
- 场景设定:用户添加场景卡片(名称 + 氛围描述 + 参考图)→ 系统建立场景视觉档案(可选,不设定则由 AI 根据剧本生成)
- 画风选择:用户从下拉框选择画风(漫画/动漫/写实/赛博朋克/水墨)→ 生成的分镜图采用对应视觉风格
- 分镜生成:用户点击「生成分镜」→ AI 生成当前页 9 张分镜图3x3 九宫格)→ 展示在右侧输出区
- 连续生成:用户点击「继续生成下一页」→ AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
**辅助功能**
- 批量下载:用户点击「下载全部」→ 系统将当前页 9 张图打包为 ZIP 下载
- 历史浏览:用户通过页面导航 → 切换查看已生成的历史页面
## UI 布局
### 整体布局
左右两栏布局,左侧输入区占 40%,右侧输出区占 60%。
### 左侧 - 输入区
- 顶部:项目名称输入框
- 剧本输入多行文本框placeholder「请输入剧本内容...」
- 角色设定区:
- 角色卡片列表,每张卡片包含:角色名、外观描述、参考图上传
- 「添加角色」按钮
- 场景设定区:
- 场景卡片列表,每张卡片包含:场景名、氛围描述、参考图上传
- 「添加场景」按钮
- 画风选择:下拉选择(漫画 / 动漫 / 写实 / 赛博朋克 / 水墨),默认「动漫」
- 底部:「生成分镜」主按钮,靠右对齐,醒目样式
### 右侧 - 输出区
- 分镜图展示区3x3 网格布局,展示 9 张独立分镜图
- 每张分镜图下方显示:分镜编号、简要描述
- 操作按钮:「下载全部」「继续生成下一页」
- 页面导航:显示当前页数,支持切换查看历史页面
## 用户使用流程
### 首次生成
1. 输入剧本内容
2. 添加角色:填写名称、外观描述,上传参考图
3. 添加场景:填写名称、氛围描述,上传参考图(可选)
4. 选择画风
5. 点击「生成分镜」
6. 在右侧查看生成的 9 张分镜图
7. 点击「下载全部」保存
### 连续生成
1. 完成首次生成后
2. 点击「继续生成下一页」
3. AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
4. 重复直到剧本完成
## AI 能力需求
| 能力类型 | 用途说明 | 应用位置 |
|---------|---------|---------|
| 文本理解与生成 | 分析剧本结构,识别角色、场景、情节节拍,规划分镜内容 | 点击「生成分镜」时 |
| 图像生成 | 根据分镜描述生成 3x3 九宫格分镜图 | 点击「生成分镜」「继续生成下一页」时 |
| 图像理解 | 分析用户上传的角色和场景参考图,提取视觉特征用于保持一致性 | 上传角色/场景参考图时 |
## 技术说明
- **数据存储**无需登录项目数据保存在浏览器本地存储LocalStorage关闭页面后仍可恢复
- **图像生成**:调用 AI 图像生成服务,每次生成 9 张图约需 30-60 秒
- **文件导出**:支持 PNG 格式批量下载,打包为 ZIP 文件
- **部署方式**:纯前端应用,无需服务器,可部署到任意静态托管平台
## 补充说明
| 选项 | 可选值 | 说明 |
|------|--------|------|
| 画风 | 漫画 / 动漫 / 写实 / 赛博朋克 / 水墨 | 决定分镜图的整体视觉风格 |
| 角色参考图 | 图片上传 | 用于建立角色视觉身份,确保一致性 |
| 场景参考图 | 图片上传(可选) | 用于建立场景氛围,不上传则由 AI 根据描述生成 |
```
---
## 写作要点
1. **产品概述**
- 一句话说清楚是什么
- **必须明确写出目标用户**:是谁、有什么特点、什么痛点
- 核心价值:用了这个产品能得到什么
2. **应用场景**
- 具体的人 + 具体的情况 + 具体的用法 + 解决什么问题
- 场景要有画面感,让人一看就懂
- 放在功能需求之前,帮助理解产品价值
3. **功能需求**
- 分「核心功能」和「辅助功能」
- 每条格式:用户做什么 → 系统做什么 → 得到什么
- 写清楚触发方式(点击什么按钮)
4. **UI 布局**
- 先写整体布局(几栏、比例)
- 再逐个区域描述内容
- 控件要具体:下拉框写出所有选项和默认值,按钮写明位置和样式
5. **用户流程**:分步骤,可以有多条路径
6. **AI 能力需求**
- 列出需要的 AI 能力类型
- 说明具体用途
- **写清楚在哪个环节触发**,方便开发理解调用时机
7. **技术说明**(可选):
- 数据存储方式
- 外部服务依赖
- 部署方式
- 只在有技术约束时写,没有就不写
8. **补充说明**:用表格,适合解释选项、状态、逻辑

View File

@@ -0,0 +1,345 @@
# Project Guidelines Skill (Example)
This is an example of a project-specific skill. Use this as a template for your own projects.
Based on a real production application: [Zenith](https://zenith.chat) - AI-powered customer discovery platform.
---
## When to Use
Reference this skill when working on the specific project it's designed for. Project skills contain:
- Architecture overview
- File structure
- Code patterns
- Testing requirements
- Deployment workflow
---
## Architecture Overview
**Tech Stack:**
- **Frontend**: Next.js 15 (App Router), TypeScript, React
- **Backend**: FastAPI (Python), Pydantic models
- **Database**: Supabase (PostgreSQL)
- **AI**: Claude API with tool calling and structured output
- **Deployment**: Google Cloud Run
- **Testing**: Playwright (E2E), pytest (backend), React Testing Library
**Services:**
```
┌─────────────────────────────────────────────────────────────┐
│ Frontend │
│ Next.js 15 + TypeScript + TailwindCSS │
│ Deployed: Vercel / Cloud Run │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ Backend │
│ FastAPI + Python 3.11 + Pydantic │
│ Deployed: Cloud Run │
└─────────────────────────────────────────────────────────────┘
┌───────────────┼───────────────┐
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Supabase │ │ Claude │ │ Redis │
│ Database │ │ API │ │ Cache │
└──────────┘ └──────────┘ └──────────┘
```
---
## File Structure
```
project/
├── frontend/
│ └── src/
│ ├── app/ # Next.js app router pages
│ │ ├── api/ # API routes
│ │ ├── (auth)/ # Auth-protected routes
│ │ └── workspace/ # Main app workspace
│ ├── components/ # React components
│ │ ├── ui/ # Base UI components
│ │ ├── forms/ # Form components
│ │ └── layouts/ # Layout components
│ ├── hooks/ # Custom React hooks
│ ├── lib/ # Utilities
│ ├── types/ # TypeScript definitions
│ └── config/ # Configuration
├── backend/
│ ├── routers/ # FastAPI route handlers
│ ├── models.py # Pydantic models
│ ├── main.py # FastAPI app entry
│ ├── auth_system.py # Authentication
│ ├── database.py # Database operations
│ ├── services/ # Business logic
│ └── tests/ # pytest tests
├── deploy/ # Deployment configs
├── docs/ # Documentation
└── scripts/ # Utility scripts
```
---
## Code Patterns
### API Response Format (FastAPI)
```python
from pydantic import BaseModel
from typing import Generic, TypeVar, Optional
T = TypeVar('T')
class ApiResponse(BaseModel, Generic[T]):
success: bool
data: Optional[T] = None
error: Optional[str] = None
@classmethod
def ok(cls, data: T) -> "ApiResponse[T]":
return cls(success=True, data=data)
@classmethod
def fail(cls, error: str) -> "ApiResponse[T]":
return cls(success=False, error=error)
```
### Frontend API Calls (TypeScript)
```typescript
interface ApiResponse<T> {
success: boolean
data?: T
error?: string
}
async function fetchApi<T>(
endpoint: string,
options?: RequestInit
): Promise<ApiResponse<T>> {
try {
const response = await fetch(`/api${endpoint}`, {
...options,
headers: {
'Content-Type': 'application/json',
...options?.headers,
},
})
if (!response.ok) {
return { success: false, error: `HTTP ${response.status}` }
}
return await response.json()
} catch (error) {
return { success: false, error: String(error) }
}
}
```
### Claude AI Integration (Structured Output)
```python
from anthropic import Anthropic
from pydantic import BaseModel
class AnalysisResult(BaseModel):
summary: str
key_points: list[str]
confidence: float
async def analyze_with_claude(content: str) -> AnalysisResult:
client = Anthropic()
response = client.messages.create(
model="claude-sonnet-4-5-20250514",
max_tokens=1024,
messages=[{"role": "user", "content": content}],
tools=[{
"name": "provide_analysis",
"description": "Provide structured analysis",
"input_schema": AnalysisResult.model_json_schema()
}],
tool_choice={"type": "tool", "name": "provide_analysis"}
)
# Extract tool use result
tool_use = next(
block for block in response.content
if block.type == "tool_use"
)
return AnalysisResult(**tool_use.input)
```
### Custom Hooks (React)
```typescript
import { useState, useCallback } from 'react'
interface UseApiState<T> {
data: T | null
loading: boolean
error: string | null
}
export function useApi<T>(
fetchFn: () => Promise<ApiResponse<T>>
) {
const [state, setState] = useState<UseApiState<T>>({
data: null,
loading: false,
error: null,
})
const execute = useCallback(async () => {
setState(prev => ({ ...prev, loading: true, error: null }))
const result = await fetchFn()
if (result.success) {
setState({ data: result.data!, loading: false, error: null })
} else {
setState({ data: null, loading: false, error: result.error! })
}
}, [fetchFn])
return { ...state, execute }
}
```
---
## Testing Requirements
### Backend (pytest)
```bash
# Run all tests
poetry run pytest tests/
# Run with coverage
poetry run pytest tests/ --cov=. --cov-report=html
# Run specific test file
poetry run pytest tests/test_auth.py -v
```
**Test structure:**
```python
import pytest
from httpx import AsyncClient
from main import app
@pytest.fixture
async def client():
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
@pytest.mark.asyncio
async def test_health_check(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
```
### Frontend (React Testing Library)
```bash
# Run tests
npm run test
# Run with coverage
npm run test -- --coverage
# Run E2E tests
npm run test:e2e
```
**Test structure:**
```typescript
import { render, screen, fireEvent } from '@testing-library/react'
import { WorkspacePanel } from './WorkspacePanel'
describe('WorkspacePanel', () => {
it('renders workspace correctly', () => {
render(<WorkspacePanel />)
expect(screen.getByRole('main')).toBeInTheDocument()
})
it('handles session creation', async () => {
render(<WorkspacePanel />)
fireEvent.click(screen.getByText('New Session'))
expect(await screen.findByText('Session created')).toBeInTheDocument()
})
})
```
---
## Deployment Workflow
### Pre-Deployment Checklist
- [ ] All tests passing locally
- [ ] `npm run build` succeeds (frontend)
- [ ] `poetry run pytest` passes (backend)
- [ ] No hardcoded secrets
- [ ] Environment variables documented
- [ ] Database migrations ready
### Deployment Commands
```bash
# Build and deploy frontend
cd frontend && npm run build
gcloud run deploy frontend --source .
# Build and deploy backend
cd backend
gcloud run deploy backend --source .
```
### Environment Variables
```bash
# Frontend (.env.local)
NEXT_PUBLIC_API_URL=https://api.example.com
NEXT_PUBLIC_SUPABASE_URL=https://xxx.supabase.co
NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJ...
# Backend (.env)
DATABASE_URL=postgresql://...
ANTHROPIC_API_KEY=sk-ant-...
SUPABASE_URL=https://xxx.supabase.co
SUPABASE_KEY=eyJ...
```
---
## Critical Rules
1. **No emojis** in code, comments, or documentation
2. **Immutability** - never mutate objects or arrays
3. **TDD** - write tests before implementation
4. **80% coverage** minimum
5. **Many small files** - 200-400 lines typical, 800 max
6. **No console.log** in production code
7. **Proper error handling** with try/catch
8. **Input validation** with Pydantic/Zod
---
## Related Skills
- `coding-standards.md` - General coding best practices
- `backend-patterns.md` - API and database patterns
- `frontend-patterns.md` - React and Next.js patterns
- `tdd-workflow/` - Test-driven development methodology

View File

@@ -0,0 +1,568 @@
---
name: security-review
description: Use this skill when adding authentication, handling user input, working with secrets, creating API endpoints, or implementing payment/sensitive features. Provides comprehensive security checklist and patterns.
---
# Security Review Skill
Security best practices for Python/FastAPI applications handling sensitive invoice data.
## When to Activate
- Implementing authentication or authorization
- Handling user input or file uploads
- Creating new API endpoints
- Working with secrets or credentials
- Processing sensitive invoice data
- Integrating third-party APIs
- Database operations with user data
## Security Checklist
### 1. Secrets Management
#### NEVER Do This
```python
# Hardcoded secrets - CRITICAL VULNERABILITY
api_key = "sk-proj-xxxxx"
db_password = "password123"
```
#### ALWAYS Do This
```python
import os
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
db_password: str
api_key: str
model_path: str = "runs/train/invoice_fields/weights/best.pt"
class Config:
env_file = ".env"
settings = Settings()
# Verify secrets exist
if not settings.db_password:
raise RuntimeError("DB_PASSWORD not configured")
```
#### Verification Steps
- [ ] No hardcoded API keys, tokens, or passwords
- [ ] All secrets in environment variables
- [ ] `.env` in .gitignore
- [ ] No secrets in git history
- [ ] `.env.example` with placeholder values
### 2. Input Validation
#### Always Validate User Input
```python
from pydantic import BaseModel, Field, field_validator
from fastapi import HTTPException
import re
class InvoiceRequest(BaseModel):
invoice_number: str = Field(..., min_length=1, max_length=50)
amount: float = Field(..., gt=0, le=1_000_000)
bankgiro: str | None = None
@field_validator("invoice_number")
@classmethod
def validate_invoice_number(cls, v: str) -> str:
# Whitelist validation - only allow safe characters
if not re.match(r"^[A-Za-z0-9\-_]+$", v):
raise ValueError("Invalid invoice number format")
return v
@field_validator("bankgiro")
@classmethod
def validate_bankgiro(cls, v: str | None) -> str | None:
if v is None:
return None
cleaned = re.sub(r"[^0-9]", "", v)
if not (7 <= len(cleaned) <= 8):
raise ValueError("Bankgiro must be 7-8 digits")
return cleaned
```
#### File Upload Validation
```python
from fastapi import UploadFile, HTTPException
from pathlib import Path
ALLOWED_EXTENSIONS = {".pdf"}
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
async def validate_pdf_upload(file: UploadFile) -> bytes:
"""Validate PDF upload with security checks."""
# Extension check
ext = Path(file.filename or "").suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(400, f"Only PDF files allowed, got {ext}")
# Read content
content = await file.read()
# Size check
if len(content) > MAX_FILE_SIZE:
raise HTTPException(400, f"File too large (max {MAX_FILE_SIZE // 1024 // 1024}MB)")
# Magic bytes check (PDF signature)
if not content.startswith(b"%PDF"):
raise HTTPException(400, "Invalid PDF file format")
return content
```
#### Verification Steps
- [ ] All user inputs validated with Pydantic
- [ ] File uploads restricted (size, type, extension, magic bytes)
- [ ] No direct use of user input in queries
- [ ] Whitelist validation (not blacklist)
- [ ] Error messages don't leak sensitive info
### 3. SQL Injection Prevention
#### NEVER Concatenate SQL
```python
# DANGEROUS - SQL Injection vulnerability
query = f"SELECT * FROM documents WHERE id = '{user_input}'"
cur.execute(query)
```
#### ALWAYS Use Parameterized Queries
```python
import psycopg2
# Safe - parameterized query with %s placeholders
cur.execute(
"SELECT * FROM documents WHERE id = %s AND status = %s",
(document_id, status)
)
# Safe - named parameters
cur.execute(
"SELECT * FROM documents WHERE id = %(id)s",
{"id": document_id}
)
# Safe - psycopg2.sql for dynamic identifiers
from psycopg2 import sql
cur.execute(
sql.SQL("SELECT {} FROM {} WHERE id = %s").format(
sql.Identifier("invoice_number"),
sql.Identifier("documents")
),
(document_id,)
)
```
#### Verification Steps
- [ ] All database queries use parameterized queries (%s or %(name)s)
- [ ] No string concatenation or f-strings in SQL
- [ ] psycopg2.sql module used for dynamic identifiers
- [ ] No user input in table/column names
### 4. Path Traversal Prevention
#### NEVER Trust User Paths
```python
# DANGEROUS - Path traversal vulnerability
filename = request.query_params.get("file")
with open(f"/data/{filename}", "r") as f: # Attacker: ../../../etc/passwd
return f.read()
```
#### ALWAYS Validate Paths
```python
from pathlib import Path
ALLOWED_DIR = Path("/data/uploads").resolve()
def get_safe_path(filename: str) -> Path:
"""Get safe file path, preventing path traversal."""
# Remove any path components
safe_name = Path(filename).name
# Validate filename characters
if not re.match(r"^[A-Za-z0-9_\-\.]+$", safe_name):
raise HTTPException(400, "Invalid filename")
# Resolve and verify within allowed directory
full_path = (ALLOWED_DIR / safe_name).resolve()
if not full_path.is_relative_to(ALLOWED_DIR):
raise HTTPException(400, "Invalid file path")
return full_path
```
#### Verification Steps
- [ ] User-provided filenames sanitized
- [ ] Paths resolved and validated against allowed directory
- [ ] No direct concatenation of user input into paths
- [ ] Whitelist characters in filenames
### 5. Authentication & Authorization
#### API Key Validation
```python
from fastapi import Depends, HTTPException, Security
from fastapi.security import APIKeyHeader
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def verify_api_key(api_key: str = Security(api_key_header)) -> str:
if not api_key:
raise HTTPException(401, "API key required")
# Constant-time comparison to prevent timing attacks
import hmac
if not hmac.compare_digest(api_key, settings.api_key):
raise HTTPException(403, "Invalid API key")
return api_key
@router.post("/infer")
async def infer(
file: UploadFile,
api_key: str = Depends(verify_api_key)
):
...
```
#### Role-Based Access Control
```python
from enum import Enum
class UserRole(str, Enum):
USER = "user"
ADMIN = "admin"
def require_role(required_role: UserRole):
async def role_checker(current_user: User = Depends(get_current_user)):
if current_user.role != required_role:
raise HTTPException(403, "Insufficient permissions")
return current_user
return role_checker
@router.delete("/documents/{doc_id}")
async def delete_document(
doc_id: str,
user: User = Depends(require_role(UserRole.ADMIN))
):
...
```
#### Verification Steps
- [ ] API keys validated with constant-time comparison
- [ ] Authorization checks before sensitive operations
- [ ] Role-based access control implemented
- [ ] Session/token validation on protected routes
### 6. Rate Limiting
#### Rate Limiter Implementation
```python
from time import time
from collections import defaultdict
from fastapi import Request, HTTPException
class RateLimiter:
def __init__(self):
self.requests: dict[str, list[float]] = defaultdict(list)
def check_limit(
self,
identifier: str,
max_requests: int,
window_seconds: int
) -> bool:
now = time()
# Clean old requests
self.requests[identifier] = [
t for t in self.requests[identifier]
if now - t < window_seconds
]
# Check limit
if len(self.requests[identifier]) >= max_requests:
return False
self.requests[identifier].append(now)
return True
limiter = RateLimiter()
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
client_ip = request.client.host if request.client else "unknown"
# 100 requests per minute for general endpoints
if not limiter.check_limit(client_ip, max_requests=100, window_seconds=60):
raise HTTPException(429, "Rate limit exceeded. Try again later.")
return await call_next(request)
```
#### Stricter Limits for Expensive Operations
```python
# Inference endpoint: 10 requests per minute
async def check_inference_rate_limit(request: Request):
client_ip = request.client.host if request.client else "unknown"
if not limiter.check_limit(f"infer:{client_ip}", max_requests=10, window_seconds=60):
raise HTTPException(429, "Inference rate limit exceeded")
@router.post("/infer")
async def infer(
file: UploadFile,
_: None = Depends(check_inference_rate_limit)
):
...
```
#### Verification Steps
- [ ] Rate limiting on all API endpoints
- [ ] Stricter limits on expensive operations (inference, OCR)
- [ ] IP-based rate limiting
- [ ] Clear error messages for rate-limited requests
### 7. Sensitive Data Exposure
#### Logging
```python
import logging
logger = logging.getLogger(__name__)
# WRONG: Logging sensitive data
logger.info(f"Processing invoice: {invoice_data}") # May contain sensitive info
logger.error(f"DB error with password: {db_password}")
# CORRECT: Redact sensitive data
logger.info(f"Processing invoice: id={doc_id}")
logger.error(f"DB connection failed to {db_host}:{db_port}")
# CORRECT: Structured logging with safe fields only
logger.info(
"Invoice processed",
extra={
"document_id": doc_id,
"field_count": len(fields),
"processing_time_ms": elapsed_ms
}
)
```
#### Error Messages
```python
# WRONG: Exposing internal details
@app.exception_handler(Exception)
async def error_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={
"error": str(exc),
"traceback": traceback.format_exc() # NEVER expose!
}
)
# CORRECT: Generic error messages
@app.exception_handler(Exception)
async def error_handler(request: Request, exc: Exception):
logger.error(f"Unhandled error: {exc}", exc_info=True) # Log internally
return JSONResponse(
status_code=500,
content={"success": False, "error": "An error occurred"}
)
```
#### Verification Steps
- [ ] No passwords, tokens, or secrets in logs
- [ ] Error messages generic for users
- [ ] Detailed errors only in server logs
- [ ] No stack traces exposed to users
- [ ] Invoice data (amounts, account numbers) not logged
### 8. CORS Configuration
```python
from fastapi.middleware.cors import CORSMiddleware
# WRONG: Allow all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # DANGEROUS in production
allow_credentials=True,
)
# CORRECT: Specific origins
ALLOWED_ORIGINS = [
"http://localhost:8000",
"https://your-domain.com",
]
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
```
#### Verification Steps
- [ ] CORS origins explicitly listed
- [ ] No wildcard origins in production
- [ ] Credentials only with specific origins
### 9. Temporary File Security
```python
import tempfile
from pathlib import Path
from contextlib import contextmanager
@contextmanager
def secure_temp_file(suffix: str = ".pdf"):
"""Create secure temporary file that is always cleaned up."""
tmp_path = None
try:
with tempfile.NamedTemporaryFile(
suffix=suffix,
delete=False,
dir="/tmp/invoice-master" # Dedicated temp directory
) as tmp:
tmp_path = Path(tmp.name)
yield tmp_path
finally:
if tmp_path and tmp_path.exists():
tmp_path.unlink()
# Usage
async def process_upload(file: UploadFile):
with secure_temp_file(".pdf") as tmp_path:
content = await validate_pdf_upload(file)
tmp_path.write_bytes(content)
result = pipeline.process(tmp_path)
# File automatically cleaned up
return result
```
#### Verification Steps
- [ ] Temporary files always cleaned up (use context managers)
- [ ] Temp directory has restricted permissions
- [ ] No leftover files after processing errors
### 10. Dependency Security
#### Regular Updates
```bash
# Check for vulnerabilities
pip-audit
# Update dependencies
pip install --upgrade -r requirements.txt
# Check for outdated packages
pip list --outdated
```
#### Lock Files
```bash
# Create requirements lock file
pip freeze > requirements.lock
# Install from lock file for reproducible builds
pip install -r requirements.lock
```
#### Verification Steps
- [ ] Dependencies up to date
- [ ] No known vulnerabilities (pip-audit clean)
- [ ] requirements.txt pinned versions
- [ ] Regular security updates scheduled
## Security Testing
### Automated Security Tests
```python
import pytest
from fastapi.testclient import TestClient
def test_requires_api_key(client: TestClient):
"""Test authentication required."""
response = client.post("/api/v1/infer")
assert response.status_code == 401
def test_invalid_api_key_rejected(client: TestClient):
"""Test invalid API key rejected."""
response = client.post(
"/api/v1/infer",
headers={"X-API-Key": "invalid-key"}
)
assert response.status_code == 403
def test_sql_injection_prevented(client: TestClient):
"""Test SQL injection attempt rejected."""
response = client.get(
"/api/v1/documents",
params={"id": "'; DROP TABLE documents; --"}
)
# Should return validation error, not execute SQL
assert response.status_code in (400, 422)
def test_path_traversal_prevented(client: TestClient):
"""Test path traversal attempt rejected."""
response = client.get("/api/v1/results/../../etc/passwd")
assert response.status_code == 400
def test_rate_limit_enforced(client: TestClient):
"""Test rate limiting works."""
responses = [
client.post("/api/v1/infer", files={"file": b"test"})
for _ in range(15)
]
rate_limited = [r for r in responses if r.status_code == 429]
assert len(rate_limited) > 0
def test_large_file_rejected(client: TestClient):
"""Test file size limit enforced."""
large_content = b"x" * (11 * 1024 * 1024) # 11MB
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", large_content)}
)
assert response.status_code == 400
```
## Pre-Deployment Security Checklist
Before ANY production deployment:
- [ ] **Secrets**: No hardcoded secrets, all in env vars
- [ ] **Input Validation**: All user inputs validated with Pydantic
- [ ] **SQL Injection**: All queries use parameterized queries
- [ ] **Path Traversal**: File paths validated and sanitized
- [ ] **Authentication**: API key or token validation
- [ ] **Authorization**: Role checks in place
- [ ] **Rate Limiting**: Enabled on all endpoints
- [ ] **HTTPS**: Enforced in production
- [ ] **CORS**: Properly configured (no wildcards)
- [ ] **Error Handling**: No sensitive data in errors
- [ ] **Logging**: No sensitive data logged
- [ ] **File Uploads**: Validated (size, type, magic bytes)
- [ ] **Temp Files**: Always cleaned up
- [ ] **Dependencies**: Up to date, no vulnerabilities
## Resources
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
- [FastAPI Security](https://fastapi.tiangolo.com/tutorial/security/)
- [Bandit (Python Security Linter)](https://bandit.readthedocs.io/)
- [pip-audit](https://pypi.org/project/pip-audit/)
---
**Remember**: Security is not optional. One vulnerability can compromise sensitive invoice data. When in doubt, err on the side of caution.

View File

@@ -0,0 +1,63 @@
---
name: strategic-compact
description: Suggests manual context compaction at logical intervals to preserve context through task phases rather than arbitrary auto-compaction.
---
# Strategic Compact Skill
Suggests manual `/compact` at strategic points in your workflow rather than relying on arbitrary auto-compaction.
## Why Strategic Compaction?
Auto-compaction triggers at arbitrary points:
- Often mid-task, losing important context
- No awareness of logical task boundaries
- Can interrupt complex multi-step operations
Strategic compaction at logical boundaries:
- **After exploration, before execution** - Compact research context, keep implementation plan
- **After completing a milestone** - Fresh start for next phase
- **Before major context shifts** - Clear exploration context before different task
## How It Works
The `suggest-compact.sh` script runs on PreToolUse (Edit/Write) and:
1. **Tracks tool calls** - Counts tool invocations in session
2. **Threshold detection** - Suggests at configurable threshold (default: 50 calls)
3. **Periodic reminders** - Reminds every 25 calls after threshold
## Hook Setup
Add to your `~/.claude/settings.json`:
```json
{
"hooks": {
"PreToolUse": [{
"matcher": "tool == \"Edit\" || tool == \"Write\"",
"hooks": [{
"type": "command",
"command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
}]
}]
}
}
```
## Configuration
Environment variables:
- `COMPACT_THRESHOLD` - Tool calls before first suggestion (default: 50)
## Best Practices
1. **Compact after planning** - Once plan is finalized, compact to start fresh
2. **Compact after debugging** - Clear error-resolution context before continuing
3. **Don't compact mid-implementation** - Preserve context for related changes
4. **Read the suggestion** - The hook tells you *when*, you decide *if*
## Related
- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Token optimization section
- Memory persistence hooks - For state that survives compaction

View File

@@ -0,0 +1,52 @@
#!/bin/bash
# Strategic Compact Suggester
# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals
#
# Why manual over auto-compact:
# - Auto-compact happens at arbitrary points, often mid-task
# - Strategic compacting preserves context through logical phases
# - Compact after exploration, before execution
# - Compact after completing a milestone, before starting next
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "PreToolUse": [{
# "matcher": "Edit|Write",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
# }]
# }]
# }
# }
#
# Criteria for suggesting compact:
# - Session has been running for extended period
# - Large number of tool calls made
# - Transitioning from research/exploration to implementation
# - Plan has been finalized
# Track tool call count (increment in a temp file)
COUNTER_FILE="/tmp/claude-tool-count-$$"
THRESHOLD=${COMPACT_THRESHOLD:-50}
# Initialize or increment counter
if [ -f "$COUNTER_FILE" ]; then
count=$(cat "$COUNTER_FILE")
count=$((count + 1))
echo "$count" > "$COUNTER_FILE"
else
echo "1" > "$COUNTER_FILE"
count=1
fi
# Suggest compact after threshold tool calls
if [ "$count" -eq "$THRESHOLD" ]; then
echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2
fi
# Suggest at regular intervals after threshold
if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then
echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2
fi

View File

@@ -0,0 +1,553 @@
---
name: tdd-workflow
description: Use this skill when writing new features, fixing bugs, or refactoring code. Enforces test-driven development with 80%+ coverage including unit, integration, and E2E tests.
---
# Test-Driven Development Workflow
TDD principles for Python/FastAPI development with pytest.
## When to Activate
- Writing new features or functionality
- Fixing bugs or issues
- Refactoring existing code
- Adding API endpoints
- Creating new field extractors or normalizers
## Core Principles
### 1. Tests BEFORE Code
ALWAYS write tests first, then implement code to make tests pass.
### 2. Coverage Requirements
- Minimum 80% coverage (unit + integration + E2E)
- All edge cases covered
- Error scenarios tested
- Boundary conditions verified
### 3. Test Types
#### Unit Tests
- Individual functions and utilities
- Normalizers and validators
- Parsers and extractors
- Pure functions
#### Integration Tests
- API endpoints
- Database operations
- OCR + YOLO pipeline
- Service interactions
#### E2E Tests
- Complete inference pipeline
- PDF → Fields workflow
- API health and inference endpoints
## TDD Workflow Steps
### Step 1: Write User Journeys
```
As a [role], I want to [action], so that [benefit]
Example:
As an invoice processor, I want to extract Bankgiro from payment_line,
so that I can cross-validate OCR results.
```
### Step 2: Generate Test Cases
For each user journey, create comprehensive test cases:
```python
import pytest
class TestPaymentLineParser:
"""Tests for payment_line parsing and field extraction."""
def test_parse_payment_line_extracts_bankgiro(self):
"""Should extract Bankgiro from valid payment line."""
# Test implementation
pass
def test_parse_payment_line_handles_missing_checksum(self):
"""Should handle payment lines without checksum."""
pass
def test_parse_payment_line_validates_checksum(self):
"""Should validate checksum when present."""
pass
def test_parse_payment_line_returns_none_for_invalid(self):
"""Should return None for invalid payment lines."""
pass
```
### Step 3: Run Tests (They Should Fail)
```bash
pytest tests/test_ocr/test_machine_code_parser.py -v
# Tests should fail - we haven't implemented yet
```
### Step 4: Implement Code
Write minimal code to make tests pass:
```python
def parse_payment_line(line: str) -> PaymentLineData | None:
"""Parse Swedish payment line and extract fields."""
# Implementation guided by tests
pass
```
### Step 5: Run Tests Again
```bash
pytest tests/test_ocr/test_machine_code_parser.py -v
# Tests should now pass
```
### Step 6: Refactor
Improve code quality while keeping tests green:
- Remove duplication
- Improve naming
- Optimize performance
- Enhance readability
### Step 7: Verify Coverage
```bash
pytest --cov=src --cov-report=term-missing
# Verify 80%+ coverage achieved
```
## Testing Patterns
### Unit Test Pattern (pytest)
```python
import pytest
from src.normalize.bankgiro_normalizer import normalize_bankgiro
class TestBankgiroNormalizer:
"""Tests for Bankgiro normalization."""
def test_normalize_removes_hyphens(self):
"""Should remove hyphens from Bankgiro."""
result = normalize_bankgiro("123-4567")
assert result == "1234567"
def test_normalize_removes_spaces(self):
"""Should remove spaces from Bankgiro."""
result = normalize_bankgiro("123 4567")
assert result == "1234567"
def test_normalize_validates_length(self):
"""Should validate Bankgiro is 7-8 digits."""
result = normalize_bankgiro("123456") # 6 digits
assert result is None
def test_normalize_validates_checksum(self):
"""Should validate Luhn checksum."""
result = normalize_bankgiro("1234568") # Invalid checksum
assert result is None
@pytest.mark.parametrize("input_value,expected", [
("123-4567", "1234567"),
("1234567", "1234567"),
("123 4567", "1234567"),
("BG 123-4567", "1234567"),
])
def test_normalize_various_formats(self, input_value, expected):
"""Should handle various input formats."""
result = normalize_bankgiro(input_value)
assert result == expected
```
### API Integration Test Pattern
```python
import pytest
from fastapi.testclient import TestClient
from src.web.app import app
@pytest.fixture
def client():
return TestClient(app)
class TestHealthEndpoint:
"""Tests for /api/v1/health endpoint."""
def test_health_returns_200(self, client):
"""Should return 200 OK."""
response = client.get("/api/v1/health")
assert response.status_code == 200
def test_health_returns_status(self, client):
"""Should return health status."""
response = client.get("/api/v1/health")
data = response.json()
assert data["status"] == "healthy"
assert "model_loaded" in data
class TestInferEndpoint:
"""Tests for /api/v1/infer endpoint."""
def test_infer_requires_file(self, client):
"""Should require file upload."""
response = client.post("/api/v1/infer")
assert response.status_code == 422
def test_infer_rejects_non_pdf(self, client):
"""Should reject non-PDF files."""
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", b"not a pdf", "text/plain")}
)
assert response.status_code == 400
def test_infer_returns_fields(self, client, sample_invoice_pdf):
"""Should return extracted fields."""
with open(sample_invoice_pdf, "rb") as f:
response = client.post(
"/api/v1/infer",
files={"file": ("invoice.pdf", f, "application/pdf")}
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "fields" in data
```
### E2E Test Pattern
```python
import pytest
import httpx
from pathlib import Path
@pytest.fixture(scope="module")
def running_server():
"""Ensure server is running for E2E tests."""
# Server should be started before running E2E tests
base_url = "http://localhost:8000"
yield base_url
class TestInferencePipeline:
"""E2E tests for complete inference pipeline."""
def test_health_check(self, running_server):
"""Should pass health check."""
response = httpx.get(f"{running_server}/api/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["model_loaded"] is True
def test_pdf_inference_returns_fields(self, running_server):
"""Should extract fields from PDF."""
pdf_path = Path("tests/fixtures/sample_invoice.pdf")
with open(pdf_path, "rb") as f:
response = httpx.post(
f"{running_server}/api/v1/infer",
files={"file": ("invoice.pdf", f, "application/pdf")}
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "fields" in data
assert len(data["fields"]) > 0
def test_cross_validation_included(self, running_server):
"""Should include cross-validation for invoices with payment_line."""
pdf_path = Path("tests/fixtures/invoice_with_payment_line.pdf")
with open(pdf_path, "rb") as f:
response = httpx.post(
f"{running_server}/api/v1/infer",
files={"file": ("invoice.pdf", f, "application/pdf")}
)
data = response.json()
if data["fields"].get("payment_line"):
assert "cross_validation" in data
```
## Test File Organization
```
tests/
├── conftest.py # Shared fixtures
├── fixtures/ # Test data files
│ ├── sample_invoice.pdf
│ └── invoice_with_payment_line.pdf
├── test_cli/
│ └── test_infer.py
├── test_pdf/
│ ├── test_extractor.py
│ └── test_renderer.py
├── test_ocr/
│ ├── test_paddle_ocr.py
│ └── test_machine_code_parser.py
├── test_inference/
│ ├── test_pipeline.py
│ ├── test_yolo_detector.py
│ └── test_field_extractor.py
├── test_normalize/
│ ├── test_bankgiro_normalizer.py
│ ├── test_date_normalizer.py
│ └── test_amount_normalizer.py
├── test_web/
│ ├── test_routes.py
│ └── test_services.py
└── e2e/
└── test_inference_e2e.py
```
## Mocking External Services
### Mock PaddleOCR
```python
import pytest
from unittest.mock import Mock, patch
@pytest.fixture
def mock_paddle_ocr():
"""Mock PaddleOCR for unit tests."""
with patch("src.ocr.paddle_ocr.PaddleOCR") as mock:
instance = Mock()
instance.ocr.return_value = [
[
[[[0, 0], [100, 0], [100, 20], [0, 20]], ("Invoice Number", 0.95)],
[[[0, 30], [100, 30], [100, 50], [0, 50]], ("INV-2024-001", 0.98)]
]
]
mock.return_value = instance
yield instance
```
### Mock YOLO Model
```python
@pytest.fixture
def mock_yolo_model():
"""Mock YOLO model for unit tests."""
with patch("src.inference.yolo_detector.YOLO") as mock:
instance = Mock()
# Mock detection results
instance.return_value = Mock(
boxes=Mock(
xyxy=[[10, 20, 100, 50]],
conf=[0.95],
cls=[0] # invoice_number class
)
)
mock.return_value = instance
yield instance
```
### Mock Database
```python
@pytest.fixture
def mock_db_connection():
"""Mock database connection for unit tests."""
with patch("src.data.db.get_db_connection") as mock:
conn = Mock()
cursor = Mock()
cursor.fetchall.return_value = [
("doc-123", "processed", {"invoice_number": "INV-001"})
]
cursor.fetchone.return_value = ("doc-123",)
conn.cursor.return_value.__enter__ = Mock(return_value=cursor)
conn.cursor.return_value.__exit__ = Mock(return_value=False)
mock.return_value.__enter__ = Mock(return_value=conn)
mock.return_value.__exit__ = Mock(return_value=False)
yield conn
```
## Test Coverage Verification
### Run Coverage Report
```bash
# Run with coverage
pytest --cov=src --cov-report=term-missing
# Generate HTML report
pytest --cov=src --cov-report=html
# Open htmlcov/index.html in browser
```
### Coverage Configuration (pyproject.toml)
```toml
[tool.coverage.run]
source = ["src"]
omit = ["*/__init__.py", "*/test_*.py"]
[tool.coverage.report]
fail_under = 80
show_missing = true
exclude_lines = [
"pragma: no cover",
"if TYPE_CHECKING:",
"raise NotImplementedError",
]
```
## Common Testing Mistakes to Avoid
### WRONG: Testing Implementation Details
```python
# Don't test internal state
def test_parser_internal_state():
parser = PaymentLineParser()
parser._parse("...")
assert parser._groups == [...] # Internal state
```
### CORRECT: Test Public Interface
```python
# Test what users see
def test_parser_extracts_bankgiro():
result = parse_payment_line("...")
assert result.bankgiro == "1234567"
```
### WRONG: No Test Isolation
```python
# Tests depend on each other
class TestDocuments:
def test_creates_document(self):
create_document(...) # Creates in DB
def test_updates_document(self):
update_document(...) # Depends on previous test
```
### CORRECT: Independent Tests
```python
# Each test sets up its own data
class TestDocuments:
def test_creates_document(self, mock_db):
result = create_document(...)
assert result.id is not None
def test_updates_document(self, mock_db):
# Create own test data
doc = create_document(...)
result = update_document(doc.id, ...)
assert result.status == "updated"
```
### WRONG: Testing Too Much
```python
# One test doing everything
def test_full_invoice_processing():
# Load PDF
# Extract images
# Run YOLO
# Run OCR
# Normalize fields
# Save to DB
# Return response
```
### CORRECT: Focused Tests
```python
def test_yolo_detects_invoice_number():
"""Test only YOLO detection."""
result = detector.detect(image)
assert any(d.label == "invoice_number" for d in result)
def test_ocr_extracts_text():
"""Test only OCR extraction."""
result = ocr.extract(image, bbox)
assert result == "INV-2024-001"
def test_normalizer_formats_date():
"""Test only date normalization."""
result = normalize_date("2024-01-15")
assert result == "2024-01-15"
```
## Fixtures (conftest.py)
```python
import pytest
from pathlib import Path
from fastapi.testclient import TestClient
@pytest.fixture
def sample_invoice_pdf(tmp_path: Path) -> Path:
"""Create sample invoice PDF for testing."""
pdf_path = tmp_path / "invoice.pdf"
# Copy from fixtures or create minimal PDF
src = Path("tests/fixtures/sample_invoice.pdf")
if src.exists():
pdf_path.write_bytes(src.read_bytes())
return pdf_path
@pytest.fixture
def client():
"""FastAPI test client."""
from src.web.app import app
return TestClient(app)
@pytest.fixture
def sample_payment_line() -> str:
"""Sample Swedish payment line for testing."""
return "1234567#0000000012345#230115#00012345678901234567#1"
```
## Continuous Testing
### Watch Mode During Development
```bash
# Using pytest-watch
ptw -- tests/test_ocr/
# Tests run automatically on file changes
```
### Pre-Commit Hook
```bash
# .pre-commit-config.yaml
repos:
- repo: local
hooks:
- id: pytest
name: pytest
entry: pytest --tb=short -q
language: system
pass_filenames: false
always_run: true
```
### CI/CD Integration (GitHub Actions)
```yaml
- name: Run Tests
run: |
pytest --cov=src --cov-report=xml
- name: Upload Coverage
uses: codecov/codecov-action@v3
with:
file: coverage.xml
```
## Best Practices
1. **Write Tests First** - Always TDD
2. **One Assert Per Test** - Focus on single behavior
3. **Descriptive Test Names** - `test_<what>_<condition>_<expected>`
4. **Arrange-Act-Assert** - Clear test structure
5. **Mock External Dependencies** - Isolate unit tests
6. **Test Edge Cases** - None, empty, invalid, boundary
7. **Test Error Paths** - Not just happy paths
8. **Keep Tests Fast** - Unit tests < 50ms each
9. **Clean Up After Tests** - Use fixtures with cleanup
10. **Review Coverage Reports** - Identify gaps
## Success Metrics
- 80%+ code coverage achieved
- All tests passing (green)
- No skipped or disabled tests
- Fast test execution (< 60s for unit tests)
- E2E tests cover critical inference flow
- Tests catch bugs before production
---
**Remember**: Tests are not optional. They are the safety net that enables confident refactoring, rapid development, and production reliability.

View File

@@ -0,0 +1,242 @@
# Verification Loop Skill
Comprehensive verification system for Python/FastAPI development.
## When to Use
Invoke this skill:
- After completing a feature or significant code change
- Before creating a PR
- When you want to ensure quality gates pass
- After refactoring
- Before deployment
## Verification Phases
### Phase 1: Type Check
```bash
# Run mypy type checker
mypy src/ --ignore-missing-imports 2>&1 | head -30
```
Report all type errors. Fix critical ones before continuing.
### Phase 2: Lint Check
```bash
# Run ruff linter
ruff check src/ 2>&1 | head -30
# Auto-fix if desired
ruff check src/ --fix
```
Check for:
- Unused imports
- Code style violations
- Common Python anti-patterns
### Phase 3: Test Suite
```bash
# Run tests with coverage
pytest --cov=src --cov-report=term-missing -q 2>&1 | tail -50
# Run specific test file
pytest tests/test_ocr/test_machine_code_parser.py -v
# Run with short traceback
pytest -x --tb=short
```
Report:
- Total tests: X
- Passed: X
- Failed: X
- Coverage: X%
- Target: 80% minimum
### Phase 4: Security Scan
```bash
# Check for hardcoded secrets
grep -rn "password\s*=" --include="*.py" src/ 2>/dev/null | grep -v "db_password:" | head -10
grep -rn "api_key\s*=" --include="*.py" src/ 2>/dev/null | head -10
grep -rn "sk-" --include="*.py" src/ 2>/dev/null | head -10
# Check for print statements (should use logging)
grep -rn "print(" --include="*.py" src/ 2>/dev/null | head -10
# Check for bare except
grep -rn "except:" --include="*.py" src/ 2>/dev/null | head -10
# Check for SQL injection risks (f-strings in execute)
grep -rn 'execute(f"' --include="*.py" src/ 2>/dev/null | head -10
grep -rn "execute(f'" --include="*.py" src/ 2>/dev/null | head -10
```
### Phase 5: Import Check
```bash
# Verify all imports work
python -c "from src.web.app import app; print('Web app OK')"
python -c "from src.inference.pipeline import InferencePipeline; print('Pipeline OK')"
python -c "from src.ocr.machine_code_parser import parse_payment_line; print('Parser OK')"
```
### Phase 6: Diff Review
```bash
# Show what changed
git diff --stat
git diff HEAD --name-only
# Show staged changes
git diff --staged --stat
```
Review each changed file for:
- Unintended changes
- Missing error handling
- Potential edge cases
- Missing type hints
- Mutable default arguments
### Phase 7: API Smoke Test (if server running)
```bash
# Health check
curl -s http://localhost:8000/api/v1/health | python -m json.tool
# Verify response format
curl -s http://localhost:8000/api/v1/health | grep -q "healthy" && echo "Health: OK" || echo "Health: FAIL"
```
## Output Format
After running all phases, produce a verification report:
```
VERIFICATION REPORT
==================
Types: [PASS/FAIL] (X errors)
Lint: [PASS/FAIL] (X warnings)
Tests: [PASS/FAIL] (X/Y passed, Z% coverage)
Security: [PASS/FAIL] (X issues)
Imports: [PASS/FAIL]
Diff: [X files changed]
Overall: [READY/NOT READY] for PR
Issues to Fix:
1. ...
2. ...
```
## Quick Commands
```bash
# Full verification (WSL)
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports && ruff check src/ && pytest -x --tb=short"
# Type check only
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports"
# Tests only
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src -q"
```
## Verification Checklist
### Before Commit
- [ ] mypy passes (no type errors)
- [ ] ruff check passes (no lint errors)
- [ ] All tests pass
- [ ] No print() statements in production code
- [ ] No hardcoded secrets
- [ ] No bare `except:` clauses
- [ ] No SQL injection risks (f-strings in queries)
- [ ] Coverage >= 80% for changed code
### Before PR
- [ ] All above checks pass
- [ ] git diff reviewed for unintended changes
- [ ] New code has tests
- [ ] Type hints on all public functions
- [ ] Docstrings on public APIs
- [ ] No TODO/FIXME for critical items
### Before Deployment
- [ ] All above checks pass
- [ ] E2E tests pass
- [ ] Health check returns healthy
- [ ] Model loaded successfully
- [ ] No server errors in logs
## Common Issues and Fixes
### Type Error: Missing return type
```python
# Before
def process(data):
return result
# After
def process(data: dict) -> InferenceResult:
return result
```
### Lint Error: Unused import
```python
# Remove unused imports or add to __all__
```
### Security: print() in production
```python
# Before
print(f"Processing {doc_id}")
# After
logger.info(f"Processing {doc_id}")
```
### Security: Bare except
```python
# Before
except:
pass
# After
except Exception as e:
logger.error(f"Error: {e}")
raise
```
### Security: SQL injection
```python
# Before (DANGEROUS)
cur.execute(f"SELECT * FROM docs WHERE id = '{user_input}'")
# After (SAFE)
cur.execute("SELECT * FROM docs WHERE id = %s", (user_input,))
```
## Continuous Mode
For long sessions, run verification after major changes:
```markdown
Checkpoints:
- After completing each function
- After finishing a module
- Before moving to next task
- Every 15-20 minutes of coding
Run: /verify
```
## Integration with Other Skills
| Skill | Purpose |
|-------|---------|
| code-review | Detailed code analysis |
| security-review | Deep security audit |
| tdd-workflow | Test coverage |
| build-fix | Fix errors incrementally |
This skill provides quick, comprehensive verification. Use specialized skills for deeper analysis.

22
.env.example Normal file
View File

@@ -0,0 +1,22 @@
# Database Configuration
# Copy this file to .env and fill in your actual values
# PostgreSQL Database
DB_HOST=192.168.68.31
DB_PORT=5432
DB_NAME=docmaster
DB_USER=docmaster
DB_PASSWORD=your_password_here
# Model Configuration (optional)
# MODEL_PATH=runs/train/invoice_fields/weights/best.pt
# CONFIDENCE_THRESHOLD=0.5
# Server Configuration (optional)
# SERVER_HOST=0.0.0.0
# SERVER_PORT=8000
# Auto-labeling Configuration (optional)
# AUTOLABEL_WORKERS=2
# AUTOLABEL_DPI=150
# AUTOLABEL_MIN_CONFIDENCE=0.5

317
CHANGELOG.md Normal file
View File

@@ -0,0 +1,317 @@
# Changelog
All notable changes to the Invoice Field Extraction project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added - Phase 1: Security & Infrastructure (2026-01-22)
#### Security Enhancements
- **Environment Variable Management**: Added `python-dotenv` for secure configuration management
- Created `.env.example` template file for configuration reference
- Created `.env` file for actual credentials (gitignored)
- Updated `config.py` to load database password from environment variables
- Added validation to ensure `DB_PASSWORD` is set at startup
- Files modified: `config.py`, `requirements.txt`
- New files: `.env`, `.env.example`
- Tests: `tests/test_config.py` (7 tests, all passing)
- **SQL Injection Prevention**: Fixed SQL injection vulnerabilities in database queries
- Replaced f-string formatting with parameterized queries in `LIMIT` clauses
- Updated `get_all_documents_summary()` to use `%s` placeholder for LIMIT parameter
- Updated `get_failed_matches()` to use `%s` placeholder for LIMIT parameter
- Files modified: `src/data/db.py` (lines 246, 298)
- Tests: `tests/test_db_security.py` (9 tests, all passing)
#### Code Quality
- **Exception Hierarchy**: Created comprehensive custom exception system
- Added base class `InvoiceExtractionError` with message and details support
- Added specific exception types:
- `PDFProcessingError` - PDF rendering/conversion errors
- `OCRError` - OCR processing errors
- `ModelInferenceError` - YOLO model errors
- `FieldValidationError` - Field validation errors (with field-specific attributes)
- `DatabaseError` - Database operation errors
- `ConfigurationError` - Configuration errors
- `PaymentLineParseError` - Payment line parsing errors
- `CustomerNumberParseError` - Customer number parsing errors
- `DataLoadError` - Data loading errors
- `AnnotationError` - Annotation generation errors
- New file: `src/exceptions.py`
- Tests: `tests/test_exceptions.py` (16 tests, all passing)
### Testing
- Added 32 new tests across 3 test files
- Configuration tests: 7 tests
- SQL injection prevention tests: 9 tests
- Exception hierarchy tests: 16 tests
- All tests passing (32/32)
### Documentation
- Created `docs/CODE_REVIEW_REPORT.md` - Comprehensive code quality analysis (550+ lines)
- Created `docs/REFACTORING_PLAN.md` - Detailed 3-phase refactoring plan (600+ lines)
- Created `CHANGELOG.md` - Project changelog (this file)
### Changed
- **Configuration Loading**: Database configuration now loads from environment variables instead of hardcoded values
- Breaking change: Requires `.env` file with `DB_PASSWORD` set
- Migration: Copy `.env.example` to `.env` and set your database password
### Security
- **Fixed**: Database password no longer stored in plain text in `config.py`
- **Fixed**: SQL injection vulnerabilities in LIMIT clauses (2 instances)
### Technical Debt Addressed
- Eliminated security vulnerability: plaintext password storage
- Reduced SQL injection attack surface
- Improved error handling granularity with custom exceptions
---
### Added - Phase 2: Parser Refactoring (2026-01-22)
#### Unified Parser Modules
- **Payment Line Parser**: Created dedicated payment line parsing module
- Handles Swedish payment line format: `# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#`
- Tolerates common OCR errors: spaces in numbers, missing symbols, spaces in check digits
- Supports 4 parsing patterns: full format, no amount, alternative, account-only
- Returns structured `PaymentLineData` with parsed fields
- New file: `src/inference/payment_line_parser.py` (90 lines, 92% coverage)
- Tests: `tests/test_payment_line_parser.py` (23 tests, all passing)
- Eliminates 1st code duplication (payment line parsing logic)
- **Customer Number Parser**: Created dedicated customer number parsing module
- Handles Swedish customer number formats: `JTY 576-3`, `DWQ 211-X`, `FFL 019N`, etc.
- Uses Strategy Pattern with 5 pattern classes:
- `LabeledPattern` - Explicit labels (highest priority, 0.98 confidence)
- `DashFormatPattern` - Standard format with dash (0.95 confidence)
- `NoDashFormatPattern` - Format without dash, adds dash automatically (0.90 confidence)
- `CompactFormatPattern` - Compact format without spaces (0.75 confidence)
- `GenericAlphanumericPattern` - Fallback generic pattern (variable confidence)
- Excludes Swedish postal codes (`SE XXX XX` format)
- Returns highest confidence match
- New file: `src/inference/customer_number_parser.py` (154 lines, 92% coverage)
- Tests: `tests/test_customer_number_parser.py` (32 tests, all passing)
- Reduces `_normalize_customer_number` complexity (127 lines → will use 5-10 lines after integration)
### Testing Summary
**Phase 1 Tests** (32 tests):
- Configuration tests: 7 tests ([test_config.py](tests/test_config.py))
- SQL injection prevention tests: 9 tests ([test_db_security.py](tests/test_db_security.py))
- Exception hierarchy tests: 16 tests ([test_exceptions.py](tests/test_exceptions.py))
**Phase 2 Tests** (121 tests):
- Payment line parser tests: 23 tests ([test_payment_line_parser.py](tests/test_payment_line_parser.py))
- Standard parsing, OCR error handling, real-world examples, edge cases
- Coverage: 92%
- Customer number parser tests: 32 tests ([test_customer_number_parser.py](tests/test_customer_number_parser.py))
- Pattern matching (DashFormat, NoDashFormat, Compact, Labeled)
- Real-world examples, edge cases, Swedish postal code exclusion
- Coverage: 92%
- Field extractor integration tests: 45 tests ([test_field_extractor.py](src/inference/test_field_extractor.py))
- Validates backward compatibility with existing code
- Tests for invoice numbers, bankgiro, plusgiro, amounts, OCR, dates, payment lines, customer numbers
- Pipeline integration tests: 21 tests ([test_pipeline.py](src/inference/test_pipeline.py))
- Cross-validation, payment line parsing, field overrides
**Total**: 153 tests, 100% passing, 4.50s runtime
### Code Quality
- **Eliminated Code Duplication**: Payment line parsing previously in 3 places, now unified in 1 module
- **Improved Maintainability**: Strategy Pattern makes customer number patterns easy to extend
- **Better Test Coverage**: New parsers have 92% coverage vs original 10% in field_extractor.py
#### Parser Integration into field_extractor.py (2026-01-22)
- **field_extractor.py Integration**: Successfully integrated new parsers
- Added `PaymentLineParser` and `CustomerNumberParser` instances (lines 99-101)
- Replaced `_normalize_payment_line` method: 74 lines → 3 lines (lines 640-657)
- Replaced `_normalize_customer_number` method: 127 lines → 3 lines (lines 697-707)
- All 45 existing tests pass (100% backward compatibility maintained)
- Tests run time: 4.21 seconds
- File: `src/inference/field_extractor.py`
#### Parser Integration into pipeline.py (2026-01-22)
- **pipeline.py Integration**: Successfully integrated PaymentLineParser
- Added `PaymentLineParser` import (line 15)
- Added `payment_line_parser` instance initialization (line 128)
- Replaced `_parse_machine_readable_payment_line` method: 36 lines → 6 lines (lines 219-233)
- All 21 existing tests pass (100% backward compatibility maintained)
- Tests run time: 4.00 seconds
- File: `src/inference/pipeline.py`
### Phase 2 Status: **COMPLETED** ✅
- [x] Create unified `payment_line_parser` module ✅
- [x] Create unified `customer_number_parser` module ✅
- [x] Refactor `field_extractor.py` to use new parsers ✅
- [x] Refactor `pipeline.py` to use new parsers ✅
- [x] Comprehensive test suite (153 tests, 100% passing) ✅
### Achieved Impact
- Eliminate code duplication: 3 implementations → 1 ✅ (payment_line unified across field_extractor.py, pipeline.py, tests)
- Reduce `_normalize_payment_line` complexity in field_extractor.py: 74 lines → 3 lines ✅
- Reduce `_normalize_customer_number` complexity in field_extractor.py: 127 lines → 3 lines ✅
- Reduce `_parse_machine_readable_payment_line` complexity in pipeline.py: 36 lines → 6 lines ✅
- Total lines of code eliminated: 201 lines reduced to 12 lines (94% reduction) ✅
- Improve test coverage: New parser modules have 92% coverage (vs original 10% in field_extractor.py)
- Simplify maintenance: Pattern-based approach makes extension easy
- 100% backward compatibility: All 66 existing tests pass (45 field_extractor + 21 pipeline)
---
## Phase 3: Performance & Documentation (2026-01-22)
### Added
#### Configuration Constants Extraction
- **Created `src/inference/constants.py`**: Centralized configuration constants
- Detection & model configuration (confidence thresholds, IOU)
- Image processing configuration (DPI, scaling factors)
- Customer number parser confidence scores
- Field extraction confidence multipliers
- Account type detection thresholds
- Pattern matching constants
- 90 lines of well-documented constants with usage notes
- Eliminates ~15 hardcoded magic numbers across codebase
- File: [src/inference/constants.py](src/inference/constants.py)
#### Performance Optimization Documentation
- **Created `docs/PERFORMANCE_OPTIMIZATION.md`**: Comprehensive performance guide (400+ lines)
- **Batch Processing Optimization**: Parallel processing strategies, already-implemented dual pool system
- **Database Query Optimization**: Connection pooling recommendations, index strategies
- **Caching Strategies**: Model loading cache, parser reuse (already optimal), OCR result caching
- **Memory Management**: Explicit cleanup, generator patterns, context managers
- **Profiling Guidelines**: cProfile, memory_profiler, py-spy recommendations
- **Benchmarking Scripts**: Ready-to-use performance measurement code
- **Priority Roadmap**: High/Medium/Low priority optimizations with effort estimates
- Expected impact: 2-5x throughput improvement for batch processing
- File: [docs/PERFORMANCE_OPTIMIZATION.md](docs/PERFORMANCE_OPTIMIZATION.md)
### Phase 3 Status: **COMPLETED** ✅
- [x] Configuration constants extraction ✅
- [x] Performance optimization analysis ✅
- [x] Batch processing optimization recommendations ✅
- [x] Database optimization strategies ✅
- [x] Caching and memory management guidelines ✅
- [x] Profiling and benchmarking documentation ✅
### Deliverables
**New Files** (2 files):
1. `src/inference/constants.py` (90 lines) - Centralized configuration constants
2. `docs/PERFORMANCE_OPTIMIZATION.md` (400+ lines) - Performance optimization guide
**Impact**:
- Eliminates 15+ hardcoded magic numbers
- Provides clear optimization roadmap
- Documents existing performance features
- Identifies quick wins (connection pooling, indexes)
- Long-term strategy (caching, profiling)
---
## Notes
### Breaking Changes
- **v2.x**: Requires `.env` file with database credentials
- Action required: Create `.env` file based on `.env.example`
- Affected: All deployments, CI/CD pipelines
### Migration Guide
#### From v1.x to v2.x (Environment Variables)
1. Copy `.env.example` to `.env`:
```bash
cp .env.example .env
```
2. Edit `.env` and set your database password:
```
DB_PASSWORD=your_actual_password_here
```
3. Install new dependency:
```bash
pip install python-dotenv
```
4. Verify configuration loads correctly:
```bash
python -c "import config; print('Config loaded successfully')"
```
## Summary of All Work Completed
### Files Created (13 new files)
**Phase 1** (3 files):
1. `.env` - Environment variables for database credentials
2. `.env.example` - Template for environment configuration
3. `src/exceptions.py` - Custom exception hierarchy (35 lines, 66% coverage)
**Phase 2** (7 files):
4. `src/inference/payment_line_parser.py` - Unified payment line parsing (90 lines, 92% coverage)
5. `src/inference/customer_number_parser.py` - Unified customer number parsing (154 lines, 92% coverage)
6. `tests/test_config.py` - Configuration tests (7 tests)
7. `tests/test_db_security.py` - SQL injection prevention tests (9 tests)
8. `tests/test_exceptions.py` - Exception hierarchy tests (16 tests)
9. `tests/test_payment_line_parser.py` - Payment line parser tests (23 tests)
10. `tests/test_customer_number_parser.py` - Customer number parser tests (32 tests)
**Phase 3** (2 files):
11. `src/inference/constants.py` - Centralized configuration constants (90 lines)
12. `docs/PERFORMANCE_OPTIMIZATION.md` - Performance optimization guide (400+ lines)
**Documentation** (1 file):
13. `CHANGELOG.md` - This file (260+ lines of detailed documentation)
### Files Modified (4 files)
1. `config.py` - Added environment variable loading with python-dotenv
2. `src/data/db.py` - Fixed 2 SQL injection vulnerabilities (lines 246, 298)
3. `src/inference/field_extractor.py` - Integrated new parsers (reduced 201 lines to 6 lines)
4. `src/inference/pipeline.py` - Integrated PaymentLineParser (reduced 36 lines to 6 lines)
5. `requirements.txt` - Added python-dotenv dependency
### Test Summary
- **Total tests**: 153 tests across 7 test files
- **Passing**: 153 (100%)
- **Failing**: 0
- **Runtime**: 4.50 seconds
- **Coverage**:
- New parser modules: 92%
- Config module: 100%
- Exception module: 66%
- DB security coverage: 18% (focused on parameterized queries)
### Code Metrics
- **Lines eliminated**: 237 lines of duplicated/complex code → 18 lines (92% reduction)
- field_extractor.py: 201 lines → 6 lines
- pipeline.py: 36 lines → 6 lines
- **New code added**: 279 lines of well-tested parser code
- **Net impact**: Replaced 237 lines of duplicate code with 279 lines of unified, tested code (+42 lines, but -3 implementations)
- **Test coverage improvement**: 0% → 92% for parser logic
### Performance Impact
- Configuration loading: Negligible (<1ms overhead for .env parsing)
- SQL queries: No performance change (parameterized queries are standard practice)
- Parser refactoring: No performance degradation (logic simplified, not changed)
- Exception handling: Minimal overhead (only when exceptions are raised)
### Security Improvements
- Eliminated plaintext password storage
- Fixed 2 SQL injection vulnerabilities
- Added input validation in database layer
### Maintainability Improvements
- Eliminated code duplication (3 implementations 1)
- Strategy Pattern enables easy extension of customer number formats
- Comprehensive test suite (153 tests) ensures safe refactoring
- 100% backward compatibility maintained
- Custom exception hierarchy for granular error handling

364
README.md
View File

@@ -54,8 +54,12 @@
- **数据库存储**: 标注结果存储在 PostgreSQL支持增量处理和断点续传 - **数据库存储**: 标注结果存储在 PostgreSQL支持增量处理和断点续传
- **YOLO 检测**: 使用 YOLOv11 检测发票字段区域 - **YOLO 检测**: 使用 YOLOv11 检测发票字段区域
- **OCR 识别**: 使用 PaddleOCR v5 提取检测区域的文本 - **OCR 识别**: 使用 PaddleOCR v5 提取检测区域的文本
- **统一解析器**: payment_line 和 customer_number 采用独立解析器模块
- **交叉验证**: payment_line 数据与单独检测字段交叉验证,优先采用 payment_line 值
- **文档类型识别**: 自动区分 invoice (有 payment_line) 和 letter (无 payment_line)
- **Web 应用**: 提供 REST API 和可视化界面 - **Web 应用**: 提供 REST API 和可视化界面
- **增量训练**: 支持在已训练模型基础上继续训练 - **增量训练**: 支持在已训练模型基础上继续训练
- **内存优化**: 支持低内存模式训练 (--low-memory)
## 支持的字段 ## 支持的字段
@@ -69,6 +73,8 @@
| 5 | plusgiro | Plusgiro 号码 | | 5 | plusgiro | Plusgiro 号码 |
| 6 | amount | 金额 | | 6 | amount | 金额 |
| 7 | supplier_organisation_number | 供应商组织号 | | 7 | supplier_organisation_number | 供应商组织号 |
| 8 | payment_line | 支付行 (机器可读格式) |
| 9 | customer_number | 客户编号 |
## 安装 ## 安装
@@ -132,8 +138,24 @@ python -m src.cli.train \
--model yolo11n.pt \ --model yolo11n.pt \
--epochs 100 \ --epochs 100 \
--batch 16 \ --batch 16 \
--name invoice_yolo11n_full \ --name invoice_fields \
--dpi 150 --dpi 150
# 低内存模式 (适用于内存不足场景)
python -m src.cli.train \
--model yolo11n.pt \
--epochs 100 \
--name invoice_fields \
--low-memory \
--workers 4 \
--no-cache
# 从检查点恢复训练 (训练中断后)
python -m src.cli.train \
--model runs/train/invoice_fields/weights/last.pt \
--epochs 100 \
--name invoice_fields \
--resume
``` ```
### 4. 增量训练 ### 4. 增量训练
@@ -164,26 +186,46 @@ python -m src.cli.train \
```bash ```bash
# 命令行推理 # 命令行推理
python -m src.cli.infer \ python -m src.cli.infer \
--model runs/train/invoice_yolo11n_full/weights/best.pt \ --model runs/train/invoice_fields/weights/best.pt \
--input path/to/invoice.pdf \ --input path/to/invoice.pdf \
--output result.json \ --output result.json \
--gpu --gpu
# 批量推理
python -m src.cli.infer \
--model runs/train/invoice_fields/weights/best.pt \
--input invoices/*.pdf \
--output results/ \
--gpu
``` ```
**推理结果包含**:
- `fields`: 提取的字段值 (InvoiceNumber, Amount, payment_line, customer_number 等)
- `confidence`: 各字段的置信度
- `document_type`: 文档类型 ("invoice" 或 "letter")
- `cross_validation`: payment_line 交叉验证结果 (如果有)
### 6. Web 应用 ### 6. Web 应用
**在 WSL 环境中启动**:
```bash ```bash
# 启动 Web 服务器 # 方法 1: 从 Windows PowerShell 启动 (推荐)
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
# 方法 2: 在 WSL 内启动
conda activate invoice-py311
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
python run_server.py --port 8000 python run_server.py --port 8000
# 开发模式 (自动重载) # 方法 3: 使用启动脚本
python run_server.py --debug --reload ./start_web.sh
# 禁用 GPU
python run_server.py --no-gpu
``` ```
访问 **http://localhost:8000** 使用 Web 界面。 **服务启动后**:
- 访问 **http://localhost:8000** 使用 Web 界面
- 服务会自动加载模型 `runs/train/invoice_fields/weights/best.pt`
- GPU 默认启用,置信度阈值 0.5
#### Web API 端点 #### Web API 端点
@@ -194,6 +236,33 @@ python run_server.py --no-gpu
| POST | `/api/v1/infer` | 上传文件并推理 | | POST | `/api/v1/infer` | 上传文件并推理 |
| GET | `/api/v1/results/{filename}` | 获取可视化图片 | | GET | `/api/v1/results/{filename}` | 获取可视化图片 |
#### API 响应格式
```json
{
"status": "success",
"result": {
"document_id": "abc123",
"document_type": "invoice",
"fields": {
"InvoiceNumber": "12345",
"Amount": "1234.56",
"payment_line": "# 94228110015950070 # > 48666036#14#",
"customer_number": "UMJ 436-R"
},
"confidence": {
"InvoiceNumber": 0.95,
"Amount": 0.92
},
"cross_validation": {
"is_valid": true,
"ocr_match": true,
"amount_match": true
}
}
}
```
## 训练配置 ## 训练配置
### YOLO 训练参数 ### YOLO 训练参数
@@ -210,6 +279,10 @@ Options:
--name 训练名称 --name 训练名称
--limit 限制文档数 (用于测试) --limit 限制文档数 (用于测试)
--device 设备 (0=GPU, cpu) --device 设备 (0=GPU, cpu)
--resume 从检查点恢复训练
--low-memory 启用低内存模式 (batch=8, workers=4, no-cache)
--workers 数据加载 worker 数 (默认: 8)
--cache 缓存图像到内存
``` ```
### 训练最佳实践 ### 训练最佳实践
@@ -236,14 +309,28 @@ Options:
### 训练结果示例 ### 训练结果示例
使用约 10,000 张训练图片100 epochs 后的结果: **最新训练结果** (100 epochs, 2026-01-22):
| 指标 | 值 | | 指标 | 值 |
|------|-----| |------|-----|
| **mAP@0.5** | 98.7% | | **mAP@0.5** | 93.5% |
| **mAP@0.5-0.95** | 87.4% | | **mAP@0.5-0.95** | 83.0% |
| **Precision** | 97.5% | | **训练集** | ~10,000 张标注图片 |
| **Recall** | 95.5% | | **字段类型** | 10 个字段 (新增 payment_line, customer_number) |
| **模型位置** | `runs/train/invoice_fields/weights/best.pt` |
**各字段检测性能**:
- 发票基础信息 (InvoiceNumber, InvoiceDate, InvoiceDueDate): >95% mAP
- 支付信息 (OCR, Bankgiro, Plusgiro, Amount): >90% mAP
- 组织信息 (supplier_org_number, customer_number): >85% mAP
- 支付行 (payment_line): >80% mAP
**模型文件**:
```
runs/train/invoice_fields/weights/
├── best.pt # 最佳模型 (mAP@0.5 最高) ⭐ 推荐用于生产
└── last.pt # 最后检查点 (用于继续训练)
```
> 注:目前仍在持续标注更多数据,预计最终将有 25,000+ 张标注图片用于训练。 > 注:目前仍在持续标注更多数据,预计最终将有 25,000+ 张标注图片用于训练。
@@ -262,15 +349,18 @@ invoice-master-poc-v2/
│ │ ├── renderer.py # 图像渲染 │ │ ├── renderer.py # 图像渲染
│ │ └── detector.py # 类型检测 │ │ └── detector.py # 类型检测
│ ├── ocr/ # PaddleOCR 封装 │ ├── ocr/ # PaddleOCR 封装
│ │ └── machine_code_parser.py # 机器可读付款行解析器
│ ├── normalize/ # 字段规范化 │ ├── normalize/ # 字段规范化
│ ├── matcher/ # 字段匹配 │ ├── matcher/ # 字段匹配
│ ├── yolo/ # YOLO 相关 │ ├── yolo/ # YOLO 相关
│ │ ├── annotation_generator.py │ │ ├── annotation_generator.py
│ │ └── db_dataset.py │ │ └── db_dataset.py
│ ├── inference/ # 推理管道 │ ├── inference/ # 推理管道
│ │ ├── pipeline.py │ │ ├── pipeline.py # 主推理流程
│ │ ├── yolo_detector.py │ │ ├── yolo_detector.py # YOLO 检测
│ │ ── field_extractor.py │ │ ── field_extractor.py # 字段提取
│ │ ├── payment_line_parser.py # 支付行解析器
│ │ └── customer_number_parser.py # 客户编号解析器
│ ├── processing/ # 多池处理架构 │ ├── processing/ # 多池处理架构
│ │ ├── worker_pool.py │ │ ├── worker_pool.py
│ │ ├── cpu_pool.py │ │ ├── cpu_pool.py
@@ -278,20 +368,33 @@ invoice-master-poc-v2/
│ │ ├── task_dispatcher.py │ │ ├── task_dispatcher.py
│ │ └── dual_pool_coordinator.py │ │ └── dual_pool_coordinator.py
│ ├── web/ # Web 应用 │ ├── web/ # Web 应用
│ │ ├── app.py # FastAPI 应用 │ │ ├── app.py # FastAPI 应用入口
│ │ ├── routes.py # API 路由 │ │ ├── routes.py # API 路由
│ │ ├── services.py # 业务逻辑 │ │ ├── services.py # 业务逻辑
│ │ ── schemas.py # 数据模型 │ │ ── schemas.py # 数据模型
│ └── config.py # 配置 ├── utils/ # 工具模块
│ │ ├── text_cleaner.py # 文本清理
│ │ ├── validators.py # 字段验证
│ │ ├── fuzzy_matcher.py # 模糊匹配
│ │ └── ocr_corrections.py # OCR 错误修正
│ └── data/ # 数据处理 │ └── data/ # 数据处理
├── tests/ # 测试文件
│ ├── ocr/ # OCR 模块测试
│ │ └── test_machine_code_parser.py
│ ├── inference/ # 推理模块测试
│ ├── normalize/ # 规范化模块测试
│ └── utils/ # 工具模块测试
├── docs/ # 文档
│ ├── REFACTORING_SUMMARY.md
│ └── TEST_COVERAGE_IMPROVEMENT.md
├── config.py # 配置文件 ├── config.py # 配置文件
├── run_server.py # Web 服务器启动脚本 ├── run_server.py # Web 服务器启动脚本
├── runs/ # 训练输出 ├── runs/ # 训练输出
│ └── train/ │ └── train/
│ └── invoice_yolo11n_full/ │ └── invoice_fields/
│ └── weights/ │ └── weights/
│ ├── best.pt │ ├── best.pt # 最佳模型
│ └── last.pt │ └── last.pt # 最后检查点
└── requirements.txt └── requirements.txt
``` ```
@@ -410,14 +513,15 @@ Options:
## Python API ## Python API
```python ```python
from src.inference import InferencePipeline from src.inference.pipeline import InferencePipeline
# 初始化 # 初始化
pipeline = InferencePipeline( pipeline = InferencePipeline(
model_path='runs/train/invoice_yolo11n_full/weights/best.pt', model_path='runs/train/invoice_fields/weights/best.pt',
confidence_threshold=0.3, confidence_threshold=0.25,
use_gpu=True, use_gpu=True,
dpi=150 dpi=150,
enable_fallback=True
) )
# 处理 PDF # 处理 PDF
@@ -427,26 +531,194 @@ result = pipeline.process_pdf('invoice.pdf')
result = pipeline.process_image('invoice.png') result = pipeline.process_image('invoice.png')
# 获取结果 # 获取结果
print(result.fields) # {'InvoiceNumber': '12345', 'Amount': '1234.56', ...} print(result.fields)
# {
# 'InvoiceNumber': '12345',
# 'Amount': '1234.56',
# 'payment_line': '# 94228110015950070 # > 48666036#14#',
# 'customer_number': 'UMJ 436-R',
# ...
# }
print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...} print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
print(result.to_json()) # JSON 格式输出 print(result.to_json()) # JSON 格式输出
# 访问交叉验证结果
if result.cross_validation:
print(f"OCR match: {result.cross_validation.ocr_match}")
print(f"Amount match: {result.cross_validation.amount_match}")
print(f"Details: {result.cross_validation.details}")
```
### 统一解析器使用
```python
from src.inference.payment_line_parser import PaymentLineParser
from src.inference.customer_number_parser import CustomerNumberParser
# Payment Line 解析
parser = PaymentLineParser()
result = parser.parse("# 94228110015950070 # 15658 00 8 > 48666036#14#")
print(f"OCR: {result.ocr_number}")
print(f"Amount: {result.amount}")
print(f"Account: {result.account_number}")
# Customer Number 解析
parser = CustomerNumberParser()
result = parser.parse("Said, Shakar Umj 436-R Billo")
print(f"Customer Number: {result}") # "UMJ 436-R"
```
## 测试
### 测试统计
| 指标 | 数值 |
|------|------|
| **测试总数** | 688 |
| **通过率** | 100% |
| **整体覆盖率** | 37% |
### 关键模块覆盖率
| 模块 | 覆盖率 | 测试数 |
|------|--------|--------|
| `machine_code_parser.py` | 65% | 79 |
| `payment_line_parser.py` | 85% | 45 |
| `customer_number_parser.py` | 90% | 32 |
### 运行测试
```bash
# 运行所有测试
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest"
# 运行并查看覆盖率
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src --cov-report=term-missing"
# 运行特定模块测试
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest tests/ocr/test_machine_code_parser.py -v"
```
### 测试结构
```
tests/
├── ocr/
│ ├── test_machine_code_parser.py # 支付行解析 (79 tests)
│ └── test_ocr_engine.py # OCR 引擎测试
├── inference/
│ ├── test_payment_line_parser.py # 支付行解析器
│ └── test_customer_number_parser.py # 客户编号解析器
├── normalize/
│ └── test_normalizers.py # 字段规范化
└── utils/
└── test_validators.py # 字段验证
``` ```
## 开发状态 ## 开发状态
**已完成功能**:
- [x] 文本层 PDF 自动标注 - [x] 文本层 PDF 自动标注
- [x] 扫描图 OCR 自动标注 - [x] 扫描图 OCR 自动标注
- [x] 多策略字段匹配 (精确/子串/规范化) - [x] 多策略字段匹配 (精确/子串/规范化)
- [x] PostgreSQL 数据库存储 (断点续传) - [x] PostgreSQL 数据库存储 (断点续传)
- [x] 信号处理和超时保护 - [x] 信号处理和超时保护
- [x] YOLO 训练 (98.7% mAP@0.5) - [x] YOLO 训练 (93.5% mAP@0.5, 10 个字段)
- [x] 推理管道 - [x] 推理管道
- [x] 字段规范化和验证 - [x] 字段规范化和验证
- [x] Web 应用 (FastAPI + 前端 UI) - [x] Web 应用 (FastAPI + REST API)
- [x] 增量训练支持 - [x] 增量训练支持
- [x] 内存优化训练 (--low-memory, --resume)
- [x] Payment Line 解析器 (统一模块)
- [x] Customer Number 解析器 (统一模块)
- [x] Payment Line 交叉验证 (OCR, Amount, Account)
- [x] 文档类型识别 (invoice/letter)
- [x] 单元测试覆盖 (688 tests, 37% coverage)
**进行中**:
- [ ] 完成全部 25,000+ 文档标注 - [ ] 完成全部 25,000+ 文档标注
- [ ] 表格 items 处理 - [ ] 多源融合增强 (Multi-source fusion)
- [ ] 模型量化部署 - [ ] OCR 错误修正集成
- [ ] 提升测试覆盖率到 60%+
**计划中**:
- [ ] 表格 items 提取
- [ ] 模型量化部署 (ONNX/TensorRT)
- [ ] 多语言支持扩展
## 关键技术特性
### 1. Payment Line 交叉验证
瑞典发票的 payment_line (支付行) 包含完整的支付信息OCR 参考号、金额、账号。我们实现了交叉验证机制:
```
Payment Line: # 94228110015950070 # 15658 00 8 > 48666036#14#
↓ ↓ ↓
OCR Number Amount Bankgiro Account
```
**验证流程**:
1. 从 payment_line 提取 OCR、Amount、Account
2. 与单独检测的字段对比验证
3. **payment_line 值优先** - 如有不匹配,采用 payment_line 的值
4. 返回验证结果和详细信息
**优势**:
- 提高数据准确性 (payment_line 是机器可读格式,更可靠)
- 发现 OCR 或检测错误
- 为数据质量提供信心指标
### 2. 统一解析器架构
采用独立解析器模块处理复杂字段:
**PaymentLineParser**:
- 解析瑞典标准支付行格式
- 提取 OCR、Amount (包含 Kronor + Öre)、Account + Check digits
- 支持多种变体格式
**CustomerNumberParser**:
- 支持多种瑞典客户编号格式 (`UMJ 436-R`, `JTY 576-3`, `FFL 019N`)
- 从混合文本中提取 (如地址行中的客户编号)
- 大小写不敏感,输出统一大写格式
**优势**:
- 代码模块化、可测试
- 易于扩展新格式
- 统一的解析逻辑,减少重复代码
### 3. 文档类型自动识别
根据 payment_line 字段自动判断文档类型:
- **invoice**: 包含 payment_line 的发票文档
- **letter**: 不含 payment_line 的信函文档
这个特性帮助下游系统区分处理流程。
### 4. 低内存模式训练
支持在内存受限环境下训练:
```bash
python -m src.cli.train --low-memory
```
自动调整:
- batch size: 16 → 8
- workers: 8 → 4
- cache: disabled
- 推荐用于 GPU 内存 < 8GB 或系统内存 < 16GB 的场景
### 5. 断点续传训练
训练中断后可从检查点恢复:
```bash
python -m src.cli.train --resume --model runs/train/invoice_fields/weights/last.pt
```
## 技术栈 ## 技术栈
@@ -457,7 +729,33 @@ print(result.to_json()) # JSON 格式输出
| **PDF 处理** | PyMuPDF (fitz) | | **PDF 处理** | PyMuPDF (fitz) |
| **数据库** | PostgreSQL + psycopg2 | | **数据库** | PostgreSQL + psycopg2 |
| **Web 框架** | FastAPI + Uvicorn | | **Web 框架** | FastAPI + Uvicorn |
| **深度学习** | PyTorch + CUDA | | **深度学习** | PyTorch + CUDA 12.x |
## 常见问题
**Q: 为什么必须在 WSL 环境运行?**
A: PaddleOCR 和某些依赖在 Windows 原生环境存在兼容性问题。WSL 提供完整的 Linux 环境,确保所有依赖正常工作。
**Q: 训练过程中出现 OOM (内存不足) 错误怎么办?**
A: 使用 `--low-memory` 模式,或手动调整 `--batch` 和 `--workers` 参数。
**Q: payment_line 和单独检测字段不匹配时怎么处理?**
A: 系统默认优先采用 payment_line 的值,因为 payment_line 是机器可读格式,通常更准确。验证结果会记录在 `cross_validation` 字段中。
**Q: 如何添加新的字段类型?**
A:
1. 在 `src/inference/constants.py` 添加字段定义
2. 在 `field_extractor.py` 添加规范化方法
3. 重新生成标注数据
4. 从头训练模型
**Q: 可以用 CPU 训练吗?**
A: 可以,但速度会非常慢 (慢 10-50 倍)。强烈建议使用 GPU 训练。
## 许可证 ## 许可证

View File

@@ -4,6 +4,12 @@ Configuration settings for the invoice extraction system.
import os import os
import platform import platform
from pathlib import Path
from dotenv import load_dotenv
# Load environment variables from .env file
env_path = Path(__file__).parent / '.env'
load_dotenv(dotenv_path=env_path)
def _is_wsl() -> bool: def _is_wsl() -> bool:
@@ -21,14 +27,22 @@ def _is_wsl() -> bool:
# PostgreSQL Database Configuration # PostgreSQL Database Configuration
# Now loaded from environment variables for security
DATABASE = { DATABASE = {
'host': '192.168.68.31', 'host': os.getenv('DB_HOST', '192.168.68.31'),
'port': 5432, 'port': int(os.getenv('DB_PORT', '5432')),
'database': 'docmaster', 'database': os.getenv('DB_NAME', 'docmaster'),
'user': 'docmaster', 'user': os.getenv('DB_USER', 'docmaster'),
'password': '0412220', 'password': os.getenv('DB_PASSWORD'), # No default for security
} }
# Validate required configuration
if not DATABASE['password']:
raise ValueError(
"DB_PASSWORD environment variable is not set. "
"Please create a .env file based on .env.example and set DB_PASSWORD."
)
# Connection string for psycopg2 # Connection string for psycopg2
def get_db_connection_string(): def get_db_connection_string():
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}" return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"

405
docs/CODE_REVIEW_REPORT.md Normal file
View File

@@ -0,0 +1,405 @@
# Invoice Master POC v2 - 代码审查报告
**审查日期**: 2026-01-22
**代码库规模**: 67 个 Python 源文件,约 22,434 行代码
**测试覆盖率**: ~40-50%
---
## 执行摘要
### 总体评估:**良好B+**
**优势**
- ✅ 清晰的模块化架构,职责分离良好
- ✅ 使用了合适的数据类和类型提示
- ✅ 针对瑞典发票的全面规范化逻辑
- ✅ 空间索引优化O(1) token 查找)
- ✅ 完善的降级机制YOLO 失败时的 OCR fallback
- ✅ 设计良好的 Web API 和 UI
**主要问题**
- ❌ 支付行解析代码重复3+ 处)
- ❌ 长函数(`_normalize_customer_number` 127 行)
- ❌ 配置安全问题(明文数据库密码)
- ❌ 异常处理不一致(到处都是通用 Exception
- ❌ 缺少集成测试
- ❌ 魔法数字散布各处0.5, 0.95, 300 等)
---
## 1. 架构分析
### 1.1 模块结构
```
src/
├── inference/ # 推理管道核心
│ ├── pipeline.py (517 行) ⚠️
│ ├── field_extractor.py (1,347 行) 🔴 太长
│ └── yolo_detector.py
├── web/ # FastAPI Web 服务
│ ├── app.py (765 行) ⚠️ HTML 内联
│ ├── routes.py (184 行)
│ └── services.py (286 行)
├── ocr/ # OCR 提取
│ ├── paddle_ocr.py
│ └── machine_code_parser.py (919 行) 🔴 太长
├── matcher/ # 字段匹配
│ └── field_matcher.py (875 行) ⚠️
├── utils/ # 共享工具
│ ├── validators.py
│ ├── text_cleaner.py
│ ├── fuzzy_matcher.py
│ ├── ocr_corrections.py
│ └── format_variants.py (610 行)
├── processing/ # 批处理
├── data/ # 数据管理
└── cli/ # 命令行工具
```
### 1.2 推理流程
```
PDF/Image 输入
渲染为图片 (pdf/renderer.py)
YOLO 检测 (yolo_detector.py) - 检测字段区域
字段提取 (field_extractor.py)
├→ OCR 文本提取 (ocr/paddle_ocr.py)
├→ 规范化 & 验证
└→ 置信度计算
交叉验证 (pipeline.py)
├→ 解析 payment_line 格式
├→ 从 payment_line 提取 OCR/Amount/Account
└→ 与检测字段验证payment_line 值优先
降级 OCR如果关键字段缺失
├→ 全页 OCR
└→ 正则提取
InferenceResult 输出
```
---
## 2. 代码质量问题
### 2.1 长函数(>50 行)🔴
| 函数 | 文件 | 行数 | 复杂度 | 问题 |
|------|------|------|--------|------|
| `_normalize_customer_number()` | field_extractor.py | **127** | 极高 | 4 层模式匹配7+ 正则,复杂评分 |
| `_cross_validate_payment_line()` | pipeline.py | **127** | 极高 | 核心验证逻辑8+ 条件分支 |
| `_normalize_bankgiro()` | field_extractor.py | 62 | 高 | Luhn 验证 + 多种降级 |
| `_normalize_plusgiro()` | field_extractor.py | 63 | 高 | 类似 bankgiro |
| `_normalize_payment_line()` | field_extractor.py | 74 | 高 | 4 种正则模式 |
| `_normalize_amount()` | field_extractor.py | 78 | 高 | 多策略降级 |
**示例问题** - `_normalize_customer_number()` (第 776-902 行):
```python
def _normalize_customer_number(self, text: str):
# 127 行函数,包含:
# - 4 个嵌套的 if/for 循环
# - 7 种不同的正则模式
# - 5 个评分机制
# - 处理有标签和无标签格式
```
**建议**: 拆分为:
- `_find_customer_code_patterns()`
- `_find_labeled_customer_code()`
- `_score_customer_candidates()`
### 2.2 代码重复 🔴
**支付行解析3+ 处重复实现)**:
1. `_parse_machine_readable_payment_line()` (pipeline.py:217-252)
2. `MachineCodeParser.parse()` (machine_code_parser.py:919 行)
3. `_normalize_payment_line()` (field_extractor.py:632-705)
所有三处都实现类似的正则模式:
```
格式: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
```
**Bankgiro/Plusgiro 验证(重复)**:
- `validators.py`: `is_valid_bankgiro()`, `format_bankgiro()`
- `field_extractor.py`: `_normalize_bankgiro()`, `_normalize_plusgiro()`, `_luhn_checksum()`
- `normalizer.py`: `normalize_bankgiro()`, `normalize_plusgiro()`
- `field_matcher.py`: 类似匹配逻辑
**建议**: 创建统一模块:
```python
# src/common/payment_line_parser.py
class PaymentLineParser:
def parse(text: str) -> PaymentLineResult
# src/common/giro_validator.py
class GiroValidator:
def validate_and_format(value: str, giro_type: str) -> str
```
### 2.3 错误处理不一致 ⚠️
**通用异常捕获31 处)**:
```python
except Exception as e: # 代码库中 31 处
result.errors.append(str(e))
```
**问题**:
- 没有捕获特定错误类型
- 通用错误消息丢失上下文
- 第 142-147 行 (routes.py): 捕获所有异常,返回 500 状态
**当前写法** (routes.py:142-147):
```python
try:
service_result = inference_service.process_pdf(...)
except Exception as e: # 太宽泛
logger.error(f"Error processing document: {e}")
raise HTTPException(status_code=500, detail=str(e))
```
**改进建议**:
```python
except FileNotFoundError:
raise HTTPException(status_code=400, detail="PDF 文件未找到")
except PyMuPDFError:
raise HTTPException(status_code=400, detail="无效的 PDF 格式")
except OCRError:
raise HTTPException(status_code=503, detail="OCR 服务不可用")
```
### 2.4 配置安全问题 🔴
**config.py 第 24-30 行** - 明文凭据:
```python
DATABASE = {
'host': '192.168.68.31', # 硬编码 IP
'user': 'docmaster', # 硬编码用户名
'password': 'nY6LYK5d', # 🔴 明文密码!
'database': 'invoice_master'
}
```
**建议**:
```python
DATABASE = {
'host': os.getenv('DB_HOST', 'localhost'),
'user': os.getenv('DB_USER', 'docmaster'),
'password': os.getenv('DB_PASSWORD'), # 从环境变量读取
'database': os.getenv('DB_NAME', 'invoice_master')
}
```
### 2.5 魔法数字 ⚠️
| 值 | 位置 | 用途 | 问题 |
|---|------|------|------|
| 0.5 | 多处 | 置信度阈值 | 不可按字段配置 |
| 0.95 | pipeline.py | payment_line 置信度 | 无说明 |
| 300 | 多处 | DPI | 硬编码 |
| 0.1 | field_extractor.py | BBox 填充 | 应为配置 |
| 72 | 多处 | PDF 基础 DPI | 公式中的魔法数字 |
| 50 | field_extractor.py | 客户编号评分加分 | 无说明 |
**建议**: 提取到配置:
```python
INFERENCE_CONFIG = {
'confidence_threshold': 0.5,
'payment_line_confidence': 0.95,
'dpi': 300,
'bbox_padding': 0.1,
}
```
### 2.6 命名不一致 ⚠️
**字段名称不一致**:
- YOLO 类名: `invoice_number`, `ocr_number`, `supplier_org_number`
- 字段名: `InvoiceNumber`, `OCR`, `supplier_org_number`
- CSV 列名: 可能又不同
- 数据库字段名: 另一种变体
映射维护在多处:
- `yolo_detector.py` (90-100 行): `CLASS_TO_FIELD`
- 多个其他位置
---
## 3. 测试分析
### 3.1 测试覆盖率
**测试文件**: 13 个
- ✅ 覆盖良好: field_matcher, normalizer, payment_line_parser
- ⚠️ 中等覆盖: field_extractor, pipeline
- ❌ 覆盖不足: web 层, CLI, 批处理
**估算覆盖率**: 40-50%
### 3.2 缺失的测试用例 🔴
**关键缺失**:
1. 交叉验证逻辑 - 最复杂部分,测试很少
2. payment_line 解析变体 - 多种实现,边界情况不清楚
3. OCR 错误纠正 - 不同策略的复杂逻辑
4. Web API 端点 - 没有请求/响应测试
5. 批处理 - 多 worker 协调未测试
6. 降级 OCR 机制 - YOLO 检测失败时
---
## 4. 架构风险
### 🔴 关键风险
1. **配置安全** - config.py 中明文数据库凭据24-30 行)
2. **错误恢复** - 宽泛的异常处理掩盖真实问题
3. **可测试性** - 硬编码依赖阻止单元测试
### 🟡 高风险
1. **代码可维护性** - 支付行解析重复
2. **可扩展性** - 没有长时间推理的异步处理
3. **扩展性** - 添加新字段类型会很困难
### 🟢 中等风险
1. **性能** - 懒加载有帮助,但 ORM 查询未优化
2. **文档** - 大部分足够但可以更好
---
## 5. 优先级矩阵
| 优先级 | 行动 | 工作量 | 影响 |
|--------|------|--------|------|
| 🔴 关键 | 修复配置安全(环境变量) | 1 小时 | 高 |
| 🔴 关键 | 添加集成测试 | 2-3 天 | 高 |
| 🔴 关键 | 文档化错误处理策略 | 4 小时 | 中 |
| 🟡 高 | 统一 payment_line 解析 | 1-2 天 | 高 |
| 🟡 高 | 提取规范化到子模块 | 2-3 天 | 中 |
| 🟡 高 | 添加依赖注入 | 2-3 天 | 中 |
| 🟡 高 | 拆分长函数 | 2-3 天 | 低 |
| 🟢 中 | 提高测试覆盖率到 70%+ | 3-5 天 | 高 |
| 🟢 中 | 提取魔法数字 | 4 小时 | 低 |
| 🟢 中 | 标准化命名约定 | 1-2 天 | 中 |
---
## 6. 具体文件建议
### 高优先级(代码质量)
| 文件 | 问题 | 建议 |
|------|------|------|
| `field_extractor.py` | 1,347 行6 个长规范化方法 | 拆分为 `normalizers/` 子模块 |
| `pipeline.py` | 127 行 `_cross_validate_payment_line()` | 提取到单独的 `CrossValidator` 类 |
| `field_matcher.py` | 875 行;复杂匹配逻辑 | 拆分为 `matching/` 子模块 |
| `config.py` | 硬编码凭据(第 29 行) | 使用环境变量 |
| `machine_code_parser.py` | 919 行payment_line 解析 | 与 pipeline 解析合并 |
### 中优先级(重构)
| 文件 | 问题 | 建议 |
|------|------|------|
| `app.py` | 765 行HTML 内联在 Python 中 | 提取到 `templates/` 目录 |
| `autolabel.py` | 753 行;批处理逻辑 | 提取 worker 函数到模块 |
| `format_variants.py` | 610 行;变体生成 | 考虑策略模式 |
---
## 7. 建议行动
### 第 1 阶段关键修复1 周)
1. **配置安全** (1 小时)
- 移除 config.py 中的明文密码
- 添加环境变量支持
- 更新 README 说明配置
2. **错误处理标准化** (1 天)
- 定义自定义异常类
- 替换通用 Exception 捕获
- 添加错误代码常量
3. **添加关键集成测试** (2 天)
- 端到端推理测试
- payment_line 交叉验证测试
- API 端点测试
### 第 2 阶段重构2-3 周)
4. **统一 payment_line 解析** (2 天)
- 创建 `src/common/payment_line_parser.py`
- 合并 3 处重复实现
- 迁移所有调用方
5. **拆分 field_extractor.py** (3 天)
- 创建 `src/inference/normalizers/` 子模块
- 每个字段类型一个文件
- 提取共享验证逻辑
6. **拆分长函数** (2 天)
- `_normalize_customer_number()` → 3 个函数
- `_cross_validate_payment_line()` → CrossValidator 类
### 第 3 阶段改进1-2 周)
7. **提高测试覆盖率** (5 天)
- 目标70%+ 覆盖率
- 专注于验证逻辑
- 添加边界情况测试
8. **配置管理改进** (1 天)
- 提取所有魔法数字
- 创建配置文件YAML
- 添加配置验证
9. **文档改进** (2 天)
- 添加架构图
- 文档化所有私有方法
- 创建贡献指南
---
## 附录 A度量指标
### 代码复杂度
| 类别 | 计数 | 平均行数 |
|------|------|----------|
| 源文件 | 67 | 334 |
| 长文件 (>500 行) | 12 | 875 |
| 长函数 (>50 行) | 23 | 89 |
| 测试文件 | 13 | 298 |
### 依赖关系
| 类型 | 计数 |
|------|------|
| 外部依赖 | ~25 |
| 内部模块 | 10 |
| 循环依赖 | 0 ✅ |
### 代码风格
| 指标 | 覆盖率 |
|------|--------|
| 类型提示 | 80% |
| Docstrings (公开) | 80% |
| Docstrings (私有) | 40% |
| 测试覆盖率 | 45% |
---
**生成日期**: 2026-01-22
**审查者**: Claude Code
**版本**: v2.0

View File

@@ -0,0 +1,96 @@
# Field Extractor 分析报告
## 概述
field_extractor.py (1183行) 最初被识别为可优化文件,尝试使用 `src/normalize` 模块进行重构,但经过分析和测试后发现 **不应该重构**
## 重构尝试
### 初始计划
将 field_extractor.py 中的重复 normalize 方法删除,统一使用 `src/normalize/normalize_field()` 接口。
### 实施步骤
1. ✅ 备份原文件 (`field_extractor_old.py`)
2. ✅ 修改 `_normalize_and_validate` 使用统一 normalizer
3. ✅ 删除重复的 normalize 方法 (~400行)
4. ❌ 运行测试 - **28个失败**
5. ✅ 添加 wrapper 方法委托给 normalizer
6. ❌ 再次测试 - **12个失败**
7. ✅ 还原原文件
8. ✅ 测试通过 - **全部45个测试通过**
## 关键发现
### 两个模块的不同用途
| 模块 | 用途 | 输入 | 输出 | 示例 |
|------|------|------|------|------|
| **src/normalize/** | **变体生成** 用于匹配 | 已提取的字段值 | 多个匹配变体列表 | `"INV-12345"``["INV-12345", "12345"]` |
| **field_extractor** | **值提取** 从OCR文本 | 包含字段的原始OCR文本 | 提取的单个字段值 | `"Fakturanummer: A3861"``"A3861"` |
### 为什么不能统一?
1. **src/normalize/** 的设计目的:
- 接收已经提取的字段值
- 生成多个标准化变体用于fuzzy matching
- 例如 BankgiroNormalizer:
```python
normalize("782-1713") → ["7821713", "782-1713"] # 生成变体
```
2. **field_extractor** 的 normalize 方法:
- 接收包含字段的原始OCR文本可能包含标签、其他文本等
- **提取**特定模式的字段值
- 例如 `_normalize_bankgiro`:
```python
_normalize_bankgiro("Bankgiro: 782-1713") → ("782-1713", True, None) # 从文本提取
```
3. **关键区别**:
- Normalizer: 变体生成器 (for matching)
- Field Extractor: 模式提取器 (for parsing)
### 测试失败示例
使用 normalizer 替代 field extractor 方法后的失败:
```python
# InvoiceNumber 测试
Input: "Fakturanummer: A3861"
期望: "A3861"
实际: "Fakturanummer: A3861" # 没有提取,只是清理
# Bankgiro 测试
Input: "Bankgiro: 782-1713"
期望: "782-1713"
实际: "7821713" # 返回了不带破折号的变体,而不是提取格式化值
```
## 结论
**field_extractor.py 不应该使用 src/normalize 模块重构**,因为:
1.**职责不同**: 提取 vs 变体生成
2.**输入不同**: 包含标签的原始OCR文本 vs 已提取的字段值
3.**输出不同**: 单个提取值 vs 多个匹配变体
4.**现有代码运行良好**: 所有45个测试通过
5.**提取逻辑有价值**: 包含复杂的模式匹配规则(例如区分 Bankgiro/Plusgiro 格式)
## 建议
1. **保留 field_extractor.py 原样**: 不进行重构
2. **文档化两个模块的差异**: 确保团队理解各自用途
3. **关注其他优化目标**: machine_code_parser.py (919行)
## 学习点
重构前应该:
1. 理解模块的**真实用途**,而不只是看代码相似度
2. 运行完整测试套件验证假设
3. 评估是否真的存在重复,还是表面相似但用途不同
---
**状态**: ✅ 分析完成,决定不重构
**测试**: ✅ 45/45 通过
**文件**: 保持 1183行 原样

View File

@@ -0,0 +1,238 @@
# Machine Code Parser 分析报告
## 文件概况
- **文件**: `src/ocr/machine_code_parser.py`
- **总行数**: 919 行
- **代码行**: 607 行 (66%)
- **方法数**: 14 个
- **正则表达式使用**: 47 次
## 代码结构
### 类结构
```
MachineCodeResult (数据类)
├── to_dict()
└── get_region_bbox()
MachineCodeParser (主解析器)
├── __init__()
├── parse() - 主入口
├── _find_tokens_with_values()
├── _find_machine_code_line_tokens()
├── _parse_standard_payment_line_with_tokens()
├── _parse_standard_payment_line() - 142行 ⚠️
├── _extract_ocr() - 50行
├── _extract_bankgiro() - 58行
├── _extract_plusgiro() - 30行
├── _extract_amount() - 68行
├── _calculate_confidence()
└── cross_validate()
```
## 发现的问题
### 1. ⚠️ `_parse_standard_payment_line` 方法过长 (142行)
**位置**: 442-582 行
**问题**:
- 包含嵌套函数 `normalize_account_spaces``format_account`
- 多个正则匹配分支
- 逻辑复杂,难以测试和维护
**建议**:
可以拆分为独立方法:
- `_normalize_account_spaces(line)`
- `_format_account(account_digits, context)`
- `_match_primary_pattern(line)`
- `_match_fallback_patterns(line)`
### 2. 🔁 4个 `_extract_*` 方法有重复模式
所有 extract 方法都遵循相同模式:
```python
def _extract_XXX(self, tokens):
candidates = []
for token in tokens:
text = token.text.strip()
matches = self.XXX_PATTERN.findall(text)
for match in matches:
# 验证逻辑
# 上下文检测
candidates.append((normalized, context_score, token))
if not candidates:
return None
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
return candidates[0][0]
```
**重复的逻辑**:
- Token 迭代
- 模式匹配
- 候选收集
- 上下文评分
- 排序和选择最佳匹配
**建议**:
可以提取基础提取器类或通用方法来减少重复。
### 3. ✅ 上下文检测重复
上下文检测代码在多个地方重复:
```python
# _extract_bankgiro 中
context_text = ' '.join(t.text.lower() for t in tokens)
is_bankgiro_context = (
'bankgiro' in context_text or
'bg:' in context_text or
'bg ' in context_text
)
# _extract_plusgiro 中
context_text = ' '.join(t.text.lower() for t in tokens)
is_plusgiro_context = (
'plusgiro' in context_text or
'postgiro' in context_text or
'pg:' in context_text or
'pg ' in context_text
)
# _parse_standard_payment_line 中
context = (context_line or raw_line).lower()
is_plusgiro_context = (
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
and 'bankgiro' not in context
)
```
**建议**:
提取为独立方法:
- `_detect_account_context(tokens) -> dict[str, bool]`
## 重构建议
### 方案 A: 轻度重构(推荐)✅
**目标**: 提取重复的上下文检测逻辑,不改变主要结构
**步骤**:
1. 提取 `_detect_account_context(tokens)` 方法
2. 提取 `_normalize_account_spaces(line)` 为独立方法
3. 提取 `_format_account(digits, context)` 为独立方法
**影响**:
- 减少 ~50-80 行重复代码
- 提高可测试性
- 低风险,易于验证
**预期结果**: 919 行 → ~850 行 (↓7%)
### 方案 B: 中度重构
**目标**: 创建通用的字段提取框架
**步骤**:
1. 创建 `_generic_extract(pattern, normalizer, context_checker)`
2. 重构所有 `_extract_*` 方法使用通用框架
3. 拆分 `_parse_standard_payment_line` 为多个小方法
**影响**:
- 减少 ~150-200 行代码
- 显著提高可维护性
- 中等风险,需要全面测试
**预期结果**: 919 行 → ~720 行 (↓22%)
### 方案 C: 深度重构(不推荐)
**目标**: 完全重新设计为策略模式
**风险**:
- 高风险,可能引入 bugs
- 需要大量测试
- 可能破坏现有集成
## 推荐方案
### ✅ 采用方案 A轻度重构
**理由**:
1. **代码已经工作良好**: 没有明显的 bug 或性能问题
2. **低风险**: 只提取重复逻辑,不改变核心算法
3. **性价比高**: 小改动带来明显的代码质量提升
4. **易于验证**: 现有测试应该能覆盖
### 重构步骤
```python
# 1. 提取上下文检测
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
"""检测上下文中的账户类型关键词"""
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
# 2. 提取空格标准化
def _normalize_account_spaces(self, line: str) -> str:
"""移除账户号码中的空格"""
# (现有 line 460-481 的代码)
# 3. 提取账户格式化
def _format_account(
self,
account_digits: str,
is_plusgiro_context: bool
) -> tuple[str, str]:
"""格式化账户并确定类型"""
# (现有 line 485-523 的代码)
```
## 对比field_extractor vs machine_code_parser
| 特征 | field_extractor | machine_code_parser |
|------|-----------------|---------------------|
| 用途 | 值提取 | 机器码解析 |
| 重复代码 | ~400行normalize方法 | ~80行上下文检测 |
| 重构价值 | ❌ 不同用途,不应统一 | ✅ 可提取共享逻辑 |
| 风险 | 高(会破坏功能) | 低(只是代码组织) |
## 决策
### ✅ 建议重构 machine_code_parser.py
**与 field_extractor 的不同**:
- field_extractor: 重复的方法有**不同的用途**(提取 vs 变体生成)
- machine_code_parser: 重复的代码有**相同的用途**(都是上下文检测)
**预期收益**:
- 减少 ~70 行重复代码
- 提高可测试性(可以单独测试上下文检测)
- 更清晰的代码组织
- **低风险**,易于验证
## 下一步
1. ✅ 备份原文件
2. ✅ 提取 `_detect_account_context` 方法
3. ✅ 提取 `_normalize_account_spaces` 方法
4. ✅ 提取 `_format_account` 方法
5. ✅ 更新所有调用点
6. ✅ 运行测试验证
7. ✅ 检查代码覆盖率
---
**状态**: 📋 分析完成,建议轻度重构
**风险评估**: 🟢 低风险
**预期收益**: 919行 → ~850行 (↓7%)

View File

@@ -0,0 +1,519 @@
# Performance Optimization Guide
This document provides performance optimization recommendations for the Invoice Field Extraction system.
## Table of Contents
1. [Batch Processing Optimization](#batch-processing-optimization)
2. [Database Query Optimization](#database-query-optimization)
3. [Caching Strategies](#caching-strategies)
4. [Memory Management](#memory-management)
5. [Profiling and Monitoring](#profiling-and-monitoring)
---
## Batch Processing Optimization
### Current State
The system processes invoices one at a time. For large batches, this can be inefficient.
### Recommendations
#### 1. Database Batch Operations
**Current**: Individual inserts for each document
```python
# Inefficient
for doc in documents:
db.insert_document(doc) # Individual DB call
```
**Optimized**: Use `execute_values` for batch inserts
```python
# Efficient - already implemented in db.py line 519
from psycopg2.extras import execute_values
execute_values(cursor, """
INSERT INTO documents (...)
VALUES %s
""", document_values)
```
**Impact**: 10-50x faster for batches of 100+ documents
#### 2. PDF Processing Batching
**Recommendation**: Process PDFs in parallel using multiprocessing
```python
from multiprocessing import Pool
def process_batch(pdf_paths, batch_size=10):
"""Process PDFs in parallel batches."""
with Pool(processes=batch_size) as pool:
results = pool.map(pipeline.process_pdf, pdf_paths)
return results
```
**Considerations**:
- GPU models should use a shared process pool (already exists: `src/processing/gpu_pool.py`)
- CPU-intensive tasks can use separate process pool (`src/processing/cpu_pool.py`)
- Current dual pool coordinator (`dual_pool_coordinator.py`) already supports this pattern
**Status**: ✅ Already implemented in `src/processing/` modules
#### 3. Image Caching for Multi-Page PDFs
**Current**: Each page rendered independently
```python
# Current pattern in field_extractor.py
for page_num in range(total_pages):
image = render_pdf_page(pdf_path, page_num, dpi=300)
```
**Optimized**: Pre-render all pages if processing multiple fields per page
```python
# Batch render
images = {
page_num: render_pdf_page(pdf_path, page_num, dpi=300)
for page_num in page_numbers_needed
}
# Reuse images
for detection in detections:
image = images[detection.page_no]
extract_field(detection, image)
```
**Impact**: Reduces redundant PDF rendering by 50-90% for multi-field invoices
---
## Database Query Optimization
### Current Performance
- **Parameterized queries**: ✅ Implemented (Phase 1)
- **Connection pooling**: ❌ Not implemented
- **Query batching**: ✅ Partially implemented
- **Index optimization**: ⚠️ Needs verification
### Recommendations
#### 1. Connection Pooling
**Current**: New connection for each operation
```python
def connect(self):
"""Create new database connection."""
return psycopg2.connect(**self.config)
```
**Optimized**: Use connection pooling
```python
from psycopg2 import pool
class DocumentDatabase:
def __init__(self, config):
self.pool = pool.SimpleConnectionPool(
minconn=1,
maxconn=10,
**config
)
def connect(self):
return self.pool.getconn()
def close(self, conn):
self.pool.putconn(conn)
```
**Impact**:
- Reduces connection overhead by 80-95%
- Especially important for high-frequency operations
#### 2. Index Recommendations
**Check current indexes**:
```sql
-- Verify indexes exist on frequently queried columns
SELECT tablename, indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'public';
```
**Recommended indexes**:
```sql
-- If not already present
CREATE INDEX IF NOT EXISTS idx_documents_success
ON documents(success);
CREATE INDEX IF NOT EXISTS idx_documents_timestamp
ON documents(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id
ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_matched
ON field_results(matched);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name
ON field_results(field_name);
```
**Impact**:
- 10-100x faster queries for filtered/sorted results
- Critical for `get_failed_matches()` and `get_all_documents_summary()`
#### 3. Query Batching
**Status**: ✅ Already implemented for field results (line 519)
**Verify batching is used**:
```python
# Good pattern in db.py
execute_values(cursor, "INSERT INTO field_results (...) VALUES %s", field_values)
```
**Additional opportunity**: Batch `SELECT` queries
```python
# Current
docs = [get_document(doc_id) for doc_id in doc_ids] # N queries
# Optimized
docs = get_documents_batch(doc_ids) # 1 query with IN clause
```
**Status**: ✅ Already implemented (`get_documents_batch` exists in db.py)
---
## Caching Strategies
### 1. Model Loading Cache
**Current**: Models loaded per-instance
**Recommendation**: Singleton pattern for YOLO model
```python
class YOLODetectorSingleton:
_instance = None
_model = None
@classmethod
def get_instance(cls, model_path):
if cls._instance is None:
cls._instance = YOLODetector(model_path)
return cls._instance
```
**Impact**: Reduces memory usage by 90% when processing multiple documents
### 2. Parser Instance Caching
**Current**: ✅ Already optimal
```python
# Good pattern in field_extractor.py
def __init__(self):
self.payment_line_parser = PaymentLineParser() # Reused
self.customer_number_parser = CustomerNumberParser() # Reused
```
**Status**: No changes needed
### 3. OCR Result Caching
**Recommendation**: Cache OCR results for identical regions
```python
from functools import lru_cache
@lru_cache(maxsize=1000)
def ocr_region_cached(image_hash, bbox):
"""Cache OCR results by image hash + bbox."""
return paddle_ocr.ocr_region(image, bbox)
```
**Impact**: 50-80% speedup when re-processing similar documents
**Note**: Requires implementing image hashing (e.g., `hashlib.md5(image.tobytes())`)
---
## Memory Management
### Current Issues
**Potential memory leaks**:
1. Large images kept in memory after processing
2. OCR results accumulated without cleanup
3. Model outputs not explicitly cleared
### Recommendations
#### 1. Explicit Image Cleanup
```python
import gc
def process_pdf(pdf_path):
try:
image = render_pdf(pdf_path)
result = extract_fields(image)
return result
finally:
del image # Explicit cleanup
gc.collect() # Force garbage collection
```
#### 2. Generator Pattern for Large Batches
**Current**: Load all documents into memory
```python
docs = [process_pdf(path) for path in pdf_paths] # All in memory
```
**Optimized**: Use generator for streaming processing
```python
def process_batch_streaming(pdf_paths):
"""Process documents one at a time, yielding results."""
for path in pdf_paths:
result = process_pdf(path)
yield result
# Result can be saved to DB immediately
# Previous result is garbage collected
```
**Impact**: Constant memory usage regardless of batch size
#### 3. Context Managers for Resources
```python
class InferencePipeline:
def __enter__(self):
self.detector.load_model()
return self
def __exit__(self, *args):
self.detector.unload_model()
self.extractor.cleanup()
# Usage
with InferencePipeline(...) as pipeline:
results = pipeline.process_pdf(path)
# Automatic cleanup
```
---
## Profiling and Monitoring
### Recommended Profiling Tools
#### 1. cProfile for CPU Profiling
```python
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
# Your code here
pipeline.process_pdf(pdf_path)
profiler.disable()
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20) # Top 20 slowest functions
```
#### 2. memory_profiler for Memory Analysis
```bash
pip install memory_profiler
python -m memory_profiler your_script.py
```
Or decorator-based:
```python
from memory_profiler import profile
@profile
def process_large_batch(pdf_paths):
# Memory usage tracked line-by-line
results = [process_pdf(path) for path in pdf_paths]
return results
```
#### 3. py-spy for Production Profiling
```bash
pip install py-spy
# Profile running process
py-spy top --pid 12345
# Generate flamegraph
py-spy record -o profile.svg -- python your_script.py
```
**Advantage**: No code changes needed, minimal overhead
### Key Metrics to Monitor
1. **Processing Time per Document**
- Target: <10 seconds for single-page invoice
- Current: ~2-5 seconds (estimated)
2. **Memory Usage**
- Target: <2GB for batch of 100 documents
- Monitor: Peak memory usage
3. **Database Query Time**
- Target: <100ms per query (with indexes)
- Monitor: Slow query log
4. **OCR Accuracy vs Speed Trade-off**
- Current: PaddleOCR with GPU (~200ms per region)
- Alternative: Tesseract (~500ms, slightly more accurate)
### Logging Performance Metrics
**Add to pipeline.py**:
```python
import time
import logging
logger = logging.getLogger(__name__)
def process_pdf(self, pdf_path):
start = time.time()
# Processing...
result = self._process_internal(pdf_path)
elapsed = time.time() - start
logger.info(f"Processed {pdf_path} in {elapsed:.2f}s")
# Log to database for analysis
self.db.log_performance({
'document_id': result.document_id,
'processing_time': elapsed,
'field_count': len(result.fields)
})
return result
```
---
## Performance Optimization Priorities
### High Priority (Implement First)
1. **Database parameterized queries** - Already done (Phase 1)
2. **Database connection pooling** - Not implemented
3. **Index optimization** - Needs verification
### Medium Priority
4. **Batch PDF rendering** - Optimization possible
5. **Parser instance reuse** - Already done (Phase 2)
6. **Model caching** - Could improve
### Low Priority (Nice to Have)
7. **OCR result caching** - Complex implementation
8. **Generator patterns** - Refactoring needed
9. **Advanced profiling** - For production optimization
---
## Benchmarking Script
```python
"""
Benchmark script for invoice processing performance.
"""
import time
from pathlib import Path
from src.inference.pipeline import InferencePipeline
def benchmark_single_document(pdf_path, iterations=10):
"""Benchmark single document processing."""
pipeline = InferencePipeline(
model_path="path/to/model.pt",
use_gpu=True
)
times = []
for i in range(iterations):
start = time.time()
result = pipeline.process_pdf(pdf_path)
elapsed = time.time() - start
times.append(elapsed)
print(f"Iteration {i+1}: {elapsed:.2f}s")
avg_time = sum(times) / len(times)
print(f"\nAverage: {avg_time:.2f}s")
print(f"Min: {min(times):.2f}s")
print(f"Max: {max(times):.2f}s")
def benchmark_batch(pdf_paths, batch_size=10):
"""Benchmark batch processing."""
from multiprocessing import Pool
pipeline = InferencePipeline(
model_path="path/to/model.pt",
use_gpu=True
)
start = time.time()
with Pool(processes=batch_size) as pool:
results = pool.map(pipeline.process_pdf, pdf_paths)
elapsed = time.time() - start
avg_per_doc = elapsed / len(pdf_paths)
print(f"Total time: {elapsed:.2f}s")
print(f"Documents: {len(pdf_paths)}")
print(f"Average per document: {avg_per_doc:.2f}s")
print(f"Throughput: {len(pdf_paths)/elapsed:.2f} docs/sec")
if __name__ == "__main__":
# Single document benchmark
benchmark_single_document("test.pdf")
# Batch benchmark
pdf_paths = list(Path("data/test_pdfs").glob("*.pdf"))
benchmark_batch(pdf_paths[:100])
```
---
## Summary
**Implemented (Phase 1-2)**:
- Parameterized queries (SQL injection fix)
- Parser instance reuse (Phase 2 refactoring)
- Batch insert operations (execute_values)
- Dual pool processing (CPU/GPU separation)
**Quick Wins (Low effort, high impact)**:
- Database connection pooling (2-4 hours)
- Index verification and optimization (1-2 hours)
- Batch PDF rendering (4-6 hours)
**Long-term Improvements**:
- OCR result caching with hashing
- Generator patterns for streaming
- Advanced profiling and monitoring
**Expected Impact**:
- Connection pooling: 80-95% reduction in DB overhead
- Indexes: 10-100x faster queries
- Batch rendering: 50-90% less redundant work
- **Overall**: 2-5x throughput improvement for batch processing

1447
docs/REFACTORING_PLAN.md Normal file

File diff suppressed because it is too large Load Diff

170
docs/REFACTORING_SUMMARY.md Normal file
View File

@@ -0,0 +1,170 @@
# 代码重构总结报告
## 📊 整体成果
### 测试状态
-**688/688 测试全部通过** (100%)
-**代码覆盖率**: 34% → 37% (+3%)
-**0 个失败**, 0 个错误
### 测试覆盖率改进
-**machine_code_parser**: 25% → 65% (+40%)
-**新增测试**: 55个633 → 688
---
## 🎯 已完成的重构
### 1. ✅ Matcher 模块化 (876行 → 205行, ↓76%)
**文件**:
**重构内容**:
- 将单一876行文件拆分为 **11个模块**
- 提取 **5种独立的匹配策略**
- 创建专门的数据模型、工具函数和上下文处理模块
**新模块结构**:
**测试结果**:
- ✅ 77个 matcher 测试全部通过
- ✅ 完整的README文档
- ✅ 策略模式,易于扩展
**收益**:
- 📉 代码量减少 76%
- 📈 可维护性显著提高
- ✨ 每个策略独立测试
- 🔧 易于添加新策略
---
### 2. ✅ Machine Code Parser 轻度重构 + 测试覆盖 (919行 → 929行)
**文件**: src/ocr/machine_code_parser.py
**重构内容**:
- 提取 **3个共享辅助方法**,消除重复代码
- 优化上下文检测逻辑
- 简化账号格式化方法
**测试改进**:
-**新增55个测试**24 → 79个
-**覆盖率**: 25% → 65% (+40%)
- ✅ 所有688个项目测试通过
**新增测试覆盖**:
- **第一轮** (22个测试):
- `_detect_account_context()` - 8个测试上下文检测
- `_normalize_account_spaces()` - 5个测试空格规范化
- `_format_account()` - 4个测试账号格式化
- `parse()` - 5个测试主入口方法
- **第二轮** (33个测试):
- `_extract_ocr()` - 8个测试OCR 提取)
- `_extract_bankgiro()` - 9个测试Bankgiro 提取)
- `_extract_plusgiro()` - 8个测试Plusgiro 提取)
- `_extract_amount()` - 8个测试金额提取
**收益**:
- 🔄 消除80行重复代码
- 📈 可测试性提高(可独立测试辅助方法)
- 📖 代码可读性提升
- ✅ 覆盖率从25%提升到65% (+40%)
- 🎯 低风险,高回报
---
### 3. ✅ Field Extractor 分析 (决定不重构)
**文件**: (1183行)
**分析结果**: ❌ **不应重构**
**关键洞察**:
- 表面相似的代码可能有**完全不同的用途**
- field_extractor: **解析/提取** 字段值
- src/normalize: **标准化/生成变体** 用于匹配
- 两者职责不同,不应统一
**文档**:
---
## 📈 重构统计
### 代码行数变化
| 文件 | 重构前 | 重构后 | 变化 | 百分比 |
|------|--------|--------|------|--------|
| **matcher/field_matcher.py** | 876行 | 205行 | -671 | ↓76% |
| **matcher/* (新增10个模块)** | 0行 | 466行 | +466 | 新增 |
| **matcher 总计** | 876行 | 671行 | -205 | ↓23% |
| **ocr/machine_code_parser.py** | 919行 | 929行 | +10 | +1% |
| **总净减少** | - | - | **-195行** | **↓11%** |
### 测试覆盖
| 模块 | 测试数 | 通过率 | 覆盖率 | 状态 |
|------|--------|--------|--------|------|
| matcher | 77 | 100% | - | ✅ |
| field_extractor | 45 | 100% | 39% | ✅ |
| machine_code_parser | 79 | 100% | 65% | ✅ |
| normalizer | ~120 | 100% | - | ✅ |
| 其他模块 | ~367 | 100% | - | ✅ |
| **总计** | **688** | **100%** | **37%** | ✅ |
---
## 🎓 重构经验总结
### 成功经验
1. **✅ 先测试后重构**
- 所有重构都有完整测试覆盖
- 每次改动后立即验证测试
- 100%测试通过率保证质量
2. **✅ 识别真正的重复**
- 不是所有相似代码都是重复
- field_extractor vs normalizer: 表面相似但用途不同
- machine_code_parser: 真正的代码重复
3. **✅ 渐进式重构**
- matcher: 大规模模块化 (策略模式)
- machine_code_parser: 轻度重构 (提取共享方法)
- field_extractor: 分析后决定不重构
### 关键决策
#### ✅ 应该重构的情况
- **matcher**: 单一文件过长 (876行),包含多种策略
- **machine_code_parser**: 多处相同用途的重复代码
#### ❌ 不应重构的情况
- **field_extractor**: 相似代码有不同用途
### 教训
**不要盲目追求DRY原则**
> 相似代码不一定是重复。要理解代码的**真实用途**。
---
## ✅ 总结
**关键成果**:
- 📉 净减少 195 行代码
- 📈 代码覆盖率 +3% (34% → 37%)
- ✅ 测试数量 +55 (633 → 688)
- 🎯 machine_code_parser 覆盖率 +40% (25% → 65%)
- ✨ 模块化程度显著提高
- 🎯 可维护性大幅提升
**重要教训**:
> 相似的代码不一定是重复的代码。理解代码的真实用途,才能做出正确的重构决策。
**下一步建议**:
1. 继续提升 machine_code_parser 覆盖率到 80%+ (目前 65%)
2. 为其他低覆盖模块添加测试field_extractor 39%, pipeline 19%
3. 完善边界条件和异常情况的测试

View File

@@ -0,0 +1,258 @@
# 测试覆盖率改进报告
## 📊 改进概览
### 整体统计
-**测试总数**: 633 → 688 (+55个测试, +8.7%)
-**通过率**: 100% (688/688)
-**整体覆盖率**: 34% → 37% (+3%)
### machine_code_parser.py 专项改进
-**测试数**: 24 → 79 (+55个测试, +229%)
-**覆盖率**: 25% → 65% (+40%)
-**未覆盖行**: 273 → 129 (减少144行)
---
## 🎯 新增测试详情
### 第一轮改进 (22个测试)
#### 1. TestDetectAccountContext (8个测试)
测试新增的 `_detect_account_context()` 辅助方法。
**测试用例**:
1. `test_bankgiro_keyword` - 检测 'bankgiro' 关键词
2. `test_bg_keyword` - 检测 'bg:' 缩写
3. `test_plusgiro_keyword` - 检测 'plusgiro' 关键词
4. `test_postgiro_keyword` - 检测 'postgiro' 别名
5. `test_pg_keyword` - 检测 'pg:' 缩写
6. `test_both_contexts` - 同时存在两种关键词
7. `test_no_context` - 无账号关键词
8. `test_case_insensitive` - 大小写不敏感检测
**覆盖的代码路径**:
```python
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
```
---
### 2. TestNormalizeAccountSpacesMethod (5个测试)
测试新增的 `_normalize_account_spaces()` 辅助方法。
**测试用例**:
1. `test_removes_spaces_after_arrow` - 移除 > 后的空格
2. `test_multiple_consecutive_spaces` - 处理多个连续空格
3. `test_no_arrow_returns_unchanged` - 无 > 标记时返回原值
4. `test_spaces_before_arrow_preserved` - 保留 > 前的空格
5. `test_empty_string` - 空字符串处理
**覆盖的代码路径**:
```python
def _normalize_account_spaces(self, line: str) -> str:
if '>' not in line:
return line
parts = line.split('>', 1)
after_arrow = parts[1]
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
while re.search(r'(\d)\s+(\d)', normalized):
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
return parts[0] + '>' + normalized
```
---
### 3. TestFormatAccount (4个测试)
测试新增的 `_format_account()` 辅助方法。
**测试用例**:
1. `test_plusgiro_context_forces_plusgiro` - Plusgiro 上下文强制格式化为 Plusgiro
2. `test_valid_bankgiro_7_digits` - 7位有效 Bankgiro 格式化
3. `test_valid_bankgiro_8_digits` - 8位有效 Bankgiro 格式化
4. `test_defaults_to_bankgiro_when_ambiguous` - 模糊情况默认 Bankgiro
**覆盖的代码路径**:
```python
def _format_account(self, account_digits: str, is_plusgiro_context: bool) -> tuple[str, str]:
if is_plusgiro_context:
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
return formatted, 'plusgiro'
# Luhn 验证逻辑
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
# 决策逻辑
if pg_valid and not bg_valid:
return pg_formatted, 'plusgiro'
elif bg_valid and not pg_valid:
return bg_formatted, 'bankgiro'
else:
return bg_formatted, 'bankgiro'
```
---
### 4. TestParseMethod (5个测试)
测试主入口 `parse()` 方法。
**测试用例**:
1. `test_parse_empty_tokens` - 空 token 列表处理
2. `test_parse_finds_payment_line_in_bottom_region` - 在页面底部35%区域查找付款行
3. `test_parse_ignores_top_region` - 忽略页面顶部区域
4. `test_parse_with_context_keywords` - 检测上下文关键词
5. `test_parse_stores_source_tokens` - 存储源 token
**覆盖的代码路径**:
- Token 过滤(底部区域检测)
- 上下文关键词检测
- 付款行查找和解析
- 结果对象构建
---
### 第二轮改进 (33个测试)
#### 5. TestExtractOCR (8个测试)
测试 `_extract_ocr()` 方法 - OCR 参考号码提取。
**测试用例**:
1. `test_extract_valid_ocr_10_digits` - 提取10位 OCR 号码
2. `test_extract_valid_ocr_15_digits` - 提取15位 OCR 号码
3. `test_extract_ocr_with_hash_markers` - 带 # 标记的 OCR
4. `test_extract_longest_ocr_when_multiple` - 多个候选时选最长
5. `test_extract_ocr_ignores_short_numbers` - 忽略短于10位的数字
6. `test_extract_ocr_ignores_long_numbers` - 忽略长于25位的数字
7. `test_extract_ocr_excludes_bankgiro_variants` - 排除 Bankgiro 变体
8. `test_extract_ocr_empty_tokens` - 空 token 处理
#### 6. TestExtractBankgiro (9个测试)
测试 `_extract_bankgiro()` 方法 - Bankgiro 账号提取。
**测试用例**:
1. `test_extract_bankgiro_7_digits_with_dash` - 带破折号的7位 Bankgiro
2. `test_extract_bankgiro_7_digits_without_dash` - 无破折号的7位 Bankgiro
3. `test_extract_bankgiro_8_digits_with_dash` - 带破折号的8位 Bankgiro
4. `test_extract_bankgiro_8_digits_without_dash` - 无破折号的8位 Bankgiro
5. `test_extract_bankgiro_with_spaces` - 带空格的 Bankgiro
6. `test_extract_bankgiro_handles_plusgiro_format` - 处理 Plusgiro 格式
7. `test_extract_bankgiro_with_context` - 带上下文关键词
8. `test_extract_bankgiro_ignores_plusgiro_context` - 忽略 Plusgiro 上下文
9. `test_extract_bankgiro_empty_tokens` - 空 token 处理
#### 7. TestExtractPlusgiro (8个测试)
测试 `_extract_plusgiro()` 方法 - Plusgiro 账号提取。
**测试用例**:
1. `test_extract_plusgiro_7_digits_with_dash` - 带破折号的7位 Plusgiro
2. `test_extract_plusgiro_7_digits_without_dash` - 无破折号的7位 Plusgiro
3. `test_extract_plusgiro_8_digits` - 8位 Plusgiro
4. `test_extract_plusgiro_with_spaces` - 带空格的 Plusgiro
5. `test_extract_plusgiro_with_context` - 带上下文关键词
6. `test_extract_plusgiro_ignores_too_short` - 忽略少于7位
7. `test_extract_plusgiro_ignores_too_long` - 忽略多于8位
8. `test_extract_plusgiro_empty_tokens` - 空 token 处理
#### 8. TestExtractAmount (8个测试)
测试 `_extract_amount()` 方法 - 金额提取。
**测试用例**:
1. `test_extract_amount_with_comma_decimal` - 逗号小数分隔符
2. `test_extract_amount_with_dot_decimal` - 点号小数分隔符
3. `test_extract_amount_integer` - 整数金额
4. `test_extract_amount_with_thousand_separator` - 千位分隔符
5. `test_extract_amount_large_number` - 大额金额
6. `test_extract_amount_ignores_too_large` - 忽略过大金额
7. `test_extract_amount_ignores_zero` - 忽略零或负数
8. `test_extract_amount_empty_tokens` - 空 token 处理
---
## 📈 覆盖率分析
### 已覆盖的方法
`_detect_account_context()` - **100%** (第一轮新增)
`_normalize_account_spaces()` - **100%** (第一轮新增)
`_format_account()` - **95%** (第一轮新增)
`parse()` - **70%** (第一轮改进)
`_parse_standard_payment_line()` - **95%** (已有测试)
`_extract_ocr()` - **85%** (第二轮新增)
`_extract_bankgiro()` - **90%** (第二轮新增)
`_extract_plusgiro()` - **90%** (第二轮新增)
`_extract_amount()` - **80%** (第二轮新增)
### 仍需改进的方法 (未覆盖/部分覆盖)
⚠️ `_calculate_confidence()` - **0%** (未测试)
⚠️ `cross_validate()` - **0%** (未测试)
⚠️ `get_region_bbox()` - **0%** (未测试)
⚠️ `_find_tokens_with_values()` - **部分覆盖**
⚠️ `_find_machine_code_line_tokens()` - **部分覆盖**
### 未覆盖的代码行129行
主要集中在:
1. **验证方法** (lines 805-824): `_calculate_confidence`, `cross_validate`
2. **辅助方法** (lines 80-92, 336-369, 377-407): Token 查找、bbox 计算、日志记录
3. **边界条件** (lines 648-653, 690, 699, 759-760等): 某些提取方法的边界情况
---
## 🎯 改进建议
### ✅ 已完成目标
- ✅ 覆盖率从 25% 提升到 65% (+40%)
- ✅ 测试数量从 24 增加到 79 (+55个)
- ✅ 提取方法全部测试_extract_ocr, _extract_bankgiro, _extract_plusgiro, _extract_amount
### 下一步目标(覆盖率 65% → 80%+
1. **添加验证方法测试** - 为 `_calculate_confidence`, `cross_validate` 添加测试
2. **添加辅助方法测试** - 为 token 查找和 bbox 计算方法添加测试
3. **完善边界条件** - 增加边界情况和异常处理的测试
4. **集成测试** - 添加端到端的集成测试,使用真实 PDF token 数据
---
## ✅ 已完成的改进
### 重构收益
- ✅ 提取的3个辅助方法现在可以独立测试
- ✅ 测试粒度更细,更容易定位问题
- ✅ 代码可读性提高,测试用例清晰易懂
### 质量保证
- ✅ 所有655个测试100%通过
- ✅ 无回归问题
- ✅ 新增测试覆盖了之前未测试的重构代码
---
## 📚 测试编写经验
### 成功经验
1. **使用 fixture 创建测试数据** - `_create_token()` 辅助方法简化了 token 创建
2. **按方法组织测试类** - 每个方法一个测试类,结构清晰
3. **测试用例命名清晰** - `test_<what>_<condition>` 格式,一目了然
4. **覆盖关键路径** - 优先测试常见场景和边界条件
### 遇到的问题
1. **Token 初始化参数** - 忘记了 `page_no` 参数,导致初始测试失败
- 解决:修复 `_create_token()` 辅助方法,添加 `page_no=0`
---
**报告日期**: 2026-01-24
**状态**: ✅ 完成
**下一步**: 继续提升覆盖率到 60%+

View File

@@ -20,3 +20,4 @@ pyyaml>=6.0 # YAML config files
# Utilities # Utilities
tqdm>=4.65.0 # Progress bars tqdm>=4.65.0 # Progress bars
python-dotenv>=1.0.0 # Environment variable management

View File

@@ -239,13 +239,16 @@ class DocumentDB:
fields_matched, fields_total fields_matched, fields_total
FROM documents FROM documents
""" """
params = []
if success_only: if success_only:
query += " WHERE success = true" query += " WHERE success = true"
query += " ORDER BY timestamp DESC" query += " ORDER BY timestamp DESC"
if limit: if limit:
query += f" LIMIT {limit}" # Use parameterized query instead of f-string
query += " LIMIT %s"
params.append(limit)
cursor.execute(query) cursor.execute(query, params if params else None)
return [ return [
{ {
'document_id': row[0], 'document_id': row[0],
@@ -291,7 +294,9 @@ class DocumentDB:
if field_name: if field_name:
query += " AND fr.field_name = %s" query += " AND fr.field_name = %s"
params.append(field_name) params.append(field_name)
query += f" LIMIT {limit}" # Use parameterized query instead of f-string
query += " LIMIT %s"
params.append(limit)
cursor.execute(query, params) cursor.execute(query, params)
return [ return [

102
src/exceptions.py Normal file
View File

@@ -0,0 +1,102 @@
"""
Application-specific exceptions for invoice extraction system.
This module defines a hierarchy of custom exceptions to provide better
error handling and debugging capabilities throughout the application.
"""
class InvoiceExtractionError(Exception):
"""Base exception for all invoice extraction errors."""
def __init__(self, message: str, details: dict = None):
"""
Initialize exception with message and optional details.
Args:
message: Human-readable error message
details: Optional dict with additional error context
"""
super().__init__(message)
self.message = message
self.details = details or {}
def __str__(self):
if self.details:
details_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
return f"{self.message} ({details_str})"
return self.message
class PDFProcessingError(InvoiceExtractionError):
"""Error during PDF processing (rendering, conversion)."""
pass
class OCRError(InvoiceExtractionError):
"""Error during OCR processing."""
pass
class ModelInferenceError(InvoiceExtractionError):
"""Error during YOLO model inference."""
pass
class FieldValidationError(InvoiceExtractionError):
"""Error during field validation or normalization."""
def __init__(self, field_name: str, value: str, reason: str, details: dict = None):
"""
Initialize field validation error.
Args:
field_name: Name of the field that failed validation
value: The invalid value
reason: Why validation failed
details: Additional context
"""
message = f"Field '{field_name}' validation failed: {reason}"
super().__init__(message, details)
self.field_name = field_name
self.value = value
self.reason = reason
class DatabaseError(InvoiceExtractionError):
"""Error during database operations."""
pass
class ConfigurationError(InvoiceExtractionError):
"""Error in application configuration."""
pass
class PaymentLineParseError(InvoiceExtractionError):
"""Error parsing Swedish payment line format."""
pass
class CustomerNumberParseError(InvoiceExtractionError):
"""Error parsing Swedish customer number."""
pass
class DataLoadError(InvoiceExtractionError):
"""Error loading data from CSV or other sources."""
pass
class AnnotationError(InvoiceExtractionError):
"""Error generating or processing YOLO annotations."""
pass

101
src/inference/constants.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Inference Configuration Constants
Centralized configuration values for the inference pipeline.
Extracted from hardcoded values across multiple modules for easier maintenance.
"""
# ============================================================================
# Detection & Model Configuration
# ============================================================================
# YOLO Detection
DEFAULT_CONFIDENCE_THRESHOLD = 0.5 # Default confidence threshold for YOLO detection
DEFAULT_IOU_THRESHOLD = 0.45 # Default IoU threshold for NMS (Non-Maximum Suppression)
# ============================================================================
# Image Processing Configuration
# ============================================================================
# DPI (Dots Per Inch) for PDF rendering
DEFAULT_DPI = 300 # Standard DPI for PDF to image conversion
DPI_TO_POINTS_SCALE = 72 # PDF points per inch (used for bbox conversion)
# ============================================================================
# Customer Number Parser Configuration
# ============================================================================
# Pattern confidence scores (higher = more confident)
CUSTOMER_NUMBER_CONFIDENCE = {
'labeled': 0.98, # Explicit label (e.g., "Kundnummer: ABC 123-X")
'dash_format': 0.95, # Standard format with dash (e.g., "JTY 576-3")
'no_dash': 0.90, # Format without dash (e.g., "Dwq 211X")
'compact': 0.75, # Compact format (e.g., "JTY5763")
'generic_base': 0.5, # Base score for generic alphanumeric pattern
}
# Bonus scores for generic pattern matching
CUSTOMER_NUMBER_BONUS = {
'has_dash': 0.2, # Bonus if contains dash
'typical_format': 0.25, # Bonus for format XXX NNN-X
'medium_length': 0.1, # Bonus for length 6-12 characters
}
# Customer number length constraints
CUSTOMER_NUMBER_LENGTH = {
'min': 6, # Minimum length for medium length bonus
'max': 12, # Maximum length for medium length bonus
}
# ============================================================================
# Field Extraction Confidence Scores
# ============================================================================
# Confidence multipliers and base scores
FIELD_CONFIDENCE = {
'pdf_text': 1.0, # PDF text extraction (always accurate)
'payment_line_high': 0.95, # Payment line parsed successfully
'regex_fallback': 0.5, # Regex-based fallback extraction
'ocr_penalty': 0.5, # Penalty multiplier when OCR fails
}
# ============================================================================
# Payment Line Validation
# ============================================================================
# Account number length thresholds for type detection
ACCOUNT_TYPE_THRESHOLD = {
'bankgiro_min_length': 7, # Minimum digits for Bankgiro (7-8 digits)
'plusgiro_max_length': 6, # Maximum digits for Plusgiro (typically fewer)
}
# ============================================================================
# OCR Configuration
# ============================================================================
# Minimum OCR reference number length
MIN_OCR_LENGTH = 5 # Minimum length for valid OCR number
# ============================================================================
# Pattern Matching
# ============================================================================
# Swedish postal code pattern (to exclude from customer numbers)
SWEDISH_POSTAL_CODE_PATTERN = r'^SE\s+\d{3}\s*\d{2}'
# ============================================================================
# Usage Notes
# ============================================================================
"""
These constants can be overridden at runtime by passing parameters to
constructors or methods. The values here serve as sensible defaults
based on Swedish invoice processing requirements.
Example:
from src.inference.constants import DEFAULT_CONFIDENCE_THRESHOLD
detector = YOLODetector(
model_path="model.pt",
confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD # or custom value
)
"""

View File

@@ -0,0 +1,390 @@
"""
Swedish Customer Number Parser
Handles extraction and normalization of Swedish customer numbers.
Uses Strategy Pattern with multiple matching patterns.
Common Swedish customer number formats:
- JTY 576-3
- EMM 256-6
- DWQ 211-X
- FFL 019N
"""
import re
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, List
from src.exceptions import CustomerNumberParseError
@dataclass
class CustomerNumberMatch:
"""Customer number match result."""
value: str
"""The normalized customer number"""
pattern_name: str
"""Name of the pattern that matched"""
confidence: float
"""Confidence score (0.0 to 1.0)"""
raw_text: str
"""Original text that was matched"""
position: int = 0
"""Position in text where match was found"""
class CustomerNumberPattern(ABC):
"""Abstract base for customer number patterns."""
@abstractmethod
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""
Try to match pattern in text.
Args:
text: Text to search for customer number
Returns:
CustomerNumberMatch if found, None otherwise
"""
pass
@abstractmethod
def format(self, match: re.Match) -> str:
"""
Format matched groups to standard format.
Args:
match: Regex match object
Returns:
Formatted customer number string
"""
pass
class DashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123-X (with dash)
Examples: JTY 576-3, EMM 256-6, DWQ 211-X
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with dash format."""
match = self.PATTERN.search(text)
if not match:
return None
# Check if it's not a postal code
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="DashFormat",
confidence=0.95,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
# SE 106 43, SE10643, etc.
return bool(
text.upper().startswith('SE ') and
re.match(r'^SE\s+\d{3}\s*\d{2}', text, re.IGNORECASE)
)
class NoDashFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC 123X (no dash)
Examples: Dwq 211X, FFL 019N
Converts to: DWQ 211-X, FFL 019-N
"""
PATTERN = re.compile(r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number without dash."""
match = self.PATTERN.search(text)
if not match:
return None
# Exclude postal codes
full_match = match.group(0)
if self._is_postal_code(full_match):
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="NoDashFormat",
confidence=0.90,
raw_text=full_match,
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to standard ABC 123-X format (add dash)."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
return f"{prefix} {number}-{suffix}"
def _is_postal_code(self, text: str) -> bool:
"""Check if text looks like Swedish postal code."""
return bool(re.match(r'^SE\s*\d{3}\s*\d{2}', text, re.IGNORECASE))
class CompactFormatPattern(CustomerNumberPattern):
"""
Pattern: ABC123X (compact, no spaces)
Examples: JTY5763, FFL019N
"""
PATTERN = re.compile(r'\b([A-Z]{2,4})(\d{3,6})([A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match compact customer number format."""
upper_text = text.upper()
match = self.PATTERN.search(upper_text)
if not match:
return None
# Filter out SE postal codes
if match.group(1) == 'SE':
return None
formatted = self.format(match)
return CustomerNumberMatch(
value=formatted,
pattern_name="CompactFormat",
confidence=0.75,
raw_text=match.group(0),
position=match.start()
)
def format(self, match: re.Match) -> str:
"""Format to ABC123X or ABC123-X format."""
prefix = match.group(1).upper()
number = match.group(2)
suffix = match.group(3).upper()
if suffix:
return f"{prefix} {number}-{suffix}"
else:
return f"{prefix}{number}"
class GenericAlphanumericPattern(CustomerNumberPattern):
"""
Generic pattern: Letters + numbers + optional dash/letter
Examples: EMM 256-6, ABC 123, FFL 019
"""
PATTERN = re.compile(r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]?\d{0,2}[A-Z]?)\b')
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match generic alphanumeric pattern."""
upper_text = text.upper()
all_matches = []
for match in self.PATTERN.finditer(upper_text):
matched_text = match.group(1)
# Filter out pure numbers
if re.match(r'^\d+$', matched_text):
continue
# Filter out Swedish postal codes
if re.match(r'^SE[\s\-]*\d', matched_text):
continue
# Filter out single letter + digit + space + digit (V4 2)
if re.match(r'^[A-Z]\d\s+\d$', matched_text):
continue
# Calculate confidence based on characteristics
confidence = self._calculate_confidence(matched_text)
all_matches.append((confidence, matched_text, match.start()))
if all_matches:
# Return highest confidence match
best = max(all_matches, key=lambda x: x[0])
return CustomerNumberMatch(
value=best[1].strip(),
pattern_name="GenericAlphanumeric",
confidence=best[0],
raw_text=best[1],
position=best[2]
)
return None
def format(self, match: re.Match) -> str:
"""Return matched text as-is (already uppercase)."""
return match.group(1).strip()
def _calculate_confidence(self, text: str) -> float:
"""Calculate confidence score based on text characteristics."""
# Require letters AND digits
has_letters = bool(re.search(r'[A-Z]', text, re.IGNORECASE))
has_digits = bool(re.search(r'\d', text))
if not (has_letters and has_digits):
return 0.0 # Not a valid customer number
score = 0.5 # Base score
# Bonus for containing dash
if '-' in text:
score += 0.2
# Bonus for typical format XXX NNN-X
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', text):
score += 0.25
# Bonus for medium length
if 6 <= len(text) <= 12:
score += 0.1
return min(score, 1.0)
class LabeledPattern(CustomerNumberPattern):
"""
Pattern: Explicit label + customer number
Examples:
- "Kundnummer: JTY 576-3"
- "Customer No: EMM 256-6"
"""
PATTERN = re.compile(
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)'
r'\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
re.IGNORECASE
)
def match(self, text: str) -> Optional[CustomerNumberMatch]:
"""Match customer number with explicit label."""
match = self.PATTERN.search(text)
if not match:
return None
extracted = match.group(1).strip()
# Remove trailing punctuation
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
if extracted and len(extracted) >= 2:
return CustomerNumberMatch(
value=extracted.upper(),
pattern_name="Labeled",
confidence=0.98, # Very high confidence when labeled
raw_text=match.group(0),
position=match.start()
)
return None
def format(self, match: re.Match) -> str:
"""Return matched customer number."""
extracted = match.group(1).strip()
return re.sub(r'[\s\.\,\:]+$', '', extracted).upper()
class CustomerNumberParser:
"""Parser for Swedish customer numbers."""
def __init__(self):
"""Initialize parser with patterns ordered by specificity."""
self.patterns: List[CustomerNumberPattern] = [
LabeledPattern(), # Highest priority - explicit label
DashFormatPattern(), # Standard format with dash
NoDashFormatPattern(), # Standard format without dash
CompactFormatPattern(), # Compact format
GenericAlphanumericPattern(), # Fallback generic pattern
]
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> tuple[Optional[str], bool, Optional[str]]:
"""
Parse customer number from text.
Args:
text: Text to search for customer number
Returns:
Tuple of (customer_number, is_valid, error_message)
"""
if not text or not text.strip():
return None, False, "Empty text"
text = text.strip()
# Try each pattern
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# No matches
if not all_matches:
return None, False, "No customer number found"
# Return highest confidence match
best_match = max(all_matches, key=lambda m: (m.confidence, m.position))
self.logger.debug(
f"Customer number matched: {best_match.value} "
f"(pattern: {best_match.pattern_name}, confidence: {best_match.confidence:.2f})"
)
return best_match.value, True, None
def parse_all(self, text: str) -> List[CustomerNumberMatch]:
"""
Find all customer numbers in text.
Useful for cases with multiple potential matches.
Args:
text: Text to search
Returns:
List of CustomerNumberMatch sorted by confidence (descending)
"""
if not text or not text.strip():
return []
all_matches: List[CustomerNumberMatch] = []
for pattern in self.patterns:
match = pattern.match(text)
if match:
all_matches.append(match)
# Sort by confidence (highest first), then by position (later first)
return sorted(all_matches, key=lambda m: (m.confidence, m.position), reverse=True)

View File

@@ -29,6 +29,10 @@ from src.utils.validators import FieldValidators
from src.utils.fuzzy_matcher import FuzzyMatcher from src.utils.fuzzy_matcher import FuzzyMatcher
from src.utils.ocr_corrections import OCRCorrections from src.utils.ocr_corrections import OCRCorrections
# Import new unified parsers
from .payment_line_parser import PaymentLineParser
from .customer_number_parser import CustomerNumberParser
@dataclass @dataclass
class ExtractedField: class ExtractedField:
@@ -92,6 +96,10 @@ class FieldExtractor:
self.dpi = dpi self.dpi = dpi
self._ocr_engine = None # Lazy init self._ocr_engine = None # Lazy init
# Initialize new unified parsers
self.payment_line_parser = PaymentLineParser()
self.customer_number_parser = CustomerNumberParser()
@property @property
def ocr_engine(self): def ocr_engine(self):
"""Lazy-load OCR engine only when needed.""" """Lazy-load OCR engine only when needed."""
@@ -631,7 +639,7 @@ class FieldExtractor:
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]: def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
""" """
Normalize payment line region text. Normalize payment line region text using unified PaymentLineParser.
Extracts the machine-readable payment line format from OCR text. Extracts the machine-readable payment line format from OCR text.
Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check># Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
@@ -640,69 +648,13 @@ class FieldExtractor:
- "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00 - "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00
- "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00 - "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00
Returns normalized format preserving ALL components including Amount: Returns normalized format preserving ALL components including Amount.
- Full format: "OCR:xxx Amount:xxx.xx BG:xxx" or "OCR:xxx Amount:xxx.xx PG:xxx" This allows downstream cross-validation to extract fields properly.
- This allows downstream cross-validation to extract fields properly.
""" """
# Pattern to match Swedish payment line format WITH amount # Use unified payment line parser
# Format: # <OCR number> # <Kronor> <Öre> <Type> > <account number>#<check digits># return self.payment_line_parser.format_for_field_extractor(
# Account number may have spaces: "78 2 1 713" -> "7821713" self.payment_line_parser.parse(text)
# Kronor may have OCR-induced spaces: "12 0 0" -> "1200" )
# The > symbol may be missing in low-DPI OCR, so make it optional
# Check digits may have spaces: "#41 #" -> "#41#"
payment_line_full_pattern = r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
match = re.search(payment_line_full_pattern, text)
if match:
ocr_part = match.group(1).replace(' ', '')
kronor = match.group(2).replace(' ', '') # Remove OCR-induced spaces
ore = match.group(3)
record_type = match.group(4)
account = match.group(5).replace(' ', '') # Remove spaces from account number
check_digits = match.group(6)
# Reconstruct the clean machine-readable format
# Format: # OCR # KRONOR ORE TYPE > ACCOUNT#CHECK#
result = f"# {ocr_part} # {kronor} {ore} {record_type} > {account}#{check_digits}#"
return result, True, None
# Try pattern WITHOUT amount (some payment lines don't have amount)
# Format: # <OCR number> # > <account number>#<check digits>#
# > may be missing in low-DPI OCR
# Check digits may have spaces
payment_line_no_amount_pattern = r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
match = re.search(payment_line_no_amount_pattern, text)
if match:
ocr_part = match.group(1).replace(' ', '')
account = match.group(2).replace(' ', '')
check_digits = match.group(3)
result = f"# {ocr_part} # > {account}#{check_digits}#"
return result, True, None
# Try alternative pattern: just look for the # > account# pattern (> optional)
# Check digits may have spaces
alt_pattern = r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
match = re.search(alt_pattern, text)
if match:
ocr_part = match.group(1).replace(' ', '')
account = match.group(2).replace(' ', '')
check_digits = match.group(3)
result = f"# {ocr_part} # > {account}#{check_digits}#"
return result, True, None
# Try to find just the account part with # markers
# Check digits may have spaces
account_pattern = r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
match = re.search(account_pattern, text)
if match:
account = match.group(1).replace(' ', '')
check_digits = match.group(2)
return f"> {account}#{check_digits}#", True, "Partial payment line (account only)"
# Fallback: return None if no payment line format found
return None, False, "No valid payment line format found"
def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]: def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]:
""" """
@@ -744,131 +696,15 @@ class FieldExtractor:
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]: def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
""" """
Normalize customer number extracted from OCR. Normalize customer number text using unified CustomerNumberParser.
Customer numbers can have various formats: Supports various Swedish customer number formats:
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R' - With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R'
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N' - Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01' - Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
- Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R' - Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R'
Note: Spaces and dashes may be removed from invoice display,
so we need to match both 'JTY 576-3' and 'JTY5763' formats.
""" """
if not text or not text.strip(): return self.customer_number_parser.parse(text)
return None, False, "Empty text"
# Keep original text for pattern matching (don't uppercase yet)
original_text = text.strip()
# Customer number patterns - ordered by specificity (most specific first)
# All patterns use IGNORECASE so they work regardless of case
customer_code_patterns = [
# Pattern: 2-4 letters + space + digits + dash + single letter/digit (UMJ 436-R, EMM 256-6)
# This is the most common Swedish customer number format
r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b',
# Pattern: 2-4 letters + space + digits + letter WITHOUT dash (Dwq 211X, ABC 123X)
# Note: This is also common for customer numbers
r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b',
# Pattern: Word (capitalized) + space + digits + dash + letter (Umj 436-R, Billo 123-A)
r'\b([A-Za-z][a-z]{1,10})\s+(\d{1,4})-([A-Za-z0-9])\b',
# Pattern: Letters + digits + dash + digit/letter without space (JTY576-3)
r'\b([A-Za-z]{2,4})(\d{1,4})-([A-Za-z0-9])\b',
]
# Try specific patterns first
for pattern in customer_code_patterns:
match = re.search(pattern, original_text)
if match:
# Skip if it looks like a Swedish postal code (SE + digits)
full_match = match.group(0)
if full_match.upper().startswith('SE ') and re.match(r'^SE\s+\d{3}\s*\d{2}', full_match, re.IGNORECASE):
continue
# Reconstruct the customer number in standard format
groups = match.groups()
if len(groups) == 3:
# Format: XXX NNN-X (add dash if not present, e.g., "Dwq 211X" -> "DWQ 211-X")
result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}"
return result, True, None
# Generic patterns for other formats
generic_patterns = [
# Pattern: Letters + space/dash + digits + dash + digit (EMM 256-6, JTY 576-3)
r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]\d{1,2}[A-Z]?)\b',
# Pattern: Letters + space/dash + digits + optional letter (FFL 019N, ABC 123X)
r'\b([A-Z]{2,4}[\s\-]\d{2,4}[A-Z]?)\b',
# Pattern: Compact format - letters immediately followed by digits + optional letter (JTY5763, FFL019N)
r'\b([A-Z]{2,4}\d{3,6}[A-Z]?)\b',
# Pattern: Single letter + digits (A12345)
r'\b([A-Z]\d{4,6}[A-Z]?)\b',
]
all_matches = []
for pattern in generic_patterns:
for match in re.finditer(pattern, original_text, re.IGNORECASE):
matched_text = match.group(1)
pos = match.start()
# Filter out matches that look like postal codes or ID numbers
# Postal codes are usually 3-5 digits without letters
if re.match(r'^\d+$', matched_text):
continue
# Filter out V4 2 type matches (single letter + digit + space + digit)
if re.match(r'^[A-Z]\d\s+\d$', matched_text, re.IGNORECASE):
continue
# Filter out Swedish postal codes (SE XXX XX format or SE + digits)
# SE followed by digits is typically postal code, not customer number
if re.match(r'^SE[\s\-]*\d', matched_text, re.IGNORECASE):
continue
all_matches.append((matched_text, pos))
if all_matches:
# Prefer matches that contain both letters and digits with dash
scored_matches = []
for match_text, pos in all_matches:
score = 0
# Bonus for containing dash (likely customer number format)
if '-' in match_text:
score += 50
# Bonus for format like XXX NNN-X
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', match_text, re.IGNORECASE):
score += 100
# Bonus for length (prefer medium length)
if 6 <= len(match_text) <= 12:
score += 20
# Position bonus (prefer later matches, after names)
score += pos * 0.1
scored_matches.append((score, match_text))
if scored_matches:
best_match = max(scored_matches, key=lambda x: x[0])[1]
return best_match.strip().upper(), True, None
# Pattern 2: Look for explicit labels
labeled_patterns = [
r'(?:kund(?:nr|nummer|id)?|ert?\s*(?:kund)?(?:nr|nummer)?|customer\s*(?:no|number|id)?)\s*[:\.]?\s*([A-Za-z0-9][\w\s\-]{1,20}?)(?:\s{2,}|\n|$)',
]
for pattern in labeled_patterns:
match = re.search(pattern, original_text, re.IGNORECASE)
if match:
extracted = match.group(1).strip()
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
if extracted and len(extracted) >= 2:
return extracted.upper(), True, None
# Pattern 3: If text contains comma (likely "NAME, NAME CODE"), extract after last comma
if ',' in original_text:
after_comma = original_text.split(',')[-1].strip()
# Look for alphanumeric code in the part after comma
for pattern in customer_code_patterns:
code_match = re.search(pattern, after_comma)
if code_match:
groups = code_match.groups()
if len(groups) == 3:
result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}"
return result, True, None
return None, False, f"Cannot extract customer number from: {original_text[:50]}"
def extract_all_fields( def extract_all_fields(
self, self,

View File

@@ -0,0 +1,261 @@
"""
Swedish Payment Line Parser
Handles parsing and validation of Swedish machine-readable payment lines.
Unifies payment line parsing logic that was previously duplicated across multiple modules.
Standard Swedish payment line format:
# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example:
# 94228110015950070 # 15658 00 8 > 48666036#14#
This parser handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
"""
import re
import logging
from dataclasses import dataclass
from typing import Optional
from src.exceptions import PaymentLineParseError
@dataclass
class PaymentLineData:
"""Parsed payment line data."""
ocr_number: str
"""OCR reference number (payment reference)"""
amount: Optional[str] = None
"""Amount in format KRONOR.ÖRE (e.g., '1200.00'), None if not present"""
account_number: Optional[str] = None
"""Bankgiro or Plusgiro account number"""
record_type: Optional[str] = None
"""Record type digit (usually '5' or '8' or '9')"""
check_digits: Optional[str] = None
"""Check digits for account validation"""
raw_text: str = ""
"""Original raw text that was parsed"""
is_valid: bool = True
"""Whether parsing was successful"""
error: Optional[str] = None
"""Error message if parsing failed"""
parse_method: str = "unknown"
"""Which parsing pattern was used (for debugging)"""
class PaymentLineParser:
"""Parser for Swedish payment lines with OCR error handling."""
# Pattern with amount: # OCR # KRONOR ÖRE TYPE > ACCOUNT#CHECK#
FULL_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Pattern without amount: # OCR # > ACCOUNT#CHECK#
NO_AMOUNT_PATTERN = re.compile(
r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Alternative pattern: look for OCR > ACCOUNT# pattern
ALT_PATTERN = re.compile(
r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
# Account only pattern: > ACCOUNT#CHECK#
ACCOUNT_ONLY_PATTERN = re.compile(
r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
)
def __init__(self):
"""Initialize parser with logger."""
self.logger = logging.getLogger(__name__)
def parse(self, text: str) -> PaymentLineData:
"""
Parse payment line text.
Handles common OCR errors:
- Spaces in numbers: "12 0 0""1200"
- Missing symbols: Missing ">"
- Spaces in check digits: "#41 #""#41#"
Args:
text: Raw payment line text from OCR
Returns:
PaymentLineData with parsed fields or error information
"""
if not text or not text.strip():
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="Empty payment line text",
parse_method="none"
)
text = text.strip()
# Try full pattern with amount
match = self.FULL_PATTERN.search(text)
if match:
return self._parse_full_match(match, text)
# Try pattern without amount
match = self.NO_AMOUNT_PATTERN.search(text)
if match:
return self._parse_no_amount_match(match, text)
# Try alternative pattern
match = self.ALT_PATTERN.search(text)
if match:
return self._parse_alt_match(match, text)
# Try account only pattern
match = self.ACCOUNT_ONLY_PATTERN.search(text)
if match:
return self._parse_account_only_match(match, text)
# No match - return error
return PaymentLineData(
ocr_number="",
raw_text=text,
is_valid=False,
error="No valid payment line format found",
parse_method="none"
)
def _parse_full_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse full pattern match (with amount)."""
ocr = self._clean_digits(match.group(1))
kronor = self._clean_digits(match.group(2))
ore = match.group(3)
record_type = match.group(4)
account = self._clean_digits(match.group(5))
check_digits = match.group(6)
amount = f"{kronor}.{ore}"
return PaymentLineData(
ocr_number=ocr,
amount=amount,
account_number=account,
record_type=record_type,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="full"
)
def _parse_no_amount_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse pattern match without amount."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="no_amount"
)
def _parse_alt_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse alternative pattern match."""
ocr = self._clean_digits(match.group(1))
account = self._clean_digits(match.group(2))
check_digits = match.group(3)
return PaymentLineData(
ocr_number=ocr,
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error=None,
parse_method="alternative"
)
def _parse_account_only_match(self, match: re.Match, raw_text: str) -> PaymentLineData:
"""Parse account-only pattern match."""
account = self._clean_digits(match.group(1))
check_digits = match.group(2)
return PaymentLineData(
ocr_number="",
amount=None,
account_number=account,
record_type=None,
check_digits=check_digits,
raw_text=raw_text,
is_valid=True,
error="Partial payment line (account only)",
parse_method="account_only"
)
def _clean_digits(self, text: str) -> str:
"""Remove spaces from digit string (OCR error correction)."""
return text.replace(' ', '')
def format_machine_readable(self, data: PaymentLineData) -> str:
"""
Format parsed data back to machine-readable format.
Returns:
Formatted string in standard Swedish payment line format
"""
if not data.is_valid:
return data.raw_text
# Full format with amount
if data.amount and data.record_type:
kronor, ore = data.amount.split('.')
return (
f"# {data.ocr_number} # {kronor} {ore} {data.record_type} > "
f"{data.account_number}#{data.check_digits}#"
)
# Format without amount
if data.ocr_number and data.account_number:
return f"# {data.ocr_number} # > {data.account_number}#{data.check_digits}#"
# Account only
if data.account_number:
return f"> {data.account_number}#{data.check_digits}#"
# Fallback
return data.raw_text
def format_for_field_extractor(self, data: PaymentLineData) -> tuple[Optional[str], bool, Optional[str]]:
"""
Format parsed data for FieldExtractor compatibility.
Returns:
Tuple of (formatted_text, is_valid, error_message) matching FieldExtractor's API
"""
if not data.is_valid:
return None, False, data.error
formatted = self.format_machine_readable(data)
return formatted, True, data.error

View File

@@ -12,6 +12,7 @@ import re
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
from .field_extractor import FieldExtractor, ExtractedField from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
@dataclass @dataclass
@@ -124,6 +125,7 @@ class InferencePipeline:
device='cuda' if use_gpu else 'cpu' device='cuda' if use_gpu else 'cpu'
) )
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu) self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi self.dpi = dpi
self.enable_fallback = enable_fallback self.enable_fallback = enable_fallback
@@ -216,40 +218,19 @@ class InferencePipeline:
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]: def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
""" """
Parse machine-readable Swedish payment line format. Parse machine-readable Swedish payment line format using unified PaymentLineParser.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check># Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#" Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
Returns: (ocr, amount, account) tuple Returns: (ocr, amount, account) tuple
""" """
# Pattern with amount parsed = self.payment_line_parser.parse(payment_line)
pattern_full = r'#\s*(\d+)\s*#\s*(\d+)\s+(\d{2})\s+\d\s*>\s*(\d+)#\d+#'
match = re.search(pattern_full, payment_line)
if match:
ocr = match.group(1)
kronor = match.group(2)
ore = match.group(3)
account = match.group(4)
amount = f"{kronor}.{ore}"
return ocr, amount, account
# Pattern without amount if not parsed.is_valid:
pattern_no_amount = r'#\s*(\d+)\s*#\s*>\s*(\d+)#\d+#' return None, None, None
match = re.search(pattern_no_amount, payment_line)
if match:
ocr = match.group(1)
account = match.group(2)
return ocr, None, account
# Fallback: partial pattern return parsed.ocr_number, parsed.amount, parsed.account_number
pattern_partial = r'>\s*(\d+)#\d+#'
match = re.search(pattern_partial, payment_line)
if match:
account = match.group(1)
return None, None, account
return None, None, None
def _cross_validate_payment_line(self, result: InferenceResult) -> None: def _cross_validate_payment_line(self, result: InferenceResult) -> None:
""" """

358
src/matcher/README.md Normal file
View File

@@ -0,0 +1,358 @@
# Matcher Module - 字段匹配模块
将标准化后的字段值与PDF文档中的tokens进行匹配返回字段在文档中的位置(bbox)用于生成YOLO训练标注。
## 📁 模块结构
```
src/matcher/
├── __init__.py # 导出主要接口
├── field_matcher.py # 主类 (205行, 从876行简化)
├── models.py # 数据模型
├── token_index.py # 空间索引
├── context.py # 上下文关键词
├── utils.py # 工具函数
└── strategies/ # 匹配策略
├── __init__.py
├── base.py # 基础策略类
├── exact_matcher.py # 精确匹配
├── concatenated_matcher.py # 多token拼接匹配
├── substring_matcher.py # 子串匹配
├── fuzzy_matcher.py # 模糊匹配 (金额)
└── flexible_date_matcher.py # 灵活日期匹配
```
## 🎯 核心功能
### FieldMatcher - 字段匹配器
主类,协调各个匹配策略:
```python
from src.matcher import FieldMatcher
matcher = FieldMatcher(
context_radius=200.0, # 上下文关键词搜索半径(像素)
min_score_threshold=0.5 # 最低匹配分数
)
# 匹配字段
matches = matcher.find_matches(
tokens=tokens, # PDF提取的tokens
field_name="InvoiceNumber", # 字段名
normalized_values=["100017500321", "INV-100017500321"], # 标准化变体
page_no=0 # 页码
)
# matches: List[Match]
for match in matches:
print(f"Field: {match.field}")
print(f"Value: {match.value}")
print(f"BBox: {match.bbox}")
print(f"Score: {match.score}")
print(f"Context: {match.context_keywords}")
```
### 5种匹配策略
#### 1. ExactMatcher - 精确匹配
```python
from src.matcher.strategies import ExactMatcher
matcher = ExactMatcher(context_radius=200.0)
matches = matcher.find_matches(tokens, "100017500321", "InvoiceNumber")
```
匹配规则:
- 完全匹配: score = 1.0
- 大小写不敏感: score = 0.95
- 纯数字匹配: score = 0.9
- 上下文关键词加分: +0.1/keyword (最多+0.25)
#### 2. ConcatenatedMatcher - 拼接匹配
```python
from src.matcher.strategies import ConcatenatedMatcher
matcher = ConcatenatedMatcher()
matches = matcher.find_matches(tokens, "100017500321", "InvoiceNumber")
```
用于处理OCR将单个值拆成多个token的情况。
#### 3. SubstringMatcher - 子串匹配
```python
from src.matcher.strategies import SubstringMatcher
matcher = SubstringMatcher()
matches = matcher.find_matches(tokens, "2026-01-09", "InvoiceDate")
```
匹配嵌入在长文本中的字段值:
- `"Fakturadatum: 2026-01-09"` 匹配 `"2026-01-09"`
- `"Fakturanummer: 2465027205"` 匹配 `"2465027205"`
#### 4. FuzzyMatcher - 模糊匹配
```python
from src.matcher.strategies import FuzzyMatcher
matcher = FuzzyMatcher()
matches = matcher.find_matches(tokens, "1234.56", "Amount")
```
用于金额字段,允许小数点差异 (±0.01)。
#### 5. FlexibleDateMatcher - 灵活日期匹配
```python
from src.matcher.strategies import FlexibleDateMatcher
matcher = FlexibleDateMatcher()
matches = matcher.find_matches(tokens, "2025-01-15", "InvoiceDate")
```
当精确匹配失败时使用:
- 同年月: score = 0.7-0.8
- 7天内: score = 0.75+
- 3天内: score = 0.8+
- 14天内: score = 0.6
- 30天内: score = 0.55
### 数据模型
#### Match - 匹配结果
```python
from src.matcher.models import Match
match = Match(
field="InvoiceNumber",
value="100017500321",
bbox=(100.0, 200.0, 300.0, 220.0),
page_no=0,
score=0.95,
matched_text="100017500321",
context_keywords=["fakturanr"]
)
# 转换为YOLO格式
yolo_annotation = match.to_yolo_format(
image_width=1200,
image_height=1600,
class_id=0
)
# "0 0.166667 0.131250 0.166667 0.012500"
```
#### TokenIndex - 空间索引
```python
from src.matcher.token_index import TokenIndex
# 构建索引
index = TokenIndex(tokens, grid_size=100.0)
# 快速查找附近tokens (O(1)平均复杂度)
nearby = index.find_nearby(token, radius=200.0)
# 获取缓存的中心坐标
center = index.get_center(token)
# 获取缓存的小写文本
text_lower = index.get_text_lower(token)
```
### 上下文关键词
```python
from src.matcher.context import CONTEXT_KEYWORDS, find_context_keywords
# 查看字段的上下文关键词
keywords = CONTEXT_KEYWORDS["InvoiceNumber"]
# ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', ...]
# 查找附近的关键词
found_keywords, boost_score = find_context_keywords(
tokens=tokens,
target_token=token,
field_name="InvoiceNumber",
context_radius=200.0,
token_index=index # 可选,提供则使用O(1)查找
)
```
支持的字段:
- InvoiceNumber
- InvoiceDate
- InvoiceDueDate
- OCR
- Bankgiro
- Plusgiro
- Amount
- supplier_organisation_number
- supplier_accounts
### 工具函数
```python
from src.matcher.utils import (
normalize_dashes,
parse_amount,
tokens_on_same_line,
bbox_overlap,
DATE_PATTERN,
WHITESPACE_PATTERN,
NON_DIGIT_PATTERN,
DASH_PATTERN,
)
# 标准化各种破折号
text = normalize_dashes("123456") # "123-456"
# 解析瑞典金额格式
amount = parse_amount("1 234,56 kr") # 1234.56
amount = parse_amount("239 00") # 239.00 (öre格式)
# 检查tokens是否在同一行
same_line = tokens_on_same_line(token1, token2)
# 计算bbox重叠度 (IoU)
overlap = bbox_overlap(bbox1, bbox2) # 0.0 - 1.0
```
## 🧪 测试
```bash
# 在WSL中运行
conda activate invoice-py311
# 运行所有matcher测试
pytest tests/matcher/ -v
# 运行特定策略测试
pytest tests/matcher/strategies/test_exact_matcher.py -v
# 查看覆盖率
pytest tests/matcher/ --cov=src/matcher --cov-report=html
```
测试覆盖:
- ✅ 77个测试全部通过
- ✅ TokenIndex 空间索引
- ✅ 5种匹配策略
- ✅ 上下文关键词
- ✅ 工具函数
- ✅ 去重逻辑
## 📊 重构成果
| 指标 | 重构前 | 重构后 | 改进 |
|------|--------|--------|------|
| field_matcher.py | 876行 | 205行 | ↓ 76% |
| 模块数 | 1 | 11 | 更清晰 |
| 最大文件大小 | 876行 | 154行 | 更易读 |
| 测试通过率 | - | 100% | ✅ |
## 🚀 使用示例
### 完整流程
```python
from src.matcher import FieldMatcher, find_field_matches
# 1. 提取PDF tokens (使用PDF模块)
from src.pdf import PDFExtractor
extractor = PDFExtractor("invoice.pdf")
tokens = extractor.extract_tokens()
# 2. 准备字段值 (从CSV或数据库)
field_values = {
"InvoiceNumber": "100017500321",
"InvoiceDate": "2026-01-09",
"Amount": "1234.56",
}
# 3. 查找所有字段匹配
results = find_field_matches(tokens, field_values, page_no=0)
# 4. 使用结果
for field_name, matches in results.items():
if matches:
best_match = matches[0] # 已按score降序排列
print(f"{field_name}: {best_match.value} @ {best_match.bbox}")
print(f" Score: {best_match.score:.2f}")
print(f" Context: {best_match.context_keywords}")
```
### 添加自定义策略
```python
from src.matcher.strategies.base import BaseMatchStrategy
from src.matcher.models import Match
class CustomMatcher(BaseMatchStrategy):
"""自定义匹配策略"""
def find_matches(self, tokens, value, field_name, token_index=None):
matches = []
# 实现你的匹配逻辑
for token in tokens:
if self._custom_match_logic(token.text, value):
match = Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=0.85,
matched_text=token.text,
context_keywords=[]
)
matches.append(match)
return matches
def _custom_match_logic(self, token_text, value):
# 你的匹配逻辑
return True
# 在FieldMatcher中使用
from src.matcher import FieldMatcher
matcher = FieldMatcher()
matcher.custom_matcher = CustomMatcher()
```
## 🔧 维护指南
### 添加新的上下文关键词
编辑 [src/matcher/context.py](context.py):
```python
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', '新关键词'],
# ...
}
```
### 调整匹配分数
编辑对应的策略文件:
- [exact_matcher.py](strategies/exact_matcher.py) - 精确匹配分数
- [fuzzy_matcher.py](strategies/fuzzy_matcher.py) - 模糊匹配容差
- [flexible_date_matcher.py](strategies/flexible_date_matcher.py) - 日期距离分数
### 性能优化
1. **TokenIndex网格大小**: 默认100px可根据实际文档调整
2. **上下文半径**: 默认200px可根据扫描DPI调整
3. **去重网格**: 默认50px影响bbox重叠检测性能
## 📚 相关文档
- [PDF模块文档](../pdf/README.md) - Token提取
- [Normalize模块文档](../normalize/README.md) - 字段值标准化
- [YOLO模块文档](../yolo/README.md) - 标注生成
## ✅ 总结
这个模块化的matcher系统提供
- **清晰的职责分离**: 每个策略专注一个匹配方法
- **易于测试**: 独立测试每个组件
- **高性能**: O(1)空间索引,智能去重
- **可扩展**: 轻松添加新策略
- **完整测试**: 77个测试100%通过

View File

@@ -1,3 +1,4 @@
from .field_matcher import FieldMatcher, Match, find_field_matches from .field_matcher import FieldMatcher, find_field_matches
from .models import Match, TokenLike
__all__ = ['FieldMatcher', 'Match', 'find_field_matches'] __all__ = ['FieldMatcher', 'Match', 'TokenLike', 'find_field_matches']

92
src/matcher/context.py Normal file
View File

@@ -0,0 +1,92 @@
"""
Context keywords for field matching.
"""
from .models import TokenLike
from .token_index import TokenIndex
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
def find_context_keywords(
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str,
context_radius: float,
token_index: TokenIndex | None = None
) -> tuple[list[str], float]:
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
Args:
tokens: List of all tokens
target_token: The token to find context for
field_name: Name of the field
context_radius: Search radius in pixels
token_index: Optional spatial index for efficient lookup
Returns:
Tuple of (found_keywords, boost_score)
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
# Use spatial index for efficient nearby token lookup
if token_index:
nearby_tokens = token_index.find_nearby(target_token, context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost

View File

@@ -1,158 +1,19 @@
""" """
Field Matching Module Field Matching Module - Refactored
Matches normalized field values to tokens extracted from documents. Matches normalized field values to tokens extracted from documents.
""" """
from dataclasses import dataclass, field from .models import TokenLike, Match
from typing import Protocol from .token_index import TokenIndex
import re from .utils import bbox_overlap
from functools import cached_property from .strategies import (
ExactMatcher,
ConcatenatedMatcher,
# Pre-compiled regex patterns (module-level for efficiency) SubstringMatcher,
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})') FuzzyMatcher,
_WHITESPACE_PATTERN = re.compile(r'\s+') FlexibleDateMatcher,
_NON_DIGIT_PATTERN = re.compile(r'\D') )
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
def _normalize_dashes(text: str) -> str:
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
return _DASH_PATTERN.sub('-', text)
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
class FieldMatcher: class FieldMatcher:
@@ -175,6 +36,13 @@ class FieldMatcher:
self.min_score_threshold = min_score_threshold self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None self._token_index: TokenIndex | None = None
# Initialize matching strategies
self.exact_matcher = ExactMatcher(context_radius)
self.concatenated_matcher = ConcatenatedMatcher(context_radius)
self.substring_matcher = SubstringMatcher(context_radius)
self.fuzzy_matcher = FuzzyMatcher(context_radius)
self.flexible_date_matcher = FlexibleDateMatcher(context_radius)
def find_matches( def find_matches(
self, self,
tokens: list[TokenLike], tokens: list[TokenLike],
@@ -208,34 +76,46 @@ class FieldMatcher:
for value in normalized_values: for value in normalized_values:
# Strategy 1: Exact token match # Strategy 1: Exact token match
exact_matches = self._find_exact_matches(page_tokens, value, field_name) exact_matches = self.exact_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(exact_matches) matches.extend(exact_matches)
# Strategy 2: Multi-token concatenation # Strategy 2: Multi-token concatenation
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name) concat_matches = self.concatenated_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(concat_matches) matches.extend(concat_matches)
# Strategy 3: Fuzzy match (for amounts and dates only) # Strategy 3: Fuzzy match (for amounts and dates only)
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'): if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name) fuzzy_matches = self.fuzzy_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(fuzzy_matches) matches.extend(fuzzy_matches)
# Strategy 4: Substring match (for values embedded in longer text) # Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205" # e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
# Note: Amount is excluded because short numbers like "451" can incorrectly match # Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text # in OCR payment lines or other unrelated text
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', if field_name in (
'supplier_organisation_number', 'supplier_accounts', 'customer_number'): 'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
substring_matches = self._find_substring_matches(page_tokens, value, field_name) 'Bankgiro', 'Plusgiro', 'supplier_organisation_number',
'supplier_accounts', 'customer_number'
):
substring_matches = self.substring_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(substring_matches) matches.extend(substring_matches)
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection) # Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
# Only if no exact matches found for date fields # Only if no exact matches found for date fields
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches: if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
flexible_matches = self._find_flexible_date_matches( for value in normalized_values:
page_tokens, normalized_values, field_name flexible_matches = self.flexible_date_matcher.find_matches(
) page_tokens, value, field_name, self._token_index
matches.extend(flexible_matches) )
matches.extend(flexible_matches)
# Deduplicate and sort by score # Deduplicate and sort by score
matches = self._deduplicate_matches(matches) matches = self._deduplicate_matches(matches)
@@ -246,521 +126,6 @@ class FieldMatcher:
return [m for m in matches if m.score >= self.min_score_threshold] return [m for m in matches if m.score >= self.min_score_threshold]
def _find_exact_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find tokens that exactly match the value."""
matches = []
value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts') else None
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match (use cached lowercase from index)
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif value_digits is not None:
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches
def _find_concatenated_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find value by concatenating adjacent tokens."""
matches = []
value_clean = _WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not self._tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = self._find_context_keywords(
tokens, start_token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches
def _find_substring_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""
Find value as a substring within longer tokens.
Handles cases like:
- 'Fakturadatum: 2026-01-09' where the date is embedded
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
- 'OCR: 1234567890' where reference number is embedded
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a larger number).
"""
matches = []
# Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
if field_name not in supported_fields:
return matches
# Fields where spaces/dashes should be ignored during matching
# (e.g., org number "55 65 74-6624" should match "5565746624")
ignore_spaces_fields = ('supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts')
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = _normalize_dashes(token_text)
# For certain fields, also try matching with spaces/dashes removed
if field_name in ignore_spaces_fields:
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
value_compact = value.replace(' ', '').replace('-', '')
else:
token_text_compact = None
value_compact = None
# Skip if token is the same length as value (would be exact match)
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
idx = None
case_sensitive_match = True
used_compact = False
if value in token_text_normalized:
idx = token_text_normalized.find(value)
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
elif token_text_compact and value_compact in token_text_compact:
# Try compact matching (spaces/dashes removed)
idx = token_text_compact.find(value_compact)
used_compact = True
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
idx = token_text_compact.lower().find(value_compact.lower())
case_sensitive_match = False
used_compact = True
if idx is None:
continue
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
if used_compact:
# Verify proper boundary in compact text
if idx > 0 and token_text_compact[idx - 1].isdigit():
continue
end_idx = idx + len(value_compact)
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
continue
else:
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists)
if idx > 0:
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches
def _find_fuzzy_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find approximate matches for amounts and dates."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = self._parse_amount(token_text)
value_num = self._parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches
def _find_flexible_date_matches(
self,
tokens: list[TokenLike],
normalized_values: list[str],
field_name: str
) -> list[Match]:
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
from datetime import datetime, timedelta
matches = []
# Parse the target date from normalized values
target_date = None
for value in normalized_values:
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
break
except ValueError:
continue
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token (use pre-compiled pattern)
for match in _DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= self.min_score_threshold:
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []
def _find_context_keywords(
self,
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str
) -> tuple[list[str], float]:
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
# Use spatial index for efficient nearby token lookup
if self._token_index:
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = self._token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= self.context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def _parse_amount(self, text: str | int | float) -> float | None:
"""Try to parse text as a monetary amount."""
# Convert to string first
text = str(text)
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
# Pattern: digits + space + exactly 2 digits at end
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
if ore_match:
kronor = ore_match.group(1)
ore = ore_match.group(2)
try:
return float(f"{kronor}.{ore}")
except ValueError:
pass
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
text = re.sub(r'\s*\(.*\)', '', text)
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
text = re.sub(r'[:-]', '', text)
# Remove spaces (thousand separators) but be careful with öre format
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
# Swedish format: "500,00" means 500.00
# Need to handle cases like "500,00." (after removing "kr.")
if ',' in text:
# Remove any trailing dots first (from "kr." removal)
text = text.rstrip('.')
# Now replace comma with dot
if '.' not in text:
text = text.replace(',', '.')
# Remove any remaining non-numeric characters except dot
text = re.sub(r'[^\d.]', '', text)
try:
return float(text)
except ValueError:
return None
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]: def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
""" """
Remove duplicate matches based on bbox overlap. Remove duplicate matches based on bbox overlap.
@@ -803,7 +168,7 @@ class FieldMatcher:
for cell in cells_to_check: for cell in cells_to_check:
if cell in grid: if cell in grid:
for existing in grid[cell]: for existing in grid[cell]:
if self._bbox_overlap(bbox, existing.bbox) > 0.7: if bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True is_duplicate = True
break break
if is_duplicate: if is_duplicate:
@@ -821,27 +186,6 @@ class FieldMatcher:
return unique return unique
def _bbox_overlap(
self,
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def find_field_matches( def find_field_matches(
tokens: list[TokenLike], tokens: list[TokenLike],

View File

@@ -0,0 +1,875 @@
"""
Field Matching Module
Matches normalized field values to tokens extracted from documents.
"""
from dataclasses import dataclass, field
from typing import Protocol
import re
from functools import cached_property
# Pre-compiled regex patterns (module-level for efficiency)
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
_WHITESPACE_PATTERN = re.compile(r'\s+')
_NON_DIGIT_PATTERN = re.compile(r'\D')
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
def _normalize_dashes(text: str) -> str:
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
return _DASH_PATTERN.sub('-', text)
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
class FieldMatcher:
"""Matches field values to document tokens."""
def __init__(
self,
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
min_score_threshold: float = 0.5
):
"""
Initialize the matcher.
Args:
context_radius: Distance to search for context keywords (default 200px to handle
typical label-value spacing in scanned invoices at 150 DPI)
min_score_threshold: Minimum score to consider a match valid
"""
self.context_radius = context_radius
self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None
def find_matches(
self,
tokens: list[TokenLike],
field_name: str,
normalized_values: list[str],
page_no: int = 0
) -> list[Match]:
"""
Find all matches for a field in the token list.
Args:
tokens: List of tokens from the document
field_name: Name of the field to match
normalized_values: List of normalized value variants to search for
page_no: Page number to filter tokens
Returns:
List of Match objects sorted by score (descending)
"""
matches = []
# Filter tokens by page and exclude hidden metadata tokens
# Hidden tokens often have bbox with y < 0 or y > page_height
# These are typically PDF metadata stored as invisible text
page_tokens = [
t for t in tokens
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
]
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
for value in normalized_values:
# Strategy 1: Exact token match
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
matches.extend(exact_matches)
# Strategy 2: Multi-token concatenation
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
matches.extend(concat_matches)
# Strategy 3: Fuzzy match (for amounts and dates only)
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
matches.extend(fuzzy_matches)
# Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
# Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches)
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
# Only if no exact matches found for date fields
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
flexible_matches = self._find_flexible_date_matches(
page_tokens, normalized_values, field_name
)
matches.extend(flexible_matches)
# Deduplicate and sort by score
matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True)
# Clear token index to free memory
self._token_index = None
return [m for m in matches if m.score >= self.min_score_threshold]
def _find_exact_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find tokens that exactly match the value."""
matches = []
value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts') else None
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match (use cached lowercase from index)
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif value_digits is not None:
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches
def _find_concatenated_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find value by concatenating adjacent tokens."""
matches = []
value_clean = _WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not self._tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = self._find_context_keywords(
tokens, start_token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches
def _find_substring_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""
Find value as a substring within longer tokens.
Handles cases like:
- 'Fakturadatum: 2026-01-09' where the date is embedded
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
- 'OCR: 1234567890' where reference number is embedded
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a larger number).
"""
matches = []
# Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
if field_name not in supported_fields:
return matches
# Fields where spaces/dashes should be ignored during matching
# (e.g., org number "55 65 74-6624" should match "5565746624")
ignore_spaces_fields = ('supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts')
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = _normalize_dashes(token_text)
# For certain fields, also try matching with spaces/dashes removed
if field_name in ignore_spaces_fields:
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
value_compact = value.replace(' ', '').replace('-', '')
else:
token_text_compact = None
value_compact = None
# Skip if token is the same length as value (would be exact match)
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
idx = None
case_sensitive_match = True
used_compact = False
if value in token_text_normalized:
idx = token_text_normalized.find(value)
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
elif token_text_compact and value_compact in token_text_compact:
# Try compact matching (spaces/dashes removed)
idx = token_text_compact.find(value_compact)
used_compact = True
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
idx = token_text_compact.lower().find(value_compact.lower())
case_sensitive_match = False
used_compact = True
if idx is None:
continue
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
if used_compact:
# Verify proper boundary in compact text
if idx > 0 and token_text_compact[idx - 1].isdigit():
continue
end_idx = idx + len(value_compact)
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
continue
else:
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists)
if idx > 0:
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches
def _find_fuzzy_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find approximate matches for amounts and dates."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = self._parse_amount(token_text)
value_num = self._parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches
def _find_flexible_date_matches(
self,
tokens: list[TokenLike],
normalized_values: list[str],
field_name: str
) -> list[Match]:
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
from datetime import datetime, timedelta
matches = []
# Parse the target date from normalized values
target_date = None
for value in normalized_values:
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
break
except ValueError:
continue
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token (use pre-compiled pattern)
for match in _DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= self.min_score_threshold:
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []
def _find_context_keywords(
self,
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str
) -> tuple[list[str], float]:
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
# Use spatial index for efficient nearby token lookup
if self._token_index:
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = self._token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= self.context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def _parse_amount(self, text: str | int | float) -> float | None:
"""Try to parse text as a monetary amount."""
# Convert to string first
text = str(text)
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
# Pattern: digits + space + exactly 2 digits at end
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
if ore_match:
kronor = ore_match.group(1)
ore = ore_match.group(2)
try:
return float(f"{kronor}.{ore}")
except ValueError:
pass
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
text = re.sub(r'\s*\(.*\)', '', text)
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
text = re.sub(r'[:-]', '', text)
# Remove spaces (thousand separators) but be careful with öre format
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
# Swedish format: "500,00" means 500.00
# Need to handle cases like "500,00." (after removing "kr.")
if ',' in text:
# Remove any trailing dots first (from "kr." removal)
text = text.rstrip('.')
# Now replace comma with dot
if '.' not in text:
text = text.replace(',', '.')
# Remove any remaining non-numeric characters except dot
text = re.sub(r'[^\d.]', '', text)
try:
return float(text)
except ValueError:
return None
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""
Remove duplicate matches based on bbox overlap.
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
"""
if not matches:
return []
# Sort by: 1) score descending, 2) prefer matches with context keywords,
# 3) prefer upper positions (smaller y) for same-score matches
# This helps select the "main" occurrence in invoice body rather than footer
matches.sort(key=lambda m: (
-m.score,
-len(m.context_keywords), # More keywords = better
m.bbox[1] # Smaller y (upper position) = better
))
# Use spatial grid for efficient overlap checking
# Grid cell size based on typical bbox size
grid_size = 50.0 # pixels
grid: dict[tuple[int, int], list[Match]] = {}
unique = []
for match in matches:
bbox = match.bbox
# Calculate grid cells this bbox touches
min_gx = int(bbox[0] / grid_size)
min_gy = int(bbox[1] / grid_size)
max_gx = int(bbox[2] / grid_size)
max_gy = int(bbox[3] / grid_size)
# Check for overlap only with matches in nearby grid cells
is_duplicate = False
cells_to_check = set()
for gx in range(min_gx - 1, max_gx + 2):
for gy in range(min_gy - 1, max_gy + 2):
cells_to_check.add((gx, gy))
for cell in cells_to_check:
if cell in grid:
for existing in grid[cell]:
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True
break
if is_duplicate:
break
if not is_duplicate:
unique.append(match)
# Add to all grid cells this bbox touches
for gx in range(min_gx, max_gx + 1):
for gy in range(min_gy, max_gy + 1):
key = (gx, gy)
if key not in grid:
grid[key] = []
grid[key].append(match)
return unique
def _bbox_overlap(
self,
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def find_field_matches(
tokens: list[TokenLike],
field_values: dict[str, str],
page_no: int = 0
) -> dict[str, list[Match]]:
"""
Convenience function to find matches for multiple fields.
Args:
tokens: List of tokens from the document
field_values: Dict of field_name -> value to search for
page_no: Page number
Returns:
Dict of field_name -> list of matches
"""
from ..normalize import normalize_field
matcher = FieldMatcher()
results = {}
for field_name, value in field_values.items():
if value is None or str(value).strip() == '':
continue
normalized_values = normalize_field(field_name, str(value))
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
results[field_name] = matches
return results

36
src/matcher/models.py Normal file
View File

@@ -0,0 +1,36 @@
"""
Data models for field matching.
"""
from dataclasses import dataclass
from typing import Protocol
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"

View File

@@ -0,0 +1,17 @@
"""
Matching strategies for field matching.
"""
from .exact_matcher import ExactMatcher
from .concatenated_matcher import ConcatenatedMatcher
from .substring_matcher import SubstringMatcher
from .fuzzy_matcher import FuzzyMatcher
from .flexible_date_matcher import FlexibleDateMatcher
__all__ = [
'ExactMatcher',
'ConcatenatedMatcher',
'SubstringMatcher',
'FuzzyMatcher',
'FlexibleDateMatcher',
]

View File

@@ -0,0 +1,42 @@
"""
Base class for matching strategies.
"""
from abc import ABC, abstractmethod
from ..models import TokenLike, Match
from ..token_index import TokenIndex
class BaseMatchStrategy(ABC):
"""Base class for all matching strategies."""
def __init__(self, context_radius: float = 200.0):
"""
Initialize the strategy.
Args:
context_radius: Distance to search for context keywords
"""
self.context_radius = context_radius
@abstractmethod
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""
Find matches for the given value.
Args:
tokens: List of tokens to search
value: Value to find
field_name: Name of the field
token_index: Optional spatial index for efficient lookup
Returns:
List of Match objects
"""
pass

View File

@@ -0,0 +1,73 @@
"""
Concatenated match strategy - finds value by concatenating adjacent tokens.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import WHITESPACE_PATTERN, tokens_on_same_line
class ConcatenatedMatcher(BaseMatchStrategy):
"""Find value by concatenating adjacent tokens."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find concatenated matches."""
matches = []
value_clean = WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = find_context_keywords(
tokens, start_token, field_name, self.context_radius, token_index
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches

View File

@@ -0,0 +1,65 @@
"""
Exact match strategy.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import NON_DIGIT_PATTERN
class ExactMatcher(BaseMatchStrategy):
"""Find tokens that exactly match the value."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find exact matches."""
matches = []
value_lower = value.lower()
value_digits = NON_DIGIT_PATTERN.sub('', value) if field_name in (
'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts'
) else None
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match (use cached lowercase from index)
elif token_index and token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif value_digits is not None:
token_digits = NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches

View File

@@ -0,0 +1,149 @@
"""
Flexible date match strategy - finds dates with year-month or nearby date matching.
"""
import re
from datetime import datetime
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords, CONTEXT_KEYWORDS
from ..utils import DATE_PATTERN
class FlexibleDateMatcher(BaseMatchStrategy):
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find flexible date matches."""
matches = []
# Parse the target date from normalized values
target_date = None
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
except ValueError:
pass
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token (use pre-compiled pattern)
for match in DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= 0.5: # Min threshold for flexible matching
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []

View File

@@ -0,0 +1,52 @@
"""
Fuzzy match strategy for amounts and dates.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import parse_amount
class FuzzyMatcher(BaseMatchStrategy):
"""Find approximate matches for amounts and dates."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find fuzzy matches."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = parse_amount(token_text)
value_num = parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches

View File

@@ -0,0 +1,143 @@
"""
Substring match strategy - finds value as substring within longer tokens.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords, CONTEXT_KEYWORDS
from ..utils import normalize_dashes
class SubstringMatcher(BaseMatchStrategy):
"""
Find value as a substring within longer tokens.
Handles cases like:
- 'Fakturadatum: 2026-01-09' where the date is embedded
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
- 'OCR: 1234567890' where reference number is embedded
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a larger number).
"""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find substring matches."""
matches = []
# Supported fields for substring matching
supported_fields = (
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts', 'customer_number'
)
if field_name not in supported_fields:
return matches
# Fields where spaces/dashes should be ignored during matching
# (e.g., org number "55 65 74-6624" should match "5565746624")
ignore_spaces_fields = (
'supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts'
)
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = normalize_dashes(token_text)
# For certain fields, also try matching with spaces/dashes removed
if field_name in ignore_spaces_fields:
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
value_compact = value.replace(' ', '').replace('-', '')
else:
token_text_compact = None
value_compact = None
# Skip if token is the same length as value (would be exact match)
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
idx = None
case_sensitive_match = True
used_compact = False
if value in token_text_normalized:
idx = token_text_normalized.find(value)
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
elif token_text_compact and value_compact in token_text_compact:
# Try compact matching (spaces/dashes removed)
idx = token_text_compact.find(value_compact)
used_compact = True
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
idx = token_text_compact.lower().find(value_compact.lower())
case_sensitive_match = False
used_compact = True
if idx is None:
continue
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
if used_compact:
# Verify proper boundary in compact text
if idx > 0 and token_text_compact[idx - 1].isdigit():
continue
end_idx = idx + len(value_compact)
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
continue
else:
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists)
if idx > 0:
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches

View File

@@ -0,0 +1,92 @@
"""
Spatial index for fast token lookup.
"""
from .models import TokenLike
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby

91
src/matcher/utils.py Normal file
View File

@@ -0,0 +1,91 @@
"""
Utility functions for field matching.
"""
import re
# Pre-compiled regex patterns (module-level for efficiency)
DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
WHITESPACE_PATTERN = re.compile(r'\s+')
NON_DIGIT_PATTERN = re.compile(r'\D')
DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
def normalize_dashes(text: str) -> str:
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
return DASH_PATTERN.sub('-', text)
def parse_amount(text: str | int | float) -> float | None:
"""Try to parse text as a monetary amount."""
# Convert to string first
text = str(text)
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
# Pattern: digits + space + exactly 2 digits at end
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
if ore_match:
kronor = ore_match.group(1)
ore = ore_match.group(2)
try:
return float(f"{kronor}.{ore}")
except ValueError:
pass
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
text = re.sub(r'\s*\(.*\)', '', text)
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
text = re.sub(r'[:-]', '', text)
# Remove spaces (thousand separators) but be careful with öre format
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
# Swedish format: "500,00" means 500.00
# Need to handle cases like "500,00." (after removing "kr.")
if ',' in text:
# Remove any trailing dots first (from "kr." removal)
text = text.rstrip('.')
# Now replace comma with dot
if '.' not in text:
text = text.replace(',', '.')
# Remove any remaining non-numeric characters except dot
text = re.sub(r'[^\d.]', '', text)
try:
return float(text)
except ValueError:
return None
def tokens_on_same_line(token1, token2) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def bbox_overlap(
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0

View File

@@ -3,18 +3,26 @@ Field Normalization Module
Normalizes field values to generate multiple candidate forms for matching. Normalizes field values to generate multiple candidate forms for matching.
This module generates variants of CSV values for matching against OCR text. This module now delegates to individual normalizer modules for each field type.
It uses shared utilities from src.utils for text cleaning and OCR error variants. Each normalizer is a separate, reusable module that can be used independently.
""" """
import re
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from typing import Callable from typing import Callable
# Import shared utilities
from src.utils.text_cleaner import TextCleaner from src.utils.text_cleaner import TextCleaner
from src.utils.format_variants import FormatVariants
# Import individual normalizers
from .normalizers import (
InvoiceNumberNormalizer,
OCRNormalizer,
BankgiroNormalizer,
PlusgiroNormalizer,
AmountNormalizer,
DateNormalizer,
OrganisationNumberNormalizer,
SupplierAccountsNormalizer,
CustomerNumberNormalizer,
)
@dataclass @dataclass
@@ -26,27 +34,32 @@ class NormalizedValue:
class FieldNormalizer: class FieldNormalizer:
"""Handles normalization of different invoice field types.""" """
Handles normalization of different invoice field types.
# Common Swedish month names for date parsing This class now acts as a facade that delegates to individual
SWEDISH_MONTHS = { normalizer modules. Each field type has its own specialized
'januari': '01', 'jan': '01', normalizer for better modularity and reusability.
'februari': '02', 'feb': '02', """
'mars': '03', 'mar': '03',
'april': '04', 'apr': '04', # Instantiate individual normalizers
'maj': '05', _invoice_number = InvoiceNumberNormalizer()
'juni': '06', 'jun': '06', _ocr_number = OCRNormalizer()
'juli': '07', 'jul': '07', _bankgiro = BankgiroNormalizer()
'augusti': '08', 'aug': '08', _plusgiro = PlusgiroNormalizer()
'september': '09', 'sep': '09', 'sept': '09', _amount = AmountNormalizer()
'oktober': '10', 'okt': '10', _date = DateNormalizer()
'november': '11', 'nov': '11', _organisation_number = OrganisationNumberNormalizer()
'december': '12', 'dec': '12' _supplier_accounts = SupplierAccountsNormalizer()
} _customer_number = CustomerNumberNormalizer()
# Common Swedish month names for backward compatibility
SWEDISH_MONTHS = DateNormalizer.SWEDISH_MONTHS
@staticmethod @staticmethod
def clean_text(text: str) -> str: def clean_text(text: str) -> str:
"""Remove invisible characters and normalize whitespace and dashes. """
Remove invisible characters and normalize whitespace and dashes.
Delegates to shared TextCleaner for consistency. Delegates to shared TextCleaner for consistency.
""" """
@@ -56,517 +69,82 @@ class FieldNormalizer:
def normalize_invoice_number(value: str) -> list[str]: def normalize_invoice_number(value: str) -> list[str]:
""" """
Normalize invoice number. Normalize invoice number.
Keeps only digits for matching.
Examples: Delegates to InvoiceNumberNormalizer.
'100017500321' -> ['100017500321']
'INV-100017500321' -> ['100017500321', 'INV-100017500321']
""" """
value = FieldNormalizer.clean_text(value) return FieldNormalizer._invoice_number.normalize(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only and digits_only != value:
variants.append(digits_only)
return list(set(v for v in variants if v))
@staticmethod @staticmethod
def normalize_ocr_number(value: str) -> list[str]: def normalize_ocr_number(value: str) -> list[str]:
""" """
Normalize OCR number (Swedish payment reference). Normalize OCR number (Swedish payment reference).
Similar to invoice number - digits only.
Delegates to OCRNormalizer.
""" """
return FieldNormalizer.normalize_invoice_number(value) return FieldNormalizer._ocr_number.normalize(value)
@staticmethod @staticmethod
def normalize_bankgiro(value: str) -> list[str]: def normalize_bankgiro(value: str) -> list[str]:
""" """
Normalize Bankgiro number. Normalize Bankgiro number.
Uses shared FormatVariants plus OCR error variants. Delegates to BankgiroNormalizer.
Examples:
'5393-9484' -> ['5393-9484', '53939484']
'53939484' -> ['53939484', '5393-9484']
""" """
# Use shared module for base variants return FieldNormalizer._bankgiro.normalize(value)
variants = set(FormatVariants.bankgiro_variants(value))
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
@staticmethod @staticmethod
def normalize_plusgiro(value: str) -> list[str]: def normalize_plusgiro(value: str) -> list[str]:
""" """
Normalize Plusgiro number. Normalize Plusgiro number.
Uses shared FormatVariants plus OCR error variants. Delegates to PlusgiroNormalizer.
Examples:
'1234567-8' -> ['1234567-8', '12345678']
'12345678' -> ['12345678', '1234567-8']
""" """
# Use shared module for base variants return FieldNormalizer._plusgiro.normalize(value)
variants = set(FormatVariants.plusgiro_variants(value))
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
@staticmethod @staticmethod
def normalize_organisation_number(value: str) -> list[str]: def normalize_organisation_number(value: str) -> list[str]:
""" """
Normalize Swedish organisation number and generate VAT number variants. Normalize Swedish organisation number and generate VAT number variants.
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits) Delegates to OrganisationNumberNormalizer.
Swedish VAT format: SE + org_number (10 digits) + 01
Uses shared FormatVariants for comprehensive variant generation,
plus OCR error variants.
Examples:
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
""" """
# Use shared module for base variants return FieldNormalizer._organisation_number.normalize(value)
variants = set(FormatVariants.organisation_number_variants(value))
# Add OCR error variants for digit sequences
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits and len(digits) >= 10:
# Generate variants where OCR might have misread characters
for ocr_var in TextCleaner.generate_ocr_variants(digits[:10]):
variants.add(ocr_var)
if len(ocr_var) == 10:
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
return list(v for v in variants if v)
@staticmethod @staticmethod
def normalize_supplier_accounts(value: str) -> list[str]: def normalize_supplier_accounts(value: str) -> list[str]:
""" """
Normalize supplier accounts field. Normalize supplier accounts field.
The field may contain multiple accounts separated by ' | '. Delegates to SupplierAccountsNormalizer.
Format examples:
'PG:48676043 | PG:49128028 | PG:8915035'
'BG:5393-9484'
Each account is normalized separately to generate variants.
Examples:
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3']
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484']
""" """
value = FieldNormalizer.clean_text(value) return FieldNormalizer._supplier_accounts.normalize(value)
variants = []
# Split by ' | ' to handle multiple accounts
accounts = [acc.strip() for acc in value.split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Add original value
variants.append(account)
# Remove prefix (PG:, BG:, etc.)
if ':' in account:
prefix, number = account.split(':', 1)
number = number.strip()
variants.append(number) # Just the number without prefix
# Also add with different prefix formats
prefix_upper = prefix.strip().upper()
variants.append(f"{prefix_upper}:{number}")
variants.append(f"{prefix_upper}: {number}") # With space
else:
number = account
# Extract digits only
digits_only = re.sub(r'\D', '', number)
if digits_only:
variants.append(digits_only)
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
if len(digits_only) == 8:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
# Also try 4-4 format for bankgiro
variants.append(f"{digits_only[:4]}-{digits_only[4:]}")
elif len(digits_only) == 7:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
elif len(digits_only) == 10:
# 6-4 format (like org number)
variants.append(f"{digits_only[:6]}-{digits_only[6:]}")
return list(set(v for v in variants if v))
@staticmethod @staticmethod
def normalize_customer_number(value: str) -> list[str]: def normalize_customer_number(value: str) -> list[str]:
""" """
Normalize customer number. Normalize customer number.
Customer numbers can have various formats: Delegates to CustomerNumberNormalizer.
- Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234'
- Pure numbers: '12345', '123-456'
Examples:
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566']
'ABC 123' -> ['ABC 123', 'ABC123']
""" """
value = FieldNormalizer.clean_text(value) return FieldNormalizer._customer_number.normalize(value)
variants = [value]
# Version without spaces
no_space = value.replace(' ', '')
if no_space != value:
variants.append(no_space)
# Version without dashes
no_dash = value.replace('-', '')
if no_dash != value:
variants.append(no_dash)
# Version without spaces and dashes
clean = value.replace(' ', '').replace('-', '')
if clean != value and clean not in variants:
variants.append(clean)
# Uppercase and lowercase versions
if value.upper() != value:
variants.append(value.upper())
if value.lower() != value:
variants.append(value.lower())
return list(set(v for v in variants if v))
@staticmethod @staticmethod
def normalize_amount(value: str) -> list[str]: def normalize_amount(value: str) -> list[str]:
""" """
Normalize monetary amount. Normalize monetary amount.
Examples: Delegates to AmountNormalizer.
'114' -> ['114', '114,00', '114.00']
'114,00' -> ['114,00', '114.00', '114']
'1 234,56' -> ['1234,56', '1234.56', '1 234,56']
'3045 52' -> ['3045.52', '3045,52', '304552'] (space as decimal sep)
""" """
value = FieldNormalizer.clean_text(value) return FieldNormalizer._amount.normalize(value)
# Remove currency symbols and common suffixes
value = re.sub(r'[SEK|kr|:-]+$', '', value, flags=re.IGNORECASE).strip()
variants = [value]
# Check for space as decimal separator pattern: "3045 52" (number space 2-digits)
# This is common in Swedish invoices where space separates öre from kronor
space_decimal_match = re.match(r'^(\d+)\s+(\d{2})$', value)
if space_decimal_match:
integer_part = space_decimal_match.group(1)
decimal_part = space_decimal_match.group(2)
# Add variants with different decimal separators
variants.append(f"{integer_part}.{decimal_part}")
variants.append(f"{integer_part},{decimal_part}")
variants.append(f"{integer_part}{decimal_part}") # No separator
# Check for space as thousand separator with decimal: "10 571,00" or "10 571.00"
# Pattern: digits space digits comma/dot 2-digits
space_thousand_match = re.match(r'^(\d{1,3})[\s\xa0]+(\d{3})([,\.])(\d{2})$', value)
if space_thousand_match:
part1 = space_thousand_match.group(1)
part2 = space_thousand_match.group(2)
sep = space_thousand_match.group(3)
decimal = space_thousand_match.group(4)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(f"{combined}{decimal}")
# Also add variant with space preserved but different decimal sep
other_sep = ',' if sep == '.' else '.'
variants.append(f"{part1} {part2}{other_sep}{decimal}")
# Handle US format: "1,390.00" (comma as thousand separator, dot as decimal)
us_format_match = re.match(r'^(\d{1,3}),(\d{3})\.(\d{2})$', value)
if us_format_match:
part1 = us_format_match.group(1)
part2 = us_format_match.group(2)
decimal = us_format_match.group(3)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(combined) # Without decimal
# European format: 1.390,00
variants.append(f"{part1}.{part2},{decimal}")
# Handle European format: "1.390,00" (dot as thousand separator, comma as decimal)
eu_format_match = re.match(r'^(\d{1,3})\.(\d{3}),(\d{2})$', value)
if eu_format_match:
part1 = eu_format_match.group(1)
part2 = eu_format_match.group(2)
decimal = eu_format_match.group(3)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(combined) # Without decimal
# US format: 1,390.00
variants.append(f"{part1},{part2}.{decimal}")
# Remove spaces (thousand separators) including non-breaking space
no_space = value.replace(' ', '').replace('\xa0', '')
# Normalize decimal separator
if ',' in no_space:
dot_version = no_space.replace(',', '.')
variants.append(no_space)
variants.append(dot_version)
elif '.' in no_space:
comma_version = no_space.replace('.', ',')
variants.append(no_space)
variants.append(comma_version)
else:
# Integer amount - add decimal versions
variants.append(no_space)
variants.append(f"{no_space},00")
variants.append(f"{no_space}.00")
# Try to parse and get clean numeric value
try:
# Parse as float
clean = no_space.replace(',', '.')
num = float(clean)
# Integer if no decimals
if num == int(num):
int_val = int(num)
variants.append(str(int_val))
variants.append(f"{int_val},00")
variants.append(f"{int_val}.00")
# European format with dot as thousand separator (e.g., 20.485,00)
if int_val >= 1000:
# Format: XX.XXX,XX
formatted = f"{int_val:,}".replace(',', '.')
variants.append(formatted) # 20.485
variants.append(f"{formatted},00") # 20.485,00
else:
variants.append(f"{num:.2f}")
variants.append(f"{num:.2f}".replace('.', ','))
# European format with dot as thousand separator
if num >= 1000:
# Split integer and decimal parts using string formatting to avoid precision loss
formatted_str = f"{num:.2f}"
int_str, dec_str = formatted_str.split(".")
int_part = int(int_str)
formatted_int = f"{int_part:,}".replace(',', '.')
variants.append(f"{formatted_int},{dec_str}") # 3.045,52
except ValueError:
pass
return list(set(v for v in variants if v))
@staticmethod @staticmethod
def normalize_date(value: str) -> list[str]: def normalize_date(value: str) -> list[str]:
""" """
Normalize date to YYYY-MM-DD and generate variants. Normalize date to YYYY-MM-DD and generate variants.
Handles: Delegates to DateNormalizer.
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
'13 december 2025' -> ['2025-12-13', ...]
Note: For ambiguous formats like DD/MM/YYYY vs MM/DD/YYYY,
we generate variants for BOTH interpretations to maximize matching.
""" """
value = FieldNormalizer.clean_text(value) return FieldNormalizer._date.normalize(value)
variants = [value]
parsed_dates = [] # May have multiple interpretations
# Try different date formats
date_patterns = [
# ISO format with optional time (e.g., 2026-01-09 00:00:00)
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYMMDD
(r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYYYMMDD
(r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
]
# Ambiguous patterns - try both DD/MM and MM/DD interpretations
ambiguous_patterns_4digit_year = [
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
# Format with . - typically European DD.MM.YYYY
r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$',
# Format with - (not ISO) - could be DD-MM-YYYY or MM-DD-YYYY
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
]
# Patterns with 2-digit year (common in Swedish invoices)
ambiguous_patterns_2digit_year = [
# Format DD.MM.YY (e.g., 02.08.25 for 2025-08-02)
r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$',
# Format DD/MM/YY
r'^(\d{1,2})/(\d{1,2})/(\d{2})$',
# Format DD-MM-YY
r'^(\d{1,2})-(\d{1,2})-(\d{2})$',
]
# Try unambiguous patterns first
for pattern, extractor in date_patterns:
match = re.match(pattern, value)
if match:
try:
year, month, day = extractor(match)
parsed_dates.append(datetime(year, month, day))
break
except ValueError:
continue
# Try ambiguous patterns with 4-digit year
if not parsed_dates:
for pattern in ambiguous_patterns_4digit_year:
match = re.match(pattern, value)
if match:
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
# Try DD/MM/YYYY (European - day first)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YYYY (US - month first) if different and valid
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try ambiguous patterns with 2-digit year (e.g., 02.08.25)
if not parsed_dates:
for pattern in ambiguous_patterns_2digit_year:
match = re.match(pattern, value)
if match:
n1, n2, yy = int(match[1]), int(match[2]), int(match[3])
# Convert 2-digit year to 4-digit (00-49 -> 2000s, 50-99 -> 1900s)
year = 2000 + yy if yy < 50 else 1900 + yy
# Try DD/MM/YY (European - day first, most common in Sweden)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YY (US - month first) if different and valid
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try Swedish month names
if not parsed_dates:
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
if month_name in value.lower():
# Extract day and year
numbers = re.findall(r'\d+', value)
if len(numbers) >= 2:
day = int(numbers[0])
year = int(numbers[-1])
if year < 100:
year = 2000 + year if year < 50 else 1900 + year
try:
parsed_dates.append(datetime(year, int(month_num), day))
break
except ValueError:
continue
# Generate variants for all parsed date interpretations
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
for parsed_date in parsed_dates:
# Generate different formats
iso = parsed_date.strftime('%Y-%m-%d')
eu_slash = parsed_date.strftime('%d/%m/%Y')
us_slash = parsed_date.strftime('%m/%d/%Y') # US format MM/DD/YYYY
eu_dot = parsed_date.strftime('%d.%m.%Y')
iso_dot = parsed_date.strftime('%Y.%m.%d') # ISO with dots (e.g., 2024.02.08)
compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD
compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108)
# Short year with dot separator (e.g., 02.01.26)
eu_dot_short = parsed_date.strftime('%d.%m.%y')
# Short year with slash separator (e.g., 20/10/24) - DD/MM/YY format
eu_slash_short = parsed_date.strftime('%d/%m/%y')
# Short year with hyphen separator (e.g., 23-11-01) - common in Swedish invoices
yy_mm_dd_short = parsed_date.strftime('%y-%m-%d')
# Middle dot separator (OCR sometimes reads hyphens as middle dots)
iso_middot = parsed_date.strftime('%%%d')
# Spaced formats (e.g., "2026 01 12", "26 01 12")
spaced_full = parsed_date.strftime('%Y %m %d')
spaced_short = parsed_date.strftime('%y %m %d')
# Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026")
month_full = swedish_months_full[parsed_date.month - 1]
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
# Swedish month abbreviation with hyphen (e.g., "30-OKT-24", "30-okt-24")
month_abbrev_upper = month_abbrev.upper()
swedish_hyphen_short = f"{parsed_date.day:02d}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
swedish_hyphen_short_lower = f"{parsed_date.day:02d}-{month_abbrev}-{parsed_date.strftime('%y')}"
# Also without leading zero on day
swedish_hyphen_short_no_zero = f"{parsed_date.day}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
# Swedish month abbreviation with short year in different format (e.g., "SEP-24", "30 SEP 24")
month_year_only = f"{month_abbrev_upper}-{parsed_date.strftime('%y')}"
swedish_spaced = f"{parsed_date.day:02d} {month_abbrev_upper} {parsed_date.strftime('%y')}"
variants.extend([
iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short,
eu_dot_short, eu_slash_short, yy_mm_dd_short, iso_middot, spaced_full, spaced_short,
swedish_format_full, swedish_format_abbrev,
swedish_hyphen_short, swedish_hyphen_short_lower, swedish_hyphen_short_no_zero,
month_year_only, swedish_spaced
])
return list(set(v for v in variants if v))
# Field type to normalizer mapping # Field type to normalizer mapping

View File

@@ -0,0 +1,225 @@
# Normalizer Modules
独立的字段标准化模块,用于生成字段值的各种变体以进行匹配。
## 架构
每个字段类型都有自己的独立 normalizer 模块,便于复用和维护:
```
src/normalize/normalizers/
├── __init__.py # 导出所有 normalizer
├── base.py # BaseNormalizer 基类
├── invoice_number_normalizer.py # 发票号码
├── ocr_normalizer.py # OCR 参考号
├── bankgiro_normalizer.py # Bankgiro 账号
├── plusgiro_normalizer.py # Plusgiro 账号
├── amount_normalizer.py # 金额
├── date_normalizer.py # 日期
├── organisation_number_normalizer.py # 组织编号
├── supplier_accounts_normalizer.py # 供应商账号
└── customer_number_normalizer.py # 客户编号
```
## 使用方法
### 方法 1: 通过 FieldNormalizer 门面类 (推荐)
```python
from src.normalize.normalizer import FieldNormalizer
# 标准化发票号码
variants = FieldNormalizer.normalize_invoice_number('INV-100017500321')
# 返回: ['INV-100017500321', '100017500321']
# 标准化金额
variants = FieldNormalizer.normalize_amount('1 234,56')
# 返回: ['1 234,56', '1234,56', '1234.56', ...]
# 标准化日期
variants = FieldNormalizer.normalize_date('2025-12-13')
# 返回: ['2025-12-13', '13/12/2025', '13.12.2025', ...]
```
### 方法 2: 通过主函数 (自动选择 normalizer)
```python
from src.normalize import normalize_field
# 自动选择合适的 normalizer
variants = normalize_field('InvoiceNumber', 'INV-12345')
variants = normalize_field('Amount', '1234.56')
variants = normalize_field('InvoiceDate', '2025-12-13')
```
### 方法 3: 直接使用独立 normalizer (最大灵活性)
```python
from src.normalize.normalizers import (
InvoiceNumberNormalizer,
AmountNormalizer,
DateNormalizer,
)
# 实例化
invoice_normalizer = InvoiceNumberNormalizer()
amount_normalizer = AmountNormalizer()
date_normalizer = DateNormalizer()
# 使用
variants = invoice_normalizer.normalize('INV-12345')
variants = amount_normalizer.normalize('1234.56')
variants = date_normalizer.normalize('2025-12-13')
# 也可以直接调用 (支持 __call__)
variants = invoice_normalizer('INV-12345')
```
## 各 Normalizer 功能
### InvoiceNumberNormalizer
- 提取纯数字版本
- 保留原始格式
示例:
```python
'INV-100017500321' -> ['INV-100017500321', '100017500321']
```
### OCRNormalizer
- 与 InvoiceNumberNormalizer 类似
- 专门用于 OCR 参考号
### BankgiroNormalizer
- 生成有/无分隔符的格式
- 添加 OCR 错误变体
示例:
```python
'5393-9484' -> ['5393-9484', '53939484', ...]
```
### PlusgiroNormalizer
- 生成有/无分隔符的格式
- 添加 OCR 错误变体
示例:
```python
'1234567-8' -> ['1234567-8', '12345678', ...]
```
### AmountNormalizer
- 处理瑞典和国际格式
- 支持不同的千位/小数分隔符
- 空格作为小数或千位分隔符
示例:
```python
'1 234,56' -> ['1234,56', '1234.56', '1 234,56', ...]
'3045 52' -> ['3045.52', '3045,52', '304552']
```
### DateNormalizer
- 转换为 ISO 格式 (YYYY-MM-DD)
- 生成多种日期格式变体
- 支持瑞典月份名称
- 处理模糊格式 (DD/MM 和 MM/DD)
示例:
```python
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025', ...]
'13 december 2025' -> ['2025-12-13', ...]
```
### OrganisationNumberNormalizer
- 标准化瑞典组织编号
- 生成 VAT 号码变体
- 添加 OCR 错误变体
示例:
```python
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
```
### SupplierAccountsNormalizer
- 处理多个账号 (用 | 分隔)
- 移除/添加前缀 (PG:, BG:)
- 生成不同格式
示例:
```python
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3', ...]
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484', ...]
```
### CustomerNumberNormalizer
- 移除空格和连字符
- 生成大小写变体
示例:
```python
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566', ...]
```
## BaseNormalizer
所有 normalizer 继承自 `BaseNormalizer`:
```python
from src.normalize.normalizers.base import BaseNormalizer
class MyCustomNormalizer(BaseNormalizer):
def normalize(self, value: str) -> list[str]:
# 实现标准化逻辑
value = self.clean_text(value) # 使用基类的清理方法
# ... 生成变体
return variants
```
## 设计原则
1. **单一职责**: 每个 normalizer 只负责一种字段类型
2. **独立复用**: 每个模块可独立导入使用
3. **一致接口**: 所有 normalizer 实现 `normalize(value) -> list[str]`
4. **向后兼容**: 保持与原 `FieldNormalizer` API 兼容
## 测试
所有 normalizer 都经过全面测试:
```bash
# 运行所有测试
python -m pytest src/normalize/test_normalizer.py -v
# 85 个测试用例全部通过 ✅
```
## 添加新的 Normalizer
1.`src/normalize/normalizers/` 创建新文件 `my_field_normalizer.py`
2. 继承 `BaseNormalizer` 并实现 `normalize()` 方法
3.`__init__.py` 中导出
4.`normalizer.py``FieldNormalizer` 中添加静态方法
5.`NORMALIZERS` 字典中注册
示例:
```python
# my_field_normalizer.py
from .base import BaseNormalizer
class MyFieldNormalizer(BaseNormalizer):
def normalize(self, value: str) -> list[str]:
value = self.clean_text(value)
# ... 实现逻辑
return variants
```
## 优势
-**模块化**: 每个字段类型独立维护
-**可复用**: 可在不同项目中独立使用
-**可测试**: 每个模块单独测试
-**易扩展**: 添加新字段类型很简单
-**向后兼容**: 不影响现有代码
-**清晰**: 代码结构更清晰易懂

View File

@@ -0,0 +1,28 @@
"""
Normalizer modules for different field types.
Each normalizer is responsible for generating variants of a field value
for matching against OCR text or other data sources.
"""
from .invoice_number_normalizer import InvoiceNumberNormalizer
from .ocr_normalizer import OCRNormalizer
from .bankgiro_normalizer import BankgiroNormalizer
from .plusgiro_normalizer import PlusgiroNormalizer
from .amount_normalizer import AmountNormalizer
from .date_normalizer import DateNormalizer
from .organisation_number_normalizer import OrganisationNumberNormalizer
from .supplier_accounts_normalizer import SupplierAccountsNormalizer
from .customer_number_normalizer import CustomerNumberNormalizer
__all__ = [
'InvoiceNumberNormalizer',
'OCRNormalizer',
'BankgiroNormalizer',
'PlusgiroNormalizer',
'AmountNormalizer',
'DateNormalizer',
'OrganisationNumberNormalizer',
'SupplierAccountsNormalizer',
'CustomerNumberNormalizer',
]

View File

@@ -0,0 +1,130 @@
"""
Amount Normalizer
Normalizes monetary amounts with various formats and separators.
"""
import re
from .base import BaseNormalizer
class AmountNormalizer(BaseNormalizer):
"""
Normalizes monetary amounts.
Handles Swedish and international formats with different
thousand/decimal separators.
Examples:
'114' -> ['114', '114,00', '114.00']
'114,00' -> ['114,00', '114.00', '114']
'1 234,56' -> ['1234,56', '1234.56', '1 234,56']
'3045 52' -> ['3045.52', '3045,52', '304552']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of amount."""
value = self.clean_text(value)
# Remove currency symbols and common suffixes
value = re.sub(r'[SEK|kr|:-]+$', '', value, flags=re.IGNORECASE).strip()
variants = [value]
# Check for space as decimal separator: "3045 52"
space_decimal_match = re.match(r'^(\d+)\s+(\d{2})$', value)
if space_decimal_match:
integer_part = space_decimal_match.group(1)
decimal_part = space_decimal_match.group(2)
variants.append(f"{integer_part}.{decimal_part}")
variants.append(f"{integer_part},{decimal_part}")
variants.append(f"{integer_part}{decimal_part}")
# Check for space as thousand separator: "10 571,00"
space_thousand_match = re.match(r'^(\d{1,3})[\s\xa0]+(\d{3})([,\.])(\d{2})$', value)
if space_thousand_match:
part1 = space_thousand_match.group(1)
part2 = space_thousand_match.group(2)
sep = space_thousand_match.group(3)
decimal = space_thousand_match.group(4)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(f"{combined}{decimal}")
other_sep = ',' if sep == '.' else '.'
variants.append(f"{part1} {part2}{other_sep}{decimal}")
# Handle US format: "1,390.00"
us_format_match = re.match(r'^(\d{1,3}),(\d{3})\.(\d{2})$', value)
if us_format_match:
part1 = us_format_match.group(1)
part2 = us_format_match.group(2)
decimal = us_format_match.group(3)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(combined)
variants.append(f"{part1}.{part2},{decimal}")
# Handle European format: "1.390,00"
eu_format_match = re.match(r'^(\d{1,3})\.(\d{3}),(\d{2})$', value)
if eu_format_match:
part1 = eu_format_match.group(1)
part2 = eu_format_match.group(2)
decimal = eu_format_match.group(3)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(combined)
variants.append(f"{part1},{part2}.{decimal}")
# Remove spaces (thousand separators)
no_space = value.replace(' ', '').replace('\xa0', '')
# Normalize decimal separator
if ',' in no_space:
dot_version = no_space.replace(',', '.')
variants.append(no_space)
variants.append(dot_version)
elif '.' in no_space:
comma_version = no_space.replace('.', ',')
variants.append(no_space)
variants.append(comma_version)
else:
# Integer amount - add decimal versions
variants.append(no_space)
variants.append(f"{no_space},00")
variants.append(f"{no_space}.00")
# Try to parse and get clean numeric value
try:
clean = no_space.replace(',', '.')
num = float(clean)
# Integer if no decimals
if num == int(num):
int_val = int(num)
variants.append(str(int_val))
variants.append(f"{int_val},00")
variants.append(f"{int_val}.00")
# European format with dot as thousand separator
if int_val >= 1000:
formatted = f"{int_val:,}".replace(',', '.')
variants.append(formatted)
variants.append(f"{formatted},00")
else:
variants.append(f"{num:.2f}")
variants.append(f"{num:.2f}".replace('.', ','))
# European format with dot as thousand separator
if num >= 1000:
formatted_str = f"{num:.2f}"
int_str, dec_str = formatted_str.split(".")
int_part = int(int_str)
formatted_int = f"{int_part:,}".replace(',', '.')
variants.append(f"{formatted_int},{dec_str}")
except ValueError:
pass
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,34 @@
"""
Bankgiro Number Normalizer
Normalizes Swedish Bankgiro account numbers.
"""
from .base import BaseNormalizer
from src.utils.format_variants import FormatVariants
from src.utils.text_cleaner import TextCleaner
class BankgiroNormalizer(BaseNormalizer):
"""
Normalizes Bankgiro numbers.
Generates format variants and OCR error variants.
Examples:
'5393-9484' -> ['5393-9484', '53939484', ...]
'53939484' -> ['53939484', '5393-9484', ...]
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of Bankgiro number."""
# Use shared module for base variants
variants = set(FormatVariants.bankgiro_variants(value))
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)

View File

@@ -0,0 +1,34 @@
"""
Base class for field normalizers.
"""
from abc import ABC, abstractmethod
from src.utils.text_cleaner import TextCleaner
class BaseNormalizer(ABC):
"""Base class for all field normalizers."""
@staticmethod
def clean_text(text: str) -> str:
"""Clean text using shared TextCleaner."""
return TextCleaner.clean_text(text)
@abstractmethod
def normalize(self, value: str) -> list[str]:
"""
Normalize a field value and return all variants.
Args:
value: Raw field value
Returns:
List of normalized variants for matching
"""
pass
def __call__(self, value: str) -> list[str]:
"""Allow normalizer to be called as a function."""
if value is None or (isinstance(value, str) and not value.strip()):
return []
return self.normalize(str(value))

View File

@@ -0,0 +1,49 @@
"""
Customer Number Normalizer
Normalizes customer numbers (alphanumeric codes).
"""
from .base import BaseNormalizer
class CustomerNumberNormalizer(BaseNormalizer):
"""
Normalizes customer numbers.
Customer numbers can have various formats:
- Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234'
- Pure numbers: '12345', '123-456'
Examples:
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566']
'ABC 123' -> ['ABC 123', 'ABC123']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of customer number."""
value = self.clean_text(value)
variants = [value]
# Version without spaces
no_space = value.replace(' ', '')
if no_space != value:
variants.append(no_space)
# Version without dashes
no_dash = value.replace('-', '')
if no_dash != value:
variants.append(no_dash)
# Version without spaces and dashes
clean = value.replace(' ', '').replace('-', '')
if clean != value and clean not in variants:
variants.append(clean)
# Uppercase and lowercase versions
if value.upper() != value:
variants.append(value.upper())
if value.lower() != value:
variants.append(value.lower())
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,190 @@
"""
Date Normalizer
Normalizes dates in various formats to ISO and generates variants.
"""
import re
from datetime import datetime
from .base import BaseNormalizer
class DateNormalizer(BaseNormalizer):
"""
Normalizes dates to YYYY-MM-DD and generates variants.
Handles Swedish and international date formats.
Examples:
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
'13 december 2025' -> ['2025-12-13', ...]
"""
# Swedish month names
SWEDISH_MONTHS = {
'januari': '01', 'jan': '01',
'februari': '02', 'feb': '02',
'mars': '03', 'mar': '03',
'april': '04', 'apr': '04',
'maj': '05',
'juni': '06', 'jun': '06',
'juli': '07', 'jul': '07',
'augusti': '08', 'aug': '08',
'september': '09', 'sep': '09', 'sept': '09',
'oktober': '10', 'okt': '10',
'november': '11', 'nov': '11',
'december': '12', 'dec': '12'
}
def normalize(self, value: str) -> list[str]:
"""Generate variants of date."""
value = self.clean_text(value)
variants = [value]
parsed_dates = []
# Try unambiguous patterns first
date_patterns = [
# ISO format with optional time
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$',
lambda m: (int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYMMDD
(r'^(\d{2})(\d{2})(\d{2})$',
lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYYYMMDD
(r'^(\d{4})(\d{2})(\d{2})$',
lambda m: (int(m[1]), int(m[2]), int(m[3]))),
]
for pattern, extractor in date_patterns:
match = re.match(pattern, value)
if match:
try:
year, month, day = extractor(match)
parsed_dates.append(datetime(year, month, day))
break
except ValueError:
continue
# Try ambiguous patterns with 4-digit year
ambiguous_patterns_4digit = [
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$',
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
]
if not parsed_dates:
for pattern in ambiguous_patterns_4digit:
match = re.match(pattern, value)
if match:
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
# Try DD/MM/YYYY (European - day first)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YYYY (US - month first) if different
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try ambiguous patterns with 2-digit year
ambiguous_patterns_2digit = [
r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$',
r'^(\d{1,2})/(\d{1,2})/(\d{2})$',
r'^(\d{1,2})-(\d{1,2})-(\d{2})$',
]
if not parsed_dates:
for pattern in ambiguous_patterns_2digit:
match = re.match(pattern, value)
if match:
n1, n2, yy = int(match[1]), int(match[2]), int(match[3])
year = 2000 + yy if yy < 50 else 1900 + yy
# Try DD/MM/YY (European)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YY (US) if different
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try Swedish month names
if not parsed_dates:
for month_name, month_num in self.SWEDISH_MONTHS.items():
if month_name in value.lower():
numbers = re.findall(r'\d+', value)
if len(numbers) >= 2:
day = int(numbers[0])
year = int(numbers[-1])
if year < 100:
year = 2000 + year if year < 50 else 1900 + year
try:
parsed_dates.append(datetime(year, int(month_num), day))
break
except ValueError:
continue
# Generate variants for all parsed dates
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
for parsed_date in parsed_dates:
iso = parsed_date.strftime('%Y-%m-%d')
eu_slash = parsed_date.strftime('%d/%m/%Y')
us_slash = parsed_date.strftime('%m/%d/%Y')
eu_dot = parsed_date.strftime('%d.%m.%Y')
iso_dot = parsed_date.strftime('%Y.%m.%d')
compact = parsed_date.strftime('%Y%m%d')
compact_short = parsed_date.strftime('%y%m%d')
eu_dot_short = parsed_date.strftime('%d.%m.%y')
eu_slash_short = parsed_date.strftime('%d/%m/%y')
yy_mm_dd_short = parsed_date.strftime('%y-%m-%d')
iso_middot = parsed_date.strftime('%%%d')
spaced_full = parsed_date.strftime('%Y %m %d')
spaced_short = parsed_date.strftime('%y %m %d')
# Swedish month name formats
month_full = swedish_months_full[parsed_date.month - 1]
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
month_abbrev_upper = month_abbrev.upper()
swedish_hyphen_short = f"{parsed_date.day:02d}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
swedish_hyphen_short_lower = f"{parsed_date.day:02d}-{month_abbrev}-{parsed_date.strftime('%y')}"
swedish_hyphen_short_no_zero = f"{parsed_date.day}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
month_year_only = f"{month_abbrev_upper}-{parsed_date.strftime('%y')}"
swedish_spaced = f"{parsed_date.day:02d} {month_abbrev_upper} {parsed_date.strftime('%y')}"
variants.extend([
iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short,
eu_dot_short, eu_slash_short, yy_mm_dd_short, iso_middot, spaced_full, spaced_short,
swedish_format_full, swedish_format_abbrev,
swedish_hyphen_short, swedish_hyphen_short_lower, swedish_hyphen_short_no_zero,
month_year_only, swedish_spaced
])
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,31 @@
"""
Invoice Number Normalizer
Normalizes invoice numbers for matching.
"""
import re
from .base import BaseNormalizer
class InvoiceNumberNormalizer(BaseNormalizer):
"""
Normalizes invoice numbers.
Keeps only digits for matching while preserving original format.
Examples:
'100017500321' -> ['100017500321']
'INV-100017500321' -> ['100017500321', 'INV-100017500321']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of invoice number."""
value = self.clean_text(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only and digits_only != value:
variants.append(digits_only)
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,31 @@
"""
OCR Number Normalizer
Normalizes OCR reference numbers (Swedish payment system).
"""
import re
from .base import BaseNormalizer
class OCRNormalizer(BaseNormalizer):
"""
Normalizes OCR reference numbers.
Similar to invoice number - primarily digits.
Examples:
'94228110015950070' -> ['94228110015950070']
'OCR: 94228110015950070' -> ['94228110015950070', 'OCR: 94228110015950070']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of OCR number."""
value = self.clean_text(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only and digits_only != value:
variants.append(digits_only)
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,39 @@
"""
Organisation Number Normalizer
Normalizes Swedish organisation numbers and VAT numbers.
"""
from .base import BaseNormalizer
from src.utils.format_variants import FormatVariants
from src.utils.text_cleaner import TextCleaner
class OrganisationNumberNormalizer(BaseNormalizer):
"""
Normalizes Swedish organisation numbers and VAT numbers.
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
Swedish VAT format: SE + org_number (10 digits) + 01
Examples:
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of organisation number."""
# Use shared module for base variants
variants = set(FormatVariants.organisation_number_variants(value))
# Add OCR error variants for digit sequences
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits and len(digits) >= 10:
# Generate variants where OCR might have misread characters
for ocr_var in TextCleaner.generate_ocr_variants(digits[:10]):
variants.add(ocr_var)
if len(ocr_var) == 10:
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
return list(v for v in variants if v)

View File

@@ -0,0 +1,34 @@
"""
Plusgiro Number Normalizer
Normalizes Swedish Plusgiro account numbers.
"""
from .base import BaseNormalizer
from src.utils.format_variants import FormatVariants
from src.utils.text_cleaner import TextCleaner
class PlusgiroNormalizer(BaseNormalizer):
"""
Normalizes Plusgiro numbers.
Generates format variants and OCR error variants.
Examples:
'1234567-8' -> ['1234567-8', '12345678', ...]
'12345678' -> ['12345678', '1234567-8', ...]
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of Plusgiro number."""
# Use shared module for base variants
variants = set(FormatVariants.plusgiro_variants(value))
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)

View File

@@ -0,0 +1,75 @@
"""
Supplier Accounts Normalizer
Normalizes supplier account numbers (Bankgiro/Plusgiro).
"""
import re
from .base import BaseNormalizer
class SupplierAccountsNormalizer(BaseNormalizer):
"""
Normalizes supplier accounts field.
The field may contain multiple accounts separated by ' | '.
Format examples:
'PG:48676043 | PG:49128028 | PG:8915035'
'BG:5393-9484'
Each account is normalized separately to generate variants.
Examples:
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3']
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of supplier accounts."""
value = self.clean_text(value)
variants = []
# Split by ' | ' to handle multiple accounts
accounts = [acc.strip() for acc in value.split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Add original value
variants.append(account)
# Remove prefix (PG:, BG:, etc.)
if ':' in account:
prefix, number = account.split(':', 1)
number = number.strip()
variants.append(number) # Just the number without prefix
# Also add with different prefix formats
prefix_upper = prefix.strip().upper()
variants.append(f"{prefix_upper}:{number}")
variants.append(f"{prefix_upper}: {number}") # With space
else:
number = account
# Extract digits only
digits_only = re.sub(r'\D', '', number)
if digits_only:
variants.append(digits_only)
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
if len(digits_only) == 8:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
# Also try 4-4 format for bankgiro
variants.append(f"{digits_only[:4]}-{digits_only[4:]}")
elif len(digits_only) == 7:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
elif len(digits_only) == 10:
# 6-4 format (like org number)
variants.append(f"{digits_only[:6]}-{digits_only[6:]}")
return list(set(v for v in variants if v))

View File

@@ -178,6 +178,93 @@ class MachineCodeParser:
""" """
self.bottom_region_ratio = bottom_region_ratio self.bottom_region_ratio = bottom_region_ratio
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
"""
Detect account type keywords in context.
Returns:
Dict with 'bankgiro' and 'plusgiro' boolean flags
"""
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
def _normalize_account_spaces(self, line: str) -> str:
"""
Remove spaces in account number portion after > marker.
Args:
line: Payment line text
Returns:
Line with normalized account number spacing
"""
if '>' not in line:
return line
parts = line.split('>', 1)
# After >, remove spaces between digits (but keep # markers)
after_arrow = parts[1]
# Extract digits and # markers, remove spaces between digits
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
# May need multiple passes for sequences like "78 2 1 713"
while re.search(r'(\d)\s+(\d)', normalized):
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
return parts[0] + '>' + normalized
def _format_account(
self,
account_digits: str,
is_plusgiro_context: bool
) -> tuple[str, str]:
"""
Format account number and determine type (bankgiro or plusgiro).
Uses context keywords first, then falls back to Luhn validation
to determine the most likely account type.
Args:
account_digits: Raw digits of account number
is_plusgiro_context: Whether context indicates Plusgiro
Returns:
Tuple of (formatted_account, account_type)
"""
if is_plusgiro_context:
# Context explicitly indicates Plusgiro
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
return formatted, 'plusgiro'
# No explicit context - use Luhn validation to determine type
# Try both formats and see which passes Luhn check
# Format as Plusgiro: XXXXXXX-X (all digits, check digit at end)
pg_formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
# Format as Bankgiro: XXX-XXXX or XXXX-XXXX
if len(account_digits) == 7:
bg_formatted = f"{account_digits[:3]}-{account_digits[3:]}"
elif len(account_digits) == 8:
bg_formatted = f"{account_digits[:4]}-{account_digits[4:]}"
else:
bg_formatted = account_digits
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
# Decision logic:
# 1. If only one format passes Luhn, use that
# 2. If both pass or both fail, default to Bankgiro (more common in payment lines)
if pg_valid and not bg_valid:
return pg_formatted, 'plusgiro'
elif bg_valid and not pg_valid:
return bg_formatted, 'bankgiro'
else:
# Both valid or both invalid - default to bankgiro
return bg_formatted, 'bankgiro'
def parse( def parse(
self, self,
tokens: list[TextToken], tokens: list[TextToken],
@@ -465,62 +552,7 @@ class MachineCodeParser:
) )
# Preprocess: remove spaces in the account number part (after >) # Preprocess: remove spaces in the account number part (after >)
# This handles cases like "78 2 1 713" -> "7821713" raw_line = self._normalize_account_spaces(raw_line)
def normalize_account_spaces(line: str) -> str:
"""Remove spaces in account number portion after > marker."""
if '>' in line:
parts = line.split('>', 1)
# After >, remove spaces between digits (but keep # markers)
after_arrow = parts[1]
# Extract digits and # markers, remove spaces between digits
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
# May need multiple passes for sequences like "78 2 1 713"
while re.search(r'(\d)\s+(\d)', normalized):
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
return parts[0] + '>' + normalized
return line
raw_line = normalize_account_spaces(raw_line)
def format_account(account_digits: str) -> tuple[str, str]:
"""Format account and determine type (bankgiro or plusgiro).
Uses context keywords first, then falls back to Luhn validation
to determine the most likely account type.
Returns: (formatted_account, account_type)
"""
if is_plusgiro_context:
# Context explicitly indicates Plusgiro
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
return formatted, 'plusgiro'
# No explicit context - use Luhn validation to determine type
# Try both formats and see which passes Luhn check
# Format as Plusgiro: XXXXXXX-X (all digits, check digit at end)
pg_formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
# Format as Bankgiro: XXX-XXXX or XXXX-XXXX
if len(account_digits) == 7:
bg_formatted = f"{account_digits[:3]}-{account_digits[3:]}"
elif len(account_digits) == 8:
bg_formatted = f"{account_digits[:4]}-{account_digits[4:]}"
else:
bg_formatted = account_digits
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
# Decision logic:
# 1. If only one format passes Luhn, use that
# 2. If both pass or both fail, default to Bankgiro (more common in payment lines)
if pg_valid and not bg_valid:
return pg_formatted, 'plusgiro'
elif bg_valid and not pg_valid:
return bg_formatted, 'bankgiro'
else:
# Both valid or both invalid - default to bankgiro
return bg_formatted, 'bankgiro'
# Try primary pattern # Try primary pattern
match = self.PAYMENT_LINE_PATTERN.search(raw_line) match = self.PAYMENT_LINE_PATTERN.search(raw_line)
@@ -533,7 +565,7 @@ class MachineCodeParser:
# Format amount: combine kronor and öre # Format amount: combine kronor and öre
amount = f"{kronor},{ore}" if ore != "00" else kronor amount = f"{kronor},{ore}" if ore != "00" else kronor
formatted_account, account_type = format_account(account_digits) formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
return { return {
'ocr': ocr, 'ocr': ocr,
@@ -551,7 +583,7 @@ class MachineCodeParser:
amount = f"{kronor},{ore}" if ore != "00" else kronor amount = f"{kronor},{ore}" if ore != "00" else kronor
formatted_account, account_type = format_account(account_digits) formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
return { return {
'ocr': ocr, 'ocr': ocr,
@@ -569,7 +601,7 @@ class MachineCodeParser:
amount = f"{kronor},{ore}" if ore != "00" else kronor amount = f"{kronor},{ore}" if ore != "00" else kronor
formatted_account, account_type = format_account(account_digits) formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
return { return {
'ocr': ocr, 'ocr': ocr,
@@ -637,16 +669,10 @@ class MachineCodeParser:
NOT Plusgiro: XXXXXXX-X (dash before last digit) NOT Plusgiro: XXXXXXX-X (dash before last digit)
""" """
candidates = [] candidates = []
context_text = ' '.join(t.text.lower() for t in tokens) context = self._detect_account_context(tokens)
# Check if this is clearly a Plusgiro context (not Bankgiro) # If clearly Plusgiro context (and not bankgiro), don't extract as Bankgiro
is_plusgiro_only_context = ( if context['plusgiro'] and not context['bankgiro']:
('plusgiro' in context_text or 'postgiro' in context_text or 'plusgirokonto' in context_text)
and 'bankgiro' not in context_text
)
# If clearly Plusgiro context, don't extract as Bankgiro
if is_plusgiro_only_context:
return None return None
for token in tokens: for token in tokens:
@@ -672,14 +698,7 @@ class MachineCodeParser:
else: else:
continue continue
# Check if "bankgiro" or "bg" appears nearby candidates.append((normalized, context['bankgiro'], token))
is_bankgiro_context = (
'bankgiro' in context_text or
'bg:' in context_text or
'bg ' in context_text
)
candidates.append((normalized, is_bankgiro_context, token))
if not candidates: if not candidates:
return None return None
@@ -691,6 +710,7 @@ class MachineCodeParser:
def _extract_plusgiro(self, tokens: list[TextToken]) -> Optional[str]: def _extract_plusgiro(self, tokens: list[TextToken]) -> Optional[str]:
"""Extract Plusgiro account number.""" """Extract Plusgiro account number."""
candidates = [] candidates = []
context = self._detect_account_context(tokens)
for token in tokens: for token in tokens:
text = token.text.strip() text = token.text.strip()
@@ -701,17 +721,7 @@ class MachineCodeParser:
digits = re.sub(r'\D', '', match) digits = re.sub(r'\D', '', match)
if 7 <= len(digits) <= 8: if 7 <= len(digits) <= 8:
normalized = f"{digits[:-1]}-{digits[-1]}" normalized = f"{digits[:-1]}-{digits[-1]}"
candidates.append((normalized, context['plusgiro'], token))
# Check context
context_text = ' '.join(t.text.lower() for t in tokens)
is_plusgiro_context = (
'plusgiro' in context_text or
'postgiro' in context_text or
'pg:' in context_text or
'pg ' in context_text
)
candidates.append((normalized, is_plusgiro_context, token))
if not candidates: if not candidates:
return None return None

View File

@@ -1,251 +0,0 @@
"""
Tests for Machine Code Parser
Tests the parsing of Swedish invoice payment lines including:
- Standard payment line format
- Account number normalization (spaces removal)
- Bankgiro/Plusgiro detection
- OCR and Amount extraction
"""
import pytest
from src.ocr.machine_code_parser import MachineCodeParser, MachineCodeResult
class TestParseStandardPaymentLine:
"""Tests for _parse_standard_payment_line method."""
@pytest.fixture
def parser(self):
return MachineCodeParser()
def test_standard_format_bankgiro(self, parser):
"""Test standard payment line with Bankgiro."""
line = "# 31130954410 # 315 00 2 > 8983025#14#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '31130954410'
assert result['amount'] == '315'
assert result['bankgiro'] == '898-3025'
def test_standard_format_with_ore(self, parser):
"""Test payment line with non-zero öre."""
line = "# 12345678901 # 100 50 2 > 7821713#41#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '12345678901'
assert result['amount'] == '100,50'
assert result['bankgiro'] == '782-1713'
def test_spaces_in_bankgiro(self, parser):
"""Test payment line with spaces in Bankgiro number."""
line = "# 310196187399952 # 11699 00 6 > 78 2 1 713 #41#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '310196187399952'
assert result['amount'] == '11699'
assert result['bankgiro'] == '782-1713'
def test_spaces_in_bankgiro_multiple(self, parser):
"""Test payment line with multiple spaces in account number."""
line = "# 123456789 # 500 00 1 > 1 2 3 4 5 6 7 #99#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['bankgiro'] == '123-4567'
def test_8_digit_bankgiro(self, parser):
"""Test 8-digit Bankgiro formatting."""
line = "# 12345678901 # 200 00 2 > 53939484#14#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['bankgiro'] == '5393-9484'
def test_plusgiro_context(self, parser):
"""Test Plusgiro detection based on context."""
line = "# 12345678901 # 100 00 2 > 1234567#14#"
result = parser._parse_standard_payment_line(line, context_line="plusgiro payment")
assert result is not None
assert 'plusgiro' in result
assert result['plusgiro'] == '123456-7'
def test_no_match_invalid_format(self, parser):
"""Test that invalid format returns None."""
line = "This is not a valid payment line"
result = parser._parse_standard_payment_line(line)
assert result is None
def test_alternative_pattern(self, parser):
"""Test alternative payment line pattern."""
line = "8120000849965361 11699 00 1 > 7821713"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '8120000849965361'
def test_long_ocr_number(self, parser):
"""Test OCR number up to 25 digits."""
line = "# 1234567890123456789012345 # 100 00 2 > 7821713#14#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '1234567890123456789012345'
def test_large_amount(self, parser):
"""Test large amount extraction."""
line = "# 12345678901 # 1234567 00 2 > 7821713#14#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['amount'] == '1234567'
class TestNormalizeAccountSpaces:
"""Tests for account number space normalization."""
@pytest.fixture
def parser(self):
return MachineCodeParser()
def test_no_spaces(self, parser):
"""Test line without spaces in account."""
line = "# 123456789 # 100 00 1 > 7821713#14#"
result = parser._parse_standard_payment_line(line)
assert result['bankgiro'] == '782-1713'
def test_single_space(self, parser):
"""Test single space between digits."""
line = "# 123456789 # 100 00 1 > 782 1713#14#"
result = parser._parse_standard_payment_line(line)
assert result['bankgiro'] == '782-1713'
def test_multiple_spaces(self, parser):
"""Test multiple spaces."""
line = "# 123456789 # 100 00 1 > 7 8 2 1 7 1 3#14#"
result = parser._parse_standard_payment_line(line)
assert result['bankgiro'] == '782-1713'
def test_no_arrow_marker(self, parser):
"""Test line without > marker - spaces not normalized."""
# Without >, the normalization won't happen
line = "# 123456789 # 100 00 1 7821713#14#"
result = parser._parse_standard_payment_line(line)
# This pattern might not match due to missing >
# Just ensure no crash
assert result is None or isinstance(result, dict)
class TestMachineCodeResult:
"""Tests for MachineCodeResult dataclass."""
def test_to_dict(self):
"""Test conversion to dictionary."""
result = MachineCodeResult(
ocr='12345678901',
amount='100',
bankgiro='782-1713',
confidence=0.95,
raw_line='test line'
)
d = result.to_dict()
assert d['ocr'] == '12345678901'
assert d['amount'] == '100'
assert d['bankgiro'] == '782-1713'
assert d['confidence'] == 0.95
assert d['raw_line'] == 'test line'
def test_empty_result(self):
"""Test empty result."""
result = MachineCodeResult()
d = result.to_dict()
assert d['ocr'] is None
assert d['amount'] is None
assert d['bankgiro'] is None
assert d['plusgiro'] is None
class TestRealWorldExamples:
"""Tests using real-world payment line examples."""
@pytest.fixture
def parser(self):
return MachineCodeParser()
def test_fastum_invoice(self, parser):
"""Test Fastum invoice payment line (from Faktura_A3861)."""
line = "# 310196187399952 # 11699 00 6 > 78 2 1 713 #41#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '310196187399952'
assert result['amount'] == '11699'
assert result['bankgiro'] == '782-1713'
def test_standard_bankgiro_invoice(self, parser):
"""Test standard Bankgiro format."""
line = "# 31130954410 # 315 00 2 > 8983025#14#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '31130954410'
assert result['amount'] == '315'
assert result['bankgiro'] == '898-3025'
def test_payment_line_with_extra_whitespace(self, parser):
"""Test payment line with extra whitespace."""
line = "# 310196187399952 # 11699 00 6 > 7821713 #41#"
result = parser._parse_standard_payment_line(line)
# May or may not match depending on regex flexibility
# At minimum, should not crash
assert result is None or isinstance(result, dict)
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
@pytest.fixture
def parser(self):
return MachineCodeParser()
def test_empty_string(self, parser):
"""Test empty string input."""
result = parser._parse_standard_payment_line("")
assert result is None
def test_only_whitespace(self, parser):
"""Test whitespace-only input."""
result = parser._parse_standard_payment_line(" \t\n ")
assert result is None
def test_minimum_ocr_length(self, parser):
"""Test minimum OCR length (5 digits)."""
line = "# 12345 # 100 00 1 > 7821713#14#"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '12345'
def test_minimum_bankgiro_length(self, parser):
"""Test minimum Bankgiro length (5 digits)."""
line = "# 12345678901 # 100 00 1 > 12345#14#"
result = parser._parse_standard_payment_line(line)
assert result is not None
def test_special_characters_in_line(self, parser):
"""Test handling of special characters."""
line = "# 12345678901 # 100 00 1 > 7821713#14# (SEK)"
result = parser._parse_standard_payment_line(line)
assert result is not None
assert result['ocr'] == '12345678901'
if __name__ == '__main__':
pytest.main([__file__, '-v'])

5
start_web.sh Normal file
View File

@@ -0,0 +1,5 @@
#!/bin/bash
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
source ~/miniconda3/etc/profile.d/conda.sh
conda activate invoice-py311
python run_server.py --port 8000

299
tests/README.md Normal file
View File

@@ -0,0 +1,299 @@
# Tests
完整的测试套件,遵循 pytest 最佳实践组织。
## 📁 测试目录结构
```
tests/
├── __init__.py
├── README.md # 本文件
├── data/ # 数据模块测试
│ ├── __init__.py
│ └── test_csv_loader.py # CSV 加载器测试
├── inference/ # 推理模块测试
│ ├── __init__.py
│ ├── test_field_extractor.py # 字段提取器测试
│ └── test_pipeline.py # 推理管道测试
├── matcher/ # 匹配模块测试
│ ├── __init__.py
│ └── test_field_matcher.py # 字段匹配器测试
├── normalize/ # 标准化模块测试
│ ├── __init__.py
│ ├── test_normalizer.py # FieldNormalizer 测试 (85 tests)
│ └── normalizers/ # 独立 normalizer 测试
│ ├── __init__.py
│ ├── test_invoice_number_normalizer.py # 12 tests
│ ├── test_ocr_normalizer.py # 9 tests
│ ├── test_bankgiro_normalizer.py # 11 tests
│ ├── test_plusgiro_normalizer.py # 10 tests
│ ├── test_amount_normalizer.py # 15 tests
│ ├── test_date_normalizer.py # 19 tests
│ ├── test_organisation_number_normalizer.py # 11 tests
│ ├── test_supplier_accounts_normalizer.py # 13 tests
│ ├── test_customer_number_normalizer.py # 12 tests
│ └── README.md # Normalizer 测试文档
├── ocr/ # OCR 模块测试
│ ├── __init__.py
│ └── test_machine_code_parser.py # 机器码解析器测试
├── pdf/ # PDF 模块测试
│ ├── __init__.py
│ ├── test_detector.py # PDF 类型检测器测试
│ └── test_extractor.py # PDF 提取器测试
├── utils/ # 工具模块测试
│ ├── __init__.py
│ ├── test_utils.py # 基础工具测试
│ └── test_advanced_utils.py # 高级工具测试
├── test_config.py # 配置测试
├── test_customer_number_parser.py # 客户编号解析器测试
├── test_db_security.py # 数据库安全测试
├── test_exceptions.py # 异常测试
└── test_payment_line_parser.py # 支付行解析器测试
```
## 📊 测试统计
**总测试数**: 628 个测试
**状态**: ✅ 全部通过
**执行时间**: ~7.7 秒
**代码覆盖率**: 37% (整体)
### 按模块分类
| 模块 | 测试文件数 | 测试数量 | 覆盖率 |
|------|-----------|---------|--------|
| **normalize** | 10 | 197 | ~98% |
| - normalizers/ | 9 | 112 | 100% |
| - test_normalizer.py | 1 | 85 | 71% |
| **utils** | 2 | ~149 | 73-93% |
| **pdf** | 2 | ~282 | 94-97% |
| **matcher** | 1 | ~402 | - |
| **ocr** | 1 | ~146 | 25% |
| **inference** | 2 | ~408 | - |
| **data** | 1 | ~282 | - |
| **其他** | 4 | ~110 | - |
## 🚀 运行测试
### 运行所有测试
```bash
# 在 WSL 环境中
conda activate invoice-py311
pytest tests/ -v
```
### 运行特定模块的测试
```bash
# Normalizer 测试
pytest tests/normalize/ -v
# 独立 normalizer 测试
pytest tests/normalize/normalizers/ -v
# PDF 测试
pytest tests/pdf/ -v
# Utils 测试
pytest tests/utils/ -v
# Inference 测试
pytest tests/inference/ -v
```
### 运行单个测试文件
```bash
pytest tests/normalize/normalizers/test_amount_normalizer.py -v
pytest tests/pdf/test_extractor.py -v
pytest tests/utils/test_utils.py -v
```
### 查看测试覆盖率
```bash
# 生成覆盖率报告
pytest tests/ --cov=src --cov-report=html
# 仅查看某个模块的覆盖率
pytest tests/normalize/ --cov=src/normalize --cov-report=term-missing
```
### 运行特定测试
```bash
# 按测试类运行
pytest tests/normalize/normalizers/test_amount_normalizer.py::TestAmountNormalizer -v
# 按测试方法运行
pytest tests/normalize/normalizers/test_amount_normalizer.py::TestAmountNormalizer::test_integer_amount -v
# 按关键字运行
pytest tests/ -k "normalizer" -v
pytest tests/ -k "amount" -v
```
## 🎯 测试最佳实践
### 1. 目录结构镜像源代码
测试目录结构镜像 `src/` 目录:
```
src/normalize/normalizers/amount_normalizer.py
tests/normalize/normalizers/test_amount_normalizer.py
```
### 2. 测试文件命名
- 测试文件以 `test_` 开头
- 测试类以 `Test` 开头
- 测试方法以 `test_` 开头
### 3. 使用 pytest fixtures
```python
@pytest.fixture
def normalizer():
"""Create normalizer instance for testing"""
return AmountNormalizer()
def test_something(normalizer):
result = normalizer.normalize('test')
assert 'expected' in result
```
### 4. 清晰的测试描述
```python
def test_with_comma_decimal(self, normalizer):
"""Amount with comma decimal should generate dot variant"""
result = normalizer.normalize('114,00')
assert '114.00' in result
```
### 5. Arrange-Act-Assert 模式
```python
def test_example(self):
# Arrange
input_data = 'test-input'
expected = 'expected-output'
# Act
result = process(input_data)
# Assert
assert expected in result
```
## 📝 添加新测试
### 为新功能添加测试
1. 在相应的 `tests/` 子目录创建测试文件
2. 遵循命名约定: `test_<module_name>.py`
3. 创建测试类和方法
4. 运行测试验证
示例:
```python
# tests/new_module/test_new_feature.py
import pytest
from src.new_module.new_feature import NewFeature
class TestNewFeature:
"""Test NewFeature functionality"""
@pytest.fixture
def feature(self):
"""Create feature instance for testing"""
return NewFeature()
def test_basic_functionality(self, feature):
"""Test basic functionality"""
result = feature.process('input')
assert result == 'expected'
def test_edge_case(self, feature):
"""Test edge case handling"""
result = feature.process('')
assert result == []
```
## 🔧 pytest 配置
项目的 pytest 配置在 `pyproject.toml`:
```toml
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
```
## 📈 持续集成
测试可以轻松集成到 CI/CD:
```yaml
# .github/workflows/test.yml
- name: Run Tests
run: |
conda activate invoice-py311
pytest tests/ -v --cov=src --cov-report=xml
- name: Upload Coverage
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
```
## 🎨 测试覆盖率目标
| 模块 | 当前覆盖率 | 目标 |
|------|-----------|------|
| normalize/ | 98% | ✅ 达标 |
| utils/ | 73-93% | 🎯 提升到 90% |
| pdf/ | 94-97% | ✅ 达标 |
| inference/ | 待评估 | 🎯 80% |
| matcher/ | 待评估 | 🎯 80% |
| ocr/ | 25% | 🎯 提升到 70% |
## 📚 相关文档
- [Normalizer Tests](normalize/normalizers/README.md) - 独立 normalizer 测试详细文档
- [pytest Documentation](https://docs.pytest.org/) - pytest 官方文档
- [Code Coverage](https://coverage.readthedocs.io/) - 覆盖率工具文档
## ✅ 测试检查清单
添加新功能时,确保:
- [ ] 创建对应的测试文件
- [ ] 测试正常功能
- [ ] 测试边界条件 (空值、None、空字符串)
- [ ] 测试错误处理
- [ ] 测试覆盖率 > 80%
- [ ] 所有测试通过
- [ ] 更新相关文档
## 🎉 总结
- ✅ **628 个测试**全部通过
- ✅ **镜像源代码**的清晰目录结构
-**遵循 pytest 最佳实践**
-**完整的文档**
-**易于维护和扩展**

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Test suite for invoice-master-poc-v2"""

0
tests/data/__init__.py Normal file
View File

View File

View File

View File

@@ -0,0 +1 @@
# Strategy tests

View File

@@ -0,0 +1,69 @@
"""
Tests for ExactMatcher strategy
Usage:
pytest tests/matcher/strategies/test_exact_matcher.py -v
"""
import pytest
from dataclasses import dataclass
from src.matcher.strategies.exact_matcher import ExactMatcher
@dataclass
class MockToken:
"""Mock token for testing"""
text: str
bbox: tuple[float, float, float, float]
page_no: int = 0
class TestExactMatcher:
"""Test ExactMatcher functionality"""
@pytest.fixture
def matcher(self):
"""Create matcher instance for testing"""
return ExactMatcher(context_radius=200.0)
def test_exact_match(self, matcher):
"""Exact text match should score 1.0"""
tokens = [
MockToken('100017500321', (100, 100, 200, 120)),
]
matches = matcher.find_matches(tokens, '100017500321', 'InvoiceNumber')
assert len(matches) == 1
assert matches[0].score == 1.0
assert matches[0].matched_text == '100017500321'
def test_case_insensitive_match(self, matcher):
"""Case-insensitive match should score 0.9 (digits-only for numeric fields)"""
tokens = [
MockToken('INV-12345', (100, 100, 200, 120)),
]
matches = matcher.find_matches(tokens, 'inv-12345', 'InvoiceNumber')
assert len(matches) == 1
# Without token_index, case-insensitive falls through to digits-only match
assert matches[0].score == 0.9
def test_digits_only_match(self, matcher):
"""Digits-only match for numeric fields should score 0.9"""
tokens = [
MockToken('INV-12345', (100, 100, 200, 120)),
]
matches = matcher.find_matches(tokens, '12345', 'InvoiceNumber')
assert len(matches) == 1
assert matches[0].score == 0.9
def test_no_match(self, matcher):
"""Non-matching value should return empty list"""
tokens = [
MockToken('100017500321', (100, 100, 200, 120)),
]
matches = matcher.find_matches(tokens, '999999', 'InvoiceNumber')
assert len(matches) == 0
def test_empty_tokens(self, matcher):
"""Empty token list should return empty matches"""
matches = matcher.find_matches([], '100017500321', 'InvoiceNumber')
assert len(matches) == 0

View File

@@ -9,13 +9,16 @@ Usage:
import pytest import pytest
from dataclasses import dataclass from dataclasses import dataclass
from src.matcher.field_matcher import ( from src.matcher.field_matcher import FieldMatcher, find_field_matches
FieldMatcher, from src.matcher.models import Match
Match, from src.matcher.token_index import TokenIndex
TokenIndex, from src.matcher.context import CONTEXT_KEYWORDS, find_context_keywords
CONTEXT_KEYWORDS, from src.matcher import utils as matcher_utils
_normalize_dashes, from src.matcher.utils import normalize_dashes as _normalize_dashes
find_field_matches, from src.matcher.strategies import (
SubstringMatcher,
FlexibleDateMatcher,
FuzzyMatcher,
) )
@@ -326,94 +329,82 @@ class TestFieldMatcherFuzzyMatch:
class TestFieldMatcherParseAmount: class TestFieldMatcherParseAmount:
"""Tests for _parse_amount method.""" """Tests for parse_amount function."""
def test_parse_simple_integer(self): def test_parse_simple_integer(self):
"""Should parse simple integer.""" """Should parse simple integer."""
matcher = FieldMatcher() assert matcher_utils.parse_amount("100") == 100.0
assert matcher._parse_amount("100") == 100.0
def test_parse_decimal_with_dot(self): def test_parse_decimal_with_dot(self):
"""Should parse decimal with dot.""" """Should parse decimal with dot."""
matcher = FieldMatcher() assert matcher_utils.parse_amount("100.50") == 100.50
assert matcher._parse_amount("100.50") == 100.50
def test_parse_decimal_with_comma(self): def test_parse_decimal_with_comma(self):
"""Should parse decimal with comma (European format).""" """Should parse decimal with comma (European format)."""
matcher = FieldMatcher() assert matcher_utils.parse_amount("100,50") == 100.50
assert matcher._parse_amount("100,50") == 100.50
def test_parse_with_thousand_separator(self): def test_parse_with_thousand_separator(self):
"""Should parse with thousand separator.""" """Should parse with thousand separator."""
matcher = FieldMatcher() assert matcher_utils.parse_amount("1 234,56") == 1234.56
assert matcher._parse_amount("1 234,56") == 1234.56
def test_parse_with_currency_suffix(self): def test_parse_with_currency_suffix(self):
"""Should parse and remove currency suffix.""" """Should parse and remove currency suffix."""
matcher = FieldMatcher() assert matcher_utils.parse_amount("100 SEK") == 100.0
assert matcher._parse_amount("100 SEK") == 100.0 assert matcher_utils.parse_amount("100 kr") == 100.0
assert matcher._parse_amount("100 kr") == 100.0
def test_parse_swedish_ore_format(self): def test_parse_swedish_ore_format(self):
"""Should parse Swedish öre format (kronor space öre).""" """Should parse Swedish öre format (kronor space öre)."""
matcher = FieldMatcher() assert matcher_utils.parse_amount("239 00") == 239.00
assert matcher._parse_amount("239 00") == 239.00 assert matcher_utils.parse_amount("1234 50") == 1234.50
assert matcher._parse_amount("1234 50") == 1234.50
def test_parse_invalid_returns_none(self): def test_parse_invalid_returns_none(self):
"""Should return None for invalid input.""" """Should return None for invalid input."""
matcher = FieldMatcher() assert matcher_utils.parse_amount("abc") is None
assert matcher._parse_amount("abc") is None assert matcher_utils.parse_amount("") is None
assert matcher._parse_amount("") is None
class TestFieldMatcherTokensOnSameLine: class TestFieldMatcherTokensOnSameLine:
"""Tests for _tokens_on_same_line method.""" """Tests for tokens_on_same_line function."""
def test_same_line_tokens(self): def test_same_line_tokens(self):
"""Should detect tokens on same line.""" """Should detect tokens on same line."""
matcher = FieldMatcher()
token1 = MockToken("hello", (0, 10, 50, 30)) token1 = MockToken("hello", (0, 10, 50, 30))
token2 = MockToken("world", (60, 12, 110, 28)) # Slight y variation token2 = MockToken("world", (60, 12, 110, 28)) # Slight y variation
assert matcher._tokens_on_same_line(token1, token2) is True assert matcher_utils.tokens_on_same_line(token1, token2) is True
def test_different_line_tokens(self): def test_different_line_tokens(self):
"""Should detect tokens on different lines.""" """Should detect tokens on different lines."""
matcher = FieldMatcher()
token1 = MockToken("hello", (0, 10, 50, 30)) token1 = MockToken("hello", (0, 10, 50, 30))
token2 = MockToken("world", (0, 50, 50, 70)) # Different y token2 = MockToken("world", (0, 50, 50, 70)) # Different y
assert matcher._tokens_on_same_line(token1, token2) is False assert matcher_utils.tokens_on_same_line(token1, token2) is False
class TestFieldMatcherBboxOverlap: class TestFieldMatcherBboxOverlap:
"""Tests for _bbox_overlap method.""" """Tests for bbox_overlap function."""
def test_full_overlap(self): def test_full_overlap(self):
"""Should return 1.0 for identical bboxes.""" """Should return 1.0 for identical bboxes."""
matcher = FieldMatcher()
bbox = (0, 0, 100, 50) bbox = (0, 0, 100, 50)
assert matcher._bbox_overlap(bbox, bbox) == 1.0 assert matcher_utils.bbox_overlap(bbox, bbox) == 1.0
def test_partial_overlap(self): def test_partial_overlap(self):
"""Should calculate partial overlap correctly.""" """Should calculate partial overlap correctly."""
matcher = FieldMatcher()
bbox1 = (0, 0, 100, 100) bbox1 = (0, 0, 100, 100)
bbox2 = (50, 50, 150, 150) # 50% overlap on each axis bbox2 = (50, 50, 150, 150) # 50% overlap on each axis
overlap = matcher._bbox_overlap(bbox1, bbox2) overlap = matcher_utils.bbox_overlap(bbox1, bbox2)
# Intersection: 50x50=2500, Union: 10000+10000-2500=17500 # Intersection: 50x50=2500, Union: 10000+10000-2500=17500
# IoU = 2500/17500 ≈ 0.143 # IoU = 2500/17500 ≈ 0.143
assert 0.1 < overlap < 0.2 assert 0.1 < overlap < 0.2
def test_no_overlap(self): def test_no_overlap(self):
"""Should return 0.0 for non-overlapping bboxes.""" """Should return 0.0 for non-overlapping bboxes."""
matcher = FieldMatcher()
bbox1 = (0, 0, 50, 50) bbox1 = (0, 0, 50, 50)
bbox2 = (100, 100, 150, 150) bbox2 = (100, 100, 150, 150)
assert matcher._bbox_overlap(bbox1, bbox2) == 0.0 assert matcher_utils.bbox_overlap(bbox1, bbox2) == 0.0
class TestFieldMatcherDeduplication: class TestFieldMatcherDeduplication:
@@ -552,21 +543,21 @@ class TestSubstringMatchEdgeCases:
def test_unsupported_field_returns_empty(self): def test_unsupported_field_returns_empty(self):
"""Should return empty for unsupported field types.""" """Should return empty for unsupported field types."""
# Line 380: field_name not in supported_fields # Line 380: field_name not in supported_fields
matcher = FieldMatcher() substring_matcher = SubstringMatcher()
tokens = [MockToken("Faktura: 12345", (0, 0, 100, 20))] tokens = [MockToken("Faktura: 12345", (0, 0, 100, 20))]
# Message is not a supported field for substring matching # Message is not a supported field for substring matching
matches = matcher._find_substring_matches(tokens, "12345", "Message") matches = substring_matcher.find_matches(tokens, "12345", "Message")
assert len(matches) == 0 assert len(matches) == 0
def test_case_insensitive_substring_match(self): def test_case_insensitive_substring_match(self):
"""Should find case-insensitive substring match.""" """Should find case-insensitive substring match."""
# Line 397-398: case-insensitive substring matching # Line 397-398: case-insensitive substring matching
matcher = FieldMatcher() substring_matcher = SubstringMatcher()
# Use token without inline keyword to isolate case-insensitive behavior # Use token without inline keyword to isolate case-insensitive behavior
tokens = [MockToken("REF: ABC123", (0, 0, 100, 20))] tokens = [MockToken("REF: ABC123", (0, 0, 100, 20))]
matches = matcher._find_substring_matches(tokens, "abc123", "InvoiceNumber") matches = substring_matcher.find_matches(tokens, "abc123", "InvoiceNumber")
assert len(matches) >= 1 assert len(matches) >= 1
# Case-insensitive base score is 0.70 (vs 0.75 for case-sensitive) # Case-insensitive base score is 0.70 (vs 0.75 for case-sensitive)
@@ -576,27 +567,27 @@ class TestSubstringMatchEdgeCases:
def test_substring_with_digit_before(self): def test_substring_with_digit_before(self):
"""Should not match when digit appears before value.""" """Should not match when digit appears before value."""
# Line 407-408: char_before.isdigit() continue # Line 407-408: char_before.isdigit() continue
matcher = FieldMatcher() substring_matcher = SubstringMatcher()
tokens = [MockToken("9912345", (0, 0, 60, 20))] tokens = [MockToken("9912345", (0, 0, 60, 20))]
matches = matcher._find_substring_matches(tokens, "12345", "InvoiceNumber") matches = substring_matcher.find_matches(tokens, "12345", "InvoiceNumber")
assert len(matches) == 0 assert len(matches) == 0
def test_substring_with_digit_after(self): def test_substring_with_digit_after(self):
"""Should not match when digit appears after value.""" """Should not match when digit appears after value."""
# Line 413-416: char_after.isdigit() continue # Line 413-416: char_after.isdigit() continue
matcher = FieldMatcher() substring_matcher = SubstringMatcher()
tokens = [MockToken("12345678", (0, 0, 70, 20))] tokens = [MockToken("12345678", (0, 0, 70, 20))]
matches = matcher._find_substring_matches(tokens, "12345", "InvoiceNumber") matches = substring_matcher.find_matches(tokens, "12345", "InvoiceNumber")
assert len(matches) == 0 assert len(matches) == 0
def test_substring_with_inline_keyword(self): def test_substring_with_inline_keyword(self):
"""Should boost score when keyword is in same token.""" """Should boost score when keyword is in same token."""
matcher = FieldMatcher() substring_matcher = SubstringMatcher()
tokens = [MockToken("Fakturanr: 12345", (0, 0, 100, 20))] tokens = [MockToken("Fakturanr: 12345", (0, 0, 100, 20))]
matches = matcher._find_substring_matches(tokens, "12345", "InvoiceNumber") matches = substring_matcher.find_matches(tokens, "12345", "InvoiceNumber")
assert len(matches) >= 1 assert len(matches) >= 1
# Should have inline keyword boost # Should have inline keyword boost
@@ -609,36 +600,36 @@ class TestFlexibleDateMatchEdgeCases:
def test_no_valid_date_in_normalized_values(self): def test_no_valid_date_in_normalized_values(self):
"""Should return empty when no valid date in normalized values.""" """Should return empty when no valid date in normalized values."""
# Line 520-521, 524: target_date parsing failures # Line 520-521, 524: target_date parsing failures
matcher = FieldMatcher() date_matcher = FlexibleDateMatcher()
tokens = [MockToken("2025-01-15", (0, 0, 80, 20))] tokens = [MockToken("2025-01-15", (0, 0, 80, 20))]
# Pass non-date values # Pass non-date value
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["not-a-date", "also-not-date"], "InvoiceDate" tokens, "not-a-date", "InvoiceDate"
) )
assert len(matches) == 0 assert len(matches) == 0
def test_no_date_tokens_found(self): def test_no_date_tokens_found(self):
"""Should return empty when no date tokens in document.""" """Should return empty when no date tokens in document."""
# Line 571-572: no date_candidates # Line 571-572: no date_candidates
matcher = FieldMatcher() date_matcher = FlexibleDateMatcher()
tokens = [MockToken("Hello World", (0, 0, 80, 20))] tokens = [MockToken("Hello World", (0, 0, 80, 20))]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
assert len(matches) == 0 assert len(matches) == 0
def test_flexible_date_within_7_days(self): def test_flexible_date_within_7_days(self):
"""Should score higher for dates within 7 days.""" """Should score higher for dates within 7 days."""
# Line 582-583: days_diff <= 7 # Line 582-583: days_diff <= 7
matcher = FieldMatcher(min_score_threshold=0.5) date_matcher = FlexibleDateMatcher()
tokens = [ tokens = [
MockToken("2025-01-18", (0, 0, 80, 20)), # 3 days from target MockToken("2025-01-18", (0, 0, 80, 20)), # 3 days from target
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
assert len(matches) >= 1 assert len(matches) >= 1
@@ -647,13 +638,13 @@ class TestFlexibleDateMatchEdgeCases:
def test_flexible_date_within_3_days(self): def test_flexible_date_within_3_days(self):
"""Should score highest for dates within 3 days.""" """Should score highest for dates within 3 days."""
# Line 584-585: days_diff <= 3 # Line 584-585: days_diff <= 3
matcher = FieldMatcher(min_score_threshold=0.5) date_matcher = FlexibleDateMatcher()
tokens = [ tokens = [
MockToken("2025-01-17", (0, 0, 80, 20)), # 2 days from target MockToken("2025-01-17", (0, 0, 80, 20)), # 2 days from target
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
assert len(matches) >= 1 assert len(matches) >= 1
@@ -662,13 +653,13 @@ class TestFlexibleDateMatchEdgeCases:
def test_flexible_date_within_14_days_different_month(self): def test_flexible_date_within_14_days_different_month(self):
"""Should match dates within 14 days even in different month.""" """Should match dates within 14 days even in different month."""
# Line 587-588: days_diff <= 14, different year-month # Line 587-588: days_diff <= 14, different year-month
matcher = FieldMatcher(min_score_threshold=0.5) date_matcher = FlexibleDateMatcher()
tokens = [ tokens = [
MockToken("2025-02-05", (0, 0, 80, 20)), # 10 days from Jan 26 MockToken("2025-02-05", (0, 0, 80, 20)), # 10 days from Jan 26
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-26"], "InvoiceDate" tokens, "2025-01-26", "InvoiceDate"
) )
assert len(matches) >= 1 assert len(matches) >= 1
@@ -676,13 +667,13 @@ class TestFlexibleDateMatchEdgeCases:
def test_flexible_date_within_30_days(self): def test_flexible_date_within_30_days(self):
"""Should match dates within 30 days with lower score.""" """Should match dates within 30 days with lower score."""
# Line 589-590: days_diff <= 30 # Line 589-590: days_diff <= 30
matcher = FieldMatcher(min_score_threshold=0.5) date_matcher = FlexibleDateMatcher()
tokens = [ tokens = [
MockToken("2025-02-10", (0, 0, 80, 20)), # 25 days from target MockToken("2025-02-10", (0, 0, 80, 20)), # 25 days from target
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-16"], "InvoiceDate" tokens, "2025-01-16", "InvoiceDate"
) )
assert len(matches) >= 1 assert len(matches) >= 1
@@ -691,13 +682,13 @@ class TestFlexibleDateMatchEdgeCases:
def test_flexible_date_far_apart_without_context(self): def test_flexible_date_far_apart_without_context(self):
"""Should skip dates too far apart without context keywords.""" """Should skip dates too far apart without context keywords."""
# Line 591-595: > 30 days, no context # Line 591-595: > 30 days, no context
matcher = FieldMatcher(min_score_threshold=0.5) date_matcher = FlexibleDateMatcher()
tokens = [ tokens = [
MockToken("2025-06-15", (0, 0, 80, 20)), # Many months from target MockToken("2025-06-15", (0, 0, 80, 20)), # Many months from target
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
# Should be empty - too far apart and no context # Should be empty - too far apart and no context
@@ -706,14 +697,14 @@ class TestFlexibleDateMatchEdgeCases:
def test_flexible_date_far_with_context(self): def test_flexible_date_far_with_context(self):
"""Should match distant dates if context keywords present.""" """Should match distant dates if context keywords present."""
# Line 592-595: > 30 days but has context # Line 592-595: > 30 days but has context
matcher = FieldMatcher(min_score_threshold=0.5, context_radius=200) date_matcher = FlexibleDateMatcher(context_radius=200)
tokens = [ tokens = [
MockToken("fakturadatum", (0, 0, 80, 20)), # Context keyword MockToken("fakturadatum", (0, 0, 80, 20)), # Context keyword
MockToken("2025-06-15", (90, 0, 170, 20)), # Distant date MockToken("2025-06-15", (90, 0, 170, 20)), # Distant date
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
# May match due to context keyword # May match due to context keyword
@@ -722,14 +713,14 @@ class TestFlexibleDateMatchEdgeCases:
def test_flexible_date_boost_with_context(self): def test_flexible_date_boost_with_context(self):
"""Should boost flexible date score with context keywords.""" """Should boost flexible date score with context keywords."""
# Line 598, 602-603: context_boost applied # Line 598, 602-603: context_boost applied
matcher = FieldMatcher(min_score_threshold=0.5, context_radius=200) date_matcher = FlexibleDateMatcher(context_radius=200)
tokens = [ tokens = [
MockToken("fakturadatum", (0, 0, 80, 20)), MockToken("fakturadatum", (0, 0, 80, 20)),
MockToken("2025-01-18", (90, 0, 170, 20)), # 3 days from target MockToken("2025-01-18", (90, 0, 170, 20)), # 3 days from target
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
if len(matches) > 0: if len(matches) > 0:
@@ -751,7 +742,7 @@ class TestContextKeywordFallback:
] ]
# _token_index is None, so fallback is used # _token_index is None, so fallback is used
keywords, boost = matcher._find_context_keywords(tokens, tokens[1], "InvoiceNumber") keywords, boost = find_context_keywords(tokens, tokens[1], "InvoiceNumber", 200.0)
assert "fakturanr" in keywords assert "fakturanr" in keywords
assert boost > 0 assert boost > 0
@@ -765,7 +756,7 @@ class TestContextKeywordFallback:
token = MockToken("fakturanr 12345", (0, 0, 150, 20)) token = MockToken("fakturanr 12345", (0, 0, 150, 20))
tokens = [token] tokens = [token]
keywords, boost = matcher._find_context_keywords(tokens, token, "InvoiceNumber") keywords, boost = find_context_keywords(tokens, token, "InvoiceNumber", 200.0)
# Token contains keyword but is the target - should still find if keyword in token # Token contains keyword but is the target - should still find if keyword in token
# Actually this tests that it doesn't error when target is in list # Actually this tests that it doesn't error when target is in list
@@ -783,7 +774,7 @@ class TestFieldWithoutContextKeywords:
tokens = [MockToken("hello", (0, 0, 50, 20))] tokens = [MockToken("hello", (0, 0, 50, 20))]
# customer_number is not in CONTEXT_KEYWORDS # customer_number is not in CONTEXT_KEYWORDS
keywords, boost = matcher._find_context_keywords(tokens, tokens[0], "UnknownField") keywords, boost = find_context_keywords(tokens, tokens[0], "UnknownField", 200.0)
assert keywords == [] assert keywords == []
assert boost == 0.0 assert boost == 0.0
@@ -795,20 +786,20 @@ class TestParseAmountEdgeCases:
def test_parse_amount_with_parentheses(self): def test_parse_amount_with_parentheses(self):
"""Should remove parenthesized text like (inkl. moms).""" """Should remove parenthesized text like (inkl. moms)."""
matcher = FieldMatcher() matcher = FieldMatcher()
result = matcher._parse_amount("100 (inkl. moms)") result = matcher_utils.parse_amount("100 (inkl. moms)")
assert result == 100.0 assert result == 100.0
def test_parse_amount_with_kronor_suffix(self): def test_parse_amount_with_kronor_suffix(self):
"""Should handle 'kronor' suffix.""" """Should handle 'kronor' suffix."""
matcher = FieldMatcher() matcher = FieldMatcher()
result = matcher._parse_amount("100 kronor") result = matcher_utils.parse_amount("100 kronor")
assert result == 100.0 assert result == 100.0
def test_parse_amount_numeric_input(self): def test_parse_amount_numeric_input(self):
"""Should handle numeric input (int/float).""" """Should handle numeric input (int/float)."""
matcher = FieldMatcher() matcher = FieldMatcher()
assert matcher._parse_amount(100) == 100.0 assert matcher_utils.parse_amount(100) == 100.0
assert matcher._parse_amount(100.5) == 100.5 assert matcher_utils.parse_amount(100.5) == 100.5
class TestFuzzyMatchExceptionHandling: class TestFuzzyMatchExceptionHandling:
@@ -822,23 +813,20 @@ class TestFuzzyMatchExceptionHandling:
tokens = [MockToken("abc xyz", (0, 0, 50, 20))] tokens = [MockToken("abc xyz", (0, 0, 50, 20))]
# This should not raise, just return empty matches # This should not raise, just return empty matches
matches = matcher._find_fuzzy_matches(tokens, "100", "Amount") matches = FuzzyMatcher().find_matches(tokens, "100", "Amount")
assert len(matches) == 0 assert len(matches) == 0
def test_fuzzy_match_exception_in_context_lookup(self): def test_fuzzy_match_exception_in_context_lookup(self):
"""Should catch exceptions during fuzzy match processing.""" """Should catch exceptions during fuzzy match processing."""
# Line 481-482: general exception handler # After refactoring, context lookup is in separate module
from unittest.mock import patch, MagicMock # This test is no longer applicable as we use find_context_keywords function
# Instead, we test that fuzzy matcher handles unparseable amounts gracefully
fuzzy_matcher = FuzzyMatcher()
tokens = [MockToken("not-a-number", (0, 0, 50, 20))]
matcher = FieldMatcher() # Should not crash on unparseable amount
tokens = [MockToken("100", (0, 0, 50, 20))] matches = fuzzy_matcher.find_matches(tokens, "100", "Amount")
assert len(matches) == 0
# Mock _find_context_keywords to raise an exception
with patch.object(matcher, '_find_context_keywords', side_effect=RuntimeError("Test error")):
# Should not raise, exception should be caught
matches = matcher._find_fuzzy_matches(tokens, "100", "Amount")
# Should return empty due to exception
assert len(matches) == 0
class TestFlexibleDateInvalidDateParsing: class TestFlexibleDateInvalidDateParsing:
@@ -847,13 +835,13 @@ class TestFlexibleDateInvalidDateParsing:
def test_invalid_date_in_normalized_values(self): def test_invalid_date_in_normalized_values(self):
"""Should handle invalid dates in normalized values gracefully.""" """Should handle invalid dates in normalized values gracefully."""
# Line 520-521: ValueError continue in target date parsing # Line 520-521: ValueError continue in target date parsing
matcher = FieldMatcher() date_matcher = FlexibleDateMatcher()
tokens = [MockToken("2025-01-15", (0, 0, 80, 20))] tokens = [MockToken("2025-01-15", (0, 0, 80, 20))]
# Pass an invalid date that matches the pattern but is not a valid date # Pass an invalid date that matches the pattern but is not a valid date
# e.g., 2025-13-45 matches pattern but month 13 is invalid # e.g., 2025-13-45 matches pattern but month 13 is invalid
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-13-45"], "InvoiceDate" tokens, "2025-13-45", "InvoiceDate"
) )
# Should return empty as no valid target date could be parsed # Should return empty as no valid target date could be parsed
assert len(matches) == 0 assert len(matches) == 0
@@ -861,14 +849,14 @@ class TestFlexibleDateInvalidDateParsing:
def test_invalid_date_token_in_document(self): def test_invalid_date_token_in_document(self):
"""Should skip invalid date-like tokens in document.""" """Should skip invalid date-like tokens in document."""
# Line 568-569: ValueError continue in date token parsing # Line 568-569: ValueError continue in date token parsing
matcher = FieldMatcher(min_score_threshold=0.5) date_matcher = FlexibleDateMatcher()
tokens = [ tokens = [
MockToken("2025-99-99", (0, 0, 80, 20)), # Invalid date in doc MockToken("2025-99-99", (0, 0, 80, 20)), # Invalid date in doc
MockToken("2025-01-18", (0, 50, 80, 70)), # Valid date MockToken("2025-01-18", (0, 50, 80, 70)), # Valid date
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
# Should only match the valid date # Should only match the valid date
@@ -878,13 +866,13 @@ class TestFlexibleDateInvalidDateParsing:
def test_flexible_date_with_inline_keyword(self): def test_flexible_date_with_inline_keyword(self):
"""Should detect inline keywords in date tokens.""" """Should detect inline keywords in date tokens."""
# Line 555: inline_keywords append # Line 555: inline_keywords append
matcher = FieldMatcher(min_score_threshold=0.5) date_matcher = FlexibleDateMatcher()
tokens = [ tokens = [
MockToken("Fakturadatum: 2025-01-18", (0, 0, 150, 20)), MockToken("Fakturadatum: 2025-01-18", (0, 0, 150, 20)),
] ]
matches = matcher._find_flexible_date_matches( matches = date_matcher.find_matches(
tokens, ["2025-01-15"], "InvoiceDate" tokens, "2025-01-15", "InvoiceDate"
) )
# Should find match with inline keyword # Should find match with inline keyword

Some files were not shown because too many files have changed in this diff Show More