Compare commits

12 Commits

Author SHA1 Message Date
Yaojia Wang
d8f2acb762 fix: change default OCR language from English to Swedish
Project targets Swedish invoice extraction. PaddleOCR sv model provides
better recognition of Swedish-specific characters (å, ä, ö).
2026-02-12 23:19:51 +01:00
Yaojia Wang
58d36c8927 WIP 2026-02-12 23:06:00 +01:00
Yaojia Wang
ad5ed46b4c WIP 2026-02-11 23:40:38 +01:00
Yaojia Wang
f1a7bfe6b7 WIP 2026-02-07 13:56:00 +01:00
Yaojia Wang
0990239e9c feat: add field-specific bbox expansion strategies for YOLO training
Implement center-point based bbox scaling with directional compensation
to capture field labels that typically appear above or to the left of
field values. This improves YOLO training data quality by including
contextual information around field values.

Key changes:
- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function
- Define field-specific strategies (ocr_number, bankgiro, invoice_date, etc.)
- Support manual_mode for minimal padding (no scaling)
- Integrate expand_bbox into AnnotationGenerator
- Add FIELD_TO_CLASS mapping for field_name to class_name lookup
- Comprehensive tests with 100% coverage (45 tests)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 22:56:52 +01:00
Yaojia Wang
8723ef4653 refactor: split line_items_extractor into smaller modules with comprehensive tests
- Extract models.py (LineItem, LineItemsResult dataclasses)
- Extract html_table_parser.py (ColumnMapper, HtmlTableParser)
- Extract merged_cell_handler.py (MergedCellHandler for PP-StructureV3 merged cells)
- Reduce line_items_extractor.py from 971 to 396 lines
- Add constants for magic numbers (MIN_AMOUNT_THRESHOLD, ROW_GROUPING_THRESHOLD, etc.)
- Fix row grouping algorithm in text_line_items_extractor.py
- Demote INFO logs to DEBUG level in structure_detector.py
- Add 209 tests achieving 85%+ coverage on main modules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 23:02:00 +01:00
Yaojia Wang
c2c8f2dd04 WIP 2026-02-03 22:29:53 +01:00
Yaojia Wang
4c7fc3015c fix: add PDF magic bytes validation to prevent file type spoofing
Add validation that checks PDF files start with '%PDF' magic bytes
before accepting uploads. This prevents attackers from uploading
malicious files (executables, scripts) by renaming them to .pdf.

- Add validate_pdf_magic_bytes() function with clear error messages
- Integrate validation in upload_document endpoint after file read
- Add comprehensive test coverage (13 test cases)

Addresses medium-risk security issue from code review.
2026-02-03 22:28:24 +01:00
Yaojia Wang
183d3503ef Prepare for opencode 2026-02-03 22:03:44 +01:00
Yaojia Wang
729d96f59e Merge branch 'feature/paddleocr-upgrade' 2026-02-03 21:28:33 +01:00
Yaojia Wang
35988b1ebf Update paddle, and support invoice line item 2026-02-03 21:28:06 +01:00
Yaojia Wang
c4e3773df1 feat: upgrade PaddlePaddle and PaddleOCR to 3.x
- Update paddlepaddle from >=2.5.0 to >=3.0.0,<3.3.0
- Update paddleocr from >=2.7.0 to >=3.0.0
- Update paddlepaddle-gpu from >=2.5.0 to >=3.0.0,<3.3.0

Note: PaddlePaddle 3.3.0 has an OneDNN bug that breaks CPU inference
(ConvertPirAttribute2RuntimeAttribute not implemented). Using <3.3.0
until the bug is fixed upstream.

This upgrade enables PP-StructureV3 for table extraction and uses
PP-OCRv5 for improved text recognition accuracy. The existing codebase
is already compatible with the 3.x API (predict() method and new
response format).

Verified:
- PaddleOCR import works
- PPStructureV3 is available
- OCREngine initializes correctly
- Inference API returns correct field extractions
- 2117 unit tests pass

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 12:15:02 +01:00
186 changed files with 19537 additions and 4096 deletions

View File

@@ -1,85 +1,72 @@
# Invoice Master POC v2
Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
Swedish Invoice Field Extraction System - YOLO + PaddleOCR extracts structured data from Swedish PDF invoices.
## Architecture
```
PDF → PyMuPDF (DPI=150) → YOLO Detection → PaddleOCR → Field Extraction → Normalization → Output
```
### Project Structure
```
packages/
├── backend/ # FastAPI web server + inference pipeline
│ └── pipeline/ # YOLO detector → OCR → field extractor → value selector → normalizers
├── shared/ # Common utilities (bbox, OCR, field mappings)
└── training/ # YOLO training data generation (annotation, dataset)
tests/ # Mirrors packages/ structure
```
### Pipeline Flow (process_pdf)
1. YOLO detects field regions on rendered PDF page
2. PaddleOCR extracts text from detected bboxes
3. Field extractor maps detections to invoice fields via CLASS_TO_FIELD
4. Value selector picks best candidate per field (confidence + validation)
5. Normalizers clean values (dates, amounts, invoice numbers)
6. Fallback regex extraction if key fields missing
## Tech Stack
| Component | Technology |
|-----------|------------|
| Object Detection | YOLOv11 (Ultralytics) |
| OCR Engine | PaddleOCR v5 (PP-OCRv5) |
| PDF Processing | PyMuPDF (fitz) |
| Object Detection | YOLO (Ultralytics >= 8.4.0) |
| OCR | PaddleOCR v5 (PP-OCRv5) |
| PDF | PyMuPDF (fitz), DPI=150 |
| Database | PostgreSQL + psycopg2 |
| Web Framework | FastAPI + Uvicorn |
| Deep Learning | PyTorch + CUDA 12.x |
| Web | FastAPI + Uvicorn |
| ML | PyTorch + CUDA 12.x |
## WSL Environment (REQUIRED)
**Prefix ALL commands with:**
ALL Python commands MUST use this prefix:
```bash
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && <command>"
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && <command>"
```
**NEVER run Python commands directly in Windows PowerShell/CMD.**
NEVER run Python directly in Windows PowerShell/CMD.
## Project-Specific Rules
## Project Rules
- Python 3.11+ with type hints
- No print() in production - use logging
- Run tests: `pytest --cov=src`
- Python 3.10, type hints on all function signatures
- No `print()` in production code - use `logging` module
- Validation with `pydantic` or `dataclasses`
- Error handling with `try/except` (not try/catch)
- Run tests: `pytest --cov=packages tests/`
## File Structure
## Key Files
```
src/
├── cli/ # autolabel, train, infer, serve
├── pdf/ # extractor, renderer, detector
├── ocr/ # PaddleOCR wrapper, machine_code_parser
├── inference/ # pipeline, yolo_detector, field_extractor
├── normalize/ # Per-field normalizers
├── matcher/ # Exact, substring, fuzzy strategies
├── processing/ # CPU/GPU pool architecture
├── web/ # FastAPI app, routes, services, schemas
├── utils/ # validators, text_cleaner, fuzzy_matcher
└── data/ # Database operations
tests/ # Mirror of src structure
runs/train/ # Training outputs
```
## Supported Fields
| ID | Field | Description |
|----|-------|-------------|
| 0 | invoice_number | Invoice number |
| 1 | invoice_date | Invoice date |
| 2 | invoice_due_date | Due date |
| 3 | ocr_number | OCR reference (Swedish payment) |
| 4 | bankgiro | Bankgiro account |
| 5 | plusgiro | Plusgiro account |
| 6 | amount | Amount |
| 7 | supplier_organisation_number | Supplier org number |
| 8 | payment_line | Payment line (machine-readable) |
| 9 | customer_number | Customer number |
## Key Patterns
### Inference Result
```python
@dataclass
class InferenceResult:
document_id: str
document_type: str # "invoice" or "letter"
fields: dict[str, str]
confidence: dict[str, float]
cross_validation: CrossValidationResult | None
processing_time_ms: float
```
### API Schemas
See `src/web/schemas.py` for request/response models.
| File | Purpose |
|------|---------|
| `packages/backend/backend/pipeline/pipeline.py` | Main inference pipeline |
| `packages/backend/backend/pipeline/field_extractor.py` | YOLO → field mapping |
| `packages/backend/backend/pipeline/value_selector.py` | Best candidate selection |
| `packages/shared/shared/fields/mappings.py` | CLASS_TO_FIELD mapping |
| `packages/shared/shared/ocr/paddle_ocr.py` | OCRToken definition |
| `packages/shared/shared/bbox/` | Bbox expansion strategies |
## Environment Variables
@@ -98,46 +85,40 @@ SERVER_HOST=0.0.0.0
SERVER_PORT=8000
```
## CLI Commands
## Auto-trigger Rules (ALWAYS FOLLOW - even after context compaction)
```bash
# Auto-labeling
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
These rules MUST be followed regardless of conversation history:
# Training
python -m src.cli.train --model yolo11n.pt --epochs 100 --batch 16 --name invoice_fields
- New feature or bug fix → MUST use **tdd-guide** agent (write tests first)
- When writing code → MUST follow coding standards skill for the target language:
- Python → `python-patterns` (PEP 8, type hints, Pythonic idioms)
- C# → `dotnet-skills:coding-standards` (records, pattern matching, modern C#)
- TS/JS → `coding-standards` (universal best practices)
- After writing/modifying code → MUST use **code-reviewer** agent
- Before git commit → MUST use **security-reviewer** agent
- When build/test fails → MUST use **build-error-resolver** agent
- After context compaction → read MEMORY.md to restore session state
# Inference
python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt --input invoice.pdf --gpu
## Plan Completion Protocol
# Web Server
python run_server.py --port 8000
```
After completing any plan or major task:
## API Endpoints
1. **Test** - Run `pytest` to confirm all tests pass
2. **Security review** - Use **security-reviewer** agent on changed files
3. **Fix loop** - If security review reports CRITICAL or HIGH issues:
- Fix the issues
- Re-run tests (back to step 1)
- Re-run security review (back to step 2)
- Repeat until no CRITICAL/HIGH issues remain
4. **Commit** - Auto-commit with conventional commit message (`feat:`, `fix:`, `refactor:`, etc.). Stage only the files changed in this task, not unrelated files
5. **Save** - Write a summary to MEMORY.md including: what was done, files changed, decisions made, remaining work
6. **Suggest clear** - Tell the user: "Plan complete. Recommend `/clear` to free context for the next task."
7. **Do NOT start a new task** in the same context - wait for user to /clear first
| Method | Endpoint | Description |
|--------|----------|-------------|
| GET | `/` | Web UI |
| GET | `/api/v1/health` | Health check |
| POST | `/api/v1/infer` | Process invoice |
| GET | `/api/v1/results/{filename}` | Get visualization |
This keeps each plan in a fresh context window for maximum quality.
## Current Status
## Known Issues
- **Tests**: 688 passing
- **Coverage**: 37%
- **Model**: 93.5% mAP@0.5
- **Documents Labeled**: 9,738
## Quick Start
```bash
# Start server
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py"
# Run tests
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest"
# Access UI: http://localhost:8000
```
- Pre-existing test failures: `test_s3.py`, `test_azure.py` (missing boto3/azure) - safe to ignore
- Always re-run dedup/validation after fallback adds new fields
- PDF DPI must be 150 (not 300) for correct bbox alignment

View File

@@ -1,22 +0,0 @@
# Build and Fix
Incrementally fix Python errors and test failures.
## Workflow
1. Run check: `mypy src/ --ignore-missing-imports` or `pytest -x --tb=short`
2. Parse errors, group by file, sort by severity (ImportError > TypeError > other)
3. For each error:
- Show context (5 lines)
- Explain and propose fix
- Apply fix
- Re-run test for that file
- Verify resolved
4. Stop if: fix introduces new errors, same error after 3 attempts, or user pauses
5. Show summary: fixed / remaining / new errors
## Rules
- Fix ONE error at a time
- Re-run tests after each fix
- Never batch multiple unrelated fixes

View File

@@ -1,74 +0,0 @@
# Checkpoint Command
Create or verify a checkpoint in your workflow.
## Usage
`/checkpoint [create|verify|list] [name]`
## Create Checkpoint
When creating a checkpoint:
1. Run `/verify quick` to ensure current state is clean
2. Create a git stash or commit with checkpoint name
3. Log checkpoint to `.claude/checkpoints.log`:
```bash
echo "$(date +%Y-%m-%d-%H:%M) | $CHECKPOINT_NAME | $(git rev-parse --short HEAD)" >> .claude/checkpoints.log
```
4. Report checkpoint created
## Verify Checkpoint
When verifying against a checkpoint:
1. Read checkpoint from log
2. Compare current state to checkpoint:
- Files added since checkpoint
- Files modified since checkpoint
- Test pass rate now vs then
- Coverage now vs then
3. Report:
```
CHECKPOINT COMPARISON: $NAME
============================
Files changed: X
Tests: +Y passed / -Z failed
Coverage: +X% / -Y%
Build: [PASS/FAIL]
```
## List Checkpoints
Show all checkpoints with:
- Name
- Timestamp
- Git SHA
- Status (current, behind, ahead)
## Workflow
Typical checkpoint flow:
```
[Start] --> /checkpoint create "feature-start"
|
[Implement] --> /checkpoint create "core-done"
|
[Test] --> /checkpoint verify "core-done"
|
[Refactor] --> /checkpoint create "refactor-done"
|
[PR] --> /checkpoint verify "feature-start"
```
## Arguments
$ARGUMENTS:
- `create <name>` - Create named checkpoint
- `verify <name>` - Verify against named checkpoint
- `list` - Show all checkpoints
- `clear` - Remove old checkpoints (keeps last 5)

View File

@@ -1,46 +0,0 @@
# Code Review
Security and quality review of uncommitted changes.
## Workflow
1. Get changed files: `git diff --name-only HEAD` and `git diff --staged --name-only`
2. Review each file for issues (see checklist below)
3. Run automated checks: `mypy src/`, `ruff check src/`, `pytest -x`
4. Generate report with severity, location, description, suggested fix
5. Block commit if CRITICAL or HIGH issues found
## Checklist
### CRITICAL (Block)
- Hardcoded credentials, API keys, tokens, passwords
- SQL injection (must use parameterized queries)
- Path traversal risks
- Missing input validation on API endpoints
- Missing authentication/authorization
### HIGH (Block)
- Functions > 50 lines, files > 800 lines
- Nesting depth > 4 levels
- Missing error handling or bare `except:`
- `print()` in production code (use logging)
- Mutable default arguments
### MEDIUM (Warn)
- Missing type hints on public functions
- Missing tests for new code
- Duplicate code, magic numbers
- Unused imports/variables
- TODO/FIXME comments
## Report Format
```
[SEVERITY] file:line - Issue description
Suggested fix: ...
```
## Never Approve Code With Security Vulnerabilities!

View File

@@ -1,40 +0,0 @@
# E2E Testing
End-to-end testing for the Invoice Field Extraction API.
## When to Use
- Testing complete inference pipeline (PDF -> Fields)
- Verifying API endpoints work end-to-end
- Validating YOLO + OCR + field extraction integration
- Pre-deployment verification
## Workflow
1. Ensure server is running: `python run_server.py`
2. Run health check: `curl http://localhost:8000/api/v1/health`
3. Run E2E tests: `pytest tests/e2e/ -v`
4. Verify results and capture any failures
## Critical Scenarios (Must Pass)
1. Health check returns `{"status": "healthy", "model_loaded": true}`
2. PDF upload returns valid response with fields
3. Fields extracted with confidence scores
4. Visualization image generated
5. Cross-validation included for invoices with payment_line
## Checklist
- [ ] Server running on http://localhost:8000
- [ ] Health check passes
- [ ] PDF inference returns valid JSON
- [ ] At least one field extracted
- [ ] Visualization URL returns image
- [ ] Response time < 10 seconds
- [ ] No server errors in logs
## Test Location
E2E tests: `tests/e2e/`
Sample fixtures: `tests/fixtures/`

View File

@@ -1,174 +0,0 @@
# Eval Command
Evaluate model performance and field extraction accuracy.
## Usage
`/eval [model|accuracy|compare|report]`
## Model Evaluation
`/eval model`
Evaluate YOLO model performance on test dataset:
```bash
# Run model evaluation
python -m src.cli.train --model runs/train/invoice_fields/weights/best.pt --eval-only
# Or use ultralytics directly
yolo val model=runs/train/invoice_fields/weights/best.pt data=data.yaml
```
Output:
```
Model Evaluation: invoice_fields/best.pt
========================================
mAP@0.5: 93.5%
mAP@0.5-0.95: 83.0%
Per-class AP:
- invoice_number: 95.2%
- invoice_date: 94.8%
- invoice_due_date: 93.1%
- ocr_number: 91.5%
- bankgiro: 92.3%
- plusgiro: 90.8%
- amount: 88.7%
- supplier_org_num: 85.2%
- payment_line: 82.4%
- customer_number: 81.1%
```
## Accuracy Evaluation
`/eval accuracy`
Evaluate field extraction accuracy against ground truth:
```bash
# Run accuracy evaluation on labeled data
python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt \
--input ~/invoice-data/test/*.pdf \
--ground-truth ~/invoice-data/test/labels.csv \
--output eval_results.json
```
Output:
```
Field Extraction Accuracy
=========================
Documents tested: 500
Per-field accuracy:
- InvoiceNumber: 98.9% (494/500)
- InvoiceDate: 95.5% (478/500)
- InvoiceDueDate: 95.9% (480/500)
- OCR: 99.1% (496/500)
- Bankgiro: 99.0% (495/500)
- Plusgiro: 99.4% (497/500)
- Amount: 91.3% (457/500)
- supplier_org: 78.2% (391/500)
Overall: 94.8%
```
## Compare Models
`/eval compare`
Compare two model versions:
```bash
# Compare old vs new model
python -m src.cli.eval compare \
--model-a runs/train/invoice_v1/weights/best.pt \
--model-b runs/train/invoice_v2/weights/best.pt \
--test-data ~/invoice-data/test/
```
Output:
```
Model Comparison
================
Model A Model B Delta
mAP@0.5: 91.2% 93.5% +2.3%
Accuracy: 92.1% 94.8% +2.7%
Speed (ms): 1850 1520 -330
Per-field improvements:
- amount: +4.2%
- payment_line: +3.8%
- customer_num: +2.1%
Recommendation: Deploy Model B
```
## Generate Report
`/eval report`
Generate comprehensive evaluation report:
```bash
python -m src.cli.eval report --output eval_report.md
```
Output:
```markdown
# Evaluation Report
Generated: 2026-01-25
## Model Performance
- Model: runs/train/invoice_fields/weights/best.pt
- mAP@0.5: 93.5%
- Training samples: 9,738
## Field Extraction Accuracy
| Field | Accuracy | Errors |
|-------|----------|--------|
| InvoiceNumber | 98.9% | 6 |
| Amount | 91.3% | 43 |
...
## Error Analysis
### Common Errors
1. Amount: OCR misreads comma as period
2. supplier_org: Missing from some invoices
3. payment_line: Partially obscured by stamps
## Recommendations
1. Add more training data for low-accuracy fields
2. Implement OCR error correction for amounts
3. Consider confidence threshold tuning
```
## Quick Commands
```bash
# Evaluate model metrics
yolo val model=runs/train/invoice_fields/weights/best.pt
# Test inference on sample
python -m src.cli.infer --input sample.pdf --output result.json --gpu
# Check test coverage
pytest --cov=src --cov-report=html
```
## Evaluation Metrics
| Metric | Target | Current |
|--------|--------|---------|
| mAP@0.5 | >90% | 93.5% |
| Overall Accuracy | >90% | 94.8% |
| Test Coverage | >60% | 37% |
| Tests Passing | 100% | 100% |
## When to Evaluate
- After training a new model
- Before deploying to production
- After adding new training data
- When accuracy complaints arise
- Weekly performance monitoring

View File

@@ -1,70 +0,0 @@
# /learn - Extract Reusable Patterns
Analyze the current session and extract any patterns worth saving as skills.
## Trigger
Run `/learn` at any point during a session when you've solved a non-trivial problem.
## What to Extract
Look for:
1. **Error Resolution Patterns**
- What error occurred?
- What was the root cause?
- What fixed it?
- Is this reusable for similar errors?
2. **Debugging Techniques**
- Non-obvious debugging steps
- Tool combinations that worked
- Diagnostic patterns
3. **Workarounds**
- Library quirks
- API limitations
- Version-specific fixes
4. **Project-Specific Patterns**
- Codebase conventions discovered
- Architecture decisions made
- Integration patterns
## Output Format
Create a skill file at `~/.claude/skills/learned/[pattern-name].md`:
```markdown
# [Descriptive Pattern Name]
**Extracted:** [Date]
**Context:** [Brief description of when this applies]
## Problem
[What problem this solves - be specific]
## Solution
[The pattern/technique/workaround]
## Example
[Code example if applicable]
## When to Use
[Trigger conditions - what should activate this skill]
```
## Process
1. Review the session for extractable patterns
2. Identify the most valuable/reusable insight
3. Draft the skill file
4. Ask user to confirm before saving
5. Save to `~/.claude/skills/learned/`
## Notes
- Don't extract trivial fixes (typos, simple syntax errors)
- Don't extract one-time issues (specific API outages, etc.)
- Focus on patterns that will save time in future sessions
- Keep skills focused - one pattern per skill

View File

@@ -1,172 +0,0 @@
# Orchestrate Command
Sequential agent workflow for complex tasks.
## Usage
`/orchestrate [workflow-type] [task-description]`
## Workflow Types
### feature
Full feature implementation workflow:
```
planner -> tdd-guide -> code-reviewer -> security-reviewer
```
### bugfix
Bug investigation and fix workflow:
```
explorer -> tdd-guide -> code-reviewer
```
### refactor
Safe refactoring workflow:
```
architect -> code-reviewer -> tdd-guide
```
### security
Security-focused review:
```
security-reviewer -> code-reviewer -> architect
```
## Execution Pattern
For each agent in the workflow:
1. **Invoke agent** with context from previous agent
2. **Collect output** as structured handoff document
3. **Pass to next agent** in chain
4. **Aggregate results** into final report
## Handoff Document Format
Between agents, create handoff document:
```markdown
## HANDOFF: [previous-agent] -> [next-agent]
### Context
[Summary of what was done]
### Findings
[Key discoveries or decisions]
### Files Modified
[List of files touched]
### Open Questions
[Unresolved items for next agent]
### Recommendations
[Suggested next steps]
```
## Example: Feature Workflow
```
/orchestrate feature "Add user authentication"
```
Executes:
1. **Planner Agent**
- Analyzes requirements
- Creates implementation plan
- Identifies dependencies
- Output: `HANDOFF: planner -> tdd-guide`
2. **TDD Guide Agent**
- Reads planner handoff
- Writes tests first
- Implements to pass tests
- Output: `HANDOFF: tdd-guide -> code-reviewer`
3. **Code Reviewer Agent**
- Reviews implementation
- Checks for issues
- Suggests improvements
- Output: `HANDOFF: code-reviewer -> security-reviewer`
4. **Security Reviewer Agent**
- Security audit
- Vulnerability check
- Final approval
- Output: Final Report
## Final Report Format
```
ORCHESTRATION REPORT
====================
Workflow: feature
Task: Add user authentication
Agents: planner -> tdd-guide -> code-reviewer -> security-reviewer
SUMMARY
-------
[One paragraph summary]
AGENT OUTPUTS
-------------
Planner: [summary]
TDD Guide: [summary]
Code Reviewer: [summary]
Security Reviewer: [summary]
FILES CHANGED
-------------
[List all files modified]
TEST RESULTS
------------
[Test pass/fail summary]
SECURITY STATUS
---------------
[Security findings]
RECOMMENDATION
--------------
[SHIP / NEEDS WORK / BLOCKED]
```
## Parallel Execution
For independent checks, run agents in parallel:
```markdown
### Parallel Phase
Run simultaneously:
- code-reviewer (quality)
- security-reviewer (security)
- architect (design)
### Merge Results
Combine outputs into single report
```
## Arguments
$ARGUMENTS:
- `feature <description>` - Full feature workflow
- `bugfix <description>` - Bug fix workflow
- `refactor <description>` - Refactoring workflow
- `security <description>` - Security review workflow
- `custom <agents> <description>` - Custom agent sequence
## Custom Workflow Example
```
/orchestrate custom "architect,tdd-guide,code-reviewer" "Redesign caching layer"
```
## Tips
1. **Start with planner** for complex features
2. **Always include code-reviewer** before merge
3. **Use security-reviewer** for auth/payment/PII
4. **Keep handoffs concise** - focus on what next agent needs
5. **Run verification** between agents if needed

View File

@@ -1,113 +0,0 @@
---
description: Restate requirements, assess risks, and create step-by-step implementation plan. WAIT for user CONFIRM before touching any code.
---
# Plan Command
This command invokes the **planner** agent to create a comprehensive implementation plan before writing any code.
## What This Command Does
1. **Restate Requirements** - Clarify what needs to be built
2. **Identify Risks** - Surface potential issues and blockers
3. **Create Step Plan** - Break down implementation into phases
4. **Wait for Confirmation** - MUST receive user approval before proceeding
## When to Use
Use `/plan` when:
- Starting a new feature
- Making significant architectural changes
- Working on complex refactoring
- Multiple files/components will be affected
- Requirements are unclear or ambiguous
## How It Works
The planner agent will:
1. **Analyze the request** and restate requirements in clear terms
2. **Break down into phases** with specific, actionable steps
3. **Identify dependencies** between components
4. **Assess risks** and potential blockers
5. **Estimate complexity** (High/Medium/Low)
6. **Present the plan** and WAIT for your explicit confirmation
## Example Usage
```
User: /plan I need to add real-time notifications when markets resolve
Agent (planner):
# Implementation Plan: Real-Time Market Resolution Notifications
## Requirements Restatement
- Send notifications to users when markets they're watching resolve
- Support multiple notification channels (in-app, email, webhook)
- Ensure notifications are delivered reliably
- Include market outcome and user's position result
## Implementation Phases
### Phase 1: Database Schema
- Add notifications table with columns: id, user_id, market_id, type, status, created_at
- Add user_notification_preferences table for channel preferences
- Create indexes on user_id and market_id for performance
### Phase 2: Notification Service
- Create notification service in lib/notifications.ts
- Implement notification queue using BullMQ/Redis
- Add retry logic for failed deliveries
- Create notification templates
### Phase 3: Integration Points
- Hook into market resolution logic (when status changes to "resolved")
- Query all users with positions in market
- Enqueue notifications for each user
### Phase 4: Frontend Components
- Create NotificationBell component in header
- Add NotificationList modal
- Implement real-time updates via Supabase subscriptions
- Add notification preferences page
## Dependencies
- Redis (for queue)
- Email service (SendGrid/Resend)
- Supabase real-time subscriptions
## Risks
- HIGH: Email deliverability (SPF/DKIM required)
- MEDIUM: Performance with 1000+ users per market
- MEDIUM: Notification spam if markets resolve frequently
- LOW: Real-time subscription overhead
## Estimated Complexity: MEDIUM
- Backend: 4-6 hours
- Frontend: 3-4 hours
- Testing: 2-3 hours
- Total: 9-13 hours
**WAITING FOR CONFIRMATION**: Proceed with this plan? (yes/no/modify)
```
## Important Notes
**CRITICAL**: The planner agent will **NOT** write any code until you explicitly confirm the plan with "yes" or "proceed" or similar affirmative response.
If you want changes, respond with:
- "modify: [your changes]"
- "different approach: [alternative]"
- "skip phase 2 and do phase 3 first"
## Integration with Other Commands
After planning:
- Use `/tdd` to implement with test-driven development
- Use `/build-and-fix` if build errors occur
- Use `/code-review` to review completed implementation
## Related Agents
This command invokes the `planner` agent located at:
`~/.claude/agents/planner.md`

View File

@@ -1,28 +0,0 @@
# Refactor Clean
Safely identify and remove dead code with test verification:
1. Run dead code analysis tools:
- knip: Find unused exports and files
- depcheck: Find unused dependencies
- ts-prune: Find unused TypeScript exports
2. Generate comprehensive report in .reports/dead-code-analysis.md
3. Categorize findings by severity:
- SAFE: Test files, unused utilities
- CAUTION: API routes, components
- DANGER: Config files, main entry points
4. Propose safe deletions only
5. Before each deletion:
- Run full test suite
- Verify tests pass
- Apply change
- Re-run tests
- Rollback if tests fail
6. Show summary of cleaned items
Never delete code without running tests first!

View File

@@ -1,80 +0,0 @@
---
description: Configure your preferred package manager (npm/pnpm/yarn/bun)
disable-model-invocation: true
---
# Package Manager Setup
Configure your preferred package manager for this project or globally.
## Usage
```bash
# Detect current package manager
node scripts/setup-package-manager.js --detect
# Set global preference
node scripts/setup-package-manager.js --global pnpm
# Set project preference
node scripts/setup-package-manager.js --project bun
# List available package managers
node scripts/setup-package-manager.js --list
```
## Detection Priority
When determining which package manager to use, the following order is checked:
1. **Environment variable**: `CLAUDE_PACKAGE_MANAGER`
2. **Project config**: `.claude/package-manager.json`
3. **package.json**: `packageManager` field
4. **Lock file**: Presence of package-lock.json, yarn.lock, pnpm-lock.yaml, or bun.lockb
5. **Global config**: `~/.claude/package-manager.json`
6. **Fallback**: First available package manager (pnpm > bun > yarn > npm)
## Configuration Files
### Global Configuration
```json
// ~/.claude/package-manager.json
{
"packageManager": "pnpm"
}
```
### Project Configuration
```json
// .claude/package-manager.json
{
"packageManager": "bun"
}
```
### package.json
```json
{
"packageManager": "pnpm@8.6.0"
}
```
## Environment Variable
Set `CLAUDE_PACKAGE_MANAGER` to override all other detection methods:
```bash
# Windows (PowerShell)
$env:CLAUDE_PACKAGE_MANAGER = "pnpm"
# macOS/Linux
export CLAUDE_PACKAGE_MANAGER=pnpm
```
## Run the Detection
To see current package manager detection results, run:
```bash
node scripts/setup-package-manager.js --detect
```

View File

@@ -1,326 +0,0 @@
---
description: Enforce test-driven development workflow. Scaffold interfaces, generate tests FIRST, then implement minimal code to pass. Ensure 80%+ coverage.
---
# TDD Command
This command invokes the **tdd-guide** agent to enforce test-driven development methodology.
## What This Command Does
1. **Scaffold Interfaces** - Define types/interfaces first
2. **Generate Tests First** - Write failing tests (RED)
3. **Implement Minimal Code** - Write just enough to pass (GREEN)
4. **Refactor** - Improve code while keeping tests green (REFACTOR)
5. **Verify Coverage** - Ensure 80%+ test coverage
## When to Use
Use `/tdd` when:
- Implementing new features
- Adding new functions/components
- Fixing bugs (write test that reproduces bug first)
- Refactoring existing code
- Building critical business logic
## How It Works
The tdd-guide agent will:
1. **Define interfaces** for inputs/outputs
2. **Write tests that will FAIL** (because code doesn't exist yet)
3. **Run tests** and verify they fail for the right reason
4. **Write minimal implementation** to make tests pass
5. **Run tests** and verify they pass
6. **Refactor** code while keeping tests green
7. **Check coverage** and add more tests if below 80%
## TDD Cycle
```
RED → GREEN → REFACTOR → REPEAT
RED: Write a failing test
GREEN: Write minimal code to pass
REFACTOR: Improve code, keep tests passing
REPEAT: Next feature/scenario
```
## Example Usage
```
User: /tdd I need a function to calculate market liquidity score
Agent (tdd-guide):
# TDD Session: Market Liquidity Score Calculator
## Step 1: Define Interface (SCAFFOLD)
```typescript
// lib/liquidity.ts
export interface MarketData {
totalVolume: number
bidAskSpread: number
activeTraders: number
lastTradeTime: Date
}
export function calculateLiquidityScore(market: MarketData): number {
// TODO: Implementation
throw new Error('Not implemented')
}
```
## Step 2: Write Failing Test (RED)
```typescript
// lib/liquidity.test.ts
import { calculateLiquidityScore } from './liquidity'
describe('calculateLiquidityScore', () => {
it('should return high score for liquid market', () => {
const market = {
totalVolume: 100000,
bidAskSpread: 0.01,
activeTraders: 500,
lastTradeTime: new Date()
}
const score = calculateLiquidityScore(market)
expect(score).toBeGreaterThan(80)
expect(score).toBeLessThanOrEqual(100)
})
it('should return low score for illiquid market', () => {
const market = {
totalVolume: 100,
bidAskSpread: 0.5,
activeTraders: 2,
lastTradeTime: new Date(Date.now() - 86400000) // 1 day ago
}
const score = calculateLiquidityScore(market)
expect(score).toBeLessThan(30)
expect(score).toBeGreaterThanOrEqual(0)
})
it('should handle edge case: zero volume', () => {
const market = {
totalVolume: 0,
bidAskSpread: 0,
activeTraders: 0,
lastTradeTime: new Date()
}
const score = calculateLiquidityScore(market)
expect(score).toBe(0)
})
})
```
## Step 3: Run Tests - Verify FAIL
```bash
npm test lib/liquidity.test.ts
FAIL lib/liquidity.test.ts
✕ should return high score for liquid market (2 ms)
Error: Not implemented
1 test failed, 0 passed
```
✅ Tests fail as expected. Ready to implement.
## Step 4: Implement Minimal Code (GREEN)
```typescript
// lib/liquidity.ts
export function calculateLiquidityScore(market: MarketData): number {
// Handle zero volume edge case
if (market.totalVolume === 0) {
return 0
}
// Calculate component scores (0-100 scale)
const volumeScore = Math.min(market.totalVolume / 1000, 100)
const spreadScore = Math.max(100 - (market.bidAskSpread * 1000), 0)
const traderScore = Math.min(market.activeTraders / 10, 100)
// Recent activity bonus
const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60)
const recencyScore = Math.max(100 - (hoursSinceLastTrade * 10), 0)
// Weighted average
const score = (
volumeScore * 0.4 +
spreadScore * 0.3 +
traderScore * 0.2 +
recencyScore * 0.1
)
return Math.min(Math.max(score, 0), 100) // Clamp to 0-100
}
```
## Step 5: Run Tests - Verify PASS
```bash
npm test lib/liquidity.test.ts
PASS lib/liquidity.test.ts
✓ should return high score for liquid market (3 ms)
✓ should return low score for illiquid market (2 ms)
✓ should handle edge case: zero volume (1 ms)
3 tests passed
```
✅ All tests passing!
## Step 6: Refactor (IMPROVE)
```typescript
// lib/liquidity.ts - Refactored with constants and better readability
const WEIGHTS = {
VOLUME: 0.4,
SPREAD: 0.3,
TRADERS: 0.2,
RECENCY: 0.1,
} as const
const SCALE_FACTORS = {
VOLUME: 1000,
SPREAD: 1000,
TRADERS: 10,
RECENCY_PENALTY: 10,
} as const
function clamp(value: number, min: number, max: number): number {
return Math.min(Math.max(value, min), max)
}
export function calculateLiquidityScore(market: MarketData): number {
if (market.totalVolume === 0) return 0
const volumeScore = Math.min(market.totalVolume / SCALE_FACTORS.VOLUME, 100)
const spreadScore = clamp(100 - (market.bidAskSpread * SCALE_FACTORS.SPREAD), 0, 100)
const traderScore = Math.min(market.activeTraders / SCALE_FACTORS.TRADERS, 100)
const hoursSinceLastTrade = (Date.now() - market.lastTradeTime.getTime()) / (1000 * 60 * 60)
const recencyScore = clamp(100 - (hoursSinceLastTrade * SCALE_FACTORS.RECENCY_PENALTY), 0, 100)
const weightedScore =
volumeScore * WEIGHTS.VOLUME +
spreadScore * WEIGHTS.SPREAD +
traderScore * WEIGHTS.TRADERS +
recencyScore * WEIGHTS.RECENCY
return clamp(weightedScore, 0, 100)
}
```
## Step 7: Verify Tests Still Pass
```bash
npm test lib/liquidity.test.ts
PASS lib/liquidity.test.ts
✓ should return high score for liquid market (3 ms)
✓ should return low score for illiquid market (2 ms)
✓ should handle edge case: zero volume (1 ms)
3 tests passed
```
✅ Refactoring complete, tests still passing!
## Step 8: Check Coverage
```bash
npm test -- --coverage lib/liquidity.test.ts
File | % Stmts | % Branch | % Funcs | % Lines
---------------|---------|----------|---------|--------
liquidity.ts | 100 | 100 | 100 | 100
Coverage: 100% ✅ (Target: 80%)
```
✅ TDD session complete!
```
## TDD Best Practices
**DO:**
- ✅ Write the test FIRST, before any implementation
- ✅ Run tests and verify they FAIL before implementing
- ✅ Write minimal code to make tests pass
- ✅ Refactor only after tests are green
- ✅ Add edge cases and error scenarios
- ✅ Aim for 80%+ coverage (100% for critical code)
**DON'T:**
- ❌ Write implementation before tests
- ❌ Skip running tests after each change
- ❌ Write too much code at once
- ❌ Ignore failing tests
- ❌ Test implementation details (test behavior)
- ❌ Mock everything (prefer integration tests)
## Test Types to Include
**Unit Tests** (Function-level):
- Happy path scenarios
- Edge cases (empty, null, max values)
- Error conditions
- Boundary values
**Integration Tests** (Component-level):
- API endpoints
- Database operations
- External service calls
- React components with hooks
**E2E Tests** (use `/e2e` command):
- Critical user flows
- Multi-step processes
- Full stack integration
## Coverage Requirements
- **80% minimum** for all code
- **100% required** for:
- Financial calculations
- Authentication logic
- Security-critical code
- Core business logic
## Important Notes
**MANDATORY**: Tests must be written BEFORE implementation. The TDD cycle is:
1. **RED** - Write failing test
2. **GREEN** - Implement to pass
3. **REFACTOR** - Improve code
Never skip the RED phase. Never write code before tests.
## Integration with Other Commands
- Use `/plan` first to understand what to build
- Use `/tdd` to implement with tests
- Use `/build-and-fix` if build errors occur
- Use `/code-review` to review implementation
- Use `/test-coverage` to verify coverage
## Related Agents
This command invokes the `tdd-guide` agent located at:
`~/.claude/agents/tdd-guide.md`
And can reference the `tdd-workflow` skill at:
`~/.claude/skills/tdd-workflow/`

View File

@@ -1,27 +0,0 @@
# Test Coverage
Analyze test coverage and generate missing tests:
1. Run tests with coverage: npm test --coverage or pnpm test --coverage
2. Analyze coverage report (coverage/coverage-summary.json)
3. Identify files below 80% coverage threshold
4. For each under-covered file:
- Analyze untested code paths
- Generate unit tests for functions
- Generate integration tests for APIs
- Generate E2E tests for critical flows
5. Verify new tests pass
6. Show before/after coverage metrics
7. Ensure project reaches 80%+ overall coverage
Focus on:
- Happy path scenarios
- Error handling
- Edge cases (null, undefined, empty)
- Boundary conditions

View File

@@ -1,17 +0,0 @@
# Update Codemaps
Analyze the codebase structure and update architecture documentation:
1. Scan all source files for imports, exports, and dependencies
2. Generate token-lean codemaps in the following format:
- codemaps/architecture.md - Overall architecture
- codemaps/backend.md - Backend structure
- codemaps/frontend.md - Frontend structure
- codemaps/data.md - Data models and schemas
3. Calculate diff percentage from previous version
4. If changes > 30%, request user approval before updating
5. Add freshness timestamp to each codemap
6. Save reports to .reports/codemap-diff.txt
Use TypeScript/Node.js for analysis. Focus on high-level structure, not implementation details.

View File

@@ -1,31 +0,0 @@
# Update Documentation
Sync documentation from source-of-truth:
1. Read package.json scripts section
- Generate scripts reference table
- Include descriptions from comments
2. Read .env.example
- Extract all environment variables
- Document purpose and format
3. Generate docs/CONTRIB.md with:
- Development workflow
- Available scripts
- Environment setup
- Testing procedures
4. Generate docs/RUNBOOK.md with:
- Deployment procedures
- Monitoring and alerts
- Common issues and fixes
- Rollback procedures
5. Identify obsolete documentation:
- Find docs not modified in 90+ days
- List for manual review
6. Show diff summary
Single source of truth: package.json and .env.example

View File

@@ -1,59 +0,0 @@
# Verification Command
Run comprehensive verification on current codebase state.
## Instructions
Execute verification in this exact order:
1. **Build Check**
- Run the build command for this project
- If it fails, report errors and STOP
2. **Type Check**
- Run TypeScript/type checker
- Report all errors with file:line
3. **Lint Check**
- Run linter
- Report warnings and errors
4. **Test Suite**
- Run all tests
- Report pass/fail count
- Report coverage percentage
5. **Console.log Audit**
- Search for console.log in source files
- Report locations
6. **Git Status**
- Show uncommitted changes
- Show files modified since last commit
## Output
Produce a concise verification report:
```
VERIFICATION: [PASS/FAIL]
Build: [OK/FAIL]
Types: [OK/X errors]
Lint: [OK/X issues]
Tests: [X/Y passed, Z% coverage]
Secrets: [OK/X found]
Logs: [OK/X console.logs]
Ready for PR: [YES/NO]
```
If any critical issues, list them with fix suggestions.
## Arguments
$ARGUMENTS can be:
- `quick` - Only build + types
- `full` - All checks (default)
- `pre-commit` - Checks relevant for commits
- `pre-pr` - Full checks plus security scan

View File

@@ -1,157 +0,0 @@
{
"$schema": "https://json.schemastore.org/claude-code-settings.json",
"hooks": {
"PreToolUse": [
{
"matcher": "tool == \"Bash\" && tool_input.command matches \"(npm run dev|pnpm( run)? dev|yarn dev|bun run dev)\"",
"hooks": [
{
"type": "command",
"command": "node -e \"console.error('[Hook] BLOCKED: Dev server must run in tmux for log access');console.error('[Hook] Use: tmux new-session -d -s dev \\\"npm run dev\\\"');console.error('[Hook] Then: tmux attach -t dev');process.exit(1)\""
}
],
"description": "Block dev servers outside tmux - ensures you can access logs"
},
{
"matcher": "tool == \"Bash\" && tool_input.command matches \"(npm (install|test)|pnpm (install|test)|yarn (install|test)?|bun (install|test)|cargo build|make|docker|pytest|vitest|playwright)\"",
"hooks": [
{
"type": "command",
"command": "node -e \"if(!process.env.TMUX){console.error('[Hook] Consider running in tmux for session persistence');console.error('[Hook] tmux new -s dev | tmux attach -t dev')}\""
}
],
"description": "Reminder to use tmux for long-running commands"
},
{
"matcher": "tool == \"Bash\" && tool_input.command matches \"git push\"",
"hooks": [
{
"type": "command",
"command": "node -e \"console.error('[Hook] Review changes before push...');console.error('[Hook] Continuing with push (remove this hook to add interactive review)')\""
}
],
"description": "Reminder before git push to review changes"
},
{
"matcher": "tool == \"Write\" && tool_input.file_path matches \"\\\\.(md|txt)$\" && !(tool_input.file_path matches \"README\\\\.md|CLAUDE\\\\.md|AGENTS\\\\.md|CONTRIBUTING\\\\.md\")",
"hooks": [
{
"type": "command",
"command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path||'';if(/\\.(md|txt)$/.test(p)&&!/(README|CLAUDE|AGENTS|CONTRIBUTING)\\.md$/.test(p)){console.error('[Hook] BLOCKED: Unnecessary documentation file creation');console.error('[Hook] File: '+p);console.error('[Hook] Use README.md for documentation instead');process.exit(1)}console.log(d)})\""
}
],
"description": "Block creation of random .md files - keeps docs consolidated"
},
{
"matcher": "tool == \"Edit\" || tool == \"Write\"",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/suggest-compact.js\""
}
],
"description": "Suggest manual compaction at logical intervals"
}
],
"PreCompact": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/pre-compact.js\""
}
],
"description": "Save state before context compaction"
}
],
"SessionStart": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-start.js\""
}
],
"description": "Load previous context and detect package manager on new session"
}
],
"PostToolUse": [
{
"matcher": "tool == \"Bash\"",
"hooks": [
{
"type": "command",
"command": "node -e \"let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const cmd=i.tool_input?.command||'';if(/gh pr create/.test(cmd)){const out=i.tool_output?.output||'';const m=out.match(/https:\\/\\/github.com\\/[^/]+\\/[^/]+\\/pull\\/\\d+/);if(m){console.error('[Hook] PR created: '+m[0]);const repo=m[0].replace(/https:\\/\\/github.com\\/([^/]+\\/[^/]+)\\/pull\\/\\d+/,'$1');const pr=m[0].replace(/.*\\/pull\\/(\\d+)/,'$1');console.error('[Hook] To review: gh pr review '+pr+' --repo '+repo)}}console.log(d)})\""
}
],
"description": "Log PR URL and provide review command after PR creation"
},
{
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"",
"hooks": [
{
"type": "command",
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){try{execSync('npx prettier --write \"'+p+'\"',{stdio:['pipe','pipe','pipe']})}catch(e){}}console.log(d)})\""
}
],
"description": "Auto-format JS/TS files with Prettier after edits"
},
{
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx)$\"",
"hooks": [
{
"type": "command",
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');const path=require('path');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){let dir=path.dirname(p);while(dir!==path.dirname(dir)&&!fs.existsSync(path.join(dir,'tsconfig.json'))){dir=path.dirname(dir)}if(fs.existsSync(path.join(dir,'tsconfig.json'))){try{const r=execSync('npx tsc --noEmit --pretty false 2>&1',{cwd:dir,encoding:'utf8',stdio:['pipe','pipe','pipe']});const lines=r.split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}catch(e){const lines=(e.stdout||'').split('\\n').filter(l=>l.includes(p)).slice(0,10);if(lines.length)console.error(lines.join('\\n'))}}}console.log(d)})\""
}
],
"description": "TypeScript check after editing .ts/.tsx files"
},
{
"matcher": "tool == \"Edit\" && tool_input.file_path matches \"\\\\.(ts|tsx|js|jsx)$\"",
"hooks": [
{
"type": "command",
"command": "node -e \"const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{const i=JSON.parse(d);const p=i.tool_input?.file_path;if(p&&fs.existsSync(p)){const c=fs.readFileSync(p,'utf8');const lines=c.split('\\n');const matches=[];lines.forEach((l,idx)=>{if(/console\\.log/.test(l))matches.push((idx+1)+': '+l.trim())});if(matches.length){console.error('[Hook] WARNING: console.log found in '+p);matches.slice(0,5).forEach(m=>console.error(m));console.error('[Hook] Remove console.log before committing')}}console.log(d)})\""
}
],
"description": "Warn about console.log statements after edits"
}
],
"Stop": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node -e \"const{execSync}=require('child_process');const fs=require('fs');let d='';process.stdin.on('data',c=>d+=c);process.stdin.on('end',()=>{try{execSync('git rev-parse --git-dir',{stdio:'pipe'})}catch{console.log(d);process.exit(0)}try{const files=execSync('git diff --name-only HEAD',{encoding:'utf8',stdio:['pipe','pipe','pipe']}).split('\\n').filter(f=>/\\.(ts|tsx|js|jsx)$/.test(f)&&fs.existsSync(f));let hasConsole=false;for(const f of files){if(fs.readFileSync(f,'utf8').includes('console.log')){console.error('[Hook] WARNING: console.log found in '+f);hasConsole=true}}if(hasConsole)console.error('[Hook] Remove console.log statements before committing')}catch(e){}console.log(d)})\""
}
],
"description": "Check for console.log in modified files after each response"
}
],
"SessionEnd": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/session-end.js\""
}
],
"description": "Persist session state on end"
},
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "node \"${CLAUDE_PLUGIN_ROOT}/scripts/hooks/evaluate-session.js\""
}
],
"description": "Evaluate session for extractable patterns"
}
]
}
}

View File

@@ -1,36 +0,0 @@
#!/bin/bash
# PreCompact Hook - Save state before context compaction
#
# Runs before Claude compacts context, giving you a chance to
# preserve important state that might get lost in summarization.
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "PreCompact": [{
# "matcher": "*",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/hooks/memory-persistence/pre-compact.sh"
# }]
# }]
# }
# }
SESSIONS_DIR="${HOME}/.claude/sessions"
COMPACTION_LOG="${SESSIONS_DIR}/compaction-log.txt"
mkdir -p "$SESSIONS_DIR"
# Log compaction event with timestamp
echo "[$(date '+%Y-%m-%d %H:%M:%S')] Context compaction triggered" >> "$COMPACTION_LOG"
# If there's an active session file, note the compaction
ACTIVE_SESSION=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1)
if [ -n "$ACTIVE_SESSION" ]; then
echo "" >> "$ACTIVE_SESSION"
echo "---" >> "$ACTIVE_SESSION"
echo "**[Compaction occurred at $(date '+%H:%M')]** - Context was summarized" >> "$ACTIVE_SESSION"
fi
echo "[PreCompact] State saved before compaction" >&2

View File

@@ -1,61 +0,0 @@
#!/bin/bash
# Stop Hook (Session End) - Persist learnings when session ends
#
# Runs when Claude session ends. Creates/updates session log file
# with timestamp for continuity tracking.
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "Stop": [{
# "matcher": "*",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/hooks/memory-persistence/session-end.sh"
# }]
# }]
# }
# }
SESSIONS_DIR="${HOME}/.claude/sessions"
TODAY=$(date '+%Y-%m-%d')
SESSION_FILE="${SESSIONS_DIR}/${TODAY}-session.tmp"
mkdir -p "$SESSIONS_DIR"
# If session file exists for today, update the end time
if [ -f "$SESSION_FILE" ]; then
# Update Last Updated timestamp
sed -i '' "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null || \
sed -i "s/\*\*Last Updated:\*\*.*/\*\*Last Updated:\*\* $(date '+%H:%M')/" "$SESSION_FILE" 2>/dev/null
echo "[SessionEnd] Updated session file: $SESSION_FILE" >&2
else
# Create new session file with template
cat > "$SESSION_FILE" << EOF
# Session: $(date '+%Y-%m-%d')
**Date:** $TODAY
**Started:** $(date '+%H:%M')
**Last Updated:** $(date '+%H:%M')
---
## Current State
[Session context goes here]
### Completed
- [ ]
### In Progress
- [ ]
### Notes for Next Session
-
### Context to Load
\`\`\`
[relevant files]
\`\`\`
EOF
echo "[SessionEnd] Created session file: $SESSION_FILE" >&2
fi

View File

@@ -1,37 +0,0 @@
#!/bin/bash
# SessionStart Hook - Load previous context on new session
#
# Runs when a new Claude session starts. Checks for recent session
# files and notifies Claude of available context to load.
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "SessionStart": [{
# "matcher": "*",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/hooks/memory-persistence/session-start.sh"
# }]
# }]
# }
# }
SESSIONS_DIR="${HOME}/.claude/sessions"
LEARNED_DIR="${HOME}/.claude/skills/learned"
# Check for recent session files (last 7 days)
recent_sessions=$(find "$SESSIONS_DIR" -name "*.tmp" -mtime -7 2>/dev/null | wc -l | tr -d ' ')
if [ "$recent_sessions" -gt 0 ]; then
latest=$(ls -t "$SESSIONS_DIR"/*.tmp 2>/dev/null | head -1)
echo "[SessionStart] Found $recent_sessions recent session(s)" >&2
echo "[SessionStart] Latest: $latest" >&2
fi
# Check for learned skills
learned_count=$(find "$LEARNED_DIR" -name "*.md" 2>/dev/null | wc -l | tr -d ' ')
if [ "$learned_count" -gt 0 ]; then
echo "[SessionStart] $learned_count learned skill(s) available in $LEARNED_DIR" >&2
fi

View File

@@ -107,7 +107,11 @@
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")"
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")",
"Skill(tdd)",
"Skill(tdd:*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m training.cli.train --model runs/train/invoice_fields/weights/best.pt --device 0 --epochs 100\")",
"Bash(git commit -m \"$\\(cat <<''EOF''\nfeat: add field-specific bbox expansion strategies for YOLO training\n\nImplement center-point based bbox scaling with directional compensation\nto capture field labels that typically appear above or to the left of\nfield values. This improves YOLO training data quality by including\ncontextual information around field values.\n\nKey changes:\n- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function\n- Define field-specific strategies \\(ocr_number, bankgiro, invoice_date, etc.\\)\n- Support manual_mode for minimal padding \\(no scaling\\)\n- Integrate expand_bbox into AnnotationGenerator\n- Add FIELD_TO_CLASS mapping for field_name to class_name lookup\n- Comprehensive tests with 100% coverage \\(45 tests\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")"
],
"deny": [],
"ask": [],

View File

@@ -1,52 +0,0 @@
#!/bin/bash
# Strategic Compact Suggester
# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals
#
# Why manual over auto-compact:
# - Auto-compact happens at arbitrary points, often mid-task
# - Strategic compacting preserves context through logical phases
# - Compact after exploration, before execution
# - Compact after completing a milestone, before starting next
#
# Hook config (in ~/.claude/settings.json):
# {
# "hooks": {
# "PreToolUse": [{
# "matcher": "Edit|Write",
# "hooks": [{
# "type": "command",
# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
# }]
# }]
# }
# }
#
# Criteria for suggesting compact:
# - Session has been running for extended period
# - Large number of tool calls made
# - Transitioning from research/exploration to implementation
# - Plan has been finalized
# Track tool call count (increment in a temp file)
COUNTER_FILE="/tmp/claude-tool-count-$$"
THRESHOLD=${COMPACT_THRESHOLD:-50}
# Initialize or increment counter
if [ -f "$COUNTER_FILE" ]; then
count=$(cat "$COUNTER_FILE")
count=$((count + 1))
echo "$count" > "$COUNTER_FILE"
else
echo "1" > "$COUNTER_FILE"
count=1
fi
# Suggest compact after threshold tool calls
if [ "$count" -eq "$THRESHOLD" ]; then
echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2
fi
# Suggest at regular intervals after threshold
if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then
echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2
fi

BIN
.coverage

Binary file not shown.

View File

@@ -396,7 +396,7 @@ def extract_invoice_fields(
) -> InferenceResult:
"""Extract structured fields from Swedish invoice PDF.
Uses YOLOv11 for field detection and PaddleOCR for text extraction.
Uses YOLO26 for field detection and PaddleOCR for text extraction.
Applies field-specific normalization and validation.
Args:

59
=3.0.0 Normal file
View File

@@ -0,0 +1,59 @@
Requirement already satisfied: paddleocr in /home/kai/.local/lib/python3.10/site-packages (3.3.1)
Requirement already satisfied: pyyaml in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (6.0.2)
Requirement already satisfied: urllib3 in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (2.6.3)
Requirement already satisfied: paddlex<3.4.0,>=3.3.0 in /home/kai/.local/lib/python3.10/site-packages (from paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (3.3.6)
Requirement already satisfied: requests in /home/kai/.local/lib/python3.10/site-packages (from paddleocr) (2.32.5)
Requirement already satisfied: typing-extensions>=4.12 in /home/kai/.local/lib/python3.10/site-packages (from paddleocr) (4.15.0)
Requirement already satisfied: aistudio-sdk>=0.3.5 in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.3.8)
Requirement already satisfied: chardet in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (5.2.0)
Requirement already satisfied: colorlog in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (6.10.1)
Requirement already satisfied: filelock in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (3.20.0)
Requirement already satisfied: huggingface-hub in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.3.1)
Requirement already satisfied: modelscope>=1.28.0 in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.31.0)
Requirement already satisfied: numpy>=1.24 in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2.2.6)
Requirement already satisfied: packaging in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (25.0)
Requirement already satisfied: pandas>=1.3 in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2.3.3)
Requirement already satisfied: pillow in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (12.1.0)
Requirement already satisfied: prettytable in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (3.16.0)
Requirement already satisfied: py-cpuinfo in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (9.0.0)
Requirement already satisfied: pydantic>=2 in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2.12.3)
Requirement already satisfied: ruamel.yaml in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.18.16)
Requirement already satisfied: ujson in /home/kai/.local/lib/python3.10/site-packages (from paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (5.11.0)
Requirement already satisfied: imagesize in /home/kai/.local/lib/python3.10/site-packages (from paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.4.1)
Requirement already satisfied: opencv-contrib-python==4.10.0.84 in /home/kai/.local/lib/python3.10/site-packages (from paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (4.10.0.84)
Requirement already satisfied: pyclipper in /home/kai/.local/lib/python3.10/site-packages (from paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.3.0.post6)
Requirement already satisfied: pypdfium2>=4 in /home/kai/.local/lib/python3.10/site-packages (from paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (5.0.0)
Requirement already satisfied: python-bidi in /home/kai/.local/lib/python3.10/site-packages (from paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.6.7)
Requirement already satisfied: shapely in /home/kai/.local/lib/python3.10/site-packages (from paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2.1.2)
Requirement already satisfied: psutil in /home/kai/.local/lib/python3.10/site-packages (from aistudio-sdk>=0.3.5->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (7.2.1)
Requirement already satisfied: tqdm in /home/kai/.local/lib/python3.10/site-packages (from aistudio-sdk>=0.3.5->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (4.67.1)
Requirement already satisfied: bce-python-sdk in /home/kai/.local/lib/python3.10/site-packages (from aistudio-sdk>=0.3.5->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.9.46)
Requirement already satisfied: click in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from aistudio-sdk>=0.3.5->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (8.3.1)
Requirement already satisfied: setuptools in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from modelscope>=1.28.0->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (80.10.1)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/kai/.local/lib/python3.10/site-packages (from pandas>=1.3->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from pandas>=1.3->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /home/kai/.local/lib/python3.10/site-packages (from pandas>=1.3->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2025.2)
Requirement already satisfied: annotated-types>=0.6.0 in /home/kai/.local/lib/python3.10/site-packages (from pydantic>=2->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.7.0)
Requirement already satisfied: pydantic-core==2.41.4 in /home/kai/.local/lib/python3.10/site-packages (from pydantic>=2->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2.41.4)
Requirement already satisfied: typing-inspection>=0.4.2 in /home/kai/.local/lib/python3.10/site-packages (from pydantic>=2->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.4.2)
Requirement already satisfied: six>=1.5 in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas>=1.3->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.17.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/kai/.local/lib/python3.10/site-packages (from requests->paddleocr) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from requests->paddleocr) (3.11)
Requirement already satisfied: certifi>=2017.4.17 in /home/kai/miniconda3/envs/invoice-py310-sm120/lib/python3.10/site-packages (from requests->paddleocr) (2026.1.4)
Requirement already satisfied: pycryptodome>=3.8.0 in /home/kai/.local/lib/python3.10/site-packages (from bce-python-sdk->aistudio-sdk>=0.3.5->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (3.23.0)
Requirement already satisfied: future>=0.6.0 in /home/kai/.local/lib/python3.10/site-packages (from bce-python-sdk->aistudio-sdk>=0.3.5->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.0.0)
Requirement already satisfied: fsspec>=2023.5.0 in /home/kai/.local/lib/python3.10/site-packages (from huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (2025.9.0)
Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /home/kai/.local/lib/python3.10/site-packages (from huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.2.0)
Requirement already satisfied: httpx<1,>=0.23.0 in /home/kai/.local/lib/python3.10/site-packages (from huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.28.1)
Requirement already satisfied: shellingham in /home/kai/.local/lib/python3.10/site-packages (from huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.5.4)
Requirement already satisfied: typer-slim in /home/kai/.local/lib/python3.10/site-packages (from huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.20.0)
Requirement already satisfied: anyio in /home/kai/.local/lib/python3.10/site-packages (from httpx<1,>=0.23.0->huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (4.11.0)
Requirement already satisfied: httpcore==1.* in /home/kai/.local/lib/python3.10/site-packages (from httpx<1,>=0.23.0->huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.0.9)
Requirement already satisfied: h11>=0.16 in /home/kai/.local/lib/python3.10/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.16.0)
Requirement already satisfied: exceptiongroup>=1.0.2 in /home/kai/.local/lib/python3.10/site-packages (from anyio->httpx<1,>=0.23.0->huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.3.0)
Requirement already satisfied: sniffio>=1.1 in /home/kai/.local/lib/python3.10/site-packages (from anyio->httpx<1,>=0.23.0->huggingface-hub->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (1.3.1)
Requirement already satisfied: wcwidth in /home/kai/.local/lib/python3.10/site-packages (from prettytable->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.2.14)
Requirement already satisfied: ruamel.yaml.clib>=0.2.7 in /home/kai/.local/lib/python3.10/site-packages (from ruamel.yaml->paddlex<3.4.0,>=3.3.0->paddlex[ocr-core]<3.4.0,>=3.3.0->paddleocr) (0.2.14)
[notice] A new release of pip is available: 25.3 -> 26.0
[notice] To update, run: pip install --upgrade pip

93
AGENTS.md Normal file
View File

@@ -0,0 +1,93 @@
# Invoice Master POC v2
Swedish Invoice Field Extraction System - YOLO26 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
## Tech Stack
| Component | Technology |
|-----------|------------|
| Object Detection | YOLO26 (Ultralytics >= 8.4.0) |
| OCR Engine | PaddleOCR v5 (PP-OCRv5) |
| PDF Processing | PyMuPDF (fitz) |
| Database | PostgreSQL + psycopg2 |
| Web Framework | FastAPI + Uvicorn |
| Deep Learning | PyTorch + CUDA 12.x |
## WSL Environment (REQUIRED)
**Prefix ALL commands with:**
```bash
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && <command>"
```
**NEVER run Python commands directly in Windows PowerShell/CMD.**
## Project-Specific Rules
- Python 3.11+ with type hints
- No print() in production - use logging
- Run tests: `pytest --cov=src`
## Critical Rules
### Code Organization
- Many small files over few large files
- High cohesion, low coupling
- 200-400 lines typical, 800 max per file
- Organize by feature/domain, not by type
### Code Style
- No emojis in code, comments, or documentation
- Immutability always - never mutate objects or arrays
- No console.log in production code
- Proper error handling with try/catch
- Input validation with Zod or similar
### Testing
- TDD: Write tests first
- 80% minimum coverage
- Unit tests for utilities
- Integration tests for APIs
- E2E tests for critical flows
### Security
- No hardcoded secrets
- Environment variables for sensitive data
- Validate all user inputs
- Parameterized queries only
- CSRF protection enabled
## Environment Variables
```bash
# Required
DB_PASSWORD=
# Optional (with defaults)
DB_HOST=192.168.68.31
DB_PORT=5432
DB_NAME=docmaster
DB_USER=docmaster
MODEL_PATH=runs/train/invoice_fields/weights/best.pt
CONFIDENCE_THRESHOLD=0.5
SERVER_HOST=0.0.0.0
SERVER_PORT=8000
```
## Available Commands
- `/tdd` - Test-driven development workflow
- `/plan` - Create implementation plan
- `/code-review` - Review code quality
- `/build-fix` - Fix build errors
## Git Workflow
- Conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `test:`
- Never commit to main directly
- PRs require review
- All tests must pass before merge

View File

@@ -1,666 +0,0 @@
# Invoice Master POC v2 - 总体架构审查报告
**审查日期**: 2026-02-01
**审查人**: Claude Code
**项目路径**: `/Users/yiukai/Documents/git/invoice-master-poc-v2`
---
## 架构概述
### 整体架构图
```
┌─────────────────────────────────────────────────────────────────┐
│ Frontend (React) │
│ Vite + TypeScript + TailwindCSS │
└─────────────────────────────┬───────────────────────────────────┘
│ HTTP/REST
┌─────────────────────────────▼───────────────────────────────────┐
│ Inference Service (FastAPI) │
│ ┌──────────────┬──────────────┬──────────────┬──────────────┐ │
│ │ Public API │ Admin API │ Training API│ Batch API │ │
│ └──────────────┴──────────────┴──────────────┴──────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Service Layer │ │
│ │ InferenceService │ AsyncProcessing │ BatchUpload │ Dataset │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Data Layer │ │
│ │ AdminDB │ AsyncRequestDB │ SQLModel │ PostgreSQL │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Core Components │ │
│ │ RateLimiter │ Schedulers │ TaskQueues │ Auth │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────┬───────────────────────────────────┘
│ PostgreSQL
┌─────────────────────────────▼───────────────────────────────────┐
│ Training Service (GPU) │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ CLI: train │ autolabel │ analyze │ validate │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ YOLO: db_dataset │ annotation_generator │ │
│ └────────────────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Processing: CPU Pool │ GPU Pool │ Task Dispatcher │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
┌─────────┴─────────┐
▼ ▼
┌──────────────┐ ┌──────────────┐
│ Shared │ │ Storage │
│ PDF │ OCR │ │ Local/Azure/ │
│ Normalize │ │ S3 │
└──────────────┘ └──────────────┘
```
### 技术栈
| 层级 | 技术 | 评估 |
|------|------|------|
| **前端** | React + Vite + TypeScript + TailwindCSS | ✅ 现代栈 |
| **API 框架** | FastAPI | ✅ 高性能,类型安全 |
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 ORM |
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
| **OCR** | PaddleOCR v5 | ✅ 支持瑞典语 |
| **部署** | Docker + Azure/AWS | ✅ 云原生 |
---
## 架构优势
### 1. Monorepo 结构 ✅
```
packages/
├── shared/ # 共享库 - 无外部依赖
├── training/ # 训练服务 - 依赖 shared
└── inference/ # 推理服务 - 依赖 shared
```
**优点**:
- 清晰的包边界,无循环依赖
- 独立部署training 按需启动
- 代码复用率高
### 2. 分层架构 ✅
```
API Routes (web/api/v1/)
Service Layer (web/services/)
Data Layer (data/)
Database (PostgreSQL)
```
**优点**:
- 职责分离明确
- 便于单元测试
- 可替换底层实现
### 3. 依赖注入 ✅
```python
# FastAPI Depends 使用得当
@router.post("/infer")
async def infer(
file: UploadFile,
db: AdminDB = Depends(get_admin_db), # 注入
token: str = Depends(validate_admin_token),
):
```
### 4. 存储抽象层 ✅
```python
# 统一接口,支持多后端
class StorageBackend(ABC):
def upload(self, source: Path, destination: str) -> None: ...
def download(self, source: str, destination: Path) -> None: ...
def get_presigned_url(self, path: str) -> str: ...
# 实现: LocalStorageBackend, AzureStorageBackend, S3StorageBackend
```
### 5. 动态模型管理 ✅
```python
# 数据库驱动的模型切换
def get_active_model_path() -> Path | None:
db = AdminDB()
active_model = db.get_active_model_version()
return active_model.model_path if active_model else None
inference_service = InferenceService(
model_path_resolver=get_active_model_path,
)
```
### 6. 任务队列分离 ✅
```python
# 不同类型任务使用不同队列
- AsyncTaskQueue: 异步推理任务
- BatchQueue: 批量上传任务
- TrainingScheduler: 训练任务调度
- AutoLabelScheduler: 自动标注调度
```
---
## 架构问题与风险
### 1. 数据库层职责过重 ⚠️ **中风险**
**问题**: `AdminDB` 类过大,违反单一职责原则
```python
# packages/inference/inference/data/admin_db.py
class AdminDB:
# Token 管理 (5 个方法)
def is_valid_admin_token(self, token: str) -> bool: ...
def create_admin_token(self, token: str, name: str): ...
# 文档管理 (8 个方法)
def create_document(self, ...): ...
def get_document(self, doc_id: str): ...
# 标注管理 (6 个方法)
def create_annotation(self, ...): ...
def get_annotations(self, doc_id: str): ...
# 训练任务 (7 个方法)
def create_training_task(self, ...): ...
def update_training_task(self, ...): ...
# 数据集 (6 个方法)
def create_dataset(self, ...): ...
def get_dataset(self, dataset_id: str): ...
# 模型版本 (5 个方法)
def create_model_version(self, ...): ...
def activate_model_version(self, ...): ...
# 批处理 (4 个方法)
# 锁管理 (3 个方法)
# ... 总计 50+ 方法
```
**影响**:
- 类过大,难以维护
- 测试困难
- 不同领域变更互相影响
**建议**: 按领域拆分为 Repository 模式
```python
# 建议重构
class TokenRepository:
def validate(self, token: str) -> bool: ...
def create(self, token: Token) -> None: ...
class DocumentRepository:
def find_by_id(self, doc_id: str) -> Document | None: ...
def save(self, document: Document) -> None: ...
class TrainingRepository:
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
def update_task_status(self, task_id: str, status: TaskStatus): ...
class ModelRepository:
def get_active(self) -> ModelVersion | None: ...
def activate(self, version_id: str) -> None: ...
```
---
### 2. Service 层混合业务逻辑与技术细节 ⚠️ **中风险**
**问题**: `InferenceService` 既处理业务逻辑又处理技术实现
```python
# packages/inference/inference/web/services/inference.py
class InferenceService:
def process(self, image_bytes: bytes) -> ServiceResult:
# 1. 技术细节: 图像解码
image = Image.open(io.BytesIO(image_bytes))
# 2. 业务逻辑: 字段提取
fields = self._extract_fields(image)
# 3. 技术细节: 模型推理
detections = self._model.predict(image)
# 4. 业务逻辑: 结果验证
if not self._validate_fields(fields):
raise ValidationError()
```
**影响**:
- 难以测试业务逻辑
- 技术变更影响业务代码
- 无法切换技术实现
**建议**: 引入领域层和适配器模式
```python
# 领域层 - 纯业务逻辑
@dataclass
class InvoiceDocument:
document_id: str
pages: list[Page]
class InvoiceExtractor:
"""纯业务逻辑,不依赖技术实现"""
def extract(self, document: InvoiceDocument) -> InvoiceFields:
# 只处理业务规则
pass
# 适配器层 - 技术实现
class YoloFieldDetector:
"""YOLO 技术适配器"""
def __init__(self, model_path: Path):
self._model = YOLO(model_path)
def detect(self, image: np.ndarray) -> list[FieldRegion]:
return self._model.predict(image)
class PaddleOcrEngine:
"""PaddleOCR 技术适配器"""
def __init__(self):
self._ocr = PaddleOCR()
def recognize(self, image: np.ndarray, region: BoundingBox) -> str:
return self._ocr.ocr(image, region)
# 应用服务 - 协调领域和适配器
class InvoiceProcessingService:
def __init__(
self,
extractor: InvoiceExtractor,
detector: FieldDetector,
ocr: OcrEngine,
):
self._extractor = extractor
self._detector = detector
self._ocr = ocr
```
---
### 3. 调度器设计分散 ⚠️ **中风险**
**问题**: 多个独立调度器缺乏统一协调
```python
# 当前设计 - 4 个独立调度器
# 1. TrainingScheduler (core/scheduler.py)
# 2. AutoLabelScheduler (core/autolabel_scheduler.py)
# 3. AsyncTaskQueue (workers/async_queue.py)
# 4. BatchQueue (workers/batch_queue.py)
# app.py 中分别启动
start_scheduler() # 训练调度器
start_autolabel_scheduler() # 自动标注调度器
init_batch_queue() # 批处理队列
```
**影响**:
- 资源竞争风险
- 难以监控和追踪
- 任务优先级难以管理
- 重启时任务丢失
**建议**: 使用 Celery + Redis 统一任务队列
```python
# 建议重构
from celery import Celery
app = Celery('invoice_master')
@app.task(bind=True, max_retries=3)
def process_inference(self, document_id: str):
"""异步推理任务"""
try:
service = get_inference_service()
result = service.process(document_id)
return result
except Exception as exc:
raise self.retry(exc=exc, countdown=60)
@app.task
def train_model(dataset_id: str, config: dict):
"""训练任务"""
training_service = get_training_service()
return training_service.train(dataset_id, config)
@app.task
def auto_label_documents(document_ids: list[str]):
"""批量自动标注"""
for doc_id in document_ids:
auto_label_document.delay(doc_id)
# 优先级队列
app.conf.task_routes = {
'tasks.process_inference': {'queue': 'high_priority'},
'tasks.train_model': {'queue': 'gpu_queue'},
'tasks.auto_label_documents': {'queue': 'low_priority'},
}
```
---
### 4. 配置分散 ⚠️ **低风险**
**问题**: 配置分散在多个文件
```python
# packages/shared/shared/config.py
DATABASE = {...}
PATHS = {...}
AUTOLABEL = {...}
# packages/inference/inference/web/config.py
@dataclass
class ModelConfig: ...
@dataclass
class ServerConfig: ...
@dataclass
class FileConfig: ...
# 环境变量
# .env 文件
```
**影响**:
- 配置难以追踪
- 可能出现不一致
- 缺少配置验证
**建议**: 使用 Pydantic Settings 集中管理
```python
# config/settings.py
from pydantic_settings import BaseSettings, SettingsConfigDict
class DatabaseSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix='DB_')
host: str = 'localhost'
port: int = 5432
name: str = 'docmaster'
user: str = 'docmaster'
password: str # 无默认值,必须设置
class StorageSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix='STORAGE_')
backend: str = 'local'
base_path: str = '~/invoice-data'
azure_connection_string: str | None = None
s3_bucket: str | None = None
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
)
database: DatabaseSettings = DatabaseSettings()
storage: StorageSettings = StorageSettings()
# 验证
@field_validator('database')
def validate_database(cls, v):
if not v.password:
raise ValueError('Database password is required')
return v
# 全局配置实例
settings = Settings()
```
---
### 5. 内存队列单点故障 ⚠️ **中风险**
**问题**: AsyncTaskQueue 和 BatchQueue 基于内存
```python
# workers/async_queue.py
class AsyncTaskQueue:
def __init__(self):
self._queue = Queue() # 内存队列
self._workers = []
def enqueue(self, task: AsyncTask) -> None:
self._queue.put(task) # 仅存储在内存
```
**影响**:
- 服务重启丢失所有待处理任务
- 无法水平扩展
- 任务持久化困难
**建议**: 使用 Redis/RabbitMQ 持久化队列
---
### 6. 缺少 API 版本迁移策略 ❓ **低风险**
**问题**: 有 `/api/v1/` 版本,但缺少升级策略
```
当前: /api/v1/admin/documents
未来: /api/v2/admin/documents ?
```
**建议**:
- 制定 API 版本升级流程
- 使用 Header 版本控制
- 维护版本兼容性文档
---
## 关键架构风险矩阵
| 风险项 | 概率 | 影响 | 风险等级 | 优先级 |
|--------|------|------|----------|--------|
| 内存队列丢失任务 | 中 | 高 | **高** | 🔴 P0 |
| AdminDB 职责过重 | 高 | 中 | **中** | 🟡 P1 |
| Service 层混合 | 高 | 中 | **中** | 🟡 P1 |
| 调度器资源竞争 | 中 | 中 | **中** | 🟡 P1 |
| 配置分散 | 高 | 低 | **低** | 🟢 P2 |
| API 版本策略 | 低 | 低 | **低** | 🟢 P2 |
---
## 改进建议路线图
### Phase 1: 立即执行 (本周)
#### 1.1 拆分 AdminDB
```python
# 创建 repositories 包
inference/data/repositories/
├── __init__.py
├── base.py # Repository 基类
├── token.py # TokenRepository
├── document.py # DocumentRepository
├── annotation.py # AnnotationRepository
├── training.py # TrainingRepository
├── dataset.py # DatasetRepository
└── model.py # ModelRepository
```
#### 1.2 统一配置
```python
# 创建统一配置模块
inference/config/
├── __init__.py
├── settings.py # Pydantic Settings
└── validators.py # 配置验证
```
### Phase 2: 短期执行 (本月)
#### 2.1 引入消息队列
```yaml
# docker-compose.yml 添加
services:
redis:
image: redis:7-alpine
ports:
- "6379:6379"
celery_worker:
build: .
command: celery -A inference.tasks worker -l info
depends_on:
- redis
- postgres
```
#### 2.2 添加缓存层
```python
# 使用 Redis 缓存热点数据
from redis import Redis
redis_client = Redis(host='localhost', port=6379)
class CachedDocumentRepository(DocumentRepository):
def find_by_id(self, doc_id: str) -> Document | None:
# 先查缓存
cached = redis_client.get(f"doc:{doc_id}")
if cached:
return Document.parse_raw(cached)
# 再查数据库
doc = super().find_by_id(doc_id)
if doc:
redis_client.setex(f"doc:{doc_id}", 3600, doc.json())
return doc
```
### Phase 3: 长期执行 (本季度)
#### 3.1 数据库读写分离
```python
# 配置主从数据库
class DatabaseManager:
def __init__(self):
self._master = create_engine(MASTER_DB_URL)
self._replica = create_engine(REPLICA_DB_URL)
def get_session(self, readonly: bool = False) -> Session:
engine = self._replica if readonly else self._master
return Session(engine)
```
#### 3.2 事件驱动架构
```python
# 引入事件总线
from event_bus import EventBus
bus = EventBus()
# 发布事件
@router.post("/documents")
async def create_document(...):
doc = document_repo.save(document)
bus.publish('document.created', {'document_id': doc.id})
return doc
# 订阅事件
@bus.subscribe('document.created')
def on_document_created(event):
# 触发自动标注
auto_label_task.delay(event['document_id'])
```
---
## 架构演进建议
### 当前架构 (适合 1-10 用户)
```
Single Instance
├── FastAPI App
├── Memory Queues
└── PostgreSQL
```
### 目标架构 (适合 100+ 用户)
```
Load Balancer
├── FastAPI Instance 1
├── FastAPI Instance 2
└── FastAPI Instance N
┌───────┴───────┐
▼ ▼
Redis Cluster PostgreSQL
(Celery + Cache) (Master + Replica)
```
---
## 总结
### 总体评分
| 维度 | 评分 | 说明 |
|------|------|------|
| **模块化** | 8/10 | 包结构清晰,但部分类过大 |
| **可扩展性** | 7/10 | 水平扩展良好,垂直扩展受限 |
| **可维护性** | 8/10 | 分层合理,但职责边界需细化 |
| **可靠性** | 7/10 | 内存队列是单点故障 |
| **性能** | 8/10 | 异步处理良好 |
| **安全性** | 8/10 | 基础安全到位 |
| **总体** | **7.7/10** | 良好的架构基础,需优化细节 |
### 关键结论
1. **架构设计合理**: Monorepo + 分层架构适合当前规模
2. **主要风险**: 内存队列和数据库职责过重
3. **演进路径**: 引入消息队列和缓存层
4. **投入产出**: 当前架构可支撑到 100+ 用户,无需大规模重构
### 下一步行动
| 优先级 | 任务 | 预计工时 | 影响 |
|--------|------|----------|------|
| 🔴 P0 | 引入 Celery + Redis | 3 天 | 解决任务丢失问题 |
| 🟡 P1 | 拆分 AdminDB | 2 天 | 提升可维护性 |
| 🟡 P1 | 统一配置管理 | 1 天 | 减少配置错误 |
| 🟢 P2 | 添加缓存层 | 2 天 | 提升性能 |
| 🟢 P2 | 数据库读写分离 | 3 天 | 提升扩展性 |
---
## 附录
### 关键文件清单
| 文件 | 职责 | 问题 |
|------|------|------|
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
| `shared/shared/config.py` | 共享配置 | 分散管理 |
### 参考资源
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
- [Celery Documentation](https://docs.celeryproject.org/)
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)

View File

@@ -1,317 +0,0 @@
# Changelog
All notable changes to the Invoice Field Extraction project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added - Phase 1: Security & Infrastructure (2026-01-22)
#### Security Enhancements
- **Environment Variable Management**: Added `python-dotenv` for secure configuration management
- Created `.env.example` template file for configuration reference
- Created `.env` file for actual credentials (gitignored)
- Updated `config.py` to load database password from environment variables
- Added validation to ensure `DB_PASSWORD` is set at startup
- Files modified: `config.py`, `requirements.txt`
- New files: `.env`, `.env.example`
- Tests: `tests/test_config.py` (7 tests, all passing)
- **SQL Injection Prevention**: Fixed SQL injection vulnerabilities in database queries
- Replaced f-string formatting with parameterized queries in `LIMIT` clauses
- Updated `get_all_documents_summary()` to use `%s` placeholder for LIMIT parameter
- Updated `get_failed_matches()` to use `%s` placeholder for LIMIT parameter
- Files modified: `src/data/db.py` (lines 246, 298)
- Tests: `tests/test_db_security.py` (9 tests, all passing)
#### Code Quality
- **Exception Hierarchy**: Created comprehensive custom exception system
- Added base class `InvoiceExtractionError` with message and details support
- Added specific exception types:
- `PDFProcessingError` - PDF rendering/conversion errors
- `OCRError` - OCR processing errors
- `ModelInferenceError` - YOLO model errors
- `FieldValidationError` - Field validation errors (with field-specific attributes)
- `DatabaseError` - Database operation errors
- `ConfigurationError` - Configuration errors
- `PaymentLineParseError` - Payment line parsing errors
- `CustomerNumberParseError` - Customer number parsing errors
- `DataLoadError` - Data loading errors
- `AnnotationError` - Annotation generation errors
- New file: `src/exceptions.py`
- Tests: `tests/test_exceptions.py` (16 tests, all passing)
### Testing
- Added 32 new tests across 3 test files
- Configuration tests: 7 tests
- SQL injection prevention tests: 9 tests
- Exception hierarchy tests: 16 tests
- All tests passing (32/32)
### Documentation
- Created `docs/CODE_REVIEW_REPORT.md` - Comprehensive code quality analysis (550+ lines)
- Created `docs/REFACTORING_PLAN.md` - Detailed 3-phase refactoring plan (600+ lines)
- Created `CHANGELOG.md` - Project changelog (this file)
### Changed
- **Configuration Loading**: Database configuration now loads from environment variables instead of hardcoded values
- Breaking change: Requires `.env` file with `DB_PASSWORD` set
- Migration: Copy `.env.example` to `.env` and set your database password
### Security
- **Fixed**: Database password no longer stored in plain text in `config.py`
- **Fixed**: SQL injection vulnerabilities in LIMIT clauses (2 instances)
### Technical Debt Addressed
- Eliminated security vulnerability: plaintext password storage
- Reduced SQL injection attack surface
- Improved error handling granularity with custom exceptions
---
### Added - Phase 2: Parser Refactoring (2026-01-22)
#### Unified Parser Modules
- **Payment Line Parser**: Created dedicated payment line parsing module
- Handles Swedish payment line format: `# <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#`
- Tolerates common OCR errors: spaces in numbers, missing symbols, spaces in check digits
- Supports 4 parsing patterns: full format, no amount, alternative, account-only
- Returns structured `PaymentLineData` with parsed fields
- New file: `src/inference/payment_line_parser.py` (90 lines, 92% coverage)
- Tests: `tests/test_payment_line_parser.py` (23 tests, all passing)
- Eliminates 1st code duplication (payment line parsing logic)
- **Customer Number Parser**: Created dedicated customer number parsing module
- Handles Swedish customer number formats: `JTY 576-3`, `DWQ 211-X`, `FFL 019N`, etc.
- Uses Strategy Pattern with 5 pattern classes:
- `LabeledPattern` - Explicit labels (highest priority, 0.98 confidence)
- `DashFormatPattern` - Standard format with dash (0.95 confidence)
- `NoDashFormatPattern` - Format without dash, adds dash automatically (0.90 confidence)
- `CompactFormatPattern` - Compact format without spaces (0.75 confidence)
- `GenericAlphanumericPattern` - Fallback generic pattern (variable confidence)
- Excludes Swedish postal codes (`SE XXX XX` format)
- Returns highest confidence match
- New file: `src/inference/customer_number_parser.py` (154 lines, 92% coverage)
- Tests: `tests/test_customer_number_parser.py` (32 tests, all passing)
- Reduces `_normalize_customer_number` complexity (127 lines → will use 5-10 lines after integration)
### Testing Summary
**Phase 1 Tests** (32 tests):
- Configuration tests: 7 tests ([test_config.py](tests/test_config.py))
- SQL injection prevention tests: 9 tests ([test_db_security.py](tests/test_db_security.py))
- Exception hierarchy tests: 16 tests ([test_exceptions.py](tests/test_exceptions.py))
**Phase 2 Tests** (121 tests):
- Payment line parser tests: 23 tests ([test_payment_line_parser.py](tests/test_payment_line_parser.py))
- Standard parsing, OCR error handling, real-world examples, edge cases
- Coverage: 92%
- Customer number parser tests: 32 tests ([test_customer_number_parser.py](tests/test_customer_number_parser.py))
- Pattern matching (DashFormat, NoDashFormat, Compact, Labeled)
- Real-world examples, edge cases, Swedish postal code exclusion
- Coverage: 92%
- Field extractor integration tests: 45 tests ([test_field_extractor.py](src/inference/test_field_extractor.py))
- Validates backward compatibility with existing code
- Tests for invoice numbers, bankgiro, plusgiro, amounts, OCR, dates, payment lines, customer numbers
- Pipeline integration tests: 21 tests ([test_pipeline.py](src/inference/test_pipeline.py))
- Cross-validation, payment line parsing, field overrides
**Total**: 153 tests, 100% passing, 4.50s runtime
### Code Quality
- **Eliminated Code Duplication**: Payment line parsing previously in 3 places, now unified in 1 module
- **Improved Maintainability**: Strategy Pattern makes customer number patterns easy to extend
- **Better Test Coverage**: New parsers have 92% coverage vs original 10% in field_extractor.py
#### Parser Integration into field_extractor.py (2026-01-22)
- **field_extractor.py Integration**: Successfully integrated new parsers
- Added `PaymentLineParser` and `CustomerNumberParser` instances (lines 99-101)
- Replaced `_normalize_payment_line` method: 74 lines → 3 lines (lines 640-657)
- Replaced `_normalize_customer_number` method: 127 lines → 3 lines (lines 697-707)
- All 45 existing tests pass (100% backward compatibility maintained)
- Tests run time: 4.21 seconds
- File: `src/inference/field_extractor.py`
#### Parser Integration into pipeline.py (2026-01-22)
- **pipeline.py Integration**: Successfully integrated PaymentLineParser
- Added `PaymentLineParser` import (line 15)
- Added `payment_line_parser` instance initialization (line 128)
- Replaced `_parse_machine_readable_payment_line` method: 36 lines → 6 lines (lines 219-233)
- All 21 existing tests pass (100% backward compatibility maintained)
- Tests run time: 4.00 seconds
- File: `src/inference/pipeline.py`
### Phase 2 Status: **COMPLETED** ✅
- [x] Create unified `payment_line_parser` module ✅
- [x] Create unified `customer_number_parser` module ✅
- [x] Refactor `field_extractor.py` to use new parsers ✅
- [x] Refactor `pipeline.py` to use new parsers ✅
- [x] Comprehensive test suite (153 tests, 100% passing) ✅
### Achieved Impact
- Eliminate code duplication: 3 implementations → 1 ✅ (payment_line unified across field_extractor.py, pipeline.py, tests)
- Reduce `_normalize_payment_line` complexity in field_extractor.py: 74 lines → 3 lines ✅
- Reduce `_normalize_customer_number` complexity in field_extractor.py: 127 lines → 3 lines ✅
- Reduce `_parse_machine_readable_payment_line` complexity in pipeline.py: 36 lines → 6 lines ✅
- Total lines of code eliminated: 201 lines reduced to 12 lines (94% reduction) ✅
- Improve test coverage: New parser modules have 92% coverage (vs original 10% in field_extractor.py)
- Simplify maintenance: Pattern-based approach makes extension easy
- 100% backward compatibility: All 66 existing tests pass (45 field_extractor + 21 pipeline)
---
## Phase 3: Performance & Documentation (2026-01-22)
### Added
#### Configuration Constants Extraction
- **Created `src/inference/constants.py`**: Centralized configuration constants
- Detection & model configuration (confidence thresholds, IOU)
- Image processing configuration (DPI, scaling factors)
- Customer number parser confidence scores
- Field extraction confidence multipliers
- Account type detection thresholds
- Pattern matching constants
- 90 lines of well-documented constants with usage notes
- Eliminates ~15 hardcoded magic numbers across codebase
- File: [src/inference/constants.py](src/inference/constants.py)
#### Performance Optimization Documentation
- **Created `docs/PERFORMANCE_OPTIMIZATION.md`**: Comprehensive performance guide (400+ lines)
- **Batch Processing Optimization**: Parallel processing strategies, already-implemented dual pool system
- **Database Query Optimization**: Connection pooling recommendations, index strategies
- **Caching Strategies**: Model loading cache, parser reuse (already optimal), OCR result caching
- **Memory Management**: Explicit cleanup, generator patterns, context managers
- **Profiling Guidelines**: cProfile, memory_profiler, py-spy recommendations
- **Benchmarking Scripts**: Ready-to-use performance measurement code
- **Priority Roadmap**: High/Medium/Low priority optimizations with effort estimates
- Expected impact: 2-5x throughput improvement for batch processing
- File: [docs/PERFORMANCE_OPTIMIZATION.md](docs/PERFORMANCE_OPTIMIZATION.md)
### Phase 3 Status: **COMPLETED** ✅
- [x] Configuration constants extraction ✅
- [x] Performance optimization analysis ✅
- [x] Batch processing optimization recommendations ✅
- [x] Database optimization strategies ✅
- [x] Caching and memory management guidelines ✅
- [x] Profiling and benchmarking documentation ✅
### Deliverables
**New Files** (2 files):
1. `src/inference/constants.py` (90 lines) - Centralized configuration constants
2. `docs/PERFORMANCE_OPTIMIZATION.md` (400+ lines) - Performance optimization guide
**Impact**:
- Eliminates 15+ hardcoded magic numbers
- Provides clear optimization roadmap
- Documents existing performance features
- Identifies quick wins (connection pooling, indexes)
- Long-term strategy (caching, profiling)
---
## Notes
### Breaking Changes
- **v2.x**: Requires `.env` file with database credentials
- Action required: Create `.env` file based on `.env.example`
- Affected: All deployments, CI/CD pipelines
### Migration Guide
#### From v1.x to v2.x (Environment Variables)
1. Copy `.env.example` to `.env`:
```bash
cp .env.example .env
```
2. Edit `.env` and set your database password:
```
DB_PASSWORD=your_actual_password_here
```
3. Install new dependency:
```bash
pip install python-dotenv
```
4. Verify configuration loads correctly:
```bash
python -c "import config; print('Config loaded successfully')"
```
## Summary of All Work Completed
### Files Created (13 new files)
**Phase 1** (3 files):
1. `.env` - Environment variables for database credentials
2. `.env.example` - Template for environment configuration
3. `src/exceptions.py` - Custom exception hierarchy (35 lines, 66% coverage)
**Phase 2** (7 files):
4. `src/inference/payment_line_parser.py` - Unified payment line parsing (90 lines, 92% coverage)
5. `src/inference/customer_number_parser.py` - Unified customer number parsing (154 lines, 92% coverage)
6. `tests/test_config.py` - Configuration tests (7 tests)
7. `tests/test_db_security.py` - SQL injection prevention tests (9 tests)
8. `tests/test_exceptions.py` - Exception hierarchy tests (16 tests)
9. `tests/test_payment_line_parser.py` - Payment line parser tests (23 tests)
10. `tests/test_customer_number_parser.py` - Customer number parser tests (32 tests)
**Phase 3** (2 files):
11. `src/inference/constants.py` - Centralized configuration constants (90 lines)
12. `docs/PERFORMANCE_OPTIMIZATION.md` - Performance optimization guide (400+ lines)
**Documentation** (1 file):
13. `CHANGELOG.md` - This file (260+ lines of detailed documentation)
### Files Modified (4 files)
1. `config.py` - Added environment variable loading with python-dotenv
2. `src/data/db.py` - Fixed 2 SQL injection vulnerabilities (lines 246, 298)
3. `src/inference/field_extractor.py` - Integrated new parsers (reduced 201 lines to 6 lines)
4. `src/inference/pipeline.py` - Integrated PaymentLineParser (reduced 36 lines to 6 lines)
5. `requirements.txt` - Added python-dotenv dependency
### Test Summary
- **Total tests**: 153 tests across 7 test files
- **Passing**: 153 (100%)
- **Failing**: 0
- **Runtime**: 4.50 seconds
- **Coverage**:
- New parser modules: 92%
- Config module: 100%
- Exception module: 66%
- DB security coverage: 18% (focused on parameterized queries)
### Code Metrics
- **Lines eliminated**: 237 lines of duplicated/complex code → 18 lines (92% reduction)
- field_extractor.py: 201 lines → 6 lines
- pipeline.py: 36 lines → 6 lines
- **New code added**: 279 lines of well-tested parser code
- **Net impact**: Replaced 237 lines of duplicate code with 279 lines of unified, tested code (+42 lines, but -3 implementations)
- **Test coverage improvement**: 0% → 92% for parser logic
### Performance Impact
- Configuration loading: Negligible (<1ms overhead for .env parsing)
- SQL queries: No performance change (parameterized queries are standard practice)
- Parser refactoring: No performance degradation (logic simplified, not changed)
- Exception handling: Minimal overhead (only when exceptions are raised)
### Security Improvements
- Eliminated plaintext password storage
- Fixed 2 SQL injection vulnerabilities
- Added input validation in database layer
### Maintainability Improvements
- Eliminated code duplication (3 implementations 1)
- Strategy Pattern enables easy extension of customer number formats
- Comprehensive test suite (153 tests) ensures safe refactoring
- 100% backward compatibility maintained
- Custom exception hierarchy for granular error handling

View File

@@ -1,805 +0,0 @@
# Invoice Master POC v2 - 详细代码审查报告
**审查日期**: 2026-02-01
**审查人**: Claude Code
**项目路径**: `C:\Users\yaoji\git\ColaCoder\invoice-master-poc-v2`
**代码统计**:
- Python文件: 200+ 个
- 测试文件: 97 个
- TypeScript/React文件: 39 个
- 总测试数: 1,601 个
- 测试覆盖率: 28%
---
## 目录
1. [执行摘要](#执行摘要)
2. [架构概览](#架构概览)
3. [详细模块审查](#详细模块审查)
4. [代码质量问题](#代码质量问题)
5. [安全风险分析](#安全风险分析)
6. [性能问题](#性能问题)
7. [改进建议](#改进建议)
8. [总结与评分](#总结与评分)
---
## 执行摘要
### 总体评估
| 维度 | 评分 | 状态 |
|------|------|------|
| **代码质量** | 7.5/10 | 良好,但有改进空间 |
| **安全性** | 7/10 | 基础安全到位,需加强 |
| **可维护性** | 8/10 | 模块化良好 |
| **测试覆盖** | 5/10 | 偏低,需提升 |
| **性能** | 8/10 | 异步处理良好 |
| **文档** | 8/10 | 文档详尽 |
| **总体** | **7.3/10** | 生产就绪,需小幅改进 |
### 关键发现
**优势:**
- 清晰的Monorepo架构三包分离合理
- 类型注解覆盖率高(>90%
- 存储抽象层设计优秀
- FastAPI使用规范依赖注入模式良好
- 异常处理完善,自定义异常层次清晰
**风险:**
- 测试覆盖率仅28%,远低于行业标准
- AdminDB类过大50+方法),违反单一职责原则
- 内存队列存在单点故障风险
- 部分安全细节需加强(时序攻击、文件上传验证)
- 前端状态管理简单,可能难以扩展
---
## 架构概览
### 项目结构
```
invoice-master-poc-v2/
├── packages/
│ ├── shared/ # 共享库 (74个Python文件)
│ │ ├── pdf/ # PDF处理
│ │ ├── ocr/ # OCR封装
│ │ ├── normalize/ # 字段规范化
│ │ ├── matcher/ # 字段匹配
│ │ ├── storage/ # 存储抽象层
│ │ ├── training/ # 训练组件
│ │ └── augmentation/# 数据增强
│ ├── training/ # 训练服务 (26个Python文件)
│ │ ├── cli/ # 命令行工具
│ │ ├── yolo/ # YOLO数据集
│ │ └── processing/ # 任务处理
│ └── inference/ # 推理服务 (100个Python文件)
│ ├── web/ # FastAPI应用
│ ├── pipeline/ # 推理管道
│ ├── data/ # 数据层
│ └── cli/ # 命令行工具
├── frontend/ # React前端 (39个TS/TSX文件)
│ ├── src/
│ │ ├── components/ # UI组件
│ │ ├── hooks/ # React Query hooks
│ │ └── api/ # API客户端
└── tests/ # 测试 (97个Python文件)
```
### 技术栈
| 层级 | 技术 | 评估 |
|------|------|------|
| **前端** | React 18 + TypeScript + Vite + TailwindCSS | 现代栈,类型安全 |
| **API框架** | FastAPI + Uvicorn | 高性能,异步支持 |
| **数据库** | PostgreSQL + SQLModel | 类型安全ORM |
| **目标检测** | YOLOv11 (Ultralytics) | 业界标准 |
| **OCR** | PaddleOCR v5 | 支持瑞典语 |
| **部署** | Docker + Azure/AWS | 云原生 |
---
## 详细模块审查
### 1. Shared Package
#### 1.1 配置模块 (`shared/config.py`)
**文件位置**: `packages/shared/shared/config.py`
**代码行数**: 82行
**优点:**
- 使用环境变量加载配置,无硬编码敏感信息
- DPI配置统一管理DEFAULT_DPI = 150
- 密码无默认值,强制要求设置
**问题:**
```python
# 问题1: 配置分散,缺少验证
DATABASE = {
'host': os.getenv('DB_HOST', '192.168.68.31'), # 硬编码IP
'port': int(os.getenv('DB_PORT', '5432')),
# ...
}
# 问题2: 缺少类型安全
# 建议使用 Pydantic Settings
```
**严重程度**: 中
**建议**: 使用 Pydantic Settings 集中管理配置,添加验证逻辑
---
#### 1.2 存储抽象层 (`shared/storage/`)
**文件位置**: `packages/shared/shared/storage/`
**包含文件**: 8个
**优点:**
- 设计优秀的抽象接口 `StorageBackend`
- 支持 Local/Azure/S3 多后端
- 预签名URL支持
- 异常层次清晰
**代码示例 - 优秀设计:**
```python
class StorageBackend(ABC):
@abstractmethod
def upload(self, local_path: Path, remote_path: str, overwrite: bool = False) -> str:
pass
@abstractmethod
def get_presigned_url(self, remote_path: str, expires_in_seconds: int = 3600) -> str:
pass
```
**问题:**
- `upload_bytes``download_bytes` 默认实现使用临时文件,效率较低
- 缺少文件类型验证(魔术字节检查)
**严重程度**: 低
**建议**: 子类可重写bytes方法以提高效率添加文件类型验证
---
#### 1.3 异常定义 (`shared/exceptions.py`)
**文件位置**: `packages/shared/shared/exceptions.py`
**代码行数**: 103行
**优点:**
- 清晰的异常层次结构
- 所有异常继承自 `InvoiceExtractionError`
- 包含详细的错误上下文
**代码示例:**
```python
class InvoiceExtractionError(Exception):
def __init__(self, message: str, details: dict = None):
super().__init__(message)
self.message = message
self.details = details or {}
```
**评分**: 9/10 - 设计优秀
---
#### 1.4 数据增强 (`shared/augmentation/`)
**文件位置**: `packages/shared/shared/augmentation/`
**包含文件**: 10个
**功能:**
- 12种数据增强策略
- 透视变换、皱纹、边缘损坏、污渍等
- 高斯模糊、运动模糊、噪声等
**代码质量**: 良好,模块化设计
---
### 2. Inference Package
#### 2.1 认证模块 (`inference/web/core/auth.py`)
**文件位置**: `packages/inference/inference/web/core/auth.py`
**代码行数**: 61行
**优点:**
- 使用FastAPI依赖注入模式
- Token过期检查
- 记录最后使用时间
**安全问题:**
```python
# 问题: 时序攻击风险 (第46行)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(status_code=401, detail="Invalid or expired admin token.")
# 建议: 使用 constant-time 比较
import hmac
if not hmac.compare_digest(token, expected_token):
raise HTTPException(status_code=401, ...)
```
**严重程度**: 中
**建议**: 使用 `hmac.compare_digest()` 进行constant-time比较
---
#### 2.2 限流器 (`inference/web/core/rate_limiter.py`)
**文件位置**: `packages/inference/inference/web/core/rate_limiter.py`
**代码行数**: 212行
**优点:**
- 滑动窗口算法实现
- 线程安全使用Lock
- 支持并发任务限制
- 可配置的限流策略
**代码示例 - 优秀设计:**
```python
@dataclass(frozen=True)
class RateLimitConfig:
requests_per_minute: int = 10
max_concurrent_jobs: int = 3
min_poll_interval_ms: int = 1000
```
**问题:**
- 内存存储,服务重启后限流状态丢失
- 分布式部署时无法共享限流状态
**严重程度**: 中
**建议**: 生产环境使用Redis实现分布式限流
---
#### 2.3 AdminDB (`inference/data/admin_db.py`)
**文件位置**: `packages/inference/inference/data/admin_db.py`
**代码行数**: 1300+行
**严重问题 - 类过大:**
```python
class AdminDB:
# Token管理 (5个方法)
# 文档管理 (8个方法)
# 标注管理 (6个方法)
# 训练任务 (7个方法)
# 数据集 (6个方法)
# 模型版本 (5个方法)
# 批处理 (4个方法)
# 锁管理 (3个方法)
# ... 总计50+方法
```
**影响:**
- 违反单一职责原则
- 难以维护
- 测试困难
- 不同领域变更互相影响
**严重程度**: 高
**建议**: 按领域拆分为Repository模式
```python
# 建议重构
class TokenRepository:
def validate(self, token: str) -> bool: ...
class DocumentRepository:
def find_by_id(self, doc_id: str) -> Document | None: ...
class TrainingRepository:
def create_task(self, config: TrainingConfig) -> TrainingTask: ...
```
---
#### 2.4 文档路由 (`inference/web/api/v1/admin/documents.py`)
**文件位置**: `packages/inference/inference/web/api/v1/admin/documents.py`
**代码行数**: 692行
**优点:**
- FastAPI使用规范
- 输入验证完善
- 响应模型定义清晰
- 错误处理良好
**问题:**
```python
# 问题1: 文件上传缺少魔术字节验证 (第127-131行)
content = await file.read()
# 建议: 验证PDF魔术字节 %PDF
# 问题2: 路径遍历风险 (第494-498行)
filename = Path(document.file_path).name
# 建议: 使用 Path.name 并验证路径范围
# 问题3: 函数过长,职责过多
# _convert_pdf_to_images 函数混合了PDF处理和存储操作
```
**严重程度**: 中
**建议**: 添加文件类型验证,拆分大函数
---
#### 2.5 推理服务 (`inference/web/services/inference.py`)
**文件位置**: `packages/inference/inference/web/services/inference.py`
**代码行数**: 361行
**优点:**
- 支持动态模型加载
- 懒加载初始化
- 模型热重载支持
**问题:**
```python
# 问题1: 混合业务逻辑和技术实现
def process_image(self, image_path: Path, ...) -> ServiceResult:
# 1. 技术细节: 图像解码
# 2. 业务逻辑: 字段提取
# 3. 技术细节: 模型推理
# 4. 业务逻辑: 结果验证
# 问题2: 可视化方法重复加载模型
model = YOLO(str(self.model_config.model_path)) # 第316行
# 应该在初始化时加载避免重复IO
# 问题3: 临时文件未使用上下文管理器
temp_path = results_dir / f"{doc_id}_temp.png"
# 建议使用 tempfile 上下文管理器
```
**严重程度**: 中
**建议**: 引入领域层和适配器模式,分离业务和技术逻辑
---
#### 2.6 异步队列 (`inference/web/workers/async_queue.py`)
**文件位置**: `packages/inference/inference/web/workers/async_queue.py`
**代码行数**: 213行
**优点:**
- 线程安全实现
- 优雅关闭支持
- 任务状态跟踪
**严重问题:**
```python
# 问题: 内存队列,服务重启丢失任务 (第42行)
self._queue: Queue[AsyncTask] = Queue(maxsize=max_size)
# 问题: 无法水平扩展
# 问题: 任务持久化困难
```
**严重程度**: 高
**建议**: 使用Redis/RabbitMQ持久化队列
---
### 3. Training Package
#### 3.1 整体评估
**文件数量**: 26个Python文件
**优点:**
- CLI工具设计良好
- 双池协调器CPU + GPU设计优秀
- 数据增强策略丰富
**总体评分**: 8/10
---
### 4. Frontend
#### 4.1 API客户端 (`frontend/src/api/client.ts`)
**文件位置**: `frontend/src/api/client.ts`
**代码行数**: 42行
**优点:**
- Axios配置清晰
- 请求/响应拦截器
- 认证token自动添加
**问题:**
```typescript
// 问题1: Token存储在localStorage存在XSS风险
const token = localStorage.getItem('admin_token')
// 问题2: 401错误处理不完整
if (error.response?.status === 401) {
console.warn('Authentication required...')
// 应该触发重新登录或token刷新
}
```
**严重程度**: 中
**建议**: 考虑使用http-only cookie存储token完善错误处理
---
#### 4.2 Dashboard组件 (`frontend/src/components/Dashboard.tsx`)
**文件位置**: `frontend/src/components/Dashboard.tsx`
**代码行数**: 301行
**优点:**
- React hooks使用规范
- 类型定义清晰
- UI响应式设计
**问题:**
```typescript
// 问题1: 硬编码的进度值
const getAutoLabelProgress = (doc: DocumentItem): number | undefined => {
if (doc.auto_label_status === 'running') {
return 45 // 硬编码!
}
// ...
}
// 问题2: 搜索功能未实现
// 没有onChange处理
// 问题3: 缺少错误边界处理
// 组件应该包裹在Error Boundary中
```
**严重程度**: 低
**建议**: 实现真实的进度获取,添加搜索功能
---
#### 4.3 整体评估
**优点:**
- TypeScript类型安全
- React Query状态管理
- TailwindCSS样式一致
**问题:**
- 缺少错误边界
- 部分功能硬编码
- 缺少单元测试
**总体评分**: 7.5/10
---
### 5. Tests
#### 5.1 测试统计
- **测试文件数**: 97个
- **测试总数**: 1,601个
- **测试覆盖率**: 28%
#### 5.2 覆盖率分析
| 模块 | 估计覆盖率 | 状态 |
|------|-----------|------|
| `shared/` | 35% | 偏低 |
| `inference/web/` | 25% | 偏低 |
| `inference/pipeline/` | 20% | 严重不足 |
| `training/` | 30% | 偏低 |
| `frontend/` | 15% | 严重不足 |
#### 5.3 测试质量问题
**优点:**
- 使用了pytest框架
- 有conftest.py配置
- 部分集成测试
**问题:**
- 覆盖率远低于行业标准80%
- 缺少端到端测试
- 部分测试可能过于简单
**严重程度**: 高
**建议**: 制定测试计划,优先覆盖核心业务逻辑
---
## 代码质量问题
### 高优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| AdminDB类过大 | `inference/data/admin_db.py` | 维护困难 | 拆分为Repository模式 |
| 内存队列单点故障 | `inference/web/workers/async_queue.py` | 任务丢失 | 使用Redis持久化 |
| 测试覆盖率过低 | 全项目 | 代码风险 | 提升至60%+ |
### 中优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| 时序攻击风险 | `inference/web/core/auth.py` | 安全漏洞 | 使用hmac.compare_digest |
| 限流器内存存储 | `inference/web/core/rate_limiter.py` | 分布式问题 | 使用Redis |
| 配置分散 | `shared/config.py` | 难以管理 | 使用Pydantic Settings |
| 文件上传验证不足 | `inference/web/api/v1/admin/documents.py` | 安全风险 | 添加魔术字节验证 |
| 推理服务混合职责 | `inference/web/services/inference.py` | 难以测试 | 分离业务和技术逻辑 |
### 低优先级问题
| 问题 | 位置 | 影响 | 建议 |
|------|------|------|------|
| 前端搜索未实现 | `frontend/src/components/Dashboard.tsx` | 功能缺失 | 实现搜索功能 |
| 硬编码进度值 | `frontend/src/components/Dashboard.tsx` | 用户体验 | 获取真实进度 |
| Token存储方式 | `frontend/src/api/client.ts` | XSS风险 | 考虑http-only cookie |
---
## 安全风险分析
### 已识别的安全风险
#### 1. 时序攻击 (中风险)
**位置**: `inference/web/core/auth.py:46`
```python
# 当前实现(有风险)
if not admin_db.is_valid_admin_token(x_admin_token):
raise HTTPException(status_code=401, ...)
# 安全实现
import hmac
if not hmac.compare_digest(token, expected_token):
raise HTTPException(status_code=401, ...)
```
#### 2. 文件上传验证不足 (中风险)
**位置**: `inference/web/api/v1/admin/documents.py:127-131`
```python
# 建议添加魔术字节验证
ALLOWED_EXTENSIONS = {".pdf"}
MAX_FILE_SIZE = 10 * 1024 * 1024
if not content.startswith(b"%PDF"):
raise HTTPException(400, "Invalid PDF file format")
```
#### 3. 路径遍历风险 (中风险)
**位置**: `inference/web/api/v1/admin/documents.py:494-498`
```python
# 建议实现
from pathlib import Path
def get_safe_path(filename: str, base_dir: Path) -> Path:
safe_name = Path(filename).name
full_path = (base_dir / safe_name).resolve()
if not full_path.is_relative_to(base_dir):
raise HTTPException(400, "Invalid file path")
return full_path
```
#### 4. CORS配置 (低风险)
**位置**: FastAPI中间件配置
```python
# 建议生产环境配置
ALLOWED_ORIGINS = [
"http://localhost:5173",
"https://your-domain.com",
]
```
#### 5. XSS风险 (低风险)
**位置**: `frontend/src/api/client.ts:13`
```typescript
// 当前实现
const token = localStorage.getItem('admin_token')
// 建议考虑
// 使用http-only cookie存储敏感token
```
### 安全评分
| 类别 | 评分 | 说明 |
|------|------|------|
| 认证 | 8/10 | 基础良好,需加强时序攻击防护 |
| 输入验证 | 7/10 | 基本验证到位,需加强文件验证 |
| 数据保护 | 8/10 | 无敏感信息硬编码 |
| 传输安全 | 8/10 | 使用HTTPS生产环境 |
| 总体 | 7.5/10 | 基础安全良好,需加强细节 |
---
## 性能问题
### 已识别的性能问题
#### 1. 重复模型加载
**位置**: `inference/web/services/inference.py:316`
```python
# 问题: 每次可视化都重新加载模型
model = YOLO(str(self.model_config.model_path))
# 建议: 复用已加载的模型
```
#### 2. 临时文件处理
**位置**: `shared/storage/base.py:178-203`
```python
# 问题: bytes操作使用临时文件
def upload_bytes(self, data: bytes, ...):
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(data)
temp_path = Path(f.name)
# ...
# 建议: 子类重写为直接上传
```
#### 3. 数据库查询优化
**位置**: `inference/data/admin_db.py`
```python
# 问题: N+1查询风险
for doc in documents:
annotations = db.get_annotations_for_document(str(doc.document_id))
# ...
# 建议: 使用join预加载
```
### 性能评分
| 类别 | 评分 | 说明 |
|------|------|------|
| 响应时间 | 8/10 | 异步处理良好 |
| 资源使用 | 7/10 | 有优化空间 |
| 可扩展性 | 7/10 | 内存队列限制 |
| 并发处理 | 8/10 | 线程池设计良好 |
| 总体 | 7.5/10 | 良好,有优化空间 |
---
## 改进建议
### 立即执行 (本周)
1. **拆分AdminDB**
- 创建 `repositories/` 目录
- 按领域拆分TokenRepository, DocumentRepository, TrainingRepository
- 估计工时: 2天
2. **修复安全漏洞**
- 添加 `hmac.compare_digest()` 时序攻击防护
- 添加文件魔术字节验证
- 估计工时: 0.5天
3. **提升测试覆盖率**
- 优先测试 `inference/pipeline/`
- 添加API集成测试
- 目标: 从28%提升至50%
- 估计工时: 3天
### 短期执行 (本月)
4. **引入消息队列**
- 添加Redis服务
- 使用Celery替换内存队列
- 估计工时: 3天
5. **统一配置管理**
- 使用 Pydantic Settings
- 集中验证逻辑
- 估计工时: 1天
6. **添加缓存层**
- Redis缓存热点数据
- 缓存文档、模型配置
- 估计工时: 2天
### 长期执行 (本季度)
7. **数据库读写分离**
- 配置主从数据库
- 读操作使用从库
- 估计工时: 3天
8. **事件驱动架构**
- 引入事件总线
- 解耦模块依赖
- 估计工时: 5天
9. **前端优化**
- 添加错误边界
- 实现真实搜索功能
- 添加E2E测试
- 估计工时: 3天
---
## 总结与评分
### 各维度评分
| 维度 | 评分 | 权重 | 加权得分 |
|------|------|------|----------|
| **代码质量** | 7.5/10 | 20% | 1.5 |
| **安全性** | 7.5/10 | 20% | 1.5 |
| **可维护性** | 8/10 | 15% | 1.2 |
| **测试覆盖** | 5/10 | 15% | 0.75 |
| **性能** | 7.5/10 | 15% | 1.125 |
| **文档** | 8/10 | 10% | 0.8 |
| **架构设计** | 8/10 | 5% | 0.4 |
| **总体** | **7.3/10** | 100% | **7.275** |
### 关键结论
1. **架构设计优秀**: Monorepo + 三包分离架构清晰,便于维护和扩展
2. **代码质量良好**: 类型注解完善,文档详尽,结构清晰
3. **安全基础良好**: 没有严重的安全漏洞,基础防护到位
4. **测试是短板**: 28%覆盖率是最大风险点
5. **生产就绪**: 经过小幅改进后可以投入生产使用
### 下一步行动
| 优先级 | 任务 | 预计工时 | 影响 |
|--------|------|----------|------|
| 高 | 拆分AdminDB | 2天 | 提升可维护性 |
| 高 | 引入Redis队列 | 3天 | 解决任务丢失问题 |
| 高 | 提升测试覆盖率 | 5天 | 降低代码风险 |
| 中 | 修复安全漏洞 | 0.5天 | 提升安全性 |
| 中 | 统一配置管理 | 1天 | 减少配置错误 |
| 低 | 前端优化 | 3天 | 提升用户体验 |
---
## 附录
### 关键文件清单
| 文件 | 职责 | 问题 |
|------|------|------|
| `inference/data/admin_db.py` | 数据库操作 | 类过大,需拆分 |
| `inference/web/services/inference.py` | 推理服务 | 混合业务和技术 |
| `inference/web/workers/async_queue.py` | 异步队列 | 内存存储,易丢失 |
| `inference/web/core/scheduler.py` | 任务调度 | 缺少统一协调 |
| `shared/shared/config.py` | 共享配置 | 分散管理 |
### 参考资源
- [Repository Pattern](https://martinfowler.com/eaaCatalog/repository.html)
- [Celery Documentation](https://docs.celeryproject.org/)
- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/)
- [FastAPI Best Practices](https://fastapi.tiangolo.com/tutorial/bigger-applications/)
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
---
**报告生成时间**: 2026-02-01
**审查工具**: Claude Code + AST-grep + LSP

View File

@@ -26,7 +26,7 @@
### 项目现状
Invoice Master是一个基于YOLOv11 + PaddleOCR的瑞典发票字段自动提取系统具备以下核心能力
Invoice Master是一个基于YOLO26 + PaddleOCR的瑞典发票字段自动提取系统具备以下核心能力
| 指标 | 数值 | 评估 |
|------|------|------|

View File

@@ -0,0 +1,314 @@
# Inference Analysis Report
Date: 2026-02-11
Sample: 39 PDFs (diverse sizes from 1783 total), invoice-sm120 environment
## Executive Summary
| Metric | Value |
|--------|-------|
| Total PDFs tested | 39 |
| Successful responses | 35 (89.7%) |
| Timeouts (>120s) | 4 (10.3%) |
| Pure fallback (all fields conf=0.500) | 15/35 (42.9%) |
| Full extraction (all expected fields) | 6/35 (17.1%) |
| supplier_org_number extraction rate | 0% |
| InvoiceDate extraction rate | 31.4% |
| OCR extraction rate | 31.4% |
**Root Cause**: A critical DPI mismatch bug causes 43% of documents to lose all YOLO-detected field data, falling back to inaccurate regex patterns.
---
## Problem #1 (CRITICAL): DPI Mismatch - Field Extraction Failures
### Symptom
- 15/35 documents (43%) have ALL extracted fields at confidence=0.500 (fallback)
- YOLO detects fields correctly (6+ detections at conf 0.8-0.97) but text extraction returns nothing
- Examples: `4f822b0d` has 6 YOLO detections but only 1 field extracted via fallback
### Root Cause
**DPI not passed from pipeline to FieldExtractor** causing 2x coordinate scaling error.
```
pipeline.py:237 -> self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
^^^ DPI NOT PASSED! Defaults to 300
```
The chain:
1. `shared/config.py:22` defines `DEFAULT_DPI = 150`
2. `InferencePipeline.__init__()` receives `dpi=150` from `ModelConfig`
3. PDF rendered at **150 DPI** -> YOLO detections in 150 DPI pixel coordinates
4. `FieldExtractor` defaults to `dpi=300` (never receives the actual 150)
5. Coordinate conversion: `scale = 72 / self.dpi` = `72/300 = 0.24` instead of `72/150 = 0.48`
6. Bounding boxes are **halved** in PDF point space -> no tokens match -> empty extraction
7. Fallback regex triggers with conf=0.500
### Fix
**File**: `packages/backend/backend/pipeline/pipeline.py`, line 237
```python
# BEFORE (broken):
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
# AFTER (fixed):
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
```
### Impact
This single-line fix will recover ~43% of documents from degraded fallback to proper YOLO+OCR extraction.
---
## Problem #2 (HIGH): Fallback Amount Extraction Grabs Wrong Values
### Symptom
- 3 documents extracted Amount=1.00 when actual amounts are 7500.00, etc.
- Fallback regex matches table column header "Summa" followed by row quantity "1,00" instead of total
### Example
Document `2b7e4103` (Astra Football Club):
- Actual amount: **7 500,00 SEK**
- Extracted: **1.00** (from "Summa 1" where "1" is the article number in the next row)
### Root Cause
The fallback Amount regex in `pipeline.py:676`:
```python
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?'
```
matches "Summa" (column header) followed by "1" (first data in next row), because PaddleOCR produces tokens in position order. The greedy `[\d\s,\.]` captures "1" and stops at "Medlemsavgift".
### Fix
**File**: `packages/backend/backend/pipeline/pipeline.py`, lines 674-688
1. Require minimum amount value in fallback (e.g., > 10.00)
2. Require the matched amount to have a decimal separator (`,` or `.`) to avoid matching integers
3. Prefer "ATT BETALA" over "Summa" as the keyword (less ambiguous)
```python
'Amount': [
r'(?:att\s+betala)\s*[:.]?\s*([\d\s]+[,\.]\d{2})\s*(?:SEK|kr)?',
r'([\d\s]+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s]+[,\.]\d{2})\s*(?:SEK|kr)?',
],
```
---
## Problem #3 (HIGH): Fallback Bankgiro Regex False Positives
### Symptom
- Document `2b7e4103` extracts Bankgiro=2546-1610 but the actual document has NO Bankgiro
- The document has Plusgiro=2131575-9 and Org.nr=802546-1610
### Root Cause
Fallback Bankgiro regex in `pipeline.py:681`:
```python
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)'
```
matches the LAST 8 digits of org number "802546-1610" as "2546-1610".
### Fix
**File**: `packages/backend/backend/pipeline/pipeline.py`, line 681
Add negative lookbehind to avoid matching within longer numbers:
```python
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)', # Must not be preceded/followed by digits
],
```
---
## Problem #4 (MEDIUM): OCR Number Minimum 5-Digit Requirement
### Symptom
- Document `2b7e4103` has OCR=3046 (4 digits) which is valid but rejected by normalizer
- `OcrNumberNormalizer` requires minimum 5 digits
### Root Cause
**File**: `packages/backend/backend/pipeline/normalizers/ocr_number.py`, line 32:
```python
if len(digits) < 5:
return NormalizationResult.failure(f"Too few digits for OCR: {len(digits)}")
```
Swedish OCR numbers can be 2-25 digits. The 5-digit minimum is too restrictive.
### Fix
Lower minimum to 2 digits (or possibly 1 for very short OCR references):
```python
if len(digits) < 2:
return NormalizationResult.failure(f"Too few digits for OCR: {len(digits)}")
```
---
## Problem #5 (MEDIUM): InvoiceNumber Extracts Year (2025, 2026)
### Symptom
- 2 documents extract year as invoice number: "2025", "2026"
- `dc35ee8e`: actual invoice number visible in PDF but normalizer picks up year
- `56cabf73`: InvoiceNumber=2026
### Root Cause
**File**: `packages/backend/backend/pipeline/normalizers/invoice_number.py`, lines 54-72
The "Pattern 3: Short digit sequence" strategy prefers shorter sequences. When the YOLO bbox contains both the year "2025" and the actual invoice number, the shorter "2025" (4 digits) wins over a longer sequence.
### Fix
Add year exclusion to Pattern 3:
```python
for seq in digit_sequences:
if len(seq) == 8 and seq.startswith("20"):
continue # Skip YYYYMMDD dates
if len(seq) == 4 and seq.startswith("20"):
continue # Skip year-only values (2024, 2025, 2026)
if len(seq) > 10:
continue
valid_sequences.append(seq)
```
---
## Problem #6 (MEDIUM): InvoiceNumber vs OCR Mismatch
### Symptom
- 5 documents show InvoiceNumber different from OCR number
- Example: `87f470da` InvoiceNumber=852460234111905 vs OCR=524602341119055
- Example: `8b0674be` InvoiceNumber=508021404131 vs OCR=50802140413
### Root Cause
These are legitimate: InvoiceNumber and OCR are detected from DIFFERENT YOLO bounding boxes (different regions of the invoice). The InvoiceNumber normalizer picks a shorter sequence from the invoice_number bbox, while the OCR normalizer extracts from the ocr_number bbox. Cross-validation from payment_line should reconcile these but cross-validation isn't running (0 documents show cross_validation results).
### Diagnosis Needed
Check why cross-validation / payment_line parsing isn't populating `result.cross_validation` even when payment_line is extracted.
---
## Problem #7 (MEDIUM): supplier_org_number 0% Extraction Rate
### Symptom
- 0/35 documents extract supplier_org_number
- YOLO detects supplier_org_number in many documents (visible in detection classes)
- When extracted, the field appears as `supplier_organisation_number` (different name)
### Root Cause
This is actually a reporting issue. The API returns the field as `supplier_organisation_number` (full spelling) from `CLASS_TO_FIELD` mapping, but the analysis expected `supplier_org_number`. Looking at the actual data, 8/35 documents DO have `supplier_organisation_number` extracted.
However, the underlying issue is that even when YOLO detects `supplier_org_number`, the DPI bug prevents text extraction for text PDFs.
### Fix
Already addressed by Problem #1 (DPI fix). Additionally, ensure consistent field naming in API documentation.
---
## Problem #8 (LOW): Timeout Failures (4/39 documents)
### Symptom
- 4 PDFs timed out at 120 seconds
- File sizes: 89KB, 169KB, 239KB, 970KB (not correlated with size)
### Root Cause
Likely multi-page PDFs or PDFs with complex layouts requiring extensive OCR. The 120s timeout in the test script may be too short for multi-page documents + full-page OCR fallback.
### Fix
1. Increase API timeout for multi-page PDFs
2. Add page limit or early termination for very large documents
3. Log page count in response to correlate with processing time
---
## Problem #9 (LOW): Non-Invoice Documents in Dataset
### Symptom
- `dccf6655`: 0 detections, 0 fields - this is a screenshot of UI buttons, NOT an invoice
### Fix
Add document classification as a pre-processing step to reject non-invoice documents before running the expensive YOLO + OCR pipeline.
---
## Problem #10 (LOW): InvoiceDueDate Before InvoiceDate
### Symptom
- Document `11de4d07`: InvoiceDate=2026-01-16, InvoiceDueDate=2025-12-01
- Due date is BEFORE invoice date, which is illogical
### Root Cause
Either the date normalizer swapped the values, or the YOLO model detected the wrong region for one of the dates. The DPI bug (Problem #1) may also affect date extraction from the correct regions.
### Fix
Add post-extraction validation: if InvoiceDueDate < InvoiceDate, either swap them or flag for review.
---
## Priority Fix Order
| Priority | Fix | Impact | Effort |
|----------|-----|--------|--------|
| 1 | DPI mismatch (Problem #1) | 43% of docs recovered | 1 line change |
| 2 | Fallback amount regex (Problem #2) | 3+ docs with wrong amounts | Small regex fix |
| 3 | Fallback bankgiro regex (Problem #3) | False positive bankgiro | Small regex fix |
| 4 | OCR min digits (Problem #4) | Short OCR numbers supported | 1 line change |
| 5 | Year as InvoiceNumber (Problem #5) | 2+ docs | Small logic add |
| 6 | Date validation (Problem #10) | Logical consistency | Small validation add |
| 7 | Cross-validation (Problem #6) | Better field reconciliation | Investigation needed |
| 8 | Timeouts (Problem #8) | 4 docs | Config change |
| 9 | Document classification (Problem #9) | Filter non-invoices | Feature addition |
---
## Re-run Expected After Fix #1
After fixing the DPI mismatch alone, re-running the same 39 PDFs should show:
- Pure fallback rate dropping from 43% to near 0%
- InvoiceDate extraction rate improving from 31% to ~70%+
- OCR extraction rate improving from 31% to ~60%+
- Average confidence scores increasing significantly
- supplier_organisation_number extraction improving from 23% to ~60%+
---
## Detailed Per-PDF Results Summary
| PDF | Size | Time | Fields | Confidence | Issues |
|-----|------|------|--------|------------|--------|
| dccf6655 | 10KB | 17s | 0/0 | - | Not an invoice |
| 4f822b0d | 183KB | 37s | 1/6 | ALL 0.500 | DPI bug: 6 detections, 5 lost |
| d4af7848 | 55KB | 41s | 1/6 | ALL 0.500 | DPI bug: 6 detections, 5 lost |
| 19533483 | 262KB | 39s | 1/9 | ALL 0.500 | DPI bug: 9 detections, 8 lost |
| 2b7e4103 | 25KB | 47s | 3/6 | ALL 0.500 | DPI bug + Amount=1.00 wrong |
| 7717d293 | 34KB | 16s | 3/6 | ALL 0.500 | DPI bug + Amount=1.00 wrong |
| 3226ac59 | 66KB | 42s | 3/5 | ALL 0.500 | DPI bug + Amount=1.00 wrong |
| 0553e5c2 | 31KB | 18s | 3/6 | ALL 0.500 | DPI bug + BG=5000-0000 suspicious |
| 32e90db8 | 136KB | 40s | 3/7 | Mixed | Amount=2026.00 (year?) |
| dc35ee8e | 567KB | 83s | 7/9 | YOLO | InvoiceNumber=2025 (year) |
| 56cabf73 | 67KB | 19s | 5/6 | YOLO | InvoiceNumber=2026 (year) |
| 87f470da | 784KB | 42s | 9/14 | YOLO | InvNum vs OCR mismatch |
| 11de4d07 | 356KB | 68s | 5/3 | Mixed | DueDate < InvoiceDate |
| 0f9047a9 | 415KB | 22s | 8/6 | YOLO | Good extraction |
| 9d0b793c | 286KB | 18s | 8/8 | YOLO | Good extraction |
| 5604d375 | 915KB | 51s | 9/10 | YOLO | Good extraction |
| 87f470da | 784KB | 42s | 9/14 | YOLO | Good extraction |
| f40fd418 | 523KB | 90s | 9/9 | YOLO | Good extraction |
---
## Field Extraction Rate Summary
| Field | Present | Missing | Rate | Avg Conf |
|-------|---------|---------|------|----------|
| Bankgiro | 32 | 3 | 91.4% | 0.681 |
| InvoiceNumber | 28 | 7 | 80.0% | 0.695 |
| Amount | 27 | 8 | 77.1% | 0.726 |
| InvoiceDueDate | 13 | 22 | 37.1% | 0.883 |
| InvoiceDate | 11 | 24 | 31.4% | 0.879 |
| OCR | 11 | 24 | 31.4% | 0.900 |
| customer_number | 11 | 24 | 31.4% | 0.926 |
| payment_line | 9 | 26 | 25.7% | 0.938 |
| Plusgiro | 3 | 32 | 8.6% | 0.948 |
| supplier_org_number | 0 | 35 | 0.0% | 0.000 |
Note: Fields with high confidence but low extraction rate (InvoiceDate 0.879, OCR 0.900, payment_line 0.938) indicate the DPI bug: when extraction works (via YOLO), confidence is high. The low rate is because most documents fall back and these fields have no fallback regex pattern.

View File

@@ -0,0 +1,257 @@
# Semi-Automatic Labeling Strategy Analysis
## 1. Current Pipeline Overview
```
CSV (field values)
|
v
Autolabel CLI
|- PDF render (300 DPI)
|- Text extraction (PDF text layer or PaddleOCR)
|- FieldMatcher.find_matches() [5 strategies]
| |- ExactMatcher (priority 1)
| |- ConcatenatedMatcher (multi-token)
| |- FuzzyMatcher (Amount, dates only)
| |- SubstringMatcher (prevents false positives)
| |- FlexibleDateMatcher (fallback)
|
|- AnnotationGenerator
| |- PDF points -> pixels
| |- expand_bbox() [field-specific strategy]
| |- pixels -> YOLO normalized (0-1)
| |- Save to database
|
v
DBYOLODataset (training)
|- Load images + bboxes from DB
|- Re-apply expand_bbox()
|- YOLO training
|
v
Inference
|- YOLO detect -> pixel bboxes
|- Crop region -> OCR extract text
|- Normalize & validate
```
---
## 2. Current Expansion Strategy Analysis
### 2.1 Field-Specific Parameters
| Field | Scale X | Scale Y | Extra Top | Extra Left | Extra Right | Max Pad X | Max Pad Y |
|---|---|---|---|---|---|---|---|
| ocr_number | 1.15 | 1.80 | 0.60 | - | - | 50 | 140 |
| bankgiro | 1.45 | 1.35 | - | 0.80 | - | 160 | 90 |
| plusgiro | 1.45 | 1.35 | - | 0.80 | - | 160 | 90 |
| invoice_date | 1.25 | 1.55 | 0.40 | - | - | 80 | 110 |
| invoice_due_date | 1.30 | 1.65 | 0.45 | 0.35 | - | 100 | 120 |
| amount | 1.20 | 1.35 | - | - | 0.30 | 70 | 80 |
| invoice_number | 1.20 | 1.50 | 0.40 | - | - | 80 | 100 |
| supplier_org_number | 1.25 | 1.40 | 0.30 | 0.20 | - | 90 | 90 |
| customer_number | 1.25 | 1.45 | 0.35 | 0.25 | - | 90 | 100 |
| payment_line | 1.10 | 1.20 | - | - | - | 40 | 40 |
### 2.2 Design Rationale
The expansion is designed based on Swedish invoice layout patterns:
- **Dates**: Labels ("Fakturadatum") typically sit **above** the value -> extra top
- **Giro accounts**: Prefix ("BG:", "PG:") sits **to the left** -> extra left
- **Amount**: Currency suffix ("SEK", "kr") to the **right** -> extra right
- **Payment line**: Machine-readable, self-contained -> minimal expansion
### 2.3 Strengths
1. **Field-specific directional expansion** - matches Swedish invoice conventions
2. **Max padding caps** - prevents runaway expansion into neighboring fields
3. **Center-point scaling** with directional compensation - geometrically sound
4. **Image boundary clamping** - prevents out-of-bounds coordinates
### 2.4 Potential Issues
| Issue | Risk Level | Description |
|---|---|---|
| Over-expansion | HIGH | OCR number 1.80x Y-scale could capture adjacent fields |
| Inconsistent training vs inference bbox | MEDIUM | Model trained on expanded boxes, inference returns raw detection |
| No expansion at inference OCR crop | MEDIUM | Detected bbox may clip text edges without post-expansion |
| Max padding in pixels vs DPI-dependent | LOW | 140px at 300DPI != 140px at 150DPI |
---
## 3. Industry Best Practices (Research Findings)
### 3.1 Labeling: Tight vs. Loose Bounding Boxes
**Consensus**: Annotate **tight bounding boxes around the value text only**.
- FUNSD/CORD benchmarks annotate keys and values as **separate entities**
- Loose boxes "introduce background noise and can mislead the model" (V7 Labs, LabelVisor)
- IoU discrepancies from loose boxes degrade mAP during training
**However**, for YOLO + OCR pipelines, tight-only creates a problem:
- YOLO predicts slightly imprecise boxes (typical IoU 0.7-0.9)
- If the predicted box clips even slightly, OCR misses characters
- Solution: **Label tight, expand at inference** OR **label with controlled padding**
### 3.2 The Two Dominant Strategies
**Strategy A: Tight Label + Inference-Time Expansion** (Recommended by research)
```
Label: [ 2024-01-15 ] (tight around value)
Inference: [ [2024-01-15] ] + pad -> OCR
```
- Clean, consistent annotations
- Requires post-detection padding before OCR crop
- Used by: Microsoft Document Intelligence, Nanonets
**Strategy B: Expanded Label at Training Time** (Current project approach)
```
Label: [ Fakturadatum: 2024-01-15 ] (includes context)
Inference: YOLO detects full region -> OCR extracts from region
```
- Model learns spatial context (label + value)
- Larger, more variable boxes
- OCR must filter out label text from extracted content
### 3.3 OCR Padding Requirements
**Tesseract**: Requires ~10px white border for reliable segmentation (PSM 7-10).
**PaddleOCR**: `det_db_unclip_ratio` parameter (default 1.5) controls detection expansion.
Key insight: Even after YOLO detection, OCR engines need some padding around text to work reliably.
### 3.4 State-of-the-Art Comparison
| System | Bbox Strategy | Field Definition |
|---|---|---|
| **LayoutLM** | Word-level bboxes from OCR | Token classification (BIO tagging) |
| **Donut** | No bboxes (end-to-end) | Internal attention mechanism |
| **Microsoft DocAI** | Field-level, tight | Post-expansion for OCR |
| **YOLO + OCR (this project)** | Field-level, expanded | Field-specific directional expansion |
---
## 4. Recommendations
### 4.1 Short-Term (Current Architecture)
#### A. Add Inference-Time OCR Padding
Currently, the detected bbox is sent directly to OCR. Add a small uniform padding (5-10%) before cropping for OCR:
```python
# In field_extractor.py, before OCR crop:
pad_ratio = 0.05 # 5% expansion
w_pad = (x2 - x1) * pad_ratio
h_pad = (y2 - y1) * pad_ratio
crop_x1 = max(0, x1 - w_pad)
crop_y1 = max(0, y1 - h_pad)
crop_x2 = min(img_w, x2 + w_pad)
crop_y2 = min(img_h, y2 + h_pad)
```
#### B. Reduce Training-Time Expansion Ratios
Current ratios (especially OCR number 1.80x Y, Bankgiro 1.45x X) are aggressive. Proposed reduction:
| Field | Current Scale Y | Proposed Scale Y | Rationale |
|---|---|---|---|
| ocr_number | 1.80 | 1.40 | 1.80 is too aggressive, captures neighbors |
| bankgiro | 1.35 | 1.25 | Reduce vertical over-expansion |
| invoice_due_date | 1.65 | 1.45 | Tighten vertical |
Principle: **shift expansion work from training-time to inference-time**.
#### C. Add Label Visualization Quality Check
Before training, sample 50-100 annotated images and visually inspect:
- Are expanded bboxes capturing only the target field?
- Are any bboxes overlapping with adjacent fields?
- Are any values being clipped?
### 4.2 Medium-Term (Architecture Improvements)
#### D. Two-Stage Detection Strategy
```
Stage 1: YOLO detects field regions (current)
Stage 2: Within each detection, use PaddleOCR text detection
to find the precise text boundary
Stage 3: Extract text from refined boundary
```
Benefits:
- YOLO handles field classification (what)
- PaddleOCR handles text localization (where exactly)
- Eliminates the "tight vs loose" dilemma entirely
#### E. Label Both Key and Value Separately
Add new annotation classes: `invoice_date_label`, `invoice_date_value`
- Model learns to find both the label and value
- Use spatial relationship (label -> value) for more robust extraction
- Aligns with FUNSD benchmark approach
#### F. Confidence-Weighted Expansion
Scale expansion by detection confidence:
```python
# Higher confidence = tighter crop (model is sure)
# Lower confidence = wider crop (give OCR more context)
expansion = base_expansion * (1.5 - confidence)
```
### 4.3 Long-Term (Next Generation)
#### G. Move to LayoutLM-Style Token Classification
- Replace YOLO field detection with token-level classification
- Each OCR word gets classified as B-field/I-field/O
- Eliminates bbox expansion entirely
- Better for fields with complex layouts
#### H. End-to-End with Donut/Pix2Struct
- No separate OCR step
- Model directly outputs structured fields from image
- Zero bbox concerns
- Requires more training data and compute
---
## 5. Recommended Action Plan
### Phase 1: Validate Current Labels (1-2 days)
- [ ] Build label visualization script
- [ ] Sample 100 documents across all field types
- [ ] Identify over-expansion and clipping cases
- [ ] Document per-field accuracy of current expansion
### Phase 2: Tune Expansion Parameters (2-3 days)
- [ ] Reduce aggressive expansion ratios (OCR number, bankgiro)
- [ ] Add inference-time OCR padding (5-10%)
- [ ] Re-train model with adjusted labels
- [ ] Compare mAP and field extraction accuracy
### Phase 3: Two-Stage Refinement (1 week)
- [ ] Implement PaddleOCR text detection within YOLO detection
- [ ] Use text detection bbox for precise OCR crop
- [ ] Keep YOLO expansion for classification only
### Phase 4: Evaluation (ongoing)
- [ ] Track per-field extraction accuracy on test set
- [ ] A/B test tight vs expanded labels
- [ ] Build regression test suite for labeling quality
---
## 6. Summary
| Aspect | Current Approach | Best Practice | Gap |
|---|---|---|---|
| **Labeling** | Value + expansion at label time | Tight value + inference expansion | Medium |
| **Expansion** | Field-specific directional | Field-specific directional | Aligned |
| **Inference OCR crop** | Raw detection bbox | Detection + padding | Needs padding |
| **Expansion ratios** | Aggressive (up to 1.80x) | Moderate (1.10-1.30x) | Over-expanded |
| **Visualization QC** | None | Regular sampling | Missing |
| **Coordinate consistency** | PDF points -> pixels | Consistent DPI | Check needed |
**Bottom line**: The architecture (field-specific directional expansion) is sound and aligns with best practices. The main improvements are:
1. **Reduce expansion aggressiveness** during training labels
2. **Add OCR padding** at inference time
3. **Add label quality visualization** for validation
4. Longer term: consider **two-stage detection** or **token classification**

153
PLAN_TWO_STAGE_DETECTION.md Normal file
View File

@@ -0,0 +1,153 @@
# Plan: Two-Stage Detection (YOLO + PaddleOCR Value Selection)
## Context
Current inference flow: YOLO detects field region -> crop region -> PaddleOCR reads ALL text -> concatenate -> normalizer extracts value via regex.
Problem: When training labels include label+value (e.g., "Fakturadatum 2024-01-15"), the detected region contains both label text and value text. Currently all OCR tokens are concatenated, and normalizers must regex out the value. This works for most fields but is fragile.
Solution: After PaddleOCR returns individual text lines from the detected region, add a **value selection** step that picks the most likely value token(s) BEFORE sending to normalizer. This gives normalizers cleaner input and provides a more precise value bbox.
## Key Insight
PaddleOCR already returns individual `OCRToken` objects (text + bbox + confidence). The current code just concatenates them all (line 227 of `field_extractor.py`):
```python
raw_text = ' '.join(t.text for t in ocr_tokens)
```
The change: replace blind concatenation with field-aware token selection.
## Architecture
```
Current:
YOLO bbox -> crop -> PaddleOCR -> [all tokens] -> concat -> normalizer
New:
YOLO bbox -> crop -> PaddleOCR -> [all tokens] -> value_selector -> normalizer
| |
individual selected
text lines value token(s)
```
## Scope: Inference Only
| Pipeline Stage | Changed? | Reason |
|---|---|---|
| **Labeling** (autolabel, expansion) | NO | Expanded bbox gives YOLO strong visual patterns for generalization |
| **Training** (YOLO26s) | NO | Model learns field regions correctly with current labels |
| **Inference - Detection** (YOLO) | NO | Detection output is correct |
| **Inference - Extraction** (OCR -> text) | **YES** | Add ValueSelector between OCR tokens and normalizer |
This design works with ANY model -- tight-label model (15K docs) or expanded-label model (58 images). ValueSelector is model-agnostic.
## Files to Modify
| File | Change |
|------|--------|
| `packages/backend/backend/pipeline/value_selector.py` | **NEW** - Value selection logic per field type |
| `packages/backend/backend/pipeline/field_extractor.py` | Use ValueSelector in both extraction paths |
| `tests/pipeline/test_value_selector.py` | **NEW** - Tests for value selection |
## Implementation
### 1. New File: `value_selector.py`
Core class `ValueSelector` with method:
```python
def select_value_tokens(
self,
tokens: list[OCRToken],
field_name: str
) -> list[OCRToken]:
```
Field-specific selection rules:
| Field | Strategy | Pattern |
|-------|----------|---------|
| InvoiceDate, InvoiceDueDate | Date pattern match | `\d{4}[-./]\d{2}[-./]\d{2}` or `\d{2}[-./]\d{2}[-./]\d{4}` or `\d{8}` |
| Amount | Number pattern match | `\d[\d\s]*[,.]\d{2}` (prefer tokens with comma/decimal) |
| Bankgiro | Giro pattern match | `\d{3,4}-\d{4}` or 7-8 consecutive digits |
| Plusgiro | Giro pattern match | `\d+-\d` or 2-8 digits |
| OCR | Longest digit sequence | Token with most consecutive digits (min 5) |
| InvoiceNumber | Exclude known labels | Remove tokens matching Swedish label keywords, keep rest |
| supplier_org_number | Org number pattern | `\d{6}-?\d{4}` |
| customer_number | Exclude labels | Remove label keywords, keep alphanumeric tokens |
| payment_line | Full concatenation | Keep all tokens (payment line needs full text) |
Label keyword exclusion list (Swedish):
```python
LABEL_KEYWORDS = {
"fakturanummer", "fakturadatum", "forfallodag", "forfalldatum",
"bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa",
"total", "att", "betala", "kundnummer", "organisationsnummer",
"org", "nr", "datum", "nummer", "ref", "referens",
"momsreg", "vat", "moms", "sek", "kr",
}
```
Selection algorithm:
1. Try field-specific pattern matching on individual tokens
2. If match found -> return matched token(s) only
3. If no match -> **fallback to ALL tokens** (current behavior, so we never lose data)
### 2. Modify: `field_extractor.py`
**OCR path** (`extract_from_detection`, line 224-228):
```python
# Before (current):
ocr_tokens = self.ocr_engine.extract_from_image(region)
raw_text = ' '.join(t.text for t in ocr_tokens)
# After:
ocr_tokens = self.ocr_engine.extract_from_image(region)
value_tokens = self._value_selector.select_value_tokens(ocr_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
```
**PDF text path** (`extract_from_detection_with_pdf`, line 172-174):
```python
# Before (current):
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# After:
matching_tokens.sort(key=lambda x: -x[1])
all_text_tokens = [OCRToken(text=t[0].text, bbox=t[0].bbox, confidence=1.0) for t in matching_tokens]
value_tokens = self._value_selector.select_value_tokens(all_text_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
```
**Constructor**: Add `ValueSelector` instance.
### 3. Tests: `test_value_selector.py`
Test cases per field type:
- Date: "Fakturadatum 2024-01-15" -> selects "2024-01-15"
- Amount: "Belopp 1 234,56 kr" -> selects "1 234,56"
- Bankgiro: "BG: 123-4567" -> selects "123-4567"
- OCR: "OCR 1234567890" -> selects "1234567890"
- InvoiceNumber: "Fakturanr INV-2024-001" -> selects "INV-2024-001"
- Fallback: Unknown pattern -> returns all tokens (no data loss)
## Key Design Decisions
1. **Fallback to full text**: If value selection can't identify the value, return ALL tokens. This means the change can never make things worse than current behavior.
2. **ValueSelector is stateless**: Pure function, no side effects. Easy to test.
3. **No training changes**: Training labels stay as-is (expanded bboxes). Only inference pipeline changes.
4. **No normalizer changes**: Normalizers still work the same. They just get cleaner input.
## Verification
1. Run existing tests: `pytest tests/pipeline/ -v`
2. Run new tests: `pytest tests/pipeline/test_value_selector.py -v`
3. Manual validation: Run inference on a few invoices and compare raw_text before/after
4. Regression check: Ensure no field extraction accuracy drops on existing test documents

View File

@@ -8,11 +8,11 @@
## 项目概述
**Invoice Master POC v2** - 基于 YOLOv11 + PaddleOCR 的瑞典发票字段自动提取系统
**Invoice Master POC v2** - 基于 YOLO26 + PaddleOCR 的瑞典发票字段自动提取系统
### 核心功能
- **自动标注**: 利用 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
- **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
- **模型训练**: 使用 YOLO26 训练字段检测模型,支持数据增强
- **推理提取**: 检测字段区域 → OCR 提取文本 → 字段规范化
- **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
@@ -175,7 +175,7 @@ frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
| 组件 | 技术选择 | 评估 |
|------|----------|------|
| **目标检测** | YOLOv11 (Ultralytics) | ✅ 业界标准 |
| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) | ✅ 业界标准 |
| **OCR 引擎** | PaddleOCR v5 | ✅ 支持瑞典语 |
| **PDF 处理** | PyMuPDF (fitz) | ✅ 功能强大 |
| **数据库** | PostgreSQL + SQLModel | ✅ 类型安全 |

139
README.md
View File

@@ -1,14 +1,14 @@
# Invoice Master POC v2
自动发票字段提取系统 - 使用 YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
自动发票字段提取系统 - 使用 YOLO26 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
## 项目概述
本项目实现了一个完整的发票字段自动提取流程:
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注
2. **模型训练**: 使用 YOLOv11 训练字段检测模型,支持数据增强
3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> 字段规范化
1. **自动标注**: 利用已有 CSV 结构化数据 + OCR 自动生成 YOLO 训练标注(统一 15px 填充)
2. **模型训练**: 使用 YOLO26 训练字段检测模型,支持数据增强
3. **推理提取**: 检测字段区域 -> OCR 提取文本 -> ValueSelector 过滤标签 -> 字段规范化
4. **Web 管理**: React 前端 + FastAPI 后端,支持文档管理、数据集构建、模型训练和版本管理
### 架构
@@ -37,8 +37,8 @@ frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|------|------|
| **已标注文档** | 9,738 (9,709 成功) |
| **总体字段匹配率** | 94.8% (82,604/87,121) |
| **测试** | 2,058 passed |
| **测试覆盖率** | 60% |
| **测试** | 2,047 passed |
| **测试覆盖率** | 72% |
| **模型 mAP@0.5** | 93.5% |
**各字段匹配率:**
@@ -64,10 +64,10 @@ frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
| 环境 | 要求 |
|------|------|
| **WSL** | WSL 2 + Ubuntu 22.04 |
| **WSL** | WSL 2 + Ubuntu 22.04 (或 24.04 for RTX 50 系列) |
| **Conda** | Miniconda 或 Anaconda |
| **Python** | 3.11+ (通过 Conda 管理) |
| **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) |
| **Python** | 3.11+ (通过 Conda 管理), 3.10 for RTX 50 系列 |
| **GPU** | NVIDIA GPU + CUDA 12.x (RTX 50 系列见 SM120 章节) |
| **数据库** | PostgreSQL (存储标注结果) |
## 安装
@@ -77,8 +77,8 @@ frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
wsl -d Ubuntu-22.04
# 2. 创建 Conda 环境
conda create -n invoice-py311 python=3.11 -y
conda activate invoice-py311
conda create -n invoice-sm120 python=3.11 -y
conda activate invoice-sm120
# 3. 进入项目目录
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
@@ -89,6 +89,85 @@ pip install -e packages/training
pip install -e packages/backend
```
## RTX 5080 (Blackwell SM 120) GPU 设置
RTX 50 系列 (Blackwell 架构) 使用 SM 120 计算能力,官方 PaddlePaddle 仅支持到 SM 90。需要使用社区编译的 SM120 wheel。
### 系统要求
| 要求 | 版本 |
|------|------|
| **WSL** | Ubuntu 24.04 (glibc 2.39+) |
| **Python** | 3.10 (wheel 限制) |
| **CUDA** | 13.0+ (通过 pip nvidia 包) |
### 升级 WSL 到 Ubuntu 24.04
```bash
# 检查当前版本
lsb_release -a
# 如果是 22.04,需要升级
sudo sed -i 's/Prompt=lts/Prompt=normal/g' /etc/update-manager/release-upgrades
sudo apt update && sudo apt upgrade -y
sudo do-release-upgrade
```
### 创建 SM120 环境
```bash
# 1. 创建 Python 3.10 环境
conda create -n invoice-sm120 python=3.10 -y
conda activate invoice-sm120
# 2. 安装 SM120 PaddlePaddle wheel
pip install https://github.com/horhe-dvlp/paddlepaddle-sm120-wheels/releases/download/v3.0.0/paddlepaddle_gpu-3.0.0-cp310-cp310-linux_x86_64.whl
# 3. 安装项目依赖
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
pip install -e packages/shared
pip install -e packages/training
pip install -e packages/backend
```
### 配置环境变量
`~/.bashrc` 中添加:
```bash
# PaddlePaddle SM120 (RTX 50 series) environment
export PADDLE_SM120_LIBS=/home/kai/.local/lib/python3.10/site-packages/nvidia
alias activate-sm120='export LD_LIBRARY_PATH=$PADDLE_SM120_LIBS/cublas/lib:$PADDLE_SM120_LIBS/cudnn/lib:$PADDLE_SM120_LIBS/cuda_runtime/lib:/usr/lib/wsl/lib:$LD_LIBRARY_PATH && export PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK=True && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120'
```
### 使用
```bash
# 激活 SM120 环境
source ~/.bashrc
activate-sm120
# 验证 GPU
python -c "import paddle; paddle.utils.run_check()"
# 运行服务
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
python run_server.py --port 8000
```
### 故障排除
| 错误 | 解决方案 |
|------|---------|
| `GLIBCXX_3.4.32 not found` | 升级到 Ubuntu 24.04 |
| `GLIBC_2.38 not found` | 升级到 Ubuntu 24.04 |
| `cublasLtCreate` 失败 | 检查 LD_LIBRARY_PATH 包含 nvidia 库路径 |
| `Mismatched GPU Architecture` | 使用 SM120 wheel不要用官方 paddle |
### 云部署
Azure/AWS GPU 实例 (A100, H100, T4, V100) 使用官方 PaddlePaddle无需 SM120 wheel。
## 项目结构
```
@@ -125,7 +204,7 @@ invoice-master-poc-v2/
│ ├── run_server.py # Web 服务器入口
│ └── backend/
│ ├── cli/ # infer, serve
│ ├── pipeline/ # YOLO 检测, 字段提取, 解析器
│ ├── pipeline/ # YOLO 检测, 字段提取, ValueSelector, 解析器
│ ├── web/ # FastAPI 应用
│ │ ├── api/v1/ # REST API (admin, public, batch)
│ │ ├── schemas/ # Pydantic 数据模型
@@ -199,7 +278,7 @@ python -m training.cli.autolabel --workers 4
```bash
# 从预训练模型开始训练
python -m training.cli.train \
--model yolo11n.pt \
--model yolo26s.pt \
--epochs 100 \
--batch 16 \
--name invoice_fields \
@@ -207,7 +286,7 @@ python -m training.cli.train \
# 低内存模式
python -m training.cli.train \
--model yolo11n.pt \
--model yolo26s.pt \
--epochs 100 \
--name invoice_fields \
--low-memory
@@ -235,7 +314,7 @@ python -m backend.cli.infer \
```bash
# 从 Windows PowerShell 启动
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
# 启动前端
cd frontend && npm install && npm run dev
@@ -364,6 +443,30 @@ result = parser.parse("Said, Shakar Umj 436-R Billo")
print(f"Customer Number: {result}") # "UMJ 436-R"
```
## 推理流水线 (Two-Stage Detection)
```
YOLO bbox -> crop -> PaddleOCR -> [all tokens] -> ValueSelector -> normalizer
| |
individual selected
text lines value token(s)
```
**BBox 扩展**: 所有字段统一使用 15px 填充150 DPI 下约 2.5mm),不做方向性扩展,不依赖布局假设。
**ValueSelector**: 在 OCR 和 normalizer 之间按字段类型过滤标签文本,只保留值 token
| 字段 | 选择策略 | 示例输入 -> 输出 |
|------|---------|-----------------|
| InvoiceDate / DueDate | 日期模式匹配 | "Fakturadatum 2024-01-15" -> "2024-01-15" |
| Amount | 金额模式匹配 | "Belopp 1 234,56 kr" -> "1 234,56" |
| Bankgiro / Plusgiro | Giro 号码模式 | "BG: 123-4567" -> "123-4567" |
| OCR | 最长数字序列 (>=5位) | "OCR 94228110015950070" -> "94228110015950070" |
| InvoiceNumber | 排除瑞典语标签 | "Fakturanr INV-2024-001" -> "INV-2024-001" |
| payment_line | 保留全部 | 不过滤 |
如果没有匹配到任何模式,回退返回全部 token永远不会比之前更差
## DPI 配置
系统所有组件统一使用 **150 DPI**。DPI 必须在训练和推理时保持一致。
@@ -427,9 +530,9 @@ DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
| 指标 | 数值 |
|------|------|
| **测试总数** | 2,058 |
| **测试总数** | 2,047 |
| **通过率** | 100% |
| **覆盖率** | 60% |
| **覆盖率** | 72% |
## 存储抽象层
@@ -540,7 +643,7 @@ npm run dev
| 组件 | 技术 |
|------|------|
| **目标检测** | YOLOv11 (Ultralytics) |
| **目标检测** | YOLO26 (Ultralytics >= 8.4.0) |
| **OCR 引擎** | PaddleOCR v5 (PP-OCRv5) |
| **PDF 处理** | PyMuPDF (fitz) |
| **数据库** | PostgreSQL + SQLModel |

View File

@@ -76,7 +76,7 @@ matching:
# YOLO Training
yolo:
model: yolov8s # Model architecture (yolov8n/s/m/l/x)
model: yolo26s # Model architecture (yolo26n/s/m/l/x)
epochs: 100
batch_size: 16
img_size: 1280 # Image size for training

View File

@@ -2,7 +2,7 @@
# Use with: yolo train data=dataset.yaml cfg=training.yaml
# Model
model: yolov8s.pt
model: yolo26s.pt
# Training hyperparameters
epochs: 100
@@ -57,3 +57,12 @@ name: invoice_fields
exist_ok: true
pretrained: true
verbose: true
# Fine-tuning profile (overrides when task_type == finetune)
finetune:
epochs: 10
lr0: 0.001
freeze: 10
warmup_epochs: 1
cos_lr: true
patience: 5

View File

@@ -32,7 +32,7 @@
### 1.1 项目背景
Invoice Master是一个基于YOLOv11 + PaddleOCR的发票字段自动提取系统当前准确率达到94.8%。本方案设计将Invoice Master作为Fortnox会计软件的插件/扩展,实现无缝的发票数据导入功能。
Invoice Master是一个基于YOLO26 + PaddleOCR的发票字段自动提取系统当前准确率达到94.8%。本方案设计将Invoice Master作为Fortnox会计软件的插件/扩展,实现无缝的发票数据导入功能。
### 1.2 目标

View File

@@ -500,7 +500,7 @@ estimator = PyTorch(
hyperparameters={
"epochs": 100,
"batch-size": 16,
"model": "yolo11n.pt"
"model": "yolo26s.pt"
}
)
```

View File

@@ -152,7 +152,7 @@ rclone mount azure:training-images Z: --vfs-cache-mode full
### 推荐: Container Apps (CPU)
对于 YOLO 推理,**CPU 足够**,不需要 GPU
- YOLOv11n 在 CPU 上推理时间 ~200-500ms
- YOLO26s 在 CPU 上推理时间 ~200-500ms
- 比 GPU 便宜很多,适合中低流量
```yaml
@@ -335,7 +335,7 @@ az containerapp create \
│ ~$30/月 │ │ ~$1-5/次训练 │ │ │
│ │ │ │ │ │
│ ┌───────────────────┐ │ │ ┌───────────────────┐ │ │ ┌───────────────────┐ │
│ │ FastAPI + YOLO │ │ │ │ YOLOv11 Training │ │ │ │ React/Vue 前端 │ │
│ │ FastAPI + YOLO │ │ │ │ YOLO26 Training │ │ │ │ React/Vue 前端 │ │
│ │ /api/v1/infer │ │ │ │ 100 epochs │ │ │ │ 上传发票界面 │ │
│ └───────────────────┘ │ │ └───────────────────┘ │ │ └───────────────────┘ │
└───────────┬───────────┘ └───────────┬───────────┘ └───────────┬───────────┘

View File

@@ -0,0 +1,185 @@
# YOLO Model Fine-Tuning Best Practices
Production guide for continuous fine-tuning of YOLO object detection models with user feedback.
## Overview
When users report failed detections, those documents are collected, reviewed, and used to incrementally improve the model without degrading performance on existing data.
Key risks:
- **Catastrophic forgetting**: model forgets original training after fine-tuning on small new data
- **Cumulative drift**: repeated fine-tuning sessions cause progressive degradation
- **Overfitting**: few samples + many epochs = memorizing noise
## 1. Data Management
```
Original training set (25K) --> permanently retained as "anchor dataset"
|
User-reported failures --> human review & labeling --> "fine-tune pool"
|
Fine-tune pool accumulates over time, never deleted
```
Every new sample MUST be human-verified before entering the fine-tune pool. Incorrect labels are more harmful than no labels.
### Data Mixing Ratios
| Accumulated New Samples | Old Data Multiplier | Total Training Size |
|------------------------|--------------------|--------------------|
| 10 | 50x (500) | 510 |
| 50 | 20x (1,000) | 1,050 |
| 200 | 10x (2,000) | 2,200 |
| 500+ | 5x (2,500) | 3,000 |
Principle: fewer new samples require higher old data ratio. Stabilize at 5x once pool reaches 500+.
Old samples are randomly sampled from the original 25K each time, ensuring broad coverage.
## 2. Model Version Management
```
base_v1.pt (original 25K training)
+-- ft_v1.1.pt (base + fine-tune batch 1)
+-- ft_v1.2.pt (base + fine-tune batch 1+2)
+-- ...
When fine-tune pool reaches 2000+ samples:
base_v2.pt (original 25K + all accumulated samples, trained from scratch)
+-- ft_v2.1.pt
+-- ...
```
CRITICAL: Never chain fine-tunes (ft_v1.1 -> ft_v1.2 -> ft_v1.3). Always start from the base model to avoid cumulative drift.
## 3. Fine-Tuning Parameters
```yaml
base_model: best.pt # always start from base model
epochs: 10 # few epochs are sufficient
lr0: 0.001 # 1/10 of base training lr
freeze: 10 # freeze first 10 backbone layers
warmup_epochs: 1
cos_lr: true
# data mixing
new_samples: all # entire fine-tune pool
old_samples: min(5x_new, 3000) # old data sampling, cap at 3000
```
### Why These Settings
| Parameter | Rationale |
|-----------|-----------|
| `epochs: 10` | More than enough for small datasets; prevents overfitting |
| `lr0: 0.001` | Low learning rate preserves base model knowledge |
| `freeze: 10` | Backbone features are general; only fine-tune detection head and later layers |
| `cos_lr: true` | Smooth decay prevents sharp weight updates |
## 4. Deployment Gating (Most Important)
Every fine-tuned model MUST pass three gates before deployment:
### Gate 1: Regression Validation
Run evaluation on the original test set (held out from the 25K training data).
| mAP50 Change | Action |
|-------------|--------|
| Drop < 1% | PASS - deploy |
| Drop 1-3% | REVIEW - human inspection required |
| Drop > 3% | REJECT - do not deploy |
### Gate 2: New Sample Validation
Run inference on the new failure documents.
| Detection Rate | Action |
|---------------|--------|
| > 80% correct | PASS |
| < 80% correct | REVIEW - check label quality or increase training |
### Gate 3: A/B Comparison (Optional)
Sample 100 production documents, run both old and new models:
- New model must not be worse on any field type
- Compare per-class mAP to detect targeted regressions
## 5. Fine-Tuning Frequency
| Strategy | Trigger | Recommendation |
|----------|---------|---------------|
| **By volume (recommended)** | Pool reaches 50+ new samples | Best signal-to-noise ratio |
| By schedule | Weekly or monthly | Predictable but may trigger with insufficient data |
| By performance | Monitored accuracy drops below threshold | Reactive, requires monitoring infrastructure |
Do NOT fine-tune daily with fewer than 50 samples. The noise outweighs the signal.
## 6. Complete Workflow
```
User marks failed document
|
v
Human reviews and labels annotations
|
v
Add to fine-tune pool
|
v
Pool >= 50 samples? --NO--> Wait for more samples
|
YES
|
v
Prepare mixed dataset:
- All samples from fine-tune pool
- Random sample 5x from original 25K
|
v
Fine-tune from base.pt:
- 10 epochs
- lr0 = 0.001
- freeze first 10 layers
|
v
Gate 1: Original test set mAP drop < 1%?
|
PASS
|
v
Gate 2: New sample detection rate > 80%?
|
PASS
|
v
Deploy new model, retain old model for rollback
|
v
Pool accumulated 2000+ samples?
|
YES --> Merge all data, train new base from scratch
```
## 7. Monitoring in Production
Track these metrics continuously:
| Metric | Purpose | Alert Threshold |
|--------|---------|----------------|
| Detection rate per field | Catch field-specific regressions | < 90% for any field |
| Average confidence score | Detect model uncertainty drift | Drop > 5% from baseline |
| User-reported failures / week | Measure improvement trend | Increasing over 3 weeks |
| Inference latency | Ensure model size hasn't bloated | > 2x baseline |
## 8. Summary of Rules
| Rule | Practice |
|------|----------|
| Never chain fine-tunes | Always start from base.pt |
| Never use only new data | Must mix with old data |
| Never fine-tune on < 50 samples | Accumulate before triggering |
| Never auto-deploy | Must pass gating validation |
| Never discard old models | Retain versions for rollback |
| Periodically retrain base | Merge all data at 2000+ new samples |
| Always human-review labels | Bad labels are worse than no labels |

View File

@@ -39,27 +39,50 @@ PDF/Image
**Goal**: 在独立分支验证 PP-StructureV3 能否正确检测瑞典发票表格
**Tasks**:
1. 创建 `feature/business-invoice` 分支
2. 升级依赖:
- `paddlepaddle>=3.0.0`
- `paddleocr>=3.0.0`
3. 创建 PP-StructureV3 wrapper:
- `src/table/structure_detector.py`
4. 用 5-10 张真实发票测试表格检测效果
5. 验证与现有 YOLO pipeline 的兼容性
**Status**: COMPLETED
**Critical Files**:
- [requirements.txt](../../requirements.txt)
- [pyproject.toml](../../pyproject.toml)
- New: `src/table/structure_detector.py`
**Completed**:
- [x] Created `TableDetector` wrapper class with TDD approach
- [x] 29 unit tests passing, 84% coverage
- [x] Supports wired and wireless table detection
- [x] Lazy initialization pattern for PP-StructureV3
- [x] PaddleX 3.x API support (LayoutParsingResultV2)
- [x] Used existing `invoice-sm120` conda environment (PaddlePaddle 3.3, PaddleOCR 3.3.1)
- [x] Tested with real Swedish invoices - 10 tables detected across 5 PDFs
- [x] HTML table structure extraction working (pred_html)
- [x] Cell-level OCR text extraction working (table_ocr_pred)
**Verification**:
**Files Created**:
- `packages/backend/backend/table/__init__.py`
- `packages/backend/backend/table/structure_detector.py`
- `tests/table/__init__.py`
- `tests/table/test_structure_detector.py`
- `scripts/ppstructure_poc.py` (POC test script)
**POC Results**:
```
Total PDFs tested: 5
Total tables detected: 10
12d321cb-4a3a-47c6-90aa-890cecd13d91.pdf: 4 tables (14, 20, 10, 12 cells)
3c8d2673-42f7-4474-82ff-4480d6aee632.pdf: 1 table (25 cells)
52bb76c4-5a43-4c5a-81e0-d9a04002fcb1.pdf: 0 tables (letter, not invoice)
7d18a79e-7b1e-4daf-8560-f10ab04f265d.pdf: 4 tables (14, 20, 10, 12 cells)
87b95d60-d980-4037-b1b5-ba2b5d14ecc8.pdf: 1 table (25 cells)
```
**Verification Commands**:
```bash
# WSL 环境测试
# Run tests
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && \
conda activate invoice-py311 && \
python -c 'from paddleocr import PPStructureV3; print(\"OK\")'"
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && \
pytest tests/table/ -v"
# Run POC with real invoices
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && \
conda activate invoice-sm120 && \
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && \
python scripts/ppstructure_poc.py"
```
---
@@ -68,6 +91,22 @@ wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && \
**Goal**: 从检测到的表格区域提取结构化行项目数据
**Status**: COMPLETED
**Completed**:
- [x] Created `LineItemsExtractor` class with TDD approach
- [x] 19 unit tests passing, 93% coverage
- [x] Supports reversed tables (header at bottom - PP-StructureV3 quirk)
- [x] Swedish column name mapping (Beskrivning, Antal, Belopp, etc.)
- [x] HTMLTableParser for table structure parsing
- [x] Automatic header detection from row content
- [x] Tested with real Swedish invoices
**Files Created**:
- `packages/backend/backend/table/line_items_extractor.py`
- `tests/table/test_line_items_extractor.py`
- `scripts/ppstructure_line_items_poc.py` (POC test script)
**Data Structures**:
```python
@dataclass
@@ -122,6 +161,23 @@ class LineItemsResult:
**Goal**: 从 OCR 全文提取多税率 VAT 信息
**Status**: COMPLETED
**Completed**:
- [x] Created `VATExtractor` class with TDD approach
- [x] 21 unit tests passing, 96% coverage
- [x] `AmountParser` for Swedish/European number formats
- [x] Multiple VAT rate extraction (25%, 12%, 6%, 0%)
- [x] Multiple regex patterns for different Swedish formats
- [x] Confidence score calculation based on extracted data
- [x] Mathematical consistency verification
**Files Created**:
- `packages/backend/backend/vat/__init__.py`
- `packages/backend/backend/vat/vat_extractor.py`
- `tests/vat/__init__.py`
- `tests/vat/test_vat_extractor.py`
**Data Structures**:
```python
@dataclass
@@ -177,6 +233,23 @@ class VATSummary:
**Goal**: 多源交叉验证,确保 99%+ 精度
**Status**: COMPLETED
**Completed**:
- [x] Created `VATValidator` class with TDD approach
- [x] 15 unit tests passing, 90% coverage
- [x] Mathematical verification (base × rate = vat)
- [x] Total amount check (excl + vat = incl)
- [x] Line items comparison
- [x] Amount consistency check with existing YOLO extraction
- [x] Configurable tolerance
- [x] Confidence score calculation
**Files Created**:
- `packages/backend/backend/validation/vat_validator.py`
- `tests/validation/__init__.py`
- `tests/validation/test_vat_validator.py`
**Data Structures**:
```python
@dataclass

View File

@@ -546,7 +546,7 @@ Request:
"description": "First training run with 500 documents",
"document_ids": ["uuid1", "uuid2", "uuid3"],
"config": {
"model_name": "yolo11n.pt",
"model_name": "yolo26s.pt",
"epochs": 100,
"batch_size": 16,
"image_size": 640
@@ -1036,7 +1036,7 @@ Response:
| | Name: [Training Run 2024-01____________] | |
| | Description: [First training with 500 documents_________] | |
| | | |
| | Base Model: [yolo11n.pt v] Epochs: [100] Batch: [16] | |
| | Base Model: [yolo26s.pt v] Epochs: [100] Batch: [16] | |
| | Image Size: [640] Device: [GPU 0 v] | |
| | | |
| | [ ] Schedule for later: [2024-01-20] [22:00] | |
@@ -1088,7 +1088,7 @@ Response:
| | - Recall: 92% | |
| | | |
| | Configuration: | |
| | - Base: yolo11n.pt Epochs: 100 Batch: 16 Size: 640 | |
| | - Base: yolo26s.pt Epochs: 100 Batch: 16 Size: 640 | |
| | | |
| | Documents Used: [View 600 documents] | |
| +--------------------------------------------------------------+ |

View File

@@ -27,7 +27,7 @@ flowchart TD
I --> I1{--resume?}
I1 -- Yes --> I2[Load last.pt checkpoint]
I1 -- No --> I3[Load pretrained model\ne.g. yolo11n.pt]
I1 -- No --> I3[Load pretrained model\ne.g. yolo26s.pt]
I2 --> J[Configure Training]
I3 --> J

View File

@@ -5,4 +5,5 @@ export { inferenceApi } from './inference'
export { datasetsApi } from './datasets'
export { augmentationApi } from './augmentation'
export { modelsApi } from './models'
export { poolApi } from './pool'
export { dashboardApi } from './dashboard'

View File

@@ -1,15 +1,30 @@
import apiClient from '../client'
import type { InferenceResponse } from '../types'
export interface ProcessDocumentOptions {
extractLineItems?: boolean
}
// Longer timeout for inference - line items extraction can take 60+ seconds
const INFERENCE_TIMEOUT_MS = 120000
export const inferenceApi = {
processDocument: async (file: File): Promise<InferenceResponse> => {
processDocument: async (
file: File,
options: ProcessDocumentOptions = {}
): Promise<InferenceResponse> => {
const formData = new FormData()
formData.append('file', file)
if (options.extractLineItems) {
formData.append('extract_line_items', 'true')
}
const { data } = await apiClient.post('/api/v1/infer', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
timeout: INFERENCE_TIMEOUT_MS,
})
return data
},

View File

@@ -0,0 +1,40 @@
import apiClient from '../client'
import type {
PoolListResponse,
PoolStatsResponse,
PoolEntryResponse,
} from '../types'
export const poolApi = {
addToPool: async (documentId: string, reason?: string): Promise<PoolEntryResponse> => {
const { data } = await apiClient.post('/api/v1/admin/training/pool', {
document_id: documentId,
reason: reason ?? 'manual_addition',
})
return data
},
listEntries: async (params?: {
verified_only?: boolean
limit?: number
offset?: number
}): Promise<PoolListResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/pool', { params })
return data
},
getStats: async (): Promise<PoolStatsResponse> => {
const { data } = await apiClient.get('/api/v1/admin/training/pool/stats')
return data
},
verifyEntry: async (entryId: string): Promise<PoolEntryResponse> => {
const { data } = await apiClient.post(`/api/v1/admin/training/pool/${entryId}/verify`)
return data
},
removeEntry: async (entryId: string): Promise<{ message: string }> => {
const { data } = await apiClient.delete(`/api/v1/admin/training/pool/${entryId}`)
return data
},
}

View File

@@ -111,6 +111,9 @@ export interface ModelVersionItem {
is_active: boolean
metrics_mAP: number | null
document_count: number
model_type?: string
base_model_version_id?: string | null
gating_status?: string
trained_at: string | null
activated_at: string | null
created_at: string
@@ -182,6 +185,62 @@ export interface CrossValidationResult {
details: string[]
}
// Business Features Types (Line Items, VAT)
export interface LineItem {
row_index: number
description: string | null
quantity: string | null
unit: string | null
unit_price: string | null
amount: string | null
article_number: string | null
vat_rate: string | null
is_deduction: boolean
confidence: number
}
export interface LineItemsResult {
items: LineItem[]
header_row: string[]
total_amount: string | null
}
export interface VATBreakdown {
rate: number
base_amount: string | null
vat_amount: string
source: string
}
export interface VATSummary {
breakdowns: VATBreakdown[]
total_excl_vat: string | null
total_vat: string | null
total_incl_vat: string | null
confidence: number
}
export interface MathCheckResult {
rate: number
base_amount: number | null
expected_vat: number | null
actual_vat: number | null
is_valid: boolean
tolerance: number
}
export interface VATValidationResult {
is_valid: boolean
confidence_score: number
math_checks: MathCheckResult[]
total_check: boolean
line_items_vs_summary: boolean | null
amount_consistency: boolean | null
needs_review: boolean
review_reasons: string[]
}
export interface InferenceResult {
document_id: string
document_type: string
@@ -193,6 +252,10 @@ export interface InferenceResult {
visualization_url: string | null
errors: string[]
fallback_used: boolean
// Business features (optional, only when extract_line_items=true)
line_items: LineItemsResult | null
vat_summary: VATSummary | null
vat_validation: VATValidationResult | null
}
export interface InferenceResponse {
@@ -310,19 +373,6 @@ export interface TrainingTaskResponse {
// Model Version types
export interface ModelVersionItem {
version_id: string
version: string
name: string
status: string
is_active: boolean
metrics_mAP: number | null
document_count: number
trained_at: string | null
activated_at: string | null
created_at: string
}
export interface ModelVersionDetailResponse {
version_id: string
version: string
@@ -337,6 +387,10 @@ export interface ModelVersionDetailResponse {
metrics_precision: number | null
metrics_recall: number | null
document_count: number
model_type?: string
base_model_version_id?: string | null
base_training_dataset_id?: string | null
gating_status?: string
training_config: Record<string, unknown> | null
file_size: number | null
trained_at: string | null
@@ -345,6 +399,39 @@ export interface ModelVersionDetailResponse {
updated_at: string
}
// Fine-Tune Pool types
export interface PoolEntryItem {
entry_id: string
document_id: string
added_by: string | null
reason: string | null
is_verified: boolean
verified_at: string | null
verified_by: string | null
created_at: string
}
export interface PoolListResponse {
total: number
limit: number
offset: number
entries: PoolEntryItem[]
}
export interface PoolStatsResponse {
total_entries: number
verified_entries: number
unverified_entries: number
is_ready: boolean
min_required: number
}
export interface PoolEntryResponse {
entry_id: string
message: string
}
export interface ModelVersionListResponse {
total: number
limit: number

View File

@@ -1,7 +1,9 @@
import React, { useState, useRef } from 'react'
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock } from 'lucide-react'
import { UploadCloud, FileText, Loader2, CheckCircle2, AlertCircle, Clock, Table2 } from 'lucide-react'
import { Button } from './Button'
import { inferenceApi } from '../api/endpoints'
import { LineItemsTable } from './LineItemsTable'
import { VATSummaryCard } from './VATSummaryCard'
import type { InferenceResult } from '../api/types'
export const InferenceDemo: React.FC = () => {
@@ -10,6 +12,7 @@ export const InferenceDemo: React.FC = () => {
const [isProcessing, setIsProcessing] = useState(false)
const [result, setResult] = useState<InferenceResult | null>(null)
const [error, setError] = useState<string | null>(null)
const [extractLineItems, setExtractLineItems] = useState(false)
const fileInputRef = useRef<HTMLInputElement>(null)
const handleFileSelect = (file: File | null) => {
@@ -50,9 +53,9 @@ export const InferenceDemo: React.FC = () => {
setError(null)
try {
const response = await inferenceApi.processDocument(selectedFile)
console.log('API Response:', response)
console.log('Visualization URL:', response.result?.visualization_url)
const response = await inferenceApi.processDocument(selectedFile, {
extractLineItems,
})
setResult(response.result)
} catch (err) {
setError(err instanceof Error ? err.message : 'Processing failed')
@@ -65,6 +68,7 @@ export const InferenceDemo: React.FC = () => {
setSelectedFile(null)
setResult(null)
setError(null)
setExtractLineItems(false)
}
const formatFieldName = (field: string): string => {
@@ -183,12 +187,35 @@ export const InferenceDemo: React.FC = () => {
)}
{selectedFile && !isProcessing && (
<div className="mt-6 flex gap-3 justify-end">
<div className="mt-6 space-y-4">
{/* Business Features Checkbox */}
<label className="flex items-center gap-3 p-4 bg-warm-bg/50 rounded-lg border border-warm-divider cursor-pointer hover:bg-warm-hover/50 transition-colors">
<input
type="checkbox"
checked={extractLineItems}
onChange={(e) => setExtractLineItems(e.target.checked)}
className="w-5 h-5 rounded border-warm-border text-warm-text-secondary focus:ring-warm-text-secondary"
/>
<div className="flex items-center gap-2">
<Table2 size={18} className="text-warm-text-secondary" />
<div>
<span className="font-medium text-warm-text-primary">
Extract Line Items & VAT
</span>
<p className="text-xs text-warm-text-muted mt-0.5">
Extract product/service rows, VAT breakdown, and cross-validation
</p>
</div>
</div>
</label>
<div className="flex gap-3 justify-end">
<Button variant="secondary" onClick={handleReset}>
Cancel
</Button>
<Button onClick={handleProcess}>Process Invoice</Button>
</div>
</div>
)}
</div>
</div>
@@ -274,6 +301,21 @@ export const InferenceDemo: React.FC = () => {
</div>
</div>
{/* Line Items */}
{result.line_items && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-5 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
<Table2 size={20} className="text-warm-text-secondary" />
Line Items
<span className="ml-auto text-sm font-normal text-warm-text-muted">
{result.line_items.items.length} item(s)
</span>
</h3>
<LineItemsTable lineItems={result.line_items} />
</div>
)}
{/* Visualization */}
{result.visualization_url && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
@@ -437,6 +479,20 @@ export const InferenceDemo: React.FC = () => {
</div>
)}
{/* VAT Summary */}
{result.vat_summary && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">
<h3 className="text-lg font-bold text-warm-text-primary mb-4 flex items-center gap-2">
<span className="w-1 h-5 bg-warm-text-secondary rounded-full"></span>
VAT Summary
</h3>
<VATSummaryCard
vatSummary={result.vat_summary}
vatValidation={result.vat_validation}
/>
</div>
)}
{/* Errors */}
{result.errors.length > 0 && (
<div className="bg-warm-card rounded-xl border border-warm-border p-6 shadow-sm">

View File

@@ -0,0 +1,128 @@
import React from 'react'
import { CheckCircle2, MinusCircle } from 'lucide-react'
import type { LineItemsResult } from '../api/types'
interface LineItemsTableProps {
lineItems: LineItemsResult
}
export const LineItemsTable: React.FC<LineItemsTableProps> = ({ lineItems }) => {
if (!lineItems.items || lineItems.items.length === 0) {
return (
<div className="text-center py-8 text-warm-text-muted">
No line items found in this document
</div>
)
}
return (
<div className="space-y-4">
<div className="overflow-x-auto">
<table className="w-full text-sm">
<thead>
<tr className="border-b border-warm-divider">
<th className="text-left py-3 px-4 font-semibold text-warm-text-muted text-xs uppercase tracking-wide">
#
</th>
<th className="text-left py-3 px-4 font-semibold text-warm-text-muted text-xs uppercase tracking-wide">
Description
</th>
<th className="text-right py-3 px-4 font-semibold text-warm-text-muted text-xs uppercase tracking-wide">
Qty
</th>
<th className="text-right py-3 px-4 font-semibold text-warm-text-muted text-xs uppercase tracking-wide">
Unit Price
</th>
<th className="text-right py-3 px-4 font-semibold text-warm-text-muted text-xs uppercase tracking-wide">
Amount
</th>
<th className="text-right py-3 px-4 font-semibold text-warm-text-muted text-xs uppercase tracking-wide">
VAT %
</th>
<th className="text-center py-3 px-4 font-semibold text-warm-text-muted text-xs uppercase tracking-wide">
Conf.
</th>
</tr>
</thead>
<tbody>
{lineItems.items.map((item) => (
<tr
key={`row-${item.row_index}`}
className={`border-b border-warm-divider hover:bg-warm-hover/50 transition-colors ${
item.is_deduction ? 'bg-red-50' : ''
}`}
>
<td className="py-3 px-4 text-warm-text-muted font-mono text-xs">
{item.row_index}
</td>
<td className="py-3 px-4 font-medium max-w-xs truncate">
<div className="flex items-center gap-2">
{item.is_deduction && (
<MinusCircle size={14} className="text-red-500 flex-shrink-0" />
)}
<span className={item.is_deduction ? 'text-red-600' : 'text-warm-text-primary'}>
{item.description || '-'}
</span>
</div>
</td>
<td className="py-3 px-4 text-right text-warm-text-primary font-mono">
{item.quantity || '-'}
{item.unit && (
<span className="text-warm-text-muted ml-1">{item.unit}</span>
)}
</td>
<td className="py-3 px-4 text-right text-warm-text-primary font-mono">
{item.unit_price || '-'}
</td>
<td className={`py-3 px-4 text-right font-bold font-mono ${
item.is_deduction ? 'text-red-600' : 'text-warm-text-primary'
}`}>
{item.amount || '-'}
</td>
<td className="py-3 px-4 text-right text-warm-text-secondary font-mono">
{item.vat_rate ? `${item.vat_rate}%` : '-'}
</td>
<td className="py-3 px-4 text-center">
<div className="flex items-center justify-center gap-1">
<CheckCircle2
size={14}
className={
item.confidence >= 0.8
? 'text-green-500'
: item.confidence >= 0.5
? 'text-yellow-500'
: 'text-red-500'
}
/>
<span
className={`text-xs font-medium ${
item.confidence >= 0.8
? 'text-green-600'
: item.confidence >= 0.5
? 'text-yellow-600'
: 'text-red-600'
}`}
>
{(item.confidence * 100).toFixed(0)}%
</span>
</div>
</td>
</tr>
))}
</tbody>
</table>
</div>
{lineItems.total_amount && (
<div className="flex justify-end pt-4 border-t border-warm-divider">
<div className="text-right">
<span className="text-sm text-warm-text-muted mr-4">Total:</span>
<span className="text-lg font-bold text-warm-text-primary font-mono">
{lineItems.total_amount} SEK
</span>
</div>
</div>
)}
</div>
)
}

View File

@@ -74,6 +74,25 @@ export const Models: React.FC = () => {
</h4>
<p className="text-sm text-warm-text-muted">Trained {formatDate(model.trained_at)}</p>
</div>
<div className="flex gap-1.5 items-center">
{(model.model_type ?? 'base') === 'finetune' && (
<span className="px-2 py-0.5 rounded text-xs font-medium bg-purple-100 text-purple-700">
Fine-tuned
</span>
)}
{model.gating_status && model.gating_status !== 'skipped' && (
<span className={`px-2 py-0.5 rounded text-xs font-medium ${
model.gating_status === 'pass' ? 'bg-green-100 text-green-700'
: model.gating_status === 'review' ? 'bg-yellow-100 text-yellow-700'
: model.gating_status === 'reject' ? 'bg-red-100 text-red-700'
: 'bg-gray-100 text-gray-600'
}`}>
{model.gating_status === 'pass' ? 'PASS'
: model.gating_status === 'review' ? 'REVIEW'
: model.gating_status === 'reject' ? 'REJECT'
: model.gating_status.toUpperCase()}
</span>
)}
<span className={`px-3 py-1 rounded-full text-xs font-medium ${
model.is_active
? 'bg-warm-state-info/10 text-warm-state-info'
@@ -82,6 +101,7 @@ export const Models: React.FC = () => {
{model.is_active ? 'Active' : model.status}
</span>
</div>
</div>
<div className="mt-4 flex gap-8">
<div>

View File

@@ -1,15 +1,15 @@
import React, { useState, useMemo } from 'react'
import { useQuery } from '@tanstack/react-query'
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle } from 'lucide-react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import { Database, Plus, Trash2, Eye, Play, Check, Loader2, AlertCircle, Shield, CheckCircle, XCircle } from 'lucide-react'
import { Button } from './Button'
import { AugmentationConfig } from './AugmentationConfig'
import { useDatasets } from '../hooks/useDatasets'
import { useTrainingDocuments } from '../hooks/useTraining'
import { trainingApi } from '../api/endpoints'
import type { DatasetListItem } from '../api/types'
import { trainingApi, poolApi } from '../api/endpoints'
import type { DatasetListItem, PoolEntryItem } from '../api/types'
import type { AugmentationConfig as AugmentationConfigType } from '../api/endpoints/augmentation'
type Tab = 'datasets' | 'create'
type Tab = 'datasets' | 'create' | 'pool'
interface TrainingProps {
onNavigate?: (view: string, id?: string) => void
@@ -72,19 +72,23 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
const [augmentationConfig, setAugmentationConfig] = useState<Partial<AugmentationConfigType>>({})
const [augmentationMultiplier, setAugmentationMultiplier] = useState(2)
const isFineTune = baseModelType === 'existing'
// Fetch available trained models (active or inactive, not archived)
const { data: modelsData } = useQuery({
queryKey: ['training', 'models', 'available'],
queryFn: () => trainingApi.getModels(),
})
// Filter out archived models - only show active/inactive models for base model selection
const availableModels = (modelsData?.models ?? []).filter(m => m.status !== 'archived')
// Only show base models (not fine-tuned) for selection - prevents chaining fine-tunes
const availableModels = (modelsData?.models ?? []).filter(
m => m.status !== 'archived' && (m.model_type ?? 'base') === 'base'
)
const handleSubmit = () => {
onSubmit({
name,
config: {
model_name: baseModelType === 'pretrained' ? 'yolo11n.pt' : undefined,
model_name: baseModelType === 'pretrained' ? 'yolo26s.pt' : undefined,
base_model_version_id: baseModelType === 'existing' ? baseModelVersionId : null,
epochs,
batch_size: batchSize,
@@ -121,14 +125,16 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
if (e.target.value === 'pretrained') {
setBaseModelType('pretrained')
setBaseModelVersionId(null)
setEpochs(100)
} else {
setBaseModelType('existing')
setBaseModelVersionId(e.target.value)
setEpochs(10) // Fine-tune: fewer epochs per best practices
}
}}
className="w-full h-10 px-3 rounded-md border border-warm-divider bg-white text-warm-text-primary focus:outline-none focus:ring-1 focus:ring-warm-state-info"
>
<option value="pretrained">yolo11n.pt (Pretrained)</option>
<option value="pretrained">yolo26s.pt (Pretrained)</option>
{availableModels.map(m => (
<option key={m.version_id} value={m.version_id}>
{m.name} v{m.version} ({m.metrics_mAP ? `${(m.metrics_mAP * 100).toFixed(1)}% mAP` : 'No metrics'})
@@ -138,10 +144,23 @@ const TrainDialog: React.FC<TrainDialogProps> = ({ dataset, onClose, onSubmit, i
<p className="text-xs text-warm-text-muted mt-1">
{baseModelType === 'pretrained'
? 'Start from pretrained YOLO model'
: 'Continue training from an existing model (incremental training)'}
: 'Fine-tune from base model (freeze=10, cos_lr, data mixing)'}
</p>
</div>
{/* Fine-tune info panel */}
{isFineTune && (
<div className="bg-warm-state-info/5 border border-warm-state-info/20 rounded-lg p-3 text-xs text-warm-text-secondary">
<p className="font-medium text-warm-state-info mb-1">Fine-Tune Mode</p>
<ul className="space-y-0.5 text-warm-text-muted">
<li>Epochs: 10 (auto-set), Backbone frozen (10 layers)</li>
<li>Cosine LR scheduler, Pool data mixed with old data</li>
<li>Requires 50+ verified pool entries</li>
<li>Deployment gating runs automatically after training</li>
</ul>
</div>
)}
<div className="flex gap-4">
<div className="flex-1">
<label htmlFor="train-epochs" className="block text-sm font-medium text-warm-text-secondary mb-1">Epochs</label>
@@ -455,6 +474,148 @@ const CreateDataset: React.FC<{ onSwitchTab: (tab: Tab) => void }> = ({ onSwitch
)
}
// --- Fine-Tune Pool ---
const FineTunePool: React.FC = () => {
const queryClient = useQueryClient()
const { data: statsData, isLoading: isLoadingStats } = useQuery({
queryKey: ['pool', 'stats'],
queryFn: () => poolApi.getStats(),
})
const { data: entriesData, isLoading: isLoadingEntries } = useQuery({
queryKey: ['pool', 'entries'],
queryFn: () => poolApi.listEntries({ limit: 50 }),
})
const verifyMutation = useMutation({
mutationFn: (entryId: string) => poolApi.verifyEntry(entryId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['pool'] })
},
})
const removeMutation = useMutation({
mutationFn: (entryId: string) => poolApi.removeEntry(entryId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['pool'] })
},
})
const stats = statsData
const entries = entriesData?.entries ?? []
return (
<div className="space-y-6">
{/* Pool Stats */}
<div className="grid grid-cols-4 gap-4">
{isLoadingStats ? (
<div className="col-span-4 flex items-center justify-center py-8 text-warm-text-muted">
<Loader2 size={20} className="animate-spin mr-2" />Loading stats...
</div>
) : (
<>
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
<p className="text-xs text-warm-text-muted uppercase mb-1">Total Entries</p>
<p className="text-2xl font-bold font-mono text-warm-text-primary">{stats?.total_entries ?? 0}</p>
</div>
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
<p className="text-xs text-warm-text-muted uppercase mb-1">Verified</p>
<p className="text-2xl font-bold font-mono text-warm-state-success">{stats?.verified_entries ?? 0}</p>
</div>
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
<p className="text-xs text-warm-text-muted uppercase mb-1">Unverified</p>
<p className="text-2xl font-bold font-mono text-warm-state-warning">{stats?.unverified_entries ?? 0}</p>
</div>
<div className="bg-warm-card border border-warm-border rounded-lg p-4">
<p className="text-xs text-warm-text-muted uppercase mb-1">Ready for Fine-Tune</p>
<div className="flex items-center gap-2">
{stats?.is_ready ? (
<CheckCircle size={20} className="text-warm-state-success" />
) : (
<AlertCircle size={20} className="text-warm-state-warning" />
)}
<p className="text-lg font-medium text-warm-text-primary">
{stats?.is_ready ? 'Yes' : `Need ${(stats?.min_required ?? 50) - (stats?.verified_entries ?? 0)} more`}
</p>
</div>
</div>
</>
)}
</div>
{/* Pool Entries Table */}
{isLoadingEntries ? (
<div className="flex items-center justify-center py-12 text-warm-text-muted">
<Loader2 size={20} className="animate-spin mr-2" />Loading pool entries...
</div>
) : entries.length === 0 ? (
<div className="flex flex-col items-center justify-center py-16 text-warm-text-muted">
<Shield size={48} className="mb-4 opacity-40" />
<p className="text-lg mb-2">Fine-tune pool is empty</p>
<p className="text-sm">Add documents with extraction failures to the pool for future fine-tuning.</p>
</div>
) : (
<div className="bg-warm-card border border-warm-border rounded-lg overflow-hidden shadow-sm">
<table className="w-full text-left">
<thead className="bg-white border-b border-warm-border">
<tr>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Document ID</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Reason</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Status</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Added</th>
<th className="py-3 px-4 text-xs font-semibold text-warm-text-muted uppercase">Actions</th>
</tr>
</thead>
<tbody>
{entries.map((entry: PoolEntryItem) => (
<tr key={entry.entry_id} className="border-b border-warm-border hover:bg-warm-hover transition-colors">
<td className="py-3 px-4 text-sm font-mono text-warm-text-secondary">{entry.document_id.slice(0, 8)}...</td>
<td className="py-3 px-4 text-sm text-warm-text-muted">{entry.reason ?? '-'}</td>
<td className="py-3 px-4">
<span className={`inline-flex items-center px-2.5 py-1 rounded-full text-xs font-medium ${
entry.is_verified
? 'bg-warm-state-success/10 text-warm-state-success'
: 'bg-warm-state-warning/10 text-warm-state-warning'
}`}>
{entry.is_verified ? <Check size={12} className="mr-1" /> : <AlertCircle size={12} className="mr-1" />}
{entry.is_verified ? 'Verified' : 'Unverified'}
</span>
</td>
<td className="py-3 px-4 text-sm text-warm-text-muted">{new Date(entry.created_at).toLocaleDateString()}</td>
<td className="py-3 px-4">
<div className="flex gap-1">
{!entry.is_verified && (
<button
title="Verify"
onClick={() => verifyMutation.mutate(entry.entry_id)}
disabled={verifyMutation.isPending}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-success transition-colors"
>
<CheckCircle size={14} />
</button>
)}
<button
title="Remove"
onClick={() => removeMutation.mutate(entry.entry_id)}
disabled={removeMutation.isPending}
className="p-1.5 rounded hover:bg-warm-selected text-warm-text-muted hover:text-warm-state-error transition-colors"
>
<Trash2 size={14} />
</button>
</div>
</td>
</tr>
))}
</tbody>
</table>
</div>
)}
</div>
)
}
// --- Main Training Component ---
export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
@@ -468,7 +629,7 @@ export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
{/* Tabs */}
<div className="flex gap-1 mb-6 border-b border-warm-border">
{([['datasets', 'Datasets'], ['create', 'Create Dataset']] as const).map(([key, label]) => (
{([['datasets', 'Datasets'], ['create', 'Create Dataset'], ['pool', 'Fine-Tune Pool']] as const).map(([key, label]) => (
<button key={key} onClick={() => setActiveTab(key)}
className={`px-4 py-2.5 text-sm font-medium border-b-2 transition-colors ${
activeTab === key
@@ -482,6 +643,7 @@ export const Training: React.FC<TrainingProps> = ({ onNavigate }) => {
{activeTab === 'datasets' && <DatasetList onNavigate={onNavigate} onSwitchTab={setActiveTab} />}
{activeTab === 'create' && <CreateDataset onSwitchTab={setActiveTab} />}
{activeTab === 'pool' && <FineTunePool />}
</div>
)
}

View File

@@ -0,0 +1,188 @@
import React from 'react'
import { CheckCircle2, AlertCircle, AlertTriangle } from 'lucide-react'
import type { VATSummary, VATValidationResult } from '../api/types'
interface VATSummaryCardProps {
vatSummary: VATSummary
vatValidation?: VATValidationResult | null
}
export const VATSummaryCard: React.FC<VATSummaryCardProps> = ({
vatSummary,
vatValidation,
}) => {
const hasBreakdowns = vatSummary.breakdowns && vatSummary.breakdowns.length > 0
return (
<div className="space-y-4">
{/* VAT Breakdowns by Rate */}
{hasBreakdowns && (
<div className="space-y-2">
<h4 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide">
VAT Breakdown
</h4>
<div className="space-y-2">
{vatSummary.breakdowns.map((breakdown, index) => (
<div
key={index}
className="p-3 bg-warm-bg/70 rounded-lg border border-warm-divider"
>
<div className="flex justify-between items-center">
<span className="text-sm font-bold text-warm-text-secondary">
{breakdown.rate}% Moms
</span>
<span className="text-xs text-warm-text-muted px-2 py-0.5 bg-warm-selected rounded">
{breakdown.source}
</span>
</div>
<div className="mt-2 grid grid-cols-2 gap-4 text-sm">
<div>
<span className="text-warm-text-muted">Base: </span>
<span className="font-mono text-warm-text-primary">
{breakdown.base_amount ?? 'N/A'}
</span>
</div>
<div>
<span className="text-warm-text-muted">VAT: </span>
<span className="font-mono font-bold text-warm-text-primary">
{breakdown.vat_amount ?? 'N/A'}
</span>
</div>
</div>
</div>
))}
</div>
</div>
)}
{/* Totals */}
<div className="pt-4 border-t border-warm-divider space-y-2">
{vatSummary.total_excl_vat && (
<div className="flex justify-between text-sm">
<span className="text-warm-text-muted">Excl. VAT:</span>
<span className="font-mono text-warm-text-primary">
{vatSummary.total_excl_vat}
</span>
</div>
)}
{vatSummary.total_vat && (
<div className="flex justify-between text-sm">
<span className="text-warm-text-muted">Total VAT:</span>
<span className="font-mono font-bold text-warm-text-secondary">
{vatSummary.total_vat}
</span>
</div>
)}
{vatSummary.total_incl_vat && (
<div className="flex justify-between text-sm pt-2 border-t border-warm-divider">
<span className="font-semibold text-warm-text-primary">Incl. VAT:</span>
<span className="font-mono font-bold text-warm-text-primary text-base">
{vatSummary.total_incl_vat}
</span>
</div>
)}
</div>
{/* Confidence */}
<div className="flex items-center gap-2 text-xs">
<CheckCircle2 size={14} className="text-warm-text-secondary" />
<span className="text-warm-text-muted">
Confidence: {(vatSummary.confidence * 100).toFixed(1)}%
</span>
</div>
{/* Validation Results */}
{vatValidation && (
<div className="pt-4 border-t border-warm-divider">
<h4 className="text-sm font-semibold text-warm-text-muted uppercase tracking-wide mb-3">
VAT Validation
</h4>
<div
className={`
p-3 rounded-lg mb-3 flex items-center gap-3
${
vatValidation.is_valid
? 'bg-green-50 border border-green-200'
: vatValidation.needs_review
? 'bg-yellow-50 border border-yellow-200'
: 'bg-red-50 border border-red-200'
}
`}
>
{vatValidation.is_valid ? (
<>
<CheckCircle2 size={20} className="text-green-600 flex-shrink-0" />
<span className="font-bold text-green-800 text-sm">
VAT Calculation Valid
</span>
</>
) : vatValidation.needs_review ? (
<>
<AlertTriangle size={20} className="text-yellow-600 flex-shrink-0" />
<span className="font-bold text-yellow-800 text-sm">
Needs Manual Review
</span>
</>
) : (
<>
<AlertCircle size={20} className="text-red-600 flex-shrink-0" />
<span className="font-bold text-red-800 text-sm">
Validation Failed
</span>
</>
)}
</div>
{/* Math Checks */}
{vatValidation.math_checks && vatValidation.math_checks.length > 0 && (
<div className="space-y-2 mb-3">
{vatValidation.math_checks.map((check, index) => (
<div
key={index}
className={`
p-2 rounded text-xs flex items-center justify-between
${
check.is_valid
? 'bg-green-50 border border-green-200'
: 'bg-red-50 border border-red-200'
}
`}
>
<span className={check.is_valid ? 'text-green-700' : 'text-red-700'}>
{check.rate}%: {check.base_amount?.toFixed(2) ?? 'N/A'} x {check.rate}% ={' '}
{check.expected_vat?.toFixed(2) ?? 'N/A'}
</span>
{check.is_valid ? (
<CheckCircle2 size={14} className="text-green-600" />
) : (
<AlertCircle size={14} className="text-red-600" />
)}
</div>
))}
</div>
)}
{/* Review Reasons */}
{vatValidation.review_reasons && vatValidation.review_reasons.length > 0 && (
<div className="space-y-1">
{vatValidation.review_reasons.map((reason, index) => (
<div
key={index}
className="text-xs text-yellow-700 bg-yellow-50 p-2 rounded border border-yellow-200"
>
{reason}
</div>
))}
</div>
)}
{/* Confidence Score */}
<div className="mt-3 text-xs text-warm-text-muted">
Validation confidence: {(vatValidation.confidence_score * 100).toFixed(1)}%
</div>
</div>
)}
</div>
)
}

View File

@@ -7,10 +7,14 @@ Runs inference on new PDFs to extract invoice data.
import argparse
import json
import logging
import sys
from pathlib import Path
from shared.config import DEFAULT_DPI
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def main():
@@ -50,8 +54,8 @@ def main():
)
parser.add_argument(
'--lang',
default='en',
help='OCR language (default: en)'
default='sv',
help='OCR language (default: sv)'
)
parser.add_argument(
'--gpu',
@@ -66,10 +70,13 @@ def main():
args = parser.parse_args()
# Configure logging for CLI
setup_cli_logging()
# Validate model
model_path = Path(args.model)
if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr)
logger.error("Model not found: %s", model_path)
sys.exit(1)
# Get input files
@@ -79,16 +86,16 @@ def main():
elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf'))
else:
print(f"Error: Input not found: {input_path}", file=sys.stderr)
logger.error("Input not found: %s", input_path)
sys.exit(1)
if not pdf_files:
print("Error: No PDF files found", file=sys.stderr)
logger.error("No PDF files found")
sys.exit(1)
if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}")
logger.info("Processing %d PDF file(s)", len(pdf_files))
logger.info("Model: %s", model_path)
from backend.pipeline import InferencePipeline
@@ -107,18 +114,18 @@ def main():
for pdf_path in pdf_files:
if args.verbose:
print(f"Processing: {pdf_path.name}")
logger.info("Processing: %s", pdf_path.name)
result = pipeline.process_pdf(pdf_path)
results.append(result.to_json())
if args.verbose:
print(f" Success: {result.success}")
print(f" Fields: {len(result.fields)}")
logger.info(" Success: %s", result.success)
logger.info(" Fields: %d", len(result.fields))
if result.fallback_used:
print(f" Fallback used: Yes")
logger.info(" Fallback used: Yes")
if result.errors:
print(f" Errors: {result.errors}")
logger.info(" Errors: %s", result.errors)
# Output results
if len(results) == 1:
@@ -132,9 +139,11 @@ def main():
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
if args.verbose:
print(f"\nResults written to: {args.output}")
logger.info("Results written to: %s", args.output)
else:
print(json_output)
# Output JSON to stdout (not logged)
sys.stdout.write(json_output)
sys.stdout.write('\n')
if __name__ == '__main__':

View File

@@ -289,6 +289,16 @@ class ModelVersion(SQLModel, table=True):
is_active: bool = Field(default=False, index=True)
# Only one version can be active at a time for inference
# Model lineage
model_type: str = Field(default="base", max_length=20, index=True)
# "base" = trained from pretrained YOLO, "finetune" = fine-tuned from base model
base_model_version_id: UUID | None = Field(default=None, index=True)
# Points to the base model this was fine-tuned from (None for base models)
base_training_dataset_id: UUID | None = Field(default=None, index=True)
# The dataset used for original base training (for data mixing old samples)
gating_status: str = Field(default="pending", max_length=20, index=True)
# Deployment gating: pending, pass, review, reject, skipped
# Training association
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
@@ -317,6 +327,64 @@ class ModelVersion(SQLModel, table=True):
updated_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Fine-Tune Pool
# =============================================================================
class FineTunePoolEntry(SQLModel, table=True):
"""Document in the fine-tune pool for incremental model improvement."""
__tablename__ = "finetune_pool_entries"
entry_id: UUID = Field(default_factory=uuid4, primary_key=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
added_by: str | None = Field(default=None, max_length=255)
reason: str | None = Field(default=None, max_length=255)
# Reason: user_reported_failure, manual_addition
is_verified: bool = Field(default=False, index=True)
verified_at: datetime | None = Field(default=None)
verified_by: str | None = Field(default=None, max_length=255)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Deployment Gating
# =============================================================================
class GatingResult(SQLModel, table=True):
"""Model deployment gating validation result."""
__tablename__ = "gating_results"
result_id: UUID = Field(default_factory=uuid4, primary_key=True)
model_version_id: UUID = Field(foreign_key="model_versions.version_id", index=True)
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id")
# Gate 1: Regression validation (original test set mAP)
gate1_status: str = Field(default="pending", max_length=20)
# pending, pass, review, reject
gate1_original_mAP: float | None = Field(default=None)
gate1_new_mAP: float | None = Field(default=None)
gate1_mAP_drop: float | None = Field(default=None)
# Gate 2: New sample validation (detection rate on pool docs)
gate2_status: str = Field(default="pending", max_length=20)
gate2_detection_rate: float | None = Field(default=None)
gate2_total_samples: int | None = Field(default=None)
gate2_detected_samples: int | None = Field(default=None)
# Overall
overall_status: str = Field(default="pending", max_length=20)
# pending, pass, review, reject
reviewer_notes: str | None = Field(default=None)
reviewed_by: str | None = Field(default=None, max_length=255)
reviewed_at: datetime | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================

View File

@@ -13,6 +13,7 @@ from backend.data.repositories.training_task_repository import TrainingTaskRepos
from backend.data.repositories.dataset_repository import DatasetRepository
from backend.data.repositories.model_version_repository import ModelVersionRepository
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
from backend.data.repositories.finetune_pool_repository import FineTunePoolRepository
__all__ = [
"BaseRepository",
@@ -23,4 +24,5 @@ __all__ = [
"DatasetRepository",
"ModelVersionRepository",
"BatchUploadRepository",
"FineTunePoolRepository",
]

View File

@@ -0,0 +1,131 @@
"""
Fine-Tune Pool Repository
Manages the fine-tune pool: accumulated user-reported failure documents
for incremental model improvement.
"""
import logging
from datetime import datetime
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import FineTunePoolEntry
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class FineTunePoolRepository(BaseRepository[FineTunePoolEntry]):
"""Repository for fine-tune pool management."""
def add_document(
self,
document_id: str | UUID,
added_by: str | None = None,
reason: str | None = None,
) -> FineTunePoolEntry:
"""Add a document to the fine-tune pool."""
with get_session_context() as session:
entry = FineTunePoolEntry(
document_id=UUID(str(document_id)),
added_by=added_by,
reason=reason,
)
session.add(entry)
session.commit()
session.refresh(entry)
session.expunge(entry)
return entry
def get_entry(self, entry_id: str | UUID) -> FineTunePoolEntry | None:
"""Get a pool entry by ID."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if entry:
session.expunge(entry)
return entry
def get_by_document(self, document_id: str | UUID) -> FineTunePoolEntry | None:
"""Get pool entry for a document."""
with get_session_context() as session:
result = session.exec(
select(FineTunePoolEntry).where(
FineTunePoolEntry.document_id == UUID(str(document_id))
)
).first()
if result:
session.expunge(result)
return result
def get_paginated(
self,
verified_only: bool = False,
limit: int = 20,
offset: int = 0,
) -> tuple[list[FineTunePoolEntry], int]:
"""List pool entries with pagination."""
with get_session_context() as session:
query = select(FineTunePoolEntry)
count_query = select(func.count()).select_from(FineTunePoolEntry)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
count_query = count_query.where(FineTunePoolEntry.is_verified == True)
total = session.exec(count_query).one()
entries = session.exec(
query.order_by(FineTunePoolEntry.created_at.desc())
.offset(offset)
.limit(limit)
).all()
for e in entries:
session.expunge(e)
return list(entries), total
def get_pool_count(self, verified_only: bool = True) -> int:
"""Get count of entries in the pool."""
with get_session_context() as session:
query = select(func.count()).select_from(FineTunePoolEntry)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
return session.exec(query).one()
def get_all_document_ids(self, verified_only: bool = True) -> list[UUID]:
"""Get all document IDs in the pool."""
with get_session_context() as session:
query = select(FineTunePoolEntry.document_id)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
results = session.exec(query).all()
return list(results)
def verify_entry(
self,
entry_id: str | UUID,
verified_by: str | None = None,
) -> FineTunePoolEntry | None:
"""Mark a pool entry as verified."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if not entry:
return None
entry.is_verified = True
entry.verified_at = datetime.utcnow()
entry.verified_by = verified_by
session.add(entry)
session.commit()
session.refresh(entry)
session.expunge(entry)
return entry
def remove_entry(self, entry_id: str | UUID) -> bool:
"""Remove an entry from the pool."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if not entry:
return False
session.delete(entry)
session.commit()
return True

View File

@@ -43,6 +43,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
model_type: str = "base",
base_model_version_id: str | UUID | None = None,
base_training_dataset_id: str | UUID | None = None,
gating_status: str = "pending",
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
@@ -60,6 +64,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
model_type=model_type,
base_model_version_id=UUID(str(base_model_version_id)) if base_model_version_id else None,
base_training_dataset_id=UUID(str(base_training_dataset_id)) if base_training_dataset_id else None,
gating_status=gating_status,
)
session.add(model)
session.commit()

View File

@@ -0,0 +1,25 @@
"""
Domain Layer
Business logic separated from technical implementation.
Contains document classification and invoice validation logic.
"""
from backend.domain.document_classifier import (
ClassificationResult,
DocumentClassifier,
)
from backend.domain.invoice_validator import (
InvoiceValidator,
ValidationIssue,
ValidationResult,
)
from backend.domain.utils import has_value
__all__ = [
"ClassificationResult",
"DocumentClassifier",
"InvoiceValidator",
"ValidationIssue",
"ValidationResult",
"has_value",
]

View File

@@ -0,0 +1,108 @@
"""
Document Classifier
Business logic for classifying documents based on extracted fields.
Separates classification logic from inference pipeline.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ClassificationResult:
"""
Immutable result of document classification.
Attributes:
document_type: Either "invoice" or "letter"
confidence: Confidence score between 0.0 and 1.0
reason: Human-readable explanation of classification
"""
document_type: str
confidence: float
reason: str
class DocumentClassifier:
"""
Classifies documents as invoice or letter based on extracted fields.
Classification Rules:
1. If payment_line is present -> invoice (high confidence)
2. If 2+ invoice indicators present -> invoice (medium confidence)
3. If 1 invoice indicator present -> invoice (lower confidence)
4. Otherwise -> letter
Invoice indicator fields:
- payment_line (strongest indicator)
- OCR
- Amount
- Bankgiro
- Plusgiro
- InvoiceNumber
"""
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
{
"payment_line",
"OCR",
"Amount",
"Bankgiro",
"Plusgiro",
"InvoiceNumber",
}
)
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
"""
Classify document type based on extracted fields.
Args:
fields: Dictionary of field names to extracted values.
Empty strings or whitespace-only strings are treated as missing.
Returns:
Immutable ClassificationResult with type, confidence, and reason.
"""
# Rule 1: payment_line is the strongest indicator
if has_value(fields.get("payment_line")):
return ClassificationResult(
document_type="invoice",
confidence=0.95,
reason="payment_line detected",
)
# Count present invoice indicators (excluding payment_line already checked)
present_indicators = [
field
for field in self.INVOICE_INDICATOR_FIELDS
if field != "payment_line" and has_value(fields.get(field))
]
indicator_count = len(present_indicators)
# Rule 2: Multiple indicators -> invoice with medium-high confidence
if indicator_count >= 2:
return ClassificationResult(
document_type="invoice",
confidence=0.8,
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
)
# Rule 3: Single indicator -> invoice with lower confidence
if indicator_count == 1:
return ClassificationResult(
document_type="invoice",
confidence=0.6,
reason=f"1 invoice indicator present: {present_indicators[0]}",
)
# Rule 4: No indicators -> letter
return ClassificationResult(
document_type="letter",
confidence=0.7,
reason="no invoice indicators found",
)

View File

@@ -0,0 +1,141 @@
"""
Invoice Validator
Business logic for validating extracted invoice fields.
Checks for required fields, format validity, and confidence thresholds.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ValidationIssue:
"""
Single validation issue.
Attributes:
field: Name of the field with the issue
severity: One of "error", "warning", "info"
message: Human-readable description of the issue
"""
field: str
severity: str
message: str
@dataclass(frozen=True)
class ValidationResult:
"""
Immutable result of invoice validation.
Attributes:
is_valid: True if no errors (warnings are allowed)
issues: Tuple of validation issues found
confidence: Average confidence score of validated fields
"""
is_valid: bool
issues: tuple[ValidationIssue, ...]
confidence: float
class InvoiceValidator:
"""
Validates extracted invoice fields for completeness and consistency.
Validation Rules:
1. Required fields must be present (Amount)
2. At least one payment reference should be present (warning if missing)
3. Field confidence should be above threshold (warning if below)
Required fields:
- Amount
Payment reference fields (at least one expected):
- OCR
- Bankgiro
- Plusgiro
- payment_line
"""
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
DEFAULT_MIN_CONFIDENCE: float = 0.5
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
"""
Initialize validator.
Args:
min_confidence: Minimum confidence threshold for valid fields.
Fields below this threshold produce warnings.
"""
self._min_confidence = min_confidence
def validate(
self,
fields: dict[str, str | None],
confidence: dict[str, float],
) -> ValidationResult:
"""
Validate extracted invoice fields.
Args:
fields: Dictionary of field names to extracted values
confidence: Dictionary of field names to confidence scores
Returns:
Immutable ValidationResult with validity status and issues
"""
issues: list[ValidationIssue] = []
# Check required fields
for field in self.REQUIRED_FIELDS:
if not has_value(fields.get(field)):
issues.append(
ValidationIssue(
field=field,
severity="error",
message=f"Required field '{field}' is missing",
)
)
# Check payment reference (at least one expected)
has_payment_ref = any(
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
)
if not has_payment_ref:
issues.append(
ValidationIssue(
field="payment_reference",
severity="warning",
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
)
)
# Check confidence thresholds
for field, conf in confidence.items():
if conf < self._min_confidence:
issues.append(
ValidationIssue(
field=field,
severity="warning",
message=f"Low confidence ({conf:.2f}) for field '{field}'",
)
)
# Calculate overall validity
has_errors = any(i.severity == "error" for i in issues)
avg_confidence = (
sum(confidence.values()) / len(confidence) if confidence else 0.0
)
return ValidationResult(
is_valid=not has_errors,
issues=tuple(issues),
confidence=avg_confidence,
)

View File

@@ -0,0 +1,23 @@
"""
Domain Layer Utilities
Shared helper functions for domain layer classes.
"""
from __future__ import annotations
def has_value(value: str | None) -> bool:
"""
Check if a field value is present and non-empty.
Args:
value: Field value to check
Returns:
True if value is a non-empty, non-whitespace string
"""
if value is None:
return False
if not isinstance(value, str):
return bool(value)
return bool(value.strip())

View File

@@ -1,5 +1,18 @@
from .pipeline import InferencePipeline, InferenceResult
from .pipeline import (
InferencePipeline,
InferenceResult,
CrossValidationResult,
BUSINESS_FEATURES_AVAILABLE,
)
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']
__all__ = [
'InferencePipeline',
'InferenceResult',
'CrossValidationResult',
'YOLODetector',
'Detection',
'FieldExtractor',
'BUSINESS_FEATURES_AVAILABLE',
]

View File

@@ -40,6 +40,7 @@ from .normalizers import (
EnhancedAmountNormalizer,
EnhancedDateNormalizer,
)
from .value_selector import ValueSelector
@dataclass
@@ -84,7 +85,7 @@ class FieldExtractor:
def __init__(
self,
ocr_lang: str = 'en',
ocr_lang: str = 'sv',
use_gpu: bool = False,
bbox_padding: float = 0.1,
dpi: int = 300,
@@ -169,13 +170,21 @@ class FieldExtractor:
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
matching_tokens.append((token, overlap_ratio))
# Sort by overlap ratio and combine text
# Sort by overlap ratio
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Convert to OCRTokens for value selection, then filter
from shared.ocr.paddle_ocr import OCRToken
pdf_ocr_tokens = [
OCRToken(text=t[0].text, bbox=t[0].bbox, confidence=1.0)
for t in matching_tokens
]
value_tokens = ValueSelector.select_value_tokens(pdf_ocr_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
@@ -223,13 +232,14 @@ class FieldExtractor:
# Run OCR on region
ocr_tokens = self.ocr_engine.extract_from_image(region)
# Combine all OCR text
raw_text = ' '.join(t.text for t in ocr_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Select value tokens (filter out label text)
value_tokens = ValueSelector.select_value_tokens(ocr_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text

View File

@@ -20,26 +20,98 @@ class AmountNormalizer(BaseNormalizer):
Handles various Swedish amount formats:
- With decimal: 1 234,56 kr
- With SEK suffix: 1234.56 SEK
- Payment line kronor/ore: 590 00 (space = decimal separator)
- Multiple amounts (returns the last one, usually the total)
"""
# Payment line kronor/ore pattern: "590 00" means 590.00 SEK
# Only matches when no comma/dot is present (pure digit-space-2digit format)
_KRONOR_ORE_PATTERN = re.compile(r'^(\d+)\s+(\d{2})$')
@property
def field_name(self) -> str:
return "Amount"
@classmethod
def _try_kronor_ore(cls, text: str) -> NormalizationResult | None:
"""Try to parse as payment line kronor/ore format.
Swedish payment lines separate kronor and ore with a space:
"590 00" = 590.00 SEK, "15658 00" = 15658.00 SEK
Only applies when text has no comma or dot (otherwise it's
a normal amount format with explicit decimal separator).
Returns NormalizationResult on success, None if not matched.
"""
if ',' in text or '.' in text:
return None
match = cls._KRONOR_ORE_PATTERN.match(text.strip())
if not match:
return None
kronor = match.group(1)
ore = match.group(2)
try:
amount = float(f"{kronor}.{ore}")
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
return None
@staticmethod
def _parse_amount_str(match: str) -> float | None:
"""Convert matched amount string to float, detecting European vs Anglo format.
European: 2.254,50 -> 2254.50 (dot=thousand, comma=decimal)
Anglo: 1,234.56 -> 1234.56 (comma=thousand, dot=decimal)
Swedish: 1 234,56 -> 1234.56 (space=thousand, comma=decimal)
"""
has_comma = ',' in match
has_dot = '.' in match
if has_comma and has_dot:
if match.rfind(',') > match.rfind('.'):
# European: 2.254,50
cleaned = match.replace(" ", "").replace(".", "").replace(",", ".")
else:
# Anglo: 1,234.56
cleaned = match.replace(" ", "").replace(",", "")
elif has_comma:
cleaned = match.replace(" ", "").replace(",", ".")
else:
cleaned = match.replace(" ", "")
try:
return float(cleaned)
except ValueError:
return None
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Early check: payment line kronor/ore format ("590 00" → 590.00)
kronor_ore_result = self._try_kronor_ore(text)
if kronor_ore_result is not None:
return kronor_ore_result
# Split by newlines and process line by line to get the last valid amount
lines = text.split("\n")
# Collect all valid amounts from all lines
all_amounts: list[float] = []
# Pattern for Swedish amount format (with decimals)
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
# Separate patterns for European and Anglo formats
# (?!\d) lookahead prevents partial matches (e.g. "1,23" in "1,234.56")
# European: dot=thousand, comma=decimal (2.254,50 or 1 234,56)
# Anglo: comma=thousand, dot=decimal (1,234.56 or 1234.56)
amount_pattern = (
r"(\d[\d\s.]*,\d{2})(?!\d)\s*(?:kr|SEK)?"
r"|"
r"(\d[\d\s,]*\.\d{2})(?!\d)\s*(?:kr|SEK)?"
)
for line in lines:
line = line.strip()
@@ -47,15 +119,13 @@ class AmountNormalizer(BaseNormalizer):
continue
# Find all amounts in this line
matches = re.findall(amount_pattern, line, re.IGNORECASE)
for match in matches:
amount_str = match.replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
all_amounts.append(amount)
except ValueError:
for m in re.finditer(amount_pattern, line, re.IGNORECASE):
match = m.group(1) or m.group(2)
if not match:
continue
amount = self._parse_amount_str(match)
if amount is not None and 0 < amount < 10_000_000:
all_amounts.append(amount)
# Return the last amount found (usually the total)
if all_amounts:
@@ -64,7 +134,7 @@ class AmountNormalizer(BaseNormalizer):
# Fallback: try shared validator on cleaned text
cleaned = TextCleaner.normalize_amount_text(text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and amount > 0:
if amount is not None and 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
# Try to find any decimal number
@@ -74,7 +144,7 @@ class AmountNormalizer(BaseNormalizer):
amount_str = matches[-1].replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
@@ -86,7 +156,7 @@ class AmountNormalizer(BaseNormalizer):
if match:
try:
amount = float(match.group(1))
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
@@ -98,7 +168,7 @@ class AmountNormalizer(BaseNormalizer):
# Take the last/largest number
try:
amount = float(matches[-1])
if amount > 0:
if 0 < amount < 10_000_000:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
@@ -122,31 +192,33 @@ class EnhancedAmountNormalizer(AmountNormalizer):
if not text:
return NormalizationResult.failure("Empty text")
# Early check: payment line kronor/ore format ("590 00" → 590.00)
kronor_ore_result = self._try_kronor_ore(text)
if kronor_ore_result is not None:
return kronor_ore_result
# Strategy 1: Apply OCR corrections first
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Strategy 2: Look for labeled amounts (highest priority)
# Use two capture groups: group(1) = European, group(2) = Anglo
labeled_patterns = [
# Swedish patterns
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
# Swedish patterns ((?!\d) prevents partial matches like "1,23" in "1,234.56")
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))", 1.0),
(
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
r"(?:moms|vat)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))",
0.8,
), # Lower priority for VAT
# Generic pattern
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
(r"(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))\s*(?:kr|sek|kronor)?", 0.7),
]
candidates: list[tuple[float, float, int]] = []
for pattern, priority in labeled_patterns:
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
amount_str = match.group(1).replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if 0 < amount < 10_000_000: # Reasonable range
amount = self._parse_amount_str(match.group(1))
if amount is not None and 0 < amount < 10_000_000:
candidates.append((amount, priority, match.start()))
except ValueError:
continue
if candidates:
# Sort by priority (desc), then by position (later is usually total)

View File

@@ -62,14 +62,25 @@ class InvoiceNumberNormalizer(BaseNormalizer):
# Skip if it looks like a date (YYYYMMDD)
if len(seq) == 8 and seq.startswith("20"):
continue
# Skip year-only values (2024, 2025, 2026, etc.)
if len(seq) == 4 and seq.startswith("20"):
continue
# Skip if too long (likely OCR number)
if len(seq) > 10:
continue
valid_sequences.append(seq)
if valid_sequences:
# Return shortest valid sequence
return NormalizationResult.success(min(valid_sequences, key=len))
# Prefer 4-8 digit sequences (typical invoice numbers),
# then closest to 6 digits within that range.
# This avoids picking short fragments like "775" from amounts.
def _score(seq: str) -> tuple[int, int]:
length = len(seq)
if 4 <= length <= 8:
return (1, -abs(length - 6))
return (0, -length)
return NormalizationResult.success(max(valid_sequences, key=_score))
# Fallback: extract all digits if nothing else works
digits = re.sub(r"\D", "", text)

View File

@@ -14,7 +14,7 @@ class OcrNumberNormalizer(BaseNormalizer):
Normalizes OCR (Optical Character Recognition) reference numbers.
OCR numbers in Swedish payment systems:
- Minimum 5 digits
- Minimum 2 digits
- Used for automated payment matching
"""
@@ -29,7 +29,7 @@ class OcrNumberNormalizer(BaseNormalizer):
digits = re.sub(r"\D", "", text)
if len(digits) < 5:
if len(digits) < 2:
return NormalizationResult.failure(
f"Too few digits for OCR: {len(digits)}"
)

View File

@@ -2,19 +2,39 @@
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
Supports both basic field extraction and business invoice features
(line items, VAT extraction, cross-validation).
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import logging
import time
import re
logger = logging.getLogger(__name__)
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
# Business invoice feature imports (optional - for extract_line_items mode)
try:
from ..table.line_items_extractor import LineItem, LineItemsResult, LineItemsExtractor
from ..table.structure_detector import TableDetector
from ..vat.vat_extractor import VATSummary, VATExtractor
from ..validation.vat_validator import VATValidationResult, VATValidator
BUSINESS_FEATURES_AVAILABLE = True
except ImportError:
BUSINESS_FEATURES_AVAILABLE = False
LineItem = None
LineItemsResult = None
TableDetector = None
VATSummary = None
VATValidationResult = None
@dataclass
class CrossValidationResult:
@@ -45,6 +65,10 @@ class InferenceResult:
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
cross_validation: CrossValidationResult | None = None
# Business invoice features (optional)
line_items: Any | None = None # LineItemsResult when available
vat_summary: Any | None = None # VATSummary when available
vat_validation: Any | None = None # VATValidationResult when available
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
@@ -81,8 +105,89 @@ class InferenceResult:
'payment_line_account_type': self.cross_validation.payment_line_account_type,
'details': self.cross_validation.details,
}
# Add business invoice features if present
if self.line_items is not None:
result['line_items'] = self._line_items_to_json()
if self.vat_summary is not None:
result['vat_summary'] = self._vat_summary_to_json()
if self.vat_validation is not None:
result['vat_validation'] = self._vat_validation_to_json()
return result
def _line_items_to_json(self) -> dict | None:
"""Convert LineItemsResult to JSON."""
if self.line_items is None:
return None
li = self.line_items
return {
'items': [
{
'row_index': item.row_index,
'description': item.description,
'quantity': item.quantity,
'unit': item.unit,
'unit_price': item.unit_price,
'amount': item.amount,
'article_number': item.article_number,
'vat_rate': item.vat_rate,
'is_deduction': item.is_deduction,
'confidence': item.confidence,
}
for item in li.items
],
'header_row': li.header_row,
'total_amount': li.total_amount,
}
def _vat_summary_to_json(self) -> dict | None:
"""Convert VATSummary to JSON."""
if self.vat_summary is None:
return None
vs = self.vat_summary
return {
'breakdowns': [
{
'rate': b.rate,
'base_amount': b.base_amount,
'vat_amount': b.vat_amount,
'source': b.source,
}
for b in vs.breakdowns
],
'total_excl_vat': vs.total_excl_vat,
'total_vat': vs.total_vat,
'total_incl_vat': vs.total_incl_vat,
'confidence': vs.confidence,
}
def _vat_validation_to_json(self) -> dict | None:
"""Convert VATValidationResult to JSON."""
if self.vat_validation is None:
return None
vv = self.vat_validation
return {
'is_valid': vv.is_valid,
'confidence_score': vv.confidence_score,
'math_checks': [
{
'rate': mc.rate,
'base_amount': mc.base_amount,
'expected_vat': mc.expected_vat,
'actual_vat': mc.actual_vat,
'is_valid': mc.is_valid,
'tolerance': mc.tolerance,
}
for mc in vv.math_checks
],
'total_check': vv.total_check,
'line_items_vs_summary': vv.line_items_vs_summary,
'amount_consistency': vv.amount_consistency,
'needs_review': vv.needs_review,
'review_reasons': vv.review_reasons,
}
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
@@ -104,10 +209,12 @@ class InferencePipeline:
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
ocr_lang: str = 'en',
ocr_lang: str = 'sv',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True
enable_fallback: bool = True,
enable_business_features: bool = False,
vat_tolerance: float = 0.5
):
"""
Initialize inference pipeline.
@@ -119,21 +226,46 @@ class InferencePipeline:
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
enable_business_features: Enable line items/VAT extraction
vat_tolerance: Tolerance for VAT math checks (in currency units)
"""
self.detector = YOLODetector(
model_path,
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
self.enable_business_features = enable_business_features
self.vat_tolerance = vat_tolerance
# Initialize business feature components if enabled and available
self.line_items_extractor = None
self.vat_extractor = None
self.vat_validator = None
self._business_ocr_engine = None # Lazy-initialized for VAT text extraction
self._table_detector = None # Shared TableDetector for line items extraction
if enable_business_features:
if not BUSINESS_FEATURES_AVAILABLE:
raise ImportError(
"Business features require table, vat, and validation modules. "
"Please ensure they are properly installed."
)
# Create shared TableDetector for performance (PP-StructureV3 init is slow)
self._table_detector = TableDetector()
# Pass shared detector to LineItemsExtractor
self.line_items_extractor = LineItemsExtractor(table_detector=self._table_detector)
self.vat_extractor = VATExtractor()
self.vat_validator = VATValidator(tolerance=vat_tolerance)
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None
document_id: str | None = None,
extract_line_items: bool | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
@@ -141,6 +273,8 @@ class InferencePipeline:
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
extract_line_items: Whether to extract line items and VAT info.
If None, uses the enable_business_features setting from __init__.
Returns:
InferenceResult with extracted fields
@@ -156,9 +290,37 @@ class InferencePipeline:
document_id=document_id or Path(pdf_path).stem
)
# Determine if business features should be used
use_business_features = (
extract_line_items if extract_line_items is not None
else self.enable_business_features
)
try:
all_detections = []
all_extracted = []
all_ocr_text = [] # Collect OCR text for VAT extraction
# Check if PDF has readable text layer (avoids OCR for text PDFs)
from shared.pdf.extractor import PDFDocument
is_text_pdf = False
pdf_tokens_by_page: dict[int, list] = {}
try:
with PDFDocument(pdf_path) as pdf_doc:
is_text_pdf = pdf_doc.is_text_pdf()
if is_text_pdf:
for pg in range(pdf_doc.page_count):
pdf_tokens_by_page[pg] = list(
pdf_doc.extract_text_tokens(pg)
)
logger.info(
"Text PDF detected, extracted %d tokens from %d pages",
sum(len(t) for t in pdf_tokens_by_page.values()),
len(pdf_tokens_by_page),
)
except Exception as e:
logger.warning("PDF text detection failed, falling back to OCR: %s", e)
is_text_pdf = False
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
@@ -172,9 +334,24 @@ class InferencePipeline:
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
if is_text_pdf and page_no in pdf_tokens_by_page:
extracted = self.extractor.extract_from_detection_with_pdf(
detection,
pdf_tokens_by_page[page_no],
image.width,
image.height,
)
else:
extracted = self.extractor.extract_from_detection(
detection, image_array
)
all_extracted.append(extracted)
# Collect full-page OCR text for VAT extraction (only if business features enabled)
if use_business_features:
page_text = self._get_full_page_text(image_array)
all_ocr_text.append(page_text)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
@@ -184,6 +361,11 @@ class InferencePipeline:
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
self._dedup_invoice_number(result)
# Extract business invoice features if enabled
if use_business_features:
self._extract_business_features(pdf_path, result, '\n'.join(all_ocr_text))
result.success = len(result.fields) > 0
@@ -194,8 +376,85 @@ class InferencePipeline:
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _get_full_page_text(self, image_array) -> str:
"""Extract full page text using OCR for VAT extraction."""
from shared.ocr import OCREngine
import logging
logger = logging.getLogger(__name__)
try:
# Lazy initialize OCR engine to avoid repeated model loading
if self._business_ocr_engine is None:
self._business_ocr_engine = OCREngine()
tokens = self._business_ocr_engine.extract_from_image(image_array, page_no=0)
return ' '.join(t.text for t in tokens)
except Exception as e:
logger.warning(f"OCR extraction for VAT failed: {e}")
return ""
def _extract_business_features(
self,
pdf_path: str | Path,
result: InferenceResult,
full_text: str
) -> None:
"""
Extract line items, VAT summary, and perform cross-validation.
Args:
pdf_path: Path to PDF file
result: InferenceResult to populate
full_text: Full OCR text from all pages
"""
if not BUSINESS_FEATURES_AVAILABLE:
result.errors.append("Business features not available")
return
if not self.line_items_extractor or not self.vat_extractor or not self.vat_validator:
result.errors.append("Business feature extractors not initialized")
return
try:
# Extract line items from tables
logger.info(f"Extracting line items from PDF: {pdf_path}")
line_items_result = self.line_items_extractor.extract_from_pdf(str(pdf_path))
logger.info(f"Line items extraction result: {line_items_result is not None}, items={len(line_items_result.items) if line_items_result else 0}")
if line_items_result and line_items_result.items:
result.line_items = line_items_result
logger.info(f"Set result.line_items with {len(line_items_result.items)} items")
# Extract VAT summary from text
logger.info(f"Extracting VAT summary from text ({len(full_text)} chars)")
vat_summary = self.vat_extractor.extract(full_text)
logger.info(f"VAT summary extraction result: {vat_summary is not None}")
if vat_summary:
result.vat_summary = vat_summary
# Cross-validate VAT information
existing_amount = result.fields.get('Amount')
vat_validation = self.vat_validator.validate(
vat_summary,
line_items=line_items_result,
existing_amount=str(existing_amount) if existing_amount else None
)
result.vat_validation = vat_validation
logger.info(f"VAT validation completed: is_valid={vat_validation.is_valid if vat_validation else None}")
except Exception as e:
import traceback
error_detail = f"{type(e).__name__}: {e}"
logger.error(f"Business feature extraction failed: {error_detail}\n{traceback.format_exc()}")
result.errors.append(f"Business feature extraction error: {error_detail}")
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
"""Merge extracted fields, keeping best candidate for each field.
Selection priority:
1. Prefer candidates without validation errors
2. Among equal validity, prefer higher confidence
"""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
@@ -208,15 +467,59 @@ class InferencePipeline:
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence)
# Sort by: (no validation error, confidence) - descending
# This prefers candidates without errors, then by confidence
best = max(
candidates,
key=lambda x: (x.validation_error is None, x.confidence)
)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Validate date consistency
self._validate_dates(result)
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
# Remove InvoiceNumber if it duplicates OCR or Bankgiro
self._dedup_invoice_number(result)
def _validate_dates(self, result: InferenceResult) -> None:
"""Remove InvoiceDueDate if it is earlier than InvoiceDate."""
invoice_date = result.fields.get('InvoiceDate')
due_date = result.fields.get('InvoiceDueDate')
if invoice_date and due_date and due_date < invoice_date:
del result.fields['InvoiceDueDate']
result.confidence.pop('InvoiceDueDate', None)
result.bboxes.pop('InvoiceDueDate', None)
def _dedup_invoice_number(self, result: InferenceResult) -> None:
"""Remove InvoiceNumber if it duplicates OCR or Bankgiro digits."""
inv_num = result.fields.get('InvoiceNumber')
if not inv_num:
return
inv_digits = re.sub(r'\D', '', str(inv_num))
# Check against OCR
ocr = result.fields.get('OCR')
if ocr and inv_digits == re.sub(r'\D', '', str(ocr)):
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
return
# Check against Bankgiro (exact or substring match)
bg = result.fields.get('Bankgiro')
if bg:
bg_digits = re.sub(r'\D', '', str(bg))
if inv_digits == bg_digits or inv_digits in bg_digits:
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
@@ -375,10 +678,14 @@ class InferencePipeline:
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
important_fields = ['InvoiceDate', 'InvoiceDueDate', 'supplier_organisation_number']
key_missing = sum(1 for f in key_fields if f not in result.fields)
important_missing = sum(1 for f in important_fields if f not in result.fields)
# Fallback if any key field missing OR 2+ important fields missing
return key_missing >= 1 or important_missing >= 2
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
@@ -410,12 +717,13 @@ class InferencePipeline:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
r'(?:att\s+betala)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
@@ -423,6 +731,20 @@ class InferencePipeline:
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
'InvoiceDate': [
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'InvoiceDueDate': [
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'supplier_organisation_number': [
r'(?:org\.?\s*n[ru]|organisationsnummer)\s*[:.]?\s*(\d{6}[-\s]?\d{4})',
],
'Plusgiro': [
r'(?:plusgiro|pg)\s*[:.]?\s*(\d[\d\s-]{4,12}\d)',
],
}
for field_name, field_patterns in patterns.items():
@@ -445,6 +767,22 @@ class InferencePipeline:
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Normalize DD/MM/YYYY to YYYY-MM-DD
date_match = re.match(r'(\d{2})[-/](\d{2})[-/](\d{4})', value)
if date_match:
value = f"{date_match.group(3)}-{date_match.group(2)}-{date_match.group(1)}"
# Replace / with -
value = value.replace('/', '-')
elif field_name == 'InvoiceNumber':
# Skip year-like values (2024, 2025, 2026, etc.)
if re.match(r'^20\d{2}$', value):
continue
elif field_name == 'supplier_organisation_number':
# Ensure NNNNNN-NNNN format
digits = re.sub(r'\D', '', value)
if len(digits) == 10:
value = f"{digits[:6]}-{digits[6:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex

View File

@@ -0,0 +1,172 @@
"""
Value Selector Module.
Selects the most likely value token(s) from OCR output per field type,
filtering out label text before sending to normalizer.
Stateless and pure -- easy to test, no side effects.
"""
import re
from typing import Final
from shared.ocr.paddle_ocr import OCRToken
# Swedish label keywords commonly found near field values
LABEL_KEYWORDS: Final[frozenset[str]] = frozenset({
"fakturanummer", "fakturanr", "fakturadatum", "forfallodag", "forfalldatum",
"bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa",
"total", "att", "betala", "kundnummer", "organisationsnummer",
"org", "nr", "datum", "nummer", "ref", "referens",
"momsreg", "vat", "moms", "sek", "kr",
"org.nr", "kund", "faktura", "invoice",
})
# Patterns
_DATE_PATTERN = re.compile(
r"\d{4}[-./]\d{2}[-./]\d{2}" # 2024-01-15, 2024.01.15
r"|"
r"\d{2}[-./]\d{2}[-./]\d{4}" # 15/01/2024
r"|"
r"\d{8}" # 20240115
)
_AMOUNT_PATTERN = re.compile(
r"\d[\d\s.]*,\d{2}(?:\s*(?:kr|SEK))?$" # European: 2.254,50 SEK
r"|"
r"\d[\d\s,]*\.\d{2}(?:\s*(?:kr|SEK))?$" # Anglo: 1,234.56 SEK
)
_BANKGIRO_PATTERN = re.compile(
r"^\d{3,4}-\d{4}$" # 123-4567
r"|"
r"^\d{7,8}$" # 1234567 or 12345678
)
_PLUSGIRO_PATTERN = re.compile(
r"^\d+-\d$" # 12345-6
r"|"
r"^\d{2,8}$" # 1234567
)
_ORG_NUMBER_PATTERN = re.compile(
r"\d{6}-?\d{4}" # 556123-4567 or 5561234567
)
def _is_label(text: str) -> bool:
"""Check if token text is a known Swedish label keyword."""
cleaned = text.lower().rstrip(":").strip()
return cleaned in LABEL_KEYWORDS
def _count_digits(text: str) -> int:
"""Count digit characters in text."""
return sum(c.isdigit() for c in text)
class ValueSelector:
"""Selects value token(s) from OCR output, filtering label text.
Pure static methods -- no state, no side effects.
Fallback: always returns all tokens if no pattern matches,
so this can never make results worse than current behavior.
"""
@staticmethod
def select_value_tokens(
tokens: list[OCRToken],
field_name: str,
) -> list[OCRToken]:
"""Select the most likely value token(s) for a given field.
Args:
tokens: OCR tokens from the detected region.
field_name: Normalized field name (e.g. "InvoiceDate", "Amount").
Returns:
Filtered list of value tokens. Falls back to all tokens
if no field-specific pattern matches.
"""
if not tokens:
return []
selector = _FIELD_SELECTORS.get(field_name, _fallback_selector)
selected = selector(tokens)
# Safety: never return empty if we had input tokens
if not selected:
return list(tokens)
return selected
@staticmethod
def _select_date(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _DATE_PATTERN)
@staticmethod
def _select_amount(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _AMOUNT_PATTERN)
@staticmethod
def _select_bankgiro(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _BANKGIRO_PATTERN)
@staticmethod
def _select_plusgiro(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _PLUSGIRO_PATTERN)
@staticmethod
def _select_org_number(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _ORG_NUMBER_PATTERN)
@staticmethod
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
"""Select token with the longest digit sequence (min 2 digits)."""
best: OCRToken | None = None
best_count = 0
for token in tokens:
digit_count = _count_digits(token.text)
if digit_count >= 2 and digit_count > best_count:
best = token
best_count = digit_count
return [best] if best else []
@staticmethod
def _select_by_label_exclusion(tokens: list[OCRToken]) -> list[OCRToken]:
"""Remove label keywords, keep remaining tokens."""
return [t for t in tokens if not _is_label(t.text)]
@staticmethod
def _select_payment_line(tokens: list[OCRToken]) -> list[OCRToken]:
"""Payment line keeps all tokens (needs full text)."""
return list(tokens)
def _select_by_pattern(
tokens: list[OCRToken],
pattern: re.Pattern[str],
) -> list[OCRToken]:
"""Select tokens matching a regex pattern."""
return [t for t in tokens if pattern.search(t.text.strip())]
def _fallback_selector(tokens: list[OCRToken]) -> list[OCRToken]:
"""Default: return all tokens unchanged."""
return list(tokens)
# Map field names to selector functions
_FIELD_SELECTORS: Final[dict[str, callable]] = {
"InvoiceDate": ValueSelector._select_date,
"InvoiceDueDate": ValueSelector._select_date,
"Amount": ValueSelector._select_amount,
"Bankgiro": ValueSelector._select_bankgiro,
"Plusgiro": ValueSelector._select_plusgiro,
"OCR": ValueSelector._select_ocr_number,
"InvoiceNumber": ValueSelector._select_by_label_exclusion,
"supplier_org_number": ValueSelector._select_org_number,
"customer_number": ValueSelector._select_by_label_exclusion,
"payment_line": ValueSelector._select_payment_line,
}

View File

@@ -0,0 +1,32 @@
"""
Table detection and extraction module.
This module provides PP-StructureV3-based table detection for invoices,
and line items extraction from detected tables.
"""
from .structure_detector import (
TableDetectionResult,
TableDetector,
TableDetectorConfig,
)
from .line_items_extractor import (
LineItem,
LineItemsResult,
LineItemsExtractor,
ColumnMapper,
HTMLTableParser,
)
__all__ = [
# Structure detection
"TableDetectionResult",
"TableDetector",
"TableDetectorConfig",
# Line items extraction
"LineItem",
"LineItemsResult",
"LineItemsExtractor",
"ColumnMapper",
"HTMLTableParser",
]

View File

@@ -0,0 +1,204 @@
"""
HTML Table Parser
Parses HTML tables into structured data and maps columns to field names.
"""
from html.parser import HTMLParser
import logging
logger = logging.getLogger(__name__)
# Configuration constants
# Minimum pattern length to avoid false positives from short substrings
MIN_PATTERN_MATCH_LENGTH = 3
# Exact match bonus for column mapping priority
EXACT_MATCH_BONUS = 100
# Swedish column name mappings
# Extended to support multiple invoice types: product invoices, rental invoices, utility bills
COLUMN_MAPPINGS = {
"article_number": [
"art nummer",
"artikelnummer",
"artikel",
"artnr",
"art.nr",
"art nr",
"objektnummer", # Rental: property reference
"objekt",
],
"description": [
"beskrivning",
"produktbeskrivning",
"produkt",
"tjänst",
"text",
"benämning",
"vara/tjänst",
"vara",
# Rental invoice specific
"specifikation",
"spec",
"hyresperiod", # Rental period
"period",
"typ", # Type of charge
# Utility bills
"förbrukning", # Consumption
"avläsning", # Meter reading
],
"quantity": ["antal", "qty", "st", "pcs", "kvantitet", "", "kvm"],
"unit": ["enhet", "unit"],
"unit_price": ["á-pris", "a-pris", "pris", "styckpris", "enhetspris", "à pris"],
"amount": [
"belopp",
"summa",
"total",
"netto",
"rad summa",
# Rental specific
"hyra", # Rent
"avgift", # Fee
"kostnad", # Cost
"debitering", # Charge
"totalt", # Total
],
"vat_rate": ["moms", "moms%", "vat", "skatt", "moms %"],
# Additional field for rental: deductions/adjustments
"deduction": [
"avdrag", # Deduction
"rabatt", # Discount
"kredit", # Credit
],
}
# Keywords that indicate NOT a line items table
SUMMARY_KEYWORDS = [
"frakt",
"faktura.avg",
"fakturavg",
"exkl.moms",
"att betala",
"öresavr",
"bankgiro",
"plusgiro",
"ocr",
"forfallodatum",
"förfallodatum",
]
class _TableHTMLParser(HTMLParser):
"""Internal HTML parser for tables."""
def __init__(self):
super().__init__()
self.rows: list[list[str]] = []
self.current_row: list[str] = []
self.current_cell: str = ""
self.in_td = False
self.in_thead = False
self.header_row: list[str] = []
def handle_starttag(self, tag, attrs):
if tag == "tr":
self.current_row = []
elif tag in ("td", "th"):
self.in_td = True
self.current_cell = ""
elif tag == "thead":
self.in_thead = True
def handle_endtag(self, tag):
if tag in ("td", "th"):
self.in_td = False
self.current_row.append(self.current_cell.strip())
elif tag == "tr":
if self.current_row:
if self.in_thead:
self.header_row = self.current_row
else:
self.rows.append(self.current_row)
elif tag == "thead":
self.in_thead = False
def handle_data(self, data):
if self.in_td:
self.current_cell += data
class HTMLTableParser:
"""Parse HTML tables into structured data."""
def parse(self, html: str) -> tuple[list[str], list[list[str]]]:
"""
Parse HTML table and return header and rows.
Args:
html: HTML string containing table.
Returns:
Tuple of (header_row, data_rows).
"""
parser = _TableHTMLParser()
parser.feed(html)
return parser.header_row, parser.rows
class ColumnMapper:
"""Map column headers to field names."""
def __init__(self, mappings: dict[str, list[str]] | None = None):
"""
Initialize column mapper.
Args:
mappings: Custom column mappings. Uses Swedish defaults if None.
"""
self.mappings = mappings or COLUMN_MAPPINGS
def map(self, headers: list[str]) -> dict[int, str]:
"""
Map column indices to field names.
Args:
headers: List of column header strings.
Returns:
Dictionary mapping column index to field name.
"""
mapping = {}
for idx, header in enumerate(headers):
normalized = self._normalize(header)
if not normalized.strip():
continue
best_match = None
best_match_len = 0
for field_name, patterns in self.mappings.items():
for pattern in patterns:
if pattern == normalized:
# Exact match gets highest priority
best_match = field_name
best_match_len = len(pattern) + EXACT_MATCH_BONUS
break
elif pattern in normalized and len(pattern) > best_match_len:
# Partial match requires minimum length to avoid false positives
if len(pattern) >= MIN_PATTERN_MATCH_LENGTH:
best_match = field_name
best_match_len = len(pattern)
if best_match_len > EXACT_MATCH_BONUS:
# Found exact match, no need to check other fields
break
if best_match:
mapping[idx] = best_match
return mapping
def _normalize(self, header: str) -> str:
"""Normalize header text for matching."""
return header.lower().strip().replace(".", "").replace("-", " ")

View File

@@ -0,0 +1,395 @@
"""
Line Items Extractor
Extracts structured line items from HTML tables produced by PP-StructureV3.
Handles Swedish invoice formats including reversed tables (header at bottom).
Includes fallback text-based extraction for invoices without detectable table structures.
"""
from pathlib import Path
import re
import logging
logger = logging.getLogger(__name__)
# Import models
from .models import LineItem, LineItemsResult
# Import parsers
from .html_table_parser import (
HTMLTableParser,
ColumnMapper,
COLUMN_MAPPINGS,
SUMMARY_KEYWORDS,
)
# Import merged cell handler
from .merged_cell_handler import MergedCellHandler
# Re-export for backward compatibility
__all__ = [
"LineItem",
"LineItemsResult",
"LineItemsExtractor",
"ColumnMapper",
"HTMLTableParser",
"COLUMN_MAPPINGS",
"SUMMARY_KEYWORDS",
]
# Configuration constants
# Minimum keyword matches required to detect a header row
MIN_HEADER_KEYWORD_MATCHES = 2
class LineItemsExtractor:
"""Extract structured line items from HTML tables."""
def __init__(
self,
column_mapper: ColumnMapper | None = None,
table_detector: "TableDetector | None" = None,
enable_text_fallback: bool = True,
):
"""
Initialize extractor.
Args:
column_mapper: Custom column mapper. Uses default Swedish mappings if None.
table_detector: Optional shared TableDetector instance (avoids slow re-init).
enable_text_fallback: Enable text-based extraction as fallback.
"""
self.parser = HTMLTableParser()
self.mapper = column_mapper or ColumnMapper()
self.merged_cell_handler = MergedCellHandler(self.mapper)
self._table_detector = table_detector
self.enable_text_fallback = enable_text_fallback
def extract(self, html: str) -> LineItemsResult:
"""
Extract line items from HTML table.
Args:
html: HTML string containing table.
Returns:
LineItemsResult with extracted items.
"""
header, rows = self.parser.parse(html)
# Check for merged header (rental invoice pattern)
if self.merged_cell_handler.has_merged_header(header):
logger.debug("Detected merged header, using merged cell extraction")
items = self.merged_cell_handler.extract_from_merged_cells(header, rows)
return LineItemsResult(
items=items,
header_row=header,
raw_html=html,
is_reversed=False,
)
# Check if merged header in first row (no explicit header)
if not header and rows and self.merged_cell_handler.has_merged_header(rows[0]):
logger.debug("Detected merged header in first row")
items = self.merged_cell_handler.extract_from_merged_cells(rows[0], rows[1:])
return LineItemsResult(
items=items,
header_row=rows[0],
raw_html=html,
is_reversed=False,
)
# Check for vertically merged cells
if self.merged_cell_handler.has_vertically_merged_cells(rows):
logger.debug("Detected vertically merged cells, splitting rows")
header, rows = self.merged_cell_handler.split_merged_rows(rows)
# If no explicit header, try to detect it
is_reversed = False
if not header and rows:
header_idx, detected_header, is_at_end = self._detect_header_row(rows)
if header_idx >= 0:
header = detected_header
is_reversed = is_at_end
if is_at_end:
# Reversed table: header at bottom
rows = rows[:header_idx]
else:
rows = rows[header_idx + 1:]
# Map columns
column_map = self.mapper.map(header)
if not column_map:
# Couldn't identify columns
return LineItemsResult(
items=[],
header_row=header,
raw_html=html,
is_reversed=is_reversed,
)
# Extract items
items = self._extract_items(rows, column_map)
return LineItemsResult(
items=items,
header_row=header,
raw_html=html,
is_reversed=is_reversed,
)
def extract_from_pdf(self, pdf_path: str | Path) -> LineItemsResult | None:
"""
Extract line items from PDF using table detection.
Args:
pdf_path: Path to PDF file.
Returns:
LineItemsResult if tables found, None otherwise.
"""
from .structure_detector import TableDetector
# Use shared detector or create new one
detector = self._table_detector or TableDetector()
# Detect tables in PDF
tables, parsing_res_list = self._detect_tables_with_parsing(detector, str(pdf_path))
# Try structured table extraction first
for table_result in tables:
if not table_result.html:
continue
# Check if this looks like a line items table
header, _ = self.parser.parse(table_result.html)
if self.is_line_items_table(header):
result = self.extract(table_result.html)
if result.items:
return result
# Fallback to text-based extraction if enabled
if self.enable_text_fallback and parsing_res_list:
return self._try_text_fallback(parsing_res_list)
return None
def _detect_tables_with_parsing(
self, detector: "TableDetector", pdf_path: str
) -> tuple[list, list]:
"""
Detect tables in PDF and return both table results and parsing_res.
Returns:
Tuple of (table_results, parsing_res_list)
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
pdf_path = Path(pdf_path)
if not pdf_path.exists():
logger.warning(f"PDF not found: {pdf_path}")
return [], []
if not pdf_path.is_file():
logger.warning(f"Path is not a file: {pdf_path}")
return [], []
# Ensure detector is initialized
detector._ensure_initialized()
# Render first page
parsing_res_list = []
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=300):
if page_no == 0:
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Detect tables using shared detector
tables = detector.detect(image_array)
# Also get parsing results for text fallback
if detector._pipeline is not None:
try:
result = detector._pipeline.predict(image_array)
# Extract parsing_res from result (API varies by version)
if isinstance(result, dict) and "parsing_res_list" in result:
parsing_res_list = result.get("parsing_res_list", [])
elif hasattr(result, "parsing_res_list"):
parsing_res_list = result.parsing_res_list or []
except Exception as e:
logger.debug(f"Could not get parsing_res: {e}")
return tables, parsing_res_list
return [], []
def _try_text_fallback(self, parsing_res_list: list) -> LineItemsResult | None:
"""
Try text-based extraction from parsing results.
Args:
parsing_res_list: Parsing results from PP-StructureV3.
Returns:
LineItemsResult if extraction successful, None otherwise.
"""
from .text_line_items_extractor import TextLineItemsExtractor, convert_text_line_item
text_extractor = TextLineItemsExtractor()
text_result = text_extractor.extract_from_parsing_res(parsing_res_list)
if text_result and text_result.items:
# Convert TextLineItem to LineItem
items = [convert_text_line_item(item) for item in text_result.items]
return LineItemsResult(
items=items,
header_row=text_result.header_row,
raw_html="", # No HTML for text-based extraction
is_reversed=False,
)
return None
def is_line_items_table(self, header: list[str]) -> bool:
"""
Check if header indicates a line items table (vs summary/payment table).
Args:
header: List of column header strings.
Returns:
True if this appears to be a line items table.
"""
if not header:
return False
header_text = " ".join(h.lower() for h in header)
# Check for summary table keywords (NOT a line items table)
for keyword in SUMMARY_KEYWORDS:
if keyword in header_text:
return False
# Check for line items keywords
column_map = self.mapper.map(header)
has_description = "description" in column_map.values()
has_amount = "amount" in column_map.values()
return has_description or has_amount
def _detect_header_row(
self, rows: list[list[str]]
) -> tuple[int, list[str], bool]:
"""
Detect which row is the header row.
PP-StructureV3 sometimes places headers at the bottom (reversed tables).
Returns:
(header_index, header_row, is_at_end)
"""
header_keywords = set()
for patterns in COLUMN_MAPPINGS.values():
header_keywords.update(patterns)
best_match = (-1, [], 0) # (index, row, match_count)
for i, row in enumerate(rows):
# Skip empty rows
if not any(cell.strip() for cell in row):
continue
row_text = " ".join(cell.lower() for cell in row)
matches = sum(1 for kw in header_keywords if kw in row_text)
if matches > best_match[2]:
best_match = (i, row, matches)
if best_match[2] >= MIN_HEADER_KEYWORD_MATCHES:
header_idx = best_match[0]
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
return header_idx, best_match[1], is_at_end
return -1, [], False
def _extract_items(
self, rows: list[list[str]], column_map: dict[int, str]
) -> list[LineItem]:
"""
Extract line items from rows using column mapping.
Args:
rows: Data rows (excluding header).
column_map: Mapping of column index to field name.
Returns:
List of LineItem objects.
"""
items = []
for row_idx, row in enumerate(rows):
# Skip empty rows
if not any(cell.strip() for cell in row):
continue
item_data = {"row_index": row_idx}
for col_idx, field_name in column_map.items():
if col_idx < len(row):
value = row[col_idx].strip()
if value:
item_data[field_name] = value
# Check for deduction
is_deduction = False
description = item_data.get("description", "")
amount = item_data.get("amount", "")
if description:
desc_lower = description.lower()
if any(kw in desc_lower for kw in ["avdrag", "rabatt", "kredit"]):
is_deduction = True
if amount and amount.startswith("-"):
is_deduction = True
# Create line item if we have at least description or amount
if item_data.get("description") or item_data.get("amount"):
item = LineItem(
row_index=row_idx,
description=item_data.get("description"),
quantity=item_data.get("quantity"),
unit=item_data.get("unit"),
unit_price=item_data.get("unit_price"),
amount=item_data.get("amount"),
article_number=item_data.get("article_number"),
vat_rate=item_data.get("vat_rate"),
is_deduction=is_deduction,
)
items.append(item)
return items
# Backward compatibility: expose merged cell handler methods
def _has_merged_header(self, header: list[str] | None) -> bool:
"""Check if header appears to be merged. Delegates to MergedCellHandler."""
return self.merged_cell_handler.has_merged_header(header)
def _has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
"""Check for vertically merged cells. Delegates to MergedCellHandler."""
return self.merged_cell_handler.has_vertically_merged_cells(rows)
def _split_merged_rows(
self, rows: list[list[str]]
) -> tuple[list[str], list[list[str]]]:
"""Split merged rows. Delegates to MergedCellHandler."""
return self.merged_cell_handler.split_merged_rows(rows)
def _extract_from_merged_cells(
self, header: list[str], rows: list[list[str]]
) -> list[LineItem]:
"""Extract from merged cells. Delegates to MergedCellHandler."""
return self.merged_cell_handler.extract_from_merged_cells(header, rows)

View File

@@ -0,0 +1,423 @@
"""
Merged Cell Handler
Handles detection and extraction of data from tables with merged cells,
a common issue with PP-StructureV3 OCR output.
"""
import re
import logging
from typing import TYPE_CHECKING
from .models import LineItem
if TYPE_CHECKING:
from .html_table_parser import ColumnMapper
logger = logging.getLogger(__name__)
# Minimum positive amount to consider as line item (filters noise like row indices)
MIN_AMOUNT_THRESHOLD = 100
class MergedCellHandler:
"""Handles tables with vertically merged cells from PP-StructureV3."""
def __init__(self, mapper: "ColumnMapper"):
"""
Initialize handler.
Args:
mapper: ColumnMapper instance for header keyword detection.
"""
self.mapper = mapper
def has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
"""
Check if table rows contain vertically merged data in single cells.
PP-StructureV3 sometimes merges multiple table rows into single cells, e.g.:
["Produktnr 1457280 1457280 1060381", "", "Antal 6ST 6ST 1ST", "Pris 127,20 127,20 159,20"]
Detection: cells contain repeating patterns of numbers or keywords suggesting multiple lines.
"""
if not rows:
return False
for row in rows:
for cell in row:
if not cell or len(cell) < 20:
continue
# Check for multiple product numbers (7+ digit patterns)
product_nums = re.findall(r"\b\d{7}\b", cell)
if len(product_nums) >= 2:
logger.debug(f"has_vertically_merged_cells: found {len(product_nums)} product numbers in cell")
return True
# Check for multiple prices (Swedish format: 123,45 or 1 234,56)
prices = re.findall(r"\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b", cell)
if len(prices) >= 3:
logger.debug(f"has_vertically_merged_cells: found {len(prices)} prices in cell")
return True
# Check for multiple quantity patterns (e.g., "6ST 6ST 1ST")
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", cell)
if len(quantities) >= 2:
logger.debug(f"has_vertically_merged_cells: found {len(quantities)} quantities in cell")
return True
return False
def split_merged_rows(
self, rows: list[list[str]]
) -> tuple[list[str], list[list[str]]]:
"""
Split vertically merged cells back into separate rows.
Handles complex cases where PP-StructureV3 merges content across
multiple HTML rows. For example, 5 line items might be spread across
3 HTML rows with content mixed together.
Strategy:
1. Merge all row content per column
2. Detect how many actual data rows exist (by counting product numbers)
3. Split each column's content into that many lines
Returns header and data rows.
"""
if not rows:
return [], []
# Filter out completely empty rows
non_empty_rows = [r for r in rows if any(cell.strip() for cell in r)]
if not non_empty_rows:
return [], rows
# Determine column count
col_count = max(len(r) for r in non_empty_rows)
# Merge content from all rows for each column
merged_columns = []
for col_idx in range(col_count):
col_content = []
for row in non_empty_rows:
if col_idx < len(row) and row[col_idx].strip():
col_content.append(row[col_idx].strip())
merged_columns.append(" ".join(col_content))
logger.debug(f"split_merged_rows: merged columns = {merged_columns}")
# Count how many actual data rows we should have
# Use the column with most product numbers as reference
expected_rows = self._count_expected_rows(merged_columns)
logger.debug(f"split_merged_rows: expecting {expected_rows} data rows")
if expected_rows <= 1:
# Not enough data for splitting
return [], rows
# Split each column based on expected row count
split_columns = []
for col_idx, col_text in enumerate(merged_columns):
if not col_text.strip():
split_columns.append([""] * (expected_rows + 1)) # +1 for header
continue
lines = self._split_cell_content_for_rows(col_text, expected_rows)
split_columns.append(lines)
# Ensure all columns have same number of lines (immutable approach)
max_lines = max(len(col) for col in split_columns)
split_columns = [
col + [""] * (max_lines - len(col))
for col in split_columns
]
logger.debug(f"split_merged_rows: split into {max_lines} lines total")
# First line is header, rest are data rows
header = [col[0] for col in split_columns]
data_rows = []
for line_idx in range(1, max_lines):
row = [col[line_idx] if line_idx < len(col) else "" for col in split_columns]
if any(cell.strip() for cell in row):
data_rows.append(row)
logger.debug(f"split_merged_rows: header={header}, data_rows count={len(data_rows)}")
return header, data_rows
def _count_expected_rows(self, merged_columns: list[str]) -> int:
"""
Count how many data rows should exist based on content patterns.
Returns the maximum count found from:
- Product numbers (7 digits)
- Quantity patterns (number + ST/PCS)
- Amount patterns (in columns likely to be totals)
"""
max_count = 0
for col_text in merged_columns:
if not col_text:
continue
# Count product numbers (most reliable indicator)
product_nums = re.findall(r"\b\d{7}\b", col_text)
max_count = max(max_count, len(product_nums))
# Count quantities (e.g., "6ST 6ST 1ST 1ST 1ST")
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", col_text)
max_count = max(max_count, len(quantities))
return max_count
def _split_cell_content_for_rows(self, cell: str, expected_rows: int) -> list[str]:
"""
Split cell content knowing how many data rows we expect.
This is smarter than split_cell_content because it knows the target count.
"""
cell = cell.strip()
# Try product number split first
product_pattern = re.compile(r"(\b\d{7}\b)")
products = product_pattern.findall(cell)
if len(products) == expected_rows:
parts = product_pattern.split(cell)
header = parts[0].strip() if parts else ""
# Include description text after each product number
values = []
for i in range(1, len(parts), 2): # Odd indices are product numbers
if i < len(parts):
prod_num = parts[i].strip()
# Check if there's description text after
desc = parts[i + 1].strip() if i + 1 < len(parts) else ""
# If description looks like text (not another pattern), include it
if desc and not re.match(r"^\d{7}$", desc):
# Truncate at next product number pattern if any
desc_clean = re.split(r"\d{7}", desc)[0].strip()
if desc_clean:
values.append(f"{prod_num} {desc_clean}")
else:
values.append(prod_num)
else:
values.append(prod_num)
if len(values) == expected_rows:
return [header] + values
# Try quantity split
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
quantities = qty_pattern.findall(cell)
if len(quantities) == expected_rows:
parts = qty_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
if len(values) == expected_rows:
return [header] + values
# Try amount split for discount+totalsumma columns
cell_lower = cell.lower()
has_discount = any(kw in cell_lower for kw in ["rabatt", "discount"])
has_total = any(kw in cell_lower for kw in ["totalsumma", "total", "summa", "belopp"])
if has_discount and has_total:
# Extract only amounts (3+ digit numbers), skip discount percentages
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
amounts = amount_pattern.findall(cell)
if len(amounts) >= expected_rows:
# Take the last expected_rows amounts (they are likely the totals)
return ["Totalsumma"] + amounts[:expected_rows]
# Try price split
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
prices = price_pattern.findall(cell)
if len(prices) >= expected_rows:
parts = price_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
if len(values) >= expected_rows:
return [header] + values[:expected_rows]
# Fall back to original single-value behavior
return [cell]
def split_cell_content(self, cell: str) -> list[str]:
"""
Split a cell containing merged multi-line content.
Strategies:
1. Look for product number patterns (7 digits)
2. Look for quantity patterns (number + ST/PCS)
3. Look for price patterns (with decimal)
4. Handle interleaved discount+amount patterns
"""
cell = cell.strip()
# Strategy 1: Split by product numbers (common pattern: "Produktnr 1234567 1234568")
product_pattern = re.compile(r"(\b\d{7}\b)")
products = product_pattern.findall(cell)
if len(products) >= 2:
# Extract header (text before first product number) and values
parts = product_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p for p in parts[1:] if p.strip() and re.match(r"\d{7}", p)]
return [header] + values
# Strategy 2: Split by quantities (e.g., "Antal 6ST 6ST 1ST")
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
quantities = qty_pattern.findall(cell)
if len(quantities) >= 2:
parts = qty_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
return [header] + values
# Strategy 3: Handle interleaved discount+amount (e.g., "Rabatt i% Totalsumma 10,0 686,88 10,0 686,88")
# Check if header contains two keywords indicating merged columns
cell_lower = cell.lower()
has_discount_header = any(kw in cell_lower for kw in ["rabatt", "discount"])
has_amount_header = any(kw in cell_lower for kw in ["totalsumma", "summa", "belopp", "total"])
if has_discount_header and has_amount_header:
# Extract all numbers and pair them (discount, amount, discount, amount, ...)
# Pattern for amounts: 3+ digit numbers with decimals (e.g., 686,88)
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
amounts = amount_pattern.findall(cell)
if len(amounts) >= 2:
# Return header as "Totalsumma" (amount header) so it maps to amount field, not deduction
# This avoids the "Rabatt" keyword causing is_deduction=True
header = "Totalsumma"
return [header] + amounts
# Strategy 4: Split by prices (e.g., "Pris 127,20 127,20 159,20")
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
prices = price_pattern.findall(cell)
if len(prices) >= 2:
parts = price_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
return [header] + values
# No pattern detected, return as single value
return [cell]
def has_merged_header(self, header: list[str] | None) -> bool:
"""
Check if header appears to be a merged cell containing multiple column names.
This happens when OCR merges table headers into a single cell, e.g.:
"Specifikation 0218103-1201 2 rum och kök Hyra Avdrag" instead of separate columns.
Also handles cases where PP-StructureV3 produces headers like:
["Specifikation ... Hyra Avdrag", "", "", ""] with empty trailing cells.
"""
if header is None or not header:
return False
# Filter out empty cells to find the actual content
non_empty_cells = [h for h in header if h.strip()]
# Check if we have a single non-empty cell that contains multiple keywords
if len(non_empty_cells) == 1:
header_text = non_empty_cells[0].lower()
# Count how many column keywords are in this single cell
keyword_count = 0
for patterns in self.mapper.mappings.values():
for pattern in patterns:
if pattern in header_text:
keyword_count += 1
break # Only count once per field type
logger.debug(f"has_merged_header: header_text='{header_text}', keyword_count={keyword_count}")
return keyword_count >= 2
return False
def extract_from_merged_cells(
self, header: list[str], rows: list[list[str]]
) -> list[LineItem]:
"""
Extract line items from tables with merged cells.
For poorly OCR'd tables like:
Header: ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
Row 1: ["", "", "", "8159"] <- amount row
Row 2: ["", "", "", "-2 000"] <- deduction row (separate line item)
Or:
Row: ["", "", "", "8159 -2 000"] <- both in same row -> 2 line items
Each amount becomes its own line item. Negative amounts are marked as is_deduction=True.
"""
items = []
# Amount pattern for Swedish format - match numbers like "8159" or "8 159" or "-2000" or "-2 000"
amount_pattern = re.compile(
r"(-?\d[\d\s]*(?:[,\.]\d+)?)"
)
# Try to parse header cell for description info
header_text = " ".join(h for h in header if h.strip()) if header else ""
logger.debug(f"extract_from_merged_cells: header_text='{header_text}'")
logger.debug(f"extract_from_merged_cells: rows={rows}")
# Extract description from header
description = None
article_number = None
# Look for object number pattern (e.g., "0218103-1201")
obj_match = re.search(r"(\d{7}-\d{4})", header_text)
if obj_match:
article_number = obj_match.group(1)
# Look for description after object number
desc_match = re.search(r"\d{7}-\d{4}\s+(.+?)(?:\s+(?:Hyra|Avdrag|Belopp))", header_text, re.IGNORECASE)
if desc_match:
description = desc_match.group(1).strip()
row_index = 0
for row in rows:
# Combine all non-empty cells in the row
row_text = " ".join(cell.strip() for cell in row if cell.strip())
logger.debug(f"extract_from_merged_cells: row text='{row_text}'")
if not row_text:
continue
# Find all amounts in the row
amounts = amount_pattern.findall(row_text)
logger.debug(f"extract_from_merged_cells: amounts={amounts}")
for amt_str in amounts:
# Clean the amount string
cleaned = amt_str.replace(" ", "").strip()
if not cleaned or cleaned == "-":
continue
is_deduction = cleaned.startswith("-")
# Skip small positive numbers that are likely not amounts
# (e.g., row indices, small percentages)
if not is_deduction:
try:
val = float(cleaned.replace(",", "."))
if val < MIN_AMOUNT_THRESHOLD:
continue
except ValueError:
continue
# Create a line item for each amount
item = LineItem(
row_index=row_index,
description=description if row_index == 0 else "Avdrag" if is_deduction else None,
article_number=article_number if row_index == 0 else None,
amount=cleaned,
is_deduction=is_deduction,
confidence=0.7,
)
items.append(item)
row_index += 1
logger.debug(f"extract_from_merged_cells: created item amount={cleaned}, is_deduction={is_deduction}")
return items

View File

@@ -0,0 +1,61 @@
"""
Line Items Data Models
Dataclasses for line item extraction results.
"""
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
@dataclass
class LineItem:
"""Single line item from invoice."""
row_index: int
description: str | None = None
quantity: str | None = None
unit: str | None = None
unit_price: str | None = None
amount: str | None = None
article_number: str | None = None
vat_rate: str | None = None
is_deduction: bool = False # True if this row is a deduction/discount
confidence: float = 0.9
@dataclass
class LineItemsResult:
"""Result of line items extraction."""
items: list[LineItem]
header_row: list[str]
raw_html: str
is_reversed: bool = False
@property
def total_amount(self) -> str | None:
"""Calculate total amount from line items (deduction rows have negative amounts)."""
if not self.items:
return None
total = Decimal("0")
for item in self.items:
if item.amount:
try:
# Parse Swedish number format (1 234,56)
amount_str = item.amount.replace(" ", "").replace(",", ".")
total += Decimal(amount_str)
except InvalidOperation:
pass
if total == 0:
return None
# Format back to Swedish format
formatted = f"{total:,.2f}".replace(",", " ").replace(".", ",")
# Fix the space/comma swap
parts = formatted.rsplit(",", 1)
if len(parts) == 2:
return parts[0].replace(" ", " ") + "," + parts[1]
return formatted

View File

@@ -0,0 +1,480 @@
"""
PP-StructureV3 Table Detection Wrapper
Provides automatic table detection in invoice images using PaddleOCR's
PP-StructureV3 pipeline. Supports both wired (bordered) and wireless
(borderless) tables commonly found in Swedish invoices.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
import logging
import numpy as np
logger = logging.getLogger(__name__)
@dataclass
class TableDetectorConfig:
"""Configuration for TableDetector."""
device: str = "gpu:0"
use_doc_orientation_classify: bool = False
use_doc_unwarping: bool = False
use_textline_orientation: bool = False
# Use SLANeXt models for better table recognition accuracy
# SLANeXt_wireless has ~6% higher accuracy than SLANet for borderless tables
wired_table_model: str = "SLANeXt_wired"
wireless_table_model: str = "SLANeXt_wireless"
layout_model: str = "PP-DocLayout_plus-L"
min_confidence: float = 0.5
@dataclass
class TableDetectionResult:
"""Result of table detection."""
bbox: tuple[float, float, float, float] # x1, y1, x2, y2 in pixels
html: str # Table structure as HTML
confidence: float
table_type: str # 'wired' or 'wireless'
cells: list[dict[str, Any]] = field(default_factory=list) # Cell-level data
class PPStructureProtocol(Protocol):
"""Protocol for PP-StructureV3 pipeline interface."""
def predict(self, image: str | np.ndarray, **kwargs: Any) -> Any:
"""Run prediction on image."""
...
class TableDetector:
"""
Table detector using PP-StructureV3.
Detects tables in invoice images and returns their bounding boxes,
HTML structure, and cell-level data.
"""
def __init__(
self,
config: TableDetectorConfig | None = None,
pipeline: PPStructureProtocol | None = None,
):
"""
Initialize table detector.
Args:
config: Configuration options. Uses defaults if None.
pipeline: Optional pre-initialized PP-StructureV3 pipeline.
If None, will be lazily initialized on first use.
"""
self.config = config or TableDetectorConfig()
self._pipeline = pipeline
self._initialized = pipeline is not None
def _ensure_initialized(self) -> None:
"""Lazily initialize PP-Structure pipeline."""
if self._initialized:
return
# Try PPStructureV3 first (paddleocr >= 3.0.0), fall back to PPStructure (2.x)
try:
from paddleocr import PPStructureV3
self._pipeline = PPStructureV3(
layout_detection_model_name=self.config.layout_model,
wired_table_structure_recognition_model_name=self.config.wired_table_model,
wireless_table_structure_recognition_model_name=self.config.wireless_table_model,
use_doc_orientation_classify=self.config.use_doc_orientation_classify,
use_doc_unwarping=self.config.use_doc_unwarping,
use_textline_orientation=self.config.use_textline_orientation,
device=self.config.device,
)
self._initialized = True
logger.info("PP-StructureV3 pipeline initialized successfully")
except ImportError:
# Fall back to PPStructure (paddleocr 2.x)
try:
from paddleocr import PPStructure
# Map device config to use_gpu for PPStructure 2.x
use_gpu = "gpu" in self.config.device.lower()
self._pipeline = PPStructure(
table=True,
ocr=True,
use_gpu=use_gpu,
show_log=False,
)
self._initialized = True
logger.info("PPStructure (2.x) pipeline initialized successfully")
except ImportError as e:
raise ImportError(
"PPStructure requires paddleocr. "
"Install with: pip install paddleocr"
) from e
def detect(
self,
image: np.ndarray | str | Path,
) -> list[TableDetectionResult]:
"""
Detect tables in an image.
Args:
image: Input image as numpy array, file path, or Path object.
Returns:
List of TableDetectionResult for each detected table.
"""
self._ensure_initialized()
if self._pipeline is None:
raise RuntimeError("Pipeline not initialized")
# Convert Path to string
if isinstance(image, Path):
image = str(image)
# Run detection
results = self._pipeline.predict(image)
return self._parse_results(results)
def _parse_results(self, results: Any) -> list[TableDetectionResult]:
"""Parse PP-StructureV3 output into TableDetectionResult list.
Supports both:
- PaddleX 3.x API: dict-like LayoutParsingResultV2 with table_res_list
- Legacy API: objects with layout_elements attribute
"""
tables: list[TableDetectionResult] = []
if results is None:
logger.warning("PP-StructureV3 returned None results")
return tables
# Log raw result type for debugging
logger.debug(f"PP-StructureV3 raw results type: {type(results).__name__}")
# Handle case where results is a single dict-like object (PaddleX 3.x)
# rather than a list of results
if hasattr(results, "get") and not isinstance(results, list):
# Single result object - wrap in list for uniform processing
logger.debug("Results is dict-like, wrapping in list")
results = [results]
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
# Iterator or generator - convert to list
try:
results = list(results)
logger.debug(f"Converted iterator to list with {len(results)} items")
except Exception as e:
logger.warning(f"Failed to convert results to list: {e}")
return tables
logger.debug(f"Processing {len(results)} result(s)")
for i, result in enumerate(results):
try:
result_type = type(result).__name__
has_get = hasattr(result, "get")
has_layout = hasattr(result, "layout_elements")
logger.debug(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
# Try PaddleX 3.x API first (dict-like with table_res_list)
if has_get:
parsed = self._parse_paddlex_result(result)
logger.debug(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
tables.extend(parsed)
continue
# Fall back to legacy API (layout_elements)
if has_layout:
legacy_count = 0
for element in result.layout_elements:
if not self._is_table_element(element):
continue
table_result = self._extract_table_data(element)
if table_result and table_result.confidence >= self.config.min_confidence:
tables.append(table_result)
legacy_count += 1
logger.debug(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
else:
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
except Exception as e:
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
continue
logger.debug(f"Total tables detected: {len(tables)}")
return tables
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
"""Parse PaddleX 3.x LayoutParsingResultV2."""
tables: list[TableDetectionResult] = []
try:
# Log result structure for debugging
result_type = type(result).__name__
result_keys = []
if hasattr(result, "keys"):
result_keys = list(result.keys())
elif hasattr(result, "__dict__"):
result_keys = list(result.__dict__.keys())
logger.debug(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
# Get table results from PaddleX 3.x API
# Handle both dict.get() and attribute access
if hasattr(result, "get"):
table_res_list = result.get("table_res_list")
parsing_res_list = result.get("parsing_res_list", [])
else:
table_res_list = getattr(result, "table_res_list", None)
parsing_res_list = getattr(result, "parsing_res_list", [])
logger.debug(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
logger.debug(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
if not table_res_list:
# Log available keys/attributes for debugging
logger.warning(f"No table_res_list found in result: {result_type}, available: {result_keys}")
return tables
# Get parsing_res_list to find table bounding boxes
table_bboxes = {}
for elem in parsing_res_list or []:
try:
if isinstance(elem, dict):
label = elem.get("label", "")
bbox = elem.get("bbox", [])
else:
label = getattr(elem, "label", "")
bbox = getattr(elem, "bbox", [])
# Check bbox has items (handles numpy arrays safely)
has_bbox = False
try:
has_bbox = len(bbox) >= 4 if hasattr(bbox, "__len__") else False
except (TypeError, ValueError):
pass
if label == "table" and has_bbox:
# Map by index (parsing_res_list tables appear in order)
idx = len(table_bboxes)
table_bboxes[idx] = bbox
except Exception as e:
logger.debug(f"Failed to parse parsing_res element: {e}")
continue
for i, table_res in enumerate(table_res_list):
try:
# Extract from PaddleX 3.x table result format
# Handle both dict and object access (SingleTableRecognitionResult)
if isinstance(table_res, dict):
cell_boxes = table_res.get("cell_box_list", [])
html = table_res.get("pred_html", "")
ocr_data = table_res.get("table_ocr_pred", {})
else:
cell_boxes = getattr(table_res, "cell_box_list", [])
html = getattr(table_res, "pred_html", "")
ocr_data = getattr(table_res, "table_ocr_pred", {})
# table_ocr_pred can be dict (PaddleOCR 3.x) or list (older versions)
# For dict format: {"rec_texts": [...], "rec_scores": [...], ...}
ocr_texts = []
if isinstance(ocr_data, dict):
ocr_texts = ocr_data.get("rec_texts", [])
elif isinstance(ocr_data, list):
ocr_texts = ocr_data
# Try to get bbox from parsing_res_list
bbox = table_bboxes.get(i, [0.0, 0.0, 0.0, 0.0])
# Handle numpy arrays - check length explicitly to avoid boolean ambiguity
try:
bbox_len = len(bbox) if hasattr(bbox, "__len__") else 0
if bbox_len < 4:
bbox = [0.0, 0.0, 0.0, 0.0]
except (TypeError, ValueError):
bbox = [0.0, 0.0, 0.0, 0.0]
# Build cells from cell_box_list and OCR text
cells = []
# Check cell_boxes length explicitly to avoid numpy array boolean issues
has_cell_boxes = False
try:
has_cell_boxes = len(cell_boxes) > 0 if hasattr(cell_boxes, "__len__") else bool(cell_boxes)
except (TypeError, ValueError):
pass
if has_cell_boxes:
# Check ocr_texts length safely for numpy arrays
ocr_texts_len = 0
try:
ocr_texts_len = len(ocr_texts) if hasattr(ocr_texts, "__len__") else 0
except (TypeError, ValueError):
pass
for j, cell_bbox in enumerate(cell_boxes):
cell_text = ocr_texts[j] if ocr_texts_len > j else ""
# Convert cell_bbox to list safely (may be numpy array)
cell_bbox_list = []
try:
cell_bbox_list = list(cell_bbox) if hasattr(cell_bbox, "__iter__") else []
except (TypeError, ValueError):
pass
cells.append({
"text": cell_text,
"bbox": cell_bbox_list,
"row": 0, # Row/col info not directly available
"col": j,
})
# Default confidence for PaddleX 3.x results
confidence = 0.9
logger.debug(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
tables.append(TableDetectionResult(
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
html=html,
confidence=confidence,
table_type="wired", # PaddleX 3.x handles both types
cells=cells,
))
except Exception as e:
import traceback
logger.warning(f"Failed to parse table_res {i}: {e}\n{traceback.format_exc()}")
continue
except Exception as e:
logger.warning(f"Failed to parse PaddleX result: {type(e).__name__}: {e}")
return tables
def _is_table_element(self, element: Any) -> bool:
"""Check if element is a table."""
if hasattr(element, "label"):
return element.label.lower() in ("table", "wired_table", "wireless_table")
if hasattr(element, "type"):
return element.type.lower() in ("table", "wired_table", "wireless_table")
return False
def _extract_table_data(self, element: Any) -> TableDetectionResult | None:
"""Extract table data from PP-StructureV3 element."""
try:
# Get bounding box
bbox = self._get_bbox(element)
if bbox is None:
return None
# Get HTML content
html = self._get_html(element)
# Get confidence
confidence = getattr(element, "score", 0.9)
if isinstance(confidence, (list, tuple)):
confidence = float(confidence[0]) if confidence else 0.9
# Determine table type
table_type = self._get_table_type(element)
# Get cells if available
cells = self._get_cells(element)
return TableDetectionResult(
bbox=bbox,
html=html,
confidence=float(confidence),
table_type=table_type,
cells=cells,
)
except Exception as e:
logger.warning(f"Failed to extract table data: {e}")
return None
def _get_bbox(self, element: Any) -> tuple[float, float, float, float] | None:
"""Extract bounding box from element."""
if hasattr(element, "bbox"):
bbox = element.bbox
if len(bbox) >= 4:
return (float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]))
if hasattr(element, "box"):
box = element.box
if len(box) >= 4:
return (float(box[0]), float(box[1]), float(box[2]), float(box[3]))
return None
def _get_html(self, element: Any) -> str:
"""Extract HTML content from element."""
if hasattr(element, "html"):
return str(element.html)
if hasattr(element, "table_html"):
return str(element.table_html)
if hasattr(element, "res") and isinstance(element.res, dict):
return element.res.get("html", "")
return ""
def _get_table_type(self, element: Any) -> str:
"""Determine table type (wired or wireless)."""
label = ""
if hasattr(element, "label"):
label = str(element.label).lower()
elif hasattr(element, "type"):
label = str(element.type).lower()
if "wireless" in label or "borderless" in label:
return "wireless"
return "wired"
def _get_cells(self, element: Any) -> list[dict[str, Any]]:
"""Extract cell-level data from element."""
cells: list[dict[str, Any]] = []
if hasattr(element, "cells"):
for cell in element.cells:
cell_data = {
"text": getattr(cell, "text", ""),
"row": getattr(cell, "row", 0),
"col": getattr(cell, "col", 0),
"row_span": getattr(cell, "row_span", 1),
"col_span": getattr(cell, "col_span", 1),
}
if hasattr(cell, "bbox"):
cell_data["bbox"] = cell.bbox
cells.append(cell_data)
return cells
def detect_from_pdf(
self,
pdf_path: str | Path,
page_number: int = 0,
dpi: int = 300,
) -> list[TableDetectionResult]:
"""
Detect tables from a PDF page.
Args:
pdf_path: Path to PDF file.
page_number: Page number (0-indexed).
dpi: Resolution for rendering.
Returns:
List of TableDetectionResult for the specified page.
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
pdf_path = Path(pdf_path)
if not pdf_path.exists():
raise FileNotFoundError(f"PDF not found: {pdf_path}")
logger.debug(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
# Render specific page
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
if page_no == page_number:
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
logger.debug(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
return self.detect(image_array)
raise ValueError(f"Page {page_number} not found in PDF")

View File

@@ -0,0 +1,475 @@
"""
Text-Based Line Items Extractor
Fallback extraction for invoices where PP-StructureV3 cannot detect table structures
(e.g., borderless/wireless tables). Uses spatial analysis of OCR text elements to
identify and group line items.
"""
from dataclasses import dataclass, field
from decimal import Decimal, InvalidOperation
import re
from typing import Any
import logging
logger = logging.getLogger(__name__)
# Configuration constants
DEFAULT_ROW_TOLERANCE = 15.0 # Max vertical distance (pixels) to consider same row
MIN_ITEMS_FOR_VALID_EXTRACTION = 2 # Minimum items required for valid extraction
MIN_TEXT_ELEMENTS_FOR_EXTRACTION = 5 # Minimum text elements needed to attempt extraction
@dataclass
class TextElement:
"""Single text element from OCR."""
text: str
bbox: tuple[float, float, float, float] # x1, y1, x2, y2
confidence: float = 1.0
@property
def center_y(self) -> float:
"""Vertical center of the element."""
return (self.bbox[1] + self.bbox[3]) / 2
@property
def center_x(self) -> float:
"""Horizontal center of the element."""
return (self.bbox[0] + self.bbox[2]) / 2
@property
def height(self) -> float:
"""Height of the element."""
return self.bbox[3] - self.bbox[1]
@dataclass
class TextLineItem:
"""Line item extracted from text elements."""
row_index: int
description: str | None = None
quantity: str | None = None
unit: str | None = None
unit_price: str | None = None
amount: str | None = None
article_number: str | None = None
vat_rate: str | None = None
is_deduction: bool = False # True if this row is a deduction/discount
confidence: float = 0.7 # Lower default confidence for text-based extraction
@dataclass
class TextLineItemsResult:
"""Result of text-based line items extraction."""
items: list[TextLineItem]
header_row: list[str]
extraction_method: str = "text_spatial"
# Amount pattern matches Swedish, US, and simple numeric formats
# Handles: "1 234,56", "1,234.56", "1234.56", "100 kr", "50:-", "-100,00"
# Does NOT handle: amounts with more than 2 decimal places, scientific notation
# See tests in test_text_line_items_extractor.py::TestAmountPattern
AMOUNT_PATTERN = re.compile(
r"(?<![0-9])(?:"
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
r"|-?\d{1,3}(?:,\d{3})*(?:\.\d{2})?" # US: 1,234.56
r"|-?\d+(?:[.,]\d{2})?" # Simple: 1234,56 or 1234.56
r")(?:\s*(?:kr|SEK|:-))?" # Optional currency suffix
r"(?![0-9])"
)
# Quantity patterns
QUANTITY_PATTERN = re.compile(
r"^(?:"
r"\d+(?:[.,]\d+)?\s*(?:st|pcs|m|kg|l|h|tim|timmar)?" # Number with optional unit
r")$",
re.IGNORECASE,
)
# VAT rate patterns
VAT_RATE_PATTERN = re.compile(r"(\d+)\s*%")
# Keywords indicating a line item area
LINE_ITEM_KEYWORDS = [
"beskrivning",
"artikel",
"produkt",
"belopp",
"summa",
"antal",
"pris",
"á-pris",
"a-pris",
"moms",
]
# Keywords indicating NOT line items (summary area)
SUMMARY_KEYWORDS = [
"att betala",
"total",
"summa att betala",
"betalningsvillkor",
"förfallodatum",
"bankgiro",
"plusgiro",
"ocr-nummer",
"fakturabelopp",
"exkl. moms",
"inkl. moms",
"varav moms",
]
class TextLineItemsExtractor:
"""
Extract line items from text elements using spatial analysis.
This is a fallback for when PP-StructureV3 cannot detect table structures.
It groups text elements by vertical position and identifies patterns
that match line item rows.
"""
def __init__(
self,
row_tolerance: float = DEFAULT_ROW_TOLERANCE,
min_items_for_valid: int = MIN_ITEMS_FOR_VALID_EXTRACTION,
):
"""
Initialize extractor.
Args:
row_tolerance: Maximum vertical distance (pixels) between elements
to consider them on the same row. Default: 15.0
min_items_for_valid: Minimum number of line items required for
extraction to be considered successful. Default: 2
"""
self.row_tolerance = row_tolerance
self.min_items_for_valid = min_items_for_valid
def extract_from_parsing_res(
self, parsing_res_list: list[dict[str, Any]]
) -> TextLineItemsResult | None:
"""
Extract line items from PP-StructureV3 parsing_res_list.
Args:
parsing_res_list: List of parsed elements from PP-StructureV3.
Returns:
TextLineItemsResult if line items found, None otherwise.
"""
if not parsing_res_list:
logger.debug("No parsing_res_list provided")
return None
# Extract text elements from parsing results
text_elements = self._extract_text_elements(parsing_res_list)
logger.debug(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
if len(text_elements) < MIN_TEXT_ELEMENTS_FOR_EXTRACTION:
logger.debug(
f"Too few text elements ({len(text_elements)}) for line item extraction, "
f"need at least {MIN_TEXT_ELEMENTS_FOR_EXTRACTION}"
)
return None
return self.extract_from_text_elements(text_elements)
def extract_from_text_elements(
self, text_elements: list[TextElement]
) -> TextLineItemsResult | None:
"""
Extract line items from a list of text elements.
Args:
text_elements: List of TextElement objects.
Returns:
TextLineItemsResult if line items found, None otherwise.
"""
# Group elements by row
rows = self._group_by_row(text_elements)
logger.debug(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
# Find the line items section
item_rows = self._identify_line_item_rows(rows)
logger.debug(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
if len(item_rows) < self.min_items_for_valid:
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
return None
# Extract structured items
items = self._parse_line_items(item_rows)
logger.debug(f"TextLineItemsExtractor: extracted {len(items)} line items")
if len(items) < self.min_items_for_valid:
return None
return TextLineItemsResult(
items=items,
header_row=[], # No explicit header in text-based extraction
extraction_method="text_spatial",
)
def _extract_text_elements(
self, parsing_res_list: list[dict[str, Any]]
) -> list[TextElement]:
"""Extract TextElement objects from parsing_res_list.
Handles both dict and LayoutBlock object formats from PP-StructureV3.
Gracefully skips invalid elements with appropriate logging.
"""
elements = []
for elem in parsing_res_list:
try:
# Get label and bbox - handle both dict and LayoutBlock objects
if isinstance(elem, dict):
label = elem.get("label", "")
bbox = elem.get("bbox", [])
# Try both 'text' and 'content' keys
text = elem.get("text", "") or elem.get("content", "")
elif hasattr(elem, "label"):
label = getattr(elem, "label", "")
bbox = getattr(elem, "bbox", [])
# LayoutBlock objects use 'content' attribute
text = getattr(elem, "content", "") or getattr(elem, "text", "")
else:
# Element is neither dict nor has expected attributes
logger.debug(f"Skipping element with unexpected type: {type(elem).__name__}")
continue
# Only process text elements (skip images, tables, etc.)
if label not in ("text", "paragraph_title", "aside_text"):
continue
# Validate bbox
if not self._valid_bbox(bbox):
logger.debug(f"Skipping element with invalid bbox: {bbox}")
continue
# Clean text
text = str(text).strip() if text else ""
if not text:
continue
elements.append(
TextElement(
text=text,
bbox=(
float(bbox[0]),
float(bbox[1]),
float(bbox[2]),
float(bbox[3]),
),
)
)
except (KeyError, TypeError, ValueError, AttributeError) as e:
# Expected format issues - log at debug level
logger.debug(f"Skipping element due to format issue: {e}")
continue
except Exception as e:
# Unexpected errors - log at warning level for visibility
logger.warning(f"Unexpected error parsing element: {type(e).__name__}: {e}")
continue
return elements
def _valid_bbox(self, bbox: Any) -> bool:
"""Check if bbox is valid (has 4 elements)."""
try:
return len(bbox) >= 4 if hasattr(bbox, "__len__") else False
except (TypeError, ValueError):
return False
def _group_by_row(
self, elements: list[TextElement]
) -> list[list[TextElement]]:
"""
Group text elements into rows based on vertical position.
Elements within row_tolerance of each other are considered same row.
Uses dynamic average center_y to handle varying element heights more accurately.
"""
if not elements:
return []
# Sort by vertical position
sorted_elements = sorted(elements, key=lambda e: e.center_y)
rows: list[list[TextElement]] = []
current_row: list[TextElement] = [sorted_elements[0]]
for elem in sorted_elements[1:]:
# Calculate dynamic average center_y for current row
avg_center_y = sum(e.center_y for e in current_row) / len(current_row)
if abs(elem.center_y - avg_center_y) <= self.row_tolerance:
# Same row - add element and recalculate average on next iteration
current_row.append(elem)
else:
# New row - finalize current row
# Sort row by horizontal position (left to right)
current_row.sort(key=lambda e: e.center_x)
rows.append(current_row)
current_row = [elem]
# Don't forget last row
if current_row:
current_row.sort(key=lambda e: e.center_x)
rows.append(current_row)
return rows
def _identify_line_item_rows(
self, rows: list[list[TextElement]]
) -> list[list[TextElement]]:
"""
Identify which rows are likely line items.
Line item rows typically have:
- Multiple elements per row
- At least one amount-like value
- Description text
"""
item_rows = []
in_item_section = False
for row in rows:
row_text = " ".join(e.text for e in row).lower()
# Check if we're entering summary section
if any(kw in row_text for kw in SUMMARY_KEYWORDS):
in_item_section = False
continue
# Check if this looks like a header row
if any(kw in row_text for kw in LINE_ITEM_KEYWORDS):
in_item_section = True
continue # Skip header row itself
# Check if row looks like a line item
if in_item_section or self._looks_like_line_item(row):
if self._looks_like_line_item(row):
item_rows.append(row)
return item_rows
def _looks_like_line_item(self, row: list[TextElement]) -> bool:
"""Check if a row looks like a line item."""
if len(row) < 2:
return False
row_text = " ".join(e.text for e in row)
# Must have at least one amount
amounts = AMOUNT_PATTERN.findall(row_text)
if not amounts:
return False
# Should have some description text (not just numbers)
has_description = any(
len(e.text) > 3 and not AMOUNT_PATTERN.fullmatch(e.text.strip())
for e in row
)
return has_description
def _parse_line_items(
self, item_rows: list[list[TextElement]]
) -> list[TextLineItem]:
"""Parse line item rows into structured items."""
items = []
for idx, row in enumerate(item_rows):
item = self._parse_single_row(row, idx)
if item:
items.append(item)
return items
def _parse_single_row(
self, row: list[TextElement], row_index: int
) -> TextLineItem | None:
"""Parse a single row into a line item."""
if not row:
return None
# Combine all text for analysis
all_text = " ".join(e.text for e in row)
# Find amounts (rightmost is usually the total)
amounts = list(AMOUNT_PATTERN.finditer(all_text))
if not amounts:
return None
# Last amount is typically line total
amount_match = amounts[-1]
amount = amount_match.group(0).strip()
# Second to last might be unit price
unit_price = None
if len(amounts) >= 2:
unit_price = amounts[-2].group(0).strip()
# Look for quantity
quantity = None
for elem in row:
text = elem.text.strip()
if QUANTITY_PATTERN.match(text):
quantity = text
break
# Look for VAT rate
vat_rate = None
vat_match = VAT_RATE_PATTERN.search(all_text)
if vat_match:
vat_rate = vat_match.group(1)
# Description is typically the longest non-numeric text
description = None
max_len = 0
for elem in row:
text = elem.text.strip()
# Skip if it looks like a number/amount
if AMOUNT_PATTERN.fullmatch(text):
continue
if QUANTITY_PATTERN.match(text):
continue
if len(text) > max_len:
description = text
max_len = len(text)
return TextLineItem(
row_index=row_index,
description=description,
quantity=quantity,
unit_price=unit_price,
amount=amount,
vat_rate=vat_rate,
confidence=0.7,
)
def convert_text_line_item(item: TextLineItem) -> "LineItem":
"""Convert TextLineItem to standard LineItem dataclass."""
from .line_items_extractor import LineItem
return LineItem(
row_index=item.row_index,
description=item.description,
quantity=item.quantity,
unit=item.unit,
unit_price=item.unit_price,
amount=item.amount,
article_number=item.article_number,
vat_rate=item.vat_rate,
is_deduction=item.is_deduction,
confidence=item.confidence,
)

View File

@@ -1,7 +1,19 @@
"""
Cross-validation module for verifying field extraction using LLM.
Cross-validation module for verifying field extraction.
Includes LLM validation and VAT cross-validation.
"""
from .llm_validator import LLMValidator
from .vat_validator import (
VATValidationResult,
VATValidator,
MathCheckResult,
)
__all__ = ['LLMValidator']
__all__ = [
"LLMValidator",
"VATValidationResult",
"VATValidator",
"MathCheckResult",
]

View File

@@ -7,6 +7,7 @@ the autolabel results to identify potential errors.
import json
import base64
import logging
import os
from pathlib import Path
from typing import Optional, Dict, Any, List
@@ -14,6 +15,8 @@ from dataclasses import dataclass, asdict
from datetime import datetime
import psycopg2
logger = logging.getLogger(__name__)
from psycopg2.extras import execute_values
from shared.config import DEFAULT_DPI
@@ -648,7 +651,7 @@ Return ONLY the JSON object, no other text."""
docs = self.get_documents_with_failed_matches(limit=limit)
if verbose:
print(f"Found {len(docs)} documents with failed matches to validate")
logger.info("Found %d documents with failed matches to validate", len(docs))
results = []
for i, doc in enumerate(docs):
@@ -656,16 +659,16 @@ Return ONLY the JSON object, no other text."""
if verbose:
failed_fields = [f['field'] for f in doc['failed_fields']]
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
logger.info("[%d/%d] Validating %s... (failed: %s)", i+1, len(docs), doc_id[:8], ', '.join(failed_fields))
result = self.validate_document(doc_id, provider, model)
results.append(result)
if verbose:
if result.error:
print(f" ERROR: {result.error}")
logger.error(" ERROR: %s", result.error)
else:
print(f" OK ({result.processing_time_ms:.0f}ms)")
logger.info(" OK (%.0fms)", result.processing_time_ms)
return results

View File

@@ -0,0 +1,267 @@
"""
VAT Validator
Cross-validates VAT information from multiple sources:
- Mathematical verification (base × rate = vat)
- Line items vs VAT summary comparison
- Consistency with existing amount field
"""
from dataclasses import dataclass, field
from decimal import Decimal, InvalidOperation
from backend.vat.vat_extractor import VATSummary, AmountParser
from backend.table.line_items_extractor import LineItemsResult
@dataclass
class MathCheckResult:
"""Result of a single VAT rate mathematical check."""
rate: float
base_amount: float | None
expected_vat: float | None
actual_vat: float
is_valid: bool
tolerance: float
@dataclass
class VATValidationResult:
"""Complete VAT validation result."""
is_valid: bool
confidence_score: float # 0.0 - 1.0
# Mathematical verification
math_checks: list[MathCheckResult]
total_check: bool # incl = excl + total_vat?
# Source comparison
line_items_vs_summary: bool | None # line items total = VAT summary?
amount_consistency: bool | None # total_incl_vat = existing amount field?
# Review flags
needs_review: bool
review_reasons: list[str] = field(default_factory=list)
class VATValidator:
"""Validates VAT information using multiple cross-checks."""
def __init__(self, tolerance: float = 0.02):
"""
Initialize validator.
Args:
tolerance: Acceptable difference for math checks (default 0.02 = 2 cents)
"""
self.tolerance = tolerance
self.amount_parser = AmountParser()
def validate(
self,
vat_summary: VATSummary,
line_items: LineItemsResult | None = None,
existing_amount: str | None = None,
) -> VATValidationResult:
"""
Validate VAT information.
Args:
vat_summary: Extracted VAT summary.
line_items: Optional line items for comparison.
existing_amount: Optional existing amount field from YOLO extraction.
Returns:
VATValidationResult with all check results.
"""
review_reasons: list[str] = []
# Handle empty summary
if not vat_summary.breakdowns and not vat_summary.total_vat:
return VATValidationResult(
is_valid=False,
confidence_score=0.0,
math_checks=[],
total_check=False,
line_items_vs_summary=None,
amount_consistency=None,
needs_review=True,
review_reasons=["No VAT information found"],
)
# Run all checks
math_checks = self._run_math_checks(vat_summary)
total_check = self._check_totals(vat_summary)
line_items_check = self._check_line_items(vat_summary, line_items)
amount_check = self._check_amount_consistency(vat_summary, existing_amount)
# Collect review reasons
math_failures = [c for c in math_checks if not c.is_valid]
if math_failures:
review_reasons.append(f"Math check failed for {len(math_failures)} VAT rate(s)")
if not total_check:
review_reasons.append("Total amount mismatch (excl + vat != incl)")
if line_items_check is False:
review_reasons.append("Line items total doesn't match VAT summary")
if amount_check is False:
review_reasons.append("VAT total doesn't match existing amount field")
# Calculate overall validity and confidence
all_math_valid = all(c.is_valid for c in math_checks) if math_checks else True
is_valid = all_math_valid and total_check and (amount_check is not False)
confidence_score = self._calculate_confidence(
vat_summary, math_checks, total_check, line_items_check, amount_check
)
needs_review = len(review_reasons) > 0 or confidence_score < 0.7
return VATValidationResult(
is_valid=is_valid,
confidence_score=confidence_score,
math_checks=math_checks,
total_check=total_check,
line_items_vs_summary=line_items_check,
amount_consistency=amount_check,
needs_review=needs_review,
review_reasons=review_reasons,
)
def _run_math_checks(self, vat_summary: VATSummary) -> list[MathCheckResult]:
"""Run mathematical verification for each VAT rate."""
results = []
for breakdown in vat_summary.breakdowns:
actual_vat = self.amount_parser.parse(breakdown.vat_amount)
if actual_vat is None:
continue
base_amount = None
expected_vat = None
is_valid = True
if breakdown.base_amount:
base_amount = self.amount_parser.parse(breakdown.base_amount)
if base_amount is not None:
expected_vat = base_amount * (breakdown.rate / 100)
is_valid = abs(expected_vat - actual_vat) <= self.tolerance
results.append(
MathCheckResult(
rate=breakdown.rate,
base_amount=base_amount,
expected_vat=expected_vat,
actual_vat=actual_vat,
is_valid=is_valid,
tolerance=self.tolerance,
)
)
return results
def _check_totals(self, vat_summary: VATSummary) -> bool:
"""Check if total_excl + total_vat = total_incl."""
if not vat_summary.total_excl_vat or not vat_summary.total_incl_vat:
# Can't verify without both values
return True # Assume ok if we can't check
excl = self.amount_parser.parse(vat_summary.total_excl_vat)
incl = self.amount_parser.parse(vat_summary.total_incl_vat)
if excl is None or incl is None:
return True # Can't verify
# Calculate expected VAT
if vat_summary.total_vat:
vat = self.amount_parser.parse(vat_summary.total_vat)
if vat is not None:
expected_incl = excl + vat
return abs(expected_incl - incl) <= self.tolerance
# Can't verify if vat parsing failed
return True
else:
# Sum up breakdown VAT amounts
total_vat = sum(
self.amount_parser.parse(b.vat_amount) or 0
for b in vat_summary.breakdowns
)
expected_incl = excl + total_vat
return abs(expected_incl - incl) <= self.tolerance
def _check_line_items(
self, vat_summary: VATSummary, line_items: LineItemsResult | None
) -> bool | None:
"""Check if line items total matches VAT summary."""
if line_items is None or not line_items.items:
return None # No comparison possible
# Sum line item amounts
line_total = 0.0
for item in line_items.items:
if item.amount:
amount = self.amount_parser.parse(item.amount)
if amount is not None:
line_total += amount
# Compare with VAT summary total
if vat_summary.total_excl_vat:
summary_total = self.amount_parser.parse(vat_summary.total_excl_vat)
if summary_total is not None:
# Allow larger tolerance for line items (rounding errors)
return abs(line_total - summary_total) <= 1.0
return None
def _check_amount_consistency(
self, vat_summary: VATSummary, existing_amount: str | None
) -> bool | None:
"""Check if VAT total matches existing amount field."""
if existing_amount is None:
return None # No comparison possible
existing = self.amount_parser.parse(existing_amount)
if existing is None:
return None
if vat_summary.total_incl_vat:
vat_total = self.amount_parser.parse(vat_summary.total_incl_vat)
if vat_total is not None:
return abs(existing - vat_total) <= self.tolerance
return None
def _calculate_confidence(
self,
vat_summary: VATSummary,
math_checks: list[MathCheckResult],
total_check: bool,
line_items_check: bool | None,
amount_check: bool | None,
) -> float:
"""Calculate overall confidence score."""
score = vat_summary.confidence # Start with extraction confidence
# Adjust based on validation results
if math_checks:
math_valid_ratio = sum(1 for c in math_checks if c.is_valid) / len(math_checks)
score = score * (0.5 + 0.5 * math_valid_ratio)
if not total_check:
score *= 0.5
if line_items_check is True:
score = min(score * 1.1, 1.0) # Boost if line items match
elif line_items_check is False:
score *= 0.7
if amount_check is True:
score = min(score * 1.1, 1.0) # Boost if amount matches
elif amount_check is False:
score *= 0.6
return round(score, 2)

View File

@@ -0,0 +1,19 @@
"""
VAT extraction module.
Extracts VAT (Moms) information from Swedish invoices using regex patterns.
"""
from .vat_extractor import (
VATBreakdown,
VATSummary,
VATExtractor,
AmountParser,
)
__all__ = [
"VATBreakdown",
"VATSummary",
"VATExtractor",
"AmountParser",
]

View File

@@ -0,0 +1,350 @@
"""
VAT Extractor
Extracts VAT (Moms) information from Swedish invoice text using regex patterns.
Supports multiple VAT rates (25%, 12%, 6%, 0%) and various Swedish formats.
"""
from dataclasses import dataclass
import re
from decimal import Decimal, InvalidOperation
@dataclass
class VATBreakdown:
"""Single VAT rate breakdown."""
rate: float # 25.0, 12.0, 6.0, 0.0
base_amount: str | None # Tax base (excl VAT)
vat_amount: str # VAT amount
source: str # 'regex' | 'line_items'
@dataclass
class VATSummary:
"""Complete VAT summary."""
breakdowns: list[VATBreakdown]
total_excl_vat: str | None
total_vat: str | None
total_incl_vat: str | None
confidence: float
class AmountParser:
"""Parse Swedish and European number formats."""
# Patterns to clean amount strings
CURRENCY_PATTERN = re.compile(r"(SEK|kr|:-)\s*", re.IGNORECASE)
def parse(self, amount_str: str) -> float | None:
"""
Parse amount string to float.
Handles:
- Swedish: 1 234,56
- European: 1.234,56
- US: 1,234.56
Args:
amount_str: Amount string to parse.
Returns:
Parsed float value or None if invalid.
"""
if not amount_str or not amount_str.strip():
return None
# Clean the string
cleaned = amount_str.strip()
# Remove currency
cleaned = self.CURRENCY_PATTERN.sub("", cleaned).strip()
cleaned = re.sub(r"^SEK\s*", "", cleaned, flags=re.IGNORECASE)
if not cleaned:
return None
# Check for negative
is_negative = cleaned.startswith("-")
if is_negative:
cleaned = cleaned[1:].strip()
try:
# Remove spaces (Swedish thousands separator)
cleaned = cleaned.replace(" ", "")
# Detect format
# Swedish/European: comma is decimal separator
# US: period is decimal separator
has_comma = "," in cleaned
has_period = "." in cleaned
if has_comma and has_period:
# Both present - check position
comma_pos = cleaned.rfind(",")
period_pos = cleaned.rfind(".")
if comma_pos > period_pos:
# European: 1.234,56
cleaned = cleaned.replace(".", "")
cleaned = cleaned.replace(",", ".")
else:
# US: 1,234.56
cleaned = cleaned.replace(",", "")
elif has_comma:
# Swedish: 1234,56
cleaned = cleaned.replace(",", ".")
# else: US format or integer
value = float(cleaned)
return -value if is_negative else value
except (ValueError, InvalidOperation):
return None
class VATExtractor:
"""Extract VAT information from invoice text."""
# VAT extraction patterns
# Note: Amount pattern uses [^\n] to avoid crossing line boundaries
VAT_PATTERNS = [
# Moms 25%: 2 500,00 or Moms 25% 2 500,00
re.compile(
r"[Mm]oms\s*(\d+(?:[,\.]\d+)?)\s*%\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
# Varav moms 25% 2 500,00
re.compile(
r"[Vv]arav\s+moms\s+(\d+(?:[,\.]\d+)?)\s*%\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
# 25% moms: 2 500,00 (at line start or after whitespace)
re.compile(
r"(?:^|\s)(\d+(?:[,\.]\d+)?)\s*%\s*moms\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
# Moms (25%): 2 500,00
re.compile(
r"[Mm]oms\s*\((\d+(?:[,\.]\d+)?)\s*%\)\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
]
# Pattern with base amount (Underlag)
VAT_WITH_BASE_PATTERN = re.compile(
r"[Mm]oms\s*(\d+(?:[,\.]\d+)?)\s*%\s*:?\s*([\d\s,\.]+)"
r"(?:.*?[Uu]nderlag\s*([\d\s,\.]+))?",
re.MULTILINE | re.DOTALL,
)
# Total patterns
TOTAL_EXCL_PATTERN = re.compile(
r"(?:[Ss]umma|[Tt]otal(?:t)?|[Nn]etto)\s*(?:exkl\.?\s*)?(?:moms)?\s*:?\s*([\d\s,\.]+)",
re.MULTILINE,
)
TOTAL_VAT_PATTERN = re.compile(
r"(?:[Ss]umma|[Tt]otal(?:t)?)\s*moms\s*:?\s*([\d\s,\.]+)",
re.MULTILINE,
)
TOTAL_INCL_PATTERN = re.compile(
r"(?:[Ss]umma|[Tt]otal(?:t)?|[Bb]rutto)\s*(?:inkl\.?\s*)?(?:moms|att\s*betala)?\s*:?\s*([\d\s,\.]+)",
re.MULTILINE,
)
def __init__(self):
self.amount_parser = AmountParser()
def extract(self, text: str) -> VATSummary:
"""
Extract VAT information from text.
Args:
text: Invoice text (OCR output).
Returns:
VATSummary with extracted information.
"""
if not text or not text.strip():
return VATSummary(
breakdowns=[],
total_excl_vat=None,
total_vat=None,
total_incl_vat=None,
confidence=0.0,
)
breakdowns = self._extract_breakdowns(text)
total_excl = self._extract_total_excl(text)
total_vat = self._extract_total_vat(text)
total_incl = self._extract_total_incl(text)
confidence = self._calculate_confidence(
breakdowns, total_excl, total_vat, total_incl
)
return VATSummary(
breakdowns=breakdowns,
total_excl_vat=total_excl,
total_vat=total_vat,
total_incl_vat=total_incl,
confidence=confidence,
)
def _extract_breakdowns(self, text: str) -> list[VATBreakdown]:
"""Extract individual VAT rate breakdowns."""
breakdowns = []
seen_rates = set()
# Try pattern with base amount first
for match in self.VAT_WITH_BASE_PATTERN.finditer(text):
rate = self._parse_rate(match.group(1))
vat_amount = self._clean_amount(match.group(2))
base_amount = (
self._clean_amount(match.group(3)) if match.group(3) else None
)
if rate is not None and vat_amount and rate not in seen_rates:
seen_rates.add(rate)
breakdowns.append(
VATBreakdown(
rate=rate,
base_amount=base_amount,
vat_amount=vat_amount,
source="regex",
)
)
# Try other patterns
for pattern in self.VAT_PATTERNS:
for match in pattern.finditer(text):
rate = self._parse_rate(match.group(1))
vat_amount = self._clean_amount(match.group(2))
if rate is not None and vat_amount and rate not in seen_rates:
seen_rates.add(rate)
breakdowns.append(
VATBreakdown(
rate=rate,
base_amount=None,
vat_amount=vat_amount,
source="regex",
)
)
return breakdowns
def _extract_total_excl(self, text: str) -> str | None:
"""Extract total excluding VAT."""
# Look for specific patterns first
patterns = [
re.compile(r"[Ss]umma\s+exkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Nn]etto\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Ee]xkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
]
for pattern in patterns:
match = pattern.search(text)
if match:
return self._clean_amount(match.group(1))
return None
def _extract_total_vat(self, text: str) -> str | None:
"""Extract total VAT amount."""
patterns = [
re.compile(r"[Ss]umma\s+moms\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Tt]otal(?:t)?\s+moms\s*:?\s*([\d\s,\.]+)"),
# Generic "Moms:" without percentage
re.compile(r"^[Mm]oms\s*:?\s*([\d\s,\.]+)", re.MULTILINE),
]
for pattern in patterns:
match = pattern.search(text)
if match:
return self._clean_amount(match.group(1))
return None
def _extract_total_incl(self, text: str) -> str | None:
"""Extract total including VAT."""
patterns = [
re.compile(r"[Ss]umma\s+inkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Tt]otal(?:t)?\s+att\s+betala\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Bb]rutto\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Aa]tt\s+betala\s*:?\s*([\d\s,\.]+)"),
]
for pattern in patterns:
match = pattern.search(text)
if match:
return self._clean_amount(match.group(1))
return None
def _parse_rate(self, rate_str: str) -> float | None:
"""Parse VAT rate string to float."""
try:
rate_str = rate_str.replace(",", ".")
return float(rate_str)
except (ValueError, TypeError):
return None
def _clean_amount(self, amount_str: str) -> str | None:
"""Clean and validate amount string."""
if not amount_str:
return None
cleaned = amount_str.strip()
# Remove trailing non-numeric chars (except comma/period)
cleaned = re.sub(r"[^\d\s,\.]+$", "", cleaned).strip()
if not cleaned:
return None
# Validate it parses as a number
if self.amount_parser.parse(cleaned) is None:
return None
return cleaned
def _calculate_confidence(
self,
breakdowns: list[VATBreakdown],
total_excl: str | None,
total_vat: str | None,
total_incl: str | None,
) -> float:
"""Calculate confidence score based on extracted data."""
score = 0.0
# Has VAT breakdowns
if breakdowns:
score += 0.3
# Has total excluding VAT
if total_excl:
score += 0.2
# Has total VAT
if total_vat:
score += 0.2
# Has total including VAT
if total_incl:
score += 0.15
# Mathematical consistency check
if total_excl and total_vat and total_incl:
excl = self.amount_parser.parse(total_excl)
vat = self.amount_parser.parse(total_vat)
incl = self.amount_parser.parse(total_incl)
if excl and vat and incl:
expected = excl + vat
if abs(expected - incl) < 0.02: # Allow 2 cent tolerance
score += 0.15
return min(score, 1.0)

View File

@@ -39,6 +39,26 @@ from backend.web.schemas.common import ErrorResponse
logger = logging.getLogger(__name__)
# PDF magic bytes - all valid PDF files must start with this sequence
PDF_MAGIC_BYTES = b"%PDF"
def validate_pdf_magic_bytes(content: bytes) -> None:
"""Validate that file content has valid PDF magic bytes.
PDF files must start with the bytes '%PDF' (0x25 0x50 0x44 0x46).
This validation prevents attackers from uploading malicious files
(executables, scripts) by simply renaming them to .pdf extension.
Args:
content: The raw file content to validate.
Raises:
ValueError: If the content does not start with valid PDF magic bytes.
"""
if not content or not content.startswith(PDF_MAGIC_BYTES):
raise ValueError("Invalid PDF file: does not have valid PDF header")
def _validate_uuid(value: str, name: str = "ID") -> None:
"""Validate UUID format."""
@@ -135,6 +155,14 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
logger.error(f"Failed to read uploaded file: {e}")
raise HTTPException(status_code=400, detail="Failed to read file")
# Validate PDF magic bytes (only for PDF files)
if file_ext == ".pdf":
try:
validate_pdf_magic_bytes(content)
except ValueError as e:
logger.warning(f"PDF magic bytes validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
# Get page count (for PDF)
page_count = 1
if file_ext == ".pdf":

View File

@@ -12,6 +12,7 @@ from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
from .models import register_model_routes
from .pool import register_pool_routes
def create_training_router() -> APIRouter:
@@ -23,6 +24,7 @@ def create_training_router() -> APIRouter:
register_export_routes(router)
register_dataset_routes(router)
register_model_routes(router)
register_pool_routes(router)
return router

Some files were not shown because too many files have changed in this diff Show More